From 17080f6484aad9b35575e1b0bb8c5eaf89b2f1d1 Mon Sep 17 00:00:00 2001 From: HecreReed <821896444@qq.com> Date: Tue, 26 May 2026 11:56:12 +0800 Subject: [PATCH 1/8] refactor: split nbnc hotspot files --- lib/PTO/IR/PTO.cpp | 12925 +-------------- lib/PTO/IR/PTO.def | 12933 ++++++++++++++++ .../Transforms/GraphSyncSolver/SyncSolver.cpp | 2568 +-- .../Transforms/GraphSyncSolver/SyncSolver.def | 2576 +++ lib/PTO/Transforms/PTOToEmitC.cpp | 12895 +-------------- lib/PTO/Transforms/PTOToEmitC.def | 12903 +++++++++++++++ lib/PTO/Transforms/PTOViewToMemref.cpp | 3607 +---- lib/PTO/Transforms/PTOViewToMemref.def | 3615 +++++ tools/ptobc/generated/ptobc_opcodes_v0.def | 722 + tools/ptobc/generated/ptobc_opcodes_v0.h | 714 +- 10 files changed, 32754 insertions(+), 32704 deletions(-) create mode 100644 lib/PTO/IR/PTO.def create mode 100644 lib/PTO/Transforms/GraphSyncSolver/SyncSolver.def create mode 100644 lib/PTO/Transforms/PTOToEmitC.def create mode 100644 lib/PTO/Transforms/PTOViewToMemref.def create mode 100644 tools/ptobc/generated/ptobc_opcodes_v0.def diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index 376b9c017..e9dc72235 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -6,12928 +6,5 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -//===- PTO.cpp - PTO Dialect ----------------------------------------------===// -//===----------------------------------------------------------------------===// -#include "PTO/IR/PTO.h" -#include "PTO/IR/PTOTypeUtils.h" -#include "PTO/IR/PTOSyncUtils.h" - -#include "mlir/AsmParser/AsmParser.h" -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/LLVMIR/LLVMTypes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/DialectImplementation.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/IR/Value.h" -#include "mlir/IR/AsmState.h" -#include "mlir/IR/SymbolTable.h" -#include "mlir/IR/Types.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Parser/Parser.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/TypeSwitch.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "llvm/Support/ErrorHandling.h" - -#include -#include -#include -#include -#include - -using namespace mlir; -using namespace mlir::pto; - -// Forward declarations for custom shape/type printers used by tensor_view and -// partition_tensor_view. -namespace mlir { -namespace pto { -static LogicalResult parseShapeAndElem(AsmParser &parser, - SmallVectorImpl &shape, - Type &elementType, - bool allowDynamic = true); -static void printShapeAndElem(AsmPrinter &printer, - ArrayRef shape, - Type elementType); -} // namespace pto -} // namespace mlir - -// ============================================================================= -// TileBufType 的自定义 Shape 解析与打印函数 -// ============================================================================= - -// 解析逻辑:解析形如 "32x32" 的维度列表 -[[maybe_unused]] static ParseResult parseShape(AsmParser &parser, SmallVectorImpl &shape) { - // parseDimensionList 会解析 "dim x dim x ...", 遇到无法解析为维度的字符停止 - // 参数 allowDynamic=true (允许 ?), withTrailingX=false (不吞掉末尾的 x) - if (parser.parseDimensionList(shape, /*allowDynamic=*/true, /*withTrailingX=*/false)) - return failure(); - return success(); -} - -// 打印逻辑:打印形如 "32x32" 的维度列表 -[[maybe_unused]] static void printShape(AsmPrinter &printer, ArrayRef shape) { - for (auto it = shape.begin(); it != shape.end(); ++it) { - if (it != shape.begin()) printer << "x"; // 维度间的分隔符 - if (*it == ShapedType::kDynamic) - printer << "?"; - else - printer << *it; - } - // 注意:我们不在这里打印末尾的 'x',因为 assemblyFormat 中已经写了 `x` $elementType -} - -static std::optional getPTOMemorySpaceEnum(Type ty); -enum class VerifierTargetArch { - A2A3, - A5, -}; -static VerifierTargetArch getVerifierTargetArch(Operation *op); -static std::optional getVerifierArchName(Operation *op); -static bool isSupportedVecElemType(Type ty, bool allowBf16 = true, - bool allowInt8 = true); -static bool isSupportedLoadStoreElemTypeA2A3(Type ty); -static bool isSupportedGatherElemTypeA2A3(Type ty); -static bool isSupportedGatherElemTypeA5(Type ty); -static bool isA5TLoadStoreTransferElemType(Type ty); -static bool isA5AccStorePreQuantDstType(Type srcElem, Type dstElem); -static bool isA5LowPrecisionTCvtPair(Type srcElem, Type dstElem); -static bool isA5SupportedTCvtPair(Type srcElem, Type dstElem); -static ParseResult parseSyncEventOpCommon(OpAsmParser &parser, - OperationState &result, - StringAttr pipeAttrName, - StringAttr eventIdAttrName); -static void printSyncEventOpCommon(OpAsmPrinter &p, Operation *op, - PipeAttr pipeAttr, IntegerAttr eventAttr, - Value eventDyn, StringRef pipeAttrName, - StringRef eventIdAttrName); -static bool isTileLikeType(Type ty); -static SmallVector getShapeVec(Type ty); -static SmallVector getValidShapeVec(Type ty); -static SmallVector getValidShapeVec(Value value); -static bool isByteIntegerType(Type ty); -static LogicalResult verifyTileBufCommon(Operation *op, Type ty, StringRef name, - bool allowLowPrecision = false); -static LogicalResult verifyTileBufSameElemType(Operation *op, Type lhs, Type rhs, - StringRef lhsName, - StringRef rhsName); -static LogicalResult verifyTileBufSameLogicalExtent(Operation *op, Type lhs, - Type rhs, StringRef lhsName, - StringRef rhsName, - bool compareValidShape); - -static LogicalResult verifyTileBufSameValidShape(Operation *op, Type lhs, Type rhs, - StringRef lhsName, StringRef rhsName); -static LogicalResult verifyVecTileCommon(Operation *op, Type ty, StringRef name); -static LogicalResult verifyVecTileCommonA2A3(Operation *op, Type ty, - StringRef name); -static LogicalResult verifyVecTileCommonA5(Operation *op, Type ty, - StringRef name); -static LogicalResult verifyVecTileUnaryOp(Operation *op, Type srcTy, Type dstTy, - StringRef srcName = "src", - StringRef dstName = "dst", - bool allowBf16 = true, - bool allowInt8 = true); -static LogicalResult verifyAccTileCommon(Operation *op, Type ty, StringRef name); -static LogicalResult verifyAccTileCommonA2A3(Operation *op, Type ty, - StringRef name); -static LogicalResult verifyAccTileCommonA5(Operation *op, Type ty, - StringRef name); -static LogicalResult verifyMatTileOperands(Operation *op, Type lhsTy, Type rhsTy, - Type dstTy); -static LogicalResult verifyMatTileOperandsA2A3(Operation *op, Type lhsTy, - Type rhsTy, Type dstTy); -static LogicalResult verifyMatTileOperandsA5(Operation *op, Type lhsTy, - Type rhsTy, Type dstTy); -static LogicalResult verifyGemvTileOperands(Operation *op, Type lhsTy, Type rhsTy, - Type dstTy); -static LogicalResult verifyAsyncFlatContiguous1DGMViewLike(Operation *op, - Value value, - StringRef name); -static LogicalResult verifyGemvTileOperandsA2A3(Operation *op, Type lhsTy, - Type rhsTy, Type dstTy); -static LogicalResult verifyGemvTileOperandsA5(Operation *op, Type lhsTy, - Type rhsTy, Type dstTy); -static LogicalResult verifyMatBiasTile(Operation *op, Type biasTy, Type dstTy, - bool requireFloatBias = false); -static LogicalResult verifyMatBiasTileA2A3(Operation *op, Type biasTy, Type dstTy, - bool requireFloatBias = false); -static LogicalResult verifyMatBiasTileA5(Operation *op, Type biasTy, Type dstTy, - bool requireFloatBias = false); -static LogicalResult verifyMatmulTypeTriple(Operation *op, Type lhsElemTy, - Type rhsElemTy, Type dstElemTy); -static std::optional getLogicalViewLayout(Value value); -static std::optional getTileBufLogicalLayout(pto::TileBufType type); -static std::optional getConstantIntegerValue(Value value); -static LogicalResult verifyPartialValidPattern(Operation *op, Type src0Ty, - Type src1Ty, Type dstTy); -static Type getElemTy(Type ty); -static FailureOr -verifyMatchingRowMajorBinaryTileOpCommon(Operation *op, Type src0Ty, - Type src1Ty, Type dstTy); -static FailureOr -verifyNumericScalarTileOpCommon(Operation *op, Type srcTy, Type dstTy, - Type scalarTy, bool requireValidRowsEqual); -static FailureOr -verifyShiftLikeBinaryTileOpCommon(Operation *op, Type src0Ty, Type src1Ty, - Type dstTy); -static LogicalResult verifyArithmeticElemTypeForArch( - Operation *op, Type elemTy, PTOArch targetArch, bool allowInt8OnA5, - bool allowBf16OnA5, StringRef a2a3Error, StringRef a5Error); -static bool isRowMajorTileBuf(Type ty); - -#define GET_ENUM_CLASSES -#include "PTO/IR/PTOEnums.cpp.inc" - -#define GET_TYPEDEF_CLASSES -#include "PTO/IR/PTOTypeDefs.cpp.inc" - -#define GET_ATTRDEF_CLASSES -#include "PTO/IR/PTOAttrs.cpp.inc" - -#include "PTO/IR/PTODialect.cpp.inc" - -[[maybe_unused]] static LogicalResult parseShapeAndElemStable(mlir::AsmParser &parser, - llvm::SmallVectorImpl &shape, - mlir::Type &elementType) { - if (failed(parser.parseLess())) - return failure(); - - if (failed(parser.parseDimensionList(shape, /*allowDynamic=*/true))) - return failure(); - - if (failed(parser.parseType(elementType))) - return failure(); - - if (failed(parser.parseGreater())) - return failure(); - - return success(); -} - -static int64_t getPTOTypeRank(Type type) { - // 1. 处理标准的 MLIR 类型 (MemRef, Tensor, Vector) - if (auto shapedTy = dyn_cast(type)) { - if (shapedTy.hasRank()) - return shapedTy.getRank(); - return -1; // Unranked type - } - - // 2. 处理 PTO 自定义类型 - if (auto tvTy = dyn_cast(type)) - return tvTy.getRank(); - - if (auto tileTy = dyn_cast(type)) - return tileTy.getRank(); - - if (auto tileViewTy = dyn_cast(type)) - return tileViewTy.getRank(); - - if (auto tileBufTy = dyn_cast(type)) - return tileBufTy.getRank(); - - // 3. 不支持的类型 - return -1; -} - -static bool isGmAddressSpaceAttr(Attribute memorySpace) { - if (!memorySpace) - return true; - if (auto addr = mlir::dyn_cast(memorySpace)) - return addr.getAddressSpace() == pto::AddressSpace::GM; - if (auto intAttr = mlir::dyn_cast(memorySpace)) - return intAttr.getInt() == 0; - return false; -} - -PTOArch mlir::pto::getTargetArch(ModuleOp module) { - if (!module) - return PTOArch::A3; - - auto arch = module->getAttrOfType(kPTOTargetArchAttrName); - if (arch && arch.getValue().equals_insensitive("a5")) - return PTOArch::A5; - return PTOArch::A3; -} - -PTOArch mlir::pto::getTargetArch(Operation *op) { - if (!op) - return PTOArch::A3; - return getTargetArch(op->getParentOfType()); -} - -bool mlir::pto::isTargetArchA3(ModuleOp module) { - return getTargetArch(module) == PTOArch::A3; -} - -bool mlir::pto::isTargetArchA5(ModuleOp module) { - return getTargetArch(module) == PTOArch::A5; -} - -bool mlir::pto::isTargetArchA3(Operation *op) { - return getTargetArch(op) == PTOArch::A3; -} - -bool mlir::pto::isTargetArchA5(Operation *op) { - return getTargetArch(op) == PTOArch::A5; -} - -static llvm::TypeSize getOneByteTypeSize() { - return llvm::TypeSize::getFixed(8); -} - -llvm::TypeSize mlir::pto::HiF8Type::getTypeSizeInBits( - const DataLayout &, DataLayoutEntryListRef) const { - return getOneByteTypeSize(); -} - -uint64_t mlir::pto::HiF8Type::getABIAlignment(const DataLayout &, - DataLayoutEntryListRef) const { - return 1; -} - -uint64_t mlir::pto::HiF8Type::getPreferredAlignment( - const DataLayout &, DataLayoutEntryListRef) const { - return 1; -} - -llvm::TypeSize mlir::pto::F4E1M2x2Type::getTypeSizeInBits( - const DataLayout &, DataLayoutEntryListRef) const { - return getOneByteTypeSize(); -} - -uint64_t mlir::pto::F4E1M2x2Type::getABIAlignment( - const DataLayout &, DataLayoutEntryListRef) const { - return 1; -} - -uint64_t mlir::pto::F4E1M2x2Type::getPreferredAlignment( - const DataLayout &, DataLayoutEntryListRef) const { - return 1; -} - -llvm::TypeSize mlir::pto::F4E2M1x2Type::getTypeSizeInBits( - const DataLayout &, DataLayoutEntryListRef) const { - return getOneByteTypeSize(); -} - -uint64_t mlir::pto::F4E2M1x2Type::getABIAlignment( - const DataLayout &, DataLayoutEntryListRef) const { - return 1; -} - -uint64_t mlir::pto::F4E2M1x2Type::getPreferredAlignment( - const DataLayout &, DataLayoutEntryListRef) const { - return 1; -} - -static VerifierTargetArch getVerifierTargetArch(Operation *op) { - if (auto archName = getVerifierArchName(op)) { - return archName->equals_insensitive("a5") ? VerifierTargetArch::A5 - : VerifierTargetArch::A2A3; - } - - switch (getPTOParserTargetArch(op ? op->getContext() : nullptr)) { - case PTOParserTargetArch::A5: - return VerifierTargetArch::A5; - case PTOParserTargetArch::A3: - case PTOParserTargetArch::Unspecified: - return VerifierTargetArch::A2A3; - } - - return VerifierTargetArch::A2A3; -} - -static std::optional getVerifierArchName(Operation *op) { - auto module = op ? op->getParentOfType() : ModuleOp(); - if (!module) - return std::nullopt; - if (auto arch = module->getAttrOfType(kPTOTargetArchAttrName)) - return arch.getValue(); - return std::nullopt; -} - -static bool shouldBypassDecodedMemrefVerifier(Operation *op) { - if (!op) - return false; - for (Value operand : op->getOperands()) { - if (isa(operand.getType())) - return true; - if (operand.getDefiningOp()) - return true; - } - return false; -} - -static SmallVector canonicalizeTileBufValidShape(ArrayRef validShape) { - SmallVector canonical; - canonical.reserve(validShape.size()); - for (int64_t dim : validShape) - canonical.push_back(dim < 0 ? ShapedType::kDynamic : dim); - return canonical; -} - -template -static LogicalResult dispatchVerifierByArch(Operation *op, FnA2A3 &&verifyA2A3, - FnA5 &&verifyA5) { - if (shouldBypassDecodedMemrefVerifier(op)) - return success(); - switch (getVerifierTargetArch(op)) { - case VerifierTargetArch::A2A3: - return verifyA2A3(); - case VerifierTargetArch::A5: - return verifyA5(); - } - return failure(); -} - -static ParseResult parseSyncEventOpCommon(OpAsmParser &parser, - OperationState &result, - StringAttr pipeAttrName, - StringAttr eventIdAttrName) { - PipeAttr pipeAttr; - if (succeeded(parser.parseOptionalLess())) { - StringRef pipeTok; - if (parser.parseKeyword(&pipeTok) || parser.parseGreater()) - return failure(); - auto pipeOr = symbolizePIPE(pipeTok); - if (!pipeOr) - return parser.emitError(parser.getCurrentLocation()) - << "unknown pipe token: " << pipeTok; - pipeAttr = PipeAttr::get(parser.getContext(), *pipeOr); - result.addAttribute(pipeAttrName, pipeAttr); - } else if (parser.parseAttribute(pipeAttr, pipeAttrName, - result.attributes)) { - return failure(); - } - if (parser.parseComma()) - return failure(); - - OpAsmParser::UnresolvedOperand eventOperand; - OptionalParseResult parseEventOperand = - parser.parseOptionalOperand(eventOperand); - if (parseEventOperand.has_value()) { - if (failed(*parseEventOperand)) - return failure(); - if (parser.resolveOperand(eventOperand, parser.getBuilder().getIndexType(), - result.operands)) - return failure(); - } else { - IntegerAttr eventAttr; - if (parser.parseAttribute(eventAttr, parser.getBuilder().getI32Type(), - eventIdAttrName, result.attributes)) - return failure(); - } - - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - return success(); -} - -static void printSyncEventOpCommon(OpAsmPrinter &p, Operation *op, - PipeAttr pipeAttr, IntegerAttr eventAttr, - Value eventDyn, StringRef pipeAttrName, - StringRef eventIdAttrName) { - p << " <" << stringifyPIPE(pipeAttr.getPipe()) << ">, "; - if (eventAttr) - p << eventAttr.getInt(); - else - p << eventDyn; - p.printOptionalAttrDict(op->getAttrs(), {pipeAttrName, eventIdAttrName}); -} - -[[maybe_unused]] static mlir::Type parsePTOTypeAllowNoBang(mlir::OpAsmParser &parser) { - mlir::Type ty; - - mlir::OptionalParseResult opt = parser.parseOptionalType(ty); - - if (opt.has_value()) { - if (failed(*opt)) - return mlir::Type(); - return ty; - } - - - llvm::StringRef head; - if (failed(parser.parseKeyword(&head))) - return mlir::Type(); - - mlir::MLIRContext *ctx = parser.getContext(); - - auto parseShapeElemForOpParser = - [&](llvm::SmallVectorImpl &shape, mlir::Type &elem) -> mlir::LogicalResult { - if (failed(parser.parseLess())) - return failure(); - if (failed(parser.parseDimensionList(shape, /*allowDynamic=*/true))) - return failure(); - if (failed(parser.parseType(elem))) - return failure(); - if (failed(parser.parseGreater())) - return failure(); - return success(); - }; - - if (head == "pto.tile_view") { - llvm::SmallVector shape; - mlir::Type elem; - if (failed(parseShapeElemForOpParser(shape, elem))) - return mlir::Type(); - return mlir::pto::PartitionTensorViewType::get(ctx, shape, elem); - } - - if (head == "pto.tile") { - llvm::SmallVector shape; - mlir::Type elem; - if (failed(parseShapeElemForOpParser(shape, elem))) - return mlir::Type(); - return mlir::pto::TileType::get(ctx, shape, elem); - } - - if (head == "pto.ptr") { - if (failed(parser.parseLess())) - return mlir::Type(); - mlir::Type elem; - if (failed(parser.parseType(elem))) - return mlir::Type(); - if (succeeded(parser.parseOptionalComma())) { - // ptr no longer accepts an address space; consume the attr for recovery. - mlir::Attribute memorySpace; - (void)parser.parseAttribute(memorySpace); - parser.emitError(parser.getCurrentLocation(), - "!pto.ptr no longer accepts address space; use !pto.ptr"); - return mlir::Type(); - } - if (failed(parser.parseGreater())) - return mlir::Type(); - return mlir::pto::PtrType::get(ctx, elem); - } - - if (head == "pto.tensor_view") { - llvm::SmallVector shape; - mlir::Type elem; - if (failed(parseShapeElemForOpParser(shape, elem))) - return mlir::Type(); - return mlir::pto::TensorViewType::get(ctx, shape, elem); - } - - return mlir::Type(); -} - -mlir::Type TensorViewType::parse(::mlir::AsmParser &parser) { - SmallVector shape; - Type elementType; - if (failed(parseShapeAndElem(parser, shape, elementType, /*allowDynamic=*/true))) - return Type(); - return TensorViewType::get(parser.getContext(), shape, elementType); -} - -void TensorViewType::print(::mlir::AsmPrinter &printer) const { - printShapeAndElem(printer, getShape(), getElementType()); -} - -//===----------------------------------------------------------------------===// -// pto.tdivs custom asm to support both: -// pto.tdivs ins(%src, %scalar : !pto.tile_buf<...>, f32) outs(%dst : !pto.tile_buf<...>) -// pto.tdivs ins(%scalar, %src : f32, !pto.tile_buf<...>) outs(%dst : !pto.tile_buf<...>) -// The operand order in the op follows textual input order. -//===----------------------------------------------------------------------===// - -ParseResult mlir::pto::TDivSOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand op0, op1, dst; - Type ty0, ty1, dstTy; - - if (parser.parseKeyword("ins") || parser.parseLParen() || - parser.parseOperand(op0) || parser.parseComma() || - parser.parseOperand(op1) || parser.parseColonType(ty0) || - parser.parseComma() || parser.parseType(ty1) || parser.parseRParen()) - return failure(); - - if (parser.parseKeyword("outs") || parser.parseLParen() || - parser.parseOperand(dst) || parser.parseColonType(dstTy) || - parser.parseRParen()) - return failure(); - - NamedAttrList attrs; - if (parser.parseOptionalAttrDict(attrs)) - return failure(); - - auto tile0 = dyn_cast(ty0); - auto tile1 = dyn_cast(ty1); - if ((tile0 && tile1) || (!tile0 && !tile1)) - return parser.emitError(parser.getCurrentLocation(), - "expected exactly one tile_buf operand and one scalar operand"); - - if (!dyn_cast(dstTy)) - return parser.emitError(parser.getCurrentLocation(), - "expected outs type to be !pto.tile_buf<...>"); - - // Keep textual order so later lowering can distinguish the two APIs by the - // first ins operand type. - if (parser.resolveOperand(op0, ty0, result.operands) || - parser.resolveOperand(op1, ty1, result.operands)) - return failure(); - - if (parser.resolveOperand(dst, dstTy, result.operands)) - return failure(); - - result.addAttributes(attrs); - return success(); -} - -void mlir::pto::TDivSOp::print(OpAsmPrinter &p) { - p << " ins("; - p << getSrc() << ", " << getScalar() << " : " - << getSrc().getType() << ", " << getScalar().getType(); - p << ") outs(" << getDst() << " : " << getDst().getType() << ")"; - - p.printOptionalAttrDict((*this)->getAttrs()); -} - - -//===----------------------------------------------------------------------===// -// pto.tgather custom asm supports three PTO-ISA forms: -// 1) index+tmp : ins(%src, %indices, %tmp : srcTy, indicesTy, tmpTy) outs(%dst : dstTy) -// 2) compare+tmp : ins(%src, %kValue, %tmp : srcTy, scalarTy, tmpTy) -// outs(%dst, %cdst : dstTy, cdstTy) {cmpMode = #pto.cmp, offset = 7} -// 3) mask : ins(%src, {maskPattern = #pto.mask_pattern} : srcTy) outs(%dst : dstTy) -//===----------------------------------------------------------------------===// - -ParseResult mlir::pto::TGatherOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand src, dst, cdst; - SmallVector insOps; - SmallVector insTypes; - Type srcTy, dstTy, cdstTy; - bool hasCdst = false; - bool hasMask = false; - bool hasIndices = false; - bool hasTmp = false; - bool hasKValue = false; - - if (parser.parseKeyword("ins") || parser.parseLParen() || parser.parseOperand(src)) - return failure(); - - if (!succeeded(parser.parseOptionalComma())) { - return parser.emitError(parser.getCurrentLocation(), - "expected ',' after src operand in ins(...)"); - } - - if (succeeded(parser.parseOptionalLBrace())) { - if (parser.parseKeyword("maskPattern") || parser.parseEqual()) - return failure(); - - Attribute rawMaskAttr; - if (parser.parseAttribute(rawMaskAttr) || parser.parseRBrace()) - return failure(); - - auto mp = llvm::dyn_cast(rawMaskAttr); - if (!mp) { - return parser.emitError(parser.getCurrentLocation(), - "expected #pto.mask_pattern for maskPattern"); - } - - result.addAttribute("maskPattern", mp); - hasMask = true; - - if (parser.parseColonType(srcTy) || parser.parseRParen()) - return failure(); - } else { - OpAsmParser::UnresolvedOperand extra; - if (parser.parseOperand(extra)) - return failure(); - insOps.push_back(extra); - while (succeeded(parser.parseOptionalComma())) { - if (insOps.size() == 3) { - return parser.emitError(parser.getCurrentLocation(), - "expected at most 3 extra operands in tgather ins(...)"); - } - if (parser.parseOperand(extra)) - return failure(); - insOps.push_back(extra); - } - - if (parser.parseColon() || parser.parseType(srcTy)) - return failure(); - for (size_t i = 0; i < insOps.size(); ++i) { - Type ty; - if (parser.parseComma() || parser.parseType(ty)) - return failure(); - insTypes.push_back(ty); - } - if (parser.parseRParen()) - return failure(); - } - - if (parser.parseKeyword("outs") || parser.parseLParen() || parser.parseOperand(dst)) - return failure(); - if (succeeded(parser.parseOptionalComma())) { - if (parser.parseOperand(cdst)) - return failure(); - hasCdst = true; - } - if (parser.parseColonType(dstTy)) - return failure(); - if (hasCdst && (parser.parseComma() || parser.parseType(cdstTy))) - return failure(); - if (parser.parseRParen()) - return failure(); - - if (succeeded(parser.parseOptionalKeyword("maskPattern"))) { - if (hasMask) - return parser.emitError(parser.getCurrentLocation(), - "maskPattern may only be specified once"); - if (parser.parseEqual()) - return failure(); - Attribute rawMaskAttr; - if (parser.parseAttribute(rawMaskAttr)) - return failure(); - auto mp = llvm::dyn_cast(rawMaskAttr); - if (!mp) { - return parser.emitError(parser.getCurrentLocation(), - "expected #pto.mask_pattern for maskPattern"); - } - result.addAttribute("maskPattern", mp); - hasMask = true; - } - - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - - if (hasMask) { - if (!insOps.empty()) - return parser.emitError(parser.getCurrentLocation(), - "mask-pattern tgather does not take extra ins operands"); - if (hasCdst) - return parser.emitError(parser.getCurrentLocation(), - "mask-pattern tgather expects a single outs operand"); - } else if (hasCdst) { - if (insOps.empty() || - !(mlir::isa(insTypes.front()) || - mlir::isa(insTypes.front()))) - return parser.emitError(parser.getCurrentLocation(), - "compare-form tgather expects a scalar kValue operand"); - hasKValue = true; - if (insOps.size() >= 2) { - if (!isTileLikeType(insTypes[1])) - return parser.emitError(parser.getCurrentLocation(), - "compare-form tgather tmp must be tile-like"); - hasTmp = true; - } - if (insOps.size() == 3) { - return parser.emitError(parser.getCurrentLocation(), - "compare-form tgather expects at most src, kValue, tmp in ins(...)"); - } - } else { - if (!insOps.empty() && !isTileLikeType(insTypes.front())) { - return parser.emitError(parser.getCurrentLocation(), - "index-form tgather expects tile-like indices; " - "compare-form must use outs(dst, cdst)"); - } - if (!insOps.empty()) { - hasIndices = true; - if (insOps.size() >= 2) { - if (!isTileLikeType(insTypes[1])) - return parser.emitError(parser.getCurrentLocation(), - "index-form tgather tmp must be tile-like"); - hasTmp = true; - } - } - if (insOps.size() == 3) { - return parser.emitError(parser.getCurrentLocation(), - "index-form tgather expects at most src, indices, tmp in ins(...)"); - } - } - - if (parser.resolveOperand(src, srcTy, result.operands) || - parser.resolveOperand(dst, dstTy, result.operands)) - return failure(); - if (hasCdst && parser.resolveOperand(cdst, cdstTy, result.operands)) - return failure(); - if (hasIndices && parser.resolveOperand(insOps[0], insTypes[0], result.operands)) - return failure(); - if (hasTmp && parser.resolveOperand(insOps[hasIndices ? 1 : 1], insTypes[1], result.operands)) - return failure(); - if (hasKValue && parser.resolveOperand(insOps[0], insTypes[0], result.operands)) - return failure(); - - result.addAttribute("operandSegmentSizes", - parser.getBuilder().getDenseI32ArrayAttr( - {1, 1, hasCdst ? 1 : 0, hasIndices ? 1 : 0, - hasTmp ? 1 : 0, hasKValue ? 1 : 0})); - return success(); -} - -void mlir::pto::TGatherOp::print(OpAsmPrinter &p) { - p << " ins(" << getSrc() << ", "; - if (auto mp = getMaskPatternAttr()) { - p << "{maskPattern = " << mp << "} : " << getSrc().getType(); - } else if (getCdst()) { - p << getKValue(); - if (getTmp()) { - p << ", " << getTmp(); - p << " : " << getSrc().getType() << ", " << getKValue().getType() - << ", " << getTmp().getType(); - } else { - p << " : " << getSrc().getType() << ", " << getKValue().getType(); - } - } else { - p << getIndices(); - if (getTmp()) { - p << ", " << getTmp(); - p << " : " << getSrc().getType() << ", " << getIndices().getType() - << ", " << getTmp().getType(); - } else { - p << " : " << getSrc().getType() << ", " << getIndices().getType(); - } - } - p << ") outs(" << getDst(); - if (getCdst()) - p << ", " << getCdst(); - p << " : " << getDst().getType(); - if (getCdst()) - p << ", " << getCdst().getType(); - p << ")"; - - if (getMaskPatternAttr()) { - p.printOptionalAttrDict((*this)->getAttrs(), - /*elidedAttrs=*/{"maskPattern", "operandSegmentSizes"}); - } else { - p.printOptionalAttrDict((*this)->getAttrs(), - /*elidedAttrs=*/{"operandSegmentSizes"}); - } -} - -ParseResult mlir::pto::TScatterOp::parse(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::UnresolvedOperand src, indexes, dst; - Type srcTy, idxTy, dstTy; - bool hasMask = false; - bool hasIndexes = false; - - if (parser.parseKeyword("ins") || parser.parseLParen() || - parser.parseOperand(src)) - return failure(); - - if (!succeeded(parser.parseOptionalComma())) - return parser.emitError(parser.getCurrentLocation(), - "expected ',' after src operand in ins(...)"); - - if (succeeded(parser.parseOptionalLBrace())) { - if (parser.parseKeyword("maskPattern") || parser.parseEqual()) - return failure(); - Attribute rawMaskAttr; - if (parser.parseAttribute(rawMaskAttr) || parser.parseRBrace()) - return failure(); - auto mp = llvm::dyn_cast(rawMaskAttr); - if (!mp) - return parser.emitError(parser.getCurrentLocation(), - "expected #pto.mask_pattern for maskPattern"); - result.addAttribute("maskPattern", mp); - hasMask = true; - if (parser.parseColonType(srcTy) || parser.parseRParen()) - return failure(); - } else { - if (parser.parseOperand(indexes)) - return failure(); - hasIndexes = true; - if (parser.parseColon() || parser.parseType(srcTy) || parser.parseComma() || - parser.parseType(idxTy) || parser.parseRParen()) - return failure(); - } - - if (parser.parseKeyword("outs") || parser.parseLParen() || - parser.parseOperand(dst) || parser.parseColonType(dstTy) || - parser.parseRParen()) - return failure(); - - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - - if (result.attributes.get("maskPattern")) - hasMask = true; - - if (hasMask && hasIndexes) - return parser.emitError(parser.getCurrentLocation(), - "mask-pattern tscatter does not take indexes"); - if (!hasMask && !hasIndexes) - return parser.emitError(parser.getCurrentLocation(), - "expected indexes operand or maskPattern for tscatter"); - - if (parser.resolveOperand(src, srcTy, result.operands) || - parser.resolveOperand(dst, dstTy, result.operands) || - (hasIndexes && parser.resolveOperand(indexes, idxTy, result.operands))) - return failure(); - return success(); -} - -void mlir::pto::TScatterOp::print(OpAsmPrinter &p) { - p << " ins(" << getSrc() << ", "; - if (getMaskPatternAttr()) { - p << "{maskPattern = " << getMaskPatternAttr() << "} : " - << getSrc().getType(); - } else { - p << getIndexes() << " : " << getSrc().getType() << ", " - << getIndexes().getType(); - } - p << ") outs(" << getDst() << " : " << getDst().getType() << ")"; - p.printOptionalAttrDict((*this)->getAttrs(), - /*elidedAttrs=*/{"maskPattern"}); -} - -namespace { - -struct CommRecvClause { - OpAsmParser::UnresolvedOperand ping; - std::optional pong; - Type pingTy; - Type pongTy; -}; - -static ParseResult parseCommRecvClause(OpAsmParser &parser, - CommRecvClause &recvClause) { - if (parser.parseKeyword("recv") || parser.parseLParen() || - parser.parseOperand(recvClause.ping)) - return failure(); - if (succeeded(parser.parseOptionalComma())) { - OpAsmParser::UnresolvedOperand pong; - if (parser.parseOperand(pong)) - return failure(); - recvClause.pong = pong; - } - return parser.parseRParen(); -} - -static ParseResult parseCommCollectiveTail( - OpAsmParser &parser, OperationState &result, - ArrayRef fixedOperands, - SmallVectorImpl &fixedTypes, CommRecvClause &recvClause, - SmallVectorImpl &groupOps, - SmallVectorImpl &groupTypes, ArrayRef operandSegmentsPrefix, - ArrayRef requiredAttrs) { - if (parser.parseComma() || parser.parseKeyword("group") || parser.parseLParen()) - return failure(); - - OpAsmParser::UnresolvedOperand group; - if (parser.parseOperand(group)) - return failure(); - groupOps.push_back(group); - while (succeeded(parser.parseOptionalComma())) { - if (parser.parseOperand(group)) - return failure(); - groupOps.push_back(group); - } - - if (parser.parseRParen()) - return failure(); - - if (parser.parseColon()) - return failure(); - - for (size_t i = 0; i < fixedTypes.size(); ++i) { - if (i != 0 && parser.parseComma()) - return failure(); - if (parser.parseType(fixedTypes[i])) - return failure(); - } - if (parser.parseComma() || parser.parseType(recvClause.pingTy)) - return failure(); - if (recvClause.pong) { - if (parser.parseComma() || parser.parseType(recvClause.pongTy)) - return failure(); - } - for (size_t i = 0; i < groupOps.size(); ++i) { - Type groupTy; - if (parser.parseComma() || parser.parseType(groupTy)) - return failure(); - groupTypes.push_back(groupTy); - } - if (parser.parseRParen()) - return failure(); - - NamedAttrList attrs; - if (parser.parseOptionalAttrDict(attrs)) - return failure(); - for (StringRef attrName : requiredAttrs) { - if (!attrs.get(attrName)) { - return parser.emitError(parser.getCurrentLocation()) - << "expected '" << attrName << "' attribute"; - } - } - result.addAttributes(attrs); - - for (auto [operand, type] : llvm::zip_equal(fixedOperands, fixedTypes)) { - if (parser.resolveOperand(operand, type, result.operands)) - return failure(); - } - if (parser.resolveOperand(recvClause.ping, recvClause.pingTy, result.operands)) - return failure(); - if (recvClause.pong && - parser.resolveOperand(*recvClause.pong, recvClause.pongTy, result.operands)) - return failure(); - if (parser.resolveOperands(groupOps, groupTypes, parser.getCurrentLocation(), - result.operands)) - return failure(); - - SmallVector segmentSizes(operandSegmentsPrefix.begin(), - operandSegmentsPrefix.end()); - segmentSizes.push_back(static_cast(groupOps.size())); - result.addAttribute("operandSegmentSizes", - parser.getBuilder().getDenseI32ArrayAttr(segmentSizes)); - return success(); -} - -static void printCommRecvClause(OpAsmPrinter &p, Value ping, Value pong) { - p << "recv(" << ping; - if (pong) - p << ", " << pong; - p << ")"; -} - -static void printCommGroupTypes(OpAsmPrinter &p, ValueRange group) { - for (Value groupValue : group) - p << ", " << groupValue.getType(); -} - -static void printCommGroupClause(OpAsmPrinter &p, ValueRange group) { - p << "group("; - p.printOperands(group); - p << ")"; -} - -} // namespace - -ParseResult mlir::pto::TBroadcastOp::parse(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::UnresolvedOperand src; - CommRecvClause recvClause; - SmallVector groupOps; - SmallVector groupTypes; - - if (parser.parseLParen() || parser.parseOperand(src) || parser.parseComma()) - return failure(); - if (failed(parseCommRecvClause(parser, recvClause))) - return failure(); - - SmallVector fixedOperands{src}; - SmallVector fixedTypes(1); - if (failed(parseCommCollectiveTail(parser, result, fixedOperands, fixedTypes, - recvClause, groupOps, groupTypes, - {1, 1, recvClause.pong ? 1 : 0}, {"root"}))) - return failure(); - return success(); -} - -void mlir::pto::TBroadcastOp::print(OpAsmPrinter &p) { - p << "(" << getSrc() << ", "; - printCommRecvClause(p, getPing(), getPong()); - p << ", "; - printCommGroupClause(p, getGroup()); - p << " : " << getSrc().getType() << ", " << getPing().getType(); - if (getPong()) - p << ", " << getPong().getType(); - printCommGroupTypes(p, getGroup()); - p << ")"; - p.printOptionalAttrDict((*this)->getAttrs(), - /*elidedAttrs=*/{"operandSegmentSizes"}); -} - -ParseResult mlir::pto::CommTGatherOp::parse(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::UnresolvedOperand dst; - CommRecvClause recvClause; - SmallVector groupOps; - SmallVector groupTypes; - - if (parser.parseLParen() || parser.parseOperand(dst) || parser.parseComma()) - return failure(); - if (failed(parseCommRecvClause(parser, recvClause))) - return failure(); - - SmallVector fixedOperands{dst}; - SmallVector fixedTypes(1); - if (failed(parseCommCollectiveTail( - parser, result, fixedOperands, fixedTypes, recvClause, groupOps, - groupTypes, {1, 1, recvClause.pong ? 1 : 0}, - {"root"}))) - return failure(); - return success(); -} - -void mlir::pto::CommTGatherOp::print(OpAsmPrinter &p) { - p << "(" << getDst() << ", "; - printCommRecvClause(p, getPing(), getPong()); - p << ", "; - printCommGroupClause(p, getGroup()); - p << " : " << getDst().getType() << ", " << getPing().getType(); - if (getPong()) - p << ", " << getPong().getType(); - printCommGroupTypes(p, getGroup()); - p << ")"; - p.printOptionalAttrDict((*this)->getAttrs(), - /*elidedAttrs=*/{"operandSegmentSizes"}); -} - -ParseResult mlir::pto::CommTScatterOp::parse(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::UnresolvedOperand src; - CommRecvClause recvClause; - SmallVector groupOps; - SmallVector groupTypes; - - if (parser.parseLParen() || parser.parseOperand(src) || parser.parseComma()) - return failure(); - if (failed(parseCommRecvClause(parser, recvClause))) - return failure(); - - SmallVector fixedOperands{src}; - SmallVector fixedTypes(1); - if (failed(parseCommCollectiveTail( - parser, result, fixedOperands, fixedTypes, recvClause, groupOps, - groupTypes, {1, 1, recvClause.pong ? 1 : 0}, - {"root"}))) - return failure(); - return success(); -} - -void mlir::pto::CommTScatterOp::print(OpAsmPrinter &p) { - p << "(" << getSrc() << ", "; - printCommRecvClause(p, getPing(), getPong()); - p << ", "; - printCommGroupClause(p, getGroup()); - p << " : " << getSrc().getType() << ", " << getPing().getType(); - if (getPong()) - p << ", " << getPong().getType(); - printCommGroupTypes(p, getGroup()); - p << ")"; - p.printOptionalAttrDict((*this)->getAttrs(), - /*elidedAttrs=*/{"operandSegmentSizes"}); -} - -ParseResult mlir::pto::TReduceOp::parse(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::UnresolvedOperand dst, acc; - CommRecvClause recvClause; - SmallVector groupOps; - SmallVector groupTypes; - - if (parser.parseLParen() || parser.parseOperand(dst) || parser.parseComma() || - parser.parseOperand(acc) || parser.parseComma()) - return failure(); - if (failed(parseCommRecvClause(parser, recvClause))) - return failure(); - - SmallVector fixedOperands{dst, acc}; - SmallVector fixedTypes(2); - if (failed(parseCommCollectiveTail( - parser, result, fixedOperands, fixedTypes, recvClause, groupOps, - groupTypes, {1, 1, 1, recvClause.pong ? 1 : 0}, - {"reduceOp", "root"}))) - return failure(); - return success(); -} - -void mlir::pto::TReduceOp::print(OpAsmPrinter &p) { - p << "(" << getDst() << ", " << getAcc() << ", "; - printCommRecvClause(p, getRecvPing(), getRecvPong()); - p << ", "; - printCommGroupClause(p, getGroup()); - p << " : " << getDst().getType() << ", " << getAcc().getType() << ", " - << getRecvPing().getType(); - if (getRecvPong()) - p << ", " << getRecvPong().getType(); - printCommGroupTypes(p, getGroup()); - p << ")"; - p.printOptionalAttrDict((*this)->getAttrs(), - /*elidedAttrs=*/{"operandSegmentSizes"}); -} - -ParseResult mlir::pto::MakeTensorViewOp::parse(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::UnresolvedOperand ptr; - SmallVector shapeOps; - SmallVector strideOps; - - Type resultTy; - - // %ptr - if (parser.parseOperand(ptr)) - return failure(); - - // , shape = [ ... ] - if (parser.parseComma() || parser.parseKeyword("shape") || parser.parseEqual() || - parser.parseLSquare() || - parser.parseOperandList(shapeOps) || - parser.parseRSquare()) - return failure(); - - // strides = [ ... ] - if (parser.parseComma() || parser.parseKeyword("strides") || parser.parseEqual() || - parser.parseLSquare() || - parser.parseOperandList(strideOps) || - parser.parseRSquare()) - return failure(); - - // attr-dict - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - - // : result-type - if (parser.parseColonType(resultTy)) - return failure(); - result.addTypes(resultTy); - - auto tvTy = llvm::dyn_cast(resultTy); - if (!tvTy) - return parser.emitError(parser.getCurrentLocation(), - "expected result type pto.tensor_view<...>"); - - Type elemTy = tvTy.getElementType(); - - Type ptrTy = mlir::pto::PtrType::get(parser.getContext(), elemTy); - - // resolve %ptr - if (parser.resolveOperand(ptr, ptrTy, result.operands)) - return failure(); - - // resolve shape/strides 为 index - Type indexTy = parser.getBuilder().getIndexType(); - if (parser.resolveOperands(shapeOps, indexTy, result.operands)) - return failure(); - if (parser.resolveOperands(strideOps, indexTy, result.operands)) - return failure(); - - auto segAttr = parser.getBuilder().getDenseI32ArrayAttr( - {1, (int32_t)shapeOps.size(), (int32_t)strideOps.size()}); - result.addAttribute("operandSegmentSizes", segAttr); - - return success(); -} - -void mlir::pto::MakeTensorViewOp::print(OpAsmPrinter &p) { - p << " " << getPtr(); - - p << ", shape = ["; - p.printOperands(getShape()); - p << "]"; - - p << ", strides = ["; - p.printOperands(getStrides()); - p << "]"; - - p.printOptionalAttrDict((*this)->getAttrs(), - /*elidedAttrs=*/{"operandSegmentSizes"}); - - p << " : " << getResult().getType(); -} - -// Layout inference helpers for make_tensor_view -static std::optional getConstIndexValue(Value v) { - if (auto c = v.getDefiningOp()) - return c.value(); - if (auto c = v.getDefiningOp()) { - if (auto ia = dyn_cast(c.getValue())) - return ia.getInt(); - } - return std::nullopt; -} - -static FailureOr -inferPartitionViewResultTypeFromSizes(mlir::pto::TensorViewType sourceType, - ValueRange sizes) { - if (!sourceType) - return failure(); - - if ((int64_t)sizes.size() != sourceType.getRank()) - return failure(); - - SmallVector shape; - shape.reserve(sizes.size()); - for (Value size : sizes) { - auto constSize = getConstIndexValue(size); - if (constSize && *constSize >= 0) - shape.push_back(*constSize); - else - shape.push_back(ShapedType::kDynamic); - } - - return mlir::pto::PartitionTensorViewType::get( - sourceType.getContext(), shape, sourceType.getElementType()); -} - -ParseResult mlir::pto::PartitionViewOp::parse(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::UnresolvedOperand source; - SmallVector offsets; - SmallVector sizes; - Type sourceTy; - Type resultTy; - bool hasExplicitResultTy = false; - - if (parser.parseOperand(source) || parser.parseComma() || - parser.parseKeyword("offsets") || parser.parseEqual() || - parser.parseLSquare() || parser.parseOperandList(offsets) || - parser.parseRSquare() || parser.parseComma() || - parser.parseKeyword("sizes") || parser.parseEqual() || - parser.parseLSquare() || parser.parseOperandList(sizes) || - parser.parseRSquare() || parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(sourceTy)) - return failure(); - - if (succeeded(parser.parseOptionalArrow())) { - if (parser.parseType(resultTy)) - return failure(); - hasExplicitResultTy = true; - } - - if (parser.resolveOperand(source, sourceTy, result.operands)) - return failure(); - - Type indexTy = parser.getBuilder().getIndexType(); - if (parser.resolveOperands(offsets, indexTy, result.operands) || - parser.resolveOperands(sizes, indexTy, result.operands)) - return failure(); - - auto &properties = result.getOrAddProperties(); - llvm::copy(ArrayRef( - {1, static_cast(offsets.size()), - static_cast(sizes.size())}), - properties.operandSegmentSizes.begin()); - - if (hasExplicitResultTy) { - result.addTypes(resultTy); - return success(); - } - - ValueRange allOperands(result.operands); - ValueRange sizeOperands = - allOperands.slice(1 + offsets.size(), sizes.size()); - auto inferredResultType = inferPartitionViewResultTypeFromSizes( - dyn_cast(sourceTy), sizeOperands); - if (failed(inferredResultType)) { - return parser.emitError(parser.getCurrentLocation(), - "failed to infer pto.partition_view result type"); - } - - result.addTypes(*inferredResultType); - return success(); -} - -void mlir::pto::PartitionViewOp::print(OpAsmPrinter &printer) { - printer << " " << getSource() << ", offsets = ["; - printer.printOperands(getOffsets()); - printer << "], sizes = ["; - printer.printOperands(getSizes()); - printer << "]"; - printer.printOptionalAttrDict((*this)->getAttrs(), - /*elidedAttrs=*/{"operandSegmentSizes"}); - printer << " : " << getSource().getType(); - - auto inferredResultType = inferPartitionViewResultTypeFromSizes( - dyn_cast(getSource().getType()), getSizes()); - if (succeeded(inferredResultType) && *inferredResultType == getResult().getType()) - return; - - printer << " -> " << getResult().getType(); -} - -static std::optional getConstantIntegerValueEx( - Value v, bool includeIndexAndIntOpsInConstFold) { - if (includeIndexAndIntOpsInConstFold) { - if (auto c = v.getDefiningOp()) - return c.value(); - if (auto c = v.getDefiningOp()) - return c.value(); - } - if (auto c = v.getDefiningOp()) { - if (auto ia = dyn_cast(c.getValue())) - return ia.getInt(); - } - return std::nullopt; -} - -static LogicalResult verifyNonNegativeIndexRowCol( - Operation &op, Value indexRow, Value indexCol, - bool includeIndexAndIntOpsInConstFold) { - if (!indexRow.getType().isIndex() || !indexCol.getType().isIndex()) - return op.emitOpError("expects indexRow and indexCol to be index type"); - auto row = - getConstantIntegerValueEx(indexRow, includeIndexAndIntOpsInConstFold); - auto col = - getConstantIntegerValueEx(indexCol, includeIndexAndIntOpsInConstFold); - if (row && *row < 0) - return op.emitOpError("expects indexRow to be non-negative"); - if (col && *col < 0) - return op.emitOpError("expects indexCol to be non-negative"); - return success(); -} - -static LogicalResult verifyExtractStaticBoundsCommon( - Operation &op, Value indexRow, Value indexCol, Type srcTy, Type dstTy, - bool includeIndexAndIntOpsInConstFold) { - auto row = - getConstantIntegerValueEx(indexRow, includeIndexAndIntOpsInConstFold); - auto col = - getConstantIntegerValueEx(indexCol, includeIndexAndIntOpsInConstFold); - auto srcShape = getShapeVec(srcTy); - auto dstShape = getShapeVec(dstTy); - if (srcShape.size() != 2 || dstShape.size() != 2) - return op.emitOpError("expects src and dst to be rank-2 tile_buf"); - if (row && srcShape[0] != ShapedType::kDynamic && - dstShape[0] != ShapedType::kDynamic && - *row + dstShape[0] > srcShape[0]) - return op.emitOpError("expects indexRow + dst.rows <= src.rows"); - if (col && srcShape[1] != ShapedType::kDynamic && - dstShape[1] != ShapedType::kDynamic && - *col + dstShape[1] > srcShape[1]) - return op.emitOpError("expects indexCol + dst.cols <= src.cols"); - return success(); -} - -static LogicalResult verifyInsertStaticBoundsCommon( - Operation &op, Value indexRow, Value indexCol, Type srcTy, Type dstTy, - bool includeIndexAndIntOpsInConstFold) { - auto row = - getConstantIntegerValueEx(indexRow, includeIndexAndIntOpsInConstFold); - auto col = - getConstantIntegerValueEx(indexCol, includeIndexAndIntOpsInConstFold); - auto srcShape = getValidShapeVec(srcTy); - auto dstShape = getShapeVec(dstTy); - if (srcShape.size() != 2 || dstShape.size() != 2) - return op.emitOpError("expects src and dst to be rank-2 tile_buf"); - if (row && srcShape[0] != ShapedType::kDynamic && - dstShape[0] != ShapedType::kDynamic && - *row + srcShape[0] > dstShape[0]) - return op.emitOpError("expects indexRow + src.rows <= dst.rows"); - if (col && srcShape[1] != ShapedType::kDynamic && - dstShape[1] != ShapedType::kDynamic && - *col + srcShape[1] > dstShape[1]) - return op.emitOpError("expects indexCol + src.cols <= dst.cols"); - return success(); -} - -static unsigned getElemByteSize(Type ty) { - return getPTOStorageElemByteSize(ty); -} - -static LogicalResult verifyTileBufLayoutConstraints(Operation *op, - pto::TileBufType tb, - StringRef name) { - auto shape = tb.getShape(); - if (shape.size() != 2) - return op->emitOpError() << "expects " << name << " to be rank-2"; - - int64_t rows = shape[0]; - int64_t cols = shape[1]; - if (rows != ShapedType::kDynamic && rows <= 0) - return op->emitOpError() << "expects " << name << " rows to be positive"; - if (cols != ShapedType::kDynamic && cols <= 0) - return op->emitOpError() << "expects " << name << " cols to be positive"; - - unsigned elemBytes = getElemByteSize(tb.getElementType()); - if (elemBytes == 0) - return op->emitOpError() << "expects " << name - << " element type to have a byte size"; - - auto cfg = tb.getConfigAttr(); - if (!cfg) - cfg = TileBufConfigAttr::getDefault(tb.getContext()); - auto readBLayout = [](Attribute attr, int32_t &out) -> bool { - if (auto layout = dyn_cast_or_null(attr)) { - out = static_cast(layout.getValue()); - return true; - } - if (auto value = dyn_cast_or_null(attr)) { - out = static_cast(value.getInt()); - return true; - } - return false; - }; - auto readSLayout = [](Attribute attr, int32_t &out) -> bool { - if (auto layout = dyn_cast_or_null(attr)) { - out = static_cast(layout.getValue()); - return true; - } - if (auto value = dyn_cast_or_null(attr)) { - out = static_cast(value.getInt()); - return true; - } - return false; - }; - int32_t blayout = 0; - int32_t slayout = 0; - if (!readBLayout(cfg.getBLayout(), blayout) || - !readSLayout(cfg.getSLayout(), slayout)) - return op->emitOpError() << "expects " << name - << " to have concrete tile layout attributes"; - constexpr int64_t kAlignedBytes = 32; - - auto checkByteAlignment = [&](int64_t dim, StringRef layoutName, - StringRef byteExpr) -> LogicalResult { - if (dim == ShapedType::kDynamic) - return success(); - int64_t bytes = dim * static_cast(elemBytes); - if (bytes % kAlignedBytes == 0) - return success(); - return op->emitOpError() - << "expects " << name << " " << layoutName - << " none_box tile " << byteExpr - << " to be 32-byte aligned, but got " << bytes << " bytes"; - }; - - if (slayout == static_cast(SLayout::NoneBox)) { - if (blayout == static_cast(BLayout::RowMajor)) - return checkByteAlignment(cols, "row-major", - "row byte size (cols * sizeof(dtype))"); - return checkByteAlignment(rows, "col-major", - "column byte size (rows * sizeof(dtype))"); - } - - int64_t innerRows = 0; - int64_t innerCols = 0; - int32_t fractal = static_cast(cfg.getSFractalSize().getInt()); - switch (fractal) { - case 1024: - innerRows = 16; - innerCols = 16; - break; - case 32: - innerRows = 16; - innerCols = 2; - break; - case 512: - if (kAlignedBytes % elemBytes != 0) - return op->emitOpError() << "expects " << name - << " element byte size to divide 32 for boxed " - "fractal-512 tile layout"; - if (slayout == static_cast(SLayout::RowMajor)) { - innerRows = 16; - innerCols = kAlignedBytes / static_cast(elemBytes); - } else if (slayout == static_cast(SLayout::ColMajor)) { - innerRows = kAlignedBytes / static_cast(elemBytes); - innerCols = 16; - } - break; - default: - break; - } - if (innerRows <= 0 || innerCols <= 0) - return op->emitOpError() << "expects " << name - << " to use a supported boxed tile layout"; - - auto loc = getPTOMemorySpaceEnum(tb); - bool allowUnalignedRows = - (loc && *loc == pto::AddressSpace::VEC) || fractal == 32 || rows == 1; - if (!allowUnalignedRows && rows != ShapedType::kDynamic && - rows % innerRows != 0) - return op->emitOpError() - << "expects " << name - << " boxed tile rows to be a multiple of innerRows (" << innerRows - << "), but got " << rows; - if (cols != ShapedType::kDynamic && cols % innerCols != 0) - return op->emitOpError() - << "expects " << name - << " boxed tile cols to be a multiple of innerCols (" << innerCols - << "), but got " << cols; - - return success(); -} - -[[maybe_unused]] static bool isSupportedLoadStoreElemTypeA2A3(Type ty) { - if (ty.isF16() || ty.isBF16() || ty.isF32()) - return true; - if (auto it = dyn_cast(ty)) { - unsigned width = it.getWidth(); - return width == 8 || width == 16 || width == 32 || width == 64; - } - return false; -} - -static bool isSupportedGatherElemTypeA2A3(Type ty) { - if (ty.isF16() || ty.isF32()) - return true; - if (auto it = dyn_cast(ty)) { - unsigned width = it.getWidth(); - return width == 16 || width == 32; - } - return false; -} - -static bool isSupportedGatherElemTypeA5(Type ty) { - if (isSupportedGatherElemTypeA2A3(ty) || ty.isBF16()) - return true; - if (auto ft = dyn_cast(ty)) { - unsigned width = ft.getWidth(); - return width == 8; - } - if (auto it = dyn_cast(ty)) - return it.getWidth() == 8 || it.getWidth() == 16 || it.getWidth() == 32; - return false; -} - -static std::optional -inferLayout(ArrayRef shape, ArrayRef strides, - unsigned elemBytes) { - if (shape.size() != strides.size() || elemBytes == 0) - return std::nullopt; - - // NZ / fractal: rank>=5, check middle dims (sh3/sh4/sh5 per spec) - if (shape.size() >= 5) { - int64_t sh3 = shape[2], sh4 = shape[3], sh5 = shape[4]; - int64_t st4 = strides[3], st5 = strides[4]; - bool alignMatch = (sh3 == 16) && (sh3 * sh4 * elemBytes == 512); - bool strideMatch = (st5 == 1) && (st4 == sh5); - if (alignMatch && strideMatch) - return mlir::pto::Layout::NZ; - } - - // ND: row-major contiguous - bool isRowMajor = true; - for (int i = 0, e = (int)shape.size() - 1; i < e; ++i) { - if (strides[i] != strides[i + 1] * shape[i + 1]) { - isRowMajor = false; - break; - } - } - if (isRowMajor && strides.back() == 1) - return mlir::pto::Layout::ND; - - // DN: col-major - bool isColMajor = true; - for (int i = 0, e = (int)shape.size() - 1; i < e; ++i) { - if (strides[i + 1] != strides[i] * shape[i]) { - isColMajor = false; - break; - } - } - if (isColMajor && strides.front() == 1) - return mlir::pto::Layout::DN; - - return mlir::pto::Layout::ND; // fallback -} - -static std::optional getLogicalViewLayout(Value value) { - if (!value) - return std::nullopt; - if (auto part = value.getDefiningOp()) - return getLogicalViewLayout(part.getSource()); - if (auto make = value.getDefiningOp()) { - auto tvTy = dyn_cast(make.getResult().getType()); - if (!tvTy) - return std::nullopt; - SmallVector shape(tvTy.getShape().begin(), tvTy.getShape().end()); - SmallVector strides; - strides.reserve(make.getStrides().size()); - for (Value stride : make.getStrides()) { - auto cst = getConstIndexValue(stride); - if (!cst) - return std::nullopt; - strides.push_back(*cst); - } - return inferLayout(shape, strides, getElemByteSize(tvTy.getElementType())); - } - return std::nullopt; -} - -static std::optional getTileBufLogicalLayout(pto::TileBufType type) { - if (!type) - return std::nullopt; - int32_t sl = type.getSLayoutValueI32(); - int32_t bl = type.getBLayoutValueI32(); - if (sl != static_cast(pto::SLayout::NoneBox)) - return pto::Layout::NZ; - if (bl == static_cast(pto::BLayout::RowMajor)) - return pto::Layout::ND; - if (bl == static_cast(pto::BLayout::ColMajor)) - return pto::Layout::DN; - return std::nullopt; -} - -static bool isRowMajorTileBuf(Type ty) { - auto tb = mlir::dyn_cast(ty); - return tb && tb.getBLayoutValueI32() == static_cast(pto::BLayout::RowMajor); -} - -static LogicalResult verifyRowReductionSrcLayout(Operation *op, Type ty, - StringRef name) { - if (failed(verifyTileBufCommon(op, ty, name))) - return failure(); - auto as = getPTOMemorySpaceEnum(ty); - if (!as || *as != pto::AddressSpace::VEC) - return op->emitOpError() << "expects " << name << " to be in the vec address space"; - if (auto tb = dyn_cast(ty)) { - if (tb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor)) - return op->emitOpError() << "expects " << name << " to use the row_major blayout"; - } - if (auto mr = dyn_cast(ty)) - (void)mr; - if (auto tb = dyn_cast(ty)) { - if (tb.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) - return op->emitOpError() << "expects " << name - << " to use the none_box slayout"; - } - if (auto tb = dyn_cast(ty)) { - auto layout = getTileBufLogicalLayout(tb); - if (layout && *layout != pto::Layout::ND) - return op->emitOpError() << "expects " << name - << " to use an ND-style tile layout"; - } - return success(); -} - -static LogicalResult verifyRowReductionDstLayout(Operation *op, Type ty, - StringRef name) { - if (failed(verifyTileBufCommon(op, ty, name))) - return failure(); - auto as = getPTOMemorySpaceEnum(ty); - if (!as || *as != pto::AddressSpace::VEC) - return op->emitOpError() << "expects " << name << " to be in the vec address space"; - if (auto tb = dyn_cast(ty)) { - if (tb.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) - return op->emitOpError() << "expects " << name - << " to use the none_box slayout"; - } - if (auto tb = dyn_cast(ty)) { - if (tb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) && - tb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor)) - return op->emitOpError() << "expects " << name - << " to use the row_major or col_major blayout"; - } - if (auto mr = dyn_cast(ty)) - (void)mr; - if (auto tb = dyn_cast(ty)) { - auto layout = getTileBufLogicalLayout(tb); - if (layout && *layout == pto::Layout::DN) { - auto shape = getShapeVec(ty); - if (shape.size() == 2 && shape[1] != ShapedType::kDynamic && shape[1] != 1) - return op->emitOpError() << "expects DN-style " << name - << " to have shape[1] == 1"; - return success(); - } - if (layout && *layout == pto::Layout::ND) - return success(); - if (layout) - return op->emitOpError() << "expects " << name - << " to use a DN-style column vector tile or legacy ND-style tile"; - } - return success(); - auto valid = getValidShapeVec(ty); - if (valid.size() != 2) - return op->emitOpError() << "expects " << name << " to have rank-2 valid_shape"; - if (valid[1] != ShapedType::kDynamic && valid[1] != 1) - return op->emitOpError() << "expects " << name << " valid_shape[1] to be 1"; - return success(); -} - -static LogicalResult verifyRowReductionValidRegion(Operation *op, Type srcTy, - Type dstTy) { - auto srcValid = getValidShapeVec(srcTy); - auto dstValid = getValidShapeVec(dstTy); - if (srcValid.size() != 2 || dstValid.size() != 2) - return op->emitOpError("expects src and dst to have rank-2 valid_shape"); - if (srcValid[0] != ShapedType::kDynamic && srcValid[0] == 0) - return op->emitOpError("expects src valid_shape[0] to be non-zero"); - if (srcValid[1] != ShapedType::kDynamic && srcValid[1] == 0) - return op->emitOpError("expects src valid_shape[1] to be non-zero"); - if (srcValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && - srcValid[0] != dstValid[0]) - return op->emitOpError("expects src and dst to have the same valid_shape[0]"); - if (dstValid[1] != ShapedType::kDynamic && dstValid[1] != 1) - return op->emitOpError("expects dst valid_shape[1] to be 1"); - return success(); -} - -static bool isSupportedRowReductionElemType(Type elem) { - return elem.isInteger(16) || elem.isInteger(32) || elem.isF16() || - elem.isF32(); -} - -static LogicalResult verifyTRowReductionNoTmpCommon(Operation *op, Type srcTy, - Type dstTy, - StringRef elemTypeError) { - if (failed(verifyRowReductionSrcLayout(op, srcTy, "src")) || - failed(verifyRowReductionDstLayout(op, dstTy, "dst"))) - return failure(); - if (getElemTy(srcTy) != getElemTy(dstTy)) - return op->emitOpError("expects src and dst to have the same element type"); - if (failed(verifyRowReductionValidRegion(op, srcTy, dstTy))) - return failure(); - if (!isSupportedRowReductionElemType(getElemTy(srcTy))) - return op->emitOpError(elemTypeError); - return success(); -} - -static LogicalResult verifyTRowReductionWithTmpCommon(Operation *op, Type srcTy, - Type tmpTy, Type dstTy, - StringRef elemTypeError) { - if (failed(verifyRowReductionSrcLayout(op, srcTy, "src")) || - failed(verifyVecTileCommon(op, tmpTy, "tmp")) || - failed(verifyRowReductionDstLayout(op, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameElemType(op, srcTy, tmpTy, "src", "tmp")) || - failed(verifyTileBufSameValidShape(op, srcTy, tmpTy, "src", "tmp"))) - return failure(); - if (getElemTy(srcTy) != getElemTy(dstTy)) - return op->emitOpError("expects src and dst to have the same element type"); - if (failed(verifyRowReductionValidRegion(op, srcTy, dstTy))) - return failure(); - if (!isSupportedRowReductionElemType(getElemTy(srcTy))) - return op->emitOpError(elemTypeError); - return success(); -} - -static LogicalResult verifyTRowArgReductionCommon(Operation *op, Type srcTy, - Type tmpTy, Type dstTy) { - if (failed(verifyRowReductionSrcLayout(op, srcTy, "src")) || - failed(verifyVecTileCommon(op, tmpTy, "tmp")) || - failed(verifyRowReductionDstLayout(op, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameElemType(op, srcTy, tmpTy, "src", "tmp")) || - failed(verifyTileBufSameValidShape(op, srcTy, tmpTy, "src", "tmp"))) - return failure(); - if (failed(verifyRowReductionValidRegion(op, srcTy, dstTy))) - return failure(); - Type srcElem = getElemTy(srcTy); - if (!isSupportedRowReductionElemType(srcElem)) - return op->emitOpError("expects src element type to be i16/i32/f16/f32"); - auto dstInt = dyn_cast(getElemTy(dstTy)); - if (!dstInt || dstInt.getWidth() != 32) - return op->emitOpError("expects dst element type to be i32 or ui32"); - return success(); -} - -static LogicalResult verifyNDStyleVecTile(Operation *op, Type ty, StringRef name) { - if (failed(verifyTileBufCommon(op, ty, name))) - return failure(); - auto as = getPTOMemorySpaceEnum(ty); - if (!as || *as != pto::AddressSpace::VEC) - return op->emitOpError() << "expects " << name << " to be in the vec address space"; - if (auto tb = dyn_cast(ty)) { - if (tb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor)) - return op->emitOpError() << "expects " << name << " to use the row_major blayout"; - if (tb.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) - return op->emitOpError() << "expects " << name << " to use the none_box slayout"; - } - return success(); -} - -static LogicalResult verifyColReductionValidRegion(Operation *op, Type srcTy, - Type dstTy, - bool requireNonZeroSrc) { - auto srcValid = getValidShapeVec(srcTy); - auto dstValid = getValidShapeVec(dstTy); - if (srcValid.size() != 2 || dstValid.size() != 2) - return op->emitOpError("expects src and dst to have rank-2 valid_shape"); - if (requireNonZeroSrc) { - if (srcValid[0] != ShapedType::kDynamic && srcValid[0] == 0) - return op->emitOpError("expects src valid_shape[0] to be non-zero"); - if (srcValid[1] != ShapedType::kDynamic && srcValid[1] == 0) - return op->emitOpError("expects src valid_shape[1] to be non-zero"); - } - if (srcValid[1] != ShapedType::kDynamic && dstValid[1] != ShapedType::kDynamic && - srcValid[1] != dstValid[1]) - return op->emitOpError("expects src and dst to have the same valid_shape[1]"); - return success(); -} - -static LogicalResult verifyColArgReductionDstLayout(Operation *op, Type ty, - StringRef name) { - if (failed(verifyNDStyleVecTile(op, ty, name))) - return failure(); - auto valid = getValidShapeVec(ty); - if (valid.size() != 2) - return op->emitOpError() << "expects " << name - << " to have rank-2 valid_shape"; - if (valid[0] != ShapedType::kDynamic && valid[0] != 1) - return op->emitOpError() << "expects " << name - << " valid_shape[0] to be 1"; - return success(); -} - -static std::optional getConstantIntegerValue(Value value) { - if (!value) - return std::nullopt; - if (auto arithCst = value.getDefiningOp()) { - if (auto intAttr = dyn_cast(arithCst.getValue())) - return intAttr.getInt(); - } - return std::nullopt; -} - -LogicalResult mlir::pto::MakeTensorViewOp::verify() { - auto tvTy = dyn_cast(getResult().getType()); - if (!tvTy) - return emitOpError("result must be pto.tensor_view<...>"); - - auto pty = dyn_cast(getPtr().getType()); - if (!pty) - return emitOpError("ptr operand must be !pto.ptr<...>"); - - if (pty.getElementType() != tvTy.getElementType()) - return emitOpError() << "ptr element type must match tensor_view element type, but got ptr=" - << pty.getElementType() << " view=" << tvTy.getElementType(); - - int64_t rank = tvTy.getRank(); - - if ((int64_t)getShape().size() != rank || (int64_t)getStrides().size() != rank) - return emitOpError() << "shape/strides operand counts must match tensor_view rank=" - << rank; - - // Detect dynamic shape/stride. - bool hasDynamicShape = llvm::any_of(tvTy.getShape(), [](int64_t v) { - return v == ShapedType::kDynamic; - }); - bool hasDynamicStride = llvm::any_of(getStrides(), [](Value s) { - return !getConstIndexValue(s).has_value(); - }); - - auto layoutAttr = getLayoutAttr(); - - // 1) Dynamic shape/stride without explicit layout: warn and keep going. - if ((hasDynamicShape || hasDynamicStride) && !layoutAttr) { - return success(); - } - - // 2) Static shape/stride with explicit layout: verify correctness. - bool allStaticStride = true; - SmallVector strideInts; - strideInts.reserve(getStrides().size()); - for (Value s : getStrides()) { - auto val = getConstIndexValue(s); - if (!val) { - allStaticStride = false; - break; - } - strideInts.push_back(*val); - } - - bool allStaticShape = - llvm::none_of(tvTy.getShape(), [](int64_t v) { return v == ShapedType::kDynamic; }); - - if (layoutAttr && allStaticShape && allStaticStride) { - SmallVector shapeInts(tvTy.getShape().begin(), tvTy.getShape().end()); - if (auto inferred = inferLayout(shapeInts, strideInts, - getElemByteSize(tvTy.getElementType()))) { - (void)inferred; - } - } - - return success(); -} - -LogicalResult mlir::pto::PartitionViewOp::verify() { - auto srcTy = dyn_cast(getSource().getType()); - auto resTy = dyn_cast(getResult().getType()); - if (!srcTy || !resTy) - return emitOpError("expects tensor_view source and partition_tensor_view result"); - - if (srcTy.getElementType() != resTy.getElementType()) - return emitOpError() << "element type mismatch between source and result: src=" - << srcTy.getElementType() << " result=" - << resTy.getElementType(); - - int64_t srcRank = srcTy.getRank(); - if ((int64_t)getOffsets().size() != srcRank) - return emitOpError() << "offset count (" << getOffsets().size() - << ") must match source rank (" << srcRank << ")"; - - if ((int64_t)getSizes().size() != srcRank) - return emitOpError() << "size count (" << getSizes().size() - << ") must match source rank (" << srcRank << ")"; - - ArrayRef srcShape = srcTy.getShape(); - ArrayRef resShape = resTy.getShape(); - bool sameRank = resTy.getRank() == srcRank; - - for (int64_t i = 0; i < srcRank; ++i) { - auto offVal = getConstIndexValue(getOffsets()[i]); - auto sizeVal = getConstIndexValue(getSizes()[i]); - - if (offVal && *offVal < 0) - return emitOpError() << "offset at dim " << i - << " must be non-negative, got " << *offVal; - - if (sizeVal && *sizeVal <= 0) - return emitOpError() << "size at dim " << i - << " must be positive, got " << *sizeVal; - - if (sameRank && sizeVal) { - int64_t resDim = resShape[i]; - if (resDim != ShapedType::kDynamic && *sizeVal != resDim) - return emitOpError() << "size/result mismatch at dim " << i - << ": size operand=" << *sizeVal - << " result type dim=" << resDim; - } - - int64_t srcDim = srcShape[i]; - if (srcDim == ShapedType::kDynamic) - continue; - - if (sizeVal && *sizeVal > srcDim) - return emitOpError() << "size at dim " << i << " (" << *sizeVal - << ") exceeds static source dim (" << srcDim << ")"; - - if (offVal && sizeVal && (*offVal + *sizeVal > srcDim)) - return emitOpError() << "offset+size at dim " << i << " (" - << (*offVal + *sizeVal) - << ") exceeds static source dim (" << srcDim << ")"; - } - - return success(); -} - -LogicalResult mlir::pto::AddPtrOp::verify() { - Value ptr = getOperation()->getOperand(0); - Value result = getOperation()->getResult(0); - - auto ptrTy = dyn_cast(ptr.getType()); - if (!ptrTy) - return emitOpError("ptr operand must be !pto.ptr<...>"); - - auto resTy = dyn_cast(result.getType()); - if (!resTy) - return emitOpError("result must be !pto.ptr<...>"); - - if (ptrTy != resTy) - return emitOpError("result type must match ptr operand type"); - - return success(); -} - -static LogicalResult verifyPtrLikeForAddressCast(Operation *op, Type type, - StringRef name) { - if (isa(type)) - return success(); - - auto memTy = dyn_cast(type); - if (!memTy) - return op->emitOpError() - << "expects " << name << " to be !pto.ptr<...> or a GM memref"; - - if (memTy.getRank() != 1) - return op->emitOpError() - << "expects lowered memref " << name << " to be rank-1"; - - if (!isGmAddressSpaceAttr(memTy.getMemorySpace())) - return op->emitOpError() - << "expects lowered memref " << name << " to use GM address space"; - - return success(); -} - -static Type getPointerLikeElementType(Type type) { - if (auto ptrTy = dyn_cast(type)) - return ptrTy.getElementType(); - if (auto memTy = dyn_cast(type)) - return memTy.getElementType(); - return Type(); -} - -static bool isEmitCSupportedScalarType(Type type) { - if (!type) - return false; - if (type.isF16() || type.isBF16() || type.isF32() || type.isF64()) - return true; - if (auto intTy = dyn_cast(type)) - return intTy.getWidth() == 8 || intTy.getWidth() == 16 || - intTy.getWidth() == 32 || intTy.getWidth() == 64; - if (mlir::pto::isPTOFloat8Type(type)) - return true; - if (isa(type)) - return true; - return false; -} - -LogicalResult mlir::pto::PtrToIntOp::verify() { - Type resultTy = getResult().getType(); - auto intTy = dyn_cast(resultTy); - if (!intTy || intTy.getWidth() != 64) - return emitOpError("result must be i64"); - - return verifyPtrLikeForAddressCast(getOperation(), getPtr().getType(), - "ptr operand"); -} - -LogicalResult mlir::pto::IntToPtrOp::verify() { - auto addrTy = dyn_cast(getAddr().getType()); - if (!addrTy || addrTy.getWidth() != 64) - return emitOpError("address operand must be i64"); - - if (failed(verifyPtrLikeForAddressCast(getOperation(), getResult().getType(), - "result"))) - return failure(); - - Type dstElem = getPointerLikeElementType(getResult().getType()); - if (!isEmitCSupportedScalarType(dstElem)) - return emitOpError("result element type is not supported by EmitC: ") - << dstElem; - - return success(); -} - -LogicalResult mlir::pto::LocalArrayGetOp::verify() { - auto arrayTy = getArray().getType(); - int64_t rank = arrayTy.getRank(); - int64_t numIdx = static_cast(getIndices().size()); - if (numIdx != rank) - return emitOpError() << "expects " << rank - << " indices for !pto.local_array of rank " << rank - << ", got " << numIdx; - if (getResult().getType() != arrayTy.getElementType()) - return emitOpError() - << "result type " << getResult().getType() - << " does not match array element type " - << arrayTy.getElementType(); - return success(); -} - -LogicalResult mlir::pto::LocalArraySetOp::verify() { - auto arrayTy = getArray().getType(); - int64_t rank = arrayTy.getRank(); - int64_t numIdx = static_cast(getIndices().size()); - if (numIdx != rank) - return emitOpError() << "expects " << rank - << " indices for !pto.local_array of rank " << rank - << ", got " << numIdx; - if (getValue().getType() != arrayTy.getElementType()) - return emitOpError() << "value type " << getValue().getType() - << " does not match array element type " - << arrayTy.getElementType(); - return success(); -} - - - - -void PTODialect::initialize() { - addTypes< -#define GET_TYPEDEF_LIST -#include "PTO/IR/PTOTypeDefs.cpp.inc" - >(); - - addOperations< -#define GET_OP_LIST -#include "PTO/IR/PTOOps.cpp.inc" - >(); - - addAttributes< -#define GET_ATTRDEF_LIST -#include "PTO/IR/PTOAttrs.cpp.inc" - >(); -} - - -AddressSpaceAttr mlir::pto::getPTOAddressSpaceAttr(Type type) { - auto memRefType = dyn_cast(type); - if (!memRefType) - return {}; - auto scopeAttr = dyn_cast(memRefType.getMemorySpace()); - if (!scopeAttr) - return {}; - return scopeAttr; -} - -bool mlir::pto::isScalarPtrOrMemRef(Type type) { - if (auto pty = dyn_cast(type)) - return true; - if (auto memTy = dyn_cast(type)) - return isGmAddressSpaceAttr(memTy.getMemorySpace()); - return false; -} - -bool mlir::pto::hasExplicitPTOEntryAttr(func::FuncOp func) { - return func && (func->hasAttrOfType(kPTOEntryAttrName) || - func->hasAttrOfType(kLegacyHACCEntryAttrName)); -} - -static constexpr StringLiteral kEffectivePTOEntryAttrName = - "pto.internal.entry"; - -static SmallVector getPTOFunctionDefinitions(ModuleOp module) { - SmallVector defs; - if (!module) - return defs; - for (auto func : module.getOps()) { - if (!func.isDeclaration()) - defs.push_back(func); - } - return defs; -} - -bool mlir::pto::isPTOEntryFunction(func::FuncOp func) { - if (!func || func.isDeclaration()) - return false; - if (auto attr = func->getAttrOfType(kEffectivePTOEntryAttrName)) - return attr.getValue(); - if (hasExplicitPTOEntryAttr(func)) - return true; - - ModuleOp module = func->getParentOfType(); - if (!module) - return false; - SmallVector defs = getPTOFunctionDefinitions(module); - return defs.size() == 1 && defs.front() == func; -} - -LogicalResult mlir::pto::validatePTOEntryFunctions(ModuleOp module) { - if (!module) - return success(); - - for (auto func : module.getOps()) { - if (!hasExplicitPTOEntryAttr(func)) - continue; - if (func.isDeclaration()) { - return func.emitOpError() - << "`" << kPTOEntryAttrName - << "` is only valid on function definitions"; - } - } - - for (auto func : module.getOps()) { - if (!isPTOEntryFunction(func)) - continue; - if (func.getFunctionType().getNumResults() != 0) { - return func.emitOpError() - << "PTO entry functions must return void"; - } - } - return success(); -} - -void mlir::pto::annotatePTOEntryFunctions(ModuleOp module) { - if (!module) - return; - - SmallVector defs = getPTOFunctionDefinitions(module); - for (auto func : module.getOps()) - func->removeAttr(kEffectivePTOEntryAttrName); - - if (defs.empty()) - return; - if (defs.size() == 1) { - defs.front()->setAttr(kEffectivePTOEntryAttrName, - BoolAttr::get(module.getContext(), true)); - return; - } - - for (auto func : defs) { - func->setAttr(kEffectivePTOEntryAttrName, - BoolAttr::get(module.getContext(), - hasExplicitPTOEntryAttr(func))); - } -} - -//===----------------------------------------------------------------------===// -// PTO Load/Store/Addf (non-DPS polymorphic) verification + inference. -// - If operands are memref/tensor: verify strictly. -// - Otherwise (tile_view/tile etc): accept (so old IR can still parse). -//===----------------------------------------------------------------------===// - -[[maybe_unused]] static LogicalResult verifyMemrefToTensorLoad(Operation *op, Value src, Value res) { - auto mr = dyn_cast(src.getType()); - auto rt = dyn_cast(res.getType()); - if (!mr) - return success(); // non-memref case: don't block old IR - if (!rt) - return op->emitOpError("when src is memref, result must be ranked tensor"); - - if (mr.getElementType() != rt.getElementType()) - return op->emitOpError() << "memref/tensor element type mismatch: memref=" - << mr.getElementType() << " tensor=" << rt.getElementType(); - - if (mr.getRank() != rt.getRank()) - return op->emitOpError() << "rank mismatch: memref rank=" << mr.getRank() - << " tensor rank=" << rt.getRank(); - - if (mr.hasStaticShape()) { - if (!rt.hasStaticShape()) - return op->emitOpError("memref has static shape but result tensor is not static"); - if (mr.getShape() != rt.getShape()) - return op->emitOpError() << "shape mismatch: memref=" << mr << " tensor=" << rt; - } else { - // For dynamic memref dims: if tensor dim is static, allow it; if it's dynamic too, also fine. - // We only reject when a memref static dim conflicts with tensor static dim. - for (int64_t i = 0; i < mr.getRank(); ++i) { - int64_t md = mr.getDimSize(i); - int64_t td = rt.getDimSize(i); - if (md != ShapedType::kDynamic && td != ShapedType::kDynamic && md != td) - return op->emitOpError() << "dim mismatch at " << i << ": memref=" << md << " tensor=" << td; - } - } - return success(); -} - -[[maybe_unused]] static LogicalResult verifyMemrefTensorStore(Operation *op, Value dst, Value src) { - auto mr = dyn_cast(dst.getType()); - if (!mr) - return success(); // non-memref case: old tile IR allowed - auto rt = dyn_cast(src.getType()); - if (!rt) - return op->emitOpError("when dst is memref, src must be ranked tensor"); - - if (mr.getElementType() != rt.getElementType()) - return op->emitOpError() << "memref/tensor element type mismatch: memref=" - << mr.getElementType() << " tensor=" << rt.getElementType(); - - if (mr.getRank() != rt.getRank()) - return op->emitOpError() << "rank mismatch: memref rank=" << mr.getRank() - << " tensor rank=" << rt.getRank(); - - for (int64_t i = 0; i < mr.getRank(); ++i) { - int64_t md = mr.getDimSize(i); - int64_t td = rt.getDimSize(i); - if (md != ShapedType::kDynamic && td != ShapedType::kDynamic && md != td) - return op->emitOpError() << "dim mismatch at " << i << ": memref=" << md << " tensor=" << td; - } - return success(); -} - -LogicalResult AllocTileOp::verify() { - auto ty = getResult().getType(); // TileBufType - - if (failed(verifyTileBufLayoutConstraints(*this, ty, "result"))) - return failure(); - - // op 上有没有传 operands - bool hasVR = getValidRow() != nullptr; - bool hasVC = getValidCol() != nullptr; - - // type 上的 validShape - auto vs = ty.getValidShape(); - if (vs.size() != 2) - return emitOpError("result tile_buf must have rank-2 validShape"); - - // TileBuf valid dims use a negative sentinel (e.g. '?' / -1). Be robust to - // any negative value (some code may materialize MLIR dynamic sentinels). - bool needVR = (vs[0] < 0); - bool needVC = (vs[1] < 0); - - // 你要求的:v_row=?, v_col=? 时必须同时给两个 - // (这条规则由下面两句自然实现) - if (hasVR != needVR) - return emitOpError() << "valid_row operand " - << (needVR ? "is required" : "must be absent") - << " because result type v_row is " - << (needVR ? "?" : std::to_string(vs[0])); - - if (hasVC != needVC) - return emitOpError() << "valid_col operand " - << (needVC ? "is required" : "must be absent") - << " because result type v_col is " - << (needVC ? "?" : std::to_string(vs[1])); - - return success(); -} - -LogicalResult MaterializeTileOp::verify() { - auto sourceTy = cast(getSource().getType()); - auto resultTy = cast(getResult().getType()); - - if (sourceTy.getRank() != 2) - return emitOpError("source memref must be rank-2 to materialize a tile handle"); - if (resultTy.getRank() != 2) - return emitOpError("result tile_buf must be rank-2"); - if (failed(verifyTileBufLayoutConstraints(*this, resultTy, "result"))) - return failure(); - - auto viewSemantics = (*this)->getAttrOfType("pto.view_semantics"); - bool isSubview = viewSemantics && viewSemantics.getValue() == "subview"; - if (!isSubview && sourceTy.getShape() != resultTy.getShape()) - return emitOpError() << "source/result shape mismatch: source=" - << sourceTy << " result=" << resultTy; - - if (sourceTy.getElementType() != resultTy.getElementType()) - return emitOpError() << "source/result element type mismatch: source=" - << sourceTy.getElementType() - << " result=" << resultTy.getElementType(); - - if (sourceTy.getMemorySpace() != resultTy.getMemorySpace()) - return emitOpError() << "source/result memory space mismatch"; - - if (getConfig() != resultTy.getConfigAttr()) - return emitOpError("config attribute must match the result tile_buf config"); - - auto shape = resultTy.getShape(); - auto validShape = resultTy.getValidShape(); - if (validShape.size() != 2) - return emitOpError("result tile_buf must have rank-2 validShape"); - for (unsigned i = 0; i < 2; ++i) { - if (shape[i] != ShapedType::kDynamic && - validShape[i] != ShapedType::kDynamic && validShape[i] > shape[i]) { - return emitOpError() << "valid_shape[" << i << "] must be <= shape[" - << i << "]"; - } - } - - return success(); -} - -LogicalResult TAssignOp::verify() { - if (getTile().getType() != getResult().getType()) { - return emitOpError("result type must match tile operand type"); - } - return success(); -} - -LogicalResult TLoadOp::verify() { - auto verifyCommon = - [&](bool allowLowPrecision) - -> FailureOr> { - auto srcPart = dyn_cast(getSrc().getType()); - auto dstTile = dyn_cast(getDst().getType()); - if (!srcPart || !dstTile) { - emitOpError("expects src to be !pto.partition_tensor_view and dst to be !pto.tile_buf"); - return failure(); - } - if (failed(verifyTileBufCommon(*this, dstTile, "dst", allowLowPrecision))) - return failure(); - - auto srcShape = srcPart.getShape(); - for (unsigned i = 0; i < srcShape.size(); ++i) { - if (srcShape[i] != ShapedType::kDynamic && srcShape[i] <= 0) { - emitOpError() << "expects src shape[" << i << "] to be positive"; - return failure(); - } - } - auto dstValid = dstTile.getValidShape(); - for (unsigned i = 0; i < dstValid.size(); ++i) { - if (dstValid[i] != ShapedType::kDynamic && dstValid[i] <= 0) { - emitOpError() << "expects dst valid_shape[" << i << "] to be positive"; - return failure(); - } - } - return std::make_pair(srcPart, dstTile); - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - auto common = verifyCommon(/*allowLowPrecision=*/false); - if (failed(common)) - return failure(); - auto [srcPart, dstTile] = *common; - - Type srcElem = srcPart.getElementType(); - Type dstElem = dstTile.getElementType(); - if (isPTOLowPrecisionType(srcElem) || isPTOLowPrecisionType(dstElem)) - return emitOpError("expects A2/A3 tload low-precision element types to be unsupported"); - if (!(dstElem.isInteger(8) || dstElem.isInteger(16) || dstElem.isInteger(32) || - dstElem.isInteger(64) || dstElem.isF16() || dstElem.isBF16() || dstElem.isF32())) - return emitOpError("expects A2/A3 tload dst element type to be i8/i16/i32/i64/u64/f16/bf16/f32"); - - auto dstSpace = getPTOMemorySpaceEnum(dstTile); - if (!dstSpace || (*dstSpace != pto::AddressSpace::VEC && - *dstSpace != pto::AddressSpace::MAT)) - return emitOpError("expects A2/A3 tload dst to use loc=vec or loc=mat"); - - if (getElemByteSize(srcElem) != getElemByteSize(dstElem)) - return emitOpError("expects src and dst element types to have the same bitwidth"); - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - auto common = verifyCommon(/*allowLowPrecision=*/true); - if (failed(common)) - return failure(); - auto [srcPart, dstTile] = *common; - - Type srcElem = srcPart.getElementType(); - Type dstElem = dstTile.getElementType(); - unsigned srcBytes = getElemByteSize(srcElem); - unsigned dstBytes = getElemByteSize(dstElem); - if (srcBytes != dstBytes) - return emitOpError("expects src and dst element types to have the same element size"); - if (!(dstBytes == 1 || dstBytes == 2 || dstBytes == 4 || dstBytes == 8)) - return emitOpError("expects A5 tload dst element size to be 1, 2, 4, or 8 bytes"); - if (!isA5TLoadStoreTransferElemType(srcElem)) - return emitOpError("expects A5 tload src element type to be i8/i16/i32/i64/f16/bf16/f32/f8/hif8/fp4"); - if (!isA5TLoadStoreTransferElemType(dstElem)) - return emitOpError("expects A5 tload dst element type to be i8/i16/i32/i64/f16/bf16/f32/f8/hif8/fp4"); - - if (dstElem.isInteger(64)) { - auto pad = dstTile.getPadValueI32(); - if (pad != static_cast(pto::PadValue::Null) && - pad != static_cast(pto::PadValue::Zero)) - return emitOpError("expects A5 i64/u64 tload dst pad to be null or zero"); - } - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult TPrefetchOp::verify() { - auto verifyImpl = [&](bool allowLowPrecision) -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - - Type srcElem; - Type dstElem; - - if (auto srcPart = dyn_cast(srcTy)) { - auto srcShape = srcPart.getShape(); - for (unsigned i = 0; i < srcShape.size(); ++i) { - if (srcShape[i] != ShapedType::kDynamic && srcShape[i] <= 0) - return emitOpError() << "expects src shape[" << i << "] to be positive"; - } - srcElem = srcPart.getElementType(); - } else if (auto srcMr = dyn_cast(srcTy)) { - if (!srcMr.hasRank()) - return emitOpError("expects src memref to be ranked"); - for (int64_t dim : srcMr.getShape()) { - if (dim != ShapedType::kDynamic && dim <= 0) - return emitOpError("expects src memref shape to be positive"); - } - srcElem = srcMr.getElementType(); - } else { - return emitOpError("expects src to be !pto.partition_tensor_view or memref"); - } - - if (auto dstTile = dyn_cast(dstTy)) { - if (failed(verifyTileBufCommon(*this, dstTile, "dst", allowLowPrecision))) - return failure(); - auto dstValid = dstTile.getValidShape(); - for (unsigned i = 0; i < dstValid.size(); ++i) { - if (dstValid[i] != ShapedType::kDynamic && dstValid[i] <= 0) - return emitOpError() << "expects dst valid_shape[" << i << "] to be positive"; - } - auto dstSpace = getPTOMemorySpaceEnum(dstTile); - if (!dstSpace || (*dstSpace != pto::AddressSpace::VEC && - *dstSpace != pto::AddressSpace::MAT)) - return emitOpError("expects dst to use loc=vec or loc=mat"); - dstElem = dstTile.getElementType(); - } else if (auto dstMr = dyn_cast(dstTy)) { - auto dstSpace = getPTOMemorySpaceEnum(dstMr); - if (!dstSpace || (*dstSpace != pto::AddressSpace::VEC && - *dstSpace != pto::AddressSpace::MAT)) - return emitOpError("expects dst memref to use loc=vec or loc=mat"); - if (!dstMr.hasRank()) - return emitOpError("expects dst memref to be ranked"); - if (failed(verifyTileBufCommon(*this, dstMr, "dst", allowLowPrecision))) - return failure(); - dstElem = dstMr.getElementType(); - } else { - return emitOpError("expects dst to be !pto.tile_buf or memref"); - } - - if (getElemByteSize(srcElem) != getElemByteSize(dstElem)) - return emitOpError("expects src and dst element types to have the same element size"); - if (!allowLowPrecision && - (isPTOLowPrecisionType(srcElem) || isPTOLowPrecisionType(dstElem))) - return emitOpError("expects A2/A3 tprefetch low-precision element types to be unsupported"); - if (allowLowPrecision && - (!isA5TLoadStoreTransferElemType(srcElem) || - !isA5TLoadStoreTransferElemType(dstElem))) - return emitOpError("expects A5 tprefetch element types to be i8/i16/i32/i64/f16/bf16/f32/f8/hif8/fp4"); - return success(); - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - return verifyImpl(/*allowLowPrecision=*/false); - }; - auto verifyA5 = [&]() -> LogicalResult { - return verifyImpl(/*allowLowPrecision=*/true); - }; - switch (getVerifierTargetArch(getOperation())) { - case VerifierTargetArch::A2A3: - return verifyA2A3(); - case VerifierTargetArch::A5: - return verifyA5(); - } - return failure(); -} - -LogicalResult MakePrefetchAsyncContextOp::verify() { - Type workspaceTy = getWorkspace().getType(); - Type elemTy = nullptr; - if (auto ptrTy = dyn_cast(workspaceTy)) { - elemTy = ptrTy.getElementType(); - } else if (auto memTy = dyn_cast(workspaceTy)) { - if (!isGmAddressSpaceAttr(memTy.getMemorySpace())) - return emitOpError("expects workspace memref to be in GM address space"); - elemTy = memTy.getElementType(); - } else { - return emitOpError("expects workspace to be !pto.ptr or GM memref"); - } - if (!isByteIntegerType(elemTy)) - return emitOpError("expects workspace element type to be an 8-bit integer"); - return success(); -} - -LogicalResult TPrefetchAsyncOp::verify() { - if (failed(verifyAsyncFlatContiguous1DGMViewLike(getOperation(), getSrc(), - "src"))) - return failure(); - return success(); -} - -LogicalResult mlir::pto::SetFFTsOp::verify() { - auto mr = llvm::dyn_cast(getFfts().getType()); - if (!mr) - return emitOpError("expects a memref operand"); - - if (!mr.getElementType().isInteger(64) && !mr.getElementType().isInteger(8)) - return emitOpError("expects element type i64 (or i8)"); - - return mlir::success(); -} - -ParseResult mlir::pto::SyncSetOp::parse(OpAsmParser &parser, - OperationState &result) { - return parseSyncEventOpCommon(parser, result, - SyncSetOp::getPipeAttrName(result.name), - SyncSetOp::getEventIdAttrName(result.name)); -} - -void mlir::pto::SyncSetOp::print(OpAsmPrinter &p) { - printSyncEventOpCommon(p, getOperation(), getPipe(), getEventIdAttr(), - getEventIdDyn(), getPipeAttrName().getValue(), - getEventIdAttrName().getValue()); -} - -LogicalResult mlir::pto::SyncSetOp::verify() { - bool hasStatic = getEventIdAttr() != nullptr; - bool hasDynamic = static_cast(getEventIdDyn()); - if (hasStatic == hasDynamic) - return emitOpError() - << "expects exactly one event-id form: static attr or dynamic index operand"; - if (IntegerAttr fftsModeAttr = getFftsModeAttr()) { - int64_t fftsMode = fftsModeAttr.getInt(); - if (fftsMode < 0 || fftsMode > 2) - return emitOpError() << "requires ffts_mode in range [0, 2], but got " - << fftsMode; - } - - auto verifyA2A3 = [&]() -> LogicalResult { return success(); }; - auto verifyA5 = [&]() -> LogicalResult { - switch (getPipe().getPipe()) { - case PIPE::PIPE_FIX: - case PIPE::PIPE_MTE3: - return success(); - default: - return emitOpError() - << "A5 sync.set expects pipe to be one of , "; - } - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -ParseResult mlir::pto::SyncWaitOp::parse(OpAsmParser &parser, - OperationState &result) { - return parseSyncEventOpCommon(parser, result, - SyncWaitOp::getPipeAttrName(result.name), - SyncWaitOp::getEventIdAttrName(result.name)); -} - -void mlir::pto::SyncWaitOp::print(OpAsmPrinter &p) { - printSyncEventOpCommon(p, getOperation(), getPipe(), getEventIdAttr(), - getEventIdDyn(), getPipeAttrName().getValue(), - getEventIdAttrName().getValue()); -} - -ParseResult mlir::pto::SyncAllOp::parse(OpAsmParser &parser, - OperationState &result) { - SmallVector operands; - SmallVector operandTypes; - Attribute modeAttr; - Attribute coreTypeAttr; - - if (parser.parseLParen()) - return failure(); - - if (failed(parser.parseOptionalRParen())) { - if (parser.parseOperandList(operands) || parser.parseColonTypeList(operandTypes) || - parser.parseRParen()) - return failure(); - if (operands.size() != operandTypes.size()) - return parser.emitError(parser.getCurrentLocation()) - << "expects the same number of operands and operand types"; - } - - if (parser.parseKeyword("mode") || parser.parseEqual() || - parser.parseAttribute(modeAttr) || parser.parseComma() || - parser.parseKeyword("core_type") || parser.parseEqual() || - parser.parseAttribute(coreTypeAttr)) - return failure(); - - auto mode = dyn_cast(modeAttr); - if (!mode) - return parser.emitError(parser.getCurrentLocation()) - << "expects mode to be #pto.sync_all_mode<...>"; - - auto coreType = dyn_cast(coreTypeAttr); - if (!coreType) - return parser.emitError(parser.getCurrentLocation()) - << "expects core_type to be #pto.sync_core_type<...>"; - - result.addAttribute("mode", mode); - result.addAttribute("core_type", coreType); - - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - - auto addSegmentSizes = [&](int32_t gm, int32_t ub, int32_t l1, - int32_t used) { - result.addAttribute("operandSegmentSizes", - parser.getBuilder().getDenseI32ArrayAttr( - {gm, ub, l1, used})); - }; - - switch (mode.getValue()) { - case pto::SyncAllMode::Hard: - if (!operands.empty()) - return parser.emitError(parser.getCurrentLocation()) - << "expects hard syncall to have no operands"; - addSegmentSizes(0, 0, 0, 0); - return success(); - case pto::SyncAllMode::Soft: - break; - } - - switch (coreType.getValue()) { - case pto::SyncCoreType::AIVOnly: - if (operands.size() != 2 && operands.size() != 3) - return parser.emitError(parser.getCurrentLocation()) - << "expects soft AIV-only syncall to have gm_workspace, " - "ub_workspace, and optional used_cores"; - if (parser.resolveOperand(operands[0], operandTypes[0], result.operands) || - parser.resolveOperand(operands[1], operandTypes[1], result.operands)) - return failure(); - if (operands.size() == 3 && - parser.resolveOperand(operands[2], operandTypes[2], result.operands)) - return failure(); - addSegmentSizes(1, 1, 0, operands.size() == 3 ? 1 : 0); - return success(); - case pto::SyncCoreType::AICOnly: - if (operands.size() != 2 && operands.size() != 3) - return parser.emitError(parser.getCurrentLocation()) - << "expects soft AIC-only syncall to have gm_workspace, " - "l1_workspace, and optional used_cores"; - if (parser.resolveOperand(operands[0], operandTypes[0], result.operands) || - parser.resolveOperand(operands[1], operandTypes[1], result.operands)) - return failure(); - if (operands.size() == 3 && - parser.resolveOperand(operands[2], operandTypes[2], result.operands)) - return failure(); - addSegmentSizes(1, 0, 1, operands.size() == 3 ? 1 : 0); - return success(); - case pto::SyncCoreType::Mix: - if (operands.size() != 3 && operands.size() != 4) - return parser.emitError(parser.getCurrentLocation()) - << "expects soft mixed syncall to have gm_workspace, " - "ub_workspace, l1_workspace, and optional used_cores"; - if (parser.resolveOperand(operands[0], operandTypes[0], result.operands) || - parser.resolveOperand(operands[1], operandTypes[1], result.operands) || - parser.resolveOperand(operands[2], operandTypes[2], result.operands)) - return failure(); - if (operands.size() == 4 && - parser.resolveOperand(operands[3], operandTypes[3], result.operands)) - return failure(); - addSegmentSizes(1, 1, 1, operands.size() == 4 ? 1 : 0); - return success(); - } - - llvm_unreachable("unhandled SyncCoreType"); -} - -void mlir::pto::SyncAllOp::print(OpAsmPrinter &p) { - SmallVector operands; - if (getGmWorkspace()) - operands.push_back(getGmWorkspace()); - if (getUbWorkspace()) - operands.push_back(getUbWorkspace()); - if (getL1Workspace()) - operands.push_back(getL1Workspace()); - if (getUsedCores()) - operands.push_back(getUsedCores()); - - p << "("; - if (!operands.empty()) { - p.printOperands(operands); - p << " : "; - llvm::interleaveComma(operands, p, - [&](Value operand) { p.printType(operand.getType()); }); - } - p << ") mode = " << getMode() << ", core_type = " << getCoreType(); - p.printOptionalAttrDict((*this)->getAttrs(), - /*elidedAttrs=*/{"operandSegmentSizes", "mode", - "core_type"}); -} - -LogicalResult mlir::pto::SyncWaitOp::verify() { - bool hasStatic = getEventIdAttr() != nullptr; - bool hasDynamic = static_cast(getEventIdDyn()); - if (hasStatic == hasDynamic) - return emitOpError() - << "expects exactly one event-id form: static attr or dynamic index operand"; - - auto verifyA2A3 = [&]() -> LogicalResult { return success(); }; - auto verifyA5 = [&]() -> LogicalResult { - switch (getPipe().getPipe()) { - case PIPE::PIPE_FIX: - case PIPE::PIPE_MTE1: - case PIPE::PIPE_MTE2: - case PIPE::PIPE_MTE3: - case PIPE::PIPE_V: - return success(); - default: - return emitOpError() << "A5 sync.wait expects pipe to be one of " - ", , , " - ", "; - } - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult TStoreOp::verify() { - auto verifyCommon = - [&](bool allowLowPrecision) - -> FailureOr> { - auto srcTile = dyn_cast(getSrc().getType()); - auto dstPart = dyn_cast(getDst().getType()); - if (!srcTile || !dstPart) { - emitOpError("expects src to be !pto.tile_buf and dst to be !pto.partition_tensor_view"); - return failure(); - } - if (failed(verifyTileBufCommon(*this, srcTile, "src", allowLowPrecision))) - return failure(); - for (auto [idx, dim] : llvm::enumerate(dstPart.getShape())) { - if (dim != ShapedType::kDynamic && dim <= 0) { - emitOpError() << "expects dst shape[" << idx << "] to be positive"; - return failure(); - } - } - auto srcValid = srcTile.getValidShape(); - for (auto [idx, dim] : llvm::enumerate(srcValid)) { - if (dim != ShapedType::kDynamic && dim <= 0) { - emitOpError() << "expects src valid_shape[" << idx << "] to be positive"; - return failure(); - } - } - - // Keep TSTORE contract explicit while preserving existing legal layout - // reinterpretation paths (e.g. 1x1024 <-> 32x32, 5D partition views). - // When both sides are fully static, require equal element counts between - // dst shape and src valid_shape. - auto getStaticElemCount = [](ArrayRef shape) -> std::optional { - int64_t total = 1; - for (int64_t dim : shape) { - if (dim == ShapedType::kDynamic) - return std::nullopt; - if (dim <= 0) - return std::nullopt; - if (total > std::numeric_limits::max() / dim) - return std::nullopt; - total *= dim; - } - return total; - }; - - auto dstElemCount = getStaticElemCount(dstPart.getShape()); - auto srcValidElemCount = getStaticElemCount(srcValid); - if (dstElemCount && srcValidElemCount && *dstElemCount != *srcValidElemCount) { - emitOpError() << "expects dst static element count (" << *dstElemCount - << ") to match src valid_shape static element count (" - << *srcValidElemCount << ")"; - return failure(); - } - return std::make_pair(srcTile, dstPart); - }; - - auto isLoadStoreElemType = [&](Type ty) -> bool { - return ty.isInteger(8) || ty.isInteger(16) || ty.isInteger(32) || - ty.isInteger(64) || ty.isF16() || ty.isBF16() || ty.isF32(); - }; - auto isI8Like = [&](Type ty) -> bool { return ty.isInteger(8); }; - bool hasPreQuant = static_cast(getPreQuantScalar()); - auto reluMode = getReluPreMode(); - - auto verifyA2A3 = [&]() -> LogicalResult { - auto common = verifyCommon(/*allowLowPrecision=*/false); - if (failed(common)) - return failure(); - auto [srcTile, dstPart] = *common; - auto srcSpace = getPTOMemorySpaceEnum(srcTile); - if (!srcSpace || (*srcSpace != pto::AddressSpace::VEC && - *srcSpace != pto::AddressSpace::MAT && - *srcSpace != pto::AddressSpace::ACC)) - return emitOpError("expects A2/A3 tstore src to use loc=vec, loc=mat, or loc=acc"); - if (hasPreQuant && *srcSpace != pto::AddressSpace::ACC) - return emitOpError("expects preQuantScalar form to use loc=acc src"); - if (reluMode != pto::ReluPreMode::NoRelu && *srcSpace != pto::AddressSpace::ACC) - return emitOpError("expects reluPreMode form to use loc=acc src"); - - Type srcElem = srcTile.getElementType(); - Type dstElem = dstPart.getElementType(); - if (*srcSpace == pto::AddressSpace::VEC || *srcSpace == pto::AddressSpace::MAT) { - if (hasPreQuant) - return emitOpError("expects preQuantScalar form to use loc=acc src"); - if (isPTOLowPrecisionType(dstElem)) - return emitOpError("expects A2/A3 vec/mat tstore low-precision dst element types to be unsupported"); - if (!isLoadStoreElemType(srcElem)) - return emitOpError("expects A2/A3 vec/mat tstore src element type to be i8/i16/i32/i64/u64/f16/bf16/f32"); - if (getElemByteSize(srcElem) != getElemByteSize(dstElem)) - return emitOpError("expects A2/A3 vec/mat tstore src and dst element types to have the same bitwidth"); - return success(); - } - - if (!(srcElem.isInteger(32) || srcElem.isF32())) - return emitOpError("expects A2/A3 acc tstore src element type to be i32 or f32"); - if (hasPreQuant) { - if (srcElem.isInteger(32)) { - if (!(isI8Like(dstElem) || dstElem.isF16())) - return emitOpError("expects A2/A3 acc preQuantScalar tstore dst type to be i8/ui8/f16"); - } else if (srcElem.isF32()) { - if (!isI8Like(dstElem)) - return emitOpError("expects A2/A3 acc preQuantScalar tstore dst type to be i8/ui8"); - } - } else { - if (!(dstElem.isInteger(32) || dstElem.isF32() || dstElem.isF16() || - dstElem.isBF16())) - return emitOpError("expects A2/A3 acc tstore dst element type to be i32/f32/f16/bf16"); - } - - auto srcShape = srcTile.getShape(); - if (srcShape[1] != ShapedType::kDynamic && - (srcShape[1] < 1 || srcShape[1] > 4095)) - return emitOpError("expects A2/A3 acc tstore src cols to be in [1, 4095]"); - auto srcValid = srcTile.getValidShape(); - if (srcValid[1] != ShapedType::kDynamic && - (srcValid[1] < 1 || srcValid[1] > 4095)) - return emitOpError("expects A2/A3 acc tstore src valid_shape[1] to be in [1, 4095]"); - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - auto common = verifyCommon(/*allowLowPrecision=*/true); - if (failed(common)) - return failure(); - auto [srcTile, dstPart] = *common; - auto srcSpace = getPTOMemorySpaceEnum(srcTile); - if (!srcSpace || (*srcSpace != pto::AddressSpace::VEC && - *srcSpace != pto::AddressSpace::ACC)) - return emitOpError("expects A5 tstore src to use loc=vec or loc=acc"); - if (hasPreQuant && *srcSpace != pto::AddressSpace::ACC) - return emitOpError("expects preQuantScalar form to use loc=acc src"); - if (reluMode != pto::ReluPreMode::NoRelu && *srcSpace != pto::AddressSpace::ACC) - return emitOpError("expects reluPreMode form to use loc=acc src"); - - Type srcElem = srcTile.getElementType(); - Type dstElem = dstPart.getElementType(); - if (*srcSpace == pto::AddressSpace::VEC) { - if (hasPreQuant) - return emitOpError("expects preQuantScalar form to use loc=acc src"); - if (!isA5TLoadStoreTransferElemType(srcElem)) - return emitOpError("expects A5 vec tstore src element type to be i8/i16/i32/i64/f16/bf16/f32/f8/hif8/fp4"); - if (getElemByteSize(srcElem) != getElemByteSize(dstElem)) - return emitOpError("expects A5 vec tstore src and dst element types to have the same bitwidth"); - return success(); - } - - if (!(srcElem.isInteger(32) || srcElem.isF32())) - return emitOpError("expects A5 acc tstore src element type to be i32 or f32"); - if (hasPreQuant) { - if (!isA5AccStorePreQuantDstType(srcElem, dstElem)) - return emitOpError("expects A5 acc preQuantScalar tstore dst type to be i8/ui8/f16/bf16/f32/hif8/f8E4M3"); - } else { - if (!(dstElem.isInteger(32) || dstElem.isF32() || dstElem.isF16() || - dstElem.isBF16())) - return emitOpError("expects A5 acc tstore dst element type to be i32/f32/f16/bf16"); - } - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult pto::TAbsOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyVecTileCommon(*this, srcTy, "src")) || - failed(verifyVecTileCommon(*this, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst")) || - failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) - return failure(); - - Type elemTy; - if (auto tb = dyn_cast(srcTy)) - elemTy = tb.getElementType(); - else if (auto mr = dyn_cast(srcTy)) - elemTy = mr.getElementType(); - if (!(elemTy.isF16() || elemTy.isF32())) - return emitOpError() << "expects element type to be f16 or f32"; - - return success(); -} -// PTO.cpp - -static bool isPTOShapedLike(Type ty) { - return mlir::isa(ty); -} - -static bool isTileLikeType(Type ty) { - return isa(ty); -} - -static Type getElemTy(Type ty) { - if (auto mr = mlir::dyn_cast(ty)) return mr.getElementType(); - if (auto tt = mlir::dyn_cast(ty)) return tt.getElementType(); - if (auto tv = mlir::dyn_cast(ty)) return tv.getElementType(); - if (auto tb = mlir::dyn_cast(ty)) return tb.getElementType(); - if (auto tv = mlir::dyn_cast(ty)) return tv.getElementType(); - return Type(); -} - -static SmallVector getShapeVec(Type ty) { - SmallVector s; - if (auto mr = mlir::dyn_cast(ty)) - return SmallVector(mr.getShape().begin(), mr.getShape().end()); - if (auto tt = mlir::dyn_cast(ty)) - return SmallVector(tt.getShape().begin(), tt.getShape().end()); - if (auto tv = mlir::dyn_cast(ty)) - return SmallVector(tv.getShape().begin(), tv.getShape().end()); - if (auto tb = mlir::dyn_cast(ty)) - return SmallVector(tb.getShape().begin(), tb.getShape().end()); - if (auto tv = mlir::dyn_cast(ty)) - return SmallVector(tv.getShape().begin(), tv.getShape().end()); - return {}; -} - -static SmallVector getValidShapeVec(Type ty) { - if (auto tb = dyn_cast(ty)) - return SmallVector(tb.getValidShape().begin(), tb.getValidShape().end()); - return getShapeVec(ty); -} - -static int64_t getLogicalTileDim(int64_t rawDim, Type elemTy, - std::optional blayout, - unsigned dimIdx) { - if (rawDim == ShapedType::kDynamic || !isPTOFloat4PackedType(elemTy)) - return rawDim; - pto::BLayout layout = blayout.value_or(pto::BLayout::RowMajor); - unsigned packedDim = layout == pto::BLayout::ColMajor ? 0 : 1; - return dimIdx == packedDim ? rawDim * 2 : rawDim; -} - -static std::optional getTileBufBLayout(Type ty) { - if (auto tb = dyn_cast(ty)) - return static_cast(tb.getBLayoutValueI32()); - return std::nullopt; -} - -static SmallVector getLogicalTileExtentVec(Type ty, - bool useValidShape) { - SmallVector dims = - useValidShape ? getValidShapeVec(ty) : getShapeVec(ty); - if (!isTileLikeType(ty) || dims.size() != 2) - return dims; - - Type elemTy = getElemTy(ty); - auto blayout = getTileBufBLayout(ty); - for (unsigned i = 0; i < dims.size(); ++i) - dims[i] = getLogicalTileDim(dims[i], elemTy, blayout, i); - return dims; -} - -static int64_t getConstantIndexOrDynamic(Value value) { - if (!value) - return ShapedType::kDynamic; - if (auto cst = value.getDefiningOp()) - return cst.value(); - if (auto cst = value.getDefiningOp()) - return cst.value(); - return ShapedType::kDynamic; -} - -static SmallVector getValidShapeVec(Value value) { - if (!value) - return {}; - auto valid = getValidShapeVec(value.getType()); - if (auto bind = value.getDefiningOp()) { - if (valid.size() >= 1 && bind.getValidRow()) - valid[0] = getConstantIndexOrDynamic(bind.getValidRow()); - if (valid.size() >= 2 && bind.getValidCol()) - valid[1] = getConstantIndexOrDynamic(bind.getValidCol()); - } - return valid; -} - -static SmallVector getMatmulLogicalShapeVec(Type ty) { - auto shape = getShapeVec(ty); - auto valid = getValidShapeVec(ty); - if (!isa(ty) || shape.size() != valid.size()) - return shape; - - for (size_t i = 0, e = shape.size(); i < e; ++i) { - if (valid[i] != ShapedType::kDynamic) - shape[i] = valid[i]; - } - return shape; -} - -static bool isByteIntegerType(Type ty) { - auto intTy = dyn_cast(ty); - return intTy && intTy.getWidth() == 8; -} - -static LogicalResult verifyAsyncFlatContiguous1DGMMemRef(Operation *op, - Value value, - StringRef name) { - auto memTy = dyn_cast(value.getType()); - if (!memTy) - return op->emitOpError() << "expects " << name << " to be a memref"; - if (!memTy.hasRank()) - return op->emitOpError() << "expects " << name << " to be a ranked memref"; - if (!isGmAddressSpaceAttr(memTy.getMemorySpace())) - return op->emitOpError() << "expects " << name - << " to be in GM address space"; - - ArrayRef shape = memTy.getShape(); - if (shape.empty()) - return op->emitOpError() << "expects " << name - << " to have rank >= 1"; - for (int64_t dim : shape) { - if (dim == ShapedType::kDynamic) - return op->emitOpError() << "expects " << name - << " to have a static shape"; - } - - SmallVector strides; - int64_t offset = 0; - if (failed(getStridesAndOffset(memTy, strides, offset))) - return op->emitOpError() << "expects " << name - << " to be a strided memref with a known layout"; - - bool hasDynamicLayout = - offset == ShapedType::kDynamic || - llvm::any_of(strides, [](int64_t stride) { - return stride == ShapedType::kDynamic; - }); - if (hasDynamicLayout) - return success(); - - bool packed = !strides.empty() && strides.back() == 1; - for (int i = static_cast(shape.size()) - 2; i >= 0 && packed; --i) - packed &= strides[i] == strides[i + 1] * shape[i + 1]; - if (!packed) - return op->emitOpError() - << "expects " << name - << " to be a static flat contiguous logical 1D GM memref"; - - bool logical1D = true; - for (int i = 0, e = static_cast(shape.size()) - 1; i < e; ++i) - logical1D &= shape[i] == 1; - if (!logical1D) - return op->emitOpError() - << "expects " << name - << " to be a static flat contiguous logical 1D GM memref"; - - return success(); -} - -static LogicalResult verifyAsyncFlatContiguous1DGMViewLike(Operation *op, - Value value, - StringRef name) { - Type ty = value.getType(); - if (isa(ty)) - return verifyAsyncFlatContiguous1DGMMemRef(op, value, name); - - if (!isa(ty)) - return op->emitOpError() << "expects " << name - << " to be a memref/tensor_view/partition_view"; - - SmallVector shape = getShapeVec(ty); - if (shape.empty()) - return op->emitOpError() << "expects " << name << " to have rank >= 1"; - for (int64_t dim : shape) { - if (dim == ShapedType::kDynamic) - return op->emitOpError() << "expects " << name - << " to have a static shape"; - } - - bool logical1D = true; - for (int i = 0, e = static_cast(shape.size()) - 1; i < e; ++i) - logical1D &= shape[i] == 1; - if (!logical1D) - return op->emitOpError() - << "expects " << name - << " to be a static flat contiguous logical 1D GM view"; - - return success(); -} - -static bool isCommGlobalLikeType(Type ty) { - if (auto memTy = dyn_cast(ty)) - return isGmAddressSpaceAttr(memTy.getMemorySpace()); - return isa(ty); -} - -static LogicalResult verifyCommGlobalLike(Operation *op, Value value, - StringRef name) { - Type ty = value.getType(); - if (!isCommGlobalLikeType(ty)) - return op->emitOpError() << "expects " << name - << " to be a GM memref/tensor_view/partition_view"; - - SmallVector shape = getShapeVec(ty); - if (shape.empty()) - return op->emitOpError() << "expects " << name << " to have rank >= 1"; - for (int64_t dim : shape) { - if (dim == ShapedType::kDynamic || dim <= 0) - return op->emitOpError() << "expects " << name - << " to have a positive static shape"; - } - return success(); -} - -static LogicalResult verifyCommSignalLike(Operation *op, Value value, - StringRef name) { - if (failed(verifyCommGlobalLike(op, value, name))) - return failure(); - Type elemTy = getElemTy(value.getType()); - if (!elemTy || !elemTy.isSignlessInteger(32)) - return op->emitOpError() << "expects " << name - << " element type to be i32"; - return success(); -} - -static LogicalResult verifyCommStagingTileLike(Operation *op, Value value, - StringRef name) { - Type ty = value.getType(); - if (!isa(ty)) - return op->emitOpError() << "expects " << name - << " to be a tile_buf or memref tile"; - auto as = getPTOMemorySpaceEnum(ty); - if (!as || *as != pto::AddressSpace::VEC) - return op->emitOpError() << "expects " << name - << " to be in vec address space"; - SmallVector shape = getShapeVec(ty); - if (shape.empty()) - return op->emitOpError() << "expects " << name << " to have rank >= 1"; - for (int64_t dim : shape) { - if (dim == ShapedType::kDynamic || dim <= 0) - return op->emitOpError() << "expects " << name - << " to have a positive static shape"; - } - return success(); -} - -static LogicalResult verifyCommGlobalGroup(Operation *op, ValueRange group, - StringRef name) { - if (group.empty()) - return op->emitOpError() << "expects at least one " << name << " operand"; - Type groupTy = group.front().getType(); - for (auto it : llvm::enumerate(group)) { - if (failed(verifyCommGlobalLike(op, it.value(), - (name + "[" + Twine(it.index()) + "]").str()))) - return failure(); - if (it.value().getType() != groupTy) - return op->emitOpError() << "expects all " << name - << " operands to have identical types"; - } - return success(); -} - -static LogicalResult verifyCommPingPongSameType(Operation *op, Value ping, - Value pong, StringRef pingName, - StringRef pongName) { - if (!pong) - return success(); - if (failed(verifyCommStagingTileLike(op, ping, pingName)) || - failed(verifyCommStagingTileLike(op, pong, pongName))) - return failure(); - if (ping.getType() != pong.getType()) - return op->emitOpError() << "expects " << pingName << " and " << pongName - << " to have identical types"; - return success(); -} - -static std::optional getStaticByteSize(Type ty) { - SmallVector shape = getShapeVec(ty); - if (shape.empty()) - return std::nullopt; - for (int64_t dim : shape) { - if (dim == ShapedType::kDynamic || dim < 0) - return std::nullopt; - } - - Type elemTy = getElemTy(ty); - uint64_t elemBytes = getElemByteSize(elemTy); - if (elemBytes == 0) - return std::nullopt; - - uint64_t total = elemBytes; - for (int64_t dim : shape) { - total *= static_cast(dim); - } - return total; -} - -static std::optional getPTOMemorySpaceEnum(Type ty) { - if (auto tb = dyn_cast(ty)) { - if (auto as = dyn_cast_or_null(tb.getMemorySpace())) - return as.getAddressSpace(); - return std::nullopt; - } - if (auto mr = dyn_cast(ty)) { - if (auto as = dyn_cast_or_null(mr.getMemorySpace())) - return as.getAddressSpace(); - if (!mr.getMemorySpace()) - return pto::AddressSpace::GM; - } - return std::nullopt; -} - -[[maybe_unused]] static bool isRank2TileBuf(Type ty) { - auto tb = dyn_cast(ty); - return tb && tb.getRank() == 2 && tb.getValidShape().size() == 2; -} - -static bool isSupportedVecElemType(Type ty, bool allowBf16, - bool allowInt8) { - if (ty.isF16() || ty.isF32()) - return true; - if (allowBf16 && ty.isBF16()) - return true; - if (auto it = dyn_cast(ty)) { - switch (it.getWidth()) { - case 32: - case 16: - return true; - case 8: - return allowInt8; - default: - return false; - } - } - return false; -} - -static bool isSupportedMGatherMScatterIndexElemType(Type ty) { - auto it = dyn_cast(ty); - if (!it || it.getWidth() != 32) - return false; - return true; -} - -static bool isSupportedMGatherMScatterPayloadElemType(Operation *op, Type ty) { - if (isSupportedVecElemType(ty, /*allowBf16=*/true, /*allowInt8=*/true)) - return true; - if (!isTargetArchA5(op)) - return false; - return ty.isFloat8E4M3() || ty.isFloat8E4M3FN() || ty.isFloat8E4M3FNUZ() || - ty.isFloat8E4M3B11FNUZ() || ty.isFloat8E5M2() || ty.isFloat8E5M2FNUZ(); -} - -static bool isSupportedMScatterAtomicPayloadElemType(Type ty, - pto::ScatterAtomicOp atomic) { - auto intTy = dyn_cast(ty); - switch (atomic) { - case pto::ScatterAtomicOp::None: - return true; - case pto::ScatterAtomicOp::Add: - return ty.isF16() || ty.isF32() || - (intTy && intTy.getWidth() == 32); - case pto::ScatterAtomicOp::Max: - case pto::ScatterAtomicOp::Min: - return ty.isF32() || - (intTy && intTy.getWidth() == 32); - } - llvm_unreachable("Unknown ScatterAtomicOp"); -} - -static LogicalResult verifyMGatherMScatterMemOperand(Operation *op, - Value memValue, - Type dataElemTy, - StringRef dataOperandLabel) { - Type memTy = memValue.getType(); - Type memElem = getElemTy(memTy); - if (!memElem || memElem != dataElemTy) - return op->emitOpError() << "expects mem element type to match " - << dataOperandLabel << " element type"; - - if (isa(memTy)) { - if (auto layout = getLogicalViewLayout(memValue)) { - if (*layout != pto::Layout::ND) - return op->emitOpError( - "expects mem partition view to use ND logical layout when layout " - "can be inferred"); - } - return success(); - } - - if (auto mr = dyn_cast(memTy)) { - auto as = getPTOMemorySpaceEnum(mr); - if (!as || (*as != pto::AddressSpace::GM && - *as != pto::AddressSpace::Zero)) - return op->emitOpError( - "expects mem memref to use GM or zero address space"); - if (mr.getRank() == 5) { - auto shape = mr.getShape(); - bool allStatic = true; - for (int64_t d : shape) - if (d == ShapedType::kDynamic) - allStatic = false; - if (allStatic && (shape[0] != 1 || shape[1] != 1 || shape[2] != 1)) - return op->emitOpError( - "expects rank-5 GM memref leading dimensions to be [1,1,1,...] " - "(GlobalTensor table shape)"); - } - return success(); - } - - return op->emitOpError( - "expects mem to be !pto.partition_tensor_view or a GM/ZERO memref"); -} - -static bool hasCompatibleKnownExtent(int64_t lhs, int64_t rhs); -static bool isKnownUnitExtent(int64_t value); - -static LogicalResult verifyMGatherMScatterTileShape(Operation *op, Type dataTy, - Type idxTy, - StringRef dataName) { - auto dataValid = getValidShapeVec(dataTy); - auto idxValid = getValidShapeVec(idxTy); - if (dataValid.size() != 2 || idxValid.size() != 2) - return op->emitOpError() << "expects " << dataName - << " and idx to have rank-2 valid_shape"; - - auto idxTile = dyn_cast(idxTy); - if (!idxTile) - return op->emitOpError("expects idx to be a tile_buf type"); - - const bool idxRowMajor = - idxTile.getBLayoutValueI32() == - static_cast(pto::BLayout::RowMajor); - const bool idxColMajor = - idxTile.getBLayoutValueI32() == - static_cast(pto::BLayout::ColMajor); - - const bool rowCoalesce1xR = - idxRowMajor && isKnownUnitExtent(idxValid[0]) && - hasCompatibleKnownExtent(idxValid[1], dataValid[0]); - const bool rowCoalesceRx1 = - idxColMajor && hasCompatibleKnownExtent(idxValid[0], dataValid[0]) && - isKnownUnitExtent(idxValid[1]); - const bool elemCoalesce = - hasCompatibleKnownExtent(idxValid[0], dataValid[0]) && - hasCompatibleKnownExtent(idxValid[1], dataValid[1]); - - if (!(rowCoalesce1xR || rowCoalesceRx1 || elemCoalesce)) - return op->emitOpError() - << "expects idx valid_shape to be [1, " << dataName - << ".valid_row], [" << dataName - << ".valid_row, 1], or match " << dataName << " valid_shape"; - - return success(); -} - -static LogicalResult verifyMGatherMScatterIdxTile(Operation *op, Type ty, - StringRef name) { - if (failed(verifyTileBufCommon(op, ty, name))) - return failure(); - auto as = getPTOMemorySpaceEnum(ty); - if (!as || *as != pto::AddressSpace::VEC) - return op->emitOpError() << "expects " << name - << " to be in the vec address space"; - auto tb = dyn_cast(ty); - if (!tb) - return op->emitOpError() << "expects " << name << " to be a tile_buf type"; - int32_t blayout = tb.getBLayoutValueI32(); - if (blayout != static_cast(pto::BLayout::RowMajor) && - blayout != static_cast(pto::BLayout::ColMajor)) - return op->emitOpError() << "expects " << name - << " to use row_major or col_major blayout"; - if (tb.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) - return op->emitOpError() << "expects " << name - << " to use the none_box slayout"; - return success(); -} - -static bool isA5TLoadStoreTransferElemType(Type ty) { - return ty.isInteger(8) || ty.isInteger(16) || ty.isInteger(32) || - ty.isInteger(64) || ty.isF16() || ty.isBF16() || ty.isF32() || - isPTOLowPrecisionType(ty); -} - -static bool isA5AccStorePreQuantDstType(Type srcElem, Type dstElem) { - if (srcElem.isInteger(32)) - return dstElem.isInteger(8) || dstElem.isF16() || dstElem.isBF16(); - if (!srcElem.isF32()) - return false; - return dstElem.isInteger(8) || dstElem.isF16() || dstElem.isBF16() || - dstElem.isF32() || isPTOHiFloat8Type(dstElem) || - dstElem.isFloat8E4M3() || dstElem.isFloat8E4M3FN() || - dstElem.isFloat8E4M3FNUZ() || dstElem.isFloat8E4M3B11FNUZ(); -} - -static bool isA5LowPrecisionTCvtPair(Type srcElem, Type dstElem) { - if (srcElem.isF32()) - return isPTOFloat8Type(dstElem) || isPTOHiFloat8Type(dstElem); - if (srcElem.isF16()) - return isPTOHiFloat8Type(dstElem); - if (srcElem.isBF16()) - return isPTOFloat4PackedType(dstElem); - if (isPTOFloat4PackedType(srcElem)) - return dstElem.isBF16(); - if (isPTOFloat8Type(srcElem) || isPTOHiFloat8Type(srcElem)) - return dstElem.isF32(); - return false; -} - -static bool isA5SupportedTCvtPair(Type srcElem, Type dstElem) { - if (isPTOLowPrecisionType(srcElem) || isPTOLowPrecisionType(dstElem)) - return isA5LowPrecisionTCvtPair(srcElem, dstElem); - return true; -} - -static LogicalResult verifyTileBufCommon(Operation *op, Type ty, StringRef name, - bool allowLowPrecision) { - auto tb = dyn_cast(ty); - if (tb) { - if (tb.getRank() != 2) - return op->emitOpError() << "expects " << name << " to be a rank-2 tile_buf"; - Type elemTy = tb.getElementType(); - if (!allowLowPrecision && isPTOLowPrecisionType(elemTy)) - return op->emitOpError() << name << ": dtype " << elemTy - << " is not supported by this op yet"; - } else if (auto mr = dyn_cast(ty)) { - if (mr.getRank() != 2) - return op->emitOpError() << "expects " << name << " to be a rank-2 memref"; - if (!allowLowPrecision && isPTOLowPrecisionType(mr.getElementType())) - return op->emitOpError() << name << ": dtype " << mr.getElementType() - << " is not supported by this op yet"; - } else { - return op->emitOpError() << "expects " << name << " to be a !pto.tile_buf or rank-2 memref"; - } - - auto validShape = getValidShapeVec(ty); - if (validShape.size() != 2) - return op->emitOpError() << "expects " << name << " to have a rank-2 valid_shape"; - auto shape = getShapeVec(ty); - for (unsigned i = 0; i < 2; ++i) { - if (shape[i] != ShapedType::kDynamic && validShape[i] != ShapedType::kDynamic && - validShape[i] > shape[i]) - return op->emitOpError() << "expects " << name << " to satisfy valid_shape[" << i - << "] <= shape[" << i << "]"; - } - return success(); -} - -static LogicalResult verifyTileBufSameElemType(Operation *op, Type lhs, Type rhs, - StringRef lhsName, - StringRef rhsName) { - if (!isTileLikeType(lhs) || !isTileLikeType(rhs)) - return op->emitOpError() << "expects " << lhsName << " and " << rhsName - << " to be !pto.tile_buf or memref"; - if (getElemTy(lhs) != getElemTy(rhs)) - return op->emitOpError() << "expects " << lhsName << " and " << rhsName - << " to have the same element type"; - return success(); -} - -static LogicalResult verifyTileBufSameValidShape(Operation *op, Type lhs, Type rhs, - StringRef lhsName, StringRef rhsName) { - if (!isTileLikeType(lhs) || !isTileLikeType(rhs)) - return success(); - auto lhsValid = getValidShapeVec(lhs); - auto rhsValid = getValidShapeVec(rhs); - for (size_t i = 0; i < lhsValid.size() && i < rhsValid.size(); ++i) { - if (lhsValid[i] != ShapedType::kDynamic && rhsValid[i] != ShapedType::kDynamic && - lhsValid[i] != rhsValid[i]) - return op->emitOpError() << "expects " << lhsName << " and " << rhsName - << " to have the same valid_shape"; - } - if (lhsValid.size() != rhsValid.size()) - return op->emitOpError() << "expects " << lhsName << " and " << rhsName - << " to have the same valid_shape"; - return success(); -} - -static LogicalResult verifyTileBufSameLogicalExtent(Operation *op, Type lhs, - Type rhs, StringRef lhsName, - StringRef rhsName, - bool compareValidShape) { - if (!isTileLikeType(lhs) || !isTileLikeType(rhs)) - return success(); - - auto lhsExtent = getLogicalTileExtentVec(lhs, compareValidShape); - auto rhsExtent = getLogicalTileExtentVec(rhs, compareValidShape); - auto emitMismatch = [&]() -> LogicalResult { - if (compareValidShape) - return op->emitOpError() << "expects " << lhsName << " and " << rhsName - << " to have the same valid_shape"; - return op->emitOpError() << "expects " << lhsName << " and " << rhsName - << " to have compatible shapes"; - }; - if (lhsExtent.size() != rhsExtent.size()) - return emitMismatch(); - - for (size_t i = 0, e = lhsExtent.size(); i < e; ++i) { - if (lhsExtent[i] != ShapedType::kDynamic && - rhsExtent[i] != ShapedType::kDynamic && lhsExtent[i] != rhsExtent[i]) - return emitMismatch(); - } - return success(); -} - -static LogicalResult verifyScaleTileMatchesOperand(Operation *op, Type scaleTy, - Type operandTy, - StringRef scaleName, - StringRef operandName) { - if (failed(verifyTileBufCommon(op, scaleTy, scaleName))) - return failure(); - auto scaleSpace = getPTOMemorySpaceEnum(scaleTy); - if (!scaleSpace || *scaleSpace != pto::AddressSpace::SCALING) - return op->emitOpError() << "expects " << scaleName - << " to be in the scaling address space"; - - auto scaleShape = getShapeVec(scaleTy); - auto operandShape = getShapeVec(operandTy); - if (scaleShape.size() != operandShape.size()) - return op->emitOpError() << "expects " << scaleName << " and " << operandName - << " to have the same rank"; - for (size_t i = 0; i < scaleShape.size(); ++i) { - if (scaleShape[i] != ShapedType::kDynamic && - operandShape[i] != ShapedType::kDynamic && - scaleShape[i] != operandShape[i]) - return op->emitOpError() << "expects " << scaleName << " and " << operandName - << " to have the same shape"; - } - - auto scaleValid = getValidShapeVec(scaleTy); - auto operandValid = getValidShapeVec(operandTy); - if (scaleValid.size() != operandValid.size()) - return op->emitOpError() << "expects " << scaleName << " and " << operandName - << " to have the same valid_shape"; - for (size_t i = 0; i < scaleValid.size(); ++i) { - if (scaleValid[i] != ShapedType::kDynamic && - operandValid[i] != ShapedType::kDynamic && - scaleValid[i] != operandValid[i]) - return op->emitOpError() << "expects " << scaleName << " and " << operandName - << " to have the same valid_shape"; - } - return success(); -} - -static LogicalResult verifyPartialValidPattern(Operation *op, Type src0Ty, - Type src1Ty, Type dstTy) { - auto src0Valid = getValidShapeVec(src0Ty); - auto src1Valid = getValidShapeVec(src1Ty); - auto dstValid = getValidShapeVec(dstTy); - if (src0Valid.size() != 2 || src1Valid.size() != 2 || dstValid.size() != 2) - return op->emitOpError("expects src0, src1, and dst to have rank-2 valid_shape"); - - auto lessEqualKnown = [](int64_t lhs, int64_t rhs) { - return lhs == ShapedType::kDynamic || rhs == ShapedType::kDynamic || lhs <= rhs; - }; - auto equalsKnown = [](ArrayRef lhs, ArrayRef rhs) { - for (auto [a, b] : llvm::zip(lhs, rhs)) { - if (a != ShapedType::kDynamic && b != ShapedType::kDynamic && a != b) - return false; - } - return true; - }; - - for (unsigned i = 0; i < 2; ++i) { - if (!lessEqualKnown(src0Valid[i], dstValid[i]) || - !lessEqualKnown(src1Valid[i], dstValid[i])) - return op->emitOpError( - "expects src0/src1 valid_shape to be less than or equal to dst valid_shape"); - } - if (!equalsKnown(src0Valid, dstValid) && !equalsKnown(src1Valid, dstValid)) - return op->emitOpError( - "expects at least one of src0/src1 valid_shape to match dst valid_shape"); - return success(); -} - -[[maybe_unused]] static bool hasKnownZeroValidRegion(Type ty) { - auto valid = getValidShapeVec(ty); - if (valid.size() != 2) - return false; - return valid[0] == 0 || valid[1] == 0; -} - -static LogicalResult verifyScalarTileOp(Operation *op, Type srcTy, Type dstTy, - StringRef srcName, StringRef dstName, - bool requireValidRowsEqual, - bool requireValidColsEqual) { - if (failed(verifyTileBufCommon(op, srcTy, srcName)) || - failed(verifyTileBufCommon(op, dstTy, dstName))) - return failure(); - auto srcSpace = getPTOMemorySpaceEnum(srcTy); - auto dstSpace = getPTOMemorySpaceEnum(dstTy); - if (!srcSpace || *srcSpace != pto::AddressSpace::VEC) - return op->emitOpError() << "expects " << srcName - << " to be in the vec address space"; - if (!dstSpace || *dstSpace != pto::AddressSpace::VEC) - return op->emitOpError() << "expects " << dstName - << " to be in the vec address space"; - if (failed(verifyTileBufSameElemType(op, srcTy, dstTy, srcName, dstName))) - return failure(); - - auto srcValid = getValidShapeVec(srcTy); - auto dstValid = getValidShapeVec(dstTy); - if (srcValid.size() != 2 || dstValid.size() != 2) - return op->emitOpError() - << "expects " << srcName << " and " << dstName - << " to have rank-2 valid_shape"; - if (requireValidRowsEqual && - srcValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && - srcValid[0] != dstValid[0]) - return op->emitOpError() - << "expects " << srcName << " and " << dstName - << " to have the same valid_shape[0]"; - if (requireValidColsEqual && - srcValid[1] != ShapedType::kDynamic && dstValid[1] != ShapedType::kDynamic && - srcValid[1] != dstValid[1]) - return op->emitOpError() - << "expects " << srcName << " and " << dstName - << " to have the same valid_shape[1]"; - return success(); -} - -static FailureOr -verifyMatchingRowMajorBinaryTileOpCommon(Operation *op, Type src0Ty, Type src1Ty, - Type dstTy) { - if (failed(verifyTileBufCommon(op, src0Ty, "src0")) || - failed(verifyTileBufCommon(op, src1Ty, "src1")) || - failed(verifyTileBufCommon(op, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameElemType(op, src0Ty, src1Ty, "src0", "src1")) || - failed(verifyTileBufSameElemType(op, src0Ty, dstTy, "src0", "dst")) || - failed(verifyTileBufSameValidShape(op, src0Ty, src1Ty, "src0", "src1")) || - failed(verifyTileBufSameValidShape(op, src0Ty, dstTy, "src0", "dst"))) - return failure(); - if (!isRowMajorTileBuf(src0Ty) || !isRowMajorTileBuf(src1Ty) || - !isRowMajorTileBuf(dstTy)) { - op->emitOpError("expects src0, src1, and dst to use row-major layout"); - return failure(); - } - return getElemTy(src0Ty); -} - -static FailureOr -verifyNumericScalarTileOpCommon(Operation *op, Type srcTy, Type dstTy, - Type scalarTy, bool requireValidRowsEqual) { - if (failed(verifyScalarTileOp(op, srcTy, dstTy, "src", "dst", - requireValidRowsEqual, - /*requireValidColsEqual=*/true))) - return failure(); - if (!mlir::isa(scalarTy)) { - op->emitOpError("scalar must be a scalar type (integer/float)"); - return failure(); - } - return getElemTy(srcTy); -} - -static FailureOr -verifyShiftLikeBinaryTileOpCommon(Operation *op, Type src0Ty, Type src1Ty, - Type dstTy) { - if (failed(verifyTileBufCommon(op, src0Ty, "src0")) || - failed(verifyTileBufCommon(op, src1Ty, "src1")) || - failed(verifyTileBufCommon(op, dstTy, "dst"))) - return failure(); - Type e0 = getElemTy(src0Ty); - Type e1 = getElemTy(src1Ty); - if (!e0 || !e1) { - op->emitOpError("failed to get element type for operands"); - return failure(); - } - if (e0 != e1) { - op->emitOpError("expects src0 and src1 to have the same element type"); - return failure(); - } - if (!isRowMajorTileBuf(src0Ty) || !isRowMajorTileBuf(src1Ty) || - !isRowMajorTileBuf(dstTy)) { - op->emitOpError("expects src0, src1, and dst to use row-major layout"); - return failure(); - } - if (failed(verifyTileBufSameValidShape(op, src0Ty, dstTy, "src0", "dst")) || - failed(verifyTileBufSameValidShape(op, src1Ty, dstTy, "src1", "dst"))) - return failure(); - return e0; -} - -static FailureOr verifyDistinctRowMajorUnaryTileOpCommon( - Operation *op, Value src, Value dst, StringRef srcName = "src", - StringRef dstName = "dst") { - if (src == dst) { - op->emitOpError("expects src and dst to use different storage"); - return failure(); - } - Type srcTy = src.getType(); - Type dstTy = dst.getType(); - if (failed(verifyTileBufCommon(op, srcTy, srcName)) || - failed(verifyTileBufCommon(op, dstTy, dstName))) - return failure(); - - Type srcElem = getElemTy(srcTy); - Type dstElem = getElemTy(dstTy); - if (!srcElem || !dstElem) { - op->emitOpError("failed to get element type for src/dst"); - return failure(); - } - if (srcElem != dstElem) { - op->emitOpError("expects src and dst to have the same element type"); - return failure(); - } - if (!isRowMajorTileBuf(srcTy) || !isRowMajorTileBuf(dstTy)) { - op->emitOpError("expects src and dst to use row-major layout"); - return failure(); - } - if (failed(verifyTileBufSameValidShape(op, srcTy, dstTy, srcName, dstName))) - return failure(); - return srcElem; -} - -static LogicalResult verifyArithmeticElemTypeForArch( - Operation *op, Type elemTy, PTOArch targetArch, bool allowInt8OnA5, - bool allowBf16OnA5, StringRef a2a3Error, StringRef a5Error) { - bool supported = elemTy.isInteger(32) || elemTy.isInteger(16) || - elemTy.isF16() || elemTy.isF32(); - if (targetArch == PTOArch::A5) - supported = supported || (allowInt8OnA5 && elemTy.isInteger(8)) || - (allowBf16OnA5 && elemTy.isBF16()); - if (supported) - return success(); - return op->emitOpError(targetArch == PTOArch::A5 ? a5Error : a2a3Error); -} - -static LogicalResult verifyArithmeticBinaryTileOpWithArchDispatch( - Operation *op, Type src0Ty, Type src1Ty, Type dstTy, bool allowInt8OnA5, - bool allowBf16OnA5, StringRef a2a3Error, StringRef a5Error) { - auto verifyByArch = [&](PTOArch targetArch) -> LogicalResult { - FailureOr elemOr = - verifyMatchingRowMajorBinaryTileOpCommon(op, src0Ty, src1Ty, dstTy); - if (failed(elemOr)) - return failure(); - return verifyArithmeticElemTypeForArch(op, *elemOr, targetArch, - allowInt8OnA5, allowBf16OnA5, - a2a3Error, a5Error); - }; - auto verifyA2A3 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A3); }; - auto verifyA5 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A5); }; - return dispatchVerifierByArch(op, verifyA2A3, verifyA5); -} - -static LogicalResult verifyArithmeticScalarTileOpWithArchDispatch( - Operation *op, Type srcTy, Type dstTy, Type scalarTy, bool allowInt8OnA5, - bool allowBf16OnA5, StringRef a2a3Error, StringRef a5Error, - bool requireValidRowsEqualOnA2A3 = true, - bool requireValidRowsEqualOnA5 = false) { - auto verifyByArch = [&](PTOArch targetArch, - bool requireValidRowsEqual) -> LogicalResult { - FailureOr elemOr = verifyNumericScalarTileOpCommon( - op, srcTy, dstTy, scalarTy, requireValidRowsEqual); - if (failed(elemOr)) - return failure(); - return verifyArithmeticElemTypeForArch(op, *elemOr, targetArch, - allowInt8OnA5, allowBf16OnA5, - a2a3Error, a5Error); - }; - auto verifyA2A3 = [&]() -> LogicalResult { - return verifyByArch(PTOArch::A3, requireValidRowsEqualOnA2A3); - }; - auto verifyA5 = [&]() -> LogicalResult { - return verifyByArch(PTOArch::A5, requireValidRowsEqualOnA5); - }; - return dispatchVerifierByArch(op, verifyA2A3, verifyA5); -} - -static LogicalResult verifyTColReductionElemTypeForArch( - Operation *op, Type elemTy, PTOArch targetArch, bool allowInt8OnA5, - bool allowBf16OnA5, StringRef a2a3Error, StringRef a5Error) { - bool ok = elemTy.isF16() || elemTy.isF32() || elemTy.isInteger(16) || - elemTy.isInteger(32); - if (targetArch == PTOArch::A5) - ok = ok || (allowInt8OnA5 && elemTy.isInteger(8)) || - (allowBf16OnA5 && elemTy.isBF16()); - if (ok) - return success(); - return op->emitOpError(targetArch == PTOArch::A5 ? a5Error : a2a3Error); -} - -static LogicalResult verifyTColReductionOpWithArchDispatch( - Operation *op, Type srcTy, Type dstTy, bool requireNonZeroSrcOnA2A3, - bool requireNonZeroSrcOnA5, bool allowInt8OnA5, bool allowBf16OnA5, - StringRef a2a3Error, StringRef a5Error) { - auto verifyByArch = [&](PTOArch targetArch, - bool requireNonZeroSrc) -> LogicalResult { - if (failed(verifyNDStyleVecTile(op, srcTy, "src")) || - failed(verifyNDStyleVecTile(op, dstTy, "dst"))) - return failure(); - if (getElemTy(srcTy) != getElemTy(dstTy)) - return op->emitOpError("expects src and dst to have the same element type"); - if (failed(verifyColReductionValidRegion(op, srcTy, dstTy, requireNonZeroSrc))) - return failure(); - Type elem = getElemTy(srcTy); - return verifyTColReductionElemTypeForArch(op, elem, targetArch, allowInt8OnA5, - allowBf16OnA5, a2a3Error, a5Error); - }; - auto verifyA2A3 = [&]() -> LogicalResult { - return verifyByArch(PTOArch::A3, requireNonZeroSrcOnA2A3); - }; - auto verifyA5 = [&]() -> LogicalResult { - return verifyByArch(PTOArch::A5, requireNonZeroSrcOnA5); - }; - return dispatchVerifierByArch(op, verifyA2A3, verifyA5); -} - -static LogicalResult verifyTColArgReductionOpCommon(Operation *op, Type srcTy, - Type tmpTy, Type dstTy) { - if (failed(verifyNDStyleVecTile(op, srcTy, "src")) || - failed(verifyVecTileCommon(op, tmpTy, "tmp")) || - failed(verifyColArgReductionDstLayout(op, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameElemType(op, srcTy, tmpTy, "src", "tmp")) || - failed(verifyTileBufSameValidShape(op, srcTy, tmpTy, "src", "tmp"))) - return failure(); - if (failed(verifyColReductionValidRegion(op, srcTy, dstTy, - /*requireNonZeroSrc=*/true))) - return failure(); - Type srcElemTy = getElemTy(srcTy); - unsigned srcElemBits = srcElemTy ? getPTOStorageElemBitWidth(srcElemTy) : 0; - if (!(mlir::isa(srcElemTy) && - (srcElemBits == 8 || srcElemBits == 16 || srcElemBits == 32))) - return op->emitOpError( - "expects src/tmp element type to be 1, 2, or 4 bytes wide"); - auto dstInt = dyn_cast(getElemTy(dstTy)); - if (!dstInt || dstInt.getWidth() != 32) - return op->emitOpError("expects dst element type to be i32 or ui32"); - return success(); -} - -static bool hasCompatibleKnownExtent(int64_t lhs, int64_t rhs) { - return lhs == ShapedType::kDynamic || rhs == ShapedType::kDynamic || lhs == rhs; -} - -static bool isKnownUnitExtent(int64_t value) { - return value == ShapedType::kDynamic || value == 1; -} - -static LogicalResult verifyVecTileStorage(Operation *op, Type ty, StringRef name) { - if (failed(verifyTileBufCommon(op, ty, name))) - return failure(); - auto as = getPTOMemorySpaceEnum(ty); - if (!as || *as != pto::AddressSpace::VEC) - return op->emitOpError() << "expects " << name << " to be in the vec address space"; - return success(); -} - -static LogicalResult verifyVecTileCommonA2A3(Operation *op, Type ty, - StringRef name) { - if (failed(verifyTileBufCommon(op, ty, name))) - return failure(); - auto tb = dyn_cast(ty); - auto as = getPTOMemorySpaceEnum(ty); - if (as && *as != pto::AddressSpace::VEC) - return op->emitOpError() << "expects " << name << " to be in the vec address space"; - if (tb && tb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor)) - return op->emitOpError() << "expects " << name << " to use the row_major blayout"; - return success(); -} - -static LogicalResult verifyVecTileCommonA5(Operation *op, Type ty, - StringRef name) { - return verifyVecTileCommonA2A3(op, ty, name); -} - -static LogicalResult verifyVecTileCommon(Operation *op, Type ty, StringRef name) { - switch (getVerifierTargetArch(op)) { - case VerifierTargetArch::A2A3: - return verifyVecTileCommonA2A3(op, ty, name); - case VerifierTargetArch::A5: - return verifyVecTileCommonA5(op, ty, name); - } - return failure(); -} - -static LogicalResult verifyVecTileUnaryOp(Operation *op, Type srcTy, Type dstTy, - StringRef srcName, - StringRef dstName, - bool allowBf16, - bool allowInt8) { - if (failed(verifyVecTileCommon(op, srcTy, srcName)) || - failed(verifyVecTileCommon(op, dstTy, dstName))) - return failure(); - if (failed(verifyTileBufSameElemType(op, srcTy, dstTy, srcName, dstName))) - return failure(); - if (!isSupportedVecElemType(getElemTy(srcTy), allowBf16, allowInt8)) - return op->emitOpError() << "expects vec tile element types to be supported"; - return success(); -} - -static LogicalResult verifyAccTileCommonA2A3(Operation *op, Type ty, - StringRef name) { - if (failed(verifyTileBufCommon(op, ty, name))) - return failure(); - auto as = getPTOMemorySpaceEnum(ty); - if (!as || *as != pto::AddressSpace::ACC) - return op->emitOpError() << "expects " << name << " to be in the acc address space"; - return success(); -} - -static LogicalResult verifyAccTileCommonA5(Operation *op, Type ty, - StringRef name) { - return verifyAccTileCommonA2A3(op, ty, name); -} - -static LogicalResult verifyAccTileCommon(Operation *op, Type ty, StringRef name) { - switch (getVerifierTargetArch(op)) { - case VerifierTargetArch::A2A3: - return verifyAccTileCommonA2A3(op, ty, name); - case VerifierTargetArch::A5: - return verifyAccTileCommonA5(op, ty, name); - } - return failure(); -} - -static LogicalResult verifyMatTileOperandsA2A3(Operation *op, Type lhsTy, - Type rhsTy, Type dstTy) { - if (failed(verifyTileBufCommon(op, lhsTy, "lhs")) || - failed(verifyTileBufCommon(op, rhsTy, "rhs")) || - failed(verifyAccTileCommon(op, dstTy, "dst"))) - return failure(); - auto lhsSpace = getPTOMemorySpaceEnum(lhsTy); - auto rhsSpace = getPTOMemorySpaceEnum(rhsTy); - auto dstSpace = getPTOMemorySpaceEnum(dstTy); - if (!lhsSpace || !rhsSpace || !dstSpace) - return op->emitOpError("expects lhs, rhs, and dst to have explicit address spaces"); - if (*lhsSpace != pto::AddressSpace::LEFT || *rhsSpace != pto::AddressSpace::RIGHT || - *dstSpace != pto::AddressSpace::ACC) - return op->emitOpError( - "expects lhs, rhs, and dst to use the left, right, and acc address spaces"); - auto lhsShape = getMatmulLogicalShapeVec(lhsTy); - auto rhsShape = getMatmulLogicalShapeVec(rhsTy); - auto dstShape = getMatmulLogicalShapeVec(dstTy); - if ((lhsShape[0] != dstShape[0] || rhsShape[1] != dstShape[1] || lhsShape[1] != rhsShape[0])) - return op->emitOpError( - "expects static matmul tile shapes lhs[M,K], rhs[K,N], and dst[M,N]"); - auto lhsValid = getValidShapeVec(lhsTy); - auto rhsValid = getValidShapeVec(rhsTy); - if (lhsValid.size() == 2 && rhsValid.size() == 2) { - int64_t m = lhsValid[0]; - int64_t k = lhsValid[1]; - int64_t n = rhsValid[1]; - if ((m != ShapedType::kDynamic && (m < 1 || m > 4095)) || - (k != ShapedType::kDynamic && (k < 1 || k > 4095)) || - (n != ShapedType::kDynamic && (n < 1 || n > 4095))) - return op->emitOpError("expects m, k, and n valid sizes to be in [1, 4095]"); - } - return success(); -} - -static LogicalResult verifyMatTileOperandsA5(Operation *op, Type lhsTy, - Type rhsTy, Type dstTy) { - if (failed(verifyMatTileOperandsA2A3(op, lhsTy, rhsTy, dstTy))) - return failure(); - - auto lhsTb = mlir::dyn_cast(lhsTy); - auto rhsTb = mlir::dyn_cast(rhsTy); - auto dstTb = mlir::dyn_cast(dstTy); - if (!lhsTb || !rhsTb || !dstTb) - return success(); - - if (lhsTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor)) - return op->emitOpError("expects lhs to use the col_major blayout on A5"); - if (rhsTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor)) - return op->emitOpError("expects rhs to use the row_major blayout on A5"); - if (dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor)) - return op->emitOpError("expects dst to use the col_major blayout on A5"); - - if (lhsTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor)) - return op->emitOpError("expects lhs to use the row_major slayout on A5"); - if (rhsTb.getSLayoutValueI32() != static_cast(pto::SLayout::ColMajor)) - return op->emitOpError("expects rhs to use the col_major slayout on A5"); - if (dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor)) - return op->emitOpError("expects dst to use the row_major slayout on A5"); - return success(); -} - -static LogicalResult verifyMatTileOperands(Operation *op, Type lhsTy, Type rhsTy, - Type dstTy) { - switch (getVerifierTargetArch(op)) { - case VerifierTargetArch::A2A3: - return verifyMatTileOperandsA2A3(op, lhsTy, rhsTy, dstTy); - case VerifierTargetArch::A5: - return verifyMatTileOperandsA5(op, lhsTy, rhsTy, dstTy); - } - return failure(); -} - -static LogicalResult verifyGemvTileOperandsA2A3(Operation *op, Type lhsTy, - Type rhsTy, Type dstTy) { - if (failed(verifyTileBufCommon(op, lhsTy, "lhs")) || - failed(verifyTileBufCommon(op, rhsTy, "rhs")) || - failed(verifyAccTileCommon(op, dstTy, "dst"))) - return failure(); - - auto lhsSpace = getPTOMemorySpaceEnum(lhsTy); - auto rhsSpace = getPTOMemorySpaceEnum(rhsTy); - if (!lhsSpace || !rhsSpace) - return op->emitOpError("expects lhs and rhs to have explicit address spaces"); - if (*lhsSpace != pto::AddressSpace::LEFT || *rhsSpace != pto::AddressSpace::RIGHT) - return op->emitOpError( - "expects lhs and rhs to use the left and right address spaces"); - - auto lhsValid = getValidShapeVec(lhsTy); - auto rhsValid = getValidShapeVec(rhsTy); - auto dstValid = getValidShapeVec(dstTy); - if (lhsValid[0] != ShapedType::kDynamic && lhsValid[0] != 1) - return op->emitOpError("expects lhs valid_shape[0] to be 1 for tgemv"); - if (isa(dstTy) && dstValid[0] != ShapedType::kDynamic && - dstValid[0] != 1) - return op->emitOpError("expects dst valid_shape[0] to be 1 for tgemv"); - if (lhsValid[1] != ShapedType::kDynamic && rhsValid[0] != ShapedType::kDynamic && - lhsValid[1] != rhsValid[0]) - return op->emitOpError() - << "expects lhs valid_shape[1] to equal rhs valid_shape[0], but got " - << lhsValid[1] << " vs " << rhsValid[0]; - if (rhsValid[1] != ShapedType::kDynamic && dstValid[1] != ShapedType::kDynamic && - rhsValid[1] != dstValid[1]) - return op->emitOpError() - << "expects rhs valid_shape[1] to equal dst valid_shape[1], but got " - << rhsValid[1] << " vs " << dstValid[1]; - return success(); -} - -static LogicalResult verifyGemvTileOperandsA5(Operation *op, Type lhsTy, - Type rhsTy, Type dstTy) { - if (failed(verifyGemvTileOperandsA2A3(op, lhsTy, rhsTy, dstTy))) - return failure(); - return verifyMatTileOperandsA5(op, lhsTy, rhsTy, dstTy); -} - -static LogicalResult verifyGemvTileOperands(Operation *op, Type lhsTy, Type rhsTy, - Type dstTy) { - switch (getVerifierTargetArch(op)) { - case VerifierTargetArch::A2A3: - return verifyGemvTileOperandsA2A3(op, lhsTy, rhsTy, dstTy); - case VerifierTargetArch::A5: - return verifyGemvTileOperandsA5(op, lhsTy, rhsTy, dstTy); - } - return failure(); -} - -static LogicalResult verifyMatBiasTileA2A3(Operation *op, Type biasTy, Type dstTy, - bool requireFloatBias) { - if (failed(verifyTileBufCommon(op, biasTy, "bias"))) - return failure(); - auto biasSpace = getPTOMemorySpaceEnum(biasTy); - if (!biasSpace || *biasSpace != pto::AddressSpace::BIAS) - return op->emitOpError("expects bias to be in the bias address space"); - auto biasShape = getShapeVec(biasTy); - if (biasShape[0] != ShapedType::kDynamic && biasShape[0] != 1) - return op->emitOpError("expects bias to have 1 row"); - if (requireFloatBias) { - if (!getElemTy(biasTy).isF32()) - return op->emitOpError("expects bias to have element type f32"); - } else if (getElemTy(biasTy) != getElemTy(dstTy)) { - return op->emitOpError("expects bias and dst to have the same element type"); - } - return success(); -} - -static LogicalResult verifyMatBiasTileA5(Operation *op, Type biasTy, Type dstTy, - bool requireFloatBias) { - if (failed(verifyMatBiasTileA2A3(op, biasTy, dstTy, requireFloatBias))) - return failure(); - if (auto biasTb = dyn_cast(biasTy)) { - if (biasTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor)) - return op->emitOpError("expects bias to use the row_major blayout on A5"); - } - return success(); -} - -static LogicalResult verifyMatBiasTile(Operation *op, Type biasTy, Type dstTy, - bool requireFloatBias) { - switch (getVerifierTargetArch(op)) { - case VerifierTargetArch::A2A3: - return verifyMatBiasTileA2A3(op, biasTy, dstTy, requireFloatBias); - case VerifierTargetArch::A5: - return verifyMatBiasTileA5(op, biasTy, dstTy, requireFloatBias); - } - return failure(); -} - -static LogicalResult verifyMatmulTypeTriple(Operation *op, Type lhsElemTy, - Type rhsElemTy, Type dstElemTy) { - bool isA5 = getVerifierTargetArch(op) == VerifierTargetArch::A5; - auto isInt8 = [](Type ty) { - return ty.isInteger(8); - }; - if (dstElemTy.isInteger(32) && isInt8(lhsElemTy) && isInt8(rhsElemTy)) - return success(); - - auto isSupportedFpInput = [](Type ty) { - return ty.isF16() || ty.isBF16() || ty.isF32(); - }; - if (dstElemTy.isF32() && lhsElemTy == rhsElemTy && isSupportedFpInput(lhsElemTy)) - return success(); - - if (isA5 && dstElemTy.isF32() && lhsElemTy == rhsElemTy) { - if (auto ft = mlir::dyn_cast(lhsElemTy)) { - unsigned width = ft.getWidth(); - if (width == 8 || width == 16 || width == 32) - return success(); - } - } - - return op->emitOpError() - << "expects (dst, lhs, rhs) element types to match one of " - "(i32, i8, i8), (f32, f16, f16), (f32, bf16, bf16), (f32, f32, f32)" - << (isA5 ? ", or an A5-supported fp8 pair" : ""); -} - -LogicalResult pto::TAddOp::verify() { - return verifyArithmeticBinaryTileOpWithArchDispatch( - getOperation(), getSrc0().getType(), getSrc1().getType(), getDst().getType(), - /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, - "expects A2/A3 tadd element type to be i32/i16/f16/f32", - "expects A5 tadd element type to be i32/i16/i8/f16/bf16/f32"); -} - -LogicalResult pto::TAddCOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - Type t0 = getSrc0().getType(); - Type t1 = getSrc1().getType(); - Type t2 = getSrc2().getType(); - Type td = getDst().getType(); - - if (!isPTOShapedLike(t0) || !isPTOShapedLike(t1) || - !isPTOShapedLike(t2) || !isPTOShapedLike(td)) - return emitOpError("expects src0/src1/src2/dst to be memref/tile_buf types"); - - auto s0 = getShapeVec(t0); - auto s1 = getShapeVec(t1); - auto s2 = getShapeVec(t2); - auto sd = getShapeVec(td); - if (s0 != s1 || s0 != s2 || s0 != sd) - return emitOpError("expects src0/src1/src2/dst to have the same shape"); - return success(); -} -LogicalResult pto::TAddSOp::verify() { - return verifyArithmeticScalarTileOpWithArchDispatch( - getOperation(), getSrc().getType(), getDst().getType(), getScalar().getType(), - /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, - "expects A2/A3 tadds element type to be i32/i16/f16/f32", - "expects A5 tadds element type to be i32/i16/i8/f16/bf16/f32", - /*requireValidRowsEqualOnA2A3=*/true, - /*requireValidRowsEqualOnA5=*/false); -} - -LogicalResult pto::TAxpyOp::verify() { - auto verifyCommon = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyVecTileCommon(*this, srcTy, "src")) || - failed(verifyVecTileCommon(*this, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) - return failure(); - - Type scalarTy = getScalar().getType(); - Type srcElem = getElemTy(srcTy); - if (scalarTy != srcElem) - return emitOpError("expects scalar type to match src element type"); - if (getShapeVec(srcTy) != getShapeVec(dstTy)) - return emitOpError("expects src and dst to have the same shape"); - return success(); - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - if (failed(verifyCommon())) - return failure(); - Type srcElem = getElemTy(getSrc().getType()); - Type dstElem = getElemTy(getDst().getType()); - bool sameType = srcElem == dstElem; - bool widenF16ToF32 = srcElem.isF16() && dstElem.isF32(); - if (!(sameType || widenF16ToF32)) - return emitOpError( - "expects dst/src element types to match, or dst=f32 and src=f16"); - if (!(dstElem.isF16() || dstElem.isF32())) - return emitOpError("expects A2/A3 taxpy dst element type to be f16/f32"); - if (!(srcElem.isF16() || srcElem.isF32())) - return emitOpError("expects A2/A3 taxpy src element type to be f16/f32"); - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - if (failed(verifyCommon())) - return failure(); - Type srcElem = getElemTy(getSrc().getType()); - Type dstElem = getElemTy(getDst().getType()); - bool sameType = srcElem == dstElem; - bool widenF16ToF32 = srcElem.isF16() && dstElem.isF32(); - if (!(sameType || widenF16ToF32)) - return emitOpError( - "expects dst/src element types to match, or dst=f32 and src=f16"); - if (!(dstElem.isF16() || dstElem.isF32() || dstElem.isBF16())) - return emitOpError("expects A5 taxpy dst element type to be f16/bf16/f32"); - if (!(srcElem.isF16() || srcElem.isF32() || srcElem.isBF16())) - return emitOpError("expects A5 taxpy src element type to be f16/bf16/f32"); - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult pto::TAddSCOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - Type ts0 = getSrc0().getType(); - Type ts1 = getSrc1().getType(); - Type td = getDst().getType(); - if (!isPTOShapedLike(ts0) || !isPTOShapedLike(ts1) || !isPTOShapedLike(td)) - return emitOpError("expects src0/src1/dst to be PTO shaped-like types"); - - auto s0 = getShapeVec(ts0); - auto s1 = getShapeVec(ts1); - auto sd = getShapeVec(td); - if (s0 != s1 || s0 != sd) - return emitOpError("expects src0/src1/dst to have the same shape"); - return success(); -} - -LogicalResult pto::TAndOp::verify() { - auto verifyCommon = [&]() -> FailureOr { - return verifyMatchingRowMajorBinaryTileOpCommon( - getOperation(), getSrc0().getType(), getSrc1().getType(), - getDst().getType()); - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - FailureOr elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - auto it = mlir::dyn_cast(*elemOr); - if (!it || (it.getWidth() != 8 && it.getWidth() != 16)) - return emitOpError( - "expects A2/A3 tand src0, src1, and dst element type to be i8/i16"); - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - FailureOr elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - auto it = mlir::dyn_cast(*elemOr); - if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && - it.getWidth() != 32)) - return emitOpError( - "expects A5 tand src0, src1, and dst element type to be i8/i16/i32"); - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TConcatOp::verify() { - auto verifyCommon = [&]() -> FailureOr { - Type t0 = getSrc0().getType(); - Type t1 = getSrc1().getType(); - Type td = getDst().getType(); - if (failed(verifyTileBufCommon(*this, t0, "src0")) || - failed(verifyTileBufCommon(*this, t1, "src1")) || - failed(verifyTileBufCommon(*this, td, "dst"))) - return failure(); - - Type e0 = getElemTy(t0); - Type e1 = getElemTy(t1); - Type ed = getElemTy(td); - if (!e0 || !e1 || !ed) { - emitOpError("failed to get element type for operands"); - return failure(); - } - if (e0 != e1 || e0 != ed) { - emitOpError("expects src0, src1, and dst to have the same element type"); - return failure(); - } - - auto v0 = getValidShapeVec(getSrc0()); - auto v1 = getValidShapeVec(getSrc1()); - auto vd = getValidShapeVec(getDst()); - if (v0.size() != 2 || v1.size() != 2 || vd.size() != 2) - return emitOpError("expects src0, src1, and dst to have rank-2 valid_shape"); - - // validRow must match dst (when known). - if (v0[0] != ShapedType::kDynamic && vd[0] != ShapedType::kDynamic && v0[0] != vd[0]) - return emitOpError("expects src0 valid row to match dst valid row"); - if (v1[0] != ShapedType::kDynamic && vd[0] != ShapedType::kDynamic && v1[0] != vd[0]) - return emitOpError("expects src1 valid row to match dst valid row"); - - // Total valid columns must fit within dst static cols (when known). - auto sd = getShapeVec(td); - if (sd.size() == 2 && sd[1] != ShapedType::kDynamic && - v0[1] != ShapedType::kDynamic && v1[1] != ShapedType::kDynamic) { - if (v0[1] + v1[1] > sd[1]) - return emitOpError("expects src0.valid_col + src1.valid_col <= dst.cols"); - } - - return e0; - }; - - auto verifyElemType = [&](Type elem) -> LogicalResult { - if (elem.isF16() || elem.isF32() || elem.isBF16()) - return success(); - auto it = mlir::dyn_cast(elem); - if (!it || - (it.getWidth() != 8 && it.getWidth() != 16 && it.getWidth() != 32)) - return emitOpError("expects element type to be i8, i16, i32, f16, f32, or bf16"); - return success(); - }; - - auto verifyLocVec = [&](Type ty, StringRef name) -> LogicalResult { - auto as = getPTOMemorySpaceEnum(ty); - if (!as || *as != pto::AddressSpace::VEC) - return emitOpError() << "expects " << name << " to use loc=vec"; - return success(); - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - FailureOr elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - if (failed(verifyLocVec(getSrc0().getType(), "src0")) || - failed(verifyLocVec(getSrc1().getType(), "src1")) || - failed(verifyLocVec(getDst().getType(), "dst"))) - return failure(); - return verifyElemType(*elemOr); - }; - - auto verifyA5 = [&]() -> LogicalResult { - FailureOr elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - if (failed(verifyLocVec(getSrc0().getType(), "src0")) || - failed(verifyLocVec(getSrc1().getType(), "src1")) || - failed(verifyLocVec(getDst().getType(), "dst"))) - return failure(); - if (!isRowMajorTileBuf(getSrc0().getType()) || !isRowMajorTileBuf(getSrc1().getType()) || - !isRowMajorTileBuf(getDst().getType())) - return emitOpError("expects src0, src1, and dst to use row-major layout"); - return verifyElemType(*elemOr); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult pto::TConcatidxOp::verify() { - auto verifyCommon = [&]() -> FailureOr> { - Type t0 = getSrc0().getType(); - Type t1 = getSrc1().getType(); - Type ti0 = getSrc0Idx().getType(); - Type ti1 = getSrc1Idx().getType(); - Type td = getDst().getType(); - if (failed(verifyTileBufCommon(*this, t0, "src0")) || - failed(verifyTileBufCommon(*this, t1, "src1")) || - failed(verifyTileBufCommon(*this, ti0, "src0Idx")) || - failed(verifyTileBufCommon(*this, ti1, "src1Idx")) || - failed(verifyTileBufCommon(*this, td, "dst"))) - return failure(); - - // Check data element type consistency. - Type e0 = getElemTy(t0); - Type e1 = getElemTy(t1); - Type ed = getElemTy(td); - if (!e0 || !e1 || !ed) { - emitOpError("failed to get element type for data operands"); - return failure(); - } - if (e0 != e1 || e0 != ed) { - emitOpError("expects src0, src1, and dst to have the same element type"); - return failure(); - } - - // Check index element type consistency. - Type ei0 = getElemTy(ti0); - Type ei1 = getElemTy(ti1); - if (!ei0 || !ei1) { - emitOpError("failed to get element type for index operands"); - return failure(); - } - if (ei0 != ei1) { - emitOpError("expects src0Idx and src1Idx to have the same element type"); - return failure(); - } - - // All five tiles must be rank-2. - auto v0 = getValidShapeVec(getSrc0()); - auto v1 = getValidShapeVec(getSrc1()); - auto vi0 = getValidShapeVec(getSrc0Idx()); - auto vi1 = getValidShapeVec(getSrc1Idx()); - auto vd = getValidShapeVec(getDst()); - if (v0.size() != 2 || v1.size() != 2 || vi0.size() != 2 || - vi1.size() != 2 || vd.size() != 2) - return emitOpError("expects all operands to have rank-2 valid_shape"); - - // validRow must match dst (when known). - auto checkValidRow = [&](const auto &v, StringRef name) -> LogicalResult { - if (v[0] != ShapedType::kDynamic && vd[0] != ShapedType::kDynamic && - v[0] != vd[0]) - return emitOpError("expects ") << name << " valid row to match dst valid row"; - return success(); - }; - if (failed(checkValidRow(v0, "src0")) || - failed(checkValidRow(v1, "src1")) || - failed(checkValidRow(vi0, "src0Idx")) || - failed(checkValidRow(vi1, "src1Idx"))) - return failure(); - - // Index tile must have cols >= 1 (when known). - if (vi0[1] != ShapedType::kDynamic && vi0[1] < 1) - return emitOpError("expects src0Idx valid_col >= 1"); - if (vi1[1] != ShapedType::kDynamic && vi1[1] < 1) - return emitOpError("expects src1Idx valid_col >= 1"); - - return std::make_pair(e0, ei0); - }; - - auto verifyElementTypes = [&](Type dataElem, Type idxElem) -> LogicalResult { - // Data element type: f16, f32, bf16, i8, i16, i32 (signless). - if (!dataElem.isF16() && !dataElem.isF32() && !dataElem.isBF16()) { - auto it = mlir::dyn_cast(dataElem); - if (!it || !it.isSignless() || - (it.getWidth() != 8 && it.getWidth() != 16 && it.getWidth() != 32)) - return emitOpError() - << "expects data element type to be i8, i16, i32, f16, f32, or bf16"; - } - - // Index element type: i8, i16, i32 (signless). - auto it = mlir::dyn_cast(idxElem); - if (!it || !it.isSignless() || - (it.getWidth() != 8 && it.getWidth() != 16 && it.getWidth() != 32)) - return emitOpError() - << "expects index element type to be i8, i16, or i32"; - return success(); - }; - - auto verifyLocVec = [&](Type ty, StringRef name) -> LogicalResult { - auto as = getPTOMemorySpaceEnum(ty); - if (!as || *as != pto::AddressSpace::VEC) - return emitOpError() << "expects " << name << " to use loc=vec"; - return success(); - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - auto elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - if (failed(verifyLocVec(getSrc0().getType(), "src0")) || - failed(verifyLocVec(getSrc1().getType(), "src1")) || - failed(verifyLocVec(getSrc0Idx().getType(), "src0Idx")) || - failed(verifyLocVec(getSrc1Idx().getType(), "src1Idx")) || - failed(verifyLocVec(getDst().getType(), "dst"))) - return failure(); - return verifyElementTypes(elemOr->first, elemOr->second); - }; - - auto verifyA5 = [&]() -> LogicalResult { - auto elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - if (failed(verifyLocVec(getSrc0().getType(), "src0")) || - failed(verifyLocVec(getSrc1().getType(), "src1")) || - failed(verifyLocVec(getSrc0Idx().getType(), "src0Idx")) || - failed(verifyLocVec(getSrc1Idx().getType(), "src1Idx")) || - failed(verifyLocVec(getDst().getType(), "dst"))) - return failure(); - if (!isRowMajorTileBuf(getSrc0().getType()) || - !isRowMajorTileBuf(getSrc1().getType()) || - !isRowMajorTileBuf(getSrc0Idx().getType()) || - !isRowMajorTileBuf(getSrc1Idx().getType()) || - !isRowMajorTileBuf(getDst().getType())) - return emitOpError( - "expects all operands to use row-major layout"); - return verifyElementTypes(elemOr->first, elemOr->second); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult pto::TAndSOp::verify() { - auto verifyCommon = [&]() -> FailureOr { - return verifyDistinctRowMajorUnaryTileOpCommon(getOperation(), getSrc(), - getDst(), "src", "dst"); - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - FailureOr elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - auto it = mlir::dyn_cast(*elemOr); - if (!it || (it.getWidth() != 8 && it.getWidth() != 16)) - return emitOpError( - "expects A2/A3 tands src, scalar, and dst element type to be i8/i16"); - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - FailureOr elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - auto it = mlir::dyn_cast(*elemOr); - if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && - it.getWidth() != 32)) - return emitOpError( - "expects A5 tands src, scalar, and dst element type to be i8/i16/i32"); - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult pto::TCIOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, dstTy, "dst"))) - return failure(); - - auto elemTy = mlir::dyn_cast(getElemTy(dstTy)); - if (!elemTy) - return emitOpError("expects dst element type to be integer"); - - unsigned bw = elemTy.getWidth(); - if (bw != 16 && bw != 32) - return emitOpError("expects dst element type to be i16/i32"); - - auto sTy = mlir::dyn_cast(getOperand(0).getType()); - if (!sTy) - return emitOpError("expects S to be integer"); - - if (sTy != elemTy) - return emitOpError("expects S and dst element type to be exactly the same type"); - auto shape = getShapeVec(dstTy); - if (shape.size() != 2) - return emitOpError("expects dst to be rank-2"); - if (shape[1] != ShapedType::kDynamic && shape[1] == 1) - return emitOpError("expects dst cols to be different from 1"); - - return success(); -} - -LogicalResult pto::TTriOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - - Type dstTy = getDst().getType(); - if (failed(verifyVecTileCommon(*this, dstTy, "dst"))) - return failure(); - - auto diagonalTy = mlir::dyn_cast(getDiagonal().getType()); - if (!diagonalTy) - return emitOpError("expects diagonal to be an integer operand"); - - int32_t upperOrLower = getUpperOrLower(); - if (upperOrLower != 0 && upperOrLower != 1) - return emitOpError("expects upperOrLower to be 0 (lower) or 1 (upper)"); - - Type elemTy = getElemTy(dstTy); - return dispatchVerifierByArch( - getOperation(), - [&]() -> LogicalResult { - if (!isSupportedVecElemType(elemTy, /*allowBf16=*/false, - /*allowInt8=*/false)) - return emitOpError() - << "expects A2/A3 dst element type to be f16/f32/i16/i32/u16/u32"; - return success(); - }, - [&]() -> LogicalResult { - if (!isSupportedVecElemType(elemTy, /*allowBf16=*/true, - /*allowInt8=*/true)) - return emitOpError() - << "expects A5 dst element type to be f16/f32/bf16/i8/i16/i32/u8/u16/u32"; - return success(); - }); -} - -LogicalResult pto::TCmpOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - Type t0 = getSrc0().getType(); - Type t1 = getSrc1().getType(); - Type td = getDst().getType(); - if (failed(verifyVecTileStorage(*this, t0, "src0")) || - failed(verifyVecTileStorage(*this, t1, "src1")) || - failed(verifyVecTileStorage(*this, td, "dst"))) - return failure(); - - Type e0 = getElemTy(t0); - Type e1 = getElemTy(t1); - Type ed = getElemTy(td); - if (!e0 || !e1 || !ed) - return emitOpError("failed to get element type for src0/src1/dst"); - if (e0 != e1) - return emitOpError("expects src0 and src1 to have the same element type"); - if (!(e0.isInteger(32) || e0.isF16() || e0.isF32())) - return emitOpError("expects A2/A3 tcmp input element type to be i32/f16/f32"); - if (!ed.isInteger(8)) - return emitOpError("expects dst element type to be i8"); - - auto valid0 = getValidShapeVec(t0); - auto valid1 = getValidShapeVec(t1); - auto validd = getValidShapeVec(td); - if (valid0.size() != 2 || valid1.size() != 2 || validd.size() != 2) - return emitOpError("expects src0, src1, and dst to have rank-2 valid_shape"); - if (!hasCompatibleKnownExtent(valid0[0], valid1[0])) - return emitOpError("expects src0 and src1 to have the same valid row"); - if (!hasCompatibleKnownExtent(valid0[1], valid1[1])) - return emitOpError("expects src0 and src1 to have the same valid column"); - if (!hasCompatibleKnownExtent(valid0[0], validd[0])) - return emitOpError("expects src0 valid row to equal dst valid row"); - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - Type t0 = getSrc0().getType(); - Type t1 = getSrc1().getType(); - Type td = getDst().getType(); - if (failed(verifyTileBufCommon(*this, t0, "src0")) || - failed(verifyTileBufCommon(*this, t1, "src1")) || - failed(verifyTileBufCommon(*this, td, "dst"))) - return failure(); - - Type e0 = getElemTy(t0); - Type e1 = getElemTy(t1); - Type ed = getElemTy(td); - if (!e0 || !e1 || !ed) - return emitOpError("failed to get element type for src0/src1/dst"); - if (e0 != e1) - return emitOpError("expects src0 and src1 to have the same element type"); - bool inputOk = e0.isF16() || e0.isF32() || e0.isBF16() || - e0.isInteger(8) || e0.isInteger(16) || e0.isInteger(32); - if (!inputOk) - return emitOpError("expects A5 tcmp input element type to be i8/i16/i32/f16/bf16/f32"); - if (auto it = dyn_cast(ed)) { - if (it.getWidth() != 8) - return emitOpError("expects dst element type to be i8"); - } else { - return emitOpError("expects dst element type to be i8"); - } - - if (getShapeVec(t0) != getShapeVec(t1) || getShapeVec(t0) != getShapeVec(td)) - return emitOpError("expects src0, src1, and dst to have the same shape"); - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -// ---- TCMPS verify ---- -LogicalResult pto::TCmpSOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyVecTileStorage(*this, srcTy, "src")) || - failed(verifyVecTileStorage(*this, dstTy, "dst"))) - return failure(); - - Type elemTy = getElemTy(srcTy); - if (!(elemTy.isInteger(16) || elemTy.isInteger(32) || - elemTy.isF16() || elemTy.isF32())) - return emitOpError("expects A2/A3 tcmps input element type to be i16/i32/f16/f32"); - - auto scalarTy = getScalar().getType(); - if (!(scalarTy.isIntOrIndexOrFloat())) - return emitOpError("expects scalar to be integer, index, or float"); - - auto srcValid = getValidShapeVec(srcTy); - auto dstValid = getValidShapeVec(dstTy); - if (srcValid.size() != 2 || dstValid.size() != 2) - return emitOpError("expects src and dst to have rank-2 valid_shape"); - if (srcValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && - srcValid[0] != dstValid[0]) - return emitOpError("expects src and dst to have the same valid_shape[0]"); - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyVecTileStorage(*this, srcTy, "src")) || - failed(verifyVecTileStorage(*this, dstTy, "dst"))) - return failure(); - - Type elemTy = getElemTy(srcTy); - if (!(elemTy.isInteger(8) || elemTy.isInteger(16) || elemTy.isInteger(32) || - elemTy.isF16() || elemTy.isF32())) - return emitOpError("expects A5 tcmps input element type to be i8/i16/i32/f16/f32"); - - auto scalarTy = getScalar().getType(); - if (!(scalarTy.isIntOrIndexOrFloat())) - return emitOpError("expects scalar to be integer, index, or float"); - - auto srcValid = getValidShapeVec(srcTy); - auto dstValid = getValidShapeVec(dstTy); - if (srcValid.size() != 2 || dstValid.size() != 2) - return emitOpError("expects src and dst to have rank-2 valid_shape"); - if (srcValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && - srcValid[0] != dstValid[0]) - return emitOpError("expects src and dst to have the same valid_shape[0]"); - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} -LogicalResult pto::TColExpandOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyNDStyleVecTile(*this, srcTy, "src")) || - failed(verifyNDStyleVecTile(*this, dstTy, "dst"))) - return failure(); - if (getElemTy(srcTy) != getElemTy(dstTy)) - return emitOpError("expects src and dst to have the same element type"); - if (!isSupportedVecElemType(getElemTy(srcTy), /*allowBf16=*/true, - /*allowInt8=*/true)) - return emitOpError("expects tcolexpand element type to be supported"); - auto srcValid = getValidShapeVec(getSrc()); - auto dstValid = getValidShapeVec(getDst()); - if (srcValid.size() != 2 || dstValid.size() != 2) - return emitOpError("expects src and dst to have rank-2 valid_shape"); - if (srcValid[1] != ShapedType::kDynamic && dstValid[1] != ShapedType::kDynamic && - srcValid[1] != dstValid[1]) - return emitOpError("expects src and dst to have the same valid_shape[1]"); - return success(); -} -static LogicalResult verifyTColExpandBinaryLikeOp(Operation *op, Type t0, Type t1, - Type td, PTOArch targetArch, - StringRef opName, - bool allowIntegerTypes) { - if (!isPTOShapedLike(t0) || !isPTOShapedLike(t1) || !isPTOShapedLike(td)) - return op->emitOpError("expects src0/src1/dst to be PTO shaped-like types"); - - Type e0 = getElemTy(t0); - Type e1 = getElemTy(t1); - Type ed = getElemTy(td); - if (!e0 || !e1 || !ed) - return op->emitOpError("failed to get element type for src0/src1/dst"); - - auto isSupportedElem = [&](Type elemTy) { - if (elemTy.isF16() || elemTy.isF32()) - return true; - if (!allowIntegerTypes) - return false; - if (elemTy.isInteger(16) || elemTy.isInteger(32)) - return true; - return targetArch == PTOArch::A5 && elemTy.isInteger(8); - }; - if (!isSupportedElem(e0) || !isSupportedElem(e1) || !isSupportedElem(ed)) { - if (!allowIntegerTypes) - return op->emitOpError() << "expects " << opName - << " element type to be f16 or f32"; - if (targetArch == PTOArch::A5) - return op->emitOpError() << "expects A5 " << opName - << " element type to be i8/i16/i32/f16/f32"; - return op->emitOpError() << "expects A2/A3 " << opName - << " element type to be i16/i32/f16/f32"; - } - - if (getShapeVec(t0) != getShapeVec(td)) - return op->emitOpError("expects src0/dst to have same shape"); - if (failed(verifyTileBufSameValidShape(op, t0, td, "src0", "dst"))) - return failure(); - - if (auto src0TileTy = dyn_cast(t0)) { - if (src0TileTy.getBLayoutValueI32() != 0) - return op->emitOpError("expects src0 to use row-major layout"); - } - - if (auto src1TileTy = dyn_cast(t1)) { - if (src1TileTy.getBLayoutValueI32() != 0) - return op->emitOpError("expects src1 to use row-major layout"); - } - if (auto dstTileTy = dyn_cast(td)) { - if (dstTileTy.getBLayoutValueI32() != 0) - return op->emitOpError("expects dst to use row-major layout"); - } - - auto src1Valid = getValidShapeVec(t1); - auto dstValid = getValidShapeVec(td); - if (src1Valid.size() == 2 && dstValid.size() == 2 && - src1Valid[1] != ShapedType::kDynamic && dstValid[1] != ShapedType::kDynamic && - src1Valid[1] != dstValid[1]) - return op->emitOpError("expects src1 valid_shape[1] to equal dst valid_shape[1]"); - - return success(); -} -LogicalResult pto::TColExpandMulOp::verify() { - PTOArch arch = getTargetArch(getOperation()); - return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), - getSrc1().getType(), getDst().getType(), - arch, "tcolexpandmul", - /*allowIntegerTypes=*/true); -} -LogicalResult pto::TColExpandAddOp::verify() { - PTOArch arch = getTargetArch(getOperation()); - return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), - getSrc1().getType(), getDst().getType(), - arch, "tcolexpandadd", - /*allowIntegerTypes=*/true); -} -LogicalResult pto::TColExpandDivOp::verify() { - auto verifyByArch = [&](PTOArch targetArch) -> LogicalResult { - bool allowIntegerTypes = (targetArch == PTOArch::A5); - return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), - getSrc1().getType(), getDst().getType(), - targetArch, "tcolexpanddiv", - /*allowIntegerTypes=*/allowIntegerTypes); - }; - auto verifyA2A3 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A3); }; - auto verifyA5 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A5); }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} -LogicalResult pto::TColExpandSubOp::verify() { - PTOArch arch = getTargetArch(getOperation()); - return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), - getSrc1().getType(), getDst().getType(), - arch, "tcolexpandsub", - /*allowIntegerTypes=*/true); -} -LogicalResult pto::TColExpandExpdifOp::verify() { - PTOArch arch = getTargetArch(getOperation()); - return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), - getSrc1().getType(), getDst().getType(), - arch, "tcolexpandexpdif", - /*allowIntegerTypes=*/false); -} -LogicalResult pto::TColExpandMaxOp::verify() { - PTOArch arch = getTargetArch(getOperation()); - return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), - getSrc1().getType(), getDst().getType(), - arch, "tcolexpandmax", - /*allowIntegerTypes=*/true); -} -LogicalResult pto::TColExpandMinOp::verify() { - PTOArch arch = getTargetArch(getOperation()); - return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), - getSrc1().getType(), getDst().getType(), - arch, "tcolexpandmin", - /*allowIntegerTypes=*/true); -} -LogicalResult pto::TColMaxOp::verify() { - return verifyTColReductionOpWithArchDispatch( - getOperation(), getSrc().getType(), getDst().getType(), - /*requireNonZeroSrcOnA2A3=*/false, /*requireNonZeroSrcOnA5=*/true, - /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, - "expects A2/A3 tcolmax element type to be f16/f32/i16/i32", - "expects A5 tcolmax element type to be i8/i16/i32/f16/bf16/f32"); -} - -LogicalResult pto::TColArgMaxOp::verify() { - auto verifyByArch = [&]() -> LogicalResult { - return verifyTColArgReductionOpCommon(*this, getSrc().getType(), - getTmp().getType(), getDst().getType()); - }; - - return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); -} - -LogicalResult pto::TColMinOp::verify() { - return verifyTColReductionOpWithArchDispatch( - getOperation(), getSrc().getType(), getDst().getType(), - /*requireNonZeroSrcOnA2A3=*/false, /*requireNonZeroSrcOnA5=*/true, - /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, - "expects A2/A3 tcolmin element type to be f16/f32/i16/i32", - "expects A5 tcolmin element type to be i8/i16/i32/f16/bf16/f32"); -} - -LogicalResult pto::TColArgMinOp::verify() { - auto verifyByArch = [&]() -> LogicalResult { - return verifyTColArgReductionOpCommon(*this, getSrc().getType(), - getTmp().getType(), getDst().getType()); - }; - - return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); -} - - - -ParseResult mlir::pto::TColSumOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand src; - OpAsmParser::UnresolvedOperand tmp; - OpAsmParser::UnresolvedOperand dst; - Type srcTy, tmpTy, dstTy; - bool hasTmp = false; - - // Parse: ins(%src : type) or ins(%src, %tmp {isBinary = ...}: type, type) - if (parser.parseKeyword("ins") || parser.parseLParen() || parser.parseOperand(src)) - return failure(); - - // Check for optional tmp operand (format 2) - if (succeeded(parser.parseOptionalComma())) { - // Format 2: ins(%src, %tmp {isBinary = ...}: type, type) - if (parser.parseOperand(tmp)) - return failure(); - hasTmp = true; - - // Parse attributes (isBinary) - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - - // Parse types: : type, type - if (parser.parseColonType(srcTy) || parser.parseComma() || parser.parseType(tmpTy)) - return failure(); - } else { - // Format 1: ins(%src : type) - if (parser.parseColonType(srcTy)) - return failure(); - } - - if (parser.parseRParen()) - return failure(); - - // Parse: outs(%dst : type) - if (parser.parseKeyword("outs") || parser.parseLParen() || - parser.parseOperand(dst) || parser.parseColonType(dstTy) || - parser.parseRParen()) - return failure(); - - // Parse any remaining attributes (for format 1) - if (!hasTmp) { - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - } - - // Resolve operands - if (parser.resolveOperand(src, srcTy, result.operands)) - return failure(); - - if (hasTmp) { - if (parser.resolveOperand(tmp, tmpTy, result.operands)) - return failure(); - } - - if (parser.resolveOperand(dst, dstTy, result.operands)) - return failure(); - - return success(); -} - -void mlir::pto::TColSumOp::print(OpAsmPrinter &p) { - if (getTmp()) { - // Format 2: ins(%src, %tmp {isBinary = ...}: type, type) outs(%dst : type) - p << " ins(" << getSrc() << ", " << getTmp(); - // Print isBinary attribute if present - SmallVector elidedAttrs; - if (!getIsBinaryAttr() || getIsBinaryAttr().getValue() == false) { - elidedAttrs.push_back("isBinary"); - } - p.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); - p << " : " << getSrc().getType() << ", " << getTmp().getType() << ")"; - } else { - // Format 1: ins(%src : type) outs(%dst : type) - p << " ins(" << getSrc() << " : " << getSrc().getType() << ")"; - } - - p << " outs(" << getDst() << " : " << getDst().getType() << ")"; - - // Print remaining attributes for format 1 (excluding isBinary) - if (!getTmp()) { - SmallVector elidedAttrs = {"isBinary"}; - p.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); - } -} - -LogicalResult pto::TColSumOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyNDStyleVecTile(*this, srcTy, "src")) || - failed(verifyNDStyleVecTile(*this, dstTy, "dst"))) - return failure(); - bool hasTmp = (bool)getTmp(); - bool hasIsBinary = (bool)getIsBinaryAttr(); - if (hasTmp != hasIsBinary) { - if (hasTmp) - return emitOpError("tmp operand requires isBinary attribute"); - return emitOpError("isBinary attribute requires tmp operand"); - } - if (getTmp()) { - Type tmpTy = getTmp().getType(); - if (failed(verifyNDStyleVecTile(*this, tmpTy, "tmp"))) - return failure(); - if (getElemTy(srcTy) != getElemTy(dstTy) || getElemTy(srcTy) != getElemTy(tmpTy)) - return emitOpError("expects src/tmp/dst element types to match"); - } - if (getElemTy(srcTy) != getElemTy(dstTy)) - return emitOpError("expects src/dst element types to match"); - if (failed(verifyColReductionValidRegion(*this, srcTy, dstTy, - /*requireNonZeroSrc=*/false))) - return failure(); - Type elem = getElemTy(srcTy); - if (!(elem.isF16() || elem.isF32() || elem.isInteger(16) || elem.isInteger(32))) - return emitOpError("expects A2/A3 tcolsum element type to be f16/f32/i16/i32"); - return success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyNDStyleVecTile(*this, srcTy, "src")) || - failed(verifyNDStyleVecTile(*this, dstTy, "dst"))) - return failure(); - bool hasTmp = (bool)getTmp(); - bool hasIsBinary = (bool)getIsBinaryAttr(); - if (hasTmp != hasIsBinary) { - if (hasTmp) - return emitOpError("tmp operand requires isBinary attribute"); - return emitOpError("isBinary attribute requires tmp operand"); - } - if (getTmp()) { - Type tmpTy = getTmp().getType(); - if (failed(verifyNDStyleVecTile(*this, tmpTy, "tmp"))) - return failure(); - if (getElemTy(srcTy) != getElemTy(dstTy) || getElemTy(srcTy) != getElemTy(tmpTy)) - return emitOpError("expects src/tmp/dst element types to match"); - } - if (getElemTy(srcTy) != getElemTy(dstTy)) - return emitOpError("expects src/dst element types to match"); - if (failed(verifyColReductionValidRegion(*this, srcTy, dstTy, - /*requireNonZeroSrc=*/true))) - return failure(); - Type elem = getElemTy(srcTy); - if (!(elem.isF16() || elem.isF32() || elem.isBF16() || elem.isInteger(8) || - elem.isInteger(16) || elem.isInteger(32))) - return emitOpError("expects A5 tcolsum element type to be i8/i16/i32/f16/bf16/f32"); - return success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult pto::TColProdOp::verify() { - return verifyTColReductionOpWithArchDispatch( - getOperation(), getSrc().getType(), getDst().getType(), - /*requireNonZeroSrcOnA2A3=*/false, /*requireNonZeroSrcOnA5=*/false, - /*allowInt8OnA5=*/false, /*allowBf16OnA5=*/true, - "expects A2/A3 tcolprod element type to be f16/f32/i16/i32", - "expects A5 tcolprod element type to be i16/ui16/i32/ui32/f16/bf16/f32"); -} - -llvm::LogicalResult mlir::pto::TCvtOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, srcTy, "src", /*allowLowPrecision=*/true)) || - failed(verifyTileBufCommon(*this, dstTy, "dst", /*allowLowPrecision=*/true))) - return failure(); - if (failed(verifyTileBufSameLogicalExtent(*this, srcTy, dstTy, "src", "dst", - /*compareValidShape=*/false))) - return failure(); - if (failed(verifyTileBufSameLogicalExtent(*this, srcTy, dstTy, "src", "dst", - /*compareValidShape=*/true))) - return failure(); - Type srcElem = getElemTy(srcTy); - Type dstElem = getElemTy(dstTy); - auto verifyA2A3 = [&]() -> LogicalResult { - if (isPTOLowPrecisionType(srcElem) || isPTOLowPrecisionType(dstElem)) - return emitOpError("expects A2/A3 tcvt low-precision element types to be unsupported"); - return success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - if (!isA5SupportedTCvtPair(srcElem, dstElem)) - return emitOpError("expects A5 tcvt low-precision type pairs to match PTO-ISA support"); - return success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -llvm::LogicalResult mlir::pto::TRandomOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - return emitOpError("trandom is only supported for A5 targets"); - }; - auto verifyA5 = [&]() -> LogicalResult { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, dstTy, "dst"))) - return failure(); - if (!isRowMajorTileBuf(dstTy)) - return emitOpError("expects dst to use row-major layout"); - - Type elemTy = getElemTy(dstTy); - if (!elemTy.isInteger(32)) - return emitOpError("expects dst element type to be i32 or ui32"); - - auto checkWord = [&](Value v, StringRef name) -> LogicalResult { - auto ty = dyn_cast(v.getType()); - if (!ty || ty.getWidth() != 32) - return emitOpError() << "expects " << name << " to be i32/ui32"; - return success(); - }; - if (failed(checkWord(getKey0(), "key0")) || - failed(checkWord(getKey1(), "key1")) || - failed(checkWord(getCounter0(), "counter0")) || - failed(checkWord(getCounter1(), "counter1")) || - failed(checkWord(getCounter2(), "counter2")) || - failed(checkWord(getCounter3(), "counter3"))) - return failure(); - - int32_t rounds = getRounds(); - if (rounds != 7 && rounds != 10) - return emitOpError("expects rounds to be 7 or 10"); - - return success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult mlir::pto::TDivOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - FailureOr elemOr = verifyMatchingRowMajorBinaryTileOpCommon( - getOperation(), getSrc0().getType(), getSrc1().getType(), - getDst().getType()); - if (failed(elemOr)) - return failure(); - auto elem0 = *elemOr; - if (!(elem0.isF16() || elem0.isF32())) - return emitOpError("expects A2/A3 tdiv element type to be f16 or f32"); - return success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - FailureOr elemOr = verifyMatchingRowMajorBinaryTileOpCommon( - getOperation(), getSrc0().getType(), getSrc1().getType(), - getDst().getType()); - if (failed(elemOr)) - return failure(); - auto elem0 = *elemOr; - if (!(elem0.isF16() || elem0.isF32() || elem0.isInteger(16) || elem0.isInteger(32))) - return emitOpError("expects A5 tdiv element type to be i32/i16/f16/f32"); - return success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TDivSOp::verify() { - auto isTileLike = [](Type ty) -> bool { - return isa(ty); - }; - auto isScalarLike = [](Type ty) -> bool { - return mlir::isa(ty); - }; - - auto verifyByArch = [&](PTOArch targetArch) -> LogicalResult { - Type srcTy = getSrc().getType(); - Type rhsTy = getScalar().getType(); - Type dstTy = getDst().getType(); - - bool srcTile = isTileLike(srcTy); - bool rhsTile = isTileLike(rhsTy); - bool srcScalar = isScalarLike(srcTy); - bool rhsScalar = isScalarLike(rhsTy); - - if (!(srcTile && rhsScalar) && !(srcScalar && rhsTile)) - return emitOpError("expects one tile-like operand and one scalar operand in ins(...)"); - - Type tileTy = srcTile ? srcTy : rhsTy; - Type scalarTy = srcTile ? rhsTy : srcTy; - - if (failed(verifyScalarTileOp(*this, tileTy, dstTy, "src", "dst", - /*requireValidRowsEqual=*/true, - /*requireValidColsEqual=*/true))) - return failure(); - if (!mlir::isa(scalarTy)) - return emitOpError("scalar must be a scalar type (integer/float)"); - Type elem = getElemTy(tileTy); - if (targetArch == PTOArch::A3 && - !(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || - elem.isF32())) - return emitOpError("expects A2/A3 tdivs element type to be i32/i16/f16/f32"); - if (targetArch == PTOArch::A5 && - !(elem.isInteger(32) || elem.isInteger(16) || elem.isInteger(8) || - elem.isF16() || elem.isF32())) - return emitOpError("expects A5 tdivs element type to be i32/i16/i8/f16/f32"); - return success(); - }; - auto verifyA2A3 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A3); }; - auto verifyA5 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A5); }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TExpOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyVecTileUnaryOp(*this, srcTy, dstTy, "src", "dst", - /*allowBf16=*/false, /*allowInt8=*/false))) - return failure(); - if (failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) - return failure(); - Type srcElem = getElemTy(srcTy); - if (!srcElem.isF16() && !srcElem.isF32()) - return emitOpError("expects element type to be f16 or f32"); - return mlir::success(); - }; - auto verifyA5 = [&]() -> LogicalResult { return verifyA2A3(); }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TExpandsOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, dstTy, "dst"))) - return failure(); - auto dstSpace = getPTOMemorySpaceEnum(dstTy); - if (!dstSpace || (*dstSpace != pto::AddressSpace::VEC && - *dstSpace != pto::AddressSpace::MAT)) - return emitOpError("expects dst to be in the vec or mat address space"); - Type dstElem = getElemTy(dstTy); - Type scalarTy = getScalar().getType(); - if (scalarTy != dstElem) - return emitOpError("expects scalar type == dst element type"); - if (*dstSpace == pto::AddressSpace::VEC && !isRowMajorTileBuf(dstTy)) - return emitOpError("expects vec dst to use row-major layout on A2/A3"); - if (dstElem.isF16() || dstElem.isBF16() || dstElem.isF32()) - return mlir::success(); - if (auto it = mlir::dyn_cast(dstElem)) { - unsigned w = it.getWidth(); - if (w == 16 || w == 32) - return mlir::success(); - } - return emitOpError("expects A2/A3 texpands dst element type to be i16/i32/f16/bf16/f32"); - }; - auto verifyA5 = [&]() -> LogicalResult { - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, dstTy, "dst"))) - return failure(); - auto dstSpace = getPTOMemorySpaceEnum(dstTy); - if (!dstSpace || (*dstSpace != pto::AddressSpace::VEC && - *dstSpace != pto::AddressSpace::MAT)) - return emitOpError("expects dst to be in the vec or mat address space"); - Type dstElem = getElemTy(dstTy); - Type scalarTy = getScalar().getType(); - if (scalarTy != dstElem) - return emitOpError("expects scalar type == dst element type"); - if (dstElem.isF16() || dstElem.isBF16() || dstElem.isF32()) - return mlir::success(); - if (auto it = mlir::dyn_cast(dstElem)) { - unsigned w = it.getWidth(); - if (w == 8 || w == 16 || w == 32) - return mlir::success(); - } - return emitOpError("expects A5 texpands dst element type to be i8/i16/i32/f16/bf16/f32"); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TExtractOp::verify() { - auto hasMatExtractSourceLayoutA2A3 = [&](pto::TileBufType srcTy) -> bool { - int32_t bl = srcTy.getBLayoutValueI32(); - int32_t sl = srcTy.getSLayoutValueI32(); - return bl == static_cast(pto::BLayout::RowMajor) || - (bl != static_cast(pto::BLayout::RowMajor) && - sl == static_cast(pto::SLayout::RowMajor)); - }; - auto hasMatExtractSourceLayoutA5 = [&](pto::TileBufType srcTy, - pto::AddressSpace dstSpace) -> bool { - int32_t bl = srcTy.getBLayoutValueI32(); - int32_t sl = srcTy.getSLayoutValueI32(); - if (dstSpace == pto::AddressSpace::LEFT) { - return (bl == static_cast(pto::BLayout::RowMajor) && - sl == static_cast(pto::SLayout::ColMajor)) || - (bl != static_cast(pto::BLayout::RowMajor) && - sl == static_cast(pto::SLayout::RowMajor)) || - bl == static_cast(pto::BLayout::RowMajor); - } - return (bl == static_cast(pto::BLayout::RowMajor) && - sl == static_cast(pto::SLayout::ColMajor)) || - (bl != static_cast(pto::BLayout::RowMajor) && - sl == static_cast(pto::SLayout::RowMajor)); - }; - auto isA2A3ExtractElemType = [&](Type ty) -> bool { - return ty.isInteger(8) || ty.isF16() || ty.isBF16() || ty.isF32(); - }; - auto isA5ExtractElemType = [&](Type ty) -> bool { - if (auto it = dyn_cast(ty)) - return it.getWidth() == 8; - if (auto ft = dyn_cast(ty)) - return ft.getWidth() == 8 || ft.isF16() || ft.isBF16() || ft.isF32(); - return false; - }; - auto isRowMajorNoneBoxND = [&](pto::TileBufType ty) -> bool { - return ty.getBLayoutValueI32() == static_cast(pto::BLayout::RowMajor) && - ty.getSLayoutValueI32() == static_cast(pto::SLayout::NoneBox); - }; - auto verifyCommon = [&]() -> FailureOr, - std::optional>> { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - auto srcTb = dyn_cast(srcTy); - auto dstTb = dyn_cast(dstTy); - if (!srcTb || !dstTb) - return emitOpError("expects src and dst to be !pto.tile_buf"); - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, dstTy, "dst")) || - failed(verifyNonNegativeIndexRowCol( - *getOperation(), getIndexRow(), getIndexCol(), - /*includeIndexAndIntOpsInConstFold=*/false)) || - failed(verifyExtractStaticBoundsCommon( - *getOperation(), getIndexRow(), getIndexCol(), srcTy, dstTy, - /*includeIndexAndIntOpsInConstFold=*/false))) - return failure(); - Type srcElem = getElemTy(srcTy); - Type dstElem = getElemTy(dstTy); - if (!srcElem || !dstElem || srcElem != dstElem) - return emitOpError("expects src and dst to have the same element type"); - auto srcSpace = getPTOMemorySpaceEnum(srcTy); - auto dstSpace = getPTOMemorySpaceEnum(dstTy); - return std::make_tuple(srcTy, dstTy, srcTb, dstTb, srcElem, dstElem, - srcSpace, dstSpace); - }; - auto verifyA2A3 = [&]() -> LogicalResult { - auto common = verifyCommon(); - if (failed(common)) - return failure(); - auto [srcTy, dstTy, srcTb, dstTb, srcElem, dstElem, srcSpace, dstSpace] = - *common; - (void)srcTy; - (void)dstTy; - (void)srcElem; - if (!isA2A3ExtractElemType(dstElem)) - return emitOpError("expects A2/A3 textract element type to be i8/f16/bf16/f32"); - if (srcSpace && dstSpace && *srcSpace == pto::AddressSpace::VEC && - *dstSpace == pto::AddressSpace::VEC) - return mlir::success(); - if (!srcSpace || *srcSpace != pto::AddressSpace::MAT) - return emitOpError("expects A2/A3 textract src to use loc=mat or vec"); - if (!dstSpace || (*dstSpace != pto::AddressSpace::LEFT && - *dstSpace != pto::AddressSpace::RIGHT)) - return emitOpError("expects A2/A3 textract dst to use loc=left, loc=right, or loc=vec"); - if (!hasMatExtractSourceLayoutA2A3(srcTb)) - return emitOpError("expects A2/A3 textract src to use a supported mat blayout/slayout combination"); - if (*dstSpace == pto::AddressSpace::LEFT) { - if (dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) || - dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor)) - return emitOpError("expects A2/A3 left dst to use row_major blayout and row_major slayout"); - } else { - if (dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) || - dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::ColMajor)) - return emitOpError("expects A2/A3 right dst to use row_major blayout and col_major slayout"); - } - return mlir::success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - auto common = verifyCommon(); - if (failed(common)) - return failure(); - auto [srcTy, dstTy, srcTb, dstTb, srcElem, dstElem, srcSpace, dstSpace] = - *common; - (void)srcTy; - (void)dstTy; - (void)srcElem; - if (!isA5ExtractElemType(dstElem)) - return emitOpError("expects A5 textract element type to be an fp8/f16/bf16/f32 or int8 family type"); - if (!srcSpace || !dstSpace) - return emitOpError("expects src and dst to have explicit loc"); - bool okPair = - (*srcSpace == pto::AddressSpace::MAT && - (*dstSpace == pto::AddressSpace::LEFT || - *dstSpace == pto::AddressSpace::RIGHT || - *dstSpace == pto::AddressSpace::SCALING)) || - (*srcSpace == pto::AddressSpace::VEC && - (*dstSpace == pto::AddressSpace::MAT || - *dstSpace == pto::AddressSpace::VEC)); - if (!okPair) - return emitOpError("expects A5 textract to use a supported src/dst loc pair"); - if (*srcSpace == pto::AddressSpace::MAT) { - if (!hasMatExtractSourceLayoutA5(srcTb, *dstSpace)) - return emitOpError("expects A5 textract src to use a supported mat blayout/slayout combination"); - if (*dstSpace == pto::AddressSpace::LEFT) { - if (dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor) || - dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor)) - return emitOpError("expects A5 left dst to use col_major blayout and row_major slayout"); - } else if (*dstSpace == pto::AddressSpace::RIGHT) { - if (dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) || - dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::ColMajor)) - return emitOpError("expects A5 right dst to use row_major blayout and col_major slayout"); - } - } else if (*srcSpace == pto::AddressSpace::VEC && - *dstSpace == pto::AddressSpace::VEC) { - if (!isRowMajorNoneBoxND(srcTb) || !isRowMajorNoneBoxND(dstTb)) - return emitOpError( - "expects A5 vec->vec textract src/dst to use ND layout " - "(blayout=row_major, slayout=none_box)"); - } - return success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} -mlir::LogicalResult mlir::pto::TInsertOp::verify() { - auto isColMajorRowMajorNZ = [&](pto::TileBufType ty) -> bool { - return ty.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) && - ty.getSLayoutValueI32() == static_cast(pto::SLayout::RowMajor); - }; - auto isRowMajorNoneBoxND = [&](pto::TileBufType ty) -> bool { - return ty.getBLayoutValueI32() == static_cast(pto::BLayout::RowMajor) && - ty.getSLayoutValueI32() == static_cast(pto::SLayout::NoneBox); - }; - auto isA5SupportedVecElemType = [&](Type ty) -> bool { - if (auto it = dyn_cast(ty)) - return it.getWidth() == 8 || it.getWidth() == 32; - if (auto ft = dyn_cast(ty)) - return ft.getWidth() == 8 || ft.isF16() || ft.isBF16() || ft.isF32(); - return false; - }; - auto isA2A3VecInsertElemType = [&](Type ty) -> bool { - return ty.isInteger(8) || ty.isF16() || ty.isBF16() || ty.isF32(); - }; - auto verifyCommon = [&]() -> FailureOr, - std::optional>> { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - auto srcTb = dyn_cast(srcTy); - auto dstTb = dyn_cast(dstTy); - if (!srcTb || !dstTb) - return emitOpError("expects src and dst to be !pto.tile_buf"); - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, dstTy, "dst")) || - failed(verifyNonNegativeIndexRowCol( - *getOperation(), getIndexRow(), getIndexCol(), - /*includeIndexAndIntOpsInConstFold=*/true)) || - failed(verifyInsertStaticBoundsCommon( - *getOperation(), getIndexRow(), getIndexCol(), srcTy, dstTy, - /*includeIndexAndIntOpsInConstFold=*/true))) - return failure(); - Type srcElem = getElemTy(srcTy); - Type dstElem = getElemTy(dstTy); - auto srcSpace = getPTOMemorySpaceEnum(srcTy); - auto dstSpace = getPTOMemorySpaceEnum(dstTy); - return std::make_tuple(srcTy, dstTy, srcTb, dstTb, srcElem, dstElem, - srcSpace, dstSpace); - }; - auto verifyA2A3 = [&]() -> LogicalResult { - auto common = verifyCommon(); - if (failed(common)) - return failure(); - auto [srcTy, dstTy, srcTb, dstTb, srcElem, dstElem, srcSpace, dstSpace] = - *common; - if (srcSpace && dstSpace && *srcSpace == pto::AddressSpace::VEC && - *dstSpace == pto::AddressSpace::VEC) { - if (srcElem != dstElem || !isA2A3VecInsertElemType(srcElem)) - return emitOpError( - "expects A2/A3 vec->vec tinsert src/dst to have same supported dtype " - "(i8/f16/bf16/f32)"); - return success(); - } - if (!srcSpace || !dstSpace || *srcSpace != pto::AddressSpace::ACC || - *dstSpace != pto::AddressSpace::MAT) - return emitOpError("expects A2/A3 tinsert to use acc->mat or vec->vec"); - - if (!isColMajorRowMajorNZ(srcTb)) - return emitOpError("expects A2/A3 tinsert src to use blayout=col_major and slayout=row_major"); - if (!isColMajorRowMajorNZ(dstTb)) - return emitOpError("expects A2/A3 tinsert dst to use blayout=col_major and slayout=row_major"); - if (dstTb.getSFractalSizeI32() != 512) - return emitOpError("expects A2/A3 tinsert dst fractal size to be 512"); - - if (!(srcElem.isF32() && (dstElem.isF16() || dstElem.isBF16()))) - return emitOpError("expects A2/A3 tinsert element types to be src=f32, dst=f16/bf16"); - return success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - auto common = verifyCommon(); - if (failed(common)) - return failure(); - auto [srcTy, dstTy, srcTb, dstTb, srcElem, dstElem, srcSpace, dstSpace] = - *common; - if (!srcSpace || !dstSpace) - return emitOpError("expects A5 tinsert src/dst to have explicit loc"); - - // A5 regular acc->mat path. - if (*srcSpace == pto::AddressSpace::ACC && *dstSpace == pto::AddressSpace::MAT) { - if (!isColMajorRowMajorNZ(srcTb)) - return emitOpError("expects A5 acc->mat tinsert src to use blayout=col_major and slayout=row_major"); - if (!isColMajorRowMajorNZ(dstTb)) - return emitOpError("expects A5 acc->mat tinsert dst to use blayout=col_major and slayout=row_major"); - bool okTypes = (srcElem.isF32() && - (dstElem.isF16() || dstElem.isBF16() || dstElem.isF32())) || - (srcElem.isInteger(32) && dstElem.isInteger(32)); - if (!okTypes) - return emitOpError( - "expects A5 acc->mat tinsert element types to be " - "(src=f32,dst=f16/bf16/f32) or (src=i32,dst=i32)"); - return success(); - } - - // A5 vec->mat path (ND/NZ modes in pto-isa). - if (*srcSpace == pto::AddressSpace::VEC && *dstSpace == pto::AddressSpace::MAT) { - if (!isColMajorRowMajorNZ(dstTb)) - return emitOpError("expects A5 vec->mat tinsert dst to use blayout=col_major and slayout=row_major"); - bool srcIsND = isRowMajorNoneBoxND(srcTb); - bool srcIsNZ = isColMajorRowMajorNZ(srcTb); - if (!srcIsND && !srcIsNZ) - return emitOpError( - "expects A5 vec->mat tinsert src to use ND(row_major/none_box) or NZ(col_major/row_major) layout"); - if (srcElem != dstElem || !isA5SupportedVecElemType(srcElem)) - return emitOpError( - "expects A5 vec->mat tinsert src/dst to have same supported dtype " - "(fp8/f16/bf16/f32/i8/i32)"); - return success(); - } - - // A5 vec->vec path (PR561 ND_VEC). - if (*srcSpace == pto::AddressSpace::VEC && *dstSpace == pto::AddressSpace::VEC) { - if (!isRowMajorNoneBoxND(srcTb) || !isRowMajorNoneBoxND(dstTb)) - return emitOpError( - "expects A5 vec->vec tinsert src/dst to use ND layout " - "(blayout=row_major, slayout=none_box)"); - if (srcElem != dstElem || !isA5SupportedVecElemType(srcElem)) - return emitOpError( - "expects A5 vec->vec tinsert src/dst to have same supported dtype " - "(fp8/f16/bf16/f32/i8/i32)"); - return success(); - } - - return emitOpError( - "expects A5 tinsert to use a supported src/dst loc pair: " - "acc->mat, vec->mat, or vec->vec"); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -static bool isColMajorRowMajorNZTileBuf(pto::TileBufType ty) { - return ty.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) && - ty.getSLayoutValueI32() == static_cast(pto::SLayout::RowMajor); -} - -static bool isA2A3VectorPreQuantTypePair(Type srcElem, Type dstElem) { - if (srcElem.isF32()) - return dstElem.isInteger(8); - if (srcElem.isInteger(32)) - return dstElem.isInteger(8) || dstElem.isF16() || dstElem.isInteger(16); - return false; -} - -static bool isA5Fp8LikeType(Type ty) { - if (auto ft = dyn_cast(ty)) - return ft.getWidth() == 8; - return false; -} - -static bool isA5MxInputType(Type ty) { - return isA5Fp8LikeType(ty); -} - -static LogicalResult verifyA5MxTypeTriple(Operation *op, Type lhsTy, Type rhsTy, - Type dstTy, StringRef lhsName, - StringRef rhsName, StringRef dstName) { - Type lhsElem = getElemTy(lhsTy); - Type rhsElem = getElemTy(rhsTy); - Type dstElem = getElemTy(dstTy); - - if (!isA5MxInputType(lhsElem) || !isA5MxInputType(rhsElem)) - return op->emitOpError() - << "expects A5 mx operands " << lhsName << " and " << rhsName - << " to use fp8 element types"; - - if (!dstElem.isF32()) - return op->emitOpError() - << "expects A5 mx result " << dstName << " to use f32 element type"; - - return success(); -} - -static bool isA5VectorPreQuantTypePair(Type srcElem, Type dstElem) { - if (srcElem.isF32()) - return dstElem.isInteger(8) || isA5Fp8LikeType(dstElem) || dstElem.isF16() || - dstElem.isBF16() || dstElem.isF32(); - if (srcElem.isInteger(32)) - return dstElem.isInteger(8) || dstElem.isF16() || dstElem.isBF16(); - return false; -} - -mlir::LogicalResult mlir::pto::TExtractFPOp::verify() { - auto verifyCommon = [&]() -> FailureOr> { - Type srcTy = getSrc().getType(); - Type fpTy = getFp().getType(); - Type dstTy = getDst().getType(); - auto srcTb = dyn_cast(srcTy); - auto fpTb = dyn_cast(fpTy); - auto dstTb = dyn_cast(dstTy); - if (!srcTb || !fpTb || !dstTb) - return emitOpError("expects src, fp, and dst to be !pto.tile_buf"); - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, fpTy, "fp")) || - failed(verifyTileBufCommon(*this, dstTy, "dst")) || - failed(verifyNonNegativeIndexRowCol( - *getOperation(), getIndexRow(), getIndexCol(), - /*includeIndexAndIntOpsInConstFold=*/true)) || - failed(verifyExtractStaticBoundsCommon( - *getOperation(), getIndexRow(), getIndexCol(), srcTy, dstTy, - /*includeIndexAndIntOpsInConstFold=*/true))) - return failure(); - auto srcSpace = getPTOMemorySpaceEnum(srcTy); - auto fpSpace = getPTOMemorySpaceEnum(fpTy); - auto dstSpace = getPTOMemorySpaceEnum(dstTy); - if (!srcSpace || !fpSpace || !dstSpace) - return emitOpError("expects src, fp, and dst to have explicit loc"); - if (*srcSpace != pto::AddressSpace::ACC) - return emitOpError("expects src to use loc=acc"); - if (*fpSpace != pto::AddressSpace::SCALING) - return emitOpError("expects fp to use loc=scaling"); - if (*dstSpace != pto::AddressSpace::MAT) - return emitOpError("expects dst to use loc=mat"); - if (!isColMajorRowMajorNZTileBuf(srcTb)) - return emitOpError("expects src to use blayout=col_major and slayout=row_major"); - if (!isColMajorRowMajorNZTileBuf(dstTb)) - return emitOpError("expects dst to use blayout=col_major and slayout=row_major"); - return std::make_tuple(srcTy, fpTy, dstTy, srcTb, fpTb, dstTb, *srcSpace, - *fpSpace, *dstSpace); - }; - auto verifyA2A3 = [&]() -> LogicalResult { - auto common = verifyCommon(); - if (failed(common)) - return failure(); - auto [srcTy, fpTy, dstTy, srcTb, fpTb, dstTb, srcSpace, fpSpace, dstSpace] = - *common; - (void)fpTy; - (void)srcSpace; - (void)fpSpace; - (void)dstSpace; - if (dstTb.getSFractalSizeI32() != 512) - return emitOpError("expects dst fractal size to be 512"); - Type srcElem = getElemTy(srcTy); - Type dstElem = getElemTy(dstTy); - if (!isA2A3VectorPreQuantTypePair(srcElem, dstElem)) - return emitOpError( - "expects A2/A3 textract_fp element types to be (src=f32,dst=i8) " - "or (src=i32,dst=i8/f16/i16)"); - return success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - auto common = verifyCommon(); - if (failed(common)) - return failure(); - auto [srcTy, fpTy, dstTy, srcTb, fpTb, dstTb, srcSpace, fpSpace, dstSpace] = - *common; - (void)fpTy; - (void)srcTb; - (void)fpTb; - (void)dstTb; - (void)srcSpace; - (void)fpSpace; - (void)dstSpace; - Type srcElem = getElemTy(srcTy); - Type dstElem = getElemTy(dstTy); - if (!isA5VectorPreQuantTypePair(srcElem, dstElem)) - return emitOpError( - "expects A5 textract_fp element types to be (src=f32,dst=i8/fp8/f16/bf16/f32) " - "or (src=i32,dst=i8/f16/bf16)"); - return success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TInsertFPOp::verify() { - auto verifyCommon = [&]() -> FailureOr> { - Type srcTy = getSrc().getType(); - Type fpTy = getFp().getType(); - Type dstTy = getDst().getType(); - auto srcTb = dyn_cast(srcTy); - auto fpTb = dyn_cast(fpTy); - auto dstTb = dyn_cast(dstTy); - if (!srcTb || !fpTb || !dstTb) - return emitOpError("expects src, fp, and dst to be !pto.tile_buf"); - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, fpTy, "fp")) || - failed(verifyTileBufCommon(*this, dstTy, "dst")) || - failed(verifyNonNegativeIndexRowCol( - *getOperation(), getIndexRow(), getIndexCol(), - /*includeIndexAndIntOpsInConstFold=*/true)) || - failed(verifyInsertStaticBoundsCommon( - *getOperation(), getIndexRow(), getIndexCol(), srcTy, dstTy, - /*includeIndexAndIntOpsInConstFold=*/true))) - return failure(); - auto srcSpace = getPTOMemorySpaceEnum(srcTy); - auto fpSpace = getPTOMemorySpaceEnum(fpTy); - auto dstSpace = getPTOMemorySpaceEnum(dstTy); - if (!srcSpace || !fpSpace || !dstSpace) - return emitOpError("expects src, fp, and dst to have explicit loc"); - if (*srcSpace != pto::AddressSpace::ACC) - return emitOpError("expects src to use loc=acc"); - if (*fpSpace != pto::AddressSpace::SCALING) - return emitOpError("expects fp to use loc=scaling"); - if (*dstSpace != pto::AddressSpace::MAT) - return emitOpError("expects dst to use loc=mat"); - if (!isColMajorRowMajorNZTileBuf(srcTb)) - return emitOpError("expects src to use blayout=col_major and slayout=row_major"); - if (!isColMajorRowMajorNZTileBuf(dstTb)) - return emitOpError("expects dst to use blayout=col_major and slayout=row_major"); - return std::make_tuple(srcTy, fpTy, dstTy, srcTb, fpTb, dstTb, *srcSpace, - *fpSpace, *dstSpace); - }; - auto verifyA2A3 = [&]() -> LogicalResult { - auto common = verifyCommon(); - if (failed(common)) - return failure(); - auto [srcTy, fpTy, dstTy, srcTb, fpTb, dstTb, srcSpace, fpSpace, dstSpace] = - *common; - (void)fpTy; - (void)srcTb; - (void)fpTb; - (void)srcSpace; - (void)fpSpace; - (void)dstSpace; - if (dstTb.getSFractalSizeI32() != 512) - return emitOpError("expects dst fractal size to be 512"); - Type srcElem = getElemTy(srcTy); - Type dstElem = getElemTy(dstTy); - if (!isA2A3VectorPreQuantTypePair(srcElem, dstElem)) - return emitOpError( - "expects A2/A3 tinsert_fp element types to be (src=f32,dst=i8) " - "or (src=i32,dst=i8/f16/i16)"); - return success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - auto common = verifyCommon(); - if (failed(common)) - return failure(); - auto [srcTy, fpTy, dstTy, srcTb, fpTb, dstTb, srcSpace, fpSpace, dstSpace] = - *common; - (void)fpTy; - (void)srcTb; - (void)fpTb; - (void)dstTb; - (void)srcSpace; - (void)fpSpace; - (void)dstSpace; - Type srcElem = getElemTy(srcTy); - Type dstElem = getElemTy(dstTy); - if (!isA5VectorPreQuantTypePair(srcElem, dstElem)) - return emitOpError( - "expects A5 tinsert_fp element types to be (src=f32,dst=i8/fp8/f16/bf16/f32) " - "or (src=i32,dst=i8/f16/bf16)"); - return success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -static mlir::LogicalResult verifyTFillPadLike(Operation *op, Type srcTy, Type dstTy, - bool allowDstExpand, - llvm::StringRef opName) { - if (!isPTOShapedLike(srcTy) || !isPTOShapedLike(dstTy)) - return op->emitError("expects src/dst to be PTO shaped-like types"); - - auto srcShape = getShapeVec(srcTy); - auto dstShape = getShapeVec(dstTy); - if (srcShape.size() != 2 || dstShape.size() != 2) - return op->emitError("expects rank-2 shaped types for src/dst"); - - auto srcElem = getElemTy(srcTy); - auto dstElem = getElemTy(dstTy); - - auto getElemBytes = [](mlir::Type t) -> int64_t { - unsigned elemBytes = getPTOStorageElemByteSize(t); - return elemBytes == 0 ? -1 : static_cast(elemBytes); - }; - - int64_t srcB = getElemBytes(srcElem); - int64_t dstB = getElemBytes(dstElem); - if (srcB < 0 || dstB < 0) - return op->emitError("unsupported element type (expects int/float element types)"); - if (srcB != dstB) - return op->emitError("expects sizeof(src element) == sizeof(dst element)"); - if (!(srcB == 1 || srcB == 2 || srcB == 4)) - return op->emitError("expects element size to be 1, 2, or 4 bytes"); - - // pto.tfillpad lowers to TFILLPAD(dst, src). For loc=mat, pto-isa only - // exposes the homogeneous overload, so src/dst must use the same Tile<...> - // specialization (including valid_shape and pad). - // Note: tfillpad_expand is intentionally not covered here because its - // cross-layer ABI contract for loc=mat heterogeneous shape expansion is not - // finalized yet. - if (opName == "tfillpad") { - auto srcTb = mlir::dyn_cast(srcTy); - auto dstTb = mlir::dyn_cast(dstTy); - auto srcSpace = getPTOMemorySpaceEnum(srcTy); - auto dstSpace = getPTOMemorySpaceEnum(dstTy); - if (srcTb && dstTb && srcSpace && dstSpace && - *srcSpace == mlir::pto::AddressSpace::MAT && - *dstSpace == mlir::pto::AddressSpace::MAT && srcTb != dstTb) { - auto dimToStr = [](int64_t dim) -> std::string { - return dim == ShapedType::kDynamic ? "?" : std::to_string(dim); - }; - SmallVector mismatchFields; - auto srcValid = getValidShapeVec(srcTy); - auto dstValid = getValidShapeVec(dstTy); - if (srcValid.size() == 2 && dstValid.size() == 2) { - if (srcValid[0] != dstValid[0]) - mismatchFields.push_back("v_row (" + dimToStr(srcValid[0]) + " vs " + - dimToStr(dstValid[0]) + ")"); - if (srcValid[1] != dstValid[1]) - mismatchFields.push_back("v_col (" + dimToStr(srcValid[1]) + " vs " + - dimToStr(dstValid[1]) + ")"); - } - if (srcTb.getPadValueI32() != dstTb.getPadValueI32()) - mismatchFields.push_back("pad (" + std::to_string(srcTb.getPadValueI32()) + - " vs " + std::to_string(dstTb.getPadValueI32()) + - ")"); - - auto diag = op->emitError() - << "expects src/dst tile types to be lowerable to TFILLPAD " - "for loc=mat"; - if (!mismatchFields.empty()) - diag << "; mismatching fields: " << llvm::join(mismatchFields, ", "); - diag << "\n src: " << srcTy; - diag << "\n dst: " << dstTy; - diag << "\n note: heterogeneous TFILLPAD overload is only available for loc=vec"; - return failure(); - } - } - - if (auto dstTileTy = mlir::dyn_cast(dstTy)) { - auto padAttr = mlir::dyn_cast(dstTileTy.getPadValueAttr()); - if (!padAttr || padAttr.getValue() == mlir::pto::PadValue::Null) - return op->emitError() << "expects dst PadVal != Null for " << opName; - } - - if (!allowDstExpand) { - if (srcShape != dstShape) - return op->emitError() - << "expects src and dst to have the same static shape for " << opName; - return mlir::success(); - } - - if (srcShape[0] > dstShape[0] || srcShape[1] > dstShape[1]) { - return op->emitError() - << "expects dst static shape to be >= src static shape for " << opName; - } - - return mlir::success(); -} - -mlir::LogicalResult mlir::pto::TFillPadOp::verify() { - return verifyTFillPadLike(getOperation(), getSrc().getType(), getDst().getType(), - /*allowDstExpand=*/false, "tfillpad"); -} - -mlir::LogicalResult mlir::pto::TFillPadExpandOp::verify() { - return verifyTFillPadLike(getOperation(), getSrc().getType(), getDst().getType(), - /*allowDstExpand=*/true, "tfillpad_expand"); -} - -mlir::LogicalResult mlir::pto::TFillPadInplaceOp::verify() { - return verifyTFillPadLike(getOperation(), getSrc().getType(), getDst().getType(), - /*allowDstExpand=*/false, "tfillpad_inplace"); -} - - -llvm::LogicalResult mlir::pto::TGatherOp::verify() { - auto isSupportedGatherElemTypeA5Index = [&](Type ty) -> bool { - if (ty.isF16() || ty.isF32()) - return true; - if (auto it = dyn_cast(ty)) { - unsigned width = it.getWidth(); - return width == 8 || width == 16 || width == 32; - } - return false; - }; - - auto verifyMaskForm = [&](bool allowA5MaskTypes) -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, dstTy, "dst"))) - return failure(); - - Type srcElem = getElemTy(srcTy); - Type dstElem = getElemTy(dstTy); - if (!srcElem || !dstElem) - return emitOpError("failed to get element type for src/dst"); - if (!isRowMajorTileBuf(srcTy) || !isRowMajorTileBuf(dstTy)) - return emitOpError("expects src and dst to use row-major layout"); - auto srcSpace = getPTOMemorySpaceEnum(srcTy); - auto dstSpace = getPTOMemorySpaceEnum(dstTy); - if (!srcSpace || !dstSpace || *srcSpace != pto::AddressSpace::VEC || - *dstSpace != pto::AddressSpace::VEC) - return emitOpError("expects src and dst to be in the vec address space"); - unsigned srcElemBytes = getPTOStorageElemByteSize(srcElem); - unsigned dstElemBytes = getPTOStorageElemByteSize(dstElem); - if (srcElemBytes == 0 || dstElemBytes == 0) - return emitOpError("failed to get element size for src/dst"); - if (srcElemBytes != dstElemBytes) - return emitOpError("expects src and dst element sizes to match"); - - auto dstValid = getValidShapeVec(dstTy); - auto dstShape = getShapeVec(dstTy); - if (dstValid.size() == 2 && dstShape.size() == 2 && - dstValid[1] != ShapedType::kDynamic && dstShape[1] != ShapedType::kDynamic && - dstValid[1] != dstShape[1]) { - return emitOpError("expects dst valid_shape[1] to equal dst cols"); - } - - if (allowA5MaskTypes) { - if (!(srcElemBytes == 1 || srcElemBytes == 2 || srcElemBytes == 4)) - return emitOpError("expects A5 mask-pattern gather element size to be 1, 2, or 4 bytes"); - if (!isSupportedGatherElemTypeA5(srcElem) || !isSupportedGatherElemTypeA5(dstElem)) - return emitOpError( - "expects A5 mask-pattern gather src/dst element type to be i8/i16/i32/f16/bf16/f32/fp8-like"); - } else { - if (!(srcElemBytes == 2 || srcElemBytes == 4)) - return emitOpError("expects A2/A3 mask-pattern gather element size to be 2 or 4 bytes"); - } - return success(); - }; - - auto verifyIndexForm = [&](bool allow16BitIndices, bool allowA5ElemTypes) -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - Type idxTy = getIndices().getType(); - Type tmpTy = getTmp().getType(); - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, dstTy, "dst")) || - failed(verifyTileBufCommon(*this, idxTy, "indices")) || - failed(verifyTileBufCommon(*this, tmpTy, "tmp"))) - return failure(); - - Type srcElem = getElemTy(srcTy); - Type dstElem = getElemTy(dstTy); - if (!srcElem || !dstElem) - return emitOpError("failed to get element type for src/dst"); - if (srcElem != dstElem) - return emitOpError("expects src and dst to have the same element type"); - if (allowA5ElemTypes) { - if (!isSupportedGatherElemTypeA5Index(srcElem) || - !isSupportedGatherElemTypeA5Index(dstElem)) - return emitOpError( - "expects A5 gather src/dst element type to be i8/i16/i32/f16/f32"); - } else if (!isSupportedGatherElemTypeA2A3(srcElem) || - !isSupportedGatherElemTypeA2A3(dstElem)) { - return emitOpError("expects gather src/dst element type to be i16/i32/f16/f32"); - } - - auto idxElem = dyn_cast(getElemTy(idxTy)); - if (!idxElem) - return emitOpError("indices element type must be integer"); - unsigned width = idxElem.getWidth(); - if (!(width == 32 || (allow16BitIndices && width == 16))) { - return emitOpError() << "expects indices element type to be i32" - << (allow16BitIndices ? " or i16" : ""); - } - - auto dstValid = getValidShapeVec(dstTy); - auto dstShape = getShapeVec(dstTy); - if (dstValid.size() == 2 && dstShape.size() == 2 && - dstValid[1] != ShapedType::kDynamic && dstShape[1] != ShapedType::kDynamic && - dstValid[1] != dstShape[1]) { - return emitOpError("expects dst valid_shape[1] to equal dst cols"); - } - - auto idxValid = getValidShapeVec(idxTy); - auto idxShape = getShapeVec(idxTy); - if (idxValid.size() == 2 && idxShape.size() == 2 && - idxValid[1] != ShapedType::kDynamic && idxShape[1] != ShapedType::kDynamic && - idxValid[1] != idxShape[1]) { - return emitOpError("expects indices valid_shape[1] to equal indices cols"); - } - - if (!allowA5ElemTypes) { - Type tmpElem = getElemTy(tmpTy); - if (tmpElem != idxElem) - return emitOpError("expects tmp and indices to have the same element type"); - if (failed(verifyTileBufSameValidShape(*this, idxTy, tmpTy, "indices", "tmp"))) - return failure(); - } - return success(); - }; - - auto verifyCompareForm = [&](bool allowA5SrcTypes) -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - Type cdstTy = getCdst().getType(); - Type tmpTy = getTmp().getType(); - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, dstTy, "dst")) || - failed(verifyTileBufCommon(*this, cdstTy, "cdst")) || - failed(verifyTileBufCommon(*this, tmpTy, "tmp"))) - return failure(); - - Type srcElem = getElemTy(srcTy); - Type dstElem = getElemTy(dstTy); - Type cdstElem = getElemTy(cdstTy); - if (!srcElem || !dstElem || !cdstElem) - return emitOpError("failed to get element type for src/dst/cdst"); - auto dstInt = dyn_cast(dstElem); - if (!dstInt || dstInt.getWidth() != 32) - return emitOpError("expects dst element type to be i32"); - if (cdstElem != dstElem) - return emitOpError("expects cdst to have the same element type as dst"); - if (getKValue().getType() != srcElem) - return emitOpError("expects kValue to have the same type as src element type"); - - auto cmpAttr = getCmpModeAttr(); - auto cmpMode = cmpAttr ? cmpAttr.getValue() : pto::CmpMode::EQ; - if (cmpMode != pto::CmpMode::EQ && cmpMode != pto::CmpMode::GT) - return emitOpError("expects compare-form tgather cmpMode to be eq or gt"); - - if (allowA5SrcTypes) { - if (!(srcElem.isF16() || srcElem.isF32() || srcElem.isInteger(16) || - srcElem.isInteger(32))) { - return emitOpError( - "expects A5 compare-form tgather src element type to be i16/i32/f16/f32"); - } - } else { - if (!(srcElem.isF16() || srcElem.isF32() || - (srcElem.isInteger(32) && cmpMode == pto::CmpMode::EQ))) { - return emitOpError( - "expects A2/A3 compare-form tgather src element type to be f16/f32, or i32 when cmpMode=eq"); - } - } - - if (failed(verifyVecTileCommonA2A3(*this, srcTy, "src")) || - failed(verifyVecTileCommonA2A3(*this, dstTy, "dst")) || - failed(verifyVecTileCommonA2A3(*this, cdstTy, "cdst")) || - failed(verifyVecTileCommonA2A3(*this, tmpTy, "tmp"))) - return failure(); - return success(); - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - if (getMaskPatternAttr()) { - if (getCdst() || getIndices() || getTmp() || getKValue()) - return emitOpError("mask-pattern tgather only allows src and dst operands"); - return verifyMaskForm(/*allowA5MaskTypes=*/false); - } - if (getCdst() || getKValue()) { - if (!getCdst() || !getKValue() || !getTmp()) - return emitOpError("compare-form tgather expects dst, cdst, kValue, and tmp"); - if (getIndices()) - return emitOpError("compare-form tgather does not take indices"); - return verifyCompareForm(/*allowA5SrcTypes=*/false); - } - if (!getIndices() || !getTmp()) - return emitOpError("index-form tgather expects both indices and tmp"); - return verifyIndexForm(/*allow16BitIndices=*/false, /*allowA5ElemTypes=*/false); - }; - - auto verifyA5 = [&]() -> LogicalResult { - if (getMaskPatternAttr()) { - if (getCdst() || getIndices() || getTmp() || getKValue()) - return emitOpError("mask-pattern tgather only allows src and dst operands"); - return verifyMaskForm(/*allowA5MaskTypes=*/true); - } - if (getCdst() || getKValue()) { - if (!getCdst() || !getKValue() || !getTmp()) - return emitOpError("compare-form tgather expects dst, cdst, kValue, and tmp"); - if (getIndices()) - return emitOpError("compare-form tgather does not take indices"); - return verifyCompareForm(/*allowA5SrcTypes=*/true); - } - if (!getIndices() || !getTmp()) - return emitOpError("index-form tgather expects both indices and tmp"); - return verifyIndexForm(/*allow16BitIndices=*/true, /*allowA5ElemTypes=*/true); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} -mlir::LogicalResult mlir::pto::TGatherBOp::verify() { - auto verifyCommon = [&]() -> FailureOr> { - Type srcTy = getSrc().getType(); - Type offTy = getOffsets().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, offTy, "offsets")) || - failed(verifyTileBufCommon(*this, dstTy, "dst"))) - return failure(); - auto srcElemTy = getElemTy(srcTy); - auto dstElemTy = getElemTy(dstTy); - if (!srcElemTy || !dstElemTy) - return emitOpError() << "failed to get element type for src/dst"; - return std::make_pair(srcElemTy, dstElemTy); - }; - - auto getElemBytes = [](Type ty) -> std::optional { - unsigned elemBytes = getPTOStorageElemByteSize(ty); - if (elemBytes == 0) - return std::nullopt; - return elemBytes; - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - FailureOr> elems = verifyCommon(); - if (failed(elems)) - return failure(); - Type dstTy = getDst().getType(); - Type dstElemTy = elems->second; - if (!isRowMajorTileBuf(dstTy)) - return emitOpError() << "expects dst to use row-major layout"; - auto dstBytes = getElemBytes(dstElemTy); - if (!dstBytes || (*dstBytes != 1 && *dstBytes != 2 && *dstBytes != 4)) - return emitOpError() << "expects dst element size to be 1, 2, or 4 bytes"; - return mlir::success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - FailureOr> elems = verifyCommon(); - if (failed(elems)) - return failure(); - Type dstElemTy = elems->second; - auto dstBytes = getElemBytes(dstElemTy); - if (!dstBytes || (*dstBytes != 1 && *dstBytes != 2 && *dstBytes != 4)) - return emitOpError() << "expects dst element size to be 1, 2, or 4 bytes"; - return mlir::success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TLogOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyVecTileUnaryOp(*this, srcTy, dstTy, "src", "dst", - /*allowBf16=*/false, /*allowInt8=*/false))) - return failure(); - if (failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) - return failure(); - auto elemTy = getElemTy(srcTy); - if (!(elemTy.isF16() || elemTy.isF32())) - return emitOpError() << "expects element type to be f16 or f32"; - return mlir::success(); -} - -mlir::LogicalResult mlir::pto::TLReluOp::verify() { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - auto verifyA2A3 = [&]() -> LogicalResult { - if (failed(verifyVecTileStorage(*this, srcTy, "src")) || - failed(verifyVecTileStorage(*this, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst")) || - failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) - return failure(); - auto valid = getValidShapeVec(srcTy); - if (valid.size() != 2) - return emitOpError("expects src to have rank-2 valid_shape"); - if (valid[0] != ShapedType::kDynamic && valid[0] <= 0) - return emitOpError("expects src valid_shape[0] to be positive"); - if (valid[1] != ShapedType::kDynamic && valid[1] <= 0) - return emitOpError("expects src valid_shape[1] to be positive"); - Type elemTy = getElemTy(srcTy); - if (!(elemTy.isF16() || elemTy.isF32())) - return emitOpError() << "expects A2/A3 tlrelu element type to be f16 or f32"; - return success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - if (failed(verifyVecTileStorage(*this, srcTy, "src")) || - failed(verifyVecTileStorage(*this, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst")) || - failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) - return failure(); - Type elemTy = getElemTy(srcTy); - if (!(elemTy.isF16() || elemTy.isF32())) - return emitOpError() << "expects A5 tlrelu element type to be f16 or f32"; - if (!getSlope().getType().isF32()) - return emitOpError() << "expects slope to have type f32"; - return success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TMaxOp::verify() { - return verifyArithmeticBinaryTileOpWithArchDispatch( - getOperation(), getSrc0().getType(), getSrc1().getType(), getDst().getType(), - /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/false, - "expects A2/A3 tmax element type to be i32/i16/f16/f32", - "expects A5 tmax element type to be i32/i16/i8/f16/f32"); -} - -mlir::LogicalResult mlir::pto::TMaxSOp::verify() { - return verifyArithmeticScalarTileOpWithArchDispatch( - getOperation(), getSrc().getType(), getDst().getType(), getScalar().getType(), - /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, - "expects A2/A3 tmaxs element type to be i32/i16/f16/f32", - "expects A5 tmaxs element type to be i32/i16/i8/f16/bf16/f32", - /*requireValidRowsEqualOnA2A3=*/true, - /*requireValidRowsEqualOnA5=*/true); -} - -mlir::LogicalResult mlir::pto::TMinOp::verify() { - return verifyArithmeticBinaryTileOpWithArchDispatch( - getOperation(), getSrc0().getType(), getSrc1().getType(), getDst().getType(), - /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, - "expects A2/A3 tmin element type to be i32/i16/f16/f32", - "expects A5 tmin element type to be i32/i16/i8/f16/bf16/f32"); -} - -mlir::LogicalResult mlir::pto::TMinSOp::verify() { - return verifyArithmeticScalarTileOpWithArchDispatch( - getOperation(), getSrc().getType(), getDst().getType(), getScalar().getType(), - /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, - "expects A2/A3 tmins element type to be i32/i16/f16/f32", - "expects A5 tmins element type to be i32/i16/i8/f16/bf16/f32", - /*requireValidRowsEqualOnA2A3=*/true, - /*requireValidRowsEqualOnA5=*/false); -} - -mlir::LogicalResult mlir::pto::TMovOp::verify() { - auto verifyImpl = [&](bool isA5) -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - Value fp = getFp(); - Value preQuantScalar = getPreQuantScalar(); - auto accToVecModeAttr = getAccToVecModeAttr(); - auto reluMode = getReluPreMode(); - const bool hasFp = static_cast(fp); - const bool hasPreQuantScalar = static_cast(preQuantScalar); - - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, dstTy, "dst"))) - return failure(); - if (hasFp && failed(verifyTileBufCommon(*this, fp.getType(), "fp"))) - return failure(); - if (hasFp && hasPreQuantScalar) - return emitOpError() << "expects fp and preQuantScalar forms to be mutually exclusive"; - - auto srcSpace = getPTOMemorySpaceEnum(srcTy); - auto dstSpace = getPTOMemorySpaceEnum(dstTy); - if (!srcSpace || !dstSpace) - return emitOpError() << "expects src and dst to have explicit address spaces"; - - auto srcShape = getShapeVec(srcTy); - auto dstShape = getShapeVec(dstTy); - if (*srcSpace == pto::AddressSpace::MAT && srcShape != dstShape) - return emitOpError() << "expects mat-source tmov to use matching src/dst shapes"; - if (!isA5 && *srcSpace != pto::AddressSpace::MAT && srcShape != dstShape) - return emitOpError() << "expects A2/A3 non-mat tmov to use matching src/dst shapes"; - - const bool isMatToTile = - *srcSpace == pto::AddressSpace::MAT && - (*dstSpace == pto::AddressSpace::LEFT || - *dstSpace == pto::AddressSpace::RIGHT || - *dstSpace == pto::AddressSpace::BIAS || - *dstSpace == pto::AddressSpace::SCALING); - const bool isVecToVec = - *srcSpace == pto::AddressSpace::VEC && - *dstSpace == pto::AddressSpace::VEC; - const bool isVecToMat = - *srcSpace == pto::AddressSpace::VEC && - *dstSpace == pto::AddressSpace::MAT; - const bool isAccToMat = - *srcSpace == pto::AddressSpace::ACC && - *dstSpace == pto::AddressSpace::MAT; - const bool isAccToVec = - *srcSpace == pto::AddressSpace::ACC && - *dstSpace == pto::AddressSpace::VEC; - - bool okPair = isMatToTile || isVecToVec || isAccToMat || isAccToVec; - if (isA5) - okPair = okPair || isVecToMat; - if (!okPair) - return emitOpError() - << "expects a supported tmov address-space pair for this target"; - - if (accToVecModeAttr && !isAccToVec) - return emitOpError() - << "expects accToVecMode to be used only for acc-to-vec tmov"; - - if (reluMode != pto::ReluPreMode::NoRelu && !(isAccToMat || isAccToVec)) - return emitOpError() - << "expects reluPreMode form to use loc=acc src"; - - if (hasPreQuantScalar && !(isAccToMat || isAccToVec)) - return emitOpError() - << "expects preQuantScalar form to use loc=acc src"; - - if (hasFp) { - auto fpTy = fp.getType(); - auto fpSpace = getPTOMemorySpaceEnum(fpTy); - if (!fpSpace || *fpSpace != pto::AddressSpace::SCALING) - return emitOpError() << "expects fp to be in the scaling address space"; - auto srcElemTy = getElemTy(srcTy); - auto srcIntTy = dyn_cast(srcElemTy); - if (!(srcElemTy.isF32() || (srcIntTy && srcIntTy.getWidth() == 32))) - return emitOpError() - << "expects fp form src to have element type f32, i32"; - if (!(isAccToMat || isAccToVec)) - return emitOpError() << "expects fp form to use loc=acc src"; - } - - if ((hasFp || hasPreQuantScalar) && accToVecModeAttr) { - switch (accToVecModeAttr.getValue()) { - case pto::AccToVecMode::SingleModeVec0: - case pto::AccToVecMode::SingleModeVec1: - break; - case pto::AccToVecMode::DualModeSplitM: - case pto::AccToVecMode::DualModeSplitN: - return emitOpError() - << "expects fp/preQuantScalar acc-to-vec forms to use single-mode accToVecMode"; - } - } - - auto srcTb = dyn_cast(srcTy); - auto dstTb = dyn_cast(dstTy); - if (srcTb && *srcSpace == pto::AddressSpace::ACC && - (hasFp || reluMode != pto::ReluPreMode::NoRelu)) { - if (srcTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor) || - srcTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor)) - return emitOpError() - << "expects acc-source fp/relu tmov src to use blayout=col_major and slayout=row_major"; - } - if (srcTb && dstTb && isAccToMat && !isA5 && - dstTb.getSFractalSizeI32() != 512) - return emitOpError() << "expects A2/A3 acc-to-mat tmov destination fractal to be 512"; - - return success(); - }; - auto verifyA2A3 = [&]() -> LogicalResult { return verifyImpl(/*isA5=*/false); }; - auto verifyA5 = [&]() -> LogicalResult { return verifyImpl(/*isA5=*/true); }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TMovFPOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type fpTy = getFp().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, fpTy, "fp")) || - failed(verifyTileBufCommon(*this, dstTy, "dst"))) - return failure(); - auto srcElemTy = getElemTy(srcTy); - auto srcIntTy = dyn_cast(srcElemTy); - if (!(srcElemTy.isF32() || - (srcIntTy && srcIntTy.getWidth() == 32))) - return emitOpError() - << "expects src to have element type f32, i32"; - auto fpSpace = getPTOMemorySpaceEnum(fpTy); - if (!fpSpace || *fpSpace != mlir::pto::AddressSpace::SCALING) - return emitOpError() << "expects fp to be in the scaling address space"; - auto srcSpace = getPTOMemorySpaceEnum(srcTy); - if (!srcSpace || *srcSpace != mlir::pto::AddressSpace::ACC) - return emitOpError() << "expects src to be in the acc address space"; - auto dstSpace = getPTOMemorySpaceEnum(dstTy); - if (!dstSpace || *dstSpace != mlir::pto::AddressSpace::MAT) - return emitOpError() << "expects dst to be in the mat address space"; - auto srcTb = dyn_cast(srcTy); - auto dstTb = dyn_cast(dstTy); - if (srcTb && - (srcTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor) || - srcTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor))) - return emitOpError() - << "expects src to use blayout=col_major and slayout=row_major"; - if (dstTb && - (dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor) || - dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor))) - return emitOpError() - << "expects dst to use blayout=col_major and slayout=row_major"; - if (dstTb && dstTb.getSFractalSizeI32() != 512) - return emitOpError() << "expects dst to use fractal size 512"; - return mlir::success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type fpTy = getFp().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, fpTy, "fp")) || - failed(verifyTileBufCommon(*this, dstTy, "dst"))) - return failure(); - auto srcElemTy = getElemTy(srcTy); - auto srcIntTy = dyn_cast(srcElemTy); - if (!(srcElemTy.isF32() || - (srcIntTy && srcIntTy.getWidth() == 32))) - return emitOpError() - << "expects src to have element type f32, i32"; - auto fpSpace = getPTOMemorySpaceEnum(fpTy); - if (!fpSpace || *fpSpace != mlir::pto::AddressSpace::SCALING) - return emitOpError() << "expects fp to be in the scaling address space"; - auto srcTb = dyn_cast(srcTy); - if (srcTb && - (srcTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor) || - srcTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor))) - return emitOpError() - << "expects src to use blayout=col_major and slayout=row_major"; - return success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} -// 辅助函数:获取 Rank,支持 ShapedType 和 PTO TileTypes -static int64_t getRankHelper(Type t) { - if (auto s = dyn_cast(t)) return s.getRank(); - if (auto tile = dyn_cast(t)) return tile.getRank(); - if (auto view = dyn_cast(t)) return view.getRank(); - return -1; -} - -static LogicalResult verifyMatmulLike(Operation *op, Type aTy, Type bTy, Type dstTy, bool checkRank = true) { - // 1. 检查类型 (ShapedType 或 Tile 类型) - bool aValid = isa(aTy); - bool bValid = isa(bTy); - bool dValid = isa(dstTy); - - if (!aValid || !bValid || !dValid) - return op->emitOpError("expects inputs/outputs to be shaped types or PTO tile types"); - - if (checkRank) { - int64_t aRank = getRankHelper(aTy); - int64_t bRank = getRankHelper(bTy); - int64_t dRank = getRankHelper(dstTy); - - // 检查 Rank 一致性 - if (aRank != -1 && dRank != -1 && aRank != dRank) - return op->emitOpError("expects a and dst to have the same rank"); - if (bRank != -1 && dRank != -1 && bRank != dRank) - return op->emitOpError("expects b and dst to have the same rank"); - } - - return success(); -} - -// ---- LoadScalarOp ---- -LogicalResult LoadScalarOp::verify() { - Type ptrTy = getPtr().getType(); - Type elemTy; - if (auto pty = dyn_cast(ptrTy)) { - elemTy = pty.getElementType(); - } else if (auto memTy = dyn_cast(ptrTy)) { - elemTy = memTy.getElementType(); - if (!isGmAddressSpaceAttr(memTy.getMemorySpace())) - return emitOpError() << "scalar load only supports GM address space pointers"; - } else { - return emitOpError("expects ptr to be !pto.ptr or memref type"); - } - - if (getValue().getType() != elemTy) - return emitOpError("expects result type to match ptr element type"); - - return success(); -} -// ---- StoreScalarOp ---- -LogicalResult StoreScalarOp::verify() { - Type ptrTy = getPtr().getType(); - Type elemTy; - if (auto pty = dyn_cast(ptrTy)) { - elemTy = pty.getElementType(); - } else if (auto memTy = dyn_cast(ptrTy)) { - elemTy = memTy.getElementType(); - if (!isGmAddressSpaceAttr(memTy.getMemorySpace())) - return emitOpError() << "scalar store only supports GM address space pointers"; - } else { - return emitOpError("expects ptr to be !pto.ptr or memref type"); - } - - if (getValue().getType() != elemTy) - return emitOpError("expects value type to match ptr element type"); - - return success(); -} - -// ---- GetBufOp / RlsBufOp ---- -static LogicalResult verifyBufSyncOp(Operation *op, Attribute opTypeAttr, - IntegerAttr bufIdAttr, IntegerAttr modeAttr) { - if (!opTypeAttr) - return op->emitOpError("expects 'op_type' attribute"); - - auto opTypeOr = parseSyncOpTypeLikeAttr(opTypeAttr); - if (failed(opTypeOr)) { - auto diag = - op->emitOpError("expects 'op_type' to be pipe_event_type/sync_op_type, got "); - diag << opTypeAttr; - return failure(); - } - pto::PIPE pipe = mapSyncOpTypeToPipe(*opTypeOr); - if (!isConcreteSyncPipe(pipe)) - return op->emitOpError("expects 'op_type' to map to a concrete pipe, not PIPE_ALL/PIPE_UNASSIGNED"); - - if (!bufIdAttr) - return op->emitOpError("expects 'buf_id' attribute"); - int64_t bufId = bufIdAttr.getInt(); - if (bufId < 0 || bufId > 31) - return op->emitOpError("expects 'buf_id' in range [0, 31]"); - - if (modeAttr) { - int64_t mode = modeAttr.getInt(); - if (mode < 0) - return op->emitOpError("expects 'mode' to be non-negative"); - } - - return success(); -} - -LogicalResult GetBufOp::verify() { - return verifyBufSyncOp(getOperation(), getOpTypeAttr(), getBufIdAttr(), - getModeAttr()); -} - -LogicalResult RlsBufOp::verify() { - return verifyBufSyncOp(getOperation(), getOpTypeAttr(), getBufIdAttr(), - getModeAttr()); -} -// ---- TOp ---- -LogicalResult TGemvBiasOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - if (failed(verifyGemvTileOperands(*this, getA().getType(), getB().getType(), - getDst().getType())) || - failed(verifyMatBiasTile(*this, getBias().getType(), getDst().getType()))) - return failure(); - if (failed(verifyMatmulTypeTriple(*this, getElemTy(getA().getType()), - getElemTy(getB().getType()), - getElemTy(getDst().getType())))) - return failure(); - return verifyMatmulLike(*this, getA().getType(), getB().getType(), - getDst().getType()); - }; - auto verifyA5 = [&]() -> LogicalResult { return verifyA2A3(); }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult TGemvMxOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - return emitOpError("tgemv.mx is only supported on A5 targets"); - }; - auto verifyA5 = [&]() -> LogicalResult { - if (failed(verifyScaleTileMatchesOperand(*this, getAScale().getType(), - getA().getType(), "a_scale", "a")) || - failed(verifyScaleTileMatchesOperand(*this, getBScale().getType(), - getB().getType(), "b_scale", "b")) || - failed(verifyGemvTileOperands(*this, getA().getType(), getB().getType(), - getDst().getType()))) - return failure(); - if (failed(verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), - getDst().getType(), "a", "b", "dst"))) - return failure(); - return verifyMatmulLike(*this, getA().getType(), getB().getType(), - getDst().getType()); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult TGemvMxAccOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - return emitOpError("tgemv.mx.acc is only supported on A5 targets"); - }; - auto verifyA5 = [&]() -> LogicalResult { - if (failed(verifyAccTileCommon(*this, getCIn().getType(), "c_in")) || - failed(verifyScaleTileMatchesOperand(*this, getAScale().getType(), - getA().getType(), "a_scale", "a")) || - failed(verifyScaleTileMatchesOperand(*this, getBScale().getType(), - getB().getType(), "b_scale", "b")) || - failed(verifyGemvTileOperands(*this, getA().getType(), getB().getType(), - getDst().getType()))) - return failure(); - if (failed(verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), - getDst().getType(), "a", "b", "dst"))) - return failure(); - if (failed(verifyTileBufSameElemType(*this, getCIn().getType(), - getDst().getType(), "c_in", "dst")) || - failed(verifyTileBufSameValidShape(*this, getCIn().getType(), - getDst().getType(), "c_in", "dst"))) - return failure(); - return verifyMatmulLike(*this, getA().getType(), getB().getType(), - getDst().getType()); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult TGemvMxBiasOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - return emitOpError("tgemv.mx.bias is only supported on A5 targets"); - }; - auto verifyA5 = [&]() -> LogicalResult { - if (failed(verifyScaleTileMatchesOperand(*this, getAScale().getType(), - getA().getType(), "a_scale", "a")) || - failed(verifyScaleTileMatchesOperand(*this, getBScale().getType(), - getB().getType(), "b_scale", "b")) || - failed(verifyGemvTileOperands(*this, getA().getType(), getB().getType(), - getDst().getType())) || - failed(verifyMatBiasTile(*this, getBias().getType(), getDst().getType(), - /*requireFloatBias=*/true))) - return failure(); - if (failed(verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), - getDst().getType(), "a", "b", "dst"))) - return failure(); - auto biasShape = getShapeVec(getBias().getType()); - auto dstShape = getShapeVec(getDst().getType()); - if (biasShape.size() != 2 || dstShape.size() != 2) - return emitOpError("expects bias and dst to be rank-2 for tgemv.mx.bias"); - if (biasShape[1] != ShapedType::kDynamic && dstShape[1] != ShapedType::kDynamic && - biasShape[1] != dstShape[1]) - return emitOpError("expects bias and dst to have the same column shape"); - if (failed(verifyTileBufSameValidShape(*this, getBias().getType(), - getDst().getType(), "bias", "dst"))) - return failure(); - return verifyMatmulLike(*this, getA().getType(), getB().getType(), - getDst().getType()); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult TMatmulBiasOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - if (failed(verifyMatTileOperands(*this, getA().getType(), getB().getType(), - getDst().getType())) || - failed(verifyMatBiasTile(*this, getBias().getType(), getDst().getType()))) - return failure(); - if (failed(verifyMatmulTypeTriple(*this, getElemTy(getA().getType()), - getElemTy(getB().getType()), - getElemTy(getDst().getType())))) - return failure(); - return verifyMatmulLike(*this, getA().getType(), getB().getType(), - getDst().getType()); - }; - auto verifyA5 = [&]() -> LogicalResult { return verifyA2A3(); }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult TMatmulMxOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - if (failed(verifyTileBufCommon(*this, getAScale().getType(), "a_scale")) || - failed(verifyTileBufCommon(*this, getBScale().getType(), "b_scale"))) - return failure(); - return verifyMatmulLike(*this, getA().getType(), getB().getType(), - getDst().getType()); - }; - auto verifyA5 = [&]() -> LogicalResult { - if (failed(verifyA2A3())) - return failure(); - return verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), - getDst().getType(), "a", "b", "dst"); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult TMatmulMxAccOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - if (failed(verifyAccTileCommon(*this, getCIn().getType(), "c_in")) || - failed(verifyTileBufCommon(*this, getAScale().getType(), "a_scale")) || - failed(verifyTileBufCommon(*this, getBScale().getType(), "b_scale"))) - return failure(); - return success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - if (failed(verifyA2A3())) - return failure(); - if (failed(verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), - getDst().getType(), "a", "b", "dst"))) - return failure(); - if (failed(verifyTileBufSameElemType(*this, getCIn().getType(), - getDst().getType(), "c_in", "dst")) || - failed(verifyTileBufSameValidShape(*this, getCIn().getType(), - getDst().getType(), "c_in", "dst"))) - return failure(); - return success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} -LogicalResult TMatmulMxBiasOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - if (failed(verifyTileBufCommon(*this, getAScale().getType(), "a_scale")) || - failed(verifyTileBufCommon(*this, getBScale().getType(), "b_scale")) || - failed(verifyMatTileOperands(*this, getA().getType(), getB().getType(), - getDst().getType())) || - failed(verifyMatBiasTile(*this, getBias().getType(), getDst().getType(), - /*requireFloatBias=*/true))) - return failure(); - return verifyMatmulLike(*this, getA().getType(), getB().getType(), - getDst().getType()); - }; - auto verifyA5 = [&]() -> LogicalResult { - if (failed(verifyA2A3())) - return failure(); - return verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), - getDst().getType(), "a", "b", "dst"); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} -// ---- TSetValOp ---- -LogicalResult TSetValOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - // dst can be tile/tensor/tilebuf (PTODpsType). Keep checks minimal. - if (auto shaped = dyn_cast(getDst().getType())) { - if (shaped.getElementType() != getVal().getType()) - return emitOpError("expects val type to match dst element type"); - } - return success(); -} -// ---- TGetValOp ---- -LogicalResult TGetValOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - Type srcTy = getSrc().getType(); - if (!mlir::isa(srcTy)) - return emitOpError("expects src to be tile_buf or memref type"); - - // Memory space must be vec (Ascend does not support getval from MAT etc.). - Attribute memSpace = - isa(srcTy) - ? cast(srcTy).getMemorySpace() - : cast(srcTy).getMemorySpace(); - auto addrSpaceAttr = dyn_cast_or_null(memSpace); - if (!addrSpaceAttr || - addrSpaceAttr.getAddressSpace() != pto::AddressSpace::VEC) { - if (addrSpaceAttr && - addrSpaceAttr.getAddressSpace() == pto::AddressSpace::MAT) - return emitOpError( - "Ascend hardware does not support reading from Mat tile_buf to Scalar unit"); - return emitOpError("expects src memory space to be vec"); - } - - if (getElemTy(srcTy) != getDst().getType()) - return emitOpError("expects dst type to match src element type"); - return success(); -} - -LogicalResult THistogramOp::verify() { - auto isIntegerWidth = [](Type ty, unsigned width) { - auto it = dyn_cast(ty); - return it && it.getWidth() == width; - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - return emitOpError("thistogram is only supported on A5"); - }; - auto verifyA5 = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type idxTy = getIdx().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, idxTy, "idx")) || - failed(verifyTileBufCommon(*this, dstTy, "dst"))) - return failure(); - - auto srcSpace = getPTOMemorySpaceEnum(srcTy); - auto idxSpace = getPTOMemorySpaceEnum(idxTy); - auto dstSpace = getPTOMemorySpaceEnum(dstTy); - if (!srcSpace || *srcSpace != pto::AddressSpace::VEC) - return emitOpError("expects src to be in the vec address space"); - if (!idxSpace || *idxSpace != pto::AddressSpace::VEC) - return emitOpError("expects idx to be in the vec address space"); - if (!dstSpace || *dstSpace != pto::AddressSpace::VEC) - return emitOpError("expects dst to be in the vec address space"); - - auto srcTB = dyn_cast(srcTy); - auto idxTB = dyn_cast(idxTy); - auto dstTB = dyn_cast(dstTy); - if (!srcTB || !idxTB || !dstTB) - return emitOpError("expects src, idx, and dst to be tile_buf types"); - - if (srcTB.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) || - srcTB.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) - return emitOpError("expects src to use row_major + none_box layout"); - if (dstTB.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) || - dstTB.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) - return emitOpError("expects dst to use row_major + none_box layout"); - if (idxTB.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor) || - idxTB.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) - return emitOpError( - "expects idx to use DN layout (col_major + none_box)"); - - if (!isIntegerWidth(getElemTy(srcTy), 16)) - return emitOpError("expects src element type to be ui16"); - if (!isIntegerWidth(getElemTy(idxTy), 8)) - return emitOpError("expects idx element type to be ui8"); - if (!isIntegerWidth(getElemTy(dstTy), 32)) - return emitOpError("expects dst element type to be ui32"); - - auto srcShape = getShapeVec(srcTy); - auto idxShape = getShapeVec(idxTy); - auto dstShape = getShapeVec(dstTy); - auto srcValid = getValidShapeVec(srcTy); - auto idxValid = getValidShapeVec(idxTy); - auto dstValid = getValidShapeVec(dstTy); - if (srcShape.size() != 2 || idxShape.size() != 2 || dstShape.size() != 2 || - srcValid.size() != 2 || idxValid.size() != 2 || dstValid.size() != 2) - return emitOpError( - "expects src, idx, and dst to have rank-2 shape and valid_shape"); - - if (!hasCompatibleKnownExtent(srcShape[0], idxShape[0]) || - !hasCompatibleKnownExtent(srcValid[0], idxValid[0])) - return emitOpError("expects idx rows and valid rows to match src"); - if (!hasCompatibleKnownExtent(srcShape[0], dstShape[0]) || - !hasCompatibleKnownExtent(srcValid[0], dstValid[0])) - return emitOpError("expects dst rows and valid rows to match src"); - - if (!isKnownUnitExtent(idxShape[1]) || !isKnownUnitExtent(idxValid[1])) - return emitOpError("expects idx to have exactly one column"); - if (dstShape[1] != ShapedType::kDynamic && dstShape[1] < 256) - return emitOpError("expects dst shape[1] to be at least 256"); - if (dstValid[1] != ShapedType::kDynamic && dstValid[1] < 256) - return emitOpError("expects dst valid_shape[1] to be at least 256"); - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult TGetScaleAddrOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - return emitOpError("tget_scale_addr is only supported on A5"); - }; - auto verifyA5 = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, srcTy, "src"))) - return failure(); - if (failed(verifyScaleTileMatchesOperand(*this, dstTy, srcTy, "dst", "src"))) - return failure(); - return success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -// ---- MScatterOp ---- -LogicalResult MScatterOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - - if (!isTargetArchA5(getOperation())) - return emitOpError("pto.mscatter is only supported on A5 targets"); - - Type srcTy = getSrc().getType(); - Type idxTy = getIdx().getType(); - Type memTy = getMem().getType(); - - if (getPTOTypeRank(srcTy) == -1 || getPTOTypeRank(idxTy) == -1 || - getPTOTypeRank(memTy) == -1) - return emitOpError("expects src, idx, and mem to use supported PTO shapes"); - - if (failed(verifyNDStyleVecTile(*this, srcTy, "src")) || - failed(verifyMGatherMScatterIdxTile(getOperation(), idxTy, "idx"))) - return failure(); - - Type srcElem = getElemTy(srcTy); - Type idxElem = getElemTy(idxTy); - if (!srcElem || !idxElem) - return emitOpError("failed to resolve element types for src or idx"); - - if (!isSupportedMGatherMScatterPayloadElemType(getOperation(), srcElem)) - return emitOpError( - "expects src element type to be i8/ui8/i16/ui16/i32/ui32/f16/bf16/f32 " - "(and on A5 targets also float8_e4m3/float8_e5m2 family types)"); - - if (!isSupportedMGatherMScatterIndexElemType(idxElem)) - return emitOpError("expects idx element type to be signless i32"); - - if (failed(verifyMGatherMScatterMemOperand(getOperation(), getMem(), srcElem, - "src"))) - return failure(); - - if (getScatterAtomicOp() != pto::ScatterAtomicOp::None || - getScatterOob() != pto::ScatterOOB::Undefined) { - if (!isTargetArchA5(getOperation())) - return emitOpError( - "expects non-default scatterAtomicOp/scatterOob only on A5 targets"); - } - - if (!isSupportedMScatterAtomicPayloadElemType(srcElem, getScatterAtomicOp())) - return emitOpError( - "expects scatterAtomicOp-compatible src element type: add supports " - "i32/ui32/f16/f32, max/min support signless i32/f32"); - - if (failed(verifyMGatherMScatterTileShape(getOperation(), srcTy, idxTy, "src"))) - return failure(); - - return success(); -} - -// ---- MGatherOp ---- -LogicalResult MGatherOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - - if (!isTargetArchA5(getOperation())) - return emitOpError("pto.mgather is only supported on A5 targets"); - - Type memTy = getMem().getType(); - Type idxTy = getIdx().getType(); - Type dstTy = getDst().getType(); - - if (getPTOTypeRank(memTy) == -1 || getPTOTypeRank(idxTy) == -1 || - getPTOTypeRank(dstTy) == -1) - return emitOpError("expects mem, idx, and dst to use supported PTO shapes"); - - if (failed(verifyNDStyleVecTile(*this, dstTy, "dst")) || - failed(verifyMGatherMScatterIdxTile(getOperation(), idxTy, "idx"))) - return failure(); - - Type dstElem = getElemTy(dstTy); - Type idxElem = getElemTy(idxTy); - if (!dstElem || !idxElem) - return emitOpError("failed to resolve element types for dst or idx"); - - if (!isSupportedMGatherMScatterPayloadElemType(getOperation(), dstElem)) - return emitOpError( - "expects dst element type to be i8/ui8/i16/ui16/i32/ui32/f16/bf16/f32 " - "(and on A5 targets also float8_e4m3/float8_e5m2 family types)"); - - if (!isSupportedMGatherMScatterIndexElemType(idxElem)) - return emitOpError("expects idx element type to be signless i32"); - - if (failed(verifyMGatherMScatterMemOperand(getOperation(), getMem(), dstElem, - "dst"))) - return failure(); - - if (getGatherOob() != pto::GatherOOB::Undefined && - !isTargetArchA5(getOperation())) - return emitOpError( - "expects non-default gatherOob only on A5 targets"); - - if (failed(verifyMGatherMScatterTileShape(getOperation(), dstTy, idxTy, "dst"))) - return failure(); - - return success(); -} - -void mlir::pto::TCvtOp::print(OpAsmPrinter &p) { - p << " ins(" << getSrc(); - Builder builder(getContext()); - NamedAttrList attrs; - for (auto attr : (*this)->getAttrs()) { - if (attr.getName() == "sat_mode") { - attrs.set(builder.getStringAttr("satmode"), attr.getValue()); - continue; - } - attrs.set(attr.getName(), attr.getValue()); - } - p.printOptionalAttrDict(attrs.getAttrs()); - p << " : " << getSrc().getType(); - p << ") outs(" << getDst() << " : " << getDst().getType() << ")"; -} - -ParseResult mlir::pto::TCvtOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand src, dst; - Type srcTy, dstTy; - - if (parser.parseKeyword("ins") || parser.parseLParen() || parser.parseOperand(src)) - return failure(); - NamedAttrList attrs; - if (parser.parseOptionalAttrDict(attrs) || parser.parseColonType(srcTy)) - return failure(); - if (auto satmode = attrs.get("satmode")) { - attrs.erase("satmode"); - if (attrs.get("sat_mode")) - return parser.emitError(parser.getCurrentLocation(), - "cannot specify both satmode and sat_mode"); - attrs.set("sat_mode", satmode); - } - result.attributes = attrs; - if (parser.parseRParen() || parser.parseKeyword("outs") || parser.parseLParen() || - parser.parseOperand(dst) || parser.parseColonType(dstTy) || parser.parseRParen()) - return failure(); - - if (parser.resolveOperand(src, srcTy, result.operands) || - parser.resolveOperand(dst, dstTy, result.operands)) - return failure(); - return success(); -} - -void mlir::pto::TMrgSortOp::print(OpAsmPrinter &p) { - if (isFormat1()) { - p << " ins(" << getSrc() << ", " << getBlockLen() << " : " << getSrc().getType() - << ", " << getBlockLen().getType() << ") outs(" << getDst() << " : " - << getDst().getType() << ")"; - } else if (isFormat2()) { - p << " ins("; - llvm::interleaveComma(getSrcs(), p, [&](Value src) { p << src; }); - p << ", " << getTmp(); - p << " {exhausted = " << (getExhausted() ? "true" : "false") << "} : "; - llvm::interleaveComma(getSrcs().getTypes(), p, [&](Type ty) { p << ty; }); - p << ", " << getTmp().getType(); - p << ") outs(" << getDst() << ", " << getExcuted() - << " : " << getDst().getType() << ", " << getExcuted().getType() << ")"; - } else { - llvm::report_fatal_error("TMrgSortOp print expects format1 or format2"); - } - p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"operandSegmentSizes", "exhausted"}); -} - -ParseResult mlir::pto::TMrgSortOp::parse(OpAsmParser &parser, OperationState &result) { - if (parser.parseKeyword("ins") || parser.parseLParen()) - return failure(); - OpAsmParser::UnresolvedOperand first, second; - if (parser.parseOperand(first) || parser.parseComma() || parser.parseOperand(second)) - return failure(); - - if (parser.parseOptionalColon().succeeded()) { - Type srcTy, blockLenTy, dstTy; - if (parser.parseType(srcTy) || parser.parseComma() || parser.parseType(blockLenTy) || - parser.parseRParen() || parser.parseKeyword("outs") || parser.parseLParen()) - return failure(); - OpAsmParser::UnresolvedOperand dstOp; - if (parser.parseOperand(dstOp) || parser.parseColon() || parser.parseType(dstTy) || - parser.parseRParen()) - return failure(); - result.addAttribute("operandSegmentSizes", - parser.getBuilder().getDenseI32ArrayAttr({1, 1, 1, 0, 0})); - if (parser.resolveOperand(first, srcTy, result.operands) || - parser.resolveOperand(second, blockLenTy, result.operands) || - parser.resolveOperand(dstOp, dstTy, result.operands)) - return failure(); - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - if (!result.attributes.get("exhausted")) - result.addAttribute("exhausted", parser.getBuilder().getBoolAttr(false)); - return success(); - } - - SmallVector srcs = {first, second}; - while (parser.parseOptionalComma().succeeded()) { - OpAsmParser::UnresolvedOperand next; - if (parser.parseOperand(next)) - return failure(); - srcs.push_back(next); - } - if (srcs.size() < 3 || srcs.size() > 5) - return parser.emitError(parser.getCurrentLocation(), - "tmrgsort format2 expects 2 to 4 src operands plus one tmp operand"); - OpAsmParser::UnresolvedOperand tmpOp = srcs.pop_back_val(); - bool exhaustedVal = false; - if (parser.parseOptionalLBrace().succeeded()) { - if (parser.parseKeyword("exhausted") || parser.parseEqual()) - return failure(); - StringRef kw; - if (parser.parseKeyword(&kw) || parser.parseRBrace()) - return failure(); - exhaustedVal = (kw == "true"); - } - SmallVector srcTypes; - srcTypes.reserve(srcs.size()); - if (parser.parseColon()) - return failure(); - Type firstSrcTy; - if (parser.parseType(firstSrcTy)) - return failure(); - srcTypes.push_back(firstSrcTy); - while (parser.parseOptionalComma().succeeded()) { - Type nextTy; - if (parser.parseType(nextTy)) - return failure(); - srcTypes.push_back(nextTy); - } - if (srcTypes.size() != srcs.size() + 1 || parser.parseRParen() || - parser.parseKeyword("outs") || parser.parseLParen()) - return failure(); - Type tmpTy = srcTypes.pop_back_val(); - OpAsmParser::UnresolvedOperand dstOp, excutedOp; - Type dstTy, excutedTy; - if (parser.parseOperand(dstOp) || parser.parseComma() || parser.parseOperand(excutedOp) || - parser.parseColon() || parser.parseType(dstTy) || parser.parseComma() || - parser.parseType(excutedTy) || parser.parseRParen()) - return failure(); - result.addAttribute("operandSegmentSizes", - parser.getBuilder().getDenseI32ArrayAttr( - {static_cast(srcs.size()), 0, 1, 1, 1})); - if (parser.resolveOperands(srcs, srcTypes, parser.getCurrentLocation(), result.operands) || - parser.resolveOperand(dstOp, dstTy, result.operands) || - parser.resolveOperand(tmpOp, tmpTy, result.operands) || - parser.resolveOperand(excutedOp, excutedTy, result.operands)) - return failure(); - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - if (!result.attributes.get("exhausted")) - result.addAttribute("exhausted", parser.getBuilder().getBoolAttr(exhaustedVal)); - return success(); -} - -mlir::LogicalResult mlir::pto::TMrgSortOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - if (isFormat1()) { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (!isPTOShapedLike(srcTy) || !isPTOShapedLike(dstTy)) - return emitOpError() << "format1 expects PTO shaped-like types for src/dst"; - if (getElemTy(srcTy) != getElemTy(dstTy)) - return emitOpError() << "expects src/dst to have the same element type"; - if (!getElemTy(srcTy).isF16() && !getElemTy(srcTy).isF32()) - return emitOpError() << "expects element type to be f16 or f32"; - auto ss = getShapeVec(srcTy); - auto ds = getShapeVec(dstTy); - if (ss.size() != 2 || ds.size() != 2) - return emitOpError() << "expects src/dst to be rank-2 tile-shaped"; - if (ss[0] != mlir::ShapedType::kDynamic && ss[0] != 1) - return emitOpError() << "expects src rows == 1"; - if (ds[0] != mlir::ShapedType::kDynamic && ds[0] != 1) - return emitOpError() << "expects dst rows == 1"; - if (ss[1] != mlir::ShapedType::kDynamic && ds[1] != mlir::ShapedType::kDynamic && ss[1] != ds[1]) - return emitOpError() << "expects src/dst cols to match"; - if (getBlockLen()) { - if (auto cstOp = getBlockLen().getDefiningOp()) { - if (auto intAttr = mlir::dyn_cast(cstOp.getValue())) { - int64_t v = intAttr.getValue().getSExtValue(); - if (v <= 0 || (v % 64) != 0) - return emitOpError() << "expects blockLen > 0 and multiple of 64"; - } - } - } - return mlir::success(); - } - if (isFormat2()) { - for (Value v : getSrcs()) - if (!isPTOShapedLike(v.getType())) - return emitOpError() << "format2 expects PTO shaped-like type for each src"; - if (getSrcs().size() < 2u || getSrcs().size() > 4u) - return emitOpError() << "format2 expects 2 to 4 srcs"; - if (getDsts().size() != 1u || !getTmp() || !getExcuted()) - return emitOpError() << "format2 expects ins(srcs..., tmp), outs(dst), and excuted=vector"; - Type dstTy = getDst().getType(); - Type tmpTy = getTmp().getType(); - if (!isPTOShapedLike(dstTy) || !isPTOShapedLike(tmpTy)) - return emitOpError() << "format2 dst/tmp must be PTO shaped-like"; - auto excutedTy = mlir::dyn_cast(getExcuted().getType()); - if (!excutedTy || excutedTy.getRank() != 1 || excutedTy.getNumElements() != 4 || - !excutedTy.getElementType().isInteger(16)) - return emitOpError() << "format2 excuted must be vector<4xi16>"; - Type elemTy = getElemTy(dstTy); - if (elemTy != getElemTy(tmpTy)) - return emitOpError() << "format2 expects dst/tmp element types to match"; - auto dstShape = getShapeVec(dstTy); - auto tmpShape = getShapeVec(tmpTy); - if (dstShape.size() != 2 || tmpShape.size() != 2) - return emitOpError() << "format2 expects dst/tmp to be rank-2 tile-shaped"; - if ((dstShape[0] != mlir::ShapedType::kDynamic && dstShape[0] != 1) || - (tmpShape[0] != mlir::ShapedType::kDynamic && tmpShape[0] != 1)) - return emitOpError() << "format2 expects dst/tmp rows == 1"; - if (dstShape[1] != mlir::ShapedType::kDynamic && - tmpShape[1] != mlir::ShapedType::kDynamic && - tmpShape[1] < dstShape[1]) - return emitOpError() << "format2 expects tmp.cols >= dst.cols"; - for (Value src : getSrcs()) { - Type srcTy = src.getType(); - auto srcShape = getShapeVec(srcTy); - if (srcShape.size() != 2) - return emitOpError() << "format2 expects src to be rank-2 tile-shaped"; - if (srcShape[0] != mlir::ShapedType::kDynamic && srcShape[0] != 1) - return emitOpError() << "format2 expects src rows == 1"; - if (getElemTy(srcTy) != elemTy) - return emitOpError() << "format2 expects src/dst/tmp element types to match"; - } - return mlir::success(); - } - return emitOpError() << "tmrgsort expects format1 (1 src + blockLen + 1 dst) or " - "format2 (2 to 4 srcs + tmp, outs dst, excuted)"; -} - -mlir::LogicalResult mlir::pto::TMulOp::verify() { - return verifyArithmeticBinaryTileOpWithArchDispatch( - getOperation(), getSrc0().getType(), getSrc1().getType(), getDst().getType(), - /*allowInt8OnA5=*/false, /*allowBf16OnA5=*/false, - "expects A2/A3 tmul element type to be i32/i16/f16/f32", - "expects A5 tmul element type to be i32/i16/f16/f32"); -} - -mlir::LogicalResult mlir::pto::TMulSOp::verify() { - return verifyArithmeticScalarTileOpWithArchDispatch( - getOperation(), getSrc0().getType(), getDst().getType(), - getScalar().getType(), /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, - "expects A2/A3 tmuls element type to be i32/i16/f16/f32", - "expects A5 tmuls element type to be i32/i16/i8/f16/bf16/f32", - /*requireValidRowsEqualOnA2A3=*/true, - /*requireValidRowsEqualOnA5=*/false); -} - -mlir::LogicalResult mlir::pto::TShlSOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, dstTy, "dst"))) - return failure(); - - Type srcElem = getElemTy(srcTy); - Type dstElem = getElemTy(dstTy); - if (!srcElem || !dstElem) - return emitOpError() << "failed to get element type for src/dst"; - if (srcElem != dstElem) - return emitOpError() << "expects src and dst to have the same element type"; - if (!mlir::isa(srcElem)) - return emitOpError() << "expects integral element types"; - if (auto scalarValue = getConstantIntegerValue(getScalar()); scalarValue && *scalarValue < 0) - return emitOpError("expects tshls scalar to be non-negative"); - return mlir::success(); -} - -mlir::LogicalResult mlir::pto::TShrSOp::verify() { - auto verifyCommon = [&]() -> FailureOr { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyVecTileCommon(*this, srcTy, "src")) || - failed(verifyVecTileCommon(*this, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) - return failure(); - - Type srcElem = getElemTy(srcTy); - Type dstElem = getElemTy(dstTy); - if (!srcElem || !dstElem) { - emitOpError("failed to get element type for src/dst"); - return failure(); - } - if (srcElem != dstElem) { - emitOpError("expects src and dst to have the same element type"); - return failure(); - } - return srcElem; - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - FailureOr elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - auto it = mlir::dyn_cast(*elemOr); - if (!it || (it.getWidth() != 16 && it.getWidth() != 32)) - return emitOpError( - "expects A2/A3 tshrs src and dst element type to be i16/i32"); - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - FailureOr elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - auto it = mlir::dyn_cast(*elemOr); - if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && - it.getWidth() != 32)) - return emitOpError( - "expects A5 tshrs src and dst element type to be i8/i16/i32"); - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TNegOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyVecTileStorage(*this, srcTy, "src")) || - failed(verifyVecTileStorage(*this, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst")) || - failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) - return failure(); - - Type elemTy = getElemTy(srcTy); - if (!(elemTy.isInteger(16) || elemTy.isInteger(32) || elemTy.isF16() || - elemTy.isF32())) - return emitOpError() - << "expects A2/A3 tneg element type to be i16/i32/f16/f32"; - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyVecTileStorage(*this, srcTy, "src")) || - failed(verifyVecTileStorage(*this, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst"))) - return failure(); - - auto srcValid = getValidShapeVec(srcTy); - auto dstValid = getValidShapeVec(dstTy); - if (srcValid.size() != 2 || dstValid.size() != 2) - return emitOpError() << "expects src and dst to have rank-2 valid_shape"; - if (srcValid[1] != ShapedType::kDynamic && - dstValid[1] != ShapedType::kDynamic && - srcValid[1] != dstValid[1]) - return emitOpError() - << "expects src and dst to have the same valid_shape[1]"; - - Type elemTy = getElemTy(srcTy); - if (!(elemTy.isInteger(8) || elemTy.isInteger(16) || elemTy.isInteger(32) || - elemTy.isF16() || elemTy.isF32() || elemTy.isBF16())) - return emitOpError() - << "expects A5 tneg element type to be i8/i16/i32/f16/f32/bf16"; - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TNotOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyVecTileCommon(*this, srcTy, "src")) || - failed(verifyVecTileCommon(*this, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) - return failure(); - auto elemTy = getElemTy(srcTy); - if (elemTy != getElemTy(dstTy)) - return emitOpError() << "expects src and dst to have the same element type"; - if (!elemTy.isInteger(16)) - return emitOpError() << "expects A2/A3 tnot element type to be i16"; - return success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyVecTileCommon(*this, srcTy, "src")) || - failed(verifyVecTileCommon(*this, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) - return failure(); - auto elemTy = getElemTy(srcTy); - if (elemTy != getElemTy(dstTy)) - return emitOpError() << "expects src and dst to have the same element type"; - if (!(elemTy.isInteger(8) || elemTy.isInteger(16) || elemTy.isInteger(32))) - return emitOpError() << "expects A5 tnot element type to be i8/i16/i32"; - return success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TOrOp::verify() { - auto verifyCommon = [&]() -> FailureOr { - return verifyMatchingRowMajorBinaryTileOpCommon( - getOperation(), getSrc0().getType(), getSrc1().getType(), - getDst().getType()); - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - FailureOr elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - auto it = mlir::dyn_cast(*elemOr); - if (!it || (it.getWidth() != 8 && it.getWidth() != 16)) - return emitOpError( - "expects A2/A3 tor src0, src1, and dst element type to be i8/i16"); - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - FailureOr elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - auto it = mlir::dyn_cast(*elemOr); - if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && - it.getWidth() != 32)) - return emitOpError( - "expects A5 tor src0, src1, and dst element type to be i8/i16/i32"); - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TOrSOp::verify() { - auto verifyCommon = [&]() -> FailureOr { - return verifyDistinctRowMajorUnaryTileOpCommon(getOperation(), getSrc(), - getDst(), "src", "dst"); - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - FailureOr elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - auto it = mlir::dyn_cast(*elemOr); - if (!it || (it.getWidth() != 8 && it.getWidth() != 16)) - return emitOpError( - "expects A2/A3 tors src and dst element type to be i8/i16"); - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - FailureOr elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - auto it = mlir::dyn_cast(*elemOr); - if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && - it.getWidth() != 32)) - return emitOpError( - "expects A5 tors src and dst element type to be i8/i16/i32"); - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -static FailureOr verifyPTOShapedBinarySameElemAndShape(Operation *op, - Type src0Ty, - Type src1Ty, - Type dstTy) { - if (!isPTOShapedLike(src0Ty) || !isPTOShapedLike(src1Ty) || - !isPTOShapedLike(dstTy)) - return op->emitOpError( - "expects src0/src1/dst to be memref/tensor/tile_buf/tile_view types"), - failure(); - Type e0 = getElemTy(src0Ty), e1 = getElemTy(src1Ty), ed = getElemTy(dstTy); - if (!e0 || !e1 || !ed) - return op->emitOpError("failed to get element type for operands"), failure(); - if (e0 != e1 || e0 != ed) - return op->emitOpError("expects src0/src1/dst to have the same element type"), - failure(); - auto s0 = getShapeVec(src0Ty), s1 = getShapeVec(src1Ty), sd = getShapeVec(dstTy); - if (s0 != s1 || s0 != sd) - return op->emitOpError("expects src0/src1/dst to have the same shape"), - failure(); - return e0; -} - -mlir::LogicalResult mlir::pto::TPartAddOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - Type src0Ty = getSrc0().getType(); - Type src1Ty = getSrc1().getType(); - Type dstTy = getDst().getType(); - if (!isPTOShapedLike(src0Ty) || !isPTOShapedLike(src1Ty) || - !isPTOShapedLike(dstTy)) - return emitOpError() << "expects PTO shaped-like src0/src1/dst"; - if (getElemTy(src0Ty) != getElemTy(src1Ty) || - getElemTy(src0Ty) != getElemTy(dstTy)) - return emitOpError() << "expects src0/src1/dst to have the same element type"; - auto s0 = getShapeVec(src0Ty); - auto s1 = getShapeVec(src1Ty); - auto d = getShapeVec(dstTy); - if (s0.size() != 2 || s1.size() != 2 || d.size() != 2) - return emitOpError() << "expects src0/src1/dst to be rank-2 (tile-shaped)"; - if (failed(verifyPartialValidPattern(*this, src0Ty, src1Ty, dstTy))) - return failure(); - Type elem = getElemTy(src0Ty); - if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || elem.isF32())) - return emitOpError("expects A2/A3 tpartadd element type to be i32/i16/f16/f32"); - return mlir::success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - Type src0Ty = getSrc0().getType(); - Type src1Ty = getSrc1().getType(); - Type dstTy = getDst().getType(); - if (!isPTOShapedLike(src0Ty) || !isPTOShapedLike(src1Ty) || - !isPTOShapedLike(dstTy)) - return emitOpError() << "expects PTO shaped-like src0/src1/dst"; - if (getElemTy(src0Ty) != getElemTy(src1Ty) || - getElemTy(src0Ty) != getElemTy(dstTy)) - return emitOpError() << "expects src0/src1/dst to have the same element type"; - Type elem = getElemTy(src0Ty); - if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isInteger(8) || - elem.isF16() || elem.isBF16() || elem.isF32())) - return emitOpError("expects A5 tpartadd element type to be i32/i16/i8/f16/bf16/f32"); - auto s0 = getShapeVec(src0Ty); - auto s1 = getShapeVec(src1Ty); - auto d = getShapeVec(dstTy); - if (s0.size() != 2 || s1.size() != 2 || d.size() != 2) - return emitOpError() << "expects src0/src1/dst to be rank-2 (tile-shaped)"; - return mlir::success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TPartMaxOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - Type t0 = getSrc0().getType(); - Type t1 = getSrc1().getType(); - Type td = getDst().getType(); - FailureOr elemOr = - verifyPTOShapedBinarySameElemAndShape(getOperation(), t0, t1, td); - if (failed(elemOr)) - return failure(); - if (failed(verifyPartialValidPattern(*this, t0, t1, td))) - return failure(); - Type e0 = *elemOr; - if (!(e0.isInteger(32) || e0.isInteger(16) || e0.isF16() || e0.isF32())) - return emitOpError("expects A2/A3 tpartmax element type to be i32/i16/f16/f32"); - return mlir::success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - Type t0 = getSrc0().getType(); - Type t1 = getSrc1().getType(); - Type td = getDst().getType(); - FailureOr elemOr = - verifyPTOShapedBinarySameElemAndShape(getOperation(), t0, t1, td); - if (failed(elemOr)) - return failure(); - Type e0 = *elemOr; - if (!(e0.isInteger(32) || e0.isInteger(16) || e0.isInteger(8) || - e0.isF16() || e0.isBF16() || e0.isF32())) - return emitOpError("expects A5 tpartmax element type to be i32/i16/i8/f16/bf16/f32"); - return mlir::success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TPartMinOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - Type t0 = getSrc0().getType(); - Type t1 = getSrc1().getType(); - Type td = getDst().getType(); - FailureOr elemOr = - verifyPTOShapedBinarySameElemAndShape(getOperation(), t0, t1, td); - if (failed(elemOr)) - return failure(); - if (failed(verifyPartialValidPattern(*this, t0, t1, td))) - return failure(); - Type e0 = *elemOr; - if (!(e0.isInteger(32) || e0.isInteger(16) || e0.isF16() || e0.isF32())) - return emitOpError("expects A2/A3 tpartmin element type to be i32/i16/f16/f32"); - return mlir::success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - Type t0 = getSrc0().getType(); - Type t1 = getSrc1().getType(); - Type td = getDst().getType(); - FailureOr elemOr = - verifyPTOShapedBinarySameElemAndShape(getOperation(), t0, t1, td); - if (failed(elemOr)) - return failure(); - Type e0 = *elemOr; - if (!(e0.isInteger(32) || e0.isInteger(16) || e0.isInteger(8) || - e0.isF16() || e0.isBF16() || e0.isF32())) - return emitOpError("expects A5 tpartmin element type to be i32/i16/i8/f16/bf16/f32"); - return mlir::success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -static LogicalResult verifyTPartArgOpCommon(Operation *op, Type src0Ty, - Type src1Ty, Type src0IdxTy, - Type src1IdxTy, Type dstTy, - Type dstIdxTy, StringRef opName) { - FailureOr dataElemOr = - verifyPTOShapedBinarySameElemAndShape(op, src0Ty, src1Ty, dstTy); - if (failed(dataElemOr)) - return failure(); - if (failed(verifyPartialValidPattern(op, src0Ty, src1Ty, dstTy))) - return failure(); - - if (!isPTOShapedLike(src0IdxTy) || !isPTOShapedLike(src1IdxTy) || - !isPTOShapedLike(dstIdxTy)) - return op->emitOpError("expects PTO shaped-like src0Idx/src1Idx/dstIdx"); - Type idxElem = getElemTy(src0IdxTy); - if (!idxElem || idxElem != getElemTy(src1IdxTy) || - idxElem != getElemTy(dstIdxTy)) - return op->emitOpError( - "expects src0Idx/src1Idx/dstIdx to have the same element type"); - auto idxInt = dyn_cast(idxElem); - if (!idxInt || idxInt.getWidth() != 32) - return op->emitOpError( - "expects src0Idx/src1Idx/dstIdx element type to be i32 or ui32"); - - auto dataShape = getShapeVec(src0Ty); - if (dataShape != getShapeVec(src0IdxTy) || - dataShape != getShapeVec(src1IdxTy) || - dataShape != getShapeVec(dstIdxTy)) - return op->emitOpError( - "expects data and index operands to have the same shape"); - if (getValidShapeVec(src0Ty) != getValidShapeVec(src0IdxTy) || - getValidShapeVec(src1Ty) != getValidShapeVec(src1IdxTy) || - getValidShapeVec(dstTy) != getValidShapeVec(dstIdxTy)) - return op->emitOpError( - "expects each data operand and its index operand to have the same valid_shape"); - - Type elem = *dataElemOr; - PTOArch arch = getTargetArch(op); - if (arch == PTOArch::A5) { - if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isInteger(8) || - elem.isF16() || elem.isBF16() || elem.isF32())) - return op->emitOpError() << "expects A5 " << opName - << " element type to be i32/i16/i8/f16/bf16/f32"; - } else { - if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || - elem.isF32())) - return op->emitOpError() << "expects A2/A3 " << opName - << " element type to be i32/i16/f16/f32"; - } - return success(); -} - -mlir::LogicalResult mlir::pto::TPartArgMaxOp::verify() { - auto verifyByArch = [&]() -> LogicalResult { - return verifyTPartArgOpCommon( - getOperation(), getSrc0().getType(), getSrc1().getType(), - getSrc0Idx().getType(), getSrc1Idx().getType(), getDst().getType(), - getDstIdx().getType(), "tpartargmax"); - }; - return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); -} - -mlir::LogicalResult mlir::pto::TPartArgMinOp::verify() { - auto verifyByArch = [&]() -> LogicalResult { - return verifyTPartArgOpCommon( - getOperation(), getSrc0().getType(), getSrc1().getType(), - getSrc0Idx().getType(), getSrc1Idx().getType(), getDst().getType(), - getDstIdx().getType(), "tpartargmin"); - }; - return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); -} - -mlir::LogicalResult mlir::pto::TPartMulOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - Type src0Ty = getSrc0().getType(); - Type src1Ty = getSrc1().getType(); - Type dstTy = getDst().getType(); - if (!isPTOShapedLike(src0Ty) || !isPTOShapedLike(src1Ty) || - !isPTOShapedLike(dstTy)) - return emitOpError() << "expects PTO shaped-like src0/src1/dst"; - if (getElemTy(src0Ty) != getElemTy(src1Ty) || - getElemTy(src0Ty) != getElemTy(dstTy)) - return emitOpError() - << "expects src0/src1/dst to have the same element type"; - auto s0 = getShapeVec(src0Ty); - auto s1 = getShapeVec(src1Ty); - auto d = getShapeVec(dstTy); - if (s0.size() != 2 || s1.size() != 2 || d.size() != 2) - return emitOpError() - << "expects src0/src1/dst to be rank-2 (tile-shaped)"; - if (failed(verifyPartialValidPattern(*this, src0Ty, src1Ty, dstTy))) - return failure(); - Type elem = getElemTy(src0Ty); - if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || - elem.isF32())) - return emitOpError( - "expects A2/A3 tpartmul element type to be i32/i16/f16/f32"); - return mlir::success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - Type src0Ty = getSrc0().getType(); - Type src1Ty = getSrc1().getType(); - Type dstTy = getDst().getType(); - if (!isPTOShapedLike(src0Ty) || !isPTOShapedLike(src1Ty) || - !isPTOShapedLike(dstTy)) - return emitOpError() << "expects PTO shaped-like src0/src1/dst"; - if (getElemTy(src0Ty) != getElemTy(src1Ty) || - getElemTy(src0Ty) != getElemTy(dstTy)) - return emitOpError() - << "expects src0/src1/dst to have the same element type"; - Type elem = getElemTy(src0Ty); - if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isInteger(8) || - elem.isF16() || elem.isBF16() || elem.isF32())) - return emitOpError( - "expects A5 tpartmul element type to be i32/i16/i8/f16/bf16/f32"); - auto s0 = getShapeVec(src0Ty); - auto s1 = getShapeVec(src1Ty); - auto d = getShapeVec(dstTy); - if (s0.size() != 2 || s1.size() != 2 || d.size() != 2) - return emitOpError() - << "expects src0/src1/dst to be rank-2 (tile-shaped)"; - return mlir::success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TPReluOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - auto verifyCommon = [&]() -> FailureOr> { - Type t0 = getSrc0().getType(); - Type t1 = getSrc1().getType(); - Type tt = getTmp().getType(); - Type td = getDst().getType(); - if (failed(verifyTileBufCommon(*this, t0, "src0")) || - failed(verifyTileBufCommon(*this, t1, "src1")) || - failed(verifyTileBufCommon(*this, tt, "tmp")) || - failed(verifyTileBufCommon(*this, td, "dst"))) - return failure(); - - Type e0 = getElemTy(t0), e1 = getElemTy(t1), et = getElemTy(tt), ed = getElemTy(td); - if (!e0 || !e1 || !et || !ed) { - emitOpError("failed to get element type for operands"); - return failure(); - } - if (e0 != e1 || e0 != ed) { - emitOpError("expects dst/src0/src1 to have the same element type"); - return failure(); - } - if (!(e0.isF16() || e0.isF32())) { - emitOpError("expects dst/src0/src1 element type to be f16 or f32"); - return failure(); - } - if (!isRowMajorTileBuf(t0) || !isRowMajorTileBuf(t1) || !isRowMajorTileBuf(td)) { - emitOpError("expects src0, src1, and dst to use row-major layout"); - return failure(); - } - if (failed(verifyTileBufSameValidShape(*this, t0, td, "src0", "dst")) || - failed(verifyTileBufSameValidShape(*this, t1, td, "src1", "dst"))) - return failure(); - - auto s0 = getShapeVec(t0), s1 = getShapeVec(t1), st = getShapeVec(tt), sd = getShapeVec(td); - if (s0 != s1 || s0 != st || s0 != sd) { - emitOpError("expects src0/src1/tmp/dst to have the same shape"); - return failure(); - } - return std::make_tuple(t0, t1, tt, td); - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - auto tysOr = verifyCommon(); - if (failed(tysOr)) - return failure(); - auto [t0, t1, tt, td] = *tysOr; - Type tmpElem = getElemTy(tt); - auto tmpIntTy = mlir::dyn_cast(tmpElem); - if (!tmpIntTy || tmpIntTy.getWidth() != 8) - return emitOpError("expects A2/A3 tmp element type to be u8"); - if (!isRowMajorTileBuf(tt)) - return emitOpError("expects tmp to use row-major layout"); - if (auto arch = getVerifierArchName(getOperation()); - arch && arch->equals_insensitive("a3")) { - if (getSrc0() == getSrc1() || getSrc0() == getTmp() || getSrc0() == getDst() || - getSrc1() == getTmp() || getSrc1() == getDst() || getTmp() == getDst()) - return emitOpError( - "expects A3 src0, src1, tmp, and dst to use different storage"); - } - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - auto tysOr = verifyCommon(); - if (failed(tysOr)) - return failure(); - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TQuantOp::verify() { - // Structural checks: always run regardless of operand representation - // (applies both before and after PTOViewToMemref lowering). - auto verifyStructural = [&]() -> LogicalResult { - // dst elem type and offset presence must be consistent with quant_type. - Type dstTy = getDst().getType(); - Type dstElemTy = getElemTy(dstTy); - auto dstIntTy = dyn_cast(dstElemTy); - if (getQuantType() == mlir::pto::QuantType::INT8_SYM) { - if (!dstIntTy || dstIntTy.getWidth() != 8) - return emitOpError() - << "expects dst element type i8/ui8 for INT8_SYM quantization"; - if (getOffset()) - return emitOpError() - << "INT8_SYM quantization must not have an offset operand"; - } else { - // INT8_ASYM - if (!dstIntTy || dstIntTy.getWidth() != 8) - return emitOpError() - << "expects dst element type i8/ui8 for INT8_ASYM quantization"; - if (!getOffset()) - return emitOpError() - << "INT8_ASYM quantization requires an offset operand"; - } - return success(); - }; - - if (failed(verifyStructural())) - return failure(); - - // Layout/tile-buffer checks: only meaningful for pre-lowering tile types. - // Skip when operands are already plain MemRefs (post PTOViewToMemref). - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - - auto verifyCommon = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type fpTy = getFp().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, fpTy, "fp")) || - failed(verifyTileBufCommon(*this, dstTy, "dst"))) - return failure(); - // src must be f32 (ISA static_assert) - if (!getElemTy(srcTy).isF32()) - return emitOpError() << "expects src to have element type f32"; - if (getOffset()) { - Type offsetTy = getOffset().getType(); - if (failed(verifyTileBufCommon(*this, offsetTy, "offset"))) - return failure(); - if (!getElemTy(offsetTy).isF32()) - return emitOpError() << "expects offset to have element type f32"; - } - return success(); - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - if (failed(verifyCommon())) - return failure(); - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (!isRowMajorTileBuf(srcTy) || !isRowMajorTileBuf(dstTy)) - return emitOpError() << "expects A2/A3 src and dst to use row-major layout"; - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - return verifyCommon(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TDequantOp::verify() { - // Structural checks: src must be i8 or i16, dst/scale/offset must be f32. - auto verifyStructural = [&]() -> LogicalResult { - Type srcElemTy = getElemTy(getSrc().getType()); - auto srcIntTy = dyn_cast(srcElemTy); - if (!srcIntTy || !(srcIntTy.getWidth() == 8 || srcIntTy.getWidth() == 16)) - return emitOpError() - << "expects src element type i8 or i16"; - if (!getElemTy(getDst().getType()).isF32()) - return emitOpError() << "expects dst element type f32"; - if (!getElemTy(getScale().getType()).isF32()) - return emitOpError() << "expects scale element type f32"; - if (!getElemTy(getOffset().getType()).isF32()) - return emitOpError() << "expects offset element type f32"; - return success(); - }; - - if (failed(verifyStructural())) - return failure(); - - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - - auto verifyCommon = [&]() -> LogicalResult { - if (failed(verifyTileBufCommon(*this, getSrc().getType(), "src")) || - failed(verifyTileBufCommon(*this, getScale().getType(), "scale")) || - failed(verifyTileBufCommon(*this, getOffset().getType(), "offset")) || - failed(verifyTileBufCommon(*this, getDst().getType(), "dst"))) - return failure(); - return success(); - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - if (failed(verifyCommon())) - return failure(); - if (!isRowMajorTileBuf(getSrc().getType()) || - !isRowMajorTileBuf(getDst().getType())) - return emitOpError() - << "expects A2/A3 src and dst to use row-major layout"; - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { return verifyCommon(); }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TRecipOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - Type ts = getSrc().getType(); - Type td = getDst().getType(); - if (failed(verifyVecTileUnaryOp(*this, ts, td, "src", "dst", - /*allowBf16=*/false, /*allowInt8=*/false))) - return failure(); - if (failed(verifyTileBufSameValidShape(*this, ts, td, "src", "dst"))) - return failure(); - Type elemTy = getElemTy(ts); - if (!(elemTy.isF16() || elemTy.isF32())) - return emitOpError() << "expects element type to be f16 or f32"; - if (auto arch = getVerifierArchName(getOperation()); - arch && arch->equals_insensitive("a3") && getSrc() == getDst()) - return emitOpError("expects A3 trecip src and dst to use different storage"); - return mlir::success(); -} - -mlir::LogicalResult mlir::pto::TReluOp::verify() { - auto verifyByArch = [&](StringRef errorMessage) -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyVecTileCommon(*this, srcTy, "src")) || - failed(verifyVecTileCommon(*this, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst")) || - failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) - return failure(); - Type elemTy = getElemTy(srcTy); - if (!(elemTy.isInteger(32) || elemTy.isF16() || elemTy.isF32())) - return emitOpError() << errorMessage; - return success(); - }; - auto verifyA2A3 = [&]() -> LogicalResult { - return verifyByArch("expects A2/A3 trelu element type to be i32/f16/f32"); - }; - auto verifyA5 = [&]() -> LogicalResult { - return verifyByArch("expects A5 trelu element type to be i32/f16/f32"); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - - -mlir::LogicalResult mlir::pto::TRemOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - - Type src0Ty = getSrc0().getType(); - Type src1Ty = getSrc1().getType(); - Type tmpTy = getTmp().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, src0Ty, "src0")) || - failed(verifyTileBufCommon(*this, src1Ty, "src1")) || - failed(verifyTileBufCommon(*this, tmpTy, "tmp")) || - failed(verifyTileBufCommon(*this, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameElemType(*this, src0Ty, src1Ty, "src0", "src1")) || - failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst")) || - failed(verifyTileBufSameValidShape(*this, src0Ty, src1Ty, "src0", "src1")) || - failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) - return failure(); - if (getElemTy(tmpTy) != getElemTy(dstTy)) - return emitOpError("expects tmp and dst to have the same element type"); - if (!isRowMajorTileBuf(src0Ty) || !isRowMajorTileBuf(src1Ty) || - !isRowMajorTileBuf(tmpTy) || !isRowMajorTileBuf(dstTy)) - return emitOpError("expects src0, src1, tmp, and dst to use row-major layout"); - auto dstValid = getValidShapeVec(dstTy); - auto tmpValid = getValidShapeVec(tmpTy); - if (dstValid.size() != 2 || tmpValid.size() != 2) - return emitOpError("expects tmp and dst to be rank-2 tiles"); - if (tmpValid[0] != ShapedType::kDynamic && tmpValid[0] < 1) - return emitOpError("expects tmp to have at least 1 valid row"); - if (dstValid[1] != ShapedType::kDynamic && tmpValid[1] != ShapedType::kDynamic && - tmpValid[1] < dstValid[1]) - return emitOpError("expects tmp valid columns to cover dst valid columns"); - - Type elem = getElemTy(src0Ty); - auto verifyA2A3 = [&]() -> LogicalResult { - if (!(elem.isInteger(32) || elem.isF32())) - return emitOpError("expects A2/A3 trem element type to be i32/f32"); - return success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || elem.isF32())) - return emitOpError("expects A5 trem element type to be i32/i16/f16/f32"); - return success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TFModOp::verify() { - return verifyArithmeticBinaryTileOpWithArchDispatch( - getOperation(), getSrc0().getType(), getSrc1().getType(), getDst().getType(), - /*allowInt8OnA5=*/false, /*allowBf16OnA5=*/false, - "expects A2/A3 tfmod element type to be i32/i16/f16/f32", - "expects A5 tfmod element type to be i32/i16/f16/f32"); -} - -mlir::LogicalResult mlir::pto::TRemSOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - Type ts = getSrc().getType(); - Type tt = getTmp().getType(); - Type td = getDst().getType(); - Type scalarTy = getScalar().getType(); - if (failed(verifyTileBufCommon(*this, ts, "src")) || - failed(verifyTileBufCommon(*this, tt, "tmp")) || - failed(verifyTileBufCommon(*this, td, "dst"))) - return failure(); - if (failed(verifyTileBufSameElemType(*this, ts, td, "src", "dst")) || - failed(verifyTileBufSameValidShape(*this, ts, td, "src", "dst"))) - return failure(); - if (getElemTy(tt) != getElemTy(td)) - return emitOpError("expects tmp and dst to have the same element type"); - if (!isRowMajorTileBuf(ts) || !isRowMajorTileBuf(tt) || !isRowMajorTileBuf(td)) - return emitOpError("expects src, tmp, and dst to use row-major layout"); - Type elem = getElemTy(ts); - if (scalarTy != elem) - return emitOpError("expects scalar type to match the tile element type"); - auto dstValid = getValidShapeVec(td); - auto tmpValid = getValidShapeVec(tt); - if (dstValid.size() != 2 || tmpValid.size() != 2) - return emitOpError("expects tmp and dst to be rank-2 tiles"); - if (tmpValid[0] != ShapedType::kDynamic && tmpValid[0] < 1) - return emitOpError("expects tmp to have at least 1 valid row"); - if (dstValid[1] != ShapedType::kDynamic && tmpValid[1] != ShapedType::kDynamic && - tmpValid[1] < dstValid[1]) - return emitOpError("expects tmp valid columns to cover dst valid columns"); - auto verifyA2A3 = [&]() -> LogicalResult { - if (!(elem.isInteger(32) || elem.isF32())) - return emitOpError("expects A2/A3 trems element type to be i32/f32"); - return success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || elem.isF32())) - return emitOpError("expects A5 trems element type to be i32/i16/f16/f32"); - return success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TFModSOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - Type scalarTy = getScalar().getType(); - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst")) || - failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) - return failure(); - if (!isRowMajorTileBuf(srcTy) || !isRowMajorTileBuf(dstTy)) - return emitOpError("expects src and dst to use row-major layout"); - - Type elem = getElemTy(srcTy); - if (scalarTy != elem) - return emitOpError("expects scalar type to match the tile element type"); - - auto verifyA2A3 = [&]() -> LogicalResult { - if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || elem.isF32())) - return emitOpError("expects A2/A3 tfmods element type to be i32/i16/f16/f32"); - return success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || elem.isF32())) - return emitOpError("expects A5 tfmods element type to be i32/i16/f16/f32"); - return success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - - -static std::optional getStaticNumElements(ArrayRef shape) { - int64_t numel = 1; - for (int64_t d : shape) { - if (d == ShapedType::kDynamic) - return std::nullopt; - if (d < 0) - return std::nullopt; - numel *= d; - } - return numel; -} - -static std::optional getElemBytes(Type elemTy) { - if (!elemTy) - return std::nullopt; - if (auto ft = dyn_cast(elemTy)) { - if (ft.isF16() || ft.isBF16()) - return 2; - if (ft.isF32()) - return 4; - if (ft.isF64()) - return 8; - return std::nullopt; - } - if (auto it = dyn_cast(elemTy)) { - int64_t bits = it.getWidth(); - if (bits <= 0) - return std::nullopt; - return std::max(1, bits / 8); - } - return std::nullopt; -} - -[[maybe_unused]] static bool isTileBufOrMemref(Type ty) { - return mlir::isa(ty); -} - -static constexpr llvm::StringLiteral kLoweredSetValidShapeAttrName = - "__pto.lowered_set_validshape"; - -static bool isLocallyBoundTileSource(Value value) { - if (!value || isa(value)) - return false; - - if (isa( - value.getDefiningOp())) - return true; - - if (auto bitcast = value.getDefiningOp()) - return isLocallyBoundTileSource(bitcast.getSrc()); - if (auto reshape = value.getDefiningOp()) - return isLocallyBoundTileSource(reshape.getSrc()); - - return false; -} - -static std::optional getConstIndexLike(Value v) { - if (auto cOp = v.getDefiningOp()) - return cOp.value(); - if (auto cInt = v.getDefiningOp()) - return cInt.value(); - if (auto cOp = v.getDefiningOp()) { - if (auto ia = dyn_cast(cOp.getValue())) - return ia.getInt(); - } - if (auto castOp = v.getDefiningOp()) - return getConstIndexLike(castOp.getIn()); - if (auto extOp = v.getDefiningOp()) - return getConstIndexLike(extOp.getIn()); - if (auto extOp = v.getDefiningOp()) - return getConstIndexLike(extOp.getIn()); - if (auto truncOp = v.getDefiningOp()) - return getConstIndexLike(truncOp.getIn()); - return std::nullopt; -} - -mlir::LogicalResult mlir::pto::SetValidShapeOp::verify() { - SmallVector shape; - if (auto srcTy = llvm::dyn_cast(getSource().getType())) { - if (srcTy.getRank() != 2) - return emitOpError("expects rank-2 tile_buf source"); - - ArrayRef validShape = srcTy.getValidShape(); - if (validShape.size() != 2) - return emitOpError("expects source validShape to be rank-2"); - if (!srcTy.hasDynamicValid()) - return emitOpError("expects source tile_buf to have dynamic validShape (?, ?)"); - - shape.assign(srcTy.getShape().begin(), srcTy.getShape().end()); - - if (!isLocallyBoundTileSource(getSource())) - return emitOpError( - "requires a locally bound tile source; function arguments/results " - "are unsupported"); - } else if (auto srcTy = llvm::dyn_cast(getSource().getType())) { - if (!(*this)->hasAttr(kLoweredSetValidShapeAttrName)) - return emitOpError( - "expects tile_buf source; memref source is only valid for the internal lowered form"); - if (srcTy.getRank() != 2) - return emitOpError("expects rank-2 memref source after tile lowering"); - shape.assign(srcTy.getShape().begin(), srcTy.getShape().end()); - } else { - return emitOpError("expects tile_buf source (or lowered memref source)"); - } - - auto checkDim = [&](Value operand, unsigned dimIdx, - StringRef dimName) -> LogicalResult { - int64_t maxStatic = shape[dimIdx]; - - auto constVal = getConstIndexLike(operand); - if (!constVal) - return success(); - - if (*constVal < 0) - return emitOpError() << "expects " << dimName << " operand to be non-negative"; - if (maxStatic != ShapedType::kDynamic && *constVal > maxStatic) - return emitOpError() << "expects " << dimName << " operand <= shape dim (" - << maxStatic << ")"; - return success(); - }; - - if (failed(checkDim(getValidRow(), /*dimIdx=*/0, "row"))) - return failure(); - if (failed(checkDim(getValidCol(), /*dimIdx=*/1, "col"))) - return failure(); - - return success(); -} - -mlir::LogicalResult mlir::pto::GetValidShapeOp::verify() { - if (auto srcTy = llvm::dyn_cast(getSource().getType())) { - if (srcTy.getRank() != 2) - return emitOpError("expects rank-2 tile_buf source"); - if (srcTy.getValidShape().size() != 2) - return emitOpError("expects source validShape to be rank-2"); - return success(); - } - if (auto srcTy = llvm::dyn_cast(getSource().getType())) { - if (srcTy.getRank() != 2) - return emitOpError("expects rank-2 memref source after tile lowering"); - return success(); - } - return emitOpError("expects tile_buf source (or lowered memref source)"); -} - - -mlir::LogicalResult mlir::pto::TReshapeOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - Type ts = getSrc().getType(); - Type tr = getResult().getType(); - auto srcTb = dyn_cast(ts); - auto dstTb = dyn_cast(tr); - if (!srcTb || !dstTb) - return emitOpError("expects src/result to be !pto.tile_buf types"); - - if (failed(verifyTileBufCommon(*this, ts, "src")) || - failed(verifyTileBufCommon(*this, tr, "dst"))) - return failure(); - - if (srcTb.getMemorySpace() != dstTb.getMemorySpace()) - return emitOpError("expects src and dst to use the same loc"); - - Type srcElem = srcTb.getElementType(); - Type dstElem = dstTb.getElementType(); - auto srcElemBytes = getElemBytes(srcElem); - auto dstElemBytes = getElemBytes(dstElem); - if (!srcElem || !dstElem || !srcElemBytes.has_value() || !dstElemBytes.has_value()) - return emitOpError("failed to get element byte width for src/dst"); - - auto srcNumel = getStaticNumElements(getShapeVec(ts)); - auto dstNumel = getStaticNumElements(getShapeVec(tr)); - if (!srcNumel.has_value() || !dstNumel.has_value()) - return emitOpError("expects static shapes for treshape"); - - if (srcElemBytes.value() * srcNumel.value() != - dstElemBytes.value() * dstNumel.value()) - return emitOpError("expects src and dst to have the same total byte size"); - - bool srcBoxed = - srcTb.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox); - bool dstBoxed = - dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox); - if (srcBoxed != dstBoxed) - return emitOpError("cannot reshape between boxed and non-boxed tile layouts"); - - return success(); -} - -mlir::LogicalResult mlir::pto::BitcastOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - auto srcTy = llvm::dyn_cast(getSrc().getType()); - auto dstTy = llvm::dyn_cast(getResult().getType()); - if (!srcTy || !dstTy) - return emitOpError("expects tile_buf src and tile_buf result"); - - if (srcTy.getMemorySpace() != dstTy.getMemorySpace()) - return emitOpError("expects src/result to have the same memorySpace"); - - if (srcTy.getElementType() == dstTy.getElementType()) - return emitOpError( - "expects src/result to have different element types; use " - "pto.treshape for shape/config changes"); - - if (srcTy.getShape() != dstTy.getShape()) - return emitOpError("expects src/result to have the same shape; use pto.treshape for shape changes"); - - if (srcTy.getValidShape() != dstTy.getValidShape()) - return emitOpError("expects src/result to have the same validShape"); - - auto srcCfg = srcTy.getConfigAttr(); - auto dstCfg = dstTy.getConfigAttr(); - if (srcCfg != dstCfg) - return emitOpError("expects src/result to have the same tile config"); - - auto numel = getStaticNumElements(srcTy.getShape()); - if (!numel.has_value()) - return emitOpError("expects static shapes for bitcast"); - - auto srcBytes = getElemBytes(srcTy.getElementType()); - auto dstBytes = getElemBytes(dstTy.getElementType()); - if (!srcBytes.has_value() || !dstBytes.has_value()) - return emitOpError("unsupported element type for bitcast"); - - int64_t srcTotalBytes = numel.value() * srcBytes.value(); - int64_t dstTotalBytes = numel.value() * dstBytes.value(); - if (dstTotalBytes > srcTotalBytes) - return emitOpError("bitcast result requires more bytes than source storage"); - - return success(); -} - - -mlir::LogicalResult mlir::pto::TRowExpandOp::verify() { - auto verifyCommon = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyNDStyleVecTile(*this, dstTy, "dst"))) - return failure(); - auto srcSpace = getPTOMemorySpaceEnum(srcTy); - if (!srcSpace || *srcSpace != pto::AddressSpace::VEC) - return emitOpError("expects src to be in the vec address space"); - if (auto srcTb = dyn_cast(srcTy)) { - if (srcTb.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) - return emitOpError("expects src to use the none_box slayout"); - } - if (getElemTy(srcTy) != getElemTy(dstTy)) - return emitOpError("expects src and dst to have the same element type"); - if (!isSupportedVecElemType(getElemTy(srcTy), /*allowBf16=*/true, - /*allowInt8=*/true)) - return emitOpError("expects trowexpand element type to be supported"); - auto srcValid = getValidShapeVec(getSrc()); - auto dstValid = getValidShapeVec(getDst()); - if (srcValid.size() != 2 || dstValid.size() != 2) - return emitOpError("expects src and dst to have rank-2 valid_shape"); - if (srcValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && - srcValid[0] != dstValid[0]) - return emitOpError("expects src and dst to have the same valid_shape[0]"); - if (srcValid[0] != ShapedType::kDynamic && srcValid[0] == 0) - return emitOpError("expects src valid_shape[0] to be non-zero"); - if (srcValid[1] != ShapedType::kDynamic && srcValid[1] == 0) - return emitOpError("expects src valid_shape[1] to be non-zero"); - if (dstValid[0] != ShapedType::kDynamic && dstValid[0] == 0) - return emitOpError("expects dst valid_shape[0] to be non-zero"); - if (dstValid[1] != ShapedType::kDynamic && dstValid[1] == 0) - return emitOpError("expects dst valid_shape[1] to be non-zero"); - return success(); - }; - auto verifyA2A3 = [&]() -> LogicalResult { - return verifyCommon(); - }; - auto verifyA5 = [&]() -> LogicalResult { - return verifyCommon(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - - -ParseResult mlir::pto::TSort32Op::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand src, idx, tmp, dst; - Type srcTy, dstTy, idxTy, tmpTy; - bool hasTmp = false; - - if (parser.parseKeyword("ins") || parser.parseLParen() || parser.parseOperand(src)) - return failure(); - if (succeeded(parser.parseOptionalComma())) { - if (parser.parseOperand(idx)) - return failure(); - if (succeeded(parser.parseOptionalComma())) { - if (parser.parseOperand(tmp)) - return failure(); - hasTmp = true; - } - } else { - return failure(); - } - if (parser.parseColonType(srcTy) || parser.parseComma() || parser.parseType(idxTy)) - return failure(); - if (hasTmp) { - if (parser.parseComma() || parser.parseType(tmpTy)) - return failure(); - } - if (parser.parseRParen()) - return failure(); - - if (parser.parseKeyword("outs") || parser.parseLParen() || - parser.parseOperand(dst) || parser.parseColonType(dstTy) || - parser.parseRParen()) - return failure(); - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - - if (parser.resolveOperand(src, srcTy, result.operands) || - parser.resolveOperand(idx, idxTy, result.operands)) - return failure(); - if (hasTmp) { - if (parser.resolveOperand(tmp, tmpTy, result.operands)) - return failure(); - } - if (parser.resolveOperand(dst, dstTy, result.operands)) - return failure(); - - result.addAttribute( - "operandSegmentSizes", - parser.getBuilder().getDenseI32ArrayAttr({1, 1, hasTmp ? 1 : 0, 1})); - return success(); -} - -void mlir::pto::TSort32Op::print(OpAsmPrinter &p) { - p << " ins(" << getSrc() << ", " << getIdx(); - if (getTmp()) { - p << ", " << getTmp(); - p << " : " << getSrc().getType() << ", " << getIdx().getType() - << ", " << getTmp().getType() << ")"; - } else { - p << " : " << getSrc().getType() << ", " << getIdx().getType() << ")"; - } - p << " outs(" << getDst() << " : " << getDst().getType() << ")"; - p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"operandSegmentSizes"}); -} - -ParseResult mlir::pto::TRsqrtOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand src, tmp, dst; - Type srcTy, tmpTy, dstTy; - bool hasTmp = false; - - if (parser.parseKeyword("ins") || parser.parseLParen() || parser.parseOperand(src)) - return failure(); - if (succeeded(parser.parseOptionalComma())) { - if (parser.parseOperand(tmp)) - return failure(); - hasTmp = true; - } - if (parser.parseColonType(srcTy)) - return failure(); - if (hasTmp) { - if (parser.parseComma() || parser.parseType(tmpTy)) - return failure(); - } - if (parser.parseRParen()) - return failure(); - - if (parser.parseKeyword("outs") || parser.parseLParen() || - parser.parseOperand(dst) || parser.parseColonType(dstTy) || - parser.parseRParen()) - return failure(); - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - - if (parser.resolveOperand(src, srcTy, result.operands) || - parser.resolveOperand(dst, dstTy, result.operands)) - return failure(); - if (hasTmp && parser.resolveOperand(tmp, tmpTy, result.operands)) - return failure(); - - return success(); -} - -void mlir::pto::TRsqrtOp::print(OpAsmPrinter &p) { - p << " ins(" << getSrc(); - if (getTmp()) - p << ", " << getTmp(); - p << " : " << getSrc().getType(); - if (getTmp()) - p << ", " << getTmp().getType(); - p << ")"; - p << " outs(" << getDst() << " : " << getDst().getType() << ")"; - p.printOptionalAttrDict((*this)->getAttrs()); -} - -static ParseResult parseTRowExpandBinaryLikeOp(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::UnresolvedOperand src0, src1, tmp, dst; - Type src0Ty, src1Ty, tmpTy, dstTy; - bool hasTmp = false; - - if (parser.parseKeyword("ins") || parser.parseLParen() || - parser.parseOperand(src0) || parser.parseComma() || parser.parseOperand(src1)) - return failure(); - if (succeeded(parser.parseOptionalComma())) { - if (parser.parseOperand(tmp)) - return failure(); - hasTmp = true; - } - if (parser.parseColon()) - return failure(); - if (parser.parseType(src0Ty) || parser.parseComma() || parser.parseType(src1Ty)) - return failure(); - if (hasTmp) { - if (parser.parseComma() || parser.parseType(tmpTy)) - return failure(); - } - if (parser.parseRParen()) - return failure(); - if (parser.parseKeyword("outs") || parser.parseLParen() || - parser.parseOperand(dst) || parser.parseColonType(dstTy) || - parser.parseRParen()) - return failure(); - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - - if (parser.resolveOperand(src0, src0Ty, result.operands) || - parser.resolveOperand(src1, src1Ty, result.operands)) - return failure(); - if (hasTmp) { - if (parser.resolveOperand(tmp, tmpTy, result.operands)) - return failure(); - } - if (parser.resolveOperand(dst, dstTy, result.operands)) - return failure(); - - result.addAttribute( - "operandSegmentSizes", - parser.getBuilder().getDenseI32ArrayAttr({1, 1, hasTmp ? 1 : 0, 1})); - return success(); -} - -static void printTRowExpandBinaryLikeOp(OpAsmPrinter &p, Operation *op, Value src0, - Value src1, Value tmp, Value dst) { - p << " ins(" << src0 << ", " << src1; - if (tmp) { - p << ", " << tmp; - p << " : " << src0.getType() << ", " << src1.getType() << ", " - << tmp.getType() << ")"; - } else { - p << " : " << src0.getType() << ", " << src1.getType() << ")"; - } - p << " outs(" << dst << " : " << dst.getType() << ")"; - p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"operandSegmentSizes"}); -} - -ParseResult mlir::pto::TRowExpandDivOp::parse(OpAsmParser &parser, OperationState &result) { - return parseTRowExpandBinaryLikeOp(parser, result); -} - -void mlir::pto::TRowExpandDivOp::print(OpAsmPrinter &p) { - printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), - getDst()); -} - -ParseResult mlir::pto::TRowExpandMulOp::parse(OpAsmParser &parser, OperationState &result) { - return parseTRowExpandBinaryLikeOp(parser, result); -} - -void mlir::pto::TRowExpandMulOp::print(OpAsmPrinter &p) { - printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), - getDst()); -} - -ParseResult mlir::pto::TRowExpandSubOp::parse(OpAsmParser &parser, OperationState &result) { - return parseTRowExpandBinaryLikeOp(parser, result); -} - -void mlir::pto::TRowExpandSubOp::print(OpAsmPrinter &p) { - printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), - getDst()); -} - -ParseResult mlir::pto::TRowExpandExpdifOp::parse(OpAsmParser &parser, - OperationState &result) { - return parseTRowExpandBinaryLikeOp(parser, result); -} - -void mlir::pto::TRowExpandExpdifOp::print(OpAsmPrinter &p) { - printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), - getDst()); -} - -ParseResult mlir::pto::TRowExpandMaxOp::parse(OpAsmParser &parser, - OperationState &result) { - return parseTRowExpandBinaryLikeOp(parser, result); -} - -void mlir::pto::TRowExpandMaxOp::print(OpAsmPrinter &p) { - printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), - getDst()); -} - -ParseResult mlir::pto::TRowExpandMinOp::parse(OpAsmParser &parser, - OperationState &result) { - return parseTRowExpandBinaryLikeOp(parser, result); -} - -void mlir::pto::TRowExpandMinOp::print(OpAsmPrinter &p) { - printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), - getDst()); -} - -static FailureOr verifyTRowExpandBinaryCore(Operation *op, Type src0Ty, - Type src1Ty, Type dstTy, - Type tmpTy, bool hasTmp) { - if (failed(verifyTileBufCommon(op, src0Ty, "src0")) || - failed(verifyTileBufCommon(op, src1Ty, "src1")) || - failed(verifyTileBufCommon(op, dstTy, "dst"))) - return failure(); - if (hasTmp && failed(verifyTileBufCommon(op, tmpTy, "tmp"))) - return failure(); - if (failed(verifyTileBufSameElemType(op, src0Ty, dstTy, "src0", "dst"))) - return failure(); - if (getElemTy(src0Ty) != getElemTy(src1Ty)) { - op->emitOpError("expects src0 and src1 to have the same element type"); - return failure(); - } - if (!isRowMajorTileBuf(dstTy)) { - op->emitOpError("expects dst to use row-major layout"); - return failure(); - } - return getElemTy(src0Ty); -} - -mlir::LogicalResult mlir::pto::TRowExpandDivOp::verify() { - auto verifyByArch = [&](PTOArch targetArch) -> LogicalResult { - Type src0Ty = getSrc0().getType(); - Type src1Ty = getSrc1().getType(); - Type dstTy = getDst().getType(); - FailureOr elemOr = verifyTRowExpandBinaryCore( - *this, src0Ty, src1Ty, dstTy, getTmp() ? getTmp().getType() : Type{}, - static_cast(getTmp())); - if (failed(elemOr)) - return failure(); - Type elem = *elemOr; - bool supported = - elem.isF16() || elem.isF32() || - (targetArch == PTOArch::A5 && - (elem.isInteger(8) || elem.isInteger(16) || elem.isInteger(32))); - if (!supported) { - if (targetArch == PTOArch::A5) - return emitOpError( - "expects A5 trowexpanddiv element type to be i8/i16/i32/f16/f32"); - return emitOpError("expects element type to be f16 or f32"); - } - return mlir::success(); - }; - auto verifyA2A3 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A3); }; - auto verifyA5 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A5); }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - - -mlir::LogicalResult mlir::pto::TRowExpandMulOp::verify() { - auto verifyByArch = [&](PTOArch targetArch) -> LogicalResult { - Type src0Ty = getSrc0().getType(); - Type src1Ty = getSrc1().getType(); - Type dstTy = getDst().getType(); - FailureOr elemOr = verifyTRowExpandBinaryCore( - *this, src0Ty, src1Ty, dstTy, getTmp() ? getTmp().getType() : Type{}, - static_cast(getTmp())); - if (failed(elemOr)) - return failure(); - Type elem = *elemOr; - bool supported = elem.isF16() || elem.isF32() || elem.isInteger(16) || - elem.isInteger(32) || - (targetArch == PTOArch::A5 && elem.isInteger(8)); - if (!supported) { - if (targetArch == PTOArch::A5) - return emitOpError( - "expects A5 trowexpandmul element type to be i8/i16/i32/f16/f32"); - return emitOpError( - "expects A2/A3 trowexpandmul element type to be i16/i32/f16/f32"); - } - return mlir::success(); - }; - auto verifyA2A3 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A3); }; - auto verifyA5 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A5); }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - - -mlir::LogicalResult mlir::pto::TRowExpandSubOp::verify() { - auto verifyByArch = [&](PTOArch targetArch) -> LogicalResult { - Type src0Ty = getSrc0().getType(); - Type src1Ty = getSrc1().getType(); - Type dstTy = getDst().getType(); - FailureOr elemOr = verifyTRowExpandBinaryCore( - *this, src0Ty, src1Ty, dstTy, getTmp() ? getTmp().getType() : Type{}, - static_cast(getTmp())); - if (failed(elemOr)) - return failure(); - Type elem = *elemOr; - bool supported = elem.isF16() || elem.isF32() || elem.isInteger(16) || - elem.isInteger(32) || - (targetArch == PTOArch::A5 && elem.isInteger(8)); - if (!supported) { - if (targetArch == PTOArch::A5) - return emitOpError( - "expects A5 trowexpandsub element type to be i8/i16/i32/f16/f32"); - return emitOpError( - "expects A2/A3 trowexpandsub element type to be i16/i32/f16/f32"); - } - return mlir::success(); - }; - auto verifyA2A3 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A3); }; - auto verifyA5 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A5); }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TRowExpandAddOp::verify() { - auto verifyByArch = [&](PTOArch targetArch) -> LogicalResult { - Type src0Ty = getSrc0().getType(); - Type src1Ty = getSrc1().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, src0Ty, "src0")) || - failed(verifyTileBufCommon(*this, src1Ty, "src1")) || - failed(verifyTileBufCommon(*this, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst"))) - return failure(); - if (failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) - return failure(); - if (getElemTy(src0Ty) != getElemTy(src1Ty)) - return emitOpError("expects src0 and src1 to have the same element type"); - if (!isRowMajorTileBuf(src0Ty)) - return emitOpError("expects src0 to use row-major layout"); - if (!isRowMajorTileBuf(dstTy)) - return emitOpError("expects dst to use row-major layout"); - Type elem = getElemTy(src0Ty); - bool supported = elem.isF16() || elem.isF32() || elem.isInteger(16) || - elem.isInteger(32) || - (targetArch == PTOArch::A5 && elem.isInteger(8)); - if (!supported) { - if (targetArch == PTOArch::A5) - return emitOpError( - "expects A5 trowexpandadd element type to be i8/i16/i32/f16/f32"); - return emitOpError( - "expects A2/A3 trowexpandadd element type to be i16/i32/f16/f32"); - } - auto src1Valid = getValidShapeVec(src1Ty); - auto dstValid = getValidShapeVec(dstTy); - if (src1Valid.size() != 2 || dstValid.size() != 2) - return emitOpError("expects src1 and dst to have rank-2 valid_shape"); - if (src1Valid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && - src1Valid[0] != dstValid[0]) - return emitOpError("expects src1 valid_shape[0] to equal dst valid_shape[0]"); - bool src1IsRowMajor = isRowMajorTileBuf(src1Ty); - int64_t expectedCol = elem.isInteger(8) - ? 32 - : ((elem.isF16() || elem.isInteger(16)) ? 16 : 8); - int64_t src1Col = src1Valid[1]; - if (src1IsRowMajor) { - if (src1Col != ShapedType::kDynamic && src1Col != expectedCol) - return emitOpError("expects row-major src1 valid_shape[1] to be 32/sizeof(dtype)"); - } else { - if (src1Col != ShapedType::kDynamic && src1Col != 1) - return emitOpError("expects non-row-major src1 valid_shape[1] to be 1"); - } - return mlir::success(); - }; - auto verifyA2A3 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A3); }; - auto verifyA5 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A5); }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -static LogicalResult verifyTRowExpandReduceLikeOp(Operation *op, Type src0Ty, - Type src1Ty, Type dstTy, - Type tmpTy, bool hasTmp, - PTOArch targetArch, - StringRef opName, - bool allowIntegerTypes) { - if (failed(verifyTileBufCommon(op, src0Ty, "src0")) || - failed(verifyTileBufCommon(op, src1Ty, "src1")) || - failed(verifyTileBufCommon(op, dstTy, "dst"))) - return failure(); - if (hasTmp) { - if (failed(verifyTileBufCommon(op, tmpTy, "tmp"))) - return failure(); - if (getElemTy(tmpTy) != getElemTy(dstTy)) - return op->emitOpError() << "expects tmp and dst to have the same element type"; - } - - Type elem = getElemTy(dstTy); - if (!elem || getElemTy(src0Ty) != elem || getElemTy(src1Ty) != elem) - return op->emitOpError("expects src0, src1, and dst to have the same element type"); - bool supported = elem.isF16() || elem.isF32() || - (allowIntegerTypes && - (elem.isInteger(16) || elem.isInteger(32) || - (targetArch == PTOArch::A5 && elem.isInteger(8)))); - if (!supported) { - if (!allowIntegerTypes) - return op->emitOpError() << "expects " << opName - << " element type to be f16 or f32"; - if (targetArch == PTOArch::A5) - return op->emitOpError() << "expects A5 " << opName - << " element type to be i8/i16/i32/f16/f32"; - return op->emitOpError() << "expects A2/A3 " << opName - << " element type to be i16/i32/f16/f32"; - } - - if (!isRowMajorTileBuf(dstTy)) - return op->emitOpError("expects dst to use row-major layout"); - - auto src0Valid = getValidShapeVec(src0Ty); - auto src1Valid = getValidShapeVec(src1Ty); - auto dstValid = getValidShapeVec(dstTy); - if (src0Valid.size() != 2 || src1Valid.size() != 2 || dstValid.size() != 2) - return op->emitOpError("expects src0, src1, and dst to have rank-2 valid_shape"); - - if (dstValid[0] != ShapedType::kDynamic && dstValid[0] == 0) - return op->emitOpError("expects dst valid_shape[0] to be non-zero"); - if (dstValid[1] != ShapedType::kDynamic && dstValid[1] == 0) - return op->emitOpError("expects dst valid_shape[1] to be non-zero"); - - auto validShapeMatches = [](ArrayRef lhs, - ArrayRef rhs) -> bool { - if (lhs.size() != rhs.size()) - return false; - for (auto [l, r] : llvm::zip(lhs, rhs)) { - if (l != ShapedType::kDynamic && r != ShapedType::kDynamic && l != r) - return false; - } - return true; - }; - - const bool src0MatchesDst = validShapeMatches(src0Valid, dstValid); - const bool src1MatchesDst = validShapeMatches(src1Valid, dstValid); - - auto checkBroadcastOperand = [&](Type operandTy, ArrayRef operandValid, - StringRef operandName, - bool requireNonRowMajor) -> LogicalResult { - if (operandValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && - operandValid[0] != dstValid[0]) { - return op->emitOpError() << "expects " << operandName - << " valid_shape[0] to equal dst valid_shape[0]"; - } - int64_t expectedCol = elem.isInteger(8) ? 32 : ((elem.isF16() || elem.isInteger(16)) ? 16 : 8); - int64_t operandCol = operandValid[1]; - bool operandIsRowMajor = isRowMajorTileBuf(operandTy); - if (requireNonRowMajor && operandIsRowMajor) { - return op->emitOpError() << "expects " << operandName - << " to use a non-row-major layout when tmp is present"; - } - if (operandIsRowMajor) { - if (operandCol != ShapedType::kDynamic && operandCol != expectedCol) { - return op->emitOpError() - << "expects row-major " << operandName - << " valid_shape[1] to be 32/sizeof(dtype)"; - } - return success(); - } - if (operandCol != ShapedType::kDynamic && operandCol != 1) { - return op->emitOpError() << "expects non-row-major " << operandName - << " valid_shape[1] to be 1"; - } - return success(); - }; - - auto checkFullAndBroadcast = [&](Type fullTy, ArrayRef fullValid, - StringRef fullName, Type broadcastTy, - ArrayRef broadcastValid, - StringRef broadcastName) -> LogicalResult { - if (!isRowMajorTileBuf(fullTy)) - return op->emitOpError() << "expects " << fullName - << " to use row-major layout when it matches dst"; - if (fullValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && - fullValid[0] != dstValid[0]) - return op->emitOpError() << "expects " << fullName - << " valid_shape[0] to equal dst valid_shape[0]"; - if (fullValid[1] != ShapedType::kDynamic && dstValid[1] != ShapedType::kDynamic && - fullValid[1] != dstValid[1]) - return op->emitOpError() << "expects " << fullName - << " valid_shape[1] to equal dst valid_shape[1]"; - return checkBroadcastOperand(broadcastTy, broadcastValid, broadcastName, - /*requireNonRowMajor=*/hasTmp && - targetArch == PTOArch::A3); - }; - - if (hasTmp && targetArch == PTOArch::A5) - return op->emitOpError("expects A5 form to omit tmp"); - - if (src0MatchesDst) { - if (succeeded(checkFullAndBroadcast(src0Ty, src0Valid, "src0", src1Ty, - src1Valid, "src1"))) - return success(); - } - if (src1MatchesDst) { - if (succeeded(checkFullAndBroadcast(src1Ty, src1Valid, "src1", src0Ty, - src0Valid, "src0"))) - return success(); - } - - return op->emitOpError() << "expects one of src0/src1 to match dst valid_shape" - << " and the other to be a per-row scalar vector"; -} - -mlir::LogicalResult mlir::pto::TRowExpandExpdifOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), - getSrc1().getType(), getDst().getType(), - getTmp() ? getTmp().getType() : Type{}, - (bool)getTmp(), PTOArch::A3, - "trowexpandexpdif", - /*allowIntegerTypes=*/false); - }; - auto verifyA5 = [&]() -> LogicalResult { - return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), - getSrc1().getType(), getDst().getType(), - getTmp() ? getTmp().getType() : Type{}, - (bool)getTmp(), PTOArch::A5, - "trowexpandexpdif", - /*allowIntegerTypes=*/false); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TRowExpandMaxOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), - getSrc1().getType(), getDst().getType(), - getTmp() ? getTmp().getType() : Type{}, - (bool)getTmp(), PTOArch::A3, - "trowexpandmax", - /*allowIntegerTypes=*/true); - }; - auto verifyA5 = [&]() -> LogicalResult { - return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), - getSrc1().getType(), getDst().getType(), - getTmp() ? getTmp().getType() : Type{}, - (bool)getTmp(), PTOArch::A5, - "trowexpandmax", - /*allowIntegerTypes=*/true); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TRowExpandMinOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), - getSrc1().getType(), getDst().getType(), - getTmp() ? getTmp().getType() : Type{}, - (bool)getTmp(), PTOArch::A3, - "trowexpandmin", - /*allowIntegerTypes=*/true); - }; - auto verifyA5 = [&]() -> LogicalResult { - return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), - getSrc1().getType(), getDst().getType(), - getTmp() ? getTmp().getType() : Type{}, - (bool)getTmp(), PTOArch::A5, - "trowexpandmin", - /*allowIntegerTypes=*/true); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - - -mlir::LogicalResult mlir::pto::TRowMaxOp::verify() { - auto verifyByArch = [&]() -> LogicalResult { - return verifyTRowReductionNoTmpCommon(*this, getSrc().getType(), - getDst().getType(), - "expects element type to be i16/i32/f16/f32"); - }; - return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); -} - -mlir::LogicalResult mlir::pto::TRowArgMaxOp::verify() { - auto verifyByArch = [&]() -> LogicalResult { - return verifyTRowArgReductionCommon(*this, getSrc().getType(), - getTmp().getType(), getDst().getType()); - }; - - return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); -} - - -mlir::LogicalResult mlir::pto::TRowMinOp::verify() { - auto verifyByArch = [&]() -> LogicalResult { - return verifyTRowReductionWithTmpCommon( - *this, getSrc().getType(), getTmp().getType(), getDst().getType(), - "expects element type to be i16/i32/f16/f32"); - }; - return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); -} - -mlir::LogicalResult mlir::pto::TRowArgMinOp::verify() { - auto verifyByArch = [&]() -> LogicalResult { - return verifyTRowArgReductionCommon(*this, getSrc().getType(), - getTmp().getType(), getDst().getType()); - }; - - return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); -} - - -mlir::LogicalResult mlir::pto::TRowSumOp::verify() { - auto verifyByArch = [&]() -> LogicalResult { - return verifyTRowReductionNoTmpCommon(*this, getSrc().getType(), - getDst().getType(), - "expects element type to be i16/i32/f16/f32"); - }; - return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); -} - -mlir::LogicalResult mlir::pto::TRowProdOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - return verifyTRowReductionWithTmpCommon( - *this, getSrc().getType(), getTmp().getType(), getDst().getType(), - "expects A2/A3 trowprod element type to be i16/i32/f16/f32"); - }; - auto verifyA5 = [&]() -> LogicalResult { - return verifyTRowReductionWithTmpCommon( - *this, getSrc().getType(), getTmp().getType(), getDst().getType(), - "expects A5 trowprod element type to be i16/i32/f16/f32"); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - - -mlir::LogicalResult mlir::pto::TRsqrtOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - Type ts = getSrc().getType(); - Type td = getDst().getType(); - if (failed(verifyVecTileUnaryOp(*this, ts, td, "src", "dst", - /*allowBf16=*/false, /*allowInt8=*/false))) - return failure(); - if (failed(verifyTileBufSameValidShape(*this, ts, td, "src", "dst"))) - return failure(); - auto ft = mlir::dyn_cast(getElemTy(ts)); - if (!ft || (!ft.isF16() && !ft.isF32())) - return emitOpError("expects element type to be f16 or f32"); - if (auto tmp = getTmp()) { - Type tt = tmp.getType(); - if (failed(verifyVecTileCommon(*this, tt, "tmp"))) - return failure(); - - auto tmpElemTy = getElemTy(tt); - auto tmpElemBytes = getElemBytes(tmpElemTy); - auto tmpNumel = getStaticNumElements(getShapeVec(tt)); - if (!tmpElemBytes.has_value() || !tmpNumel.has_value()) - return emitOpError("expects tmp to have a static, byte-addressable tile type"); - if (tmpElemBytes.value() * tmpNumel.value() < 32) - return emitOpError("expects tmp to be at least 32 bytes when provided"); - } - return mlir::success(); -} - - -mlir::LogicalResult mlir::pto::TScatterOp::verify() { - const bool hasIndexes = static_cast(getIndexes()); - const bool hasMaskPattern = static_cast(getMaskPatternAttr()); - if (hasIndexes == hasMaskPattern) { - return emitOpError( - "expects exactly one of indexes operand or maskPattern attribute"); - } - - auto isAllowedDataElem = [&](mlir::Type t) -> bool { - if (t.isF16() || t.isF32() || t.isBF16()) return true; - if (auto it = mlir::dyn_cast(t)) - return (it.getWidth() == 8 || it.getWidth() == 16 || it.getWidth() == 32); - return false; - }; - auto isAllowedIndexElem = [&](mlir::Type t) -> bool { - if (auto it = mlir::dyn_cast(t)) - return (it.getWidth() == 16 || it.getWidth() == 32); - return false; - }; - auto getMaskScatterTimes = [&](mlir::pto::MaskPatternAttr mp) -> unsigned { - switch (mp.getValue()) { - case mlir::pto::MaskPattern::P1111: - return 1; - case mlir::pto::MaskPattern::P0101: - case mlir::pto::MaskPattern::P1010: - return 2; - default: - return 4; - } - }; - - auto verifyIndexedForm = [&]() -> LogicalResult { - Type ts = getSrc().getType(); - Type ti = getIndexes().getType(); - Type td = getDst().getType(); - if (failed(verifyVecTileStorage(*this, ts, "src")) || - failed(verifyVecTileStorage(*this, ti, "indexes")) || - failed(verifyVecTileStorage(*this, td, "dst"))) - return failure(); - - Type srcElem = getElemTy(ts), dstElem = getElemTy(td), idxElem = getElemTy(ti); - if (!srcElem || !dstElem || !idxElem) - return emitOpError("failed to get element type for operands"); - if (srcElem != dstElem) - return emitOpError("expects src/dst to have the same element type"); - - if (!isAllowedDataElem(srcElem)) - return emitOpError("expects src/dst element type to be i8/i16/i32/f16/bf16/f32"); - if (!isAllowedIndexElem(idxElem)) - return emitOpError("expects indexes element type to be i16/i32"); - - auto bwData = getPTOStorageElemBitWidth(srcElem); - auto bwIdx = getPTOStorageElemBitWidth(idxElem); - if (bwData != 8 && bwData != 16 && bwData != 32) - return emitOpError("unexpected src/dst element bitwidth"); - - unsigned dataBytes = bwData / 8; - unsigned idxBytes = bwIdx / 8; - unsigned expectedIdxBytes = (dataBytes == 1) ? 2 : dataBytes; - if (idxBytes != expectedIdxBytes) - return emitOpError("expects indexes element size to match the documented scatter rule"); - return mlir::success(); - }; - - auto verifyMaskForm = [&]() -> LogicalResult { - Type ts = getSrc().getType(); - Type td = getDst().getType(); - if (failed(verifyVecTileCommon(*this, ts, "src")) || - failed(verifyVecTileCommon(*this, td, "dst"))) - return failure(); - - auto srcTB = dyn_cast(ts); - auto dstTB = dyn_cast(td); - if (!srcTB || !dstTB) - return emitOpError("expects src and dst to be tile_buf types"); - - if (getElemTy(ts) != getElemTy(td)) - return emitOpError("expects src and dst to have the same element type"); - if (!isAllowedDataElem(getElemTy(ts))) - return emitOpError("expects src/dst element type to be i8/i16/i32/f16/bf16/f32"); - - auto srcValid = getValidShapeVec(ts); - auto dstValid = getValidShapeVec(td); - if (srcValid.size() != 2 || dstValid.size() != 2) - return emitOpError("expects src and dst to have rank-2 valid_shape"); - - auto mp = getMaskPatternAttr(); - if (!mp) - return emitOpError("expects mask-pattern tscatter to provide maskPattern"); - const unsigned times = getMaskScatterTimes(mp); - if (srcValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && - srcValid[0] != dstValid[0]) - return emitOpError("expects src and dst to have the same valid rows"); - if (srcValid[1] != ShapedType::kDynamic && dstValid[1] != ShapedType::kDynamic && - srcValid[1] != static_cast(dstValid[1] * times)) - return emitOpError("expects src valid cols to equal dst valid cols times the mask expansion factor"); - - if (srcTB.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) || - dstTB.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor)) - return emitOpError("expects mask-pattern tscatter to use row_major blayout"); - return mlir::success(); - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - if (hasMaskPattern) - return verifyMaskForm(); - return verifyIndexedForm(); - }; - auto verifyA5 = [&]() -> LogicalResult { - if (hasMaskPattern) - return emitOpError("mask-pattern tscatter is not supported on A5 yet"); - return verifyIndexedForm(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - - -mlir::LogicalResult mlir::pto::TSelOp::verify() { - auto verifyCommon = [&]() -> FailureOr { - Type t0 = getSrc0().getType(); - Type t1 = getSrc1().getType(); - Type td = getDst().getType(); - if (failed(verifyTileBufCommon(*this, t0, "src0")) || - failed(verifyTileBufCommon(*this, t1, "src1")) || - failed(verifyTileBufCommon(*this, td, "dst"))) - return failure(); - - Type srcElem = getElemTy(t0); - Type src1Elem = getElemTy(t1); - Type dstElem = getElemTy(td); - if (!srcElem || !src1Elem || !dstElem) { - emitOpError("failed to get element type for operands"); - return failure(); - } - if (srcElem != src1Elem || srcElem != dstElem) { - emitOpError("expects src0, src1, and dst to have the same element type"); - return failure(); - } - - if (!isRowMajorTileBuf(t0) || !isRowMajorTileBuf(t1) || - !isRowMajorTileBuf(td)) { - emitOpError( - "expects src0, src1, and dst to use row-major layout"); - return failure(); - } - return srcElem; - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - FailureOr srcElem = verifyCommon(); - if (failed(srcElem)) - return failure(); - Type elem = *srcElem; - bool ok = elem.isF16() || elem.isBF16() || elem.isF32(); - if (auto it = dyn_cast(elem)) - ok = it.getWidth() == 16 || it.getWidth() == 32; - if (!ok) - return emitOpError( - "expects A2/A3 tsel src0, src1, and dst element type to be i16/i32/f16/bf16/f32"); - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - FailureOr srcElem = verifyCommon(); - if (failed(srcElem)) - return failure(); - Type elem = *srcElem; - bool ok = elem.isF16() || elem.isBF16() || elem.isF32(); - if (auto it = dyn_cast(elem)) - ok = it.getWidth() == 8 || it.getWidth() == 16 || it.getWidth() == 32; - if (!ok) - return emitOpError( - "expects A5 tsel src0, src1, and dst element type to be i8/i16/i32/f16/bf16/f32"); - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - - -mlir::LogicalResult mlir::pto::TSelSOp::verify() { - // Constraints & Verification per PTO_IR_manual.md pto.tsels: - // - src and dst same element type; A2A3: i16/i32/f16/f32; A5: i8/i16/i32/f16/f32 - // - src and dst row-major; src and dst same valid region - auto verifyCommon = [&]() -> FailureOr { - Type tMask = getMask().getType(); - Type tSrc = getSrc().getType(); - Type tTmp = getTmp().getType(); - Type tDst = getDst().getType(); - if (failed(verifyTileBufCommon(*this, tMask, "mask")) || - failed(verifyTileBufCommon(*this, tSrc, "src")) || - failed(verifyTileBufCommon(*this, tTmp, "tmp")) || - failed(verifyTileBufCommon(*this, tDst, "dst"))) - return failure(); - Type eMask = getElemTy(tMask), eSrc = getElemTy(tSrc); - Type eTmp = getElemTy(tTmp), eDst = getElemTy(tDst); - if (!eMask || !eSrc || !eTmp || !eDst) { - emitOpError("failed to get element type for operands"); - return failure(); - } - if (eSrc != eDst) - return emitOpError("expects src and dst to have the same element type"); - if (failed(verifyTileBufSameValidShape(*this, tSrc, tDst, "src", "dst"))) - return failure(); - return eDst; - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - FailureOr elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - Type tSrc = getSrc().getType(); - Type tDst = getDst().getType(); - if (!isRowMajorTileBuf(tSrc) || !isRowMajorTileBuf(tDst)) - return emitOpError("expects src and dst to use row-major layout"); - Type elem = *elemOr; - bool ok = elem.isF16() || elem.isF32(); - if (auto it = mlir::dyn_cast(elem)) - ok = (it.getWidth() == 16 || it.getWidth() == 32); - if (!ok) - return emitOpError( - "expects A2/A3 tsels src and dst element type to be i16, i32, f16, or f32"); - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - FailureOr elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - Type tSrc = getSrc().getType(); - Type tDst = getDst().getType(); - if (!isRowMajorTileBuf(tSrc) || !isRowMajorTileBuf(tDst)) - return emitOpError("expects src and dst to use row-major layout"); - Type elem = *elemOr; - bool ok = elem.isF16() || elem.isF32(); - if (auto it = mlir::dyn_cast(elem)) - ok = (it.getWidth() == 8 || it.getWidth() == 16 || it.getWidth() == 32); - if (!ok) - return emitOpError( - "expects A5 tsels src and dst element type to be i8, i16, i32, f16, or f32"); - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - - -mlir::LogicalResult mlir::pto::TShlOp::verify() { - auto verify = [&]() -> LogicalResult { - FailureOr elemOr = verifyShiftLikeBinaryTileOpCommon( - *this, getSrc0().getType(), getSrc1().getType(), getDst().getType()); - if (failed(elemOr)) - return failure(); - auto it = mlir::dyn_cast(*elemOr); - if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && - it.getWidth() != 32)) - return emitOpError( - "expects tshl src0 and src1 element type to be i8/i16/i32"); - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verify, verify); -} - - -mlir::LogicalResult mlir::pto::TShrOp::verify() { - auto verify = [&]() -> LogicalResult { - FailureOr elemOr = verifyShiftLikeBinaryTileOpCommon( - *this, getSrc0().getType(), getSrc1().getType(), getDst().getType()); - if (failed(elemOr)) - return failure(); - auto it = mlir::dyn_cast(*elemOr); - if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && - it.getWidth() != 32)) - return emitOpError( - "expects tshr src0 and src1 element type to be i8/i16/i32"); - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verify, verify); -} - - -mlir::LogicalResult mlir::pto::TSort32Op::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - Type idxTy = getIdx().getType(); - if (failed(verifyVecTileCommon(*this, srcTy, "src")) || - failed(verifyVecTileCommon(*this, dstTy, "dst")) || - failed(verifyVecTileCommon(*this, idxTy, "idx"))) - return failure(); - if (getTmp() && - failed(verifyVecTileCommon(*this, getTmp().getType(), "tmp"))) - return failure(); - - auto srcElem = getElemTy(srcTy); - auto dstElem = getElemTy(dstTy); - if (!srcElem || !dstElem || srcElem != dstElem) - return emitOpError() << "expects src and dst to have the same element type"; - if (!(srcElem.isF16() || srcElem.isF32())) - return emitOpError() << "expects src and dst element type to be f16 or f32"; - - auto idxElem = getElemTy(idxTy); - auto idxInt = dyn_cast(idxElem); - if (!idxInt || idxInt.getWidth() != 32) - return emitOpError() << "expects idx element type to be i32/u32"; - return mlir::success(); -} - - -mlir::LogicalResult mlir::pto::TSqrtOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyVecTileUnaryOp(*this, srcTy, dstTy, "src", "dst", - /*allowBf16=*/false, /*allowInt8=*/false))) - return failure(); - if (failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) - return failure(); - - auto srcElem = getElemTy(srcTy); - if (!(mlir::isa(srcElem) || mlir::isa(srcElem))) - return emitOpError() << "expects src and dst element type to be float or half"; - - return mlir::success(); -} - - - -mlir::LogicalResult mlir::pto::TStoreFPOp::verify() { - auto shouldBypassDecoded = [&]() -> bool { - Value src = getSrc(); - Value fp = getFp(); - return isa(src.getType()) || isa(fp.getType()) || - src.getDefiningOp() || - fp.getDefiningOp(); - }; - - auto verifyDstType = [&]() -> LogicalResult { - Type dstTy = getDst().getType(); - if (!isa(dstTy)) - return emitOpError() - << "expects dst to be a memref or !pto.partition_tensor_view"; - if (auto dstPart = dyn_cast(dstTy)) { - for (auto [idx, dim] : llvm::enumerate(dstPart.getShape())) { - if (dim != ShapedType::kDynamic && dim <= 0) - return emitOpError() - << "expects dst shape[" << idx << "] to be positive"; - } - } - return success(); - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type fpTy = getFp().getType(); - if (!isa(srcTy)) - return emitOpError() << "expects src to be a !pto.tile_buf"; - if (!isa(fpTy)) - return emitOpError() << "expects fp to be a !pto.tile_buf"; - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, fpTy, "fp"))) - return failure(); - if (failed(verifyDstType())) - return failure(); - auto srcSpace = getPTOMemorySpaceEnum(srcTy); - if (!srcSpace || *srcSpace != pto::AddressSpace::ACC) - return emitOpError() << "expects src to be in the acc address space"; - auto srcElemTy = getElemTy(srcTy); - auto srcIntTy = dyn_cast(srcElemTy); - if (!(srcElemTy.isF32() || - (srcIntTy && srcIntTy.getWidth() == 32))) - return emitOpError() - << "expects src to have element type f32, i32"; - auto srcShape = getShapeVec(srcTy); - if (srcShape.size() != 2) - return emitOpError() << "expects src to have rank 2"; - if (srcShape[1] != ShapedType::kDynamic && - (srcShape[1] < 1 || srcShape[1] > 4095)) - return emitOpError() << "expects src.cols to be in the range [1, 4095]"; - auto srcValid = getValidShapeVec(srcTy); - if (srcValid.size() != 2) - return emitOpError() << "expects src to have a rank-2 valid_shape"; - if (srcValid[1] != ShapedType::kDynamic && - (srcValid[1] < 1 || srcValid[1] > 4095)) - return emitOpError() - << "expects src.valid_shape[1] to be in the range [1, 4095]"; - return mlir::success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type fpTy = getFp().getType(); - if (!isa(srcTy)) - return emitOpError() << "expects src to be a !pto.tile_buf"; - if (!isa(fpTy)) - return emitOpError() << "expects fp to be a !pto.tile_buf"; - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, fpTy, "fp"))) - return failure(); - if (failed(verifyDstType())) - return failure(); - auto srcSpace = getPTOMemorySpaceEnum(srcTy); - if (!srcSpace || *srcSpace != pto::AddressSpace::ACC) - return emitOpError() << "expects src to be in the acc address space"; - return mlir::success(); - }; - if (shouldBypassDecoded()) - return success(); - switch (getVerifierTargetArch(getOperation())) { - case VerifierTargetArch::A2A3: - return verifyA2A3(); - case VerifierTargetArch::A5: - return verifyA5(); - } - return failure(); -} - - -mlir::LogicalResult mlir::pto::TSubOp::verify() { - return verifyArithmeticBinaryTileOpWithArchDispatch( - getOperation(), getSrc0().getType(), getSrc1().getType(), getDst().getType(), - /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/false, - "expects A2/A3 tsub element type to be i32/i16/f16/f32", - "expects A5 tsub element type to be i32/i16/i8/f16/f32"); -} - - -mlir::LogicalResult mlir::pto::TSubCOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - Type src0Ty = getSrc0().getType(); - Type src1Ty = getSrc1().getType(); - Type src2Ty = getSrc2().getType(); - Type dstTy = getDst().getType(); - if (!isPTOShapedLike(src0Ty) || !isPTOShapedLike(src1Ty) || !isPTOShapedLike(src2Ty) || !isPTOShapedLike(dstTy)) - return emitOpError() << "expects PTO shaped-like src0, src1, src2, and dst"; - - auto d = getShapeVec(dstTy); - if (getShapeVec(src0Ty).size() != d.size() || getShapeVec(src1Ty).size() != d.size() || getShapeVec(src2Ty).size() != d.size()) - return emitOpError() << "expects all tensors to have the same rank"; - return mlir::success(); -} - - -mlir::LogicalResult mlir::pto::TSubSOp::verify() { - return verifyArithmeticScalarTileOpWithArchDispatch( - getOperation(), getSrc().getType(), getDst().getType(), getScalar().getType(), - /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, - "expects A2/A3 tsubs element type to be i32/i16/f16/f32", - "expects A5 tsubs element type to be i32/i16/i8/f16/bf16/f32", - /*requireValidRowsEqualOnA2A3=*/true, - /*requireValidRowsEqualOnA5=*/false); -} - - -mlir::LogicalResult mlir::pto::TSubSCOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - Type src0Ty = getSrc0().getType(); - Type src1Ty = getSrc1().getType(); - Type dstTy = getDst().getType(); - if (!isPTOShapedLike(src0Ty) || !isPTOShapedLike(src1Ty) || !isPTOShapedLike(dstTy)) - return emitOpError() << "expects PTO shaped-like src0, src1, and dst"; - - auto d = getShapeVec(dstTy); - if (getShapeVec(src0Ty).size() != d.size() || getShapeVec(src1Ty).size() != d.size()) - return emitOpError() << "expects src0, src1, and dst to have the same rank"; - return mlir::success(); -} -mlir::LogicalResult mlir::pto::TTransOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type tmpTy = getTmp().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, tmpTy, "tmp")) || - failed(verifyTileBufCommon(*this, dstTy, "dst"))) - return failure(); - Type srcElem = getElemTy(srcTy); - Type tmpElem = getElemTy(tmpTy); - Type dstElem = getElemTy(dstTy); - if (!srcElem || !tmpElem || !dstElem || srcElem != dstElem || srcElem != tmpElem) - return emitOpError() << "expects src and dst to have the same element type"; - if (auto srcTb = dyn_cast(srcTy)) { - if (srcTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor)) - return emitOpError() << "expects A2/A3 transpose src to use the row_major blayout"; - } - unsigned elemBytes = getPTOStorageElemByteSize(srcElem); - if (elemBytes == 0) - return emitOpError() << "failed to get transpose element size"; - if (elemBytes != 1 && elemBytes != 2 && elemBytes != 4) - return emitOpError() << "expects transpose element size to be 1, 2, or 4 bytes"; - auto isAllowedWidthType = [&](Type ty) { - if (elemBytes == 4) - return ty.isInteger(32) || ty.isF32(); - if (elemBytes == 2) - return ty.isInteger(16) || ty.isF16() || ty.isBF16(); - return ty.isInteger(8); - }; - if (!isAllowedWidthType(srcElem)) - return emitOpError() << "expects transpose element type to match the supported set for its width"; - return mlir::success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type tmpTy = getTmp().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, tmpTy, "tmp")) || - failed(verifyTileBufCommon(*this, dstTy, "dst"))) - return failure(); - Type srcElem = getElemTy(srcTy); - Type tmpElem = getElemTy(tmpTy); - Type dstElem = getElemTy(dstTy); - if (!srcElem || !tmpElem || !dstElem || srcElem != dstElem || srcElem != tmpElem) - return emitOpError() << "expects src, tmp, and dst to have the same element type"; - unsigned elemBytes = getPTOStorageElemByteSize(srcElem); - if (elemBytes == 0) - return emitOpError() << "failed to get transpose element size"; - if (elemBytes != 1 && elemBytes != 2 && elemBytes != 4) - return emitOpError() << "expects transpose element size to be 1, 2, or 4 bytes"; - auto isAllowedWidthType = [&](Type ty) { - if (elemBytes == 4) - return ty.isInteger(32) || ty.isF32(); - if (elemBytes == 2) - return ty.isInteger(16) || ty.isF16() || ty.isBF16(); - return ty.isInteger(8); - }; - if (!isAllowedWidthType(srcElem)) - return emitOpError() << "expects transpose element type to match the supported set for its width"; - auto checkAlignedMajor = [&](Type ty, StringRef name) -> LogicalResult { - auto tb = mlir::dyn_cast(ty); - if (!tb) - return success(); - auto shape = getShapeVec(ty); - if (shape.size() != 2) - return success(); - bool rowMajor = tb.getBLayoutValueI32() == static_cast(pto::BLayout::RowMajor); - int64_t major = rowMajor ? shape[1] : shape[0]; - if (major != ShapedType::kDynamic && (major * static_cast(elemBytes)) % 32 != 0) - return emitOpError() << "expects " << name << " major dimension times element size to be 32-byte aligned on A5"; - return success(); - }; - if (failed(checkAlignedMajor(srcTy, "src")) || failed(checkAlignedMajor(dstTy, "dst"))) - return failure(); - return mlir::success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TXorOp::verify() { - auto verifyBase = [&]() -> FailureOr { - return verifyMatchingRowMajorBinaryTileOpCommon( - getOperation(), getSrc0().getType(), getSrc1().getType(), - getDst().getType()); - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - FailureOr elemOr = verifyBase(); - if (failed(elemOr)) - return failure(); - Type tmpTy = getTmp().getType(); - if (failed(verifyTileBufCommon(*this, tmpTy, "tmp"))) - return failure(); - Type elem = *elemOr; - if (getElemTy(tmpTy) != elem) - return emitOpError("expects tmp to have the same element type as src0, src1, and dst"); - if (!isRowMajorTileBuf(tmpTy)) - return emitOpError("expects tmp to use row-major layout"); - if (failed(verifyTileBufSameValidShape(*this, tmpTy, getDst().getType(), "tmp", "dst"))) - return failure(); - auto it = mlir::dyn_cast(elem); - if (!it || (it.getWidth() != 8 && it.getWidth() != 16)) - return emitOpError( - "expects A2/A3 txor src0, src1, tmp, and dst element type to be i8/i16"); - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - FailureOr elemOr = verifyBase(); - if (failed(elemOr)) - return failure(); - auto it = mlir::dyn_cast(*elemOr); - if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && - it.getWidth() != 32)) - return emitOpError( - "expects A5 txor src0, src1, and dst element type to be i8/i16/i32"); - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - - -mlir::LogicalResult mlir::pto::TXorSOp::verify() { - auto verifyCommon = [&]() -> FailureOr { - return verifyDistinctRowMajorUnaryTileOpCommon(getOperation(), getSrc(), - getDst(), "src", "dst"); - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - FailureOr elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - auto it = mlir::dyn_cast(*elemOr); - if (!it || (it.getWidth() != 8 && it.getWidth() != 16)) - return emitOpError( - "expects A2/A3 txors src and dst element type to be i8/i16"); - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - FailureOr elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - auto it = mlir::dyn_cast(*elemOr); - if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && - it.getWidth() != 32)) - return emitOpError( - "expects A5 txors src and dst element type to be i8/i16/i32"); - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} -mlir::LogicalResult mlir::pto::TPrintOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - auto srcType = getSrc().getType(); - if (auto tb = mlir::dyn_cast(srcType)) { - auto elem = tb.getElementType(); - if (!(elem.isF16() || elem.isF32() || - elem.isInteger(8) || elem.isInteger(16) || elem.isInteger(32))) - return emitOpError() << "expects printable tile element type"; - auto space = getPTOMemorySpaceEnum(srcType); - if (!space || *space != pto::AddressSpace::VEC) - return emitOpError() << "expects printable tile_buf to be in vec address space"; - return success(); - } - if (mlir::dyn_cast(srcType) || - mlir::dyn_cast(srcType)) - return mlir::success(); - return emitOpError() << "expects tile_buf, memref, or partition_tensor_view for src"; -} - - - -[[maybe_unused]] static LogicalResult verifyMatmulCommon(Operation *op, Value lhs, Value rhs, - Value biasOpt, Type maybeDstElemTy, - Type maybeResultElemTy) { - // ---- case A: tensor/memref (ShapedType) ---- - if (auto lhsTy = dyn_cast(lhs.getType())) { - auto rhsTy = dyn_cast(rhs.getType()); - if (!rhsTy || !lhsTy.hasRank() || !rhsTy.hasRank()) - return op->emitOpError("expects lhs and rhs to be ranked tensors or memrefs"); - - if (lhsTy.getElementType() != rhsTy.getElementType()) - return op->emitOpError() - << "expects lhs and rhs to have the same element type, but got lhs=" - << lhsTy.getElementType() << " rhs=" << rhsTy.getElementType(); - - if (biasOpt) { - auto biasTy = dyn_cast(biasOpt.getType()); - if (!biasTy || !biasTy.hasRank()) - return op->emitOpError("expects bias to be a ranked tensor or memref"); - if (biasTy.getElementType() != lhsTy.getElementType()) - return op->emitOpError() - << "expects bias to have the same element type as lhs and rhs, but got bias=" - << biasTy.getElementType() << " vs " << lhsTy.getElementType(); - } - - if (maybeDstElemTy && maybeDstElemTy != lhsTy.getElementType()) - return op->emitOpError() - << "expects dst to have the same element type as lhs and rhs, but got dst=" - << maybeDstElemTy << " vs " << lhsTy.getElementType(); - - if (maybeResultElemTy && maybeResultElemTy != lhsTy.getElementType()) - return op->emitOpError() - << "expects result to have the same element type as lhs and rhs, but got result=" - << maybeResultElemTy << " vs " << lhsTy.getElementType(); - - return success(); - } - - // ---- case B: tile ---- - auto lhsTile = dyn_cast(lhs.getType()); - auto rhsTile = dyn_cast(rhs.getType()); - if (!lhsTile || !rhsTile) - return op->emitOpError("expects lhs and rhs to be ranked tensors, memrefs, or !pto.tile"); - - if (lhsTile.getElementType() != rhsTile.getElementType()) - return op->emitOpError() << "expects lhs and rhs tiles to have the same element type, but got lhs=" - << lhsTile.getElementType() << " rhs=" << rhsTile.getElementType(); - - if ((int64_t)lhsTile.getShape().size() != 2 || (int64_t)rhsTile.getShape().size() != 2) - return op->emitOpError("expects lhs and rhs tiles to be 2D"); - - if (lhsTile.getShape()[1] != rhsTile.getShape()[0]) - return op->emitOpError() << "expects lhs dim1 to equal rhs dim0, but got " - << lhsTile.getShape()[1] << " vs " << rhsTile.getShape()[0]; - - if (biasOpt) { - auto biasTile = dyn_cast(biasOpt.getType()); - if (!biasTile) - return op->emitOpError("expects bias to be !pto.tile when lhs and rhs are !pto.tile"); - if (biasTile.getElementType() != lhsTile.getElementType()) - return op->emitOpError("expects bias to have the same element type as lhs and rhs"); - } - - if (maybeDstElemTy && maybeDstElemTy != lhsTile.getElementType()) - return op->emitOpError() << "expects dst to have the same element type as lhs and rhs"; - - if (maybeResultElemTy && maybeResultElemTy != lhsTile.getElementType()) - return op->emitOpError() << "expects result to have the same element type as lhs and rhs"; - - return success(); -} - -LogicalResult mlir::pto::TMatmulOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - if (failed(verifyMatTileOperands(*this, getLhs().getType(), getRhs().getType(), - getDst().getType()))) - return failure(); - if (failed(verifyMatmulTypeTriple(*this, getElemTy(getLhs().getType()), - getElemTy(getRhs().getType()), - getElemTy(getDst().getType())))) - return failure(); - return verifyMatmulLike(*this, getLhs().getType(), getRhs().getType(), - getDst().getType()); - }; - auto verifyA5 = [&]() -> LogicalResult { return verifyA2A3(); }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult mlir::pto::TGemvOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - if (failed(verifyGemvTileOperands(*this, getLhs().getType(), getRhs().getType(), - getDst().getType()))) - return failure(); - if (failed(verifyMatmulTypeTriple(*this, getElemTy(getLhs().getType()), - getElemTy(getRhs().getType()), - getElemTy(getDst().getType())))) - return failure(); - return verifyMatmulLike(*this, getLhs().getType(), getRhs().getType(), - getDst().getType()); - }; - auto verifyA5 = [&]() -> LogicalResult { return verifyA2A3(); }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult mlir::pto::TMatmulAccOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - if (failed(verifyAccTileCommon(*this, getAccIn().getType(), "acc_in")) || - failed(verifyMatTileOperands(*this, getLhs().getType(), getRhs().getType(), - getDst().getType()))) - return failure(); - return success(); -} - -LogicalResult mlir::pto::TGemvAccOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - if (failed(verifyAccTileCommon(*this, getAccIn().getType(), "acc_in")) || - failed(verifyGemvTileOperands(*this, getLhs().getType(), getRhs().getType(), - getDst().getType()))) - return failure(); - return success(); -} - -//===----------------------------------------------------------------------===// -// inferReturnTypes() for matmul ops (keep your existing code) -//===----------------------------------------------------------------------=== -[[maybe_unused]] static mlir::Type inferMatmulTileResult2DFromAB(MLIRContext *context, ValueRange operands) { - if (operands.size() < 2) - return mlir::Type(); - - auto lhsTile = dyn_cast(operands[0].getType()); - auto rhsTile = dyn_cast(operands[1].getType()); - if (!lhsTile || !rhsTile) - return mlir::Type(); - - Type elemTy = lhsTile.getElementType(); - - if (operands.size() >= 3) { - if (auto biasTile = dyn_cast(operands[2].getType())) { - return mlir::pto::TileType::get(context, biasTile.getShape(), elemTy); - } - } - - auto lhsShape = lhsTile.getShape(); - auto rhsShape = rhsTile.getShape(); - if (lhsShape.size() >= 2 && rhsShape.size() >= 2) { - int64_t M = lhsShape[0]; - int64_t N = rhsShape[1]; - llvm::SmallVector outShape = {M, N}; - return mlir::pto::TileType::get(context, outShape, elemTy); - } - - return mlir::Type(); -} - -[[maybe_unused]] static RankedTensorType inferMatmulResult2DFromAB(ValueRange operands) { - if (operands.size() < 2) - return RankedTensorType(); - - auto lhsTy = dyn_cast(operands[0].getType()); - auto rhsTy = dyn_cast(operands[1].getType()); - if (!lhsTy || !rhsTy || !lhsTy.hasRank() || !rhsTy.hasRank()) - return RankedTensorType(); - - Type elemTy = lhsTy.getElementType(); - - if (operands.size() >= 3) { - if (auto biasRT = dyn_cast(operands[2].getType())) - return RankedTensorType::get(biasRT.getShape(), elemTy); - if (auto biasMR = dyn_cast(operands[2].getType())) { - if (biasMR.hasStaticShape()) - return RankedTensorType::get(biasMR.getShape(), elemTy); - } - } - - if (lhsTy.getRank() >= 2 && rhsTy.getRank() >= 2) { - int64_t M = lhsTy.getDimSize(0); - int64_t N = rhsTy.getDimSize(1); - return RankedTensorType::get({M, N}, elemTy); - } - - return RankedTensorType(); -} - -[[maybe_unused]] static RankedTensorType inferAccReturnFromAccIn(ValueRange operands) { - if (operands.empty()) - return RankedTensorType(); - if (auto accRT = dyn_cast(operands[0].getType())) - return accRT; - return RankedTensorType(); -} - -namespace mlir { -namespace pto { - -static LogicalResult parseShapeAndElem(AsmParser &parser, - SmallVectorImpl &shape, - Type &elementType, - bool allowDynamic) { - if (parser.parseLess()) - return failure(); - - if (parser.parseDimensionList(shape, allowDynamic)) - return failure(); - - if (parser.parseType(elementType)) - return failure(); - - if (parser.parseGreater()) - return failure(); - - return success(); -} - -static void printShapeAndElem(AsmPrinter &printer, - ArrayRef shape, - Type elementType) { - printer << "<"; - for (auto d : shape) { - if (d == ShapedType::kDynamic) - printer << "?"; - else - printer << d; - printer << "x"; - } - printer.printType(elementType); - printer << ">"; -} - -// ============================================================================= -// PartitionTensorViewType Implementation -// ============================================================================= - -Type PartitionTensorViewType::parse(AsmParser &parser) { - SmallVector shape; - Type elemTy; - if (failed(parseShapeAndElem(parser, shape, elemTy, /*allowDynamic=*/true))) - return Type(); - - return PartitionTensorViewType::get(parser.getContext(), shape, elemTy); -} - -void PartitionTensorViewType::print(AsmPrinter &printer) const { - printShapeAndElem(printer, getShape(), getElementType()); -} - -// ---- TileType ---- -Type TileType::parse(AsmParser &parser) { - SmallVector shape; - Type elemTy; - if (failed(parseShapeAndElem(parser, shape, elemTy, /*allowDynamic=*/true))) - return Type(); - return TileType::get(parser.getContext(), shape, elemTy); -} - -void TileType::print(AsmPrinter &printer) const { - printShapeAndElem(printer, getShape(), getElementType()); -} - -// ---- LocalArrayType ---- -// Asm form: !pto.local_array -// Static shape only (no '?'). Element type must be a scalar; this is enforced -// by the type verifier below. -Type LocalArrayType::parse(AsmParser &parser) { - SmallVector shape; - Type elemTy; - if (failed(parseShapeAndElem(parser, shape, elemTy, /*allowDynamic=*/false))) - return Type(); - return LocalArrayType::getChecked( - [&]() { return parser.emitError(parser.getNameLoc()); }, - parser.getContext(), shape, elemTy); -} - -void LocalArrayType::print(AsmPrinter &printer) const { - printShapeAndElem(printer, getShape(), getElementType()); -} - -LogicalResult LocalArrayType::verify( - llvm::function_ref emitError, - llvm::ArrayRef shape, Type elementType) { - if (shape.empty()) - return emitError() << "'!pto.local_array' requires at least one dimension"; - for (auto [i, d] : llvm::enumerate(shape)) { - if (d <= 0) - return emitError() - << "'!pto.local_array' dimension " << i - << " must be a positive static size, got " << d; - } - if (!elementType.isIntOrFloat()) - return emitError() - << "'!pto.local_array' element type must be a scalar integer or " - "float, got " - << elementType; - return success(); -} - -// ============================================================================= -// Decompose Helper (Reverse Engineering AffineMap -> Strides) -// ============================================================================= - -// Helper: 递归地将 Add 表达式拆解为单独的项列表 -static void flattenAddExpr(AffineExpr expr, SmallVectorImpl &terms) { - if (auto add = llvm::dyn_cast(expr)) { - if (add.getKind() == AffineExprKind::Add) { - flattenAddExpr(add.getLHS(), terms); - flattenAddExpr(add.getRHS(), terms); - return; - } - } - terms.push_back(expr); -} - -// Helper: 从 AffineMap 中提取 Strides -static void decomposeStridedLayout(AffineMap map, SmallVectorImpl &strides) { - // 1. 初始化 - strides.assign(map.getNumDims(), 0); - - if (map.getNumResults() != 1) return; - - // 2. 摊平表达式 - SmallVector terms; - flattenAddExpr(map.getResult(0), terms); - - // 3. 分析每一项 - for (auto term : terms) { - // 情况 A: dN * Const 或 Const * dN - if (auto mul = llvm::dyn_cast(term)) { - if (mul.getKind() == AffineExprKind::Mul) { - AffineExpr lhs = mul.getLHS(); - AffineExpr rhs = mul.getRHS(); - - // 尝试匹配 LHS=Dim, RHS=Const - if (auto dim = llvm::dyn_cast(lhs)) { - if (auto cst = llvm::dyn_cast(rhs)) { - strides[dim.getPosition()] = cst.getValue(); - continue; - } - } - - // 尝试匹配 LHS=Const, RHS=Dim (乘法交换律) - if (auto dim = llvm::dyn_cast(rhs)) { - if (auto cst = llvm::dyn_cast(lhs)) { - strides[dim.getPosition()] = cst.getValue(); - continue; - } - } - } - } - // 情况 B: 单独的 dN (隐含 Stride = 1) - else if (auto dim = llvm::dyn_cast(term)) { - strides[dim.getPosition()] = 1; - } - } -} - -// ============================================================================= -// [Critical] Strict Alignment Protocol Helper -// ============================================================================= -// This function is the SINGLE source of truth for building the AffineMap. -// Both the Parser and the Op Inference MUST use this exact function. -// It ensures that the order of AffineExpr addition is: -// 0 + (d0*str0 + d1*str1...) + (s0*str0 + s1*str1...) -// This guarantees bitwise-identical AffineMaps for verification. -static AffineMap buildStrictBitwiseAffineMap(MLIRContext *ctx, - ArrayRef strides, - bool isMultiDimSymbol) { - unsigned rank = strides.size(); - - // Step 1: Initialize with Constant(0) - AffineExpr totalExpr = getAffineConstantExpr(0, ctx); - - // Step 2: Add Dimensions (d0*str0 + d1*str1...) - // Strictly in order: 0, 1, 2... - for (unsigned i = 0; i < rank; ++i) { - auto dim = getAffineDimExpr(i, ctx); - auto str = getAffineConstantExpr(strides[i], ctx); - totalExpr = totalExpr + (dim * str); - } - - // Step 3: Add Symbols (s0*str0 + s1*str1...) - // Strictly in order: 0, 1, 2... - if (isMultiDimSymbol) { - for (unsigned i = 0; i < rank; ++i) { - auto sym = getAffineSymbolExpr(i, ctx); - auto str = getAffineConstantExpr(strides[i], ctx); - totalExpr = totalExpr + (sym * str); - } - } - // (Optional: handle single dynamic offset case if needed, omitted for clarity) - - // numSymbols is rank if multi-dim (for offsets), else 0 - unsigned numSymbols = isMultiDimSymbol ? rank : 0; - return AffineMap::get(rank, numSymbols, totalExpr); -} - - -// ============================================================================= -// Parser Implementation -// ============================================================================= - -// Helper for parsing [64, 1] -static ParseResult parseStrideList(AsmParser &parser, SmallVectorImpl &strides) { - if (parser.parseLSquare()) return failure(); - do { - int64_t stride; - if (parser.parseInteger(stride)) return failure(); - strides.push_back(stride); - } while (succeeded(parser.parseOptionalComma())); - if (parser.parseRSquare()) return failure(); - return success(); -} - -// The custom attribute parser for: strided<[64, 1], offset: [?, ?]> -[[maybe_unused]] static ParseResult parseStridedLayout(AsmParser &parser, Attribute &layout) { - if (parser.parseLess()) return failure(); - - // 1. Parse Strides - SmallVector strides; - if (parseStrideList(parser, strides)) return failure(); - - bool isMultiDim = false; - unsigned numSymbols = 0; - - // 2. Parse Offset - if (succeeded(parser.parseOptionalComma())) { - if (parser.parseKeyword("offset") || parser.parseColon()) return failure(); - - // Check for multi-dim syntax: [?, ?] - if (succeeded(parser.parseOptionalLSquare())) { - isMultiDim = true; - do { - if (parser.parseQuestion()) return failure(); - numSymbols++; - } while (succeeded(parser.parseOptionalComma())); - if (parser.parseRSquare()) return failure(); - } else { - // Fallback for old scalar syntax '?' - if (parser.parseOptionalQuestion()) { /* handle single scalar */ } - } - } - - if (parser.parseGreater()) return failure(); - - // 3. Validation - if (isMultiDim && numSymbols != strides.size()) { - return parser.emitError(parser.getCurrentLocation(), - "Number of offset symbols must match rank"); - } - - // 4. [CALL SHARED BUILDER] - // Delegate to the strict builder - MLIRContext *ctx = parser.getContext(); - AffineMap map = buildStrictBitwiseAffineMap(ctx, strides, isMultiDim); - - layout = AffineMapAttr::get(map); - return success(); -} - -// ============================================================================= -// Printer Implementation -// ============================================================================= - -[[maybe_unused]] static void printLayout(AsmPrinter &printer, Attribute layoutAttr) { - if (!layoutAttr) return; - auto mapAttr = llvm::dyn_cast(layoutAttr); - if (!mapAttr) { printer << ", " << layoutAttr; return; } - - AffineMap map = mapAttr.getValue(); - if (map.isIdentity()) return; - - // 1. [核心修改] 反解 Strides - SmallVector strides; - decomposeStridedLayout(map, strides); - - printer << ", strided<["; - // 2. 打印真实的 strides - llvm::interleaveComma(strides, printer); - printer << "]"; - - // Print Offset: [?, ?] - unsigned numSyms = map.getNumSymbols(); - if (numSyms > 0) { - printer << ", offset: ["; - for (unsigned i = 0; i < numSyms; ++i) { - printer << "?"; - if (i < numSyms - 1) printer << ", "; - } - printer << "]"; - } - printer << ">"; -} - -// ---- TileBuf --- - - -// Tile subview 相关实现 - -// ============================================================================= -// Op Interface Implementation: SubViewOp -// ============================================================================= - -ParseResult mlir::pto::SubViewOp::parse(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::UnresolvedOperand source; - SmallVector offsets; - SmallVector valids; - Type sourceTy; - Type resultTy; - bool hasExplicitResultTy = false; - - if (parser.parseOperand(source) || parser.parseLSquare() || - parser.parseOperandList(offsets) || parser.parseRSquare() || - parser.parseKeyword("sizes")) - return failure(); - - ArrayAttr sizesAttr; - if (parser.parseAttribute(sizesAttr, "sizes", result.attributes)) - return failure(); - - if (succeeded(parser.parseOptionalKeyword("valid"))) { - OpAsmParser::UnresolvedOperand vrow, vcol; - if (parser.parseLSquare() || parser.parseOperand(vrow) || parser.parseComma() || - parser.parseOperand(vcol) || parser.parseRSquare()) - return failure(); - valids.push_back(vrow); - valids.push_back(vcol); - } - - if (parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(sourceTy)) - return failure(); - - if (succeeded(parser.parseOptionalArrow())) { - if (parser.parseType(resultTy)) - return failure(); - hasExplicitResultTy = true; - } - - if (parser.resolveOperand(source, sourceTy, result.operands)) - return failure(); - - Type indexTy = parser.getBuilder().getIndexType(); - if (parser.resolveOperands(offsets, indexTy, result.operands)) - return failure(); - if (!valids.empty() && - parser.resolveOperands(valids, indexTy, result.operands)) - return failure(); - - int32_t hasValid = valids.empty() ? 0 : 1; - result.addAttribute( - "operandSegmentSizes", - parser.getBuilder().getDenseI32ArrayAttr( - {1, static_cast(offsets.size()), hasValid, hasValid})); - - if (hasExplicitResultTy) { - result.addTypes(resultTy); - return success(); - } - - SmallVector inferredReturnTypes; - DictionaryAttr attrs = result.attributes.getDictionary(parser.getContext()); - if (failed(SubViewOp::inferReturnTypes( - parser.getContext(), std::nullopt, result.operands, attrs, nullptr, - RegionRange(), inferredReturnTypes))) { - return parser.emitError(parser.getCurrentLocation(), - "failed to infer pto.subview result type"); - } - result.addTypes(inferredReturnTypes); - return success(); -} - -void mlir::pto::SubViewOp::print(OpAsmPrinter &printer) { - printer << " " << getSource() << "["; - printer.printOperands(getOffsets()); - printer << "] sizes " << getSizes(); - if (getValidRow()) { - printer << " valid [" << getValidRow() << ", " << getValidCol() << "]"; - } - printer.printOptionalAttrDict((*this)->getAttrs(), - /*elidedAttrs=*/{"operandSegmentSizes", - "sizes"}); - printer << " : " << getSource().getType() << " -> " << getResult().getType(); -} - -LogicalResult SubViewOp::inferReturnTypes( - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - - // 1. 获取 Source Type - if (operands.empty()) return failure(); - auto sourceType = llvm::dyn_cast(operands[0].getType()); - if (!sourceType) return failure(); - - // 2. 获取 subview 逻辑窗口(sizes) - ArrayAttr sizeAttr; - if (properties) { - const auto *prop = properties.as(); - if (prop) sizeAttr = prop->sizes; - } - if (!sizeAttr && attributes) { - sizeAttr = attributes.getAs("sizes"); - } - if (!sizeAttr) return failure(); - - SmallVector subviewShape; - for (auto attr : sizeAttr) { - int64_t dim = llvm::cast(attr).getInt(); - subviewShape.push_back(dim); - } - - // Design: subview 的结果 tile 类型显式表达逻辑子窗口 shape(sizes)。 - ArrayRef parentShape = sourceType.getShape(); - if (subviewShape.size() != parentShape.size()) - return failure(); - - // Derive valid shape from explicit valid_row/valid_col when provided. - // Otherwise default to subview shape (no parent valid-shape inheritance). - SmallVector validShape; - constexpr int64_t kDynamicValidDim = -1; - int64_t rank = static_cast(subviewShape.size()); - Value explicitVRow; - Value explicitVCol; - - // Robustly decode optional valid operands using AttrSizedOperandSegments: - // [source, offsets..., valid_row?, valid_col?] - if (attributes) { - if (auto segAttr = - attributes.getAs("operandSegmentSizes")) { - ArrayRef segs = segAttr.asArrayRef(); - if (segs.size() == 4) { - int32_t srcSeg = segs[0]; - int32_t offSeg = segs[1]; - int32_t vRowSeg = segs[2]; - int32_t vColSeg = segs[3]; - if (srcSeg == 1 && offSeg >= 0 && (vRowSeg == 0 || vRowSeg == 1) && - (vColSeg == 0 || vColSeg == 1)) { - size_t idx = static_cast(srcSeg + offSeg); - if (vRowSeg == 1 && idx < operands.size()) - explicitVRow = operands[idx++]; - if (vColSeg == 1 && idx < operands.size()) - explicitVCol = operands[idx]; - } - } - } - } - - // Fallback for legacy callers that may not provide operandSegmentSizes. - if (!explicitVRow && !explicitVCol && rank == 2) { - size_t expectedWithoutValid = static_cast(1 + rank); - if (operands.size() >= expectedWithoutValid + 2) { - explicitVRow = operands[expectedWithoutValid]; - explicitVCol = operands[expectedWithoutValid + 1]; - } - } - - for (size_t i = 0, e = subviewShape.size(); i < e; ++i) { - int64_t vdim = subviewShape[i]; - Value explicitV = (i == 0) ? explicitVRow : (i == 1 ? explicitVCol : Value()); - if (explicitV) { - auto cst = getConstIndexValue(explicitV); - vdim = cst ? std::min(*cst, subviewShape[i]) : kDynamicValidDim; - } - validShape.push_back(vdim); - } - - // 3. 继承 Config (若为空使用默认) - auto cfg = sourceType.getConfigAttr(); - if (!cfg) cfg = TileBufConfigAttr::getDefault(context); - - // 4. 构建 Result Type - auto canonicalValidShape = canonicalizeTileBufValidShape(validShape); - auto resultType = TileBufType::get( - context, subviewShape, sourceType.getElementType(), - sourceType.getMemorySpace(), canonicalValidShape, cfg); - - inferredReturnTypes.push_back(resultType); - return success(); -} - -// ============================================================================= -// SubViewOp verifier -// ============================================================================= -static bool getConstIndex(Value v, int64_t &out) { - if (auto cOp = v.getDefiningOp()) { - out = cOp.value(); - return true; - } - if (auto cInt = v.getDefiningOp()) { - out = cInt.value(); - return true; - } - if (auto cOp = v.getDefiningOp()) { - if (auto ia = dyn_cast(cOp.getValue())) { - out = ia.getInt(); - return true; - } - } - if (auto castOp = v.getDefiningOp()) - return getConstIndex(castOp.getIn(), out); - if (auto extOp = v.getDefiningOp()) - return getConstIndex(extOp.getIn(), out); - if (auto extOp = v.getDefiningOp()) - return getConstIndex(extOp.getIn(), out); - if (auto truncOp = v.getDefiningOp()) - return getConstIndex(truncOp.getIn(), out); - return false; -} - -static LogicalResult computeInnerShape(TileBufConfigAttr cfg, Type elemTy, - int64_t &innerRows, int64_t &innerCols, - bool &boxed, int32_t &bl, int32_t &sl) { - auto readBLayoutI32 = [](Attribute attr, int32_t &out) -> bool { - if (auto a = dyn_cast(attr)) { - out = (int32_t)a.getValue(); - return true; - } - if (auto a = dyn_cast(attr)) { - out = (int32_t)a.getInt(); - return true; - } - return false; - }; - auto readSLayoutI32 = [](Attribute attr, int32_t &out) -> bool { - if (auto a = dyn_cast(attr)) { - out = (int32_t)a.getValue(); - return true; - } - if (auto a = dyn_cast(attr)) { - out = (int32_t)a.getInt(); - return true; - } - return false; - }; - bl = 0; - sl = 0; - int32_t fr = 512; - (void)readBLayoutI32(cfg.getBLayout(), bl); - (void)readSLayoutI32(cfg.getSLayout(), sl); - if (auto attr = dyn_cast(cfg.getSFractalSize())) fr = (int32_t)attr.getInt(); - - boxed = (sl != 0); - if (!boxed) { - innerRows = 1; - innerCols = 1; - return success(); - } - - int64_t elemBytes = static_cast(getElemByteSize(elemTy)); - if (elemBytes <= 0) return failure(); - - if (fr == 1024) { - innerRows = 16; - innerCols = 16; - return success(); - } - if (fr == 32) { - innerRows = 16; - innerCols = 2; - return success(); - } - if (fr == 512) { - if (sl == 1) { - innerRows = 16; - innerCols = 32 / elemBytes; - return success(); - } - if (sl == 2) { - innerRows = 32 / elemBytes; - innerCols = 16; - return success(); - } - } - return failure(); -} - -mlir::LogicalResult mlir::pto::SubViewOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - auto srcTy = llvm::dyn_cast(getSource().getType()); - auto dstTy = llvm::dyn_cast(getResult().getType()); - if (!srcTy || !dstTy) - return emitOpError("expects tile_buf src and tile_buf result"); - if (srcTy.getRank() != 2 || dstTy.getRank() != 2) - return emitOpError("expects rank-2 tilebuf for src/dst"); - - auto sizesAttr = getSizes(); - if (!sizesAttr || sizesAttr.size() != 2) - return emitOpError("subview expects 2D sizes"); - int64_t sizeR = cast(sizesAttr[0]).getInt(); - int64_t sizeC = cast(sizesAttr[1]).getInt(); - if (sizeR <= 0 || sizeC <= 0) - return emitOpError("subview sizes must be positive"); - if (getOffsets().size() != 2) - return emitOpError("subview expects 2D offsets"); - - int64_t offR = 0, offC = 0; - bool offRConst = getConstIndex(getOffsets()[0], offR); - bool offCConst = getConstIndex(getOffsets()[1], offC); - if (offRConst && offR < 0) - return emitOpError("subview offsets must be non-negative"); - if (offCConst && offC < 0) - return emitOpError("subview offsets must be non-negative"); - - bool hasValidRow = static_cast(getValidRow()); - bool hasValidCol = static_cast(getValidCol()); - if (hasValidRow != hasValidCol) - return emitOpError( - "subview expects valid_row and valid_col to be both present or both absent"); - - if (hasValidRow) { - int64_t vRow = 0, vCol = 0; - if (getConstIndex(getValidRow(), vRow)) { - if (vRow <= 0) - return emitOpError("valid_row must be positive when constant"); - if (vRow > sizeR) - return emitOpError("valid_row must be <= subview row size"); - } - if (getConstIndex(getValidCol(), vCol)) { - if (vCol <= 0) - return emitOpError("valid_col must be positive when constant"); - if (vCol > sizeC) - return emitOpError("valid_col must be <= subview col size"); - } - } - - auto dstShape = dstTy.getShape(); - if (dstShape.size() != 2) - return emitOpError("expects result to be rank-2"); - auto srcShape = srcTy.getShape(); - if (srcShape.size() != 2) - return emitOpError("expects source to be rank-2"); - if (dstShape[0] != sizeR || dstShape[1] != sizeC) - return emitOpError("expects result shape to match subview sizes"); - - if (dstTy.getElementType() != srcTy.getElementType()) - return emitOpError("expects result element type to match source"); - if (dstTy.getMemorySpace() != srcTy.getMemorySpace()) - return emitOpError("expects result address space to match source"); - auto srcCfg = srcTy.getConfigAttr(); - if (!srcCfg) srcCfg = TileBufConfigAttr::getDefault(getContext()); - auto dstCfg = dstTy.getConfigAttr(); - if (!dstCfg) dstCfg = TileBufConfigAttr::getDefault(getContext()); - if (dstCfg != srcCfg) - return emitOpError("expects result tile config to match source"); - - // Design choice: when valid[...] is omitted, infer result valid_shape from - // subview sizes directly. We intentionally do not constrain it by source - // valid_shape to allow user-controlled subview semantics. - - auto expectedValidDim = [&](Value explicitValid, int64_t defaultSize) { - if (!explicitValid) - return defaultSize; - int64_t c = 0; - if (getConstIndex(explicitValid, c)) - return std::min(c, defaultSize); - return ShapedType::kDynamic; - }; - int64_t expectedVRow = expectedValidDim(getValidRow(), sizeR); - int64_t expectedVCol = expectedValidDim(getValidCol(), sizeC); - auto dstValid = dstTy.getValidShape(); - if (dstValid.size() != 2) - return emitOpError("expects result to have rank-2 valid_shape"); - if (dstValid[0] != expectedVRow) - return emitOpError("expects result valid_shape[0] to match inferred/explicit valid_row"); - if (dstValid[1] != expectedVCol) - return emitOpError("expects result valid_shape[1] to match inferred/explicit valid_col"); - - auto cfg = srcTy.getConfigAttr(); - if (!cfg) cfg = TileBufConfigAttr::getDefault(getContext()); - - int64_t innerRows = 1, innerCols = 1; - bool boxed = false; - int32_t bl = 0, sl = 0; - if (failed(computeInnerShape(cfg, srcTy.getElementType(), innerRows, innerCols, - boxed, bl, sl))) - return emitOpError("unsupported tile layout for subview"); - - if (!boxed) - return success(); - - // Boxed layout: require static 2D sizes with inner alignment. Offsets may be - // dynamic, but static offsets must be aligned. - if (sizeR % innerRows != 0 || sizeC % innerCols != 0) - return emitOpError("boxed layout subview sizes must be multiples of inner shape"); - - if (offRConst) { - if (offR % innerRows != 0) - return emitOpError("boxed layout subview offsets must be multiples of inner shape"); - } - if (offCConst) { - if (offC % innerCols != 0) - return emitOpError("boxed layout subview offsets must be multiples of inner shape"); - } - - (void)bl; - if (srcShape.size() != 2 || - srcShape[0] == ShapedType::kDynamic || - srcShape[1] == ShapedType::kDynamic) { - return emitOpError("boxed layout subview requires static source shape"); - } - - return success(); -} - -} // namespace pto -} // namespace mlir - -using namespace mlir; -using namespace mlir::pto; - -// ============================================================================= -// Helper Functions -// ============================================================================= - -[[maybe_unused]] static AddressSpace getAddressSpace(Value val) { - auto type = llvm::dyn_cast(val.getType()); - if (!type) return AddressSpace::Zero; // Default - - // 假设你的 AddressSpaceAttr 存储在 MemRef 的 memorySpace 中 - // 需要根据你的 getPTOAddressSpaceAttr 实现来调整 - auto attr = llvm::dyn_cast_or_null(type.getMemorySpace()); - if (attr) return attr.getAddressSpace(); - return AddressSpace::Zero; -} - -// ============================================================================= -// Side Effects Implementation -// ============================================================================= - -// [Fix] 辅助函数:重载以支持 OpOperand* 和 OpResult,避免直接传 Value - -// 针对操作数 (Operand) 的重载 -static void addEffect( - SmallVectorImpl> &effects, - OpOperand *operand, MemoryEffects::Effect *effect) { - if (operand) - effects.emplace_back(effect, operand, SideEffects::DefaultResource::get()); -} - -// 针对结果 (Result) 的重载 -static void addEffect( - SmallVectorImpl> &effects, - OpResult result, MemoryEffects::Effect *effect) { - if (result) - effects.emplace_back(effect, result, SideEffects::DefaultResource::get()); -} - -// === TLoadOp === -// Read: src, Write: dst -// 针对 OpOperand* 的重载 -void TLoadOp::getEffects(SmallVectorImpl> &effects) { - // [Fix] 单个操作数,直接取地址 - addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); -} - -void TPrefetchOp::getEffects( - SmallVectorImpl> &effects) { - addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); -} - -// === TAbsOp === -// Read: src, Write: dst -void TAbsOp::getEffects( - SmallVectorImpl> &effects) { - addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); -} - -// === TStoreOp === -// Read: src, Write: dst (GM) -void TStoreOp::getEffects(SmallVectorImpl> &effects) { - addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); - auto preQuantRange = getPreQuantScalarMutable(); - if (!preQuantRange.empty()) - addEffect(effects, &*preQuantRange.begin(), MemoryEffects::Read::get()); - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); -} - -// === TMovOp === -// Read: src, Write: dst -void TMovOp::getEffects(SmallVectorImpl> &effects) { - addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); - auto fpRange = getFpMutable(); - if (!fpRange.empty()) - addEffect(effects, &*fpRange.begin(), MemoryEffects::Read::get()); - auto preQuantRange = getPreQuantScalarMutable(); - if (!preQuantRange.empty()) - addEffect(effects, &*preQuantRange.begin(), MemoryEffects::Read::get()); - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); -} - -#define PTO_ADD_READ(operand) addEffect(effects, &(operand), MemoryEffects::Read::get()) -#define PTO_ADD_WRITE(operand) addEffect(effects, &(operand), MemoryEffects::Write::get()) - -#define PTO_DEFINE_UNARY_EFFECTS(OpClass, srcOperand, dstOperand) \ - void OpClass::getEffects( \ - SmallVectorImpl> &effects) { \ - PTO_ADD_READ(srcOperand); \ - PTO_ADD_WRITE(dstOperand); \ - } - -#define PTO_DEFINE_BINARY_EFFECTS(OpClass, lhsOperand, rhsOperand, dstOperand) \ - void OpClass::getEffects( \ - SmallVectorImpl> &effects) { \ - PTO_ADD_READ(lhsOperand); \ - PTO_ADD_READ(rhsOperand); \ - PTO_ADD_WRITE(dstOperand); \ - } - -#define PTO_DEFINE_TERNARY_EFFECTS(OpClass, op0, op1, op2, dstOperand) \ - void OpClass::getEffects( \ - SmallVectorImpl> &effects) { \ - PTO_ADD_READ(op0); \ - PTO_ADD_READ(op1); \ - PTO_ADD_READ(op2); \ - PTO_ADD_WRITE(dstOperand); \ - } - -#define PTO_DEFINE_QUATERNARY_EFFECTS(OpClass, op0, op1, op2, op3, dstOperand) \ - void OpClass::getEffects( \ - SmallVectorImpl> &effects) { \ - PTO_ADD_READ(op0); \ - PTO_ADD_READ(op1); \ - PTO_ADD_READ(op2); \ - PTO_ADD_READ(op3); \ - PTO_ADD_WRITE(dstOperand); \ - } - -void LoadScalarOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getPtrMutable()); -} - -void StoreScalarOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_WRITE(getPtrMutable()); -} - -// === Tile/Device ops added for InsertSync === - -// MGATHER: Read(mem, idx) -> Write(dst) -void MGatherOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getMemMutable()); - PTO_ADD_READ(getIdxMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -// MSCATTER: Read(src, idx) -> Write(mem) -void MScatterOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_READ(getIdxMutable()); - PTO_ADD_WRITE(getMemMutable()); -} - -// TGETVAL: Read(src) -> scalar result -void TGetValOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); -} - -void THistogramOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_READ(getIdxMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -void TGetScaleAddrOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -// TSETVAL: Write(dst) (single element update) -void TSetValOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_WRITE(getDstMutable()); -} - -// SET_VALIDSHAPE: update runtime valid row/col metadata on source tile in-place. -void SetValidShapeOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_WRITE(getSourceMutable()); -} - -// GET_VALIDSHAPE: read runtime valid row/col metadata from source tile. -void GetValidShapeOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSourceMutable()); -} - -// Elementwise + reductions: mostly PIPE_V tilebuf ops -PTO_DEFINE_BINARY_EFFECTS(TAddOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_TERNARY_EFFECTS(TAddCOp, getSrc0Mutable(), getSrc1Mutable(), getSrc2Mutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TAddSOp, getSrcMutable(), getDstMutable()) -PTO_DEFINE_BINARY_EFFECTS(TAddSCOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -void TAxpyOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_READ(getScalarMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -PTO_DEFINE_BINARY_EFFECTS(TAndOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_BINARY_EFFECTS(TConcatOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_QUATERNARY_EFFECTS(TConcatidxOp, getSrc0Mutable(), getSrc1Mutable(), getSrc0IdxMutable(), getSrc1IdxMutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TAndSOp, getSrcMutable(), getDstMutable()) - -// TCI: Write(dst) (generates sequence) -void TCIOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_WRITE(getDstMutable()); -} - -// TTRI: Write(dst) (generates triangular mask) -void TTriOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_WRITE(getDstMutable()); -} - -PTO_DEFINE_BINARY_EFFECTS(TCmpOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TCmpSOp, getSrcMutable(), getDstMutable()) - -PTO_DEFINE_UNARY_EFFECTS(TColExpandOp, getSrcMutable(), getDstMutable()) -PTO_DEFINE_BINARY_EFFECTS(TColExpandAddOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_BINARY_EFFECTS(TColExpandMulOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_BINARY_EFFECTS(TColExpandDivOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_BINARY_EFFECTS(TColExpandSubOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_BINARY_EFFECTS(TColExpandExpdifOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_BINARY_EFFECTS(TColExpandMaxOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_BINARY_EFFECTS(TColExpandMinOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TColMaxOp, getSrcMutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TColMinOp, getSrcMutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TColProdOp, getSrcMutable(), getDstMutable()) - -void TColArgMaxOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_WRITE(getTmpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -void TColArgMinOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_WRITE(getTmpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -void TColSumOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - auto tmp = getTmpMutable(); - if (!tmp.empty()) { - PTO_ADD_WRITE(tmp[0]); - } - PTO_ADD_WRITE(getDstMutable()); -} - -void TCvtOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_WRITE(getDstMutable()); -} -void TRandomOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_WRITE(getDstMutable()); -} -PTO_DEFINE_BINARY_EFFECTS(TDivOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) - -// TDIVS has custom assembly format; conservatively treat first 2 operands as reads. -void TDivSOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_READ(getScalarMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -PTO_DEFINE_UNARY_EFFECTS(TExpOp, getSrcMutable(), getDstMutable()) - -// TEXPANDS: Write(dst) (broadcast scalar) -void TExpandsOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_WRITE(getDstMutable()); -} - -// TEXTRACT: Read(src) -> Write(dst) -void TExtractOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -// TINSERT: Read(src) -> Write(dst) -void TInsertOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -// TEXTRACT_FP: Read(src), Read(fp) -> Write(dst) -void TExtractFPOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_READ(getFpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -// TINSERT_FP: Read(src), Read(fp) -> Write(dst) -void TInsertFPOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_READ(getFpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -PTO_DEFINE_UNARY_EFFECTS(TFillPadOp, getSrcMutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TFillPadExpandOp, getSrcMutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TFillPadInplaceOp, getSrcMutable(), getDstMutable()) - -void TGatherOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - if (auto cdst = getCdstMutable(); !cdst.empty()) - PTO_ADD_WRITE(cdst[0]); - if (auto indices = getIndicesMutable(); !indices.empty()) - PTO_ADD_READ(indices[0]); - if (auto tmp = getTmpMutable(); !tmp.empty()) - PTO_ADD_READ(tmp[0]); - PTO_ADD_WRITE(getDstMutable()); -} - -PTO_DEFINE_BINARY_EFFECTS(TGatherBOp, getSrcMutable(), getOffsetsMutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TLogOp, getSrcMutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TLReluOp, getSrcMutable(), getDstMutable()) - -PTO_DEFINE_BINARY_EFFECTS(TMaxOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TMaxSOp, getSrcMutable(), getDstMutable()) -PTO_DEFINE_BINARY_EFFECTS(TMinOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TMinSOp, getSrcMutable(), getDstMutable()) - -PTO_DEFINE_BINARY_EFFECTS(TMovFPOp, getSrcMutable(), getFpMutable(), getDstMutable()) - -void TMrgSortOp::getEffects( - SmallVectorImpl> &effects) { - for (auto &opnd : getSrcsMutable()) { - PTO_ADD_READ(opnd); - } - auto tmp = getTmpMutable(); - if (!tmp.empty()) - PTO_ADD_WRITE(tmp[0]); - for (auto &opnd : getDstsMutable()) { - PTO_ADD_WRITE(opnd); - } - auto executed = getExcutedMutable(); - if (!executed.empty()) { - PTO_ADD_WRITE(executed[0]); - } -} - -PTO_DEFINE_BINARY_EFFECTS(TMulOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TMulSOp, getSrc0Mutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TNegOp, getSrcMutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TNotOp, getSrcMutable(), getDstMutable()) -PTO_DEFINE_BINARY_EFFECTS(TOrOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TOrSOp, getSrcMutable(), getDstMutable()) - -PTO_DEFINE_BINARY_EFFECTS(TPartAddOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_BINARY_EFFECTS(TPartMaxOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_BINARY_EFFECTS(TPartMinOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -void TPartArgMaxOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrc0Mutable()); - PTO_ADD_READ(getSrc1Mutable()); - PTO_ADD_READ(getSrc0IdxMutable()); - PTO_ADD_READ(getSrc1IdxMutable()); - PTO_ADD_WRITE(getDstMutable()); - PTO_ADD_WRITE(getDstIdxMutable()); -} -void TPartArgMinOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrc0Mutable()); - PTO_ADD_READ(getSrc1Mutable()); - PTO_ADD_READ(getSrc0IdxMutable()); - PTO_ADD_READ(getSrc1IdxMutable()); - PTO_ADD_WRITE(getDstMutable()); - PTO_ADD_WRITE(getDstIdxMutable()); -} -PTO_DEFINE_BINARY_EFFECTS(TPartMulOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -// TPRELU: Read(src0, src1) -> Write(tmp, dst) -void TPReluOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrc0Mutable()); - PTO_ADD_READ(getSrc1Mutable()); - // A5 pto-isa TPRELU implementation does not consume tmp; modeling tmp as a - // write-only scratch on A5 incorrectly inflates local-memory planning and - // can trigger false vec-overflow diagnostics. - if (getTargetArch(getOperation()) != PTOArch::A5) - PTO_ADD_WRITE(getTmpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -void TQuantOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_READ(getFpMutable()); - auto offsetRange = getOffsetMutable(); - if (!offsetRange.empty()) - PTO_ADD_READ(offsetRange[0]); - PTO_ADD_WRITE(getDstMutable()); -} -PTO_DEFINE_TERNARY_EFFECTS(TDequantOp, getSrcMutable(), getScaleMutable(), - getOffsetMutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TRecipOp, getSrcMutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TReluOp, getSrcMutable(), getDstMutable()) -PTO_DEFINE_BINARY_EFFECTS(TFModOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TFModSOp, getSrcMutable(), getDstMutable()) -void TRemOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrc0Mutable()); - PTO_ADD_READ(getSrc1Mutable()); - PTO_ADD_WRITE(getTmpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -void TRemSOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_WRITE(getTmpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} -PTO_DEFINE_UNARY_EFFECTS(TRowExpandOp, getSrcMutable(), getDstMutable()) - -void TRowExpandDivOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrc0Mutable()); - PTO_ADD_READ(getSrc1Mutable()); - auto tmp = getTmpMutable(); - if (!tmp.empty()) - PTO_ADD_WRITE(tmp[0]); - PTO_ADD_WRITE(getDstMutable()); -} - -void TRowExpandMulOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrc0Mutable()); - PTO_ADD_READ(getSrc1Mutable()); - auto tmp = getTmpMutable(); - if (!tmp.empty()) - PTO_ADD_WRITE(tmp[0]); - PTO_ADD_WRITE(getDstMutable()); -} - -void TRowExpandSubOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrc0Mutable()); - PTO_ADD_READ(getSrc1Mutable()); - auto tmp = getTmpMutable(); - if (!tmp.empty()) - PTO_ADD_WRITE(tmp[0]); - PTO_ADD_WRITE(getDstMutable()); -} - -PTO_DEFINE_BINARY_EFFECTS(TRowExpandAddOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) - -void TRowExpandExpdifOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrc0Mutable()); - PTO_ADD_READ(getSrc1Mutable()); - auto tmp = getTmpMutable(); - if (!tmp.empty()) - PTO_ADD_WRITE(tmp[0]); - PTO_ADD_WRITE(getDstMutable()); -} - -void TRowExpandMaxOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrc0Mutable()); - PTO_ADD_READ(getSrc1Mutable()); - auto tmp = getTmpMutable(); - if (!tmp.empty()) - PTO_ADD_WRITE(tmp[0]); - PTO_ADD_WRITE(getDstMutable()); -} - -void TRowExpandMinOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrc0Mutable()); - PTO_ADD_READ(getSrc1Mutable()); - auto tmp = getTmpMutable(); - if (!tmp.empty()) - PTO_ADD_WRITE(tmp[0]); - PTO_ADD_WRITE(getDstMutable()); -} - -// Row reductions use tmp scratch tile. -void TRowMaxOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_WRITE(getTmpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -void TRowArgMaxOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - // A5 lowering does not consume tmp for TROWARGMAX; modeling tmp as a - // scratch write inflates local-memory planning and can trigger false - // vec-overflow diagnostics, mirroring the fixed A5 TPRELU issue. - if (getTargetArch(getOperation()) != PTOArch::A5) - PTO_ADD_WRITE(getTmpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -void TRowMinOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_WRITE(getTmpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -void TRowArgMinOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - // A5 lowering does not consume tmp for TROWARGMIN; modeling tmp as a - // scratch write inflates local-memory planning and can trigger false - // vec-overflow diagnostics, mirroring the fixed A5 TPRELU issue. - if (getTargetArch(getOperation()) != PTOArch::A5) - PTO_ADD_WRITE(getTmpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -void TRowSumOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_WRITE(getTmpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -void TRowProdOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_WRITE(getTmpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} -void TRsqrtOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - auto tmp = getTmpMutable(); - if (!tmp.empty()) - PTO_ADD_WRITE(tmp[0]); - PTO_ADD_WRITE(getDstMutable()); -} - -void TScatterOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - if (getIndexes()) { - auto idx = getIndexesMutable(); - if (!idx.empty()) - PTO_ADD_READ(idx[0]); - } - PTO_ADD_WRITE(getDstMutable()); -} - -// Select: Read(mask, src0, src1) -> Write(tmp, dst) -void TSelOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getMaskMutable()); - PTO_ADD_READ(getSrc0Mutable()); - PTO_ADD_READ(getSrc1Mutable()); - PTO_ADD_WRITE(getTmpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -// TSELS: Read(src0, src1) -> Write(tmp, dst) -void TSelSOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getMaskMutable()); - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_WRITE(getTmpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -PTO_DEFINE_BINARY_EFFECTS(TShlOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_BINARY_EFFECTS(TShrOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TShlSOp, getSrcMutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TShrSOp, getSrcMutable(), getDstMutable()) - -// TSORT32: Read(src, idx) -> Write(dst [, tmp]) -void TSort32Op::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_READ(getIdxMutable()); - auto tmp = getTmpMutable(); - if (!tmp.empty()) - PTO_ADD_WRITE(tmp[0]); - PTO_ADD_WRITE(getDstMutable()); -} - -PTO_DEFINE_UNARY_EFFECTS(TSqrtOp, getSrcMutable(), getDstMutable()) -PTO_DEFINE_BINARY_EFFECTS(TSubOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_TERNARY_EFFECTS(TSubCOp, getSrc0Mutable(), getSrc1Mutable(), getSrc2Mutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TSubSOp, getSrcMutable(), getDstMutable()) -PTO_DEFINE_BINARY_EFFECTS(TSubSCOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) - -// TXORS: Read(src) -> Write(tmp, dst) -void TXorSOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_WRITE(getTmpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -// TXOR: Read(src0, src1) -> Write(tmp?, dst) -void TXorOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrc0Mutable()); - PTO_ADD_READ(getSrc1Mutable()); - PTO_ADD_WRITE(getTmpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -// TTRANS: Read(src) -> Write(tmp, dst) -void TTransOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_WRITE(getTmpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -void TPrintOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_WRITE(getSrcMutable()); -} - -#undef PTO_DEFINE_TERNARY_EFFECTS -#undef PTO_DEFINE_BINARY_EFFECTS -#undef PTO_DEFINE_UNARY_EFFECTS -#undef PTO_ADD_WRITE -#undef PTO_ADD_READ - -// === TMatmulOp === -// Read: lhs, rhs, (bias), Write: dst -void TMatmulOp::getEffects(SmallVectorImpl> &effects) { - // Singleton -> 直接取地址 - addEffect(effects, &getLhsMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getRhsMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); -} - -// === TMatmulAccOp === -// Read: acc_in, lhs, rhs, Write: dst -void TMatmulAccOp::getEffects(SmallVectorImpl> &effects) { - addEffect(effects, &getAccInMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getLhsMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getRhsMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); -} - -// === TMatmulBiasOp === -// Read: a, b, bias, Write: dst -void TMatmulBiasOp::getEffects(SmallVectorImpl> &effects) { - addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); - // 这里的 bias 是必选的 AnyType:$bias,所以是 Singleton - addEffect(effects, &getBiasMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); -} - -// === TGemvOp === -// Read: lhs, rhs, Write: dst -void TGemvOp::getEffects(SmallVectorImpl> &effects) { - addEffect(effects, &getLhsMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getRhsMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); -} - -// === TGemvAccOp === -// Read: acc_in, lhs, rhs, Write: dst -void TGemvAccOp::getEffects(SmallVectorImpl> &effects) { - addEffect(effects, &getAccInMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getLhsMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getRhsMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); -} - -// === TGemvBiasOp === -// Read: a, b, bias, Write: dst -void TGemvBiasOp::getEffects(SmallVectorImpl> &effects) { - addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getBiasMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); -} - -// === TGemvMxOp === -// Read: a, a_scale, b, b_scale, Write: dst -void TGemvMxOp::getEffects(SmallVectorImpl> &effects) { - addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getAScaleMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getBScaleMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); -} - -// === TGemvMxAccOp === -// Read: c_in, a, a_scale, b, b_scale, Write: dst -void TGemvMxAccOp::getEffects(SmallVectorImpl> &effects) { - addEffect(effects, &getCInMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getAScaleMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getBScaleMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); -} - -// === TGemvMxBiasOp === -// Read: a, a_scale, b, b_scale, bias, Write: dst -void TGemvMxBiasOp::getEffects(SmallVectorImpl> &effects) { - addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getAScaleMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getBScaleMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getBiasMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); -} - -// === TMatmulOp === -void TMatmulMxOp::getEffects(SmallVectorImpl> &effects) { - addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getAScaleMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getBScaleMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); -} - -// === TMatmulAccMxOp === -// Read: acc_in, lhs, rhs, Write: dst -void TMatmulMxAccOp::getEffects(SmallVectorImpl> &effects) { - addEffect(effects, &getCInMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getAScaleMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getBScaleMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); -} - -// === TMatmulBiasMxOp === -// Read: a, b, bias, Write: dst -void TMatmulMxBiasOp::getEffects(SmallVectorImpl> &effects) { - addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getAScaleMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getBScaleMutable(), MemoryEffects::Read::get()); - // 这里的 bias 是必选的 AnyType:$bias,所以是 Singleton - addEffect(effects, &getBiasMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); -} - -static bool isInsideSectionCube(Operation *op) { - return op->getParentOfType() != nullptr; -} - -static bool isInsideSectionVector(Operation *op) { - return op->getParentOfType() != nullptr; -} - -static std::optional -getEnclosingFunctionKernelKind(Operation *op) { - auto funcOp = op->getParentOfType(); - if (!funcOp) - return std::nullopt; - - auto kernelKindAttr = - funcOp->getAttrOfType( - FunctionKernelKindAttr::name); - if (!kernelKindAttr) - return std::nullopt; - - return kernelKindAttr.getKernelKind(); -} - -static bool isInsideSectionOrAttributedKernel(Operation *op) { - return isInsideSectionCube(op) || isInsideSectionVector(op) || - getEnclosingFunctionKernelKind(op).has_value(); -} - -static LogicalResult verifySplitAttr(Operation *op, int64_t split) { - if (split < 0 || split > 2) - return op->emitOpError("expects 'split' to be 0, 1, or 2"); - return success(); -} - -static LogicalResult verifyFrontendKernelKind(Operation *op, - FunctionKernelKind expected, - StringRef kernelName) { - auto kernelKind = getEnclosingFunctionKernelKind(op); - if (!kernelKind || *kernelKind != expected) { - return op->emitOpError("must be inside a ") - << kernelName << " kernel function"; - } - return success(); -} - -static ParseResult parseFrontendInitializePipeOp(OpAsmParser &parser, - OperationState &result) { - NamedAttrList attrs; - bool sawId = false; - bool sawDirMask = false; - bool sawSlotSize = false; - bool sawLocalSlotNum = false; - bool sawNoSplit = false; - - if (parser.parseLBrace()) - return failure(); - - while (failed(parser.parseOptionalRBrace())) { - StringRef keyword; - if (parser.parseKeyword(&keyword) || parser.parseEqual()) - return failure(); - - if (keyword == "id") { - if (sawId) - return parser.emitError(parser.getCurrentLocation(), - "duplicate 'id' clause"); - IntegerAttr idAttr; - if (parser.parseAttribute(idAttr, parser.getBuilder().getI32Type(), "id", - attrs)) - return failure(); - sawId = true; - } else if (keyword == "dir_mask") { - if (sawDirMask) - return parser.emitError(parser.getCurrentLocation(), - "duplicate 'dir_mask' clause"); - IntegerAttr dirMaskAttr; - if (parser.parseAttribute(dirMaskAttr, parser.getBuilder().getI8Type(), - "dir_mask", attrs)) - return failure(); - sawDirMask = true; - } else if (keyword == "slot_size") { - if (sawSlotSize) - return parser.emitError(parser.getCurrentLocation(), - "duplicate 'slot_size' clause"); - IntegerAttr slotSizeAttr; - if (parser.parseAttribute(slotSizeAttr, parser.getBuilder().getI32Type(), - "slot_size", attrs)) - return failure(); - sawSlotSize = true; - } else if (keyword == "local_slot_num") { - if (sawLocalSlotNum) - return parser.emitError(parser.getCurrentLocation(), - "duplicate 'local_slot_num' clause"); - IntegerAttr localSlotNumAttr; - if (parser.parseAttribute(localSlotNumAttr, parser.getBuilder().getI32Type(), - "local_slot_num", attrs)) - return failure(); - sawLocalSlotNum = true; - } else if (keyword == "nosplit") { - if (sawNoSplit) - return parser.emitError(parser.getCurrentLocation(), - "duplicate 'nosplit' clause"); - BoolAttr noSplitAttr; - if (parser.parseAttribute(noSplitAttr, "nosplit", attrs)) - return failure(); - sawNoSplit = true; - } else { - return parser.emitError(parser.getCurrentLocation()) - << "unexpected keyword '" << keyword << "'"; - } - - if (succeeded(parser.parseOptionalRBrace())) - break; - if (parser.parseComma()) - return failure(); - } - - if (!sawDirMask) - return parser.emitError(parser.getNameLoc(), "expected 'dir_mask' clause"); - if (!sawSlotSize) - return parser.emitError(parser.getNameLoc(), "expected 'slot_size' clause"); - if (!sawId) - attrs.set("id", parser.getBuilder().getI32IntegerAttr(0)); - - OpAsmParser::UnresolvedOperand gmSlotBuffer; - OpAsmParser::UnresolvedOperand gmSlotTensor; - OpAsmParser::UnresolvedOperand c2vConsumerBuf; - OpAsmParser::UnresolvedOperand v2cConsumerBuf; - Type gmSlotBufferTy; - Type gmSlotTensorTy; - Type c2vConsumerBufTy; - Type v2cConsumerBufTy; - bool hasGmSlotBuffer = false; - bool hasGmSlotTensor = false; - bool hasC2vConsumerBuf = false; - bool hasV2cConsumerBuf = false; - - if (parser.parseLParen()) - return failure(); - while (failed(parser.parseOptionalRParen())) { - StringRef keyword; - if (parser.parseKeyword(&keyword) || parser.parseEqual()) - return failure(); - - if (keyword == "gm_slot_buffer") { - if (hasGmSlotBuffer) - return parser.emitError(parser.getCurrentLocation(), - "duplicate 'gm_slot_buffer' operand"); - if (parser.parseOperand(gmSlotBuffer) || - parser.parseColonType(gmSlotBufferTy)) - return failure(); - hasGmSlotBuffer = true; - } else if (keyword == "gm_slot_tensor") { - if (hasGmSlotTensor) - return parser.emitError(parser.getCurrentLocation(), - "duplicate 'gm_slot_tensor' operand"); - if (parser.parseOperand(gmSlotTensor) || - parser.parseColonType(gmSlotTensorTy)) - return failure(); - hasGmSlotTensor = true; - } else if (keyword == "c2v_consumer_buf") { - if (hasC2vConsumerBuf) - return parser.emitError(parser.getCurrentLocation(), - "duplicate 'c2v_consumer_buf' operand"); - if (parser.parseOperand(c2vConsumerBuf) || - parser.parseColonType(c2vConsumerBufTy)) - return failure(); - hasC2vConsumerBuf = true; - } else if (keyword == "v2c_consumer_buf") { - if (hasV2cConsumerBuf) - return parser.emitError(parser.getCurrentLocation(), - "duplicate 'v2c_consumer_buf' operand"); - if (parser.parseOperand(v2cConsumerBuf) || - parser.parseColonType(v2cConsumerBufTy)) - return failure(); - hasV2cConsumerBuf = true; - } else { - return parser.emitError(parser.getCurrentLocation()) - << "unexpected initialize_pipe operand '" << keyword << "'"; - } - - if (succeeded(parser.parseOptionalRParen())) - break; - if (parser.parseComma()) - return failure(); - } - - if (parser.parseOptionalAttrDict(attrs)) - return failure(); - - result.addAttributes(attrs); - result.addAttribute("operandSegmentSizes", - parser.getBuilder().getDenseI32ArrayAttr( - {hasGmSlotBuffer ? 1 : 0, hasGmSlotTensor ? 1 : 0, - hasC2vConsumerBuf ? 1 : 0, - hasV2cConsumerBuf ? 1 : 0})); - if (hasGmSlotBuffer && - parser.resolveOperand(gmSlotBuffer, gmSlotBufferTy, result.operands)) - return failure(); - if (hasGmSlotTensor && - parser.resolveOperand(gmSlotTensor, gmSlotTensorTy, result.operands)) - return failure(); - if (hasC2vConsumerBuf && - parser.resolveOperand(c2vConsumerBuf, c2vConsumerBufTy, result.operands)) - return failure(); - if (hasV2cConsumerBuf && - parser.resolveOperand(v2cConsumerBuf, v2cConsumerBufTy, result.operands)) - return failure(); - return success(); -} - -template -static void printFrontendInitializePipeOp(InitOpT op, OpAsmPrinter &p) { - p << " {"; - bool needsComma = false; - auto printClause = [&](StringRef keyword, auto value) { - if (needsComma) - p << ", "; - p << keyword << " = " << value; - needsComma = true; - }; - - if (op.getId() != 0) - printClause("id", op.getId()); - printClause("dir_mask", static_cast(op.getDirMask())); - printClause("slot_size", op.getSlotSize()); - if (auto localSlotNumAttr = op.getLocalSlotNumAttr()) - printClause("local_slot_num", localSlotNumAttr.getInt()); - if (auto noSplitAttr = op.getNosplitAttr()) - printClause("nosplit", noSplitAttr.getValue() ? "true" : "false"); - p << "}"; - - p << "("; - bool needsOperandComma = false; - auto printOperandClause = [&](StringRef keyword, Value value) { - if (needsOperandComma) - p << ", "; - p << keyword << " = " << value << " : " << value.getType(); - needsOperandComma = true; - }; - if (op.getGmSlotBuffer()) { - printOperandClause("gm_slot_buffer", op.getGmSlotBuffer()); - } - if (op.getGmSlotTensor()) - printOperandClause("gm_slot_tensor", op.getGmSlotTensor()); - if (op.getC2vConsumerBuf()) - printOperandClause("c2v_consumer_buf", op.getC2vConsumerBuf()); - if (op.getV2cConsumerBuf()) - printOperandClause("v2c_consumer_buf", op.getV2cConsumerBuf()); - p << ")"; - p.printOptionalAttrDict( - op->getAttrs(), - /*elidedAttrs=*/{"id", "dir_mask", "slot_size", "local_slot_num", - "nosplit", "operandSegmentSizes"}); -} - -static std::optional -getStaticElementCount(ArrayRef shape) { - uint64_t count = 1; - for (int64_t dim : shape) { - if (dim == ShapedType::kDynamic || dim < 0) - return std::nullopt; - count *= static_cast(dim); - } - return count; -} - -static bool isSameOrHalfSlotByteSize(uint64_t tensorBytes, uint64_t slotBytes) { - return tensorBytes == slotBytes || tensorBytes * 2 == slotBytes; -} - -static LogicalResult verifyFrontendGlobalSlotTensor(Operation *op, Value tensor, - int8_t dirMask, - int32_t slotSize) { - (void)dirMask; - auto tvTy = dyn_cast(tensor.getType()); - if (!tvTy) - return op->emitOpError("expects 'gm_slot_tensor' to be !pto.tensor_view"); - - ArrayRef shape = tvTy.getShape(); - if (shape.empty()) - return op->emitOpError( - "expects 'gm_slot_tensor' to describe one slot entry tensor"); - - if (auto elemCount = getStaticElementCount(shape)) { - uint64_t elemBytes = getElemByteSize(tvTy.getElementType()); - if (elemBytes != 0) { - uint64_t tensorBytes = *elemCount * elemBytes; - if (!isSameOrHalfSlotByteSize(tensorBytes, - static_cast(slotSize))) { - return op->emitOpError() - << "expects 'slot_size' to equal gm_slot_tensor byte size " - "or twice gm_slot_tensor byte size for split GlobalTensor " - "entries (got slot_size = " - << slotSize << ", gm_slot_tensor byte size = " << tensorBytes - << ")"; - } - } - } - - return success(); -} - -template -static LogicalResult verifyFrontendInitCommon(InitOpT op, - FunctionKernelKind expected, - StringRef kernelName) { - if (failed(verifyFrontendKernelKind(op.getOperation(), expected, kernelName))) - return failure(); - - auto funcOp = op->template getParentOfType(); - if (!funcOp) - return op.emitOpError("must be nested under a func.func"); - - if (op.getId() < 0) - return op.emitOpError("expects 'id' to be non-negative"); - - unsigned sameIdInitCount = 0; - funcOp.walk([&](Operation *candidate) { - if (auto aic = dyn_cast(candidate)) { - if (aic.getId() == op.getId()) - ++sameIdInitCount; - return; - } - if (auto aiv = dyn_cast(candidate)) - if (aiv.getId() == op.getId()) - ++sameIdInitCount; - }); - if (sameIdInitCount > 1) { - return op.emitOpError( - "requires 'id' to be unique across frontend initialize_pipe ops in the function"); - } - - int8_t dirMask = op.getDirMask(); - if (dirMask != 1 && dirMask != 2 && dirMask != 3) - return op.emitOpError("expects 'dir_mask' to be 1, 2, or 3"); - if (op.getSlotSize() <= 0) - return op.emitOpError("expects 'slot_size' to be greater than 0"); - - bool hasGlobalSlotTensor = static_cast(op.getGmSlotTensor()); - bool hasC2vConsumerBuf = static_cast(op.getC2vConsumerBuf()); - bool hasV2cConsumerBuf = static_cast(op.getV2cConsumerBuf()); - if (hasGlobalSlotTensor) { - if (op.getGmSlotBuffer() || hasC2vConsumerBuf || hasV2cConsumerBuf) { - return op.emitOpError( - "globaltensor pipe init expects only 'gm_slot_tensor' and no " - "'gm_slot_buffer', 'c2v_consumer_buf', or 'v2c_consumer_buf'"); - } - if (op.getLocalSlotNumAttr()) - return op.emitOpError( - "globaltensor pipe init does not use 'local_slot_num'"); - if (getTargetArch(op.getOperation()) == PTOArch::A5) { - return op.emitOpError( - "globaltensor pipe entries are supported for a2/a3 l2g2l pipes"); - } - return verifyFrontendGlobalSlotTensor( - op.getOperation(), op.getGmSlotTensor(), dirMask, op.getSlotSize()); - } - - if (hasC2vConsumerBuf != hasV2cConsumerBuf) { - return op.emitOpError( - "expects 'c2v_consumer_buf' and 'v2c_consumer_buf' to be provided together"); - } - if (!hasC2vConsumerBuf) { - return op.emitOpError( - "expects local pipe init to provide 'c2v_consumer_buf' and " - "'v2c_consumer_buf'; use 'gm_slot_tensor' for globaltensor pipe entries"); - } - - if (auto localSlotNumAttr = op.getLocalSlotNumAttr()) { - int32_t localSlotNum = localSlotNumAttr.getInt(); - if (localSlotNum <= 0) - return op.emitOpError("expects 'local_slot_num' to be greater than 0"); - int32_t loweredSlotNum = dirMask == 3 ? 4 : 8; - if (localSlotNum > loweredSlotNum) { - return op.emitOpError() - << "expects 'local_slot_num' to be less than or equal to " - << loweredSlotNum << " for dir_mask = " << static_cast(dirMask); - } - } - - return success(); -} - -ParseResult AicInitializePipeOp::parse(OpAsmParser &parser, - OperationState &result) { - return parseFrontendInitializePipeOp(parser, result); -} - -void AicInitializePipeOp::print(OpAsmPrinter &p) { - printFrontendInitializePipeOp(*this, p); -} - -ParseResult AivInitializePipeOp::parse(OpAsmParser &parser, - OperationState &result) { - return parseFrontendInitializePipeOp(parser, result); -} - -void AivInitializePipeOp::print(OpAsmPrinter &p) { - printFrontendInitializePipeOp(*this, p); -} - -static ReserveBufferOp findReserveBufferByName(func::FuncOp funcOp, - StringRef name) { - ReserveBufferOp found; - funcOp.walk([&](ReserveBufferOp reserveOp) { - if (reserveOp.getName() != name) - return WalkResult::advance(); - found = reserveOp; - return WalkResult::interrupt(); - }); - return found; -} - -LogicalResult ReserveBufferOp::verify() { - auto funcOp = getOperation()->getParentOfType(); - if (!funcOp) - return emitOpError("must be nested under a func.func"); - - if (getSize() <= 0) - return emitOpError("expects 'size' to be greater than 0"); - - auto location = getLocation().getAddressSpace(); - if (location != AddressSpace::VEC && location != AddressSpace::MAT) - return emitOpError("expects 'location' to be #pto.address_space or #pto.address_space"); - - if (!getAutoAlloc() && !getBaseAttr()) - return emitOpError("expects 'base' when 'auto' is false"); - - if (auto baseAttr = getBaseAttr(); baseAttr && baseAttr.getInt() < 0) - return emitOpError("expects 'base' to be non-negative when present"); - - unsigned sameNameCount = 0; - funcOp.walk([&](ReserveBufferOp reserveOp) { - if (reserveOp.getName() == getName()) - ++sameNameCount; - }); - if (sameNameCount > 1) - return emitOpError("requires 'name' to be unique within the function"); - - return success(); -} - -LogicalResult ImportReservedBufferOp::verify() { - auto funcOp = getOperation()->getParentOfType(); - if (!funcOp) - return emitOpError("must be nested under a func.func"); - - auto peerFunc = SymbolTable::lookupNearestSymbolFrom( - getOperation(), getPeerFuncAttr()); - if (!peerFunc) - return emitOpError("expects 'peer_func' to reference an existing func.func"); - - unsigned sameImportCount = 0; - funcOp.walk([&](ImportReservedBufferOp importOp) { - if (importOp.getName() == getName() && - importOp.getPeerFuncAttr() == getPeerFuncAttr()) { - ++sameImportCount; - } - }); - if (sameImportCount > 1) { - return emitOpError( - "requires (name, peer_func) to be unique within the function"); - } - - if (!findReserveBufferByName(peerFunc, getName())) - return emitOpError("expects matching peer reserve_buffer to exist"); - - return success(); -} - -static FailureOr lookupFrontendInitOpById(Operation *op, - func::FuncOp funcOp, - int32_t id) { - Operation *matchedInit = nullptr; - unsigned matchedInitCount = 0; - funcOp.walk([&](Operation *candidate) { - if (auto aic = dyn_cast(candidate)) { - if (aic.getId() == static_cast(id)) { - matchedInit = candidate; - ++matchedInitCount; - } - return WalkResult::advance(); - } - if (auto aiv = dyn_cast(candidate)) { - if (aiv.getId() == static_cast(id)) { - matchedInit = candidate; - ++matchedInitCount; - } - return WalkResult::advance(); - } - return WalkResult::advance(); - }); - - if (matchedInitCount == 0) { - op->emitOpError() << "expects 'id' = " << id - << " to match a frontend initialize_pipe op in the same function"; - return failure(); - } - if (matchedInitCount > 1) { - op->emitOpError() << "expects 'id' = " << id - << " to match exactly one frontend initialize_pipe op in the same function"; - return failure(); - } - return matchedInit; -} - -static LogicalResult verifyFrontendSplitOp(Operation *op, - FunctionKernelKind expected, - StringRef kernelName, - int32_t id, - int64_t split) { - if (failed(verifyFrontendKernelKind(op, expected, kernelName))) - return failure(); - if (id < 0) - return op->emitOpError("expects 'id' to be non-negative"); - return verifySplitAttr(op, split); -} - -static FailureOr lookupFrontendInitDirMaskById(Operation *op, - func::FuncOp funcOp, - int32_t id) { - auto initOr = lookupFrontendInitOpById(op, funcOp, id); - if (failed(initOr)) - return failure(); - if (auto aic = dyn_cast(*initOr)) - return aic.getDirMask(); - return cast(*initOr).getDirMask(); -} - -static LogicalResult verifyFrontendDataOpDirection(Operation *op, int32_t id, - bool expectC2V) { - auto funcOp = op->getParentOfType(); - if (!funcOp) - return op->emitOpError("must be nested under a func.func"); - - auto dirMaskOr = lookupFrontendInitDirMaskById(op, funcOp, id); - if (failed(dirMaskOr)) - return failure(); - - int8_t dirMask = *dirMaskOr; - if (expectC2V && dirMask != 1 && dirMask != 3) { - return op->emitOpError() - << "expects 'id' = " << id - << " to reference initialize_pipe with dir_mask = 1 or 3"; - } - if (!expectC2V && dirMask != 2 && dirMask != 3) { - return op->emitOpError() - << "expects 'id' = " << id - << " to reference initialize_pipe with dir_mask = 2 or 3"; - } - return success(); -} - -static Value getFrontendInitGmSlotTensor(Operation *initOp) { - if (auto aic = dyn_cast(initOp)) - return aic.getGmSlotTensor(); - return cast(initOp).getGmSlotTensor(); -} - -static LogicalResult verifyFrontendTensorEntryMatchesInit(Operation *op, - int32_t id, - Type entryTy) { - auto entryViewTy = dyn_cast(entryTy); - if (!entryViewTy) - return success(); - - auto funcOp = op->getParentOfType(); - if (!funcOp) - return op->emitOpError("must be nested under a func.func"); - - auto initOr = lookupFrontendInitOpById(op, funcOp, id); - if (failed(initOr)) - return failure(); - Value gmSlotTensor = getFrontendInitGmSlotTensor(*initOr); - if (!gmSlotTensor) { - return op->emitOpError() - << "expects 'id' = " << id - << " to reference initialize_pipe with 'gm_slot_tensor' when the " - "pipe entry is !pto.tensor_view"; - } - - auto slotTensorTy = dyn_cast(gmSlotTensor.getType()); - if (!slotTensorTy) - return op->emitOpError("expects 'gm_slot_tensor' to be !pto.tensor_view"); - if (slotTensorTy.getElementType() != entryViewTy.getElementType()) { - return op->emitOpError() - << "expects pipe entry element type to match gm_slot_tensor element type"; - } - if (slotTensorTy.getRank() != entryViewTy.getRank()) { - return op->emitOpError() - << "expects pipe entry rank to match gm_slot_tensor rank"; - } - - ArrayRef slotShape = slotTensorTy.getShape(); - ArrayRef entryShape = entryViewTy.getShape(); - for (auto [idx, entryDim] : llvm::enumerate(entryShape)) { - int64_t slotDim = slotShape[idx]; - if (slotDim == ShapedType::kDynamic || - entryDim == ShapedType::kDynamic || slotDim == entryDim) - continue; - return op->emitOpError() - << "expects pipe entry dimension " << idx - << " to match gm_slot_tensor dimension " << slotDim; - } - return success(); -} - -template -static LogicalResult verifyFrontendPopOp(FrontendPopOpT op, - FunctionKernelKind expected, - StringRef kernelName, - bool expectC2V) { - if (failed(verifyFrontendSplitOp(op.getOperation(), expected, kernelName, - op.getId(), - op.getSplit()))) - return failure(); - if (failed(verifyFrontendDataOpDirection(op.getOperation(), op.getId(), - expectC2V))) - return failure(); - if (failed(verifyFrontendTensorEntryMatchesInit(op.getOperation(), op.getId(), - op.getTile().getType()))) - return failure(); - - bool hasValidRow = static_cast(op.getValidRow()); - bool hasValidCol = static_cast(op.getValidCol()); - if (hasValidRow != hasValidCol) - return op.emitOpError( - "expects valid_row and valid_col operands to be provided together"); - if (!hasValidRow) - return success(); - - if (isa(op.getTile().getType())) - return op.emitOpError( - "does not accept valid_row/valid_col when result is !pto.tensor_view"); - - auto tileTy = dyn_cast(op.getTile().getType()); - if (!tileTy) - return op.emitOpError( - "expects tile result to be !pto.tile_buf when valid_row/valid_col operands are provided"); - if (!tileTy.hasDynamicValid()) - return op.emitOpError( - "expects tile result to have dynamic validShape (?, ?) when valid_row/valid_col operands are provided"); - return success(); -} - -static LogicalResult verifyPipeShape(Operation *op, int8_t dirMask, int32_t slotSize, - int32_t slotNum, - std::optional flagBase) { - constexpr int32_t kMaxHardwareFlagIds = 16; - if (dirMask != 1 && dirMask != 2 && dirMask != 3) - return op->emitOpError("expects 'dir_mask' to be 1, 2, or 3"); - if (slotSize <= 0) - return op->emitOpError("expects 'slot_size' to be greater than 0"); - if (slotNum != 4 && slotNum != 8) - return op->emitOpError("expects 'slot_num' to be 4 or 8"); - if (flagBase && *flagBase < 0) - return op->emitOpError("expects 'flag_base' to be non-negative when present"); - if (flagBase) { - int32_t flagWidth = dirMask == 3 ? 4 : 2; - if (*flagBase + flagWidth > kMaxHardwareFlagIds) { - return op->emitOpError() - << "requires 'flag_base' and dir_mask to fit within " - << kMaxHardwareFlagIds << " hardware flag ids"; - } - } - - return success(); -} - -static LogicalResult verifyPipeHandleProducer(Operation *op, Value pipeHandle) { - if (!isa(pipeHandle.getType())) - return op->emitOpError("expects pipe operand type !pto.pipe"); - if (!pipeHandle.getDefiningOp() && - !pipeHandle.getDefiningOp()) { - return op->emitOpError( - "pipe_handle must be produced by pto.initialize_l2l_pipe or " - "pto.initialize_l2g2l_pipe"); - } - return success(); -} - -static bool getTensorLikeElementAndShape(Type ty, Type &elementType, - ArrayRef &shape) { - if (auto tvTy = dyn_cast(ty)) { - elementType = tvTy.getElementType(); - shape = tvTy.getShape(); - return true; - } - if (auto memrefTy = dyn_cast(ty)) { - elementType = memrefTy.getElementType(); - shape = memrefTy.getShape(); - return true; - } - return false; -} - -static LogicalResult verifyTensorEntryMatchesInternalPipeInit(Operation *op, - Value pipeHandle, - Type entryTy) { - auto entryViewTy = dyn_cast(entryTy); - if (!entryViewTy) - return success(); - - auto initOp = pipeHandle.getDefiningOp(); - if (!initOp) { - return op->emitOpError() - << "expects !pto.tensor_view pipe entry to use a pipe produced by " - "pto.initialize_l2g2l_pipe"; - } - if (initOp.getLocalAddr()) { - return op->emitOpError() - << "expects !pto.tensor_view pipe entry to use global-only " - "pto.initialize_l2g2l_pipe without local_addr"; - } - - Type slotElementType; - ArrayRef slotShape; - if (!getTensorLikeElementAndShape(initOp.getGmAddr().getType(), - slotElementType, slotShape)) { - return op->emitOpError() - << "expects !pto.tensor_view pipe entry to use " - "pto.initialize_l2g2l_pipe gm_addr with tensor/memref slot type"; - } - - if (slotElementType != entryViewTy.getElementType()) { - return op->emitOpError() - << "expects pipe entry element type to match initialize_l2g2l_pipe " - "gm_addr element type"; - } - if (slotShape.size() != static_cast(entryViewTy.getRank())) { - return op->emitOpError() - << "expects pipe entry rank to match initialize_l2g2l_pipe gm_addr " - "rank"; - } - - ArrayRef entryShape = entryViewTy.getShape(); - for (auto [idx, entryDim] : llvm::enumerate(entryShape)) { - int64_t slotDim = slotShape[idx]; - if (slotDim == ShapedType::kDynamic || - entryDim == ShapedType::kDynamic || slotDim == entryDim) - continue; - return op->emitOpError() - << "expects pipe entry dimension " << idx - << " to match initialize_l2g2l_pipe gm_addr dimension " - << slotDim; - } - - if (auto entryElemCount = getStaticElementCount(entryShape)) { - uint64_t elemBytes = getElemByteSize(entryViewTy.getElementType()); - uint64_t entryBytes = *entryElemCount * elemBytes; - if (elemBytes != 0) { - int8_t split = 0; - if (auto alloc = dyn_cast(op)) - split = alloc.getSplit(); - else if (auto push = dyn_cast(op)) - split = push.getSplit(); - else if (auto pop = dyn_cast(op)) - split = pop.getSplit(); - else if (auto free = dyn_cast(op)) - split = free.getSplit(); - - uint64_t slotBytes = static_cast(initOp.getSlotSize()); - bool isSplitEntry = split != 0; - bool byteSizeMatches = - entryBytes == slotBytes || (isSplitEntry && entryBytes * 2 == slotBytes); - if (!byteSizeMatches) { - return op->emitOpError() - << "expects pipe entry byte size to match initialize_l2g2l_pipe " - "slot_size" - << (isSplitEntry ? " or half slot_size for split entries" : "") - << " (got entry byte size = " << entryBytes - << ", slot_size = " << initOp.getSlotSize() << ")"; - } - } - } - - return success(); -} - -LogicalResult BuildAsyncSessionOp::verify() { - Type scratchTy = getScratch().getType(); - if (!isa(scratchTy)) - return emitOpError("expects scratch to be tile_buf or memref type"); - - auto scratchSpace = getPTOMemorySpaceEnum(scratchTy); - if (!scratchSpace || *scratchSpace != pto::AddressSpace::VEC) - return emitOpError("expects scratch to be in vec address space"); - - auto scratchShape = getShapeVec(scratchTy); - if (scratchShape.empty() || scratchShape.size() > 2) - return emitOpError("expects scratch to be rank-1 or rank-2"); - for (int64_t dim : scratchShape) { - if (dim == ShapedType::kDynamic) - return emitOpError("expects scratch to have a static shape"); - } - - auto scratchBytes = getStaticByteSize(scratchTy); - if (!scratchBytes) - return emitOpError("expects scratch byte size to be statically known"); - if (*scratchBytes < sizeof(uint64_t)) - return emitOpError("expects scratch to provide at least 8 bytes"); - - Type workspaceElemTy; - Type workspaceTy = getWorkspace().getType(); - if (auto ptrTy = dyn_cast(workspaceTy)) { - workspaceElemTy = ptrTy.getElementType(); - } else if (auto memTy = dyn_cast(workspaceTy)) { - workspaceElemTy = memTy.getElementType(); - if (!isGmAddressSpaceAttr(memTy.getMemorySpace())) - return emitOpError("expects workspace to be in GM address space"); - } else { - return emitOpError("expects workspace to be !pto.ptr or memref type"); - } - if (!isByteIntegerType(workspaceElemTy)) - return emitOpError("expects workspace element type to be an 8-bit integer"); - - if (auto syncIdAttr = getSyncIdAttr()) { - int64_t syncId = syncIdAttr.getInt(); - if (syncId < 0 || syncId > 7) - return emitOpError("expects sync_id in range [0, 7]"); - } - if (auto blockBytesAttr = getBlockBytesAttr()) { - if (blockBytesAttr.getInt() <= 0) - return emitOpError("expects block_bytes to be greater than 0"); - } - if (auto commBlockOffsetAttr = getCommBlockOffsetAttr()) { - if (commBlockOffsetAttr.getInt() < 0) - return emitOpError("expects comm_block_offset to be non-negative"); - } - if (auto queueNumAttr = getQueueNumAttr()) { - if (queueNumAttr.getInt() <= 0) - return emitOpError("expects queue_num to be greater than 0"); - } - if (auto channelGroupIdxAttr = getChannelGroupIdxAttr()) { - APInt value = channelGroupIdxAttr.getValue(); - if (value.isNegative()) - return emitOpError("expects channel_group_idx to be non-negative"); - if (value.ugt(UINT32_MAX)) - return emitOpError("expects channel_group_idx to fit in uint32"); - } - - return success(); -} - -static LogicalResult verifyAsyncTransferOp(Operation *op, Value dst, Value src) { - Type dstElemTy = getElemTy(dst.getType()); - Type srcElemTy = getElemTy(src.getType()); - if (!dstElemTy || !srcElemTy) - return op->emitOpError("expects src and dst to have element types"); - if (dstElemTy != srcElemTy) - return op->emitOpError("expects src and dst to have the same element type"); - if (failed(verifyAsyncFlatContiguous1DGMViewLike(op, dst, "dst")) || - failed(verifyAsyncFlatContiguous1DGMViewLike(op, src, "src"))) - return failure(); - if (getShapeVec(dst.getType()) != getShapeVec(src.getType())) - return op->emitOpError("expects src and dst to have the same static shape"); - return success(); -} - -LogicalResult TPutAsyncOp::verify() { - return verifyAsyncTransferOp(getOperation(), getDst(), getSrc()); -} - -LogicalResult TGetAsyncOp::verify() { - return verifyAsyncTransferOp(getOperation(), getDst(), getSrc()); -} - -LogicalResult TPutOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - if (failed(verifyCommGlobalLike(*this, getDst(), "dst")) || - failed(verifyCommGlobalLike(*this, getSrc(), "src")) || - failed(verifyCommStagingTileLike(*this, getPing(), "ping")) || - failed(verifyCommPingPongSameType(*this, getPing(), getPong(), "ping", - "pong"))) - return failure(); - if (getElemTy(getDst().getType()) != getElemTy(getSrc().getType())) - return emitOpError("expects src and dst to have the same element type"); - if (getShapeVec(getDst().getType()) != getShapeVec(getSrc().getType())) - return emitOpError("expects src and dst to have the same static shape"); - if (getElemTy(getPing().getType()) != getElemTy(getSrc().getType())) - return emitOpError("expects staging tile element type to match src/dst"); - return success(); -} - -LogicalResult TGetOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - if (failed(verifyCommGlobalLike(*this, getDst(), "dst")) || - failed(verifyCommGlobalLike(*this, getSrc(), "src")) || - failed(verifyCommStagingTileLike(*this, getPing(), "ping")) || - failed(verifyCommPingPongSameType(*this, getPing(), getPong(), "ping", - "pong"))) - return failure(); - if (getElemTy(getDst().getType()) != getElemTy(getSrc().getType())) - return emitOpError("expects src and dst to have the same element type"); - if (getShapeVec(getDst().getType()) != getShapeVec(getSrc().getType())) - return emitOpError("expects src and dst to have the same static shape"); - if (getElemTy(getPing().getType()) != getElemTy(getSrc().getType())) - return emitOpError("expects staging tile element type to match src/dst"); - return success(); -} - -LogicalResult TNotifyOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - if (failed(verifyCommSignalLike(*this, getSignal(), "signal"))) - return failure(); - auto valueTy = dyn_cast(getValue().getType()); - if (!valueTy || valueTy.getWidth() != 32) - return emitOpError("expects value to be i32"); - return success(); -} - -LogicalResult TWaitOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - if (failed(verifyCommSignalLike(*this, getSignal(), "signal"))) - return failure(); - auto cmpTy = dyn_cast(getCmpValue().getType()); - if (!cmpTy || cmpTy.getWidth() != 32) - return emitOpError("expects cmp_value to be i32"); - return success(); -} - -LogicalResult TTestOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - if (failed(verifyCommSignalLike(*this, getSignal(), "signal"))) - return failure(); - auto cmpTy = dyn_cast(getCmpValue().getType()); - if (!cmpTy || cmpTy.getWidth() != 32) - return emitOpError("expects cmp_value to be i32"); - return success(); -} - -static LogicalResult verifySyncAllGmWorkspace(Operation *op, Value workspace, - StringRef name) { - Type ty = workspace.getType(); - if (!isa(ty)) - return op->emitOpError() << "expects " << name - << " to be a GM memref/tensor_view/partition_view"; - - if (auto memTy = dyn_cast(ty)) { - if (!memTy.hasRank()) - return op->emitOpError() << "expects " << name << " to be ranked"; - if (!isGmAddressSpaceAttr(memTy.getMemorySpace())) - return op->emitOpError() << "expects " << name - << " to be in GM address space"; - } - - auto elemTy = dyn_cast(getElemTy(ty)); - if (!elemTy || elemTy.getWidth() != 32) - return op->emitOpError() << "expects " << name - << " element type to be i32"; - - SmallVector shape = getShapeVec(ty); - if (shape.empty()) - return op->emitOpError() << "expects " << name << " to have rank >= 1"; - for (int64_t dim : shape) { - if (dim != ShapedType::kDynamic && dim <= 0) - return op->emitOpError() << "expects " << name - << " shape to be positive"; - } - return success(); -} - -static LogicalResult verifySyncAllTileWorkspace(Operation *op, Value workspace, - StringRef name, - pto::AddressSpace expectedSpace) { - Type ty = workspace.getType(); - if (!isa(ty)) - return op->emitOpError() << "expects " << name - << " to be tile_buf or memref type"; - - if (isa(ty) && failed(verifyTileBufCommon(op, ty, name))) - return failure(); - - auto as = getPTOMemorySpaceEnum(ty); - if (!as || *as != expectedSpace) - return op->emitOpError() << "expects " << name << " to be in " - << (expectedSpace == pto::AddressSpace::VEC - ? "vec" - : "mat") - << " address space"; - - Type elemTy = getElemTy(ty); - auto intTy = dyn_cast_or_null(elemTy); - if (!intTy || intTy.getWidth() != 32) - return op->emitOpError() << "expects " << name - << " element type to be i32"; - - auto shape = getShapeVec(ty); - if (shape.empty() || shape.size() > 2) - return op->emitOpError() << "expects " << name - << " to be rank-1 or rank-2"; - for (int64_t dim : shape) { - if (dim != ShapedType::kDynamic && dim <= 0) - return op->emitOpError() << "expects " << name - << " shape to be positive"; - } - return success(); -} - -LogicalResult SyncAllOp::verify() { - bool hasGm = static_cast(getGmWorkspace()); - bool hasUb = static_cast(getUbWorkspace()); - bool hasL1 = static_cast(getL1Workspace()); - auto mode = getMode().getValue(); - auto coreType = getCoreType().getValue(); - - if (mode == pto::SyncAllMode::Hard) { - if (hasGm || hasUb || hasL1 || getUsedCores()) - return emitOpError( - "expects hard syncall to have no workspace operands or used_cores"); - return success(); - } - - if (!hasGm) - return emitOpError("expects soft syncall to provide gm_workspace"); - if (failed(verifySyncAllGmWorkspace(getOperation(), getGmWorkspace(), - "gm_workspace"))) - return failure(); - - if (auto used = getUsedCores()) { - auto intTy = dyn_cast(used.getType()); - if (!intTy || intTy.getWidth() != 32) - return emitOpError("expects used_cores to be i32"); - } - - switch (coreType) { - case pto::SyncCoreType::AIVOnly: - if (!hasUb || hasL1) - return emitOpError("expects soft AIV-only syncall to use gm_workspace " - "+ ub_workspace only"); - return verifySyncAllTileWorkspace(getOperation(), getUbWorkspace(), - "ub_workspace", - pto::AddressSpace::VEC); - case pto::SyncCoreType::AICOnly: - if (hasUb || !hasL1) - return emitOpError("expects soft AIC-only syncall to use gm_workspace " - "+ l1_workspace only"); - return verifySyncAllTileWorkspace(getOperation(), getL1Workspace(), - "l1_workspace", - pto::AddressSpace::MAT); - case pto::SyncCoreType::Mix: - if (!hasUb || !hasL1) - return emitOpError("expects soft mixed syncall to use gm_workspace + " - "ub_workspace + l1_workspace"); - if (failed(verifySyncAllTileWorkspace(getOperation(), getUbWorkspace(), - "ub_workspace", - pto::AddressSpace::VEC))) - return failure(); - return verifySyncAllTileWorkspace(getOperation(), getL1Workspace(), - "l1_workspace", - pto::AddressSpace::MAT); - } - - llvm_unreachable("unhandled SyncCoreType"); -} - -LogicalResult TBroadcastOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - if (failed(verifyCommGlobalLike(*this, getSrc(), "src")) || - failed(verifyCommStagingTileLike(*this, getPing(), "ping")) || - failed(verifyCommPingPongSameType(*this, getPing(), getPong(), "ping", - "pong")) || - failed(verifyCommGlobalGroup(*this, getGroup(), "group"))) - return failure(); - if (getRoot() >= static_cast(getGroup().size())) - return emitOpError("expects root to index into group operands"); - if (getSrc().getType() != getGroup().front().getType()) - return emitOpError("expects src type to match group member type"); - if (getElemTy(getPing().getType()) != getElemTy(getSrc().getType())) - return emitOpError("expects staging tile element type to match src"); - return success(); -} - -LogicalResult CommTGatherOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - if (failed(verifyCommGlobalLike(*this, getDst(), "dst")) || - failed(verifyCommStagingTileLike(*this, getPing(), "ping")) || - failed(verifyCommPingPongSameType(*this, getPing(), getPong(), "ping", - "pong")) || - failed(verifyCommGlobalGroup(*this, getGroup(), "group"))) - return failure(); - if (getRoot() >= static_cast(getGroup().size())) - return emitOpError("expects root to index into group operands"); - if (getElemTy(getDst().getType()) != getElemTy(getGroup().front().getType())) - return emitOpError("expects dst element type to match group member type"); - if (getElemTy(getPing().getType()) != getElemTy(getDst().getType())) - return emitOpError("expects staging tile element type to match dst"); - return success(); -} - -LogicalResult CommTScatterOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - if (failed(verifyCommGlobalLike(*this, getSrc(), "src")) || - failed(verifyCommStagingTileLike(*this, getPing(), "ping")) || - failed(verifyCommPingPongSameType(*this, getPing(), getPong(), "ping", - "pong")) || - failed(verifyCommGlobalGroup(*this, getGroup(), "group"))) - return failure(); - if (getRoot() >= static_cast(getGroup().size())) - return emitOpError("expects root to index into group operands"); - if (getElemTy(getSrc().getType()) != getElemTy(getGroup().front().getType())) - return emitOpError("expects src element type to match group member type"); - if (getElemTy(getPing().getType()) != getElemTy(getSrc().getType())) - return emitOpError("expects staging tile element type to match src"); - return success(); -} - -LogicalResult TReduceOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - if (failed(verifyCommGlobalLike(*this, getDst(), "dst")) || - failed(verifyCommStagingTileLike(*this, getAcc(), "acc")) || - failed(verifyCommStagingTileLike(*this, getRecvPing(), "recv_ping")) || - failed(verifyCommPingPongSameType(*this, getRecvPing(), getRecvPong(), - "recv_ping", "recv_pong")) || - failed(verifyCommGlobalGroup(*this, getGroup(), "group"))) - return failure(); - if (getRoot() >= static_cast(getGroup().size())) - return emitOpError("expects root to index into group operands"); - if (getElemTy(getDst().getType()) != getElemTy(getGroup().front().getType())) - return emitOpError("expects dst element type to match group member type"); - if (getAcc().getType() != getRecvPing().getType()) - return emitOpError("expects acc and recv_ping to have identical types"); - if (getElemTy(getAcc().getType()) != getElemTy(getDst().getType())) - return emitOpError("expects accumulator/receive tiles to match dst element type"); - return success(); -} - -LogicalResult AicInitializePipeOp::verify() { - return verifyFrontendInitCommon(*this, FunctionKernelKind::Cube, "cube"); -} - -LogicalResult AivInitializePipeOp::verify() { - return verifyFrontendInitCommon(*this, FunctionKernelKind::Vector, "vector"); -} - -LogicalResult TAllocToAivOp::verify() { - if (failed(verifyFrontendSplitOp(getOperation(), FunctionKernelKind::Cube, - "cube", getId(), getSplit()))) - return failure(); - if (failed(verifyFrontendDataOpDirection(getOperation(), getId(), - /*expectC2V=*/true))) - return failure(); - return verifyFrontendTensorEntryMatchesInit(getOperation(), getId(), - getEntry().getType()); -} - -LogicalResult TAllocToAicOp::verify() { - if (failed(verifyFrontendSplitOp(getOperation(), FunctionKernelKind::Vector, - "vector", getId(), getSplit()))) - return failure(); - if (failed(verifyFrontendDataOpDirection(getOperation(), getId(), - /*expectC2V=*/false))) - return failure(); - return verifyFrontendTensorEntryMatchesInit(getOperation(), getId(), - getEntry().getType()); -} - -LogicalResult TPushToAivOp::verify() { - if (failed(verifyFrontendSplitOp(getOperation(), FunctionKernelKind::Cube, - "cube", getId(), getSplit()))) - return failure(); - if (failed(verifyFrontendDataOpDirection(getOperation(), getId(), - /*expectC2V=*/true))) - return failure(); - return verifyFrontendTensorEntryMatchesInit(getOperation(), getId(), - getTile().getType()); -} - -LogicalResult TPushToAicOp::verify() { - if (failed(verifyFrontendSplitOp(getOperation(), FunctionKernelKind::Vector, - "vector", getId(), getSplit()))) - return failure(); - if (failed(verifyFrontendDataOpDirection(getOperation(), getId(), - /*expectC2V=*/false))) - return failure(); - return verifyFrontendTensorEntryMatchesInit(getOperation(), getId(), - getTile().getType()); -} - -LogicalResult TPopFromAicOp::verify() { - return verifyFrontendPopOp(*this, FunctionKernelKind::Vector, "vector", - /*expectC2V=*/true); -} - -LogicalResult TPopFromAivOp::verify() { - return verifyFrontendPopOp(*this, FunctionKernelKind::Cube, "cube", - /*expectC2V=*/false); -} - -LogicalResult TFreeFromAicOp::verify() { - if (failed(verifyFrontendSplitOp(getOperation(), FunctionKernelKind::Vector, - "vector", getId(), getSplit()))) - return failure(); - if (failed(verifyFrontendDataOpDirection(getOperation(), getId(), - /*expectC2V=*/true))) - return failure(); - if (getEntry()) - return verifyFrontendTensorEntryMatchesInit(getOperation(), getId(), - getEntry().getType()); - return success(); -} - -LogicalResult TFreeFromAivOp::verify() { - if (failed(verifyFrontendSplitOp(getOperation(), FunctionKernelKind::Cube, - "cube", getId(), getSplit()))) - return failure(); - if (failed(verifyFrontendDataOpDirection(getOperation(), getId(), - /*expectC2V=*/false))) - return failure(); - if (getEntry()) - return verifyFrontendTensorEntryMatchesInit(getOperation(), getId(), - getEntry().getType()); - return success(); -} - -LogicalResult InitializeL2G2LPipeOp::verify() { - if (failed(verifyPipeShape(getOperation(), getDirMask(), getSlotSize(), - getSlotNum(), - getFlagBaseAttr() - ? std::optional(getFlagBaseAttr().getInt()) - : std::nullopt))) - return failure(); - - if (!getLocalAddr()) { - if (getPeerLocalAddr()) - return emitOpError("'peer_local_addr' requires 'local_addr'"); - if (getLocalSlotNumAttr()) - return emitOpError( - "'local_slot_num' is only allowed when 'local_addr' is present"); - return success(); - } - - if (auto localSlotNumAttr = getLocalSlotNumAttr()) { - int32_t localSlotNum = localSlotNumAttr.getInt(); - if (localSlotNum <= 0) - return emitOpError("expects 'local_slot_num' to be greater than 0"); - if (static_cast(localSlotNum) > getSlotNum()) - return emitOpError( - "expects 'local_slot_num' to be less than or equal to slot_num"); - } - - if (getDirMask() == 3 && !getPeerLocalAddr()) - return emitOpError("expects 'peer_local_addr' when dir_mask is 3"); - if (getDirMask() != 3 && getPeerLocalAddr()) - return emitOpError("'peer_local_addr' is only allowed when dir_mask is 3"); - return success(); -} - -LogicalResult InitializeL2LPipeOp::verify() { - if (failed(verifyPipeShape(getOperation(), getDirMask(), getSlotSize(), - getSlotNum(), - getFlagBaseAttr() - ? std::optional(getFlagBaseAttr().getInt()) - : std::nullopt))) - return failure(); - - if (getDirMask() == 3 && !getPeerLocalAddr()) - return emitOpError("expects 'peer_local_addr' when dir_mask is 3"); - if (getDirMask() != 3 && getPeerLocalAddr()) - return emitOpError("'peer_local_addr' is only allowed when dir_mask is 3"); - return success(); -} - -LogicalResult TPushOp::verify() { - if (!isInsideSectionOrAttributedKernel(getOperation())) - return emitOpError("must be inside pto.section.cube/vector or a kernel_kind function"); - if (failed(verifyPipeHandleProducer(getOperation(), getPipeHandle()))) - return failure(); - if (failed(verifySplitAttr(getOperation(), getSplit()))) - return failure(); - if (failed(verifyTensorEntryMatchesInternalPipeInit( - getOperation(), getPipeHandle(), getTile().getType()))) - return failure(); - if (!isa(getTile().getType()) && - getPipe() == pto::PIPE::PIPE_UNASSIGNED) - return emitOpError("tile type must map to a supported producer pipe"); - return success(); -} - -LogicalResult TAllocOp::verify() { - if (!isInsideSectionOrAttributedKernel(getOperation())) - return emitOpError("must be inside pto.section.cube/vector or a kernel_kind function"); - if (failed(verifyPipeHandleProducer(getOperation(), getPipeHandle()))) - return failure(); - if (failed(verifyTensorEntryMatchesInternalPipeInit( - getOperation(), getPipeHandle(), getEntry().getType()))) - return failure(); - return verifySplitAttr(getOperation(), getSplit()); -} - -LogicalResult TPopOp::verify() { - if (!isInsideSectionOrAttributedKernel(getOperation())) - return emitOpError("must be inside pto.section.cube/vector or a kernel_kind function"); - if (failed(verifyPipeHandleProducer(getOperation(), getPipeHandle()))) - return failure(); - if (failed(verifySplitAttr(getOperation(), getSplit()))) - return failure(); - if (failed(verifyTensorEntryMatchesInternalPipeInit( - getOperation(), getPipeHandle(), getTile().getType()))) - return failure(); - if (!isa(getTile().getType()) && - getPipe() == pto::PIPE::PIPE_UNASSIGNED) - return emitOpError( - "tile type and target arch must map to a supported consumer pipe"); - return success(); -} - -LogicalResult TFreeOp::verify() { - if (!isInsideSectionOrAttributedKernel(getOperation())) - return emitOpError("must be inside pto.section.cube/vector or a kernel_kind function"); - if (failed(verifyPipeHandleProducer(getOperation(), getPipeHandle()))) - return failure(); - if (getEntry() && - failed(verifyTensorEntryMatchesInternalPipeInit( - getOperation(), getPipeHandle(), getEntry().getType()))) - return failure(); - return verifySplitAttr(getOperation(), getSplit()); -} - -ParseResult TFreeOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand first; - OpAsmParser::UnresolvedOperand pipe; - Type firstTy; - Type pipeTy; - bool hasEntry = false; - - if (parser.parseLParen() || parser.parseOperand(first)) - return failure(); - - if (succeeded(parser.parseOptionalComma())) { - hasEntry = true; - if (parser.parseOperand(pipe) || parser.parseColonType(firstTy) || - parser.parseComma() || parser.parseType(pipeTy) || parser.parseRParen()) - return failure(); - } else { - if (parser.parseColonType(pipeTy) || parser.parseRParen()) - return failure(); - pipe = first; - } - - NamedAttrList attrs; - if (parser.parseLBrace() || parser.parseKeyword("split") || - parser.parseEqual()) - return failure(); - IntegerAttr splitAttr; - if (parser.parseAttribute(splitAttr, parser.getBuilder().getI8Type(), - "split", attrs) || - parser.parseRBrace() || parser.parseOptionalAttrDict(attrs)) - return failure(); - - result.addAttributes(attrs); - if (hasEntry && - parser.resolveOperand(first, firstTy, result.operands)) - return failure(); - if (parser.resolveOperand(pipe, pipeTy, result.operands)) - return failure(); - return success(); -} - -void TFreeOp::print(OpAsmPrinter &p) { - p << "("; - if (getEntry()) { - p << getEntry() << ", " << getPipeHandle() << " : " - << getEntry().getType() << ", " << getPipeHandle().getType(); - } else { - p << getPipeHandle() << " : " << getPipeHandle().getType(); - } - p << ") {split = " << static_cast(getSplit()) << "}"; - p.printOptionalAttrDict((*this)->getAttrs(), - /*elidedAttrs=*/{"split"}); -} - -void BuildAsyncSessionOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getScratchMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getWorkspaceMutable(), MemoryEffects::Read::get()); - addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); -} - -void TPutAsyncOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); - addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getSessionMutable(), MemoryEffects::Read::get()); - addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); -} - -void TGetAsyncOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); - addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getSessionMutable(), MemoryEffects::Read::get()); - addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); -} - -void TPutOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); - addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getPingMutable(), MemoryEffects::Read::get()); -} - -void TGetOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); - addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getPingMutable(), MemoryEffects::Read::get()); -} - -void TNotifyOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getSignalMutable(), MemoryEffects::Write::get()); - addEffect(effects, &getValueMutable(), MemoryEffects::Read::get()); -} - -void TWaitOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getSignalMutable(), MemoryEffects::Write::get()); - addEffect(effects, &getCmpValueMutable(), MemoryEffects::Read::get()); -} - -void TTestOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getSignalMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getCmpValueMutable(), MemoryEffects::Read::get()); - addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); -} - -void TBroadcastOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getPingMutable(), MemoryEffects::Write::get()); - if (getPong()) { - auto pongRange = getPongMutable(); - if (auto it = pongRange.begin(); it != pongRange.end()) - addEffect(effects, &*it, MemoryEffects::Write::get()); - } -} - -void CommTGatherOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); - addEffect(effects, &getPingMutable(), MemoryEffects::Read::get()); -} - -void CommTScatterOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getPingMutable(), MemoryEffects::Write::get()); - if (getPong()) { - auto pongRange = getPongMutable(); - if (auto it = pongRange.begin(); it != pongRange.end()) - addEffect(effects, &*it, MemoryEffects::Write::get()); - } -} - -void TReduceOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); - addEffect(effects, &getAccMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getAccMutable(), MemoryEffects::Write::get()); - addEffect(effects, &getRecvPingMutable(), MemoryEffects::Write::get()); -} - -void WaitAsyncEventOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getEventMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getSessionMutable(), MemoryEffects::Read::get()); - addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); -} - -void TestAsyncEventOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getEventMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getSessionMutable(), MemoryEffects::Read::get()); - addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); -} - -void InitializeL2G2LPipeOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getGmAddrMutable(), MemoryEffects::Read::get()); - auto localAddr = getLocalAddrMutable(); - if (!localAddr.empty()) - addEffect(effects, &*localAddr.begin(), MemoryEffects::Read::get()); - auto peerLocalAddr = getPeerLocalAddrMutable(); - if (!peerLocalAddr.empty()) - addEffect(effects, &*peerLocalAddr.begin(), MemoryEffects::Read::get()); - addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); -} - -void InitializeL2LPipeOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getLocalAddrMutable(), MemoryEffects::Read::get()); - addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); -} - -void TPushOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getTileMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getPipeHandleMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getPipeHandleMutable(), MemoryEffects::Write::get()); -} - -void TAllocOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getEntryMutable(), MemoryEffects::Write::get()); - addEffect(effects, &getPipeHandleMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getPipeHandleMutable(), MemoryEffects::Write::get()); -} - -void TPopOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getPipeHandleMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getPipeHandleMutable(), MemoryEffects::Write::get()); - addEffect(effects, &getTileMutable(), MemoryEffects::Write::get()); -} - -void TFreeOp::getEffects( - SmallVectorImpl> - &effects) { - auto entry = getEntryMutable(); - if (!entry.empty()) - addEffect(effects, &*entry.begin(), MemoryEffects::Read::get()); - addEffect(effects, &getPipeHandleMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getPipeHandleMutable(), MemoryEffects::Write::get()); -} - -// [Include 必须放在最后] -#include "PTO/IR/PTOInterfaces.cpp.inc" -#define GET_OP_CLASSES -#include "PTO/IR/PTOOps.cpp.inc" +#include "PTO.def" diff --git a/lib/PTO/IR/PTO.def b/lib/PTO/IR/PTO.def new file mode 100644 index 000000000..376b9c017 --- /dev/null +++ b/lib/PTO/IR/PTO.def @@ -0,0 +1,12933 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- PTO.cpp - PTO Dialect ----------------------------------------------===// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" +#include "PTO/IR/PTOSyncUtils.h" + +#include "mlir/AsmParser/AsmParser.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Parser/Parser.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "llvm/Support/ErrorHandling.h" + +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace mlir::pto; + +// Forward declarations for custom shape/type printers used by tensor_view and +// partition_tensor_view. +namespace mlir { +namespace pto { +static LogicalResult parseShapeAndElem(AsmParser &parser, + SmallVectorImpl &shape, + Type &elementType, + bool allowDynamic = true); +static void printShapeAndElem(AsmPrinter &printer, + ArrayRef shape, + Type elementType); +} // namespace pto +} // namespace mlir + +// ============================================================================= +// TileBufType 的自定义 Shape 解析与打印函数 +// ============================================================================= + +// 解析逻辑:解析形如 "32x32" 的维度列表 +[[maybe_unused]] static ParseResult parseShape(AsmParser &parser, SmallVectorImpl &shape) { + // parseDimensionList 会解析 "dim x dim x ...", 遇到无法解析为维度的字符停止 + // 参数 allowDynamic=true (允许 ?), withTrailingX=false (不吞掉末尾的 x) + if (parser.parseDimensionList(shape, /*allowDynamic=*/true, /*withTrailingX=*/false)) + return failure(); + return success(); +} + +// 打印逻辑:打印形如 "32x32" 的维度列表 +[[maybe_unused]] static void printShape(AsmPrinter &printer, ArrayRef shape) { + for (auto it = shape.begin(); it != shape.end(); ++it) { + if (it != shape.begin()) printer << "x"; // 维度间的分隔符 + if (*it == ShapedType::kDynamic) + printer << "?"; + else + printer << *it; + } + // 注意:我们不在这里打印末尾的 'x',因为 assemblyFormat 中已经写了 `x` $elementType +} + +static std::optional getPTOMemorySpaceEnum(Type ty); +enum class VerifierTargetArch { + A2A3, + A5, +}; +static VerifierTargetArch getVerifierTargetArch(Operation *op); +static std::optional getVerifierArchName(Operation *op); +static bool isSupportedVecElemType(Type ty, bool allowBf16 = true, + bool allowInt8 = true); +static bool isSupportedLoadStoreElemTypeA2A3(Type ty); +static bool isSupportedGatherElemTypeA2A3(Type ty); +static bool isSupportedGatherElemTypeA5(Type ty); +static bool isA5TLoadStoreTransferElemType(Type ty); +static bool isA5AccStorePreQuantDstType(Type srcElem, Type dstElem); +static bool isA5LowPrecisionTCvtPair(Type srcElem, Type dstElem); +static bool isA5SupportedTCvtPair(Type srcElem, Type dstElem); +static ParseResult parseSyncEventOpCommon(OpAsmParser &parser, + OperationState &result, + StringAttr pipeAttrName, + StringAttr eventIdAttrName); +static void printSyncEventOpCommon(OpAsmPrinter &p, Operation *op, + PipeAttr pipeAttr, IntegerAttr eventAttr, + Value eventDyn, StringRef pipeAttrName, + StringRef eventIdAttrName); +static bool isTileLikeType(Type ty); +static SmallVector getShapeVec(Type ty); +static SmallVector getValidShapeVec(Type ty); +static SmallVector getValidShapeVec(Value value); +static bool isByteIntegerType(Type ty); +static LogicalResult verifyTileBufCommon(Operation *op, Type ty, StringRef name, + bool allowLowPrecision = false); +static LogicalResult verifyTileBufSameElemType(Operation *op, Type lhs, Type rhs, + StringRef lhsName, + StringRef rhsName); +static LogicalResult verifyTileBufSameLogicalExtent(Operation *op, Type lhs, + Type rhs, StringRef lhsName, + StringRef rhsName, + bool compareValidShape); + +static LogicalResult verifyTileBufSameValidShape(Operation *op, Type lhs, Type rhs, + StringRef lhsName, StringRef rhsName); +static LogicalResult verifyVecTileCommon(Operation *op, Type ty, StringRef name); +static LogicalResult verifyVecTileCommonA2A3(Operation *op, Type ty, + StringRef name); +static LogicalResult verifyVecTileCommonA5(Operation *op, Type ty, + StringRef name); +static LogicalResult verifyVecTileUnaryOp(Operation *op, Type srcTy, Type dstTy, + StringRef srcName = "src", + StringRef dstName = "dst", + bool allowBf16 = true, + bool allowInt8 = true); +static LogicalResult verifyAccTileCommon(Operation *op, Type ty, StringRef name); +static LogicalResult verifyAccTileCommonA2A3(Operation *op, Type ty, + StringRef name); +static LogicalResult verifyAccTileCommonA5(Operation *op, Type ty, + StringRef name); +static LogicalResult verifyMatTileOperands(Operation *op, Type lhsTy, Type rhsTy, + Type dstTy); +static LogicalResult verifyMatTileOperandsA2A3(Operation *op, Type lhsTy, + Type rhsTy, Type dstTy); +static LogicalResult verifyMatTileOperandsA5(Operation *op, Type lhsTy, + Type rhsTy, Type dstTy); +static LogicalResult verifyGemvTileOperands(Operation *op, Type lhsTy, Type rhsTy, + Type dstTy); +static LogicalResult verifyAsyncFlatContiguous1DGMViewLike(Operation *op, + Value value, + StringRef name); +static LogicalResult verifyGemvTileOperandsA2A3(Operation *op, Type lhsTy, + Type rhsTy, Type dstTy); +static LogicalResult verifyGemvTileOperandsA5(Operation *op, Type lhsTy, + Type rhsTy, Type dstTy); +static LogicalResult verifyMatBiasTile(Operation *op, Type biasTy, Type dstTy, + bool requireFloatBias = false); +static LogicalResult verifyMatBiasTileA2A3(Operation *op, Type biasTy, Type dstTy, + bool requireFloatBias = false); +static LogicalResult verifyMatBiasTileA5(Operation *op, Type biasTy, Type dstTy, + bool requireFloatBias = false); +static LogicalResult verifyMatmulTypeTriple(Operation *op, Type lhsElemTy, + Type rhsElemTy, Type dstElemTy); +static std::optional getLogicalViewLayout(Value value); +static std::optional getTileBufLogicalLayout(pto::TileBufType type); +static std::optional getConstantIntegerValue(Value value); +static LogicalResult verifyPartialValidPattern(Operation *op, Type src0Ty, + Type src1Ty, Type dstTy); +static Type getElemTy(Type ty); +static FailureOr +verifyMatchingRowMajorBinaryTileOpCommon(Operation *op, Type src0Ty, + Type src1Ty, Type dstTy); +static FailureOr +verifyNumericScalarTileOpCommon(Operation *op, Type srcTy, Type dstTy, + Type scalarTy, bool requireValidRowsEqual); +static FailureOr +verifyShiftLikeBinaryTileOpCommon(Operation *op, Type src0Ty, Type src1Ty, + Type dstTy); +static LogicalResult verifyArithmeticElemTypeForArch( + Operation *op, Type elemTy, PTOArch targetArch, bool allowInt8OnA5, + bool allowBf16OnA5, StringRef a2a3Error, StringRef a5Error); +static bool isRowMajorTileBuf(Type ty); + +#define GET_ENUM_CLASSES +#include "PTO/IR/PTOEnums.cpp.inc" + +#define GET_TYPEDEF_CLASSES +#include "PTO/IR/PTOTypeDefs.cpp.inc" + +#define GET_ATTRDEF_CLASSES +#include "PTO/IR/PTOAttrs.cpp.inc" + +#include "PTO/IR/PTODialect.cpp.inc" + +[[maybe_unused]] static LogicalResult parseShapeAndElemStable(mlir::AsmParser &parser, + llvm::SmallVectorImpl &shape, + mlir::Type &elementType) { + if (failed(parser.parseLess())) + return failure(); + + if (failed(parser.parseDimensionList(shape, /*allowDynamic=*/true))) + return failure(); + + if (failed(parser.parseType(elementType))) + return failure(); + + if (failed(parser.parseGreater())) + return failure(); + + return success(); +} + +static int64_t getPTOTypeRank(Type type) { + // 1. 处理标准的 MLIR 类型 (MemRef, Tensor, Vector) + if (auto shapedTy = dyn_cast(type)) { + if (shapedTy.hasRank()) + return shapedTy.getRank(); + return -1; // Unranked type + } + + // 2. 处理 PTO 自定义类型 + if (auto tvTy = dyn_cast(type)) + return tvTy.getRank(); + + if (auto tileTy = dyn_cast(type)) + return tileTy.getRank(); + + if (auto tileViewTy = dyn_cast(type)) + return tileViewTy.getRank(); + + if (auto tileBufTy = dyn_cast(type)) + return tileBufTy.getRank(); + + // 3. 不支持的类型 + return -1; +} + +static bool isGmAddressSpaceAttr(Attribute memorySpace) { + if (!memorySpace) + return true; + if (auto addr = mlir::dyn_cast(memorySpace)) + return addr.getAddressSpace() == pto::AddressSpace::GM; + if (auto intAttr = mlir::dyn_cast(memorySpace)) + return intAttr.getInt() == 0; + return false; +} + +PTOArch mlir::pto::getTargetArch(ModuleOp module) { + if (!module) + return PTOArch::A3; + + auto arch = module->getAttrOfType(kPTOTargetArchAttrName); + if (arch && arch.getValue().equals_insensitive("a5")) + return PTOArch::A5; + return PTOArch::A3; +} + +PTOArch mlir::pto::getTargetArch(Operation *op) { + if (!op) + return PTOArch::A3; + return getTargetArch(op->getParentOfType()); +} + +bool mlir::pto::isTargetArchA3(ModuleOp module) { + return getTargetArch(module) == PTOArch::A3; +} + +bool mlir::pto::isTargetArchA5(ModuleOp module) { + return getTargetArch(module) == PTOArch::A5; +} + +bool mlir::pto::isTargetArchA3(Operation *op) { + return getTargetArch(op) == PTOArch::A3; +} + +bool mlir::pto::isTargetArchA5(Operation *op) { + return getTargetArch(op) == PTOArch::A5; +} + +static llvm::TypeSize getOneByteTypeSize() { + return llvm::TypeSize::getFixed(8); +} + +llvm::TypeSize mlir::pto::HiF8Type::getTypeSizeInBits( + const DataLayout &, DataLayoutEntryListRef) const { + return getOneByteTypeSize(); +} + +uint64_t mlir::pto::HiF8Type::getABIAlignment(const DataLayout &, + DataLayoutEntryListRef) const { + return 1; +} + +uint64_t mlir::pto::HiF8Type::getPreferredAlignment( + const DataLayout &, DataLayoutEntryListRef) const { + return 1; +} + +llvm::TypeSize mlir::pto::F4E1M2x2Type::getTypeSizeInBits( + const DataLayout &, DataLayoutEntryListRef) const { + return getOneByteTypeSize(); +} + +uint64_t mlir::pto::F4E1M2x2Type::getABIAlignment( + const DataLayout &, DataLayoutEntryListRef) const { + return 1; +} + +uint64_t mlir::pto::F4E1M2x2Type::getPreferredAlignment( + const DataLayout &, DataLayoutEntryListRef) const { + return 1; +} + +llvm::TypeSize mlir::pto::F4E2M1x2Type::getTypeSizeInBits( + const DataLayout &, DataLayoutEntryListRef) const { + return getOneByteTypeSize(); +} + +uint64_t mlir::pto::F4E2M1x2Type::getABIAlignment( + const DataLayout &, DataLayoutEntryListRef) const { + return 1; +} + +uint64_t mlir::pto::F4E2M1x2Type::getPreferredAlignment( + const DataLayout &, DataLayoutEntryListRef) const { + return 1; +} + +static VerifierTargetArch getVerifierTargetArch(Operation *op) { + if (auto archName = getVerifierArchName(op)) { + return archName->equals_insensitive("a5") ? VerifierTargetArch::A5 + : VerifierTargetArch::A2A3; + } + + switch (getPTOParserTargetArch(op ? op->getContext() : nullptr)) { + case PTOParserTargetArch::A5: + return VerifierTargetArch::A5; + case PTOParserTargetArch::A3: + case PTOParserTargetArch::Unspecified: + return VerifierTargetArch::A2A3; + } + + return VerifierTargetArch::A2A3; +} + +static std::optional getVerifierArchName(Operation *op) { + auto module = op ? op->getParentOfType() : ModuleOp(); + if (!module) + return std::nullopt; + if (auto arch = module->getAttrOfType(kPTOTargetArchAttrName)) + return arch.getValue(); + return std::nullopt; +} + +static bool shouldBypassDecodedMemrefVerifier(Operation *op) { + if (!op) + return false; + for (Value operand : op->getOperands()) { + if (isa(operand.getType())) + return true; + if (operand.getDefiningOp()) + return true; + } + return false; +} + +static SmallVector canonicalizeTileBufValidShape(ArrayRef validShape) { + SmallVector canonical; + canonical.reserve(validShape.size()); + for (int64_t dim : validShape) + canonical.push_back(dim < 0 ? ShapedType::kDynamic : dim); + return canonical; +} + +template +static LogicalResult dispatchVerifierByArch(Operation *op, FnA2A3 &&verifyA2A3, + FnA5 &&verifyA5) { + if (shouldBypassDecodedMemrefVerifier(op)) + return success(); + switch (getVerifierTargetArch(op)) { + case VerifierTargetArch::A2A3: + return verifyA2A3(); + case VerifierTargetArch::A5: + return verifyA5(); + } + return failure(); +} + +static ParseResult parseSyncEventOpCommon(OpAsmParser &parser, + OperationState &result, + StringAttr pipeAttrName, + StringAttr eventIdAttrName) { + PipeAttr pipeAttr; + if (succeeded(parser.parseOptionalLess())) { + StringRef pipeTok; + if (parser.parseKeyword(&pipeTok) || parser.parseGreater()) + return failure(); + auto pipeOr = symbolizePIPE(pipeTok); + if (!pipeOr) + return parser.emitError(parser.getCurrentLocation()) + << "unknown pipe token: " << pipeTok; + pipeAttr = PipeAttr::get(parser.getContext(), *pipeOr); + result.addAttribute(pipeAttrName, pipeAttr); + } else if (parser.parseAttribute(pipeAttr, pipeAttrName, + result.attributes)) { + return failure(); + } + if (parser.parseComma()) + return failure(); + + OpAsmParser::UnresolvedOperand eventOperand; + OptionalParseResult parseEventOperand = + parser.parseOptionalOperand(eventOperand); + if (parseEventOperand.has_value()) { + if (failed(*parseEventOperand)) + return failure(); + if (parser.resolveOperand(eventOperand, parser.getBuilder().getIndexType(), + result.operands)) + return failure(); + } else { + IntegerAttr eventAttr; + if (parser.parseAttribute(eventAttr, parser.getBuilder().getI32Type(), + eventIdAttrName, result.attributes)) + return failure(); + } + + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + return success(); +} + +static void printSyncEventOpCommon(OpAsmPrinter &p, Operation *op, + PipeAttr pipeAttr, IntegerAttr eventAttr, + Value eventDyn, StringRef pipeAttrName, + StringRef eventIdAttrName) { + p << " <" << stringifyPIPE(pipeAttr.getPipe()) << ">, "; + if (eventAttr) + p << eventAttr.getInt(); + else + p << eventDyn; + p.printOptionalAttrDict(op->getAttrs(), {pipeAttrName, eventIdAttrName}); +} + +[[maybe_unused]] static mlir::Type parsePTOTypeAllowNoBang(mlir::OpAsmParser &parser) { + mlir::Type ty; + + mlir::OptionalParseResult opt = parser.parseOptionalType(ty); + + if (opt.has_value()) { + if (failed(*opt)) + return mlir::Type(); + return ty; + } + + + llvm::StringRef head; + if (failed(parser.parseKeyword(&head))) + return mlir::Type(); + + mlir::MLIRContext *ctx = parser.getContext(); + + auto parseShapeElemForOpParser = + [&](llvm::SmallVectorImpl &shape, mlir::Type &elem) -> mlir::LogicalResult { + if (failed(parser.parseLess())) + return failure(); + if (failed(parser.parseDimensionList(shape, /*allowDynamic=*/true))) + return failure(); + if (failed(parser.parseType(elem))) + return failure(); + if (failed(parser.parseGreater())) + return failure(); + return success(); + }; + + if (head == "pto.tile_view") { + llvm::SmallVector shape; + mlir::Type elem; + if (failed(parseShapeElemForOpParser(shape, elem))) + return mlir::Type(); + return mlir::pto::PartitionTensorViewType::get(ctx, shape, elem); + } + + if (head == "pto.tile") { + llvm::SmallVector shape; + mlir::Type elem; + if (failed(parseShapeElemForOpParser(shape, elem))) + return mlir::Type(); + return mlir::pto::TileType::get(ctx, shape, elem); + } + + if (head == "pto.ptr") { + if (failed(parser.parseLess())) + return mlir::Type(); + mlir::Type elem; + if (failed(parser.parseType(elem))) + return mlir::Type(); + if (succeeded(parser.parseOptionalComma())) { + // ptr no longer accepts an address space; consume the attr for recovery. + mlir::Attribute memorySpace; + (void)parser.parseAttribute(memorySpace); + parser.emitError(parser.getCurrentLocation(), + "!pto.ptr no longer accepts address space; use !pto.ptr"); + return mlir::Type(); + } + if (failed(parser.parseGreater())) + return mlir::Type(); + return mlir::pto::PtrType::get(ctx, elem); + } + + if (head == "pto.tensor_view") { + llvm::SmallVector shape; + mlir::Type elem; + if (failed(parseShapeElemForOpParser(shape, elem))) + return mlir::Type(); + return mlir::pto::TensorViewType::get(ctx, shape, elem); + } + + return mlir::Type(); +} + +mlir::Type TensorViewType::parse(::mlir::AsmParser &parser) { + SmallVector shape; + Type elementType; + if (failed(parseShapeAndElem(parser, shape, elementType, /*allowDynamic=*/true))) + return Type(); + return TensorViewType::get(parser.getContext(), shape, elementType); +} + +void TensorViewType::print(::mlir::AsmPrinter &printer) const { + printShapeAndElem(printer, getShape(), getElementType()); +} + +//===----------------------------------------------------------------------===// +// pto.tdivs custom asm to support both: +// pto.tdivs ins(%src, %scalar : !pto.tile_buf<...>, f32) outs(%dst : !pto.tile_buf<...>) +// pto.tdivs ins(%scalar, %src : f32, !pto.tile_buf<...>) outs(%dst : !pto.tile_buf<...>) +// The operand order in the op follows textual input order. +//===----------------------------------------------------------------------===// + +ParseResult mlir::pto::TDivSOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand op0, op1, dst; + Type ty0, ty1, dstTy; + + if (parser.parseKeyword("ins") || parser.parseLParen() || + parser.parseOperand(op0) || parser.parseComma() || + parser.parseOperand(op1) || parser.parseColonType(ty0) || + parser.parseComma() || parser.parseType(ty1) || parser.parseRParen()) + return failure(); + + if (parser.parseKeyword("outs") || parser.parseLParen() || + parser.parseOperand(dst) || parser.parseColonType(dstTy) || + parser.parseRParen()) + return failure(); + + NamedAttrList attrs; + if (parser.parseOptionalAttrDict(attrs)) + return failure(); + + auto tile0 = dyn_cast(ty0); + auto tile1 = dyn_cast(ty1); + if ((tile0 && tile1) || (!tile0 && !tile1)) + return parser.emitError(parser.getCurrentLocation(), + "expected exactly one tile_buf operand and one scalar operand"); + + if (!dyn_cast(dstTy)) + return parser.emitError(parser.getCurrentLocation(), + "expected outs type to be !pto.tile_buf<...>"); + + // Keep textual order so later lowering can distinguish the two APIs by the + // first ins operand type. + if (parser.resolveOperand(op0, ty0, result.operands) || + parser.resolveOperand(op1, ty1, result.operands)) + return failure(); + + if (parser.resolveOperand(dst, dstTy, result.operands)) + return failure(); + + result.addAttributes(attrs); + return success(); +} + +void mlir::pto::TDivSOp::print(OpAsmPrinter &p) { + p << " ins("; + p << getSrc() << ", " << getScalar() << " : " + << getSrc().getType() << ", " << getScalar().getType(); + p << ") outs(" << getDst() << " : " << getDst().getType() << ")"; + + p.printOptionalAttrDict((*this)->getAttrs()); +} + + +//===----------------------------------------------------------------------===// +// pto.tgather custom asm supports three PTO-ISA forms: +// 1) index+tmp : ins(%src, %indices, %tmp : srcTy, indicesTy, tmpTy) outs(%dst : dstTy) +// 2) compare+tmp : ins(%src, %kValue, %tmp : srcTy, scalarTy, tmpTy) +// outs(%dst, %cdst : dstTy, cdstTy) {cmpMode = #pto.cmp, offset = 7} +// 3) mask : ins(%src, {maskPattern = #pto.mask_pattern} : srcTy) outs(%dst : dstTy) +//===----------------------------------------------------------------------===// + +ParseResult mlir::pto::TGatherOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand src, dst, cdst; + SmallVector insOps; + SmallVector insTypes; + Type srcTy, dstTy, cdstTy; + bool hasCdst = false; + bool hasMask = false; + bool hasIndices = false; + bool hasTmp = false; + bool hasKValue = false; + + if (parser.parseKeyword("ins") || parser.parseLParen() || parser.parseOperand(src)) + return failure(); + + if (!succeeded(parser.parseOptionalComma())) { + return parser.emitError(parser.getCurrentLocation(), + "expected ',' after src operand in ins(...)"); + } + + if (succeeded(parser.parseOptionalLBrace())) { + if (parser.parseKeyword("maskPattern") || parser.parseEqual()) + return failure(); + + Attribute rawMaskAttr; + if (parser.parseAttribute(rawMaskAttr) || parser.parseRBrace()) + return failure(); + + auto mp = llvm::dyn_cast(rawMaskAttr); + if (!mp) { + return parser.emitError(parser.getCurrentLocation(), + "expected #pto.mask_pattern for maskPattern"); + } + + result.addAttribute("maskPattern", mp); + hasMask = true; + + if (parser.parseColonType(srcTy) || parser.parseRParen()) + return failure(); + } else { + OpAsmParser::UnresolvedOperand extra; + if (parser.parseOperand(extra)) + return failure(); + insOps.push_back(extra); + while (succeeded(parser.parseOptionalComma())) { + if (insOps.size() == 3) { + return parser.emitError(parser.getCurrentLocation(), + "expected at most 3 extra operands in tgather ins(...)"); + } + if (parser.parseOperand(extra)) + return failure(); + insOps.push_back(extra); + } + + if (parser.parseColon() || parser.parseType(srcTy)) + return failure(); + for (size_t i = 0; i < insOps.size(); ++i) { + Type ty; + if (parser.parseComma() || parser.parseType(ty)) + return failure(); + insTypes.push_back(ty); + } + if (parser.parseRParen()) + return failure(); + } + + if (parser.parseKeyword("outs") || parser.parseLParen() || parser.parseOperand(dst)) + return failure(); + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseOperand(cdst)) + return failure(); + hasCdst = true; + } + if (parser.parseColonType(dstTy)) + return failure(); + if (hasCdst && (parser.parseComma() || parser.parseType(cdstTy))) + return failure(); + if (parser.parseRParen()) + return failure(); + + if (succeeded(parser.parseOptionalKeyword("maskPattern"))) { + if (hasMask) + return parser.emitError(parser.getCurrentLocation(), + "maskPattern may only be specified once"); + if (parser.parseEqual()) + return failure(); + Attribute rawMaskAttr; + if (parser.parseAttribute(rawMaskAttr)) + return failure(); + auto mp = llvm::dyn_cast(rawMaskAttr); + if (!mp) { + return parser.emitError(parser.getCurrentLocation(), + "expected #pto.mask_pattern for maskPattern"); + } + result.addAttribute("maskPattern", mp); + hasMask = true; + } + + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + if (hasMask) { + if (!insOps.empty()) + return parser.emitError(parser.getCurrentLocation(), + "mask-pattern tgather does not take extra ins operands"); + if (hasCdst) + return parser.emitError(parser.getCurrentLocation(), + "mask-pattern tgather expects a single outs operand"); + } else if (hasCdst) { + if (insOps.empty() || + !(mlir::isa(insTypes.front()) || + mlir::isa(insTypes.front()))) + return parser.emitError(parser.getCurrentLocation(), + "compare-form tgather expects a scalar kValue operand"); + hasKValue = true; + if (insOps.size() >= 2) { + if (!isTileLikeType(insTypes[1])) + return parser.emitError(parser.getCurrentLocation(), + "compare-form tgather tmp must be tile-like"); + hasTmp = true; + } + if (insOps.size() == 3) { + return parser.emitError(parser.getCurrentLocation(), + "compare-form tgather expects at most src, kValue, tmp in ins(...)"); + } + } else { + if (!insOps.empty() && !isTileLikeType(insTypes.front())) { + return parser.emitError(parser.getCurrentLocation(), + "index-form tgather expects tile-like indices; " + "compare-form must use outs(dst, cdst)"); + } + if (!insOps.empty()) { + hasIndices = true; + if (insOps.size() >= 2) { + if (!isTileLikeType(insTypes[1])) + return parser.emitError(parser.getCurrentLocation(), + "index-form tgather tmp must be tile-like"); + hasTmp = true; + } + } + if (insOps.size() == 3) { + return parser.emitError(parser.getCurrentLocation(), + "index-form tgather expects at most src, indices, tmp in ins(...)"); + } + } + + if (parser.resolveOperand(src, srcTy, result.operands) || + parser.resolveOperand(dst, dstTy, result.operands)) + return failure(); + if (hasCdst && parser.resolveOperand(cdst, cdstTy, result.operands)) + return failure(); + if (hasIndices && parser.resolveOperand(insOps[0], insTypes[0], result.operands)) + return failure(); + if (hasTmp && parser.resolveOperand(insOps[hasIndices ? 1 : 1], insTypes[1], result.operands)) + return failure(); + if (hasKValue && parser.resolveOperand(insOps[0], insTypes[0], result.operands)) + return failure(); + + result.addAttribute("operandSegmentSizes", + parser.getBuilder().getDenseI32ArrayAttr( + {1, 1, hasCdst ? 1 : 0, hasIndices ? 1 : 0, + hasTmp ? 1 : 0, hasKValue ? 1 : 0})); + return success(); +} + +void mlir::pto::TGatherOp::print(OpAsmPrinter &p) { + p << " ins(" << getSrc() << ", "; + if (auto mp = getMaskPatternAttr()) { + p << "{maskPattern = " << mp << "} : " << getSrc().getType(); + } else if (getCdst()) { + p << getKValue(); + if (getTmp()) { + p << ", " << getTmp(); + p << " : " << getSrc().getType() << ", " << getKValue().getType() + << ", " << getTmp().getType(); + } else { + p << " : " << getSrc().getType() << ", " << getKValue().getType(); + } + } else { + p << getIndices(); + if (getTmp()) { + p << ", " << getTmp(); + p << " : " << getSrc().getType() << ", " << getIndices().getType() + << ", " << getTmp().getType(); + } else { + p << " : " << getSrc().getType() << ", " << getIndices().getType(); + } + } + p << ") outs(" << getDst(); + if (getCdst()) + p << ", " << getCdst(); + p << " : " << getDst().getType(); + if (getCdst()) + p << ", " << getCdst().getType(); + p << ")"; + + if (getMaskPatternAttr()) { + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"maskPattern", "operandSegmentSizes"}); + } else { + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"operandSegmentSizes"}); + } +} + +ParseResult mlir::pto::TScatterOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand src, indexes, dst; + Type srcTy, idxTy, dstTy; + bool hasMask = false; + bool hasIndexes = false; + + if (parser.parseKeyword("ins") || parser.parseLParen() || + parser.parseOperand(src)) + return failure(); + + if (!succeeded(parser.parseOptionalComma())) + return parser.emitError(parser.getCurrentLocation(), + "expected ',' after src operand in ins(...)"); + + if (succeeded(parser.parseOptionalLBrace())) { + if (parser.parseKeyword("maskPattern") || parser.parseEqual()) + return failure(); + Attribute rawMaskAttr; + if (parser.parseAttribute(rawMaskAttr) || parser.parseRBrace()) + return failure(); + auto mp = llvm::dyn_cast(rawMaskAttr); + if (!mp) + return parser.emitError(parser.getCurrentLocation(), + "expected #pto.mask_pattern for maskPattern"); + result.addAttribute("maskPattern", mp); + hasMask = true; + if (parser.parseColonType(srcTy) || parser.parseRParen()) + return failure(); + } else { + if (parser.parseOperand(indexes)) + return failure(); + hasIndexes = true; + if (parser.parseColon() || parser.parseType(srcTy) || parser.parseComma() || + parser.parseType(idxTy) || parser.parseRParen()) + return failure(); + } + + if (parser.parseKeyword("outs") || parser.parseLParen() || + parser.parseOperand(dst) || parser.parseColonType(dstTy) || + parser.parseRParen()) + return failure(); + + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + if (result.attributes.get("maskPattern")) + hasMask = true; + + if (hasMask && hasIndexes) + return parser.emitError(parser.getCurrentLocation(), + "mask-pattern tscatter does not take indexes"); + if (!hasMask && !hasIndexes) + return parser.emitError(parser.getCurrentLocation(), + "expected indexes operand or maskPattern for tscatter"); + + if (parser.resolveOperand(src, srcTy, result.operands) || + parser.resolveOperand(dst, dstTy, result.operands) || + (hasIndexes && parser.resolveOperand(indexes, idxTy, result.operands))) + return failure(); + return success(); +} + +void mlir::pto::TScatterOp::print(OpAsmPrinter &p) { + p << " ins(" << getSrc() << ", "; + if (getMaskPatternAttr()) { + p << "{maskPattern = " << getMaskPatternAttr() << "} : " + << getSrc().getType(); + } else { + p << getIndexes() << " : " << getSrc().getType() << ", " + << getIndexes().getType(); + } + p << ") outs(" << getDst() << " : " << getDst().getType() << ")"; + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"maskPattern"}); +} + +namespace { + +struct CommRecvClause { + OpAsmParser::UnresolvedOperand ping; + std::optional pong; + Type pingTy; + Type pongTy; +}; + +static ParseResult parseCommRecvClause(OpAsmParser &parser, + CommRecvClause &recvClause) { + if (parser.parseKeyword("recv") || parser.parseLParen() || + parser.parseOperand(recvClause.ping)) + return failure(); + if (succeeded(parser.parseOptionalComma())) { + OpAsmParser::UnresolvedOperand pong; + if (parser.parseOperand(pong)) + return failure(); + recvClause.pong = pong; + } + return parser.parseRParen(); +} + +static ParseResult parseCommCollectiveTail( + OpAsmParser &parser, OperationState &result, + ArrayRef fixedOperands, + SmallVectorImpl &fixedTypes, CommRecvClause &recvClause, + SmallVectorImpl &groupOps, + SmallVectorImpl &groupTypes, ArrayRef operandSegmentsPrefix, + ArrayRef requiredAttrs) { + if (parser.parseComma() || parser.parseKeyword("group") || parser.parseLParen()) + return failure(); + + OpAsmParser::UnresolvedOperand group; + if (parser.parseOperand(group)) + return failure(); + groupOps.push_back(group); + while (succeeded(parser.parseOptionalComma())) { + if (parser.parseOperand(group)) + return failure(); + groupOps.push_back(group); + } + + if (parser.parseRParen()) + return failure(); + + if (parser.parseColon()) + return failure(); + + for (size_t i = 0; i < fixedTypes.size(); ++i) { + if (i != 0 && parser.parseComma()) + return failure(); + if (parser.parseType(fixedTypes[i])) + return failure(); + } + if (parser.parseComma() || parser.parseType(recvClause.pingTy)) + return failure(); + if (recvClause.pong) { + if (parser.parseComma() || parser.parseType(recvClause.pongTy)) + return failure(); + } + for (size_t i = 0; i < groupOps.size(); ++i) { + Type groupTy; + if (parser.parseComma() || parser.parseType(groupTy)) + return failure(); + groupTypes.push_back(groupTy); + } + if (parser.parseRParen()) + return failure(); + + NamedAttrList attrs; + if (parser.parseOptionalAttrDict(attrs)) + return failure(); + for (StringRef attrName : requiredAttrs) { + if (!attrs.get(attrName)) { + return parser.emitError(parser.getCurrentLocation()) + << "expected '" << attrName << "' attribute"; + } + } + result.addAttributes(attrs); + + for (auto [operand, type] : llvm::zip_equal(fixedOperands, fixedTypes)) { + if (parser.resolveOperand(operand, type, result.operands)) + return failure(); + } + if (parser.resolveOperand(recvClause.ping, recvClause.pingTy, result.operands)) + return failure(); + if (recvClause.pong && + parser.resolveOperand(*recvClause.pong, recvClause.pongTy, result.operands)) + return failure(); + if (parser.resolveOperands(groupOps, groupTypes, parser.getCurrentLocation(), + result.operands)) + return failure(); + + SmallVector segmentSizes(operandSegmentsPrefix.begin(), + operandSegmentsPrefix.end()); + segmentSizes.push_back(static_cast(groupOps.size())); + result.addAttribute("operandSegmentSizes", + parser.getBuilder().getDenseI32ArrayAttr(segmentSizes)); + return success(); +} + +static void printCommRecvClause(OpAsmPrinter &p, Value ping, Value pong) { + p << "recv(" << ping; + if (pong) + p << ", " << pong; + p << ")"; +} + +static void printCommGroupTypes(OpAsmPrinter &p, ValueRange group) { + for (Value groupValue : group) + p << ", " << groupValue.getType(); +} + +static void printCommGroupClause(OpAsmPrinter &p, ValueRange group) { + p << "group("; + p.printOperands(group); + p << ")"; +} + +} // namespace + +ParseResult mlir::pto::TBroadcastOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand src; + CommRecvClause recvClause; + SmallVector groupOps; + SmallVector groupTypes; + + if (parser.parseLParen() || parser.parseOperand(src) || parser.parseComma()) + return failure(); + if (failed(parseCommRecvClause(parser, recvClause))) + return failure(); + + SmallVector fixedOperands{src}; + SmallVector fixedTypes(1); + if (failed(parseCommCollectiveTail(parser, result, fixedOperands, fixedTypes, + recvClause, groupOps, groupTypes, + {1, 1, recvClause.pong ? 1 : 0}, {"root"}))) + return failure(); + return success(); +} + +void mlir::pto::TBroadcastOp::print(OpAsmPrinter &p) { + p << "(" << getSrc() << ", "; + printCommRecvClause(p, getPing(), getPong()); + p << ", "; + printCommGroupClause(p, getGroup()); + p << " : " << getSrc().getType() << ", " << getPing().getType(); + if (getPong()) + p << ", " << getPong().getType(); + printCommGroupTypes(p, getGroup()); + p << ")"; + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"operandSegmentSizes"}); +} + +ParseResult mlir::pto::CommTGatherOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand dst; + CommRecvClause recvClause; + SmallVector groupOps; + SmallVector groupTypes; + + if (parser.parseLParen() || parser.parseOperand(dst) || parser.parseComma()) + return failure(); + if (failed(parseCommRecvClause(parser, recvClause))) + return failure(); + + SmallVector fixedOperands{dst}; + SmallVector fixedTypes(1); + if (failed(parseCommCollectiveTail( + parser, result, fixedOperands, fixedTypes, recvClause, groupOps, + groupTypes, {1, 1, recvClause.pong ? 1 : 0}, + {"root"}))) + return failure(); + return success(); +} + +void mlir::pto::CommTGatherOp::print(OpAsmPrinter &p) { + p << "(" << getDst() << ", "; + printCommRecvClause(p, getPing(), getPong()); + p << ", "; + printCommGroupClause(p, getGroup()); + p << " : " << getDst().getType() << ", " << getPing().getType(); + if (getPong()) + p << ", " << getPong().getType(); + printCommGroupTypes(p, getGroup()); + p << ")"; + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"operandSegmentSizes"}); +} + +ParseResult mlir::pto::CommTScatterOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand src; + CommRecvClause recvClause; + SmallVector groupOps; + SmallVector groupTypes; + + if (parser.parseLParen() || parser.parseOperand(src) || parser.parseComma()) + return failure(); + if (failed(parseCommRecvClause(parser, recvClause))) + return failure(); + + SmallVector fixedOperands{src}; + SmallVector fixedTypes(1); + if (failed(parseCommCollectiveTail( + parser, result, fixedOperands, fixedTypes, recvClause, groupOps, + groupTypes, {1, 1, recvClause.pong ? 1 : 0}, + {"root"}))) + return failure(); + return success(); +} + +void mlir::pto::CommTScatterOp::print(OpAsmPrinter &p) { + p << "(" << getSrc() << ", "; + printCommRecvClause(p, getPing(), getPong()); + p << ", "; + printCommGroupClause(p, getGroup()); + p << " : " << getSrc().getType() << ", " << getPing().getType(); + if (getPong()) + p << ", " << getPong().getType(); + printCommGroupTypes(p, getGroup()); + p << ")"; + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"operandSegmentSizes"}); +} + +ParseResult mlir::pto::TReduceOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand dst, acc; + CommRecvClause recvClause; + SmallVector groupOps; + SmallVector groupTypes; + + if (parser.parseLParen() || parser.parseOperand(dst) || parser.parseComma() || + parser.parseOperand(acc) || parser.parseComma()) + return failure(); + if (failed(parseCommRecvClause(parser, recvClause))) + return failure(); + + SmallVector fixedOperands{dst, acc}; + SmallVector fixedTypes(2); + if (failed(parseCommCollectiveTail( + parser, result, fixedOperands, fixedTypes, recvClause, groupOps, + groupTypes, {1, 1, 1, recvClause.pong ? 1 : 0}, + {"reduceOp", "root"}))) + return failure(); + return success(); +} + +void mlir::pto::TReduceOp::print(OpAsmPrinter &p) { + p << "(" << getDst() << ", " << getAcc() << ", "; + printCommRecvClause(p, getRecvPing(), getRecvPong()); + p << ", "; + printCommGroupClause(p, getGroup()); + p << " : " << getDst().getType() << ", " << getAcc().getType() << ", " + << getRecvPing().getType(); + if (getRecvPong()) + p << ", " << getRecvPong().getType(); + printCommGroupTypes(p, getGroup()); + p << ")"; + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"operandSegmentSizes"}); +} + +ParseResult mlir::pto::MakeTensorViewOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand ptr; + SmallVector shapeOps; + SmallVector strideOps; + + Type resultTy; + + // %ptr + if (parser.parseOperand(ptr)) + return failure(); + + // , shape = [ ... ] + if (parser.parseComma() || parser.parseKeyword("shape") || parser.parseEqual() || + parser.parseLSquare() || + parser.parseOperandList(shapeOps) || + parser.parseRSquare()) + return failure(); + + // strides = [ ... ] + if (parser.parseComma() || parser.parseKeyword("strides") || parser.parseEqual() || + parser.parseLSquare() || + parser.parseOperandList(strideOps) || + parser.parseRSquare()) + return failure(); + + // attr-dict + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + // : result-type + if (parser.parseColonType(resultTy)) + return failure(); + result.addTypes(resultTy); + + auto tvTy = llvm::dyn_cast(resultTy); + if (!tvTy) + return parser.emitError(parser.getCurrentLocation(), + "expected result type pto.tensor_view<...>"); + + Type elemTy = tvTy.getElementType(); + + Type ptrTy = mlir::pto::PtrType::get(parser.getContext(), elemTy); + + // resolve %ptr + if (parser.resolveOperand(ptr, ptrTy, result.operands)) + return failure(); + + // resolve shape/strides 为 index + Type indexTy = parser.getBuilder().getIndexType(); + if (parser.resolveOperands(shapeOps, indexTy, result.operands)) + return failure(); + if (parser.resolveOperands(strideOps, indexTy, result.operands)) + return failure(); + + auto segAttr = parser.getBuilder().getDenseI32ArrayAttr( + {1, (int32_t)shapeOps.size(), (int32_t)strideOps.size()}); + result.addAttribute("operandSegmentSizes", segAttr); + + return success(); +} + +void mlir::pto::MakeTensorViewOp::print(OpAsmPrinter &p) { + p << " " << getPtr(); + + p << ", shape = ["; + p.printOperands(getShape()); + p << "]"; + + p << ", strides = ["; + p.printOperands(getStrides()); + p << "]"; + + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"operandSegmentSizes"}); + + p << " : " << getResult().getType(); +} + +// Layout inference helpers for make_tensor_view +static std::optional getConstIndexValue(Value v) { + if (auto c = v.getDefiningOp()) + return c.value(); + if (auto c = v.getDefiningOp()) { + if (auto ia = dyn_cast(c.getValue())) + return ia.getInt(); + } + return std::nullopt; +} + +static FailureOr +inferPartitionViewResultTypeFromSizes(mlir::pto::TensorViewType sourceType, + ValueRange sizes) { + if (!sourceType) + return failure(); + + if ((int64_t)sizes.size() != sourceType.getRank()) + return failure(); + + SmallVector shape; + shape.reserve(sizes.size()); + for (Value size : sizes) { + auto constSize = getConstIndexValue(size); + if (constSize && *constSize >= 0) + shape.push_back(*constSize); + else + shape.push_back(ShapedType::kDynamic); + } + + return mlir::pto::PartitionTensorViewType::get( + sourceType.getContext(), shape, sourceType.getElementType()); +} + +ParseResult mlir::pto::PartitionViewOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand source; + SmallVector offsets; + SmallVector sizes; + Type sourceTy; + Type resultTy; + bool hasExplicitResultTy = false; + + if (parser.parseOperand(source) || parser.parseComma() || + parser.parseKeyword("offsets") || parser.parseEqual() || + parser.parseLSquare() || parser.parseOperandList(offsets) || + parser.parseRSquare() || parser.parseComma() || + parser.parseKeyword("sizes") || parser.parseEqual() || + parser.parseLSquare() || parser.parseOperandList(sizes) || + parser.parseRSquare() || parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(sourceTy)) + return failure(); + + if (succeeded(parser.parseOptionalArrow())) { + if (parser.parseType(resultTy)) + return failure(); + hasExplicitResultTy = true; + } + + if (parser.resolveOperand(source, sourceTy, result.operands)) + return failure(); + + Type indexTy = parser.getBuilder().getIndexType(); + if (parser.resolveOperands(offsets, indexTy, result.operands) || + parser.resolveOperands(sizes, indexTy, result.operands)) + return failure(); + + auto &properties = result.getOrAddProperties(); + llvm::copy(ArrayRef( + {1, static_cast(offsets.size()), + static_cast(sizes.size())}), + properties.operandSegmentSizes.begin()); + + if (hasExplicitResultTy) { + result.addTypes(resultTy); + return success(); + } + + ValueRange allOperands(result.operands); + ValueRange sizeOperands = + allOperands.slice(1 + offsets.size(), sizes.size()); + auto inferredResultType = inferPartitionViewResultTypeFromSizes( + dyn_cast(sourceTy), sizeOperands); + if (failed(inferredResultType)) { + return parser.emitError(parser.getCurrentLocation(), + "failed to infer pto.partition_view result type"); + } + + result.addTypes(*inferredResultType); + return success(); +} + +void mlir::pto::PartitionViewOp::print(OpAsmPrinter &printer) { + printer << " " << getSource() << ", offsets = ["; + printer.printOperands(getOffsets()); + printer << "], sizes = ["; + printer.printOperands(getSizes()); + printer << "]"; + printer.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"operandSegmentSizes"}); + printer << " : " << getSource().getType(); + + auto inferredResultType = inferPartitionViewResultTypeFromSizes( + dyn_cast(getSource().getType()), getSizes()); + if (succeeded(inferredResultType) && *inferredResultType == getResult().getType()) + return; + + printer << " -> " << getResult().getType(); +} + +static std::optional getConstantIntegerValueEx( + Value v, bool includeIndexAndIntOpsInConstFold) { + if (includeIndexAndIntOpsInConstFold) { + if (auto c = v.getDefiningOp()) + return c.value(); + if (auto c = v.getDefiningOp()) + return c.value(); + } + if (auto c = v.getDefiningOp()) { + if (auto ia = dyn_cast(c.getValue())) + return ia.getInt(); + } + return std::nullopt; +} + +static LogicalResult verifyNonNegativeIndexRowCol( + Operation &op, Value indexRow, Value indexCol, + bool includeIndexAndIntOpsInConstFold) { + if (!indexRow.getType().isIndex() || !indexCol.getType().isIndex()) + return op.emitOpError("expects indexRow and indexCol to be index type"); + auto row = + getConstantIntegerValueEx(indexRow, includeIndexAndIntOpsInConstFold); + auto col = + getConstantIntegerValueEx(indexCol, includeIndexAndIntOpsInConstFold); + if (row && *row < 0) + return op.emitOpError("expects indexRow to be non-negative"); + if (col && *col < 0) + return op.emitOpError("expects indexCol to be non-negative"); + return success(); +} + +static LogicalResult verifyExtractStaticBoundsCommon( + Operation &op, Value indexRow, Value indexCol, Type srcTy, Type dstTy, + bool includeIndexAndIntOpsInConstFold) { + auto row = + getConstantIntegerValueEx(indexRow, includeIndexAndIntOpsInConstFold); + auto col = + getConstantIntegerValueEx(indexCol, includeIndexAndIntOpsInConstFold); + auto srcShape = getShapeVec(srcTy); + auto dstShape = getShapeVec(dstTy); + if (srcShape.size() != 2 || dstShape.size() != 2) + return op.emitOpError("expects src and dst to be rank-2 tile_buf"); + if (row && srcShape[0] != ShapedType::kDynamic && + dstShape[0] != ShapedType::kDynamic && + *row + dstShape[0] > srcShape[0]) + return op.emitOpError("expects indexRow + dst.rows <= src.rows"); + if (col && srcShape[1] != ShapedType::kDynamic && + dstShape[1] != ShapedType::kDynamic && + *col + dstShape[1] > srcShape[1]) + return op.emitOpError("expects indexCol + dst.cols <= src.cols"); + return success(); +} + +static LogicalResult verifyInsertStaticBoundsCommon( + Operation &op, Value indexRow, Value indexCol, Type srcTy, Type dstTy, + bool includeIndexAndIntOpsInConstFold) { + auto row = + getConstantIntegerValueEx(indexRow, includeIndexAndIntOpsInConstFold); + auto col = + getConstantIntegerValueEx(indexCol, includeIndexAndIntOpsInConstFold); + auto srcShape = getValidShapeVec(srcTy); + auto dstShape = getShapeVec(dstTy); + if (srcShape.size() != 2 || dstShape.size() != 2) + return op.emitOpError("expects src and dst to be rank-2 tile_buf"); + if (row && srcShape[0] != ShapedType::kDynamic && + dstShape[0] != ShapedType::kDynamic && + *row + srcShape[0] > dstShape[0]) + return op.emitOpError("expects indexRow + src.rows <= dst.rows"); + if (col && srcShape[1] != ShapedType::kDynamic && + dstShape[1] != ShapedType::kDynamic && + *col + srcShape[1] > dstShape[1]) + return op.emitOpError("expects indexCol + src.cols <= dst.cols"); + return success(); +} + +static unsigned getElemByteSize(Type ty) { + return getPTOStorageElemByteSize(ty); +} + +static LogicalResult verifyTileBufLayoutConstraints(Operation *op, + pto::TileBufType tb, + StringRef name) { + auto shape = tb.getShape(); + if (shape.size() != 2) + return op->emitOpError() << "expects " << name << " to be rank-2"; + + int64_t rows = shape[0]; + int64_t cols = shape[1]; + if (rows != ShapedType::kDynamic && rows <= 0) + return op->emitOpError() << "expects " << name << " rows to be positive"; + if (cols != ShapedType::kDynamic && cols <= 0) + return op->emitOpError() << "expects " << name << " cols to be positive"; + + unsigned elemBytes = getElemByteSize(tb.getElementType()); + if (elemBytes == 0) + return op->emitOpError() << "expects " << name + << " element type to have a byte size"; + + auto cfg = tb.getConfigAttr(); + if (!cfg) + cfg = TileBufConfigAttr::getDefault(tb.getContext()); + auto readBLayout = [](Attribute attr, int32_t &out) -> bool { + if (auto layout = dyn_cast_or_null(attr)) { + out = static_cast(layout.getValue()); + return true; + } + if (auto value = dyn_cast_or_null(attr)) { + out = static_cast(value.getInt()); + return true; + } + return false; + }; + auto readSLayout = [](Attribute attr, int32_t &out) -> bool { + if (auto layout = dyn_cast_or_null(attr)) { + out = static_cast(layout.getValue()); + return true; + } + if (auto value = dyn_cast_or_null(attr)) { + out = static_cast(value.getInt()); + return true; + } + return false; + }; + int32_t blayout = 0; + int32_t slayout = 0; + if (!readBLayout(cfg.getBLayout(), blayout) || + !readSLayout(cfg.getSLayout(), slayout)) + return op->emitOpError() << "expects " << name + << " to have concrete tile layout attributes"; + constexpr int64_t kAlignedBytes = 32; + + auto checkByteAlignment = [&](int64_t dim, StringRef layoutName, + StringRef byteExpr) -> LogicalResult { + if (dim == ShapedType::kDynamic) + return success(); + int64_t bytes = dim * static_cast(elemBytes); + if (bytes % kAlignedBytes == 0) + return success(); + return op->emitOpError() + << "expects " << name << " " << layoutName + << " none_box tile " << byteExpr + << " to be 32-byte aligned, but got " << bytes << " bytes"; + }; + + if (slayout == static_cast(SLayout::NoneBox)) { + if (blayout == static_cast(BLayout::RowMajor)) + return checkByteAlignment(cols, "row-major", + "row byte size (cols * sizeof(dtype))"); + return checkByteAlignment(rows, "col-major", + "column byte size (rows * sizeof(dtype))"); + } + + int64_t innerRows = 0; + int64_t innerCols = 0; + int32_t fractal = static_cast(cfg.getSFractalSize().getInt()); + switch (fractal) { + case 1024: + innerRows = 16; + innerCols = 16; + break; + case 32: + innerRows = 16; + innerCols = 2; + break; + case 512: + if (kAlignedBytes % elemBytes != 0) + return op->emitOpError() << "expects " << name + << " element byte size to divide 32 for boxed " + "fractal-512 tile layout"; + if (slayout == static_cast(SLayout::RowMajor)) { + innerRows = 16; + innerCols = kAlignedBytes / static_cast(elemBytes); + } else if (slayout == static_cast(SLayout::ColMajor)) { + innerRows = kAlignedBytes / static_cast(elemBytes); + innerCols = 16; + } + break; + default: + break; + } + if (innerRows <= 0 || innerCols <= 0) + return op->emitOpError() << "expects " << name + << " to use a supported boxed tile layout"; + + auto loc = getPTOMemorySpaceEnum(tb); + bool allowUnalignedRows = + (loc && *loc == pto::AddressSpace::VEC) || fractal == 32 || rows == 1; + if (!allowUnalignedRows && rows != ShapedType::kDynamic && + rows % innerRows != 0) + return op->emitOpError() + << "expects " << name + << " boxed tile rows to be a multiple of innerRows (" << innerRows + << "), but got " << rows; + if (cols != ShapedType::kDynamic && cols % innerCols != 0) + return op->emitOpError() + << "expects " << name + << " boxed tile cols to be a multiple of innerCols (" << innerCols + << "), but got " << cols; + + return success(); +} + +[[maybe_unused]] static bool isSupportedLoadStoreElemTypeA2A3(Type ty) { + if (ty.isF16() || ty.isBF16() || ty.isF32()) + return true; + if (auto it = dyn_cast(ty)) { + unsigned width = it.getWidth(); + return width == 8 || width == 16 || width == 32 || width == 64; + } + return false; +} + +static bool isSupportedGatherElemTypeA2A3(Type ty) { + if (ty.isF16() || ty.isF32()) + return true; + if (auto it = dyn_cast(ty)) { + unsigned width = it.getWidth(); + return width == 16 || width == 32; + } + return false; +} + +static bool isSupportedGatherElemTypeA5(Type ty) { + if (isSupportedGatherElemTypeA2A3(ty) || ty.isBF16()) + return true; + if (auto ft = dyn_cast(ty)) { + unsigned width = ft.getWidth(); + return width == 8; + } + if (auto it = dyn_cast(ty)) + return it.getWidth() == 8 || it.getWidth() == 16 || it.getWidth() == 32; + return false; +} + +static std::optional +inferLayout(ArrayRef shape, ArrayRef strides, + unsigned elemBytes) { + if (shape.size() != strides.size() || elemBytes == 0) + return std::nullopt; + + // NZ / fractal: rank>=5, check middle dims (sh3/sh4/sh5 per spec) + if (shape.size() >= 5) { + int64_t sh3 = shape[2], sh4 = shape[3], sh5 = shape[4]; + int64_t st4 = strides[3], st5 = strides[4]; + bool alignMatch = (sh3 == 16) && (sh3 * sh4 * elemBytes == 512); + bool strideMatch = (st5 == 1) && (st4 == sh5); + if (alignMatch && strideMatch) + return mlir::pto::Layout::NZ; + } + + // ND: row-major contiguous + bool isRowMajor = true; + for (int i = 0, e = (int)shape.size() - 1; i < e; ++i) { + if (strides[i] != strides[i + 1] * shape[i + 1]) { + isRowMajor = false; + break; + } + } + if (isRowMajor && strides.back() == 1) + return mlir::pto::Layout::ND; + + // DN: col-major + bool isColMajor = true; + for (int i = 0, e = (int)shape.size() - 1; i < e; ++i) { + if (strides[i + 1] != strides[i] * shape[i]) { + isColMajor = false; + break; + } + } + if (isColMajor && strides.front() == 1) + return mlir::pto::Layout::DN; + + return mlir::pto::Layout::ND; // fallback +} + +static std::optional getLogicalViewLayout(Value value) { + if (!value) + return std::nullopt; + if (auto part = value.getDefiningOp()) + return getLogicalViewLayout(part.getSource()); + if (auto make = value.getDefiningOp()) { + auto tvTy = dyn_cast(make.getResult().getType()); + if (!tvTy) + return std::nullopt; + SmallVector shape(tvTy.getShape().begin(), tvTy.getShape().end()); + SmallVector strides; + strides.reserve(make.getStrides().size()); + for (Value stride : make.getStrides()) { + auto cst = getConstIndexValue(stride); + if (!cst) + return std::nullopt; + strides.push_back(*cst); + } + return inferLayout(shape, strides, getElemByteSize(tvTy.getElementType())); + } + return std::nullopt; +} + +static std::optional getTileBufLogicalLayout(pto::TileBufType type) { + if (!type) + return std::nullopt; + int32_t sl = type.getSLayoutValueI32(); + int32_t bl = type.getBLayoutValueI32(); + if (sl != static_cast(pto::SLayout::NoneBox)) + return pto::Layout::NZ; + if (bl == static_cast(pto::BLayout::RowMajor)) + return pto::Layout::ND; + if (bl == static_cast(pto::BLayout::ColMajor)) + return pto::Layout::DN; + return std::nullopt; +} + +static bool isRowMajorTileBuf(Type ty) { + auto tb = mlir::dyn_cast(ty); + return tb && tb.getBLayoutValueI32() == static_cast(pto::BLayout::RowMajor); +} + +static LogicalResult verifyRowReductionSrcLayout(Operation *op, Type ty, + StringRef name) { + if (failed(verifyTileBufCommon(op, ty, name))) + return failure(); + auto as = getPTOMemorySpaceEnum(ty); + if (!as || *as != pto::AddressSpace::VEC) + return op->emitOpError() << "expects " << name << " to be in the vec address space"; + if (auto tb = dyn_cast(ty)) { + if (tb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor)) + return op->emitOpError() << "expects " << name << " to use the row_major blayout"; + } + if (auto mr = dyn_cast(ty)) + (void)mr; + if (auto tb = dyn_cast(ty)) { + if (tb.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) + return op->emitOpError() << "expects " << name + << " to use the none_box slayout"; + } + if (auto tb = dyn_cast(ty)) { + auto layout = getTileBufLogicalLayout(tb); + if (layout && *layout != pto::Layout::ND) + return op->emitOpError() << "expects " << name + << " to use an ND-style tile layout"; + } + return success(); +} + +static LogicalResult verifyRowReductionDstLayout(Operation *op, Type ty, + StringRef name) { + if (failed(verifyTileBufCommon(op, ty, name))) + return failure(); + auto as = getPTOMemorySpaceEnum(ty); + if (!as || *as != pto::AddressSpace::VEC) + return op->emitOpError() << "expects " << name << " to be in the vec address space"; + if (auto tb = dyn_cast(ty)) { + if (tb.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) + return op->emitOpError() << "expects " << name + << " to use the none_box slayout"; + } + if (auto tb = dyn_cast(ty)) { + if (tb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) && + tb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor)) + return op->emitOpError() << "expects " << name + << " to use the row_major or col_major blayout"; + } + if (auto mr = dyn_cast(ty)) + (void)mr; + if (auto tb = dyn_cast(ty)) { + auto layout = getTileBufLogicalLayout(tb); + if (layout && *layout == pto::Layout::DN) { + auto shape = getShapeVec(ty); + if (shape.size() == 2 && shape[1] != ShapedType::kDynamic && shape[1] != 1) + return op->emitOpError() << "expects DN-style " << name + << " to have shape[1] == 1"; + return success(); + } + if (layout && *layout == pto::Layout::ND) + return success(); + if (layout) + return op->emitOpError() << "expects " << name + << " to use a DN-style column vector tile or legacy ND-style tile"; + } + return success(); + auto valid = getValidShapeVec(ty); + if (valid.size() != 2) + return op->emitOpError() << "expects " << name << " to have rank-2 valid_shape"; + if (valid[1] != ShapedType::kDynamic && valid[1] != 1) + return op->emitOpError() << "expects " << name << " valid_shape[1] to be 1"; + return success(); +} + +static LogicalResult verifyRowReductionValidRegion(Operation *op, Type srcTy, + Type dstTy) { + auto srcValid = getValidShapeVec(srcTy); + auto dstValid = getValidShapeVec(dstTy); + if (srcValid.size() != 2 || dstValid.size() != 2) + return op->emitOpError("expects src and dst to have rank-2 valid_shape"); + if (srcValid[0] != ShapedType::kDynamic && srcValid[0] == 0) + return op->emitOpError("expects src valid_shape[0] to be non-zero"); + if (srcValid[1] != ShapedType::kDynamic && srcValid[1] == 0) + return op->emitOpError("expects src valid_shape[1] to be non-zero"); + if (srcValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && + srcValid[0] != dstValid[0]) + return op->emitOpError("expects src and dst to have the same valid_shape[0]"); + if (dstValid[1] != ShapedType::kDynamic && dstValid[1] != 1) + return op->emitOpError("expects dst valid_shape[1] to be 1"); + return success(); +} + +static bool isSupportedRowReductionElemType(Type elem) { + return elem.isInteger(16) || elem.isInteger(32) || elem.isF16() || + elem.isF32(); +} + +static LogicalResult verifyTRowReductionNoTmpCommon(Operation *op, Type srcTy, + Type dstTy, + StringRef elemTypeError) { + if (failed(verifyRowReductionSrcLayout(op, srcTy, "src")) || + failed(verifyRowReductionDstLayout(op, dstTy, "dst"))) + return failure(); + if (getElemTy(srcTy) != getElemTy(dstTy)) + return op->emitOpError("expects src and dst to have the same element type"); + if (failed(verifyRowReductionValidRegion(op, srcTy, dstTy))) + return failure(); + if (!isSupportedRowReductionElemType(getElemTy(srcTy))) + return op->emitOpError(elemTypeError); + return success(); +} + +static LogicalResult verifyTRowReductionWithTmpCommon(Operation *op, Type srcTy, + Type tmpTy, Type dstTy, + StringRef elemTypeError) { + if (failed(verifyRowReductionSrcLayout(op, srcTy, "src")) || + failed(verifyVecTileCommon(op, tmpTy, "tmp")) || + failed(verifyRowReductionDstLayout(op, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameElemType(op, srcTy, tmpTy, "src", "tmp")) || + failed(verifyTileBufSameValidShape(op, srcTy, tmpTy, "src", "tmp"))) + return failure(); + if (getElemTy(srcTy) != getElemTy(dstTy)) + return op->emitOpError("expects src and dst to have the same element type"); + if (failed(verifyRowReductionValidRegion(op, srcTy, dstTy))) + return failure(); + if (!isSupportedRowReductionElemType(getElemTy(srcTy))) + return op->emitOpError(elemTypeError); + return success(); +} + +static LogicalResult verifyTRowArgReductionCommon(Operation *op, Type srcTy, + Type tmpTy, Type dstTy) { + if (failed(verifyRowReductionSrcLayout(op, srcTy, "src")) || + failed(verifyVecTileCommon(op, tmpTy, "tmp")) || + failed(verifyRowReductionDstLayout(op, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameElemType(op, srcTy, tmpTy, "src", "tmp")) || + failed(verifyTileBufSameValidShape(op, srcTy, tmpTy, "src", "tmp"))) + return failure(); + if (failed(verifyRowReductionValidRegion(op, srcTy, dstTy))) + return failure(); + Type srcElem = getElemTy(srcTy); + if (!isSupportedRowReductionElemType(srcElem)) + return op->emitOpError("expects src element type to be i16/i32/f16/f32"); + auto dstInt = dyn_cast(getElemTy(dstTy)); + if (!dstInt || dstInt.getWidth() != 32) + return op->emitOpError("expects dst element type to be i32 or ui32"); + return success(); +} + +static LogicalResult verifyNDStyleVecTile(Operation *op, Type ty, StringRef name) { + if (failed(verifyTileBufCommon(op, ty, name))) + return failure(); + auto as = getPTOMemorySpaceEnum(ty); + if (!as || *as != pto::AddressSpace::VEC) + return op->emitOpError() << "expects " << name << " to be in the vec address space"; + if (auto tb = dyn_cast(ty)) { + if (tb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor)) + return op->emitOpError() << "expects " << name << " to use the row_major blayout"; + if (tb.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) + return op->emitOpError() << "expects " << name << " to use the none_box slayout"; + } + return success(); +} + +static LogicalResult verifyColReductionValidRegion(Operation *op, Type srcTy, + Type dstTy, + bool requireNonZeroSrc) { + auto srcValid = getValidShapeVec(srcTy); + auto dstValid = getValidShapeVec(dstTy); + if (srcValid.size() != 2 || dstValid.size() != 2) + return op->emitOpError("expects src and dst to have rank-2 valid_shape"); + if (requireNonZeroSrc) { + if (srcValid[0] != ShapedType::kDynamic && srcValid[0] == 0) + return op->emitOpError("expects src valid_shape[0] to be non-zero"); + if (srcValid[1] != ShapedType::kDynamic && srcValid[1] == 0) + return op->emitOpError("expects src valid_shape[1] to be non-zero"); + } + if (srcValid[1] != ShapedType::kDynamic && dstValid[1] != ShapedType::kDynamic && + srcValid[1] != dstValid[1]) + return op->emitOpError("expects src and dst to have the same valid_shape[1]"); + return success(); +} + +static LogicalResult verifyColArgReductionDstLayout(Operation *op, Type ty, + StringRef name) { + if (failed(verifyNDStyleVecTile(op, ty, name))) + return failure(); + auto valid = getValidShapeVec(ty); + if (valid.size() != 2) + return op->emitOpError() << "expects " << name + << " to have rank-2 valid_shape"; + if (valid[0] != ShapedType::kDynamic && valid[0] != 1) + return op->emitOpError() << "expects " << name + << " valid_shape[0] to be 1"; + return success(); +} + +static std::optional getConstantIntegerValue(Value value) { + if (!value) + return std::nullopt; + if (auto arithCst = value.getDefiningOp()) { + if (auto intAttr = dyn_cast(arithCst.getValue())) + return intAttr.getInt(); + } + return std::nullopt; +} + +LogicalResult mlir::pto::MakeTensorViewOp::verify() { + auto tvTy = dyn_cast(getResult().getType()); + if (!tvTy) + return emitOpError("result must be pto.tensor_view<...>"); + + auto pty = dyn_cast(getPtr().getType()); + if (!pty) + return emitOpError("ptr operand must be !pto.ptr<...>"); + + if (pty.getElementType() != tvTy.getElementType()) + return emitOpError() << "ptr element type must match tensor_view element type, but got ptr=" + << pty.getElementType() << " view=" << tvTy.getElementType(); + + int64_t rank = tvTy.getRank(); + + if ((int64_t)getShape().size() != rank || (int64_t)getStrides().size() != rank) + return emitOpError() << "shape/strides operand counts must match tensor_view rank=" + << rank; + + // Detect dynamic shape/stride. + bool hasDynamicShape = llvm::any_of(tvTy.getShape(), [](int64_t v) { + return v == ShapedType::kDynamic; + }); + bool hasDynamicStride = llvm::any_of(getStrides(), [](Value s) { + return !getConstIndexValue(s).has_value(); + }); + + auto layoutAttr = getLayoutAttr(); + + // 1) Dynamic shape/stride without explicit layout: warn and keep going. + if ((hasDynamicShape || hasDynamicStride) && !layoutAttr) { + return success(); + } + + // 2) Static shape/stride with explicit layout: verify correctness. + bool allStaticStride = true; + SmallVector strideInts; + strideInts.reserve(getStrides().size()); + for (Value s : getStrides()) { + auto val = getConstIndexValue(s); + if (!val) { + allStaticStride = false; + break; + } + strideInts.push_back(*val); + } + + bool allStaticShape = + llvm::none_of(tvTy.getShape(), [](int64_t v) { return v == ShapedType::kDynamic; }); + + if (layoutAttr && allStaticShape && allStaticStride) { + SmallVector shapeInts(tvTy.getShape().begin(), tvTy.getShape().end()); + if (auto inferred = inferLayout(shapeInts, strideInts, + getElemByteSize(tvTy.getElementType()))) { + (void)inferred; + } + } + + return success(); +} + +LogicalResult mlir::pto::PartitionViewOp::verify() { + auto srcTy = dyn_cast(getSource().getType()); + auto resTy = dyn_cast(getResult().getType()); + if (!srcTy || !resTy) + return emitOpError("expects tensor_view source and partition_tensor_view result"); + + if (srcTy.getElementType() != resTy.getElementType()) + return emitOpError() << "element type mismatch between source and result: src=" + << srcTy.getElementType() << " result=" + << resTy.getElementType(); + + int64_t srcRank = srcTy.getRank(); + if ((int64_t)getOffsets().size() != srcRank) + return emitOpError() << "offset count (" << getOffsets().size() + << ") must match source rank (" << srcRank << ")"; + + if ((int64_t)getSizes().size() != srcRank) + return emitOpError() << "size count (" << getSizes().size() + << ") must match source rank (" << srcRank << ")"; + + ArrayRef srcShape = srcTy.getShape(); + ArrayRef resShape = resTy.getShape(); + bool sameRank = resTy.getRank() == srcRank; + + for (int64_t i = 0; i < srcRank; ++i) { + auto offVal = getConstIndexValue(getOffsets()[i]); + auto sizeVal = getConstIndexValue(getSizes()[i]); + + if (offVal && *offVal < 0) + return emitOpError() << "offset at dim " << i + << " must be non-negative, got " << *offVal; + + if (sizeVal && *sizeVal <= 0) + return emitOpError() << "size at dim " << i + << " must be positive, got " << *sizeVal; + + if (sameRank && sizeVal) { + int64_t resDim = resShape[i]; + if (resDim != ShapedType::kDynamic && *sizeVal != resDim) + return emitOpError() << "size/result mismatch at dim " << i + << ": size operand=" << *sizeVal + << " result type dim=" << resDim; + } + + int64_t srcDim = srcShape[i]; + if (srcDim == ShapedType::kDynamic) + continue; + + if (sizeVal && *sizeVal > srcDim) + return emitOpError() << "size at dim " << i << " (" << *sizeVal + << ") exceeds static source dim (" << srcDim << ")"; + + if (offVal && sizeVal && (*offVal + *sizeVal > srcDim)) + return emitOpError() << "offset+size at dim " << i << " (" + << (*offVal + *sizeVal) + << ") exceeds static source dim (" << srcDim << ")"; + } + + return success(); +} + +LogicalResult mlir::pto::AddPtrOp::verify() { + Value ptr = getOperation()->getOperand(0); + Value result = getOperation()->getResult(0); + + auto ptrTy = dyn_cast(ptr.getType()); + if (!ptrTy) + return emitOpError("ptr operand must be !pto.ptr<...>"); + + auto resTy = dyn_cast(result.getType()); + if (!resTy) + return emitOpError("result must be !pto.ptr<...>"); + + if (ptrTy != resTy) + return emitOpError("result type must match ptr operand type"); + + return success(); +} + +static LogicalResult verifyPtrLikeForAddressCast(Operation *op, Type type, + StringRef name) { + if (isa(type)) + return success(); + + auto memTy = dyn_cast(type); + if (!memTy) + return op->emitOpError() + << "expects " << name << " to be !pto.ptr<...> or a GM memref"; + + if (memTy.getRank() != 1) + return op->emitOpError() + << "expects lowered memref " << name << " to be rank-1"; + + if (!isGmAddressSpaceAttr(memTy.getMemorySpace())) + return op->emitOpError() + << "expects lowered memref " << name << " to use GM address space"; + + return success(); +} + +static Type getPointerLikeElementType(Type type) { + if (auto ptrTy = dyn_cast(type)) + return ptrTy.getElementType(); + if (auto memTy = dyn_cast(type)) + return memTy.getElementType(); + return Type(); +} + +static bool isEmitCSupportedScalarType(Type type) { + if (!type) + return false; + if (type.isF16() || type.isBF16() || type.isF32() || type.isF64()) + return true; + if (auto intTy = dyn_cast(type)) + return intTy.getWidth() == 8 || intTy.getWidth() == 16 || + intTy.getWidth() == 32 || intTy.getWidth() == 64; + if (mlir::pto::isPTOFloat8Type(type)) + return true; + if (isa(type)) + return true; + return false; +} + +LogicalResult mlir::pto::PtrToIntOp::verify() { + Type resultTy = getResult().getType(); + auto intTy = dyn_cast(resultTy); + if (!intTy || intTy.getWidth() != 64) + return emitOpError("result must be i64"); + + return verifyPtrLikeForAddressCast(getOperation(), getPtr().getType(), + "ptr operand"); +} + +LogicalResult mlir::pto::IntToPtrOp::verify() { + auto addrTy = dyn_cast(getAddr().getType()); + if (!addrTy || addrTy.getWidth() != 64) + return emitOpError("address operand must be i64"); + + if (failed(verifyPtrLikeForAddressCast(getOperation(), getResult().getType(), + "result"))) + return failure(); + + Type dstElem = getPointerLikeElementType(getResult().getType()); + if (!isEmitCSupportedScalarType(dstElem)) + return emitOpError("result element type is not supported by EmitC: ") + << dstElem; + + return success(); +} + +LogicalResult mlir::pto::LocalArrayGetOp::verify() { + auto arrayTy = getArray().getType(); + int64_t rank = arrayTy.getRank(); + int64_t numIdx = static_cast(getIndices().size()); + if (numIdx != rank) + return emitOpError() << "expects " << rank + << " indices for !pto.local_array of rank " << rank + << ", got " << numIdx; + if (getResult().getType() != arrayTy.getElementType()) + return emitOpError() + << "result type " << getResult().getType() + << " does not match array element type " + << arrayTy.getElementType(); + return success(); +} + +LogicalResult mlir::pto::LocalArraySetOp::verify() { + auto arrayTy = getArray().getType(); + int64_t rank = arrayTy.getRank(); + int64_t numIdx = static_cast(getIndices().size()); + if (numIdx != rank) + return emitOpError() << "expects " << rank + << " indices for !pto.local_array of rank " << rank + << ", got " << numIdx; + if (getValue().getType() != arrayTy.getElementType()) + return emitOpError() << "value type " << getValue().getType() + << " does not match array element type " + << arrayTy.getElementType(); + return success(); +} + + + + +void PTODialect::initialize() { + addTypes< +#define GET_TYPEDEF_LIST +#include "PTO/IR/PTOTypeDefs.cpp.inc" + >(); + + addOperations< +#define GET_OP_LIST +#include "PTO/IR/PTOOps.cpp.inc" + >(); + + addAttributes< +#define GET_ATTRDEF_LIST +#include "PTO/IR/PTOAttrs.cpp.inc" + >(); +} + + +AddressSpaceAttr mlir::pto::getPTOAddressSpaceAttr(Type type) { + auto memRefType = dyn_cast(type); + if (!memRefType) + return {}; + auto scopeAttr = dyn_cast(memRefType.getMemorySpace()); + if (!scopeAttr) + return {}; + return scopeAttr; +} + +bool mlir::pto::isScalarPtrOrMemRef(Type type) { + if (auto pty = dyn_cast(type)) + return true; + if (auto memTy = dyn_cast(type)) + return isGmAddressSpaceAttr(memTy.getMemorySpace()); + return false; +} + +bool mlir::pto::hasExplicitPTOEntryAttr(func::FuncOp func) { + return func && (func->hasAttrOfType(kPTOEntryAttrName) || + func->hasAttrOfType(kLegacyHACCEntryAttrName)); +} + +static constexpr StringLiteral kEffectivePTOEntryAttrName = + "pto.internal.entry"; + +static SmallVector getPTOFunctionDefinitions(ModuleOp module) { + SmallVector defs; + if (!module) + return defs; + for (auto func : module.getOps()) { + if (!func.isDeclaration()) + defs.push_back(func); + } + return defs; +} + +bool mlir::pto::isPTOEntryFunction(func::FuncOp func) { + if (!func || func.isDeclaration()) + return false; + if (auto attr = func->getAttrOfType(kEffectivePTOEntryAttrName)) + return attr.getValue(); + if (hasExplicitPTOEntryAttr(func)) + return true; + + ModuleOp module = func->getParentOfType(); + if (!module) + return false; + SmallVector defs = getPTOFunctionDefinitions(module); + return defs.size() == 1 && defs.front() == func; +} + +LogicalResult mlir::pto::validatePTOEntryFunctions(ModuleOp module) { + if (!module) + return success(); + + for (auto func : module.getOps()) { + if (!hasExplicitPTOEntryAttr(func)) + continue; + if (func.isDeclaration()) { + return func.emitOpError() + << "`" << kPTOEntryAttrName + << "` is only valid on function definitions"; + } + } + + for (auto func : module.getOps()) { + if (!isPTOEntryFunction(func)) + continue; + if (func.getFunctionType().getNumResults() != 0) { + return func.emitOpError() + << "PTO entry functions must return void"; + } + } + return success(); +} + +void mlir::pto::annotatePTOEntryFunctions(ModuleOp module) { + if (!module) + return; + + SmallVector defs = getPTOFunctionDefinitions(module); + for (auto func : module.getOps()) + func->removeAttr(kEffectivePTOEntryAttrName); + + if (defs.empty()) + return; + if (defs.size() == 1) { + defs.front()->setAttr(kEffectivePTOEntryAttrName, + BoolAttr::get(module.getContext(), true)); + return; + } + + for (auto func : defs) { + func->setAttr(kEffectivePTOEntryAttrName, + BoolAttr::get(module.getContext(), + hasExplicitPTOEntryAttr(func))); + } +} + +//===----------------------------------------------------------------------===// +// PTO Load/Store/Addf (non-DPS polymorphic) verification + inference. +// - If operands are memref/tensor: verify strictly. +// - Otherwise (tile_view/tile etc): accept (so old IR can still parse). +//===----------------------------------------------------------------------===// + +[[maybe_unused]] static LogicalResult verifyMemrefToTensorLoad(Operation *op, Value src, Value res) { + auto mr = dyn_cast(src.getType()); + auto rt = dyn_cast(res.getType()); + if (!mr) + return success(); // non-memref case: don't block old IR + if (!rt) + return op->emitOpError("when src is memref, result must be ranked tensor"); + + if (mr.getElementType() != rt.getElementType()) + return op->emitOpError() << "memref/tensor element type mismatch: memref=" + << mr.getElementType() << " tensor=" << rt.getElementType(); + + if (mr.getRank() != rt.getRank()) + return op->emitOpError() << "rank mismatch: memref rank=" << mr.getRank() + << " tensor rank=" << rt.getRank(); + + if (mr.hasStaticShape()) { + if (!rt.hasStaticShape()) + return op->emitOpError("memref has static shape but result tensor is not static"); + if (mr.getShape() != rt.getShape()) + return op->emitOpError() << "shape mismatch: memref=" << mr << " tensor=" << rt; + } else { + // For dynamic memref dims: if tensor dim is static, allow it; if it's dynamic too, also fine. + // We only reject when a memref static dim conflicts with tensor static dim. + for (int64_t i = 0; i < mr.getRank(); ++i) { + int64_t md = mr.getDimSize(i); + int64_t td = rt.getDimSize(i); + if (md != ShapedType::kDynamic && td != ShapedType::kDynamic && md != td) + return op->emitOpError() << "dim mismatch at " << i << ": memref=" << md << " tensor=" << td; + } + } + return success(); +} + +[[maybe_unused]] static LogicalResult verifyMemrefTensorStore(Operation *op, Value dst, Value src) { + auto mr = dyn_cast(dst.getType()); + if (!mr) + return success(); // non-memref case: old tile IR allowed + auto rt = dyn_cast(src.getType()); + if (!rt) + return op->emitOpError("when dst is memref, src must be ranked tensor"); + + if (mr.getElementType() != rt.getElementType()) + return op->emitOpError() << "memref/tensor element type mismatch: memref=" + << mr.getElementType() << " tensor=" << rt.getElementType(); + + if (mr.getRank() != rt.getRank()) + return op->emitOpError() << "rank mismatch: memref rank=" << mr.getRank() + << " tensor rank=" << rt.getRank(); + + for (int64_t i = 0; i < mr.getRank(); ++i) { + int64_t md = mr.getDimSize(i); + int64_t td = rt.getDimSize(i); + if (md != ShapedType::kDynamic && td != ShapedType::kDynamic && md != td) + return op->emitOpError() << "dim mismatch at " << i << ": memref=" << md << " tensor=" << td; + } + return success(); +} + +LogicalResult AllocTileOp::verify() { + auto ty = getResult().getType(); // TileBufType + + if (failed(verifyTileBufLayoutConstraints(*this, ty, "result"))) + return failure(); + + // op 上有没有传 operands + bool hasVR = getValidRow() != nullptr; + bool hasVC = getValidCol() != nullptr; + + // type 上的 validShape + auto vs = ty.getValidShape(); + if (vs.size() != 2) + return emitOpError("result tile_buf must have rank-2 validShape"); + + // TileBuf valid dims use a negative sentinel (e.g. '?' / -1). Be robust to + // any negative value (some code may materialize MLIR dynamic sentinels). + bool needVR = (vs[0] < 0); + bool needVC = (vs[1] < 0); + + // 你要求的:v_row=?, v_col=? 时必须同时给两个 + // (这条规则由下面两句自然实现) + if (hasVR != needVR) + return emitOpError() << "valid_row operand " + << (needVR ? "is required" : "must be absent") + << " because result type v_row is " + << (needVR ? "?" : std::to_string(vs[0])); + + if (hasVC != needVC) + return emitOpError() << "valid_col operand " + << (needVC ? "is required" : "must be absent") + << " because result type v_col is " + << (needVC ? "?" : std::to_string(vs[1])); + + return success(); +} + +LogicalResult MaterializeTileOp::verify() { + auto sourceTy = cast(getSource().getType()); + auto resultTy = cast(getResult().getType()); + + if (sourceTy.getRank() != 2) + return emitOpError("source memref must be rank-2 to materialize a tile handle"); + if (resultTy.getRank() != 2) + return emitOpError("result tile_buf must be rank-2"); + if (failed(verifyTileBufLayoutConstraints(*this, resultTy, "result"))) + return failure(); + + auto viewSemantics = (*this)->getAttrOfType("pto.view_semantics"); + bool isSubview = viewSemantics && viewSemantics.getValue() == "subview"; + if (!isSubview && sourceTy.getShape() != resultTy.getShape()) + return emitOpError() << "source/result shape mismatch: source=" + << sourceTy << " result=" << resultTy; + + if (sourceTy.getElementType() != resultTy.getElementType()) + return emitOpError() << "source/result element type mismatch: source=" + << sourceTy.getElementType() + << " result=" << resultTy.getElementType(); + + if (sourceTy.getMemorySpace() != resultTy.getMemorySpace()) + return emitOpError() << "source/result memory space mismatch"; + + if (getConfig() != resultTy.getConfigAttr()) + return emitOpError("config attribute must match the result tile_buf config"); + + auto shape = resultTy.getShape(); + auto validShape = resultTy.getValidShape(); + if (validShape.size() != 2) + return emitOpError("result tile_buf must have rank-2 validShape"); + for (unsigned i = 0; i < 2; ++i) { + if (shape[i] != ShapedType::kDynamic && + validShape[i] != ShapedType::kDynamic && validShape[i] > shape[i]) { + return emitOpError() << "valid_shape[" << i << "] must be <= shape[" + << i << "]"; + } + } + + return success(); +} + +LogicalResult TAssignOp::verify() { + if (getTile().getType() != getResult().getType()) { + return emitOpError("result type must match tile operand type"); + } + return success(); +} + +LogicalResult TLoadOp::verify() { + auto verifyCommon = + [&](bool allowLowPrecision) + -> FailureOr> { + auto srcPart = dyn_cast(getSrc().getType()); + auto dstTile = dyn_cast(getDst().getType()); + if (!srcPart || !dstTile) { + emitOpError("expects src to be !pto.partition_tensor_view and dst to be !pto.tile_buf"); + return failure(); + } + if (failed(verifyTileBufCommon(*this, dstTile, "dst", allowLowPrecision))) + return failure(); + + auto srcShape = srcPart.getShape(); + for (unsigned i = 0; i < srcShape.size(); ++i) { + if (srcShape[i] != ShapedType::kDynamic && srcShape[i] <= 0) { + emitOpError() << "expects src shape[" << i << "] to be positive"; + return failure(); + } + } + auto dstValid = dstTile.getValidShape(); + for (unsigned i = 0; i < dstValid.size(); ++i) { + if (dstValid[i] != ShapedType::kDynamic && dstValid[i] <= 0) { + emitOpError() << "expects dst valid_shape[" << i << "] to be positive"; + return failure(); + } + } + return std::make_pair(srcPart, dstTile); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + auto common = verifyCommon(/*allowLowPrecision=*/false); + if (failed(common)) + return failure(); + auto [srcPart, dstTile] = *common; + + Type srcElem = srcPart.getElementType(); + Type dstElem = dstTile.getElementType(); + if (isPTOLowPrecisionType(srcElem) || isPTOLowPrecisionType(dstElem)) + return emitOpError("expects A2/A3 tload low-precision element types to be unsupported"); + if (!(dstElem.isInteger(8) || dstElem.isInteger(16) || dstElem.isInteger(32) || + dstElem.isInteger(64) || dstElem.isF16() || dstElem.isBF16() || dstElem.isF32())) + return emitOpError("expects A2/A3 tload dst element type to be i8/i16/i32/i64/u64/f16/bf16/f32"); + + auto dstSpace = getPTOMemorySpaceEnum(dstTile); + if (!dstSpace || (*dstSpace != pto::AddressSpace::VEC && + *dstSpace != pto::AddressSpace::MAT)) + return emitOpError("expects A2/A3 tload dst to use loc=vec or loc=mat"); + + if (getElemByteSize(srcElem) != getElemByteSize(dstElem)) + return emitOpError("expects src and dst element types to have the same bitwidth"); + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + auto common = verifyCommon(/*allowLowPrecision=*/true); + if (failed(common)) + return failure(); + auto [srcPart, dstTile] = *common; + + Type srcElem = srcPart.getElementType(); + Type dstElem = dstTile.getElementType(); + unsigned srcBytes = getElemByteSize(srcElem); + unsigned dstBytes = getElemByteSize(dstElem); + if (srcBytes != dstBytes) + return emitOpError("expects src and dst element types to have the same element size"); + if (!(dstBytes == 1 || dstBytes == 2 || dstBytes == 4 || dstBytes == 8)) + return emitOpError("expects A5 tload dst element size to be 1, 2, 4, or 8 bytes"); + if (!isA5TLoadStoreTransferElemType(srcElem)) + return emitOpError("expects A5 tload src element type to be i8/i16/i32/i64/f16/bf16/f32/f8/hif8/fp4"); + if (!isA5TLoadStoreTransferElemType(dstElem)) + return emitOpError("expects A5 tload dst element type to be i8/i16/i32/i64/f16/bf16/f32/f8/hif8/fp4"); + + if (dstElem.isInteger(64)) { + auto pad = dstTile.getPadValueI32(); + if (pad != static_cast(pto::PadValue::Null) && + pad != static_cast(pto::PadValue::Zero)) + return emitOpError("expects A5 i64/u64 tload dst pad to be null or zero"); + } + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult TPrefetchOp::verify() { + auto verifyImpl = [&](bool allowLowPrecision) -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + + Type srcElem; + Type dstElem; + + if (auto srcPart = dyn_cast(srcTy)) { + auto srcShape = srcPart.getShape(); + for (unsigned i = 0; i < srcShape.size(); ++i) { + if (srcShape[i] != ShapedType::kDynamic && srcShape[i] <= 0) + return emitOpError() << "expects src shape[" << i << "] to be positive"; + } + srcElem = srcPart.getElementType(); + } else if (auto srcMr = dyn_cast(srcTy)) { + if (!srcMr.hasRank()) + return emitOpError("expects src memref to be ranked"); + for (int64_t dim : srcMr.getShape()) { + if (dim != ShapedType::kDynamic && dim <= 0) + return emitOpError("expects src memref shape to be positive"); + } + srcElem = srcMr.getElementType(); + } else { + return emitOpError("expects src to be !pto.partition_tensor_view or memref"); + } + + if (auto dstTile = dyn_cast(dstTy)) { + if (failed(verifyTileBufCommon(*this, dstTile, "dst", allowLowPrecision))) + return failure(); + auto dstValid = dstTile.getValidShape(); + for (unsigned i = 0; i < dstValid.size(); ++i) { + if (dstValid[i] != ShapedType::kDynamic && dstValid[i] <= 0) + return emitOpError() << "expects dst valid_shape[" << i << "] to be positive"; + } + auto dstSpace = getPTOMemorySpaceEnum(dstTile); + if (!dstSpace || (*dstSpace != pto::AddressSpace::VEC && + *dstSpace != pto::AddressSpace::MAT)) + return emitOpError("expects dst to use loc=vec or loc=mat"); + dstElem = dstTile.getElementType(); + } else if (auto dstMr = dyn_cast(dstTy)) { + auto dstSpace = getPTOMemorySpaceEnum(dstMr); + if (!dstSpace || (*dstSpace != pto::AddressSpace::VEC && + *dstSpace != pto::AddressSpace::MAT)) + return emitOpError("expects dst memref to use loc=vec or loc=mat"); + if (!dstMr.hasRank()) + return emitOpError("expects dst memref to be ranked"); + if (failed(verifyTileBufCommon(*this, dstMr, "dst", allowLowPrecision))) + return failure(); + dstElem = dstMr.getElementType(); + } else { + return emitOpError("expects dst to be !pto.tile_buf or memref"); + } + + if (getElemByteSize(srcElem) != getElemByteSize(dstElem)) + return emitOpError("expects src and dst element types to have the same element size"); + if (!allowLowPrecision && + (isPTOLowPrecisionType(srcElem) || isPTOLowPrecisionType(dstElem))) + return emitOpError("expects A2/A3 tprefetch low-precision element types to be unsupported"); + if (allowLowPrecision && + (!isA5TLoadStoreTransferElemType(srcElem) || + !isA5TLoadStoreTransferElemType(dstElem))) + return emitOpError("expects A5 tprefetch element types to be i8/i16/i32/i64/f16/bf16/f32/f8/hif8/fp4"); + return success(); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + return verifyImpl(/*allowLowPrecision=*/false); + }; + auto verifyA5 = [&]() -> LogicalResult { + return verifyImpl(/*allowLowPrecision=*/true); + }; + switch (getVerifierTargetArch(getOperation())) { + case VerifierTargetArch::A2A3: + return verifyA2A3(); + case VerifierTargetArch::A5: + return verifyA5(); + } + return failure(); +} + +LogicalResult MakePrefetchAsyncContextOp::verify() { + Type workspaceTy = getWorkspace().getType(); + Type elemTy = nullptr; + if (auto ptrTy = dyn_cast(workspaceTy)) { + elemTy = ptrTy.getElementType(); + } else if (auto memTy = dyn_cast(workspaceTy)) { + if (!isGmAddressSpaceAttr(memTy.getMemorySpace())) + return emitOpError("expects workspace memref to be in GM address space"); + elemTy = memTy.getElementType(); + } else { + return emitOpError("expects workspace to be !pto.ptr or GM memref"); + } + if (!isByteIntegerType(elemTy)) + return emitOpError("expects workspace element type to be an 8-bit integer"); + return success(); +} + +LogicalResult TPrefetchAsyncOp::verify() { + if (failed(verifyAsyncFlatContiguous1DGMViewLike(getOperation(), getSrc(), + "src"))) + return failure(); + return success(); +} + +LogicalResult mlir::pto::SetFFTsOp::verify() { + auto mr = llvm::dyn_cast(getFfts().getType()); + if (!mr) + return emitOpError("expects a memref operand"); + + if (!mr.getElementType().isInteger(64) && !mr.getElementType().isInteger(8)) + return emitOpError("expects element type i64 (or i8)"); + + return mlir::success(); +} + +ParseResult mlir::pto::SyncSetOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseSyncEventOpCommon(parser, result, + SyncSetOp::getPipeAttrName(result.name), + SyncSetOp::getEventIdAttrName(result.name)); +} + +void mlir::pto::SyncSetOp::print(OpAsmPrinter &p) { + printSyncEventOpCommon(p, getOperation(), getPipe(), getEventIdAttr(), + getEventIdDyn(), getPipeAttrName().getValue(), + getEventIdAttrName().getValue()); +} + +LogicalResult mlir::pto::SyncSetOp::verify() { + bool hasStatic = getEventIdAttr() != nullptr; + bool hasDynamic = static_cast(getEventIdDyn()); + if (hasStatic == hasDynamic) + return emitOpError() + << "expects exactly one event-id form: static attr or dynamic index operand"; + if (IntegerAttr fftsModeAttr = getFftsModeAttr()) { + int64_t fftsMode = fftsModeAttr.getInt(); + if (fftsMode < 0 || fftsMode > 2) + return emitOpError() << "requires ffts_mode in range [0, 2], but got " + << fftsMode; + } + + auto verifyA2A3 = [&]() -> LogicalResult { return success(); }; + auto verifyA5 = [&]() -> LogicalResult { + switch (getPipe().getPipe()) { + case PIPE::PIPE_FIX: + case PIPE::PIPE_MTE3: + return success(); + default: + return emitOpError() + << "A5 sync.set expects pipe to be one of , "; + } + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +ParseResult mlir::pto::SyncWaitOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseSyncEventOpCommon(parser, result, + SyncWaitOp::getPipeAttrName(result.name), + SyncWaitOp::getEventIdAttrName(result.name)); +} + +void mlir::pto::SyncWaitOp::print(OpAsmPrinter &p) { + printSyncEventOpCommon(p, getOperation(), getPipe(), getEventIdAttr(), + getEventIdDyn(), getPipeAttrName().getValue(), + getEventIdAttrName().getValue()); +} + +ParseResult mlir::pto::SyncAllOp::parse(OpAsmParser &parser, + OperationState &result) { + SmallVector operands; + SmallVector operandTypes; + Attribute modeAttr; + Attribute coreTypeAttr; + + if (parser.parseLParen()) + return failure(); + + if (failed(parser.parseOptionalRParen())) { + if (parser.parseOperandList(operands) || parser.parseColonTypeList(operandTypes) || + parser.parseRParen()) + return failure(); + if (operands.size() != operandTypes.size()) + return parser.emitError(parser.getCurrentLocation()) + << "expects the same number of operands and operand types"; + } + + if (parser.parseKeyword("mode") || parser.parseEqual() || + parser.parseAttribute(modeAttr) || parser.parseComma() || + parser.parseKeyword("core_type") || parser.parseEqual() || + parser.parseAttribute(coreTypeAttr)) + return failure(); + + auto mode = dyn_cast(modeAttr); + if (!mode) + return parser.emitError(parser.getCurrentLocation()) + << "expects mode to be #pto.sync_all_mode<...>"; + + auto coreType = dyn_cast(coreTypeAttr); + if (!coreType) + return parser.emitError(parser.getCurrentLocation()) + << "expects core_type to be #pto.sync_core_type<...>"; + + result.addAttribute("mode", mode); + result.addAttribute("core_type", coreType); + + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + auto addSegmentSizes = [&](int32_t gm, int32_t ub, int32_t l1, + int32_t used) { + result.addAttribute("operandSegmentSizes", + parser.getBuilder().getDenseI32ArrayAttr( + {gm, ub, l1, used})); + }; + + switch (mode.getValue()) { + case pto::SyncAllMode::Hard: + if (!operands.empty()) + return parser.emitError(parser.getCurrentLocation()) + << "expects hard syncall to have no operands"; + addSegmentSizes(0, 0, 0, 0); + return success(); + case pto::SyncAllMode::Soft: + break; + } + + switch (coreType.getValue()) { + case pto::SyncCoreType::AIVOnly: + if (operands.size() != 2 && operands.size() != 3) + return parser.emitError(parser.getCurrentLocation()) + << "expects soft AIV-only syncall to have gm_workspace, " + "ub_workspace, and optional used_cores"; + if (parser.resolveOperand(operands[0], operandTypes[0], result.operands) || + parser.resolveOperand(operands[1], operandTypes[1], result.operands)) + return failure(); + if (operands.size() == 3 && + parser.resolveOperand(operands[2], operandTypes[2], result.operands)) + return failure(); + addSegmentSizes(1, 1, 0, operands.size() == 3 ? 1 : 0); + return success(); + case pto::SyncCoreType::AICOnly: + if (operands.size() != 2 && operands.size() != 3) + return parser.emitError(parser.getCurrentLocation()) + << "expects soft AIC-only syncall to have gm_workspace, " + "l1_workspace, and optional used_cores"; + if (parser.resolveOperand(operands[0], operandTypes[0], result.operands) || + parser.resolveOperand(operands[1], operandTypes[1], result.operands)) + return failure(); + if (operands.size() == 3 && + parser.resolveOperand(operands[2], operandTypes[2], result.operands)) + return failure(); + addSegmentSizes(1, 0, 1, operands.size() == 3 ? 1 : 0); + return success(); + case pto::SyncCoreType::Mix: + if (operands.size() != 3 && operands.size() != 4) + return parser.emitError(parser.getCurrentLocation()) + << "expects soft mixed syncall to have gm_workspace, " + "ub_workspace, l1_workspace, and optional used_cores"; + if (parser.resolveOperand(operands[0], operandTypes[0], result.operands) || + parser.resolveOperand(operands[1], operandTypes[1], result.operands) || + parser.resolveOperand(operands[2], operandTypes[2], result.operands)) + return failure(); + if (operands.size() == 4 && + parser.resolveOperand(operands[3], operandTypes[3], result.operands)) + return failure(); + addSegmentSizes(1, 1, 1, operands.size() == 4 ? 1 : 0); + return success(); + } + + llvm_unreachable("unhandled SyncCoreType"); +} + +void mlir::pto::SyncAllOp::print(OpAsmPrinter &p) { + SmallVector operands; + if (getGmWorkspace()) + operands.push_back(getGmWorkspace()); + if (getUbWorkspace()) + operands.push_back(getUbWorkspace()); + if (getL1Workspace()) + operands.push_back(getL1Workspace()); + if (getUsedCores()) + operands.push_back(getUsedCores()); + + p << "("; + if (!operands.empty()) { + p.printOperands(operands); + p << " : "; + llvm::interleaveComma(operands, p, + [&](Value operand) { p.printType(operand.getType()); }); + } + p << ") mode = " << getMode() << ", core_type = " << getCoreType(); + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"operandSegmentSizes", "mode", + "core_type"}); +} + +LogicalResult mlir::pto::SyncWaitOp::verify() { + bool hasStatic = getEventIdAttr() != nullptr; + bool hasDynamic = static_cast(getEventIdDyn()); + if (hasStatic == hasDynamic) + return emitOpError() + << "expects exactly one event-id form: static attr or dynamic index operand"; + + auto verifyA2A3 = [&]() -> LogicalResult { return success(); }; + auto verifyA5 = [&]() -> LogicalResult { + switch (getPipe().getPipe()) { + case PIPE::PIPE_FIX: + case PIPE::PIPE_MTE1: + case PIPE::PIPE_MTE2: + case PIPE::PIPE_MTE3: + case PIPE::PIPE_V: + return success(); + default: + return emitOpError() << "A5 sync.wait expects pipe to be one of " + ", , , " + ", "; + } + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult TStoreOp::verify() { + auto verifyCommon = + [&](bool allowLowPrecision) + -> FailureOr> { + auto srcTile = dyn_cast(getSrc().getType()); + auto dstPart = dyn_cast(getDst().getType()); + if (!srcTile || !dstPart) { + emitOpError("expects src to be !pto.tile_buf and dst to be !pto.partition_tensor_view"); + return failure(); + } + if (failed(verifyTileBufCommon(*this, srcTile, "src", allowLowPrecision))) + return failure(); + for (auto [idx, dim] : llvm::enumerate(dstPart.getShape())) { + if (dim != ShapedType::kDynamic && dim <= 0) { + emitOpError() << "expects dst shape[" << idx << "] to be positive"; + return failure(); + } + } + auto srcValid = srcTile.getValidShape(); + for (auto [idx, dim] : llvm::enumerate(srcValid)) { + if (dim != ShapedType::kDynamic && dim <= 0) { + emitOpError() << "expects src valid_shape[" << idx << "] to be positive"; + return failure(); + } + } + + // Keep TSTORE contract explicit while preserving existing legal layout + // reinterpretation paths (e.g. 1x1024 <-> 32x32, 5D partition views). + // When both sides are fully static, require equal element counts between + // dst shape and src valid_shape. + auto getStaticElemCount = [](ArrayRef shape) -> std::optional { + int64_t total = 1; + for (int64_t dim : shape) { + if (dim == ShapedType::kDynamic) + return std::nullopt; + if (dim <= 0) + return std::nullopt; + if (total > std::numeric_limits::max() / dim) + return std::nullopt; + total *= dim; + } + return total; + }; + + auto dstElemCount = getStaticElemCount(dstPart.getShape()); + auto srcValidElemCount = getStaticElemCount(srcValid); + if (dstElemCount && srcValidElemCount && *dstElemCount != *srcValidElemCount) { + emitOpError() << "expects dst static element count (" << *dstElemCount + << ") to match src valid_shape static element count (" + << *srcValidElemCount << ")"; + return failure(); + } + return std::make_pair(srcTile, dstPart); + }; + + auto isLoadStoreElemType = [&](Type ty) -> bool { + return ty.isInteger(8) || ty.isInteger(16) || ty.isInteger(32) || + ty.isInteger(64) || ty.isF16() || ty.isBF16() || ty.isF32(); + }; + auto isI8Like = [&](Type ty) -> bool { return ty.isInteger(8); }; + bool hasPreQuant = static_cast(getPreQuantScalar()); + auto reluMode = getReluPreMode(); + + auto verifyA2A3 = [&]() -> LogicalResult { + auto common = verifyCommon(/*allowLowPrecision=*/false); + if (failed(common)) + return failure(); + auto [srcTile, dstPart] = *common; + auto srcSpace = getPTOMemorySpaceEnum(srcTile); + if (!srcSpace || (*srcSpace != pto::AddressSpace::VEC && + *srcSpace != pto::AddressSpace::MAT && + *srcSpace != pto::AddressSpace::ACC)) + return emitOpError("expects A2/A3 tstore src to use loc=vec, loc=mat, or loc=acc"); + if (hasPreQuant && *srcSpace != pto::AddressSpace::ACC) + return emitOpError("expects preQuantScalar form to use loc=acc src"); + if (reluMode != pto::ReluPreMode::NoRelu && *srcSpace != pto::AddressSpace::ACC) + return emitOpError("expects reluPreMode form to use loc=acc src"); + + Type srcElem = srcTile.getElementType(); + Type dstElem = dstPart.getElementType(); + if (*srcSpace == pto::AddressSpace::VEC || *srcSpace == pto::AddressSpace::MAT) { + if (hasPreQuant) + return emitOpError("expects preQuantScalar form to use loc=acc src"); + if (isPTOLowPrecisionType(dstElem)) + return emitOpError("expects A2/A3 vec/mat tstore low-precision dst element types to be unsupported"); + if (!isLoadStoreElemType(srcElem)) + return emitOpError("expects A2/A3 vec/mat tstore src element type to be i8/i16/i32/i64/u64/f16/bf16/f32"); + if (getElemByteSize(srcElem) != getElemByteSize(dstElem)) + return emitOpError("expects A2/A3 vec/mat tstore src and dst element types to have the same bitwidth"); + return success(); + } + + if (!(srcElem.isInteger(32) || srcElem.isF32())) + return emitOpError("expects A2/A3 acc tstore src element type to be i32 or f32"); + if (hasPreQuant) { + if (srcElem.isInteger(32)) { + if (!(isI8Like(dstElem) || dstElem.isF16())) + return emitOpError("expects A2/A3 acc preQuantScalar tstore dst type to be i8/ui8/f16"); + } else if (srcElem.isF32()) { + if (!isI8Like(dstElem)) + return emitOpError("expects A2/A3 acc preQuantScalar tstore dst type to be i8/ui8"); + } + } else { + if (!(dstElem.isInteger(32) || dstElem.isF32() || dstElem.isF16() || + dstElem.isBF16())) + return emitOpError("expects A2/A3 acc tstore dst element type to be i32/f32/f16/bf16"); + } + + auto srcShape = srcTile.getShape(); + if (srcShape[1] != ShapedType::kDynamic && + (srcShape[1] < 1 || srcShape[1] > 4095)) + return emitOpError("expects A2/A3 acc tstore src cols to be in [1, 4095]"); + auto srcValid = srcTile.getValidShape(); + if (srcValid[1] != ShapedType::kDynamic && + (srcValid[1] < 1 || srcValid[1] > 4095)) + return emitOpError("expects A2/A3 acc tstore src valid_shape[1] to be in [1, 4095]"); + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + auto common = verifyCommon(/*allowLowPrecision=*/true); + if (failed(common)) + return failure(); + auto [srcTile, dstPart] = *common; + auto srcSpace = getPTOMemorySpaceEnum(srcTile); + if (!srcSpace || (*srcSpace != pto::AddressSpace::VEC && + *srcSpace != pto::AddressSpace::ACC)) + return emitOpError("expects A5 tstore src to use loc=vec or loc=acc"); + if (hasPreQuant && *srcSpace != pto::AddressSpace::ACC) + return emitOpError("expects preQuantScalar form to use loc=acc src"); + if (reluMode != pto::ReluPreMode::NoRelu && *srcSpace != pto::AddressSpace::ACC) + return emitOpError("expects reluPreMode form to use loc=acc src"); + + Type srcElem = srcTile.getElementType(); + Type dstElem = dstPart.getElementType(); + if (*srcSpace == pto::AddressSpace::VEC) { + if (hasPreQuant) + return emitOpError("expects preQuantScalar form to use loc=acc src"); + if (!isA5TLoadStoreTransferElemType(srcElem)) + return emitOpError("expects A5 vec tstore src element type to be i8/i16/i32/i64/f16/bf16/f32/f8/hif8/fp4"); + if (getElemByteSize(srcElem) != getElemByteSize(dstElem)) + return emitOpError("expects A5 vec tstore src and dst element types to have the same bitwidth"); + return success(); + } + + if (!(srcElem.isInteger(32) || srcElem.isF32())) + return emitOpError("expects A5 acc tstore src element type to be i32 or f32"); + if (hasPreQuant) { + if (!isA5AccStorePreQuantDstType(srcElem, dstElem)) + return emitOpError("expects A5 acc preQuantScalar tstore dst type to be i8/ui8/f16/bf16/f32/hif8/f8E4M3"); + } else { + if (!(dstElem.isInteger(32) || dstElem.isF32() || dstElem.isF16() || + dstElem.isBF16())) + return emitOpError("expects A5 acc tstore dst element type to be i32/f32/f16/bf16"); + } + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult pto::TAbsOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyVecTileCommon(*this, srcTy, "src")) || + failed(verifyVecTileCommon(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst")) || + failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) + return failure(); + + Type elemTy; + if (auto tb = dyn_cast(srcTy)) + elemTy = tb.getElementType(); + else if (auto mr = dyn_cast(srcTy)) + elemTy = mr.getElementType(); + if (!(elemTy.isF16() || elemTy.isF32())) + return emitOpError() << "expects element type to be f16 or f32"; + + return success(); +} +// PTO.cpp + +static bool isPTOShapedLike(Type ty) { + return mlir::isa(ty); +} + +static bool isTileLikeType(Type ty) { + return isa(ty); +} + +static Type getElemTy(Type ty) { + if (auto mr = mlir::dyn_cast(ty)) return mr.getElementType(); + if (auto tt = mlir::dyn_cast(ty)) return tt.getElementType(); + if (auto tv = mlir::dyn_cast(ty)) return tv.getElementType(); + if (auto tb = mlir::dyn_cast(ty)) return tb.getElementType(); + if (auto tv = mlir::dyn_cast(ty)) return tv.getElementType(); + return Type(); +} + +static SmallVector getShapeVec(Type ty) { + SmallVector s; + if (auto mr = mlir::dyn_cast(ty)) + return SmallVector(mr.getShape().begin(), mr.getShape().end()); + if (auto tt = mlir::dyn_cast(ty)) + return SmallVector(tt.getShape().begin(), tt.getShape().end()); + if (auto tv = mlir::dyn_cast(ty)) + return SmallVector(tv.getShape().begin(), tv.getShape().end()); + if (auto tb = mlir::dyn_cast(ty)) + return SmallVector(tb.getShape().begin(), tb.getShape().end()); + if (auto tv = mlir::dyn_cast(ty)) + return SmallVector(tv.getShape().begin(), tv.getShape().end()); + return {}; +} + +static SmallVector getValidShapeVec(Type ty) { + if (auto tb = dyn_cast(ty)) + return SmallVector(tb.getValidShape().begin(), tb.getValidShape().end()); + return getShapeVec(ty); +} + +static int64_t getLogicalTileDim(int64_t rawDim, Type elemTy, + std::optional blayout, + unsigned dimIdx) { + if (rawDim == ShapedType::kDynamic || !isPTOFloat4PackedType(elemTy)) + return rawDim; + pto::BLayout layout = blayout.value_or(pto::BLayout::RowMajor); + unsigned packedDim = layout == pto::BLayout::ColMajor ? 0 : 1; + return dimIdx == packedDim ? rawDim * 2 : rawDim; +} + +static std::optional getTileBufBLayout(Type ty) { + if (auto tb = dyn_cast(ty)) + return static_cast(tb.getBLayoutValueI32()); + return std::nullopt; +} + +static SmallVector getLogicalTileExtentVec(Type ty, + bool useValidShape) { + SmallVector dims = + useValidShape ? getValidShapeVec(ty) : getShapeVec(ty); + if (!isTileLikeType(ty) || dims.size() != 2) + return dims; + + Type elemTy = getElemTy(ty); + auto blayout = getTileBufBLayout(ty); + for (unsigned i = 0; i < dims.size(); ++i) + dims[i] = getLogicalTileDim(dims[i], elemTy, blayout, i); + return dims; +} + +static int64_t getConstantIndexOrDynamic(Value value) { + if (!value) + return ShapedType::kDynamic; + if (auto cst = value.getDefiningOp()) + return cst.value(); + if (auto cst = value.getDefiningOp()) + return cst.value(); + return ShapedType::kDynamic; +} + +static SmallVector getValidShapeVec(Value value) { + if (!value) + return {}; + auto valid = getValidShapeVec(value.getType()); + if (auto bind = value.getDefiningOp()) { + if (valid.size() >= 1 && bind.getValidRow()) + valid[0] = getConstantIndexOrDynamic(bind.getValidRow()); + if (valid.size() >= 2 && bind.getValidCol()) + valid[1] = getConstantIndexOrDynamic(bind.getValidCol()); + } + return valid; +} + +static SmallVector getMatmulLogicalShapeVec(Type ty) { + auto shape = getShapeVec(ty); + auto valid = getValidShapeVec(ty); + if (!isa(ty) || shape.size() != valid.size()) + return shape; + + for (size_t i = 0, e = shape.size(); i < e; ++i) { + if (valid[i] != ShapedType::kDynamic) + shape[i] = valid[i]; + } + return shape; +} + +static bool isByteIntegerType(Type ty) { + auto intTy = dyn_cast(ty); + return intTy && intTy.getWidth() == 8; +} + +static LogicalResult verifyAsyncFlatContiguous1DGMMemRef(Operation *op, + Value value, + StringRef name) { + auto memTy = dyn_cast(value.getType()); + if (!memTy) + return op->emitOpError() << "expects " << name << " to be a memref"; + if (!memTy.hasRank()) + return op->emitOpError() << "expects " << name << " to be a ranked memref"; + if (!isGmAddressSpaceAttr(memTy.getMemorySpace())) + return op->emitOpError() << "expects " << name + << " to be in GM address space"; + + ArrayRef shape = memTy.getShape(); + if (shape.empty()) + return op->emitOpError() << "expects " << name + << " to have rank >= 1"; + for (int64_t dim : shape) { + if (dim == ShapedType::kDynamic) + return op->emitOpError() << "expects " << name + << " to have a static shape"; + } + + SmallVector strides; + int64_t offset = 0; + if (failed(getStridesAndOffset(memTy, strides, offset))) + return op->emitOpError() << "expects " << name + << " to be a strided memref with a known layout"; + + bool hasDynamicLayout = + offset == ShapedType::kDynamic || + llvm::any_of(strides, [](int64_t stride) { + return stride == ShapedType::kDynamic; + }); + if (hasDynamicLayout) + return success(); + + bool packed = !strides.empty() && strides.back() == 1; + for (int i = static_cast(shape.size()) - 2; i >= 0 && packed; --i) + packed &= strides[i] == strides[i + 1] * shape[i + 1]; + if (!packed) + return op->emitOpError() + << "expects " << name + << " to be a static flat contiguous logical 1D GM memref"; + + bool logical1D = true; + for (int i = 0, e = static_cast(shape.size()) - 1; i < e; ++i) + logical1D &= shape[i] == 1; + if (!logical1D) + return op->emitOpError() + << "expects " << name + << " to be a static flat contiguous logical 1D GM memref"; + + return success(); +} + +static LogicalResult verifyAsyncFlatContiguous1DGMViewLike(Operation *op, + Value value, + StringRef name) { + Type ty = value.getType(); + if (isa(ty)) + return verifyAsyncFlatContiguous1DGMMemRef(op, value, name); + + if (!isa(ty)) + return op->emitOpError() << "expects " << name + << " to be a memref/tensor_view/partition_view"; + + SmallVector shape = getShapeVec(ty); + if (shape.empty()) + return op->emitOpError() << "expects " << name << " to have rank >= 1"; + for (int64_t dim : shape) { + if (dim == ShapedType::kDynamic) + return op->emitOpError() << "expects " << name + << " to have a static shape"; + } + + bool logical1D = true; + for (int i = 0, e = static_cast(shape.size()) - 1; i < e; ++i) + logical1D &= shape[i] == 1; + if (!logical1D) + return op->emitOpError() + << "expects " << name + << " to be a static flat contiguous logical 1D GM view"; + + return success(); +} + +static bool isCommGlobalLikeType(Type ty) { + if (auto memTy = dyn_cast(ty)) + return isGmAddressSpaceAttr(memTy.getMemorySpace()); + return isa(ty); +} + +static LogicalResult verifyCommGlobalLike(Operation *op, Value value, + StringRef name) { + Type ty = value.getType(); + if (!isCommGlobalLikeType(ty)) + return op->emitOpError() << "expects " << name + << " to be a GM memref/tensor_view/partition_view"; + + SmallVector shape = getShapeVec(ty); + if (shape.empty()) + return op->emitOpError() << "expects " << name << " to have rank >= 1"; + for (int64_t dim : shape) { + if (dim == ShapedType::kDynamic || dim <= 0) + return op->emitOpError() << "expects " << name + << " to have a positive static shape"; + } + return success(); +} + +static LogicalResult verifyCommSignalLike(Operation *op, Value value, + StringRef name) { + if (failed(verifyCommGlobalLike(op, value, name))) + return failure(); + Type elemTy = getElemTy(value.getType()); + if (!elemTy || !elemTy.isSignlessInteger(32)) + return op->emitOpError() << "expects " << name + << " element type to be i32"; + return success(); +} + +static LogicalResult verifyCommStagingTileLike(Operation *op, Value value, + StringRef name) { + Type ty = value.getType(); + if (!isa(ty)) + return op->emitOpError() << "expects " << name + << " to be a tile_buf or memref tile"; + auto as = getPTOMemorySpaceEnum(ty); + if (!as || *as != pto::AddressSpace::VEC) + return op->emitOpError() << "expects " << name + << " to be in vec address space"; + SmallVector shape = getShapeVec(ty); + if (shape.empty()) + return op->emitOpError() << "expects " << name << " to have rank >= 1"; + for (int64_t dim : shape) { + if (dim == ShapedType::kDynamic || dim <= 0) + return op->emitOpError() << "expects " << name + << " to have a positive static shape"; + } + return success(); +} + +static LogicalResult verifyCommGlobalGroup(Operation *op, ValueRange group, + StringRef name) { + if (group.empty()) + return op->emitOpError() << "expects at least one " << name << " operand"; + Type groupTy = group.front().getType(); + for (auto it : llvm::enumerate(group)) { + if (failed(verifyCommGlobalLike(op, it.value(), + (name + "[" + Twine(it.index()) + "]").str()))) + return failure(); + if (it.value().getType() != groupTy) + return op->emitOpError() << "expects all " << name + << " operands to have identical types"; + } + return success(); +} + +static LogicalResult verifyCommPingPongSameType(Operation *op, Value ping, + Value pong, StringRef pingName, + StringRef pongName) { + if (!pong) + return success(); + if (failed(verifyCommStagingTileLike(op, ping, pingName)) || + failed(verifyCommStagingTileLike(op, pong, pongName))) + return failure(); + if (ping.getType() != pong.getType()) + return op->emitOpError() << "expects " << pingName << " and " << pongName + << " to have identical types"; + return success(); +} + +static std::optional getStaticByteSize(Type ty) { + SmallVector shape = getShapeVec(ty); + if (shape.empty()) + return std::nullopt; + for (int64_t dim : shape) { + if (dim == ShapedType::kDynamic || dim < 0) + return std::nullopt; + } + + Type elemTy = getElemTy(ty); + uint64_t elemBytes = getElemByteSize(elemTy); + if (elemBytes == 0) + return std::nullopt; + + uint64_t total = elemBytes; + for (int64_t dim : shape) { + total *= static_cast(dim); + } + return total; +} + +static std::optional getPTOMemorySpaceEnum(Type ty) { + if (auto tb = dyn_cast(ty)) { + if (auto as = dyn_cast_or_null(tb.getMemorySpace())) + return as.getAddressSpace(); + return std::nullopt; + } + if (auto mr = dyn_cast(ty)) { + if (auto as = dyn_cast_or_null(mr.getMemorySpace())) + return as.getAddressSpace(); + if (!mr.getMemorySpace()) + return pto::AddressSpace::GM; + } + return std::nullopt; +} + +[[maybe_unused]] static bool isRank2TileBuf(Type ty) { + auto tb = dyn_cast(ty); + return tb && tb.getRank() == 2 && tb.getValidShape().size() == 2; +} + +static bool isSupportedVecElemType(Type ty, bool allowBf16, + bool allowInt8) { + if (ty.isF16() || ty.isF32()) + return true; + if (allowBf16 && ty.isBF16()) + return true; + if (auto it = dyn_cast(ty)) { + switch (it.getWidth()) { + case 32: + case 16: + return true; + case 8: + return allowInt8; + default: + return false; + } + } + return false; +} + +static bool isSupportedMGatherMScatterIndexElemType(Type ty) { + auto it = dyn_cast(ty); + if (!it || it.getWidth() != 32) + return false; + return true; +} + +static bool isSupportedMGatherMScatterPayloadElemType(Operation *op, Type ty) { + if (isSupportedVecElemType(ty, /*allowBf16=*/true, /*allowInt8=*/true)) + return true; + if (!isTargetArchA5(op)) + return false; + return ty.isFloat8E4M3() || ty.isFloat8E4M3FN() || ty.isFloat8E4M3FNUZ() || + ty.isFloat8E4M3B11FNUZ() || ty.isFloat8E5M2() || ty.isFloat8E5M2FNUZ(); +} + +static bool isSupportedMScatterAtomicPayloadElemType(Type ty, + pto::ScatterAtomicOp atomic) { + auto intTy = dyn_cast(ty); + switch (atomic) { + case pto::ScatterAtomicOp::None: + return true; + case pto::ScatterAtomicOp::Add: + return ty.isF16() || ty.isF32() || + (intTy && intTy.getWidth() == 32); + case pto::ScatterAtomicOp::Max: + case pto::ScatterAtomicOp::Min: + return ty.isF32() || + (intTy && intTy.getWidth() == 32); + } + llvm_unreachable("Unknown ScatterAtomicOp"); +} + +static LogicalResult verifyMGatherMScatterMemOperand(Operation *op, + Value memValue, + Type dataElemTy, + StringRef dataOperandLabel) { + Type memTy = memValue.getType(); + Type memElem = getElemTy(memTy); + if (!memElem || memElem != dataElemTy) + return op->emitOpError() << "expects mem element type to match " + << dataOperandLabel << " element type"; + + if (isa(memTy)) { + if (auto layout = getLogicalViewLayout(memValue)) { + if (*layout != pto::Layout::ND) + return op->emitOpError( + "expects mem partition view to use ND logical layout when layout " + "can be inferred"); + } + return success(); + } + + if (auto mr = dyn_cast(memTy)) { + auto as = getPTOMemorySpaceEnum(mr); + if (!as || (*as != pto::AddressSpace::GM && + *as != pto::AddressSpace::Zero)) + return op->emitOpError( + "expects mem memref to use GM or zero address space"); + if (mr.getRank() == 5) { + auto shape = mr.getShape(); + bool allStatic = true; + for (int64_t d : shape) + if (d == ShapedType::kDynamic) + allStatic = false; + if (allStatic && (shape[0] != 1 || shape[1] != 1 || shape[2] != 1)) + return op->emitOpError( + "expects rank-5 GM memref leading dimensions to be [1,1,1,...] " + "(GlobalTensor table shape)"); + } + return success(); + } + + return op->emitOpError( + "expects mem to be !pto.partition_tensor_view or a GM/ZERO memref"); +} + +static bool hasCompatibleKnownExtent(int64_t lhs, int64_t rhs); +static bool isKnownUnitExtent(int64_t value); + +static LogicalResult verifyMGatherMScatterTileShape(Operation *op, Type dataTy, + Type idxTy, + StringRef dataName) { + auto dataValid = getValidShapeVec(dataTy); + auto idxValid = getValidShapeVec(idxTy); + if (dataValid.size() != 2 || idxValid.size() != 2) + return op->emitOpError() << "expects " << dataName + << " and idx to have rank-2 valid_shape"; + + auto idxTile = dyn_cast(idxTy); + if (!idxTile) + return op->emitOpError("expects idx to be a tile_buf type"); + + const bool idxRowMajor = + idxTile.getBLayoutValueI32() == + static_cast(pto::BLayout::RowMajor); + const bool idxColMajor = + idxTile.getBLayoutValueI32() == + static_cast(pto::BLayout::ColMajor); + + const bool rowCoalesce1xR = + idxRowMajor && isKnownUnitExtent(idxValid[0]) && + hasCompatibleKnownExtent(idxValid[1], dataValid[0]); + const bool rowCoalesceRx1 = + idxColMajor && hasCompatibleKnownExtent(idxValid[0], dataValid[0]) && + isKnownUnitExtent(idxValid[1]); + const bool elemCoalesce = + hasCompatibleKnownExtent(idxValid[0], dataValid[0]) && + hasCompatibleKnownExtent(idxValid[1], dataValid[1]); + + if (!(rowCoalesce1xR || rowCoalesceRx1 || elemCoalesce)) + return op->emitOpError() + << "expects idx valid_shape to be [1, " << dataName + << ".valid_row], [" << dataName + << ".valid_row, 1], or match " << dataName << " valid_shape"; + + return success(); +} + +static LogicalResult verifyMGatherMScatterIdxTile(Operation *op, Type ty, + StringRef name) { + if (failed(verifyTileBufCommon(op, ty, name))) + return failure(); + auto as = getPTOMemorySpaceEnum(ty); + if (!as || *as != pto::AddressSpace::VEC) + return op->emitOpError() << "expects " << name + << " to be in the vec address space"; + auto tb = dyn_cast(ty); + if (!tb) + return op->emitOpError() << "expects " << name << " to be a tile_buf type"; + int32_t blayout = tb.getBLayoutValueI32(); + if (blayout != static_cast(pto::BLayout::RowMajor) && + blayout != static_cast(pto::BLayout::ColMajor)) + return op->emitOpError() << "expects " << name + << " to use row_major or col_major blayout"; + if (tb.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) + return op->emitOpError() << "expects " << name + << " to use the none_box slayout"; + return success(); +} + +static bool isA5TLoadStoreTransferElemType(Type ty) { + return ty.isInteger(8) || ty.isInteger(16) || ty.isInteger(32) || + ty.isInteger(64) || ty.isF16() || ty.isBF16() || ty.isF32() || + isPTOLowPrecisionType(ty); +} + +static bool isA5AccStorePreQuantDstType(Type srcElem, Type dstElem) { + if (srcElem.isInteger(32)) + return dstElem.isInteger(8) || dstElem.isF16() || dstElem.isBF16(); + if (!srcElem.isF32()) + return false; + return dstElem.isInteger(8) || dstElem.isF16() || dstElem.isBF16() || + dstElem.isF32() || isPTOHiFloat8Type(dstElem) || + dstElem.isFloat8E4M3() || dstElem.isFloat8E4M3FN() || + dstElem.isFloat8E4M3FNUZ() || dstElem.isFloat8E4M3B11FNUZ(); +} + +static bool isA5LowPrecisionTCvtPair(Type srcElem, Type dstElem) { + if (srcElem.isF32()) + return isPTOFloat8Type(dstElem) || isPTOHiFloat8Type(dstElem); + if (srcElem.isF16()) + return isPTOHiFloat8Type(dstElem); + if (srcElem.isBF16()) + return isPTOFloat4PackedType(dstElem); + if (isPTOFloat4PackedType(srcElem)) + return dstElem.isBF16(); + if (isPTOFloat8Type(srcElem) || isPTOHiFloat8Type(srcElem)) + return dstElem.isF32(); + return false; +} + +static bool isA5SupportedTCvtPair(Type srcElem, Type dstElem) { + if (isPTOLowPrecisionType(srcElem) || isPTOLowPrecisionType(dstElem)) + return isA5LowPrecisionTCvtPair(srcElem, dstElem); + return true; +} + +static LogicalResult verifyTileBufCommon(Operation *op, Type ty, StringRef name, + bool allowLowPrecision) { + auto tb = dyn_cast(ty); + if (tb) { + if (tb.getRank() != 2) + return op->emitOpError() << "expects " << name << " to be a rank-2 tile_buf"; + Type elemTy = tb.getElementType(); + if (!allowLowPrecision && isPTOLowPrecisionType(elemTy)) + return op->emitOpError() << name << ": dtype " << elemTy + << " is not supported by this op yet"; + } else if (auto mr = dyn_cast(ty)) { + if (mr.getRank() != 2) + return op->emitOpError() << "expects " << name << " to be a rank-2 memref"; + if (!allowLowPrecision && isPTOLowPrecisionType(mr.getElementType())) + return op->emitOpError() << name << ": dtype " << mr.getElementType() + << " is not supported by this op yet"; + } else { + return op->emitOpError() << "expects " << name << " to be a !pto.tile_buf or rank-2 memref"; + } + + auto validShape = getValidShapeVec(ty); + if (validShape.size() != 2) + return op->emitOpError() << "expects " << name << " to have a rank-2 valid_shape"; + auto shape = getShapeVec(ty); + for (unsigned i = 0; i < 2; ++i) { + if (shape[i] != ShapedType::kDynamic && validShape[i] != ShapedType::kDynamic && + validShape[i] > shape[i]) + return op->emitOpError() << "expects " << name << " to satisfy valid_shape[" << i + << "] <= shape[" << i << "]"; + } + return success(); +} + +static LogicalResult verifyTileBufSameElemType(Operation *op, Type lhs, Type rhs, + StringRef lhsName, + StringRef rhsName) { + if (!isTileLikeType(lhs) || !isTileLikeType(rhs)) + return op->emitOpError() << "expects " << lhsName << " and " << rhsName + << " to be !pto.tile_buf or memref"; + if (getElemTy(lhs) != getElemTy(rhs)) + return op->emitOpError() << "expects " << lhsName << " and " << rhsName + << " to have the same element type"; + return success(); +} + +static LogicalResult verifyTileBufSameValidShape(Operation *op, Type lhs, Type rhs, + StringRef lhsName, StringRef rhsName) { + if (!isTileLikeType(lhs) || !isTileLikeType(rhs)) + return success(); + auto lhsValid = getValidShapeVec(lhs); + auto rhsValid = getValidShapeVec(rhs); + for (size_t i = 0; i < lhsValid.size() && i < rhsValid.size(); ++i) { + if (lhsValid[i] != ShapedType::kDynamic && rhsValid[i] != ShapedType::kDynamic && + lhsValid[i] != rhsValid[i]) + return op->emitOpError() << "expects " << lhsName << " and " << rhsName + << " to have the same valid_shape"; + } + if (lhsValid.size() != rhsValid.size()) + return op->emitOpError() << "expects " << lhsName << " and " << rhsName + << " to have the same valid_shape"; + return success(); +} + +static LogicalResult verifyTileBufSameLogicalExtent(Operation *op, Type lhs, + Type rhs, StringRef lhsName, + StringRef rhsName, + bool compareValidShape) { + if (!isTileLikeType(lhs) || !isTileLikeType(rhs)) + return success(); + + auto lhsExtent = getLogicalTileExtentVec(lhs, compareValidShape); + auto rhsExtent = getLogicalTileExtentVec(rhs, compareValidShape); + auto emitMismatch = [&]() -> LogicalResult { + if (compareValidShape) + return op->emitOpError() << "expects " << lhsName << " and " << rhsName + << " to have the same valid_shape"; + return op->emitOpError() << "expects " << lhsName << " and " << rhsName + << " to have compatible shapes"; + }; + if (lhsExtent.size() != rhsExtent.size()) + return emitMismatch(); + + for (size_t i = 0, e = lhsExtent.size(); i < e; ++i) { + if (lhsExtent[i] != ShapedType::kDynamic && + rhsExtent[i] != ShapedType::kDynamic && lhsExtent[i] != rhsExtent[i]) + return emitMismatch(); + } + return success(); +} + +static LogicalResult verifyScaleTileMatchesOperand(Operation *op, Type scaleTy, + Type operandTy, + StringRef scaleName, + StringRef operandName) { + if (failed(verifyTileBufCommon(op, scaleTy, scaleName))) + return failure(); + auto scaleSpace = getPTOMemorySpaceEnum(scaleTy); + if (!scaleSpace || *scaleSpace != pto::AddressSpace::SCALING) + return op->emitOpError() << "expects " << scaleName + << " to be in the scaling address space"; + + auto scaleShape = getShapeVec(scaleTy); + auto operandShape = getShapeVec(operandTy); + if (scaleShape.size() != operandShape.size()) + return op->emitOpError() << "expects " << scaleName << " and " << operandName + << " to have the same rank"; + for (size_t i = 0; i < scaleShape.size(); ++i) { + if (scaleShape[i] != ShapedType::kDynamic && + operandShape[i] != ShapedType::kDynamic && + scaleShape[i] != operandShape[i]) + return op->emitOpError() << "expects " << scaleName << " and " << operandName + << " to have the same shape"; + } + + auto scaleValid = getValidShapeVec(scaleTy); + auto operandValid = getValidShapeVec(operandTy); + if (scaleValid.size() != operandValid.size()) + return op->emitOpError() << "expects " << scaleName << " and " << operandName + << " to have the same valid_shape"; + for (size_t i = 0; i < scaleValid.size(); ++i) { + if (scaleValid[i] != ShapedType::kDynamic && + operandValid[i] != ShapedType::kDynamic && + scaleValid[i] != operandValid[i]) + return op->emitOpError() << "expects " << scaleName << " and " << operandName + << " to have the same valid_shape"; + } + return success(); +} + +static LogicalResult verifyPartialValidPattern(Operation *op, Type src0Ty, + Type src1Ty, Type dstTy) { + auto src0Valid = getValidShapeVec(src0Ty); + auto src1Valid = getValidShapeVec(src1Ty); + auto dstValid = getValidShapeVec(dstTy); + if (src0Valid.size() != 2 || src1Valid.size() != 2 || dstValid.size() != 2) + return op->emitOpError("expects src0, src1, and dst to have rank-2 valid_shape"); + + auto lessEqualKnown = [](int64_t lhs, int64_t rhs) { + return lhs == ShapedType::kDynamic || rhs == ShapedType::kDynamic || lhs <= rhs; + }; + auto equalsKnown = [](ArrayRef lhs, ArrayRef rhs) { + for (auto [a, b] : llvm::zip(lhs, rhs)) { + if (a != ShapedType::kDynamic && b != ShapedType::kDynamic && a != b) + return false; + } + return true; + }; + + for (unsigned i = 0; i < 2; ++i) { + if (!lessEqualKnown(src0Valid[i], dstValid[i]) || + !lessEqualKnown(src1Valid[i], dstValid[i])) + return op->emitOpError( + "expects src0/src1 valid_shape to be less than or equal to dst valid_shape"); + } + if (!equalsKnown(src0Valid, dstValid) && !equalsKnown(src1Valid, dstValid)) + return op->emitOpError( + "expects at least one of src0/src1 valid_shape to match dst valid_shape"); + return success(); +} + +[[maybe_unused]] static bool hasKnownZeroValidRegion(Type ty) { + auto valid = getValidShapeVec(ty); + if (valid.size() != 2) + return false; + return valid[0] == 0 || valid[1] == 0; +} + +static LogicalResult verifyScalarTileOp(Operation *op, Type srcTy, Type dstTy, + StringRef srcName, StringRef dstName, + bool requireValidRowsEqual, + bool requireValidColsEqual) { + if (failed(verifyTileBufCommon(op, srcTy, srcName)) || + failed(verifyTileBufCommon(op, dstTy, dstName))) + return failure(); + auto srcSpace = getPTOMemorySpaceEnum(srcTy); + auto dstSpace = getPTOMemorySpaceEnum(dstTy); + if (!srcSpace || *srcSpace != pto::AddressSpace::VEC) + return op->emitOpError() << "expects " << srcName + << " to be in the vec address space"; + if (!dstSpace || *dstSpace != pto::AddressSpace::VEC) + return op->emitOpError() << "expects " << dstName + << " to be in the vec address space"; + if (failed(verifyTileBufSameElemType(op, srcTy, dstTy, srcName, dstName))) + return failure(); + + auto srcValid = getValidShapeVec(srcTy); + auto dstValid = getValidShapeVec(dstTy); + if (srcValid.size() != 2 || dstValid.size() != 2) + return op->emitOpError() + << "expects " << srcName << " and " << dstName + << " to have rank-2 valid_shape"; + if (requireValidRowsEqual && + srcValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && + srcValid[0] != dstValid[0]) + return op->emitOpError() + << "expects " << srcName << " and " << dstName + << " to have the same valid_shape[0]"; + if (requireValidColsEqual && + srcValid[1] != ShapedType::kDynamic && dstValid[1] != ShapedType::kDynamic && + srcValid[1] != dstValid[1]) + return op->emitOpError() + << "expects " << srcName << " and " << dstName + << " to have the same valid_shape[1]"; + return success(); +} + +static FailureOr +verifyMatchingRowMajorBinaryTileOpCommon(Operation *op, Type src0Ty, Type src1Ty, + Type dstTy) { + if (failed(verifyTileBufCommon(op, src0Ty, "src0")) || + failed(verifyTileBufCommon(op, src1Ty, "src1")) || + failed(verifyTileBufCommon(op, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameElemType(op, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameElemType(op, src0Ty, dstTy, "src0", "dst")) || + failed(verifyTileBufSameValidShape(op, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameValidShape(op, src0Ty, dstTy, "src0", "dst"))) + return failure(); + if (!isRowMajorTileBuf(src0Ty) || !isRowMajorTileBuf(src1Ty) || + !isRowMajorTileBuf(dstTy)) { + op->emitOpError("expects src0, src1, and dst to use row-major layout"); + return failure(); + } + return getElemTy(src0Ty); +} + +static FailureOr +verifyNumericScalarTileOpCommon(Operation *op, Type srcTy, Type dstTy, + Type scalarTy, bool requireValidRowsEqual) { + if (failed(verifyScalarTileOp(op, srcTy, dstTy, "src", "dst", + requireValidRowsEqual, + /*requireValidColsEqual=*/true))) + return failure(); + if (!mlir::isa(scalarTy)) { + op->emitOpError("scalar must be a scalar type (integer/float)"); + return failure(); + } + return getElemTy(srcTy); +} + +static FailureOr +verifyShiftLikeBinaryTileOpCommon(Operation *op, Type src0Ty, Type src1Ty, + Type dstTy) { + if (failed(verifyTileBufCommon(op, src0Ty, "src0")) || + failed(verifyTileBufCommon(op, src1Ty, "src1")) || + failed(verifyTileBufCommon(op, dstTy, "dst"))) + return failure(); + Type e0 = getElemTy(src0Ty); + Type e1 = getElemTy(src1Ty); + if (!e0 || !e1) { + op->emitOpError("failed to get element type for operands"); + return failure(); + } + if (e0 != e1) { + op->emitOpError("expects src0 and src1 to have the same element type"); + return failure(); + } + if (!isRowMajorTileBuf(src0Ty) || !isRowMajorTileBuf(src1Ty) || + !isRowMajorTileBuf(dstTy)) { + op->emitOpError("expects src0, src1, and dst to use row-major layout"); + return failure(); + } + if (failed(verifyTileBufSameValidShape(op, src0Ty, dstTy, "src0", "dst")) || + failed(verifyTileBufSameValidShape(op, src1Ty, dstTy, "src1", "dst"))) + return failure(); + return e0; +} + +static FailureOr verifyDistinctRowMajorUnaryTileOpCommon( + Operation *op, Value src, Value dst, StringRef srcName = "src", + StringRef dstName = "dst") { + if (src == dst) { + op->emitOpError("expects src and dst to use different storage"); + return failure(); + } + Type srcTy = src.getType(); + Type dstTy = dst.getType(); + if (failed(verifyTileBufCommon(op, srcTy, srcName)) || + failed(verifyTileBufCommon(op, dstTy, dstName))) + return failure(); + + Type srcElem = getElemTy(srcTy); + Type dstElem = getElemTy(dstTy); + if (!srcElem || !dstElem) { + op->emitOpError("failed to get element type for src/dst"); + return failure(); + } + if (srcElem != dstElem) { + op->emitOpError("expects src and dst to have the same element type"); + return failure(); + } + if (!isRowMajorTileBuf(srcTy) || !isRowMajorTileBuf(dstTy)) { + op->emitOpError("expects src and dst to use row-major layout"); + return failure(); + } + if (failed(verifyTileBufSameValidShape(op, srcTy, dstTy, srcName, dstName))) + return failure(); + return srcElem; +} + +static LogicalResult verifyArithmeticElemTypeForArch( + Operation *op, Type elemTy, PTOArch targetArch, bool allowInt8OnA5, + bool allowBf16OnA5, StringRef a2a3Error, StringRef a5Error) { + bool supported = elemTy.isInteger(32) || elemTy.isInteger(16) || + elemTy.isF16() || elemTy.isF32(); + if (targetArch == PTOArch::A5) + supported = supported || (allowInt8OnA5 && elemTy.isInteger(8)) || + (allowBf16OnA5 && elemTy.isBF16()); + if (supported) + return success(); + return op->emitOpError(targetArch == PTOArch::A5 ? a5Error : a2a3Error); +} + +static LogicalResult verifyArithmeticBinaryTileOpWithArchDispatch( + Operation *op, Type src0Ty, Type src1Ty, Type dstTy, bool allowInt8OnA5, + bool allowBf16OnA5, StringRef a2a3Error, StringRef a5Error) { + auto verifyByArch = [&](PTOArch targetArch) -> LogicalResult { + FailureOr elemOr = + verifyMatchingRowMajorBinaryTileOpCommon(op, src0Ty, src1Ty, dstTy); + if (failed(elemOr)) + return failure(); + return verifyArithmeticElemTypeForArch(op, *elemOr, targetArch, + allowInt8OnA5, allowBf16OnA5, + a2a3Error, a5Error); + }; + auto verifyA2A3 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A3); }; + auto verifyA5 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A5); }; + return dispatchVerifierByArch(op, verifyA2A3, verifyA5); +} + +static LogicalResult verifyArithmeticScalarTileOpWithArchDispatch( + Operation *op, Type srcTy, Type dstTy, Type scalarTy, bool allowInt8OnA5, + bool allowBf16OnA5, StringRef a2a3Error, StringRef a5Error, + bool requireValidRowsEqualOnA2A3 = true, + bool requireValidRowsEqualOnA5 = false) { + auto verifyByArch = [&](PTOArch targetArch, + bool requireValidRowsEqual) -> LogicalResult { + FailureOr elemOr = verifyNumericScalarTileOpCommon( + op, srcTy, dstTy, scalarTy, requireValidRowsEqual); + if (failed(elemOr)) + return failure(); + return verifyArithmeticElemTypeForArch(op, *elemOr, targetArch, + allowInt8OnA5, allowBf16OnA5, + a2a3Error, a5Error); + }; + auto verifyA2A3 = [&]() -> LogicalResult { + return verifyByArch(PTOArch::A3, requireValidRowsEqualOnA2A3); + }; + auto verifyA5 = [&]() -> LogicalResult { + return verifyByArch(PTOArch::A5, requireValidRowsEqualOnA5); + }; + return dispatchVerifierByArch(op, verifyA2A3, verifyA5); +} + +static LogicalResult verifyTColReductionElemTypeForArch( + Operation *op, Type elemTy, PTOArch targetArch, bool allowInt8OnA5, + bool allowBf16OnA5, StringRef a2a3Error, StringRef a5Error) { + bool ok = elemTy.isF16() || elemTy.isF32() || elemTy.isInteger(16) || + elemTy.isInteger(32); + if (targetArch == PTOArch::A5) + ok = ok || (allowInt8OnA5 && elemTy.isInteger(8)) || + (allowBf16OnA5 && elemTy.isBF16()); + if (ok) + return success(); + return op->emitOpError(targetArch == PTOArch::A5 ? a5Error : a2a3Error); +} + +static LogicalResult verifyTColReductionOpWithArchDispatch( + Operation *op, Type srcTy, Type dstTy, bool requireNonZeroSrcOnA2A3, + bool requireNonZeroSrcOnA5, bool allowInt8OnA5, bool allowBf16OnA5, + StringRef a2a3Error, StringRef a5Error) { + auto verifyByArch = [&](PTOArch targetArch, + bool requireNonZeroSrc) -> LogicalResult { + if (failed(verifyNDStyleVecTile(op, srcTy, "src")) || + failed(verifyNDStyleVecTile(op, dstTy, "dst"))) + return failure(); + if (getElemTy(srcTy) != getElemTy(dstTy)) + return op->emitOpError("expects src and dst to have the same element type"); + if (failed(verifyColReductionValidRegion(op, srcTy, dstTy, requireNonZeroSrc))) + return failure(); + Type elem = getElemTy(srcTy); + return verifyTColReductionElemTypeForArch(op, elem, targetArch, allowInt8OnA5, + allowBf16OnA5, a2a3Error, a5Error); + }; + auto verifyA2A3 = [&]() -> LogicalResult { + return verifyByArch(PTOArch::A3, requireNonZeroSrcOnA2A3); + }; + auto verifyA5 = [&]() -> LogicalResult { + return verifyByArch(PTOArch::A5, requireNonZeroSrcOnA5); + }; + return dispatchVerifierByArch(op, verifyA2A3, verifyA5); +} + +static LogicalResult verifyTColArgReductionOpCommon(Operation *op, Type srcTy, + Type tmpTy, Type dstTy) { + if (failed(verifyNDStyleVecTile(op, srcTy, "src")) || + failed(verifyVecTileCommon(op, tmpTy, "tmp")) || + failed(verifyColArgReductionDstLayout(op, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameElemType(op, srcTy, tmpTy, "src", "tmp")) || + failed(verifyTileBufSameValidShape(op, srcTy, tmpTy, "src", "tmp"))) + return failure(); + if (failed(verifyColReductionValidRegion(op, srcTy, dstTy, + /*requireNonZeroSrc=*/true))) + return failure(); + Type srcElemTy = getElemTy(srcTy); + unsigned srcElemBits = srcElemTy ? getPTOStorageElemBitWidth(srcElemTy) : 0; + if (!(mlir::isa(srcElemTy) && + (srcElemBits == 8 || srcElemBits == 16 || srcElemBits == 32))) + return op->emitOpError( + "expects src/tmp element type to be 1, 2, or 4 bytes wide"); + auto dstInt = dyn_cast(getElemTy(dstTy)); + if (!dstInt || dstInt.getWidth() != 32) + return op->emitOpError("expects dst element type to be i32 or ui32"); + return success(); +} + +static bool hasCompatibleKnownExtent(int64_t lhs, int64_t rhs) { + return lhs == ShapedType::kDynamic || rhs == ShapedType::kDynamic || lhs == rhs; +} + +static bool isKnownUnitExtent(int64_t value) { + return value == ShapedType::kDynamic || value == 1; +} + +static LogicalResult verifyVecTileStorage(Operation *op, Type ty, StringRef name) { + if (failed(verifyTileBufCommon(op, ty, name))) + return failure(); + auto as = getPTOMemorySpaceEnum(ty); + if (!as || *as != pto::AddressSpace::VEC) + return op->emitOpError() << "expects " << name << " to be in the vec address space"; + return success(); +} + +static LogicalResult verifyVecTileCommonA2A3(Operation *op, Type ty, + StringRef name) { + if (failed(verifyTileBufCommon(op, ty, name))) + return failure(); + auto tb = dyn_cast(ty); + auto as = getPTOMemorySpaceEnum(ty); + if (as && *as != pto::AddressSpace::VEC) + return op->emitOpError() << "expects " << name << " to be in the vec address space"; + if (tb && tb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor)) + return op->emitOpError() << "expects " << name << " to use the row_major blayout"; + return success(); +} + +static LogicalResult verifyVecTileCommonA5(Operation *op, Type ty, + StringRef name) { + return verifyVecTileCommonA2A3(op, ty, name); +} + +static LogicalResult verifyVecTileCommon(Operation *op, Type ty, StringRef name) { + switch (getVerifierTargetArch(op)) { + case VerifierTargetArch::A2A3: + return verifyVecTileCommonA2A3(op, ty, name); + case VerifierTargetArch::A5: + return verifyVecTileCommonA5(op, ty, name); + } + return failure(); +} + +static LogicalResult verifyVecTileUnaryOp(Operation *op, Type srcTy, Type dstTy, + StringRef srcName, + StringRef dstName, + bool allowBf16, + bool allowInt8) { + if (failed(verifyVecTileCommon(op, srcTy, srcName)) || + failed(verifyVecTileCommon(op, dstTy, dstName))) + return failure(); + if (failed(verifyTileBufSameElemType(op, srcTy, dstTy, srcName, dstName))) + return failure(); + if (!isSupportedVecElemType(getElemTy(srcTy), allowBf16, allowInt8)) + return op->emitOpError() << "expects vec tile element types to be supported"; + return success(); +} + +static LogicalResult verifyAccTileCommonA2A3(Operation *op, Type ty, + StringRef name) { + if (failed(verifyTileBufCommon(op, ty, name))) + return failure(); + auto as = getPTOMemorySpaceEnum(ty); + if (!as || *as != pto::AddressSpace::ACC) + return op->emitOpError() << "expects " << name << " to be in the acc address space"; + return success(); +} + +static LogicalResult verifyAccTileCommonA5(Operation *op, Type ty, + StringRef name) { + return verifyAccTileCommonA2A3(op, ty, name); +} + +static LogicalResult verifyAccTileCommon(Operation *op, Type ty, StringRef name) { + switch (getVerifierTargetArch(op)) { + case VerifierTargetArch::A2A3: + return verifyAccTileCommonA2A3(op, ty, name); + case VerifierTargetArch::A5: + return verifyAccTileCommonA5(op, ty, name); + } + return failure(); +} + +static LogicalResult verifyMatTileOperandsA2A3(Operation *op, Type lhsTy, + Type rhsTy, Type dstTy) { + if (failed(verifyTileBufCommon(op, lhsTy, "lhs")) || + failed(verifyTileBufCommon(op, rhsTy, "rhs")) || + failed(verifyAccTileCommon(op, dstTy, "dst"))) + return failure(); + auto lhsSpace = getPTOMemorySpaceEnum(lhsTy); + auto rhsSpace = getPTOMemorySpaceEnum(rhsTy); + auto dstSpace = getPTOMemorySpaceEnum(dstTy); + if (!lhsSpace || !rhsSpace || !dstSpace) + return op->emitOpError("expects lhs, rhs, and dst to have explicit address spaces"); + if (*lhsSpace != pto::AddressSpace::LEFT || *rhsSpace != pto::AddressSpace::RIGHT || + *dstSpace != pto::AddressSpace::ACC) + return op->emitOpError( + "expects lhs, rhs, and dst to use the left, right, and acc address spaces"); + auto lhsShape = getMatmulLogicalShapeVec(lhsTy); + auto rhsShape = getMatmulLogicalShapeVec(rhsTy); + auto dstShape = getMatmulLogicalShapeVec(dstTy); + if ((lhsShape[0] != dstShape[0] || rhsShape[1] != dstShape[1] || lhsShape[1] != rhsShape[0])) + return op->emitOpError( + "expects static matmul tile shapes lhs[M,K], rhs[K,N], and dst[M,N]"); + auto lhsValid = getValidShapeVec(lhsTy); + auto rhsValid = getValidShapeVec(rhsTy); + if (lhsValid.size() == 2 && rhsValid.size() == 2) { + int64_t m = lhsValid[0]; + int64_t k = lhsValid[1]; + int64_t n = rhsValid[1]; + if ((m != ShapedType::kDynamic && (m < 1 || m > 4095)) || + (k != ShapedType::kDynamic && (k < 1 || k > 4095)) || + (n != ShapedType::kDynamic && (n < 1 || n > 4095))) + return op->emitOpError("expects m, k, and n valid sizes to be in [1, 4095]"); + } + return success(); +} + +static LogicalResult verifyMatTileOperandsA5(Operation *op, Type lhsTy, + Type rhsTy, Type dstTy) { + if (failed(verifyMatTileOperandsA2A3(op, lhsTy, rhsTy, dstTy))) + return failure(); + + auto lhsTb = mlir::dyn_cast(lhsTy); + auto rhsTb = mlir::dyn_cast(rhsTy); + auto dstTb = mlir::dyn_cast(dstTy); + if (!lhsTb || !rhsTb || !dstTb) + return success(); + + if (lhsTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor)) + return op->emitOpError("expects lhs to use the col_major blayout on A5"); + if (rhsTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor)) + return op->emitOpError("expects rhs to use the row_major blayout on A5"); + if (dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor)) + return op->emitOpError("expects dst to use the col_major blayout on A5"); + + if (lhsTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor)) + return op->emitOpError("expects lhs to use the row_major slayout on A5"); + if (rhsTb.getSLayoutValueI32() != static_cast(pto::SLayout::ColMajor)) + return op->emitOpError("expects rhs to use the col_major slayout on A5"); + if (dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor)) + return op->emitOpError("expects dst to use the row_major slayout on A5"); + return success(); +} + +static LogicalResult verifyMatTileOperands(Operation *op, Type lhsTy, Type rhsTy, + Type dstTy) { + switch (getVerifierTargetArch(op)) { + case VerifierTargetArch::A2A3: + return verifyMatTileOperandsA2A3(op, lhsTy, rhsTy, dstTy); + case VerifierTargetArch::A5: + return verifyMatTileOperandsA5(op, lhsTy, rhsTy, dstTy); + } + return failure(); +} + +static LogicalResult verifyGemvTileOperandsA2A3(Operation *op, Type lhsTy, + Type rhsTy, Type dstTy) { + if (failed(verifyTileBufCommon(op, lhsTy, "lhs")) || + failed(verifyTileBufCommon(op, rhsTy, "rhs")) || + failed(verifyAccTileCommon(op, dstTy, "dst"))) + return failure(); + + auto lhsSpace = getPTOMemorySpaceEnum(lhsTy); + auto rhsSpace = getPTOMemorySpaceEnum(rhsTy); + if (!lhsSpace || !rhsSpace) + return op->emitOpError("expects lhs and rhs to have explicit address spaces"); + if (*lhsSpace != pto::AddressSpace::LEFT || *rhsSpace != pto::AddressSpace::RIGHT) + return op->emitOpError( + "expects lhs and rhs to use the left and right address spaces"); + + auto lhsValid = getValidShapeVec(lhsTy); + auto rhsValid = getValidShapeVec(rhsTy); + auto dstValid = getValidShapeVec(dstTy); + if (lhsValid[0] != ShapedType::kDynamic && lhsValid[0] != 1) + return op->emitOpError("expects lhs valid_shape[0] to be 1 for tgemv"); + if (isa(dstTy) && dstValid[0] != ShapedType::kDynamic && + dstValid[0] != 1) + return op->emitOpError("expects dst valid_shape[0] to be 1 for tgemv"); + if (lhsValid[1] != ShapedType::kDynamic && rhsValid[0] != ShapedType::kDynamic && + lhsValid[1] != rhsValid[0]) + return op->emitOpError() + << "expects lhs valid_shape[1] to equal rhs valid_shape[0], but got " + << lhsValid[1] << " vs " << rhsValid[0]; + if (rhsValid[1] != ShapedType::kDynamic && dstValid[1] != ShapedType::kDynamic && + rhsValid[1] != dstValid[1]) + return op->emitOpError() + << "expects rhs valid_shape[1] to equal dst valid_shape[1], but got " + << rhsValid[1] << " vs " << dstValid[1]; + return success(); +} + +static LogicalResult verifyGemvTileOperandsA5(Operation *op, Type lhsTy, + Type rhsTy, Type dstTy) { + if (failed(verifyGemvTileOperandsA2A3(op, lhsTy, rhsTy, dstTy))) + return failure(); + return verifyMatTileOperandsA5(op, lhsTy, rhsTy, dstTy); +} + +static LogicalResult verifyGemvTileOperands(Operation *op, Type lhsTy, Type rhsTy, + Type dstTy) { + switch (getVerifierTargetArch(op)) { + case VerifierTargetArch::A2A3: + return verifyGemvTileOperandsA2A3(op, lhsTy, rhsTy, dstTy); + case VerifierTargetArch::A5: + return verifyGemvTileOperandsA5(op, lhsTy, rhsTy, dstTy); + } + return failure(); +} + +static LogicalResult verifyMatBiasTileA2A3(Operation *op, Type biasTy, Type dstTy, + bool requireFloatBias) { + if (failed(verifyTileBufCommon(op, biasTy, "bias"))) + return failure(); + auto biasSpace = getPTOMemorySpaceEnum(biasTy); + if (!biasSpace || *biasSpace != pto::AddressSpace::BIAS) + return op->emitOpError("expects bias to be in the bias address space"); + auto biasShape = getShapeVec(biasTy); + if (biasShape[0] != ShapedType::kDynamic && biasShape[0] != 1) + return op->emitOpError("expects bias to have 1 row"); + if (requireFloatBias) { + if (!getElemTy(biasTy).isF32()) + return op->emitOpError("expects bias to have element type f32"); + } else if (getElemTy(biasTy) != getElemTy(dstTy)) { + return op->emitOpError("expects bias and dst to have the same element type"); + } + return success(); +} + +static LogicalResult verifyMatBiasTileA5(Operation *op, Type biasTy, Type dstTy, + bool requireFloatBias) { + if (failed(verifyMatBiasTileA2A3(op, biasTy, dstTy, requireFloatBias))) + return failure(); + if (auto biasTb = dyn_cast(biasTy)) { + if (biasTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor)) + return op->emitOpError("expects bias to use the row_major blayout on A5"); + } + return success(); +} + +static LogicalResult verifyMatBiasTile(Operation *op, Type biasTy, Type dstTy, + bool requireFloatBias) { + switch (getVerifierTargetArch(op)) { + case VerifierTargetArch::A2A3: + return verifyMatBiasTileA2A3(op, biasTy, dstTy, requireFloatBias); + case VerifierTargetArch::A5: + return verifyMatBiasTileA5(op, biasTy, dstTy, requireFloatBias); + } + return failure(); +} + +static LogicalResult verifyMatmulTypeTriple(Operation *op, Type lhsElemTy, + Type rhsElemTy, Type dstElemTy) { + bool isA5 = getVerifierTargetArch(op) == VerifierTargetArch::A5; + auto isInt8 = [](Type ty) { + return ty.isInteger(8); + }; + if (dstElemTy.isInteger(32) && isInt8(lhsElemTy) && isInt8(rhsElemTy)) + return success(); + + auto isSupportedFpInput = [](Type ty) { + return ty.isF16() || ty.isBF16() || ty.isF32(); + }; + if (dstElemTy.isF32() && lhsElemTy == rhsElemTy && isSupportedFpInput(lhsElemTy)) + return success(); + + if (isA5 && dstElemTy.isF32() && lhsElemTy == rhsElemTy) { + if (auto ft = mlir::dyn_cast(lhsElemTy)) { + unsigned width = ft.getWidth(); + if (width == 8 || width == 16 || width == 32) + return success(); + } + } + + return op->emitOpError() + << "expects (dst, lhs, rhs) element types to match one of " + "(i32, i8, i8), (f32, f16, f16), (f32, bf16, bf16), (f32, f32, f32)" + << (isA5 ? ", or an A5-supported fp8 pair" : ""); +} + +LogicalResult pto::TAddOp::verify() { + return verifyArithmeticBinaryTileOpWithArchDispatch( + getOperation(), getSrc0().getType(), getSrc1().getType(), getDst().getType(), + /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, + "expects A2/A3 tadd element type to be i32/i16/f16/f32", + "expects A5 tadd element type to be i32/i16/i8/f16/bf16/f32"); +} + +LogicalResult pto::TAddCOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + Type t0 = getSrc0().getType(); + Type t1 = getSrc1().getType(); + Type t2 = getSrc2().getType(); + Type td = getDst().getType(); + + if (!isPTOShapedLike(t0) || !isPTOShapedLike(t1) || + !isPTOShapedLike(t2) || !isPTOShapedLike(td)) + return emitOpError("expects src0/src1/src2/dst to be memref/tile_buf types"); + + auto s0 = getShapeVec(t0); + auto s1 = getShapeVec(t1); + auto s2 = getShapeVec(t2); + auto sd = getShapeVec(td); + if (s0 != s1 || s0 != s2 || s0 != sd) + return emitOpError("expects src0/src1/src2/dst to have the same shape"); + return success(); +} +LogicalResult pto::TAddSOp::verify() { + return verifyArithmeticScalarTileOpWithArchDispatch( + getOperation(), getSrc().getType(), getDst().getType(), getScalar().getType(), + /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, + "expects A2/A3 tadds element type to be i32/i16/f16/f32", + "expects A5 tadds element type to be i32/i16/i8/f16/bf16/f32", + /*requireValidRowsEqualOnA2A3=*/true, + /*requireValidRowsEqualOnA5=*/false); +} + +LogicalResult pto::TAxpyOp::verify() { + auto verifyCommon = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyVecTileCommon(*this, srcTy, "src")) || + failed(verifyVecTileCommon(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) + return failure(); + + Type scalarTy = getScalar().getType(); + Type srcElem = getElemTy(srcTy); + if (scalarTy != srcElem) + return emitOpError("expects scalar type to match src element type"); + if (getShapeVec(srcTy) != getShapeVec(dstTy)) + return emitOpError("expects src and dst to have the same shape"); + return success(); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + if (failed(verifyCommon())) + return failure(); + Type srcElem = getElemTy(getSrc().getType()); + Type dstElem = getElemTy(getDst().getType()); + bool sameType = srcElem == dstElem; + bool widenF16ToF32 = srcElem.isF16() && dstElem.isF32(); + if (!(sameType || widenF16ToF32)) + return emitOpError( + "expects dst/src element types to match, or dst=f32 and src=f16"); + if (!(dstElem.isF16() || dstElem.isF32())) + return emitOpError("expects A2/A3 taxpy dst element type to be f16/f32"); + if (!(srcElem.isF16() || srcElem.isF32())) + return emitOpError("expects A2/A3 taxpy src element type to be f16/f32"); + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + if (failed(verifyCommon())) + return failure(); + Type srcElem = getElemTy(getSrc().getType()); + Type dstElem = getElemTy(getDst().getType()); + bool sameType = srcElem == dstElem; + bool widenF16ToF32 = srcElem.isF16() && dstElem.isF32(); + if (!(sameType || widenF16ToF32)) + return emitOpError( + "expects dst/src element types to match, or dst=f32 and src=f16"); + if (!(dstElem.isF16() || dstElem.isF32() || dstElem.isBF16())) + return emitOpError("expects A5 taxpy dst element type to be f16/bf16/f32"); + if (!(srcElem.isF16() || srcElem.isF32() || srcElem.isBF16())) + return emitOpError("expects A5 taxpy src element type to be f16/bf16/f32"); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult pto::TAddSCOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + Type ts0 = getSrc0().getType(); + Type ts1 = getSrc1().getType(); + Type td = getDst().getType(); + if (!isPTOShapedLike(ts0) || !isPTOShapedLike(ts1) || !isPTOShapedLike(td)) + return emitOpError("expects src0/src1/dst to be PTO shaped-like types"); + + auto s0 = getShapeVec(ts0); + auto s1 = getShapeVec(ts1); + auto sd = getShapeVec(td); + if (s0 != s1 || s0 != sd) + return emitOpError("expects src0/src1/dst to have the same shape"); + return success(); +} + +LogicalResult pto::TAndOp::verify() { + auto verifyCommon = [&]() -> FailureOr { + return verifyMatchingRowMajorBinaryTileOpCommon( + getOperation(), getSrc0().getType(), getSrc1().getType(), + getDst().getType()); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + FailureOr elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + auto it = mlir::dyn_cast(*elemOr); + if (!it || (it.getWidth() != 8 && it.getWidth() != 16)) + return emitOpError( + "expects A2/A3 tand src0, src1, and dst element type to be i8/i16"); + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + FailureOr elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + auto it = mlir::dyn_cast(*elemOr); + if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && + it.getWidth() != 32)) + return emitOpError( + "expects A5 tand src0, src1, and dst element type to be i8/i16/i32"); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TConcatOp::verify() { + auto verifyCommon = [&]() -> FailureOr { + Type t0 = getSrc0().getType(); + Type t1 = getSrc1().getType(); + Type td = getDst().getType(); + if (failed(verifyTileBufCommon(*this, t0, "src0")) || + failed(verifyTileBufCommon(*this, t1, "src1")) || + failed(verifyTileBufCommon(*this, td, "dst"))) + return failure(); + + Type e0 = getElemTy(t0); + Type e1 = getElemTy(t1); + Type ed = getElemTy(td); + if (!e0 || !e1 || !ed) { + emitOpError("failed to get element type for operands"); + return failure(); + } + if (e0 != e1 || e0 != ed) { + emitOpError("expects src0, src1, and dst to have the same element type"); + return failure(); + } + + auto v0 = getValidShapeVec(getSrc0()); + auto v1 = getValidShapeVec(getSrc1()); + auto vd = getValidShapeVec(getDst()); + if (v0.size() != 2 || v1.size() != 2 || vd.size() != 2) + return emitOpError("expects src0, src1, and dst to have rank-2 valid_shape"); + + // validRow must match dst (when known). + if (v0[0] != ShapedType::kDynamic && vd[0] != ShapedType::kDynamic && v0[0] != vd[0]) + return emitOpError("expects src0 valid row to match dst valid row"); + if (v1[0] != ShapedType::kDynamic && vd[0] != ShapedType::kDynamic && v1[0] != vd[0]) + return emitOpError("expects src1 valid row to match dst valid row"); + + // Total valid columns must fit within dst static cols (when known). + auto sd = getShapeVec(td); + if (sd.size() == 2 && sd[1] != ShapedType::kDynamic && + v0[1] != ShapedType::kDynamic && v1[1] != ShapedType::kDynamic) { + if (v0[1] + v1[1] > sd[1]) + return emitOpError("expects src0.valid_col + src1.valid_col <= dst.cols"); + } + + return e0; + }; + + auto verifyElemType = [&](Type elem) -> LogicalResult { + if (elem.isF16() || elem.isF32() || elem.isBF16()) + return success(); + auto it = mlir::dyn_cast(elem); + if (!it || + (it.getWidth() != 8 && it.getWidth() != 16 && it.getWidth() != 32)) + return emitOpError("expects element type to be i8, i16, i32, f16, f32, or bf16"); + return success(); + }; + + auto verifyLocVec = [&](Type ty, StringRef name) -> LogicalResult { + auto as = getPTOMemorySpaceEnum(ty); + if (!as || *as != pto::AddressSpace::VEC) + return emitOpError() << "expects " << name << " to use loc=vec"; + return success(); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + FailureOr elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + if (failed(verifyLocVec(getSrc0().getType(), "src0")) || + failed(verifyLocVec(getSrc1().getType(), "src1")) || + failed(verifyLocVec(getDst().getType(), "dst"))) + return failure(); + return verifyElemType(*elemOr); + }; + + auto verifyA5 = [&]() -> LogicalResult { + FailureOr elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + if (failed(verifyLocVec(getSrc0().getType(), "src0")) || + failed(verifyLocVec(getSrc1().getType(), "src1")) || + failed(verifyLocVec(getDst().getType(), "dst"))) + return failure(); + if (!isRowMajorTileBuf(getSrc0().getType()) || !isRowMajorTileBuf(getSrc1().getType()) || + !isRowMajorTileBuf(getDst().getType())) + return emitOpError("expects src0, src1, and dst to use row-major layout"); + return verifyElemType(*elemOr); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult pto::TConcatidxOp::verify() { + auto verifyCommon = [&]() -> FailureOr> { + Type t0 = getSrc0().getType(); + Type t1 = getSrc1().getType(); + Type ti0 = getSrc0Idx().getType(); + Type ti1 = getSrc1Idx().getType(); + Type td = getDst().getType(); + if (failed(verifyTileBufCommon(*this, t0, "src0")) || + failed(verifyTileBufCommon(*this, t1, "src1")) || + failed(verifyTileBufCommon(*this, ti0, "src0Idx")) || + failed(verifyTileBufCommon(*this, ti1, "src1Idx")) || + failed(verifyTileBufCommon(*this, td, "dst"))) + return failure(); + + // Check data element type consistency. + Type e0 = getElemTy(t0); + Type e1 = getElemTy(t1); + Type ed = getElemTy(td); + if (!e0 || !e1 || !ed) { + emitOpError("failed to get element type for data operands"); + return failure(); + } + if (e0 != e1 || e0 != ed) { + emitOpError("expects src0, src1, and dst to have the same element type"); + return failure(); + } + + // Check index element type consistency. + Type ei0 = getElemTy(ti0); + Type ei1 = getElemTy(ti1); + if (!ei0 || !ei1) { + emitOpError("failed to get element type for index operands"); + return failure(); + } + if (ei0 != ei1) { + emitOpError("expects src0Idx and src1Idx to have the same element type"); + return failure(); + } + + // All five tiles must be rank-2. + auto v0 = getValidShapeVec(getSrc0()); + auto v1 = getValidShapeVec(getSrc1()); + auto vi0 = getValidShapeVec(getSrc0Idx()); + auto vi1 = getValidShapeVec(getSrc1Idx()); + auto vd = getValidShapeVec(getDst()); + if (v0.size() != 2 || v1.size() != 2 || vi0.size() != 2 || + vi1.size() != 2 || vd.size() != 2) + return emitOpError("expects all operands to have rank-2 valid_shape"); + + // validRow must match dst (when known). + auto checkValidRow = [&](const auto &v, StringRef name) -> LogicalResult { + if (v[0] != ShapedType::kDynamic && vd[0] != ShapedType::kDynamic && + v[0] != vd[0]) + return emitOpError("expects ") << name << " valid row to match dst valid row"; + return success(); + }; + if (failed(checkValidRow(v0, "src0")) || + failed(checkValidRow(v1, "src1")) || + failed(checkValidRow(vi0, "src0Idx")) || + failed(checkValidRow(vi1, "src1Idx"))) + return failure(); + + // Index tile must have cols >= 1 (when known). + if (vi0[1] != ShapedType::kDynamic && vi0[1] < 1) + return emitOpError("expects src0Idx valid_col >= 1"); + if (vi1[1] != ShapedType::kDynamic && vi1[1] < 1) + return emitOpError("expects src1Idx valid_col >= 1"); + + return std::make_pair(e0, ei0); + }; + + auto verifyElementTypes = [&](Type dataElem, Type idxElem) -> LogicalResult { + // Data element type: f16, f32, bf16, i8, i16, i32 (signless). + if (!dataElem.isF16() && !dataElem.isF32() && !dataElem.isBF16()) { + auto it = mlir::dyn_cast(dataElem); + if (!it || !it.isSignless() || + (it.getWidth() != 8 && it.getWidth() != 16 && it.getWidth() != 32)) + return emitOpError() + << "expects data element type to be i8, i16, i32, f16, f32, or bf16"; + } + + // Index element type: i8, i16, i32 (signless). + auto it = mlir::dyn_cast(idxElem); + if (!it || !it.isSignless() || + (it.getWidth() != 8 && it.getWidth() != 16 && it.getWidth() != 32)) + return emitOpError() + << "expects index element type to be i8, i16, or i32"; + return success(); + }; + + auto verifyLocVec = [&](Type ty, StringRef name) -> LogicalResult { + auto as = getPTOMemorySpaceEnum(ty); + if (!as || *as != pto::AddressSpace::VEC) + return emitOpError() << "expects " << name << " to use loc=vec"; + return success(); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + auto elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + if (failed(verifyLocVec(getSrc0().getType(), "src0")) || + failed(verifyLocVec(getSrc1().getType(), "src1")) || + failed(verifyLocVec(getSrc0Idx().getType(), "src0Idx")) || + failed(verifyLocVec(getSrc1Idx().getType(), "src1Idx")) || + failed(verifyLocVec(getDst().getType(), "dst"))) + return failure(); + return verifyElementTypes(elemOr->first, elemOr->second); + }; + + auto verifyA5 = [&]() -> LogicalResult { + auto elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + if (failed(verifyLocVec(getSrc0().getType(), "src0")) || + failed(verifyLocVec(getSrc1().getType(), "src1")) || + failed(verifyLocVec(getSrc0Idx().getType(), "src0Idx")) || + failed(verifyLocVec(getSrc1Idx().getType(), "src1Idx")) || + failed(verifyLocVec(getDst().getType(), "dst"))) + return failure(); + if (!isRowMajorTileBuf(getSrc0().getType()) || + !isRowMajorTileBuf(getSrc1().getType()) || + !isRowMajorTileBuf(getSrc0Idx().getType()) || + !isRowMajorTileBuf(getSrc1Idx().getType()) || + !isRowMajorTileBuf(getDst().getType())) + return emitOpError( + "expects all operands to use row-major layout"); + return verifyElementTypes(elemOr->first, elemOr->second); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult pto::TAndSOp::verify() { + auto verifyCommon = [&]() -> FailureOr { + return verifyDistinctRowMajorUnaryTileOpCommon(getOperation(), getSrc(), + getDst(), "src", "dst"); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + FailureOr elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + auto it = mlir::dyn_cast(*elemOr); + if (!it || (it.getWidth() != 8 && it.getWidth() != 16)) + return emitOpError( + "expects A2/A3 tands src, scalar, and dst element type to be i8/i16"); + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + FailureOr elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + auto it = mlir::dyn_cast(*elemOr); + if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && + it.getWidth() != 32)) + return emitOpError( + "expects A5 tands src, scalar, and dst element type to be i8/i16/i32"); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult pto::TCIOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + + auto elemTy = mlir::dyn_cast(getElemTy(dstTy)); + if (!elemTy) + return emitOpError("expects dst element type to be integer"); + + unsigned bw = elemTy.getWidth(); + if (bw != 16 && bw != 32) + return emitOpError("expects dst element type to be i16/i32"); + + auto sTy = mlir::dyn_cast(getOperand(0).getType()); + if (!sTy) + return emitOpError("expects S to be integer"); + + if (sTy != elemTy) + return emitOpError("expects S and dst element type to be exactly the same type"); + auto shape = getShapeVec(dstTy); + if (shape.size() != 2) + return emitOpError("expects dst to be rank-2"); + if (shape[1] != ShapedType::kDynamic && shape[1] == 1) + return emitOpError("expects dst cols to be different from 1"); + + return success(); +} + +LogicalResult pto::TTriOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + + Type dstTy = getDst().getType(); + if (failed(verifyVecTileCommon(*this, dstTy, "dst"))) + return failure(); + + auto diagonalTy = mlir::dyn_cast(getDiagonal().getType()); + if (!diagonalTy) + return emitOpError("expects diagonal to be an integer operand"); + + int32_t upperOrLower = getUpperOrLower(); + if (upperOrLower != 0 && upperOrLower != 1) + return emitOpError("expects upperOrLower to be 0 (lower) or 1 (upper)"); + + Type elemTy = getElemTy(dstTy); + return dispatchVerifierByArch( + getOperation(), + [&]() -> LogicalResult { + if (!isSupportedVecElemType(elemTy, /*allowBf16=*/false, + /*allowInt8=*/false)) + return emitOpError() + << "expects A2/A3 dst element type to be f16/f32/i16/i32/u16/u32"; + return success(); + }, + [&]() -> LogicalResult { + if (!isSupportedVecElemType(elemTy, /*allowBf16=*/true, + /*allowInt8=*/true)) + return emitOpError() + << "expects A5 dst element type to be f16/f32/bf16/i8/i16/i32/u8/u16/u32"; + return success(); + }); +} + +LogicalResult pto::TCmpOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + Type t0 = getSrc0().getType(); + Type t1 = getSrc1().getType(); + Type td = getDst().getType(); + if (failed(verifyVecTileStorage(*this, t0, "src0")) || + failed(verifyVecTileStorage(*this, t1, "src1")) || + failed(verifyVecTileStorage(*this, td, "dst"))) + return failure(); + + Type e0 = getElemTy(t0); + Type e1 = getElemTy(t1); + Type ed = getElemTy(td); + if (!e0 || !e1 || !ed) + return emitOpError("failed to get element type for src0/src1/dst"); + if (e0 != e1) + return emitOpError("expects src0 and src1 to have the same element type"); + if (!(e0.isInteger(32) || e0.isF16() || e0.isF32())) + return emitOpError("expects A2/A3 tcmp input element type to be i32/f16/f32"); + if (!ed.isInteger(8)) + return emitOpError("expects dst element type to be i8"); + + auto valid0 = getValidShapeVec(t0); + auto valid1 = getValidShapeVec(t1); + auto validd = getValidShapeVec(td); + if (valid0.size() != 2 || valid1.size() != 2 || validd.size() != 2) + return emitOpError("expects src0, src1, and dst to have rank-2 valid_shape"); + if (!hasCompatibleKnownExtent(valid0[0], valid1[0])) + return emitOpError("expects src0 and src1 to have the same valid row"); + if (!hasCompatibleKnownExtent(valid0[1], valid1[1])) + return emitOpError("expects src0 and src1 to have the same valid column"); + if (!hasCompatibleKnownExtent(valid0[0], validd[0])) + return emitOpError("expects src0 valid row to equal dst valid row"); + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + Type t0 = getSrc0().getType(); + Type t1 = getSrc1().getType(); + Type td = getDst().getType(); + if (failed(verifyTileBufCommon(*this, t0, "src0")) || + failed(verifyTileBufCommon(*this, t1, "src1")) || + failed(verifyTileBufCommon(*this, td, "dst"))) + return failure(); + + Type e0 = getElemTy(t0); + Type e1 = getElemTy(t1); + Type ed = getElemTy(td); + if (!e0 || !e1 || !ed) + return emitOpError("failed to get element type for src0/src1/dst"); + if (e0 != e1) + return emitOpError("expects src0 and src1 to have the same element type"); + bool inputOk = e0.isF16() || e0.isF32() || e0.isBF16() || + e0.isInteger(8) || e0.isInteger(16) || e0.isInteger(32); + if (!inputOk) + return emitOpError("expects A5 tcmp input element type to be i8/i16/i32/f16/bf16/f32"); + if (auto it = dyn_cast(ed)) { + if (it.getWidth() != 8) + return emitOpError("expects dst element type to be i8"); + } else { + return emitOpError("expects dst element type to be i8"); + } + + if (getShapeVec(t0) != getShapeVec(t1) || getShapeVec(t0) != getShapeVec(td)) + return emitOpError("expects src0, src1, and dst to have the same shape"); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +// ---- TCMPS verify ---- +LogicalResult pto::TCmpSOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyVecTileStorage(*this, srcTy, "src")) || + failed(verifyVecTileStorage(*this, dstTy, "dst"))) + return failure(); + + Type elemTy = getElemTy(srcTy); + if (!(elemTy.isInteger(16) || elemTy.isInteger(32) || + elemTy.isF16() || elemTy.isF32())) + return emitOpError("expects A2/A3 tcmps input element type to be i16/i32/f16/f32"); + + auto scalarTy = getScalar().getType(); + if (!(scalarTy.isIntOrIndexOrFloat())) + return emitOpError("expects scalar to be integer, index, or float"); + + auto srcValid = getValidShapeVec(srcTy); + auto dstValid = getValidShapeVec(dstTy); + if (srcValid.size() != 2 || dstValid.size() != 2) + return emitOpError("expects src and dst to have rank-2 valid_shape"); + if (srcValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && + srcValid[0] != dstValid[0]) + return emitOpError("expects src and dst to have the same valid_shape[0]"); + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyVecTileStorage(*this, srcTy, "src")) || + failed(verifyVecTileStorage(*this, dstTy, "dst"))) + return failure(); + + Type elemTy = getElemTy(srcTy); + if (!(elemTy.isInteger(8) || elemTy.isInteger(16) || elemTy.isInteger(32) || + elemTy.isF16() || elemTy.isF32())) + return emitOpError("expects A5 tcmps input element type to be i8/i16/i32/f16/f32"); + + auto scalarTy = getScalar().getType(); + if (!(scalarTy.isIntOrIndexOrFloat())) + return emitOpError("expects scalar to be integer, index, or float"); + + auto srcValid = getValidShapeVec(srcTy); + auto dstValid = getValidShapeVec(dstTy); + if (srcValid.size() != 2 || dstValid.size() != 2) + return emitOpError("expects src and dst to have rank-2 valid_shape"); + if (srcValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && + srcValid[0] != dstValid[0]) + return emitOpError("expects src and dst to have the same valid_shape[0]"); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} +LogicalResult pto::TColExpandOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyNDStyleVecTile(*this, srcTy, "src")) || + failed(verifyNDStyleVecTile(*this, dstTy, "dst"))) + return failure(); + if (getElemTy(srcTy) != getElemTy(dstTy)) + return emitOpError("expects src and dst to have the same element type"); + if (!isSupportedVecElemType(getElemTy(srcTy), /*allowBf16=*/true, + /*allowInt8=*/true)) + return emitOpError("expects tcolexpand element type to be supported"); + auto srcValid = getValidShapeVec(getSrc()); + auto dstValid = getValidShapeVec(getDst()); + if (srcValid.size() != 2 || dstValid.size() != 2) + return emitOpError("expects src and dst to have rank-2 valid_shape"); + if (srcValid[1] != ShapedType::kDynamic && dstValid[1] != ShapedType::kDynamic && + srcValid[1] != dstValid[1]) + return emitOpError("expects src and dst to have the same valid_shape[1]"); + return success(); +} +static LogicalResult verifyTColExpandBinaryLikeOp(Operation *op, Type t0, Type t1, + Type td, PTOArch targetArch, + StringRef opName, + bool allowIntegerTypes) { + if (!isPTOShapedLike(t0) || !isPTOShapedLike(t1) || !isPTOShapedLike(td)) + return op->emitOpError("expects src0/src1/dst to be PTO shaped-like types"); + + Type e0 = getElemTy(t0); + Type e1 = getElemTy(t1); + Type ed = getElemTy(td); + if (!e0 || !e1 || !ed) + return op->emitOpError("failed to get element type for src0/src1/dst"); + + auto isSupportedElem = [&](Type elemTy) { + if (elemTy.isF16() || elemTy.isF32()) + return true; + if (!allowIntegerTypes) + return false; + if (elemTy.isInteger(16) || elemTy.isInteger(32)) + return true; + return targetArch == PTOArch::A5 && elemTy.isInteger(8); + }; + if (!isSupportedElem(e0) || !isSupportedElem(e1) || !isSupportedElem(ed)) { + if (!allowIntegerTypes) + return op->emitOpError() << "expects " << opName + << " element type to be f16 or f32"; + if (targetArch == PTOArch::A5) + return op->emitOpError() << "expects A5 " << opName + << " element type to be i8/i16/i32/f16/f32"; + return op->emitOpError() << "expects A2/A3 " << opName + << " element type to be i16/i32/f16/f32"; + } + + if (getShapeVec(t0) != getShapeVec(td)) + return op->emitOpError("expects src0/dst to have same shape"); + if (failed(verifyTileBufSameValidShape(op, t0, td, "src0", "dst"))) + return failure(); + + if (auto src0TileTy = dyn_cast(t0)) { + if (src0TileTy.getBLayoutValueI32() != 0) + return op->emitOpError("expects src0 to use row-major layout"); + } + + if (auto src1TileTy = dyn_cast(t1)) { + if (src1TileTy.getBLayoutValueI32() != 0) + return op->emitOpError("expects src1 to use row-major layout"); + } + if (auto dstTileTy = dyn_cast(td)) { + if (dstTileTy.getBLayoutValueI32() != 0) + return op->emitOpError("expects dst to use row-major layout"); + } + + auto src1Valid = getValidShapeVec(t1); + auto dstValid = getValidShapeVec(td); + if (src1Valid.size() == 2 && dstValid.size() == 2 && + src1Valid[1] != ShapedType::kDynamic && dstValid[1] != ShapedType::kDynamic && + src1Valid[1] != dstValid[1]) + return op->emitOpError("expects src1 valid_shape[1] to equal dst valid_shape[1]"); + + return success(); +} +LogicalResult pto::TColExpandMulOp::verify() { + PTOArch arch = getTargetArch(getOperation()); + return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + arch, "tcolexpandmul", + /*allowIntegerTypes=*/true); +} +LogicalResult pto::TColExpandAddOp::verify() { + PTOArch arch = getTargetArch(getOperation()); + return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + arch, "tcolexpandadd", + /*allowIntegerTypes=*/true); +} +LogicalResult pto::TColExpandDivOp::verify() { + auto verifyByArch = [&](PTOArch targetArch) -> LogicalResult { + bool allowIntegerTypes = (targetArch == PTOArch::A5); + return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + targetArch, "tcolexpanddiv", + /*allowIntegerTypes=*/allowIntegerTypes); + }; + auto verifyA2A3 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A3); }; + auto verifyA5 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A5); }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} +LogicalResult pto::TColExpandSubOp::verify() { + PTOArch arch = getTargetArch(getOperation()); + return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + arch, "tcolexpandsub", + /*allowIntegerTypes=*/true); +} +LogicalResult pto::TColExpandExpdifOp::verify() { + PTOArch arch = getTargetArch(getOperation()); + return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + arch, "tcolexpandexpdif", + /*allowIntegerTypes=*/false); +} +LogicalResult pto::TColExpandMaxOp::verify() { + PTOArch arch = getTargetArch(getOperation()); + return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + arch, "tcolexpandmax", + /*allowIntegerTypes=*/true); +} +LogicalResult pto::TColExpandMinOp::verify() { + PTOArch arch = getTargetArch(getOperation()); + return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + arch, "tcolexpandmin", + /*allowIntegerTypes=*/true); +} +LogicalResult pto::TColMaxOp::verify() { + return verifyTColReductionOpWithArchDispatch( + getOperation(), getSrc().getType(), getDst().getType(), + /*requireNonZeroSrcOnA2A3=*/false, /*requireNonZeroSrcOnA5=*/true, + /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, + "expects A2/A3 tcolmax element type to be f16/f32/i16/i32", + "expects A5 tcolmax element type to be i8/i16/i32/f16/bf16/f32"); +} + +LogicalResult pto::TColArgMaxOp::verify() { + auto verifyByArch = [&]() -> LogicalResult { + return verifyTColArgReductionOpCommon(*this, getSrc().getType(), + getTmp().getType(), getDst().getType()); + }; + + return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); +} + +LogicalResult pto::TColMinOp::verify() { + return verifyTColReductionOpWithArchDispatch( + getOperation(), getSrc().getType(), getDst().getType(), + /*requireNonZeroSrcOnA2A3=*/false, /*requireNonZeroSrcOnA5=*/true, + /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, + "expects A2/A3 tcolmin element type to be f16/f32/i16/i32", + "expects A5 tcolmin element type to be i8/i16/i32/f16/bf16/f32"); +} + +LogicalResult pto::TColArgMinOp::verify() { + auto verifyByArch = [&]() -> LogicalResult { + return verifyTColArgReductionOpCommon(*this, getSrc().getType(), + getTmp().getType(), getDst().getType()); + }; + + return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); +} + + + +ParseResult mlir::pto::TColSumOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand src; + OpAsmParser::UnresolvedOperand tmp; + OpAsmParser::UnresolvedOperand dst; + Type srcTy, tmpTy, dstTy; + bool hasTmp = false; + + // Parse: ins(%src : type) or ins(%src, %tmp {isBinary = ...}: type, type) + if (parser.parseKeyword("ins") || parser.parseLParen() || parser.parseOperand(src)) + return failure(); + + // Check for optional tmp operand (format 2) + if (succeeded(parser.parseOptionalComma())) { + // Format 2: ins(%src, %tmp {isBinary = ...}: type, type) + if (parser.parseOperand(tmp)) + return failure(); + hasTmp = true; + + // Parse attributes (isBinary) + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + // Parse types: : type, type + if (parser.parseColonType(srcTy) || parser.parseComma() || parser.parseType(tmpTy)) + return failure(); + } else { + // Format 1: ins(%src : type) + if (parser.parseColonType(srcTy)) + return failure(); + } + + if (parser.parseRParen()) + return failure(); + + // Parse: outs(%dst : type) + if (parser.parseKeyword("outs") || parser.parseLParen() || + parser.parseOperand(dst) || parser.parseColonType(dstTy) || + parser.parseRParen()) + return failure(); + + // Parse any remaining attributes (for format 1) + if (!hasTmp) { + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + } + + // Resolve operands + if (parser.resolveOperand(src, srcTy, result.operands)) + return failure(); + + if (hasTmp) { + if (parser.resolveOperand(tmp, tmpTy, result.operands)) + return failure(); + } + + if (parser.resolveOperand(dst, dstTy, result.operands)) + return failure(); + + return success(); +} + +void mlir::pto::TColSumOp::print(OpAsmPrinter &p) { + if (getTmp()) { + // Format 2: ins(%src, %tmp {isBinary = ...}: type, type) outs(%dst : type) + p << " ins(" << getSrc() << ", " << getTmp(); + // Print isBinary attribute if present + SmallVector elidedAttrs; + if (!getIsBinaryAttr() || getIsBinaryAttr().getValue() == false) { + elidedAttrs.push_back("isBinary"); + } + p.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + p << " : " << getSrc().getType() << ", " << getTmp().getType() << ")"; + } else { + // Format 1: ins(%src : type) outs(%dst : type) + p << " ins(" << getSrc() << " : " << getSrc().getType() << ")"; + } + + p << " outs(" << getDst() << " : " << getDst().getType() << ")"; + + // Print remaining attributes for format 1 (excluding isBinary) + if (!getTmp()) { + SmallVector elidedAttrs = {"isBinary"}; + p.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + } +} + +LogicalResult pto::TColSumOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyNDStyleVecTile(*this, srcTy, "src")) || + failed(verifyNDStyleVecTile(*this, dstTy, "dst"))) + return failure(); + bool hasTmp = (bool)getTmp(); + bool hasIsBinary = (bool)getIsBinaryAttr(); + if (hasTmp != hasIsBinary) { + if (hasTmp) + return emitOpError("tmp operand requires isBinary attribute"); + return emitOpError("isBinary attribute requires tmp operand"); + } + if (getTmp()) { + Type tmpTy = getTmp().getType(); + if (failed(verifyNDStyleVecTile(*this, tmpTy, "tmp"))) + return failure(); + if (getElemTy(srcTy) != getElemTy(dstTy) || getElemTy(srcTy) != getElemTy(tmpTy)) + return emitOpError("expects src/tmp/dst element types to match"); + } + if (getElemTy(srcTy) != getElemTy(dstTy)) + return emitOpError("expects src/dst element types to match"); + if (failed(verifyColReductionValidRegion(*this, srcTy, dstTy, + /*requireNonZeroSrc=*/false))) + return failure(); + Type elem = getElemTy(srcTy); + if (!(elem.isF16() || elem.isF32() || elem.isInteger(16) || elem.isInteger(32))) + return emitOpError("expects A2/A3 tcolsum element type to be f16/f32/i16/i32"); + return success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyNDStyleVecTile(*this, srcTy, "src")) || + failed(verifyNDStyleVecTile(*this, dstTy, "dst"))) + return failure(); + bool hasTmp = (bool)getTmp(); + bool hasIsBinary = (bool)getIsBinaryAttr(); + if (hasTmp != hasIsBinary) { + if (hasTmp) + return emitOpError("tmp operand requires isBinary attribute"); + return emitOpError("isBinary attribute requires tmp operand"); + } + if (getTmp()) { + Type tmpTy = getTmp().getType(); + if (failed(verifyNDStyleVecTile(*this, tmpTy, "tmp"))) + return failure(); + if (getElemTy(srcTy) != getElemTy(dstTy) || getElemTy(srcTy) != getElemTy(tmpTy)) + return emitOpError("expects src/tmp/dst element types to match"); + } + if (getElemTy(srcTy) != getElemTy(dstTy)) + return emitOpError("expects src/dst element types to match"); + if (failed(verifyColReductionValidRegion(*this, srcTy, dstTy, + /*requireNonZeroSrc=*/true))) + return failure(); + Type elem = getElemTy(srcTy); + if (!(elem.isF16() || elem.isF32() || elem.isBF16() || elem.isInteger(8) || + elem.isInteger(16) || elem.isInteger(32))) + return emitOpError("expects A5 tcolsum element type to be i8/i16/i32/f16/bf16/f32"); + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult pto::TColProdOp::verify() { + return verifyTColReductionOpWithArchDispatch( + getOperation(), getSrc().getType(), getDst().getType(), + /*requireNonZeroSrcOnA2A3=*/false, /*requireNonZeroSrcOnA5=*/false, + /*allowInt8OnA5=*/false, /*allowBf16OnA5=*/true, + "expects A2/A3 tcolprod element type to be f16/f32/i16/i32", + "expects A5 tcolprod element type to be i16/ui16/i32/ui32/f16/bf16/f32"); +} + +llvm::LogicalResult mlir::pto::TCvtOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, srcTy, "src", /*allowLowPrecision=*/true)) || + failed(verifyTileBufCommon(*this, dstTy, "dst", /*allowLowPrecision=*/true))) + return failure(); + if (failed(verifyTileBufSameLogicalExtent(*this, srcTy, dstTy, "src", "dst", + /*compareValidShape=*/false))) + return failure(); + if (failed(verifyTileBufSameLogicalExtent(*this, srcTy, dstTy, "src", "dst", + /*compareValidShape=*/true))) + return failure(); + Type srcElem = getElemTy(srcTy); + Type dstElem = getElemTy(dstTy); + auto verifyA2A3 = [&]() -> LogicalResult { + if (isPTOLowPrecisionType(srcElem) || isPTOLowPrecisionType(dstElem)) + return emitOpError("expects A2/A3 tcvt low-precision element types to be unsupported"); + return success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + if (!isA5SupportedTCvtPair(srcElem, dstElem)) + return emitOpError("expects A5 tcvt low-precision type pairs to match PTO-ISA support"); + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +llvm::LogicalResult mlir::pto::TRandomOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + return emitOpError("trandom is only supported for A5 targets"); + }; + auto verifyA5 = [&]() -> LogicalResult { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + if (!isRowMajorTileBuf(dstTy)) + return emitOpError("expects dst to use row-major layout"); + + Type elemTy = getElemTy(dstTy); + if (!elemTy.isInteger(32)) + return emitOpError("expects dst element type to be i32 or ui32"); + + auto checkWord = [&](Value v, StringRef name) -> LogicalResult { + auto ty = dyn_cast(v.getType()); + if (!ty || ty.getWidth() != 32) + return emitOpError() << "expects " << name << " to be i32/ui32"; + return success(); + }; + if (failed(checkWord(getKey0(), "key0")) || + failed(checkWord(getKey1(), "key1")) || + failed(checkWord(getCounter0(), "counter0")) || + failed(checkWord(getCounter1(), "counter1")) || + failed(checkWord(getCounter2(), "counter2")) || + failed(checkWord(getCounter3(), "counter3"))) + return failure(); + + int32_t rounds = getRounds(); + if (rounds != 7 && rounds != 10) + return emitOpError("expects rounds to be 7 or 10"); + + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult mlir::pto::TDivOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + FailureOr elemOr = verifyMatchingRowMajorBinaryTileOpCommon( + getOperation(), getSrc0().getType(), getSrc1().getType(), + getDst().getType()); + if (failed(elemOr)) + return failure(); + auto elem0 = *elemOr; + if (!(elem0.isF16() || elem0.isF32())) + return emitOpError("expects A2/A3 tdiv element type to be f16 or f32"); + return success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + FailureOr elemOr = verifyMatchingRowMajorBinaryTileOpCommon( + getOperation(), getSrc0().getType(), getSrc1().getType(), + getDst().getType()); + if (failed(elemOr)) + return failure(); + auto elem0 = *elemOr; + if (!(elem0.isF16() || elem0.isF32() || elem0.isInteger(16) || elem0.isInteger(32))) + return emitOpError("expects A5 tdiv element type to be i32/i16/f16/f32"); + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TDivSOp::verify() { + auto isTileLike = [](Type ty) -> bool { + return isa(ty); + }; + auto isScalarLike = [](Type ty) -> bool { + return mlir::isa(ty); + }; + + auto verifyByArch = [&](PTOArch targetArch) -> LogicalResult { + Type srcTy = getSrc().getType(); + Type rhsTy = getScalar().getType(); + Type dstTy = getDst().getType(); + + bool srcTile = isTileLike(srcTy); + bool rhsTile = isTileLike(rhsTy); + bool srcScalar = isScalarLike(srcTy); + bool rhsScalar = isScalarLike(rhsTy); + + if (!(srcTile && rhsScalar) && !(srcScalar && rhsTile)) + return emitOpError("expects one tile-like operand and one scalar operand in ins(...)"); + + Type tileTy = srcTile ? srcTy : rhsTy; + Type scalarTy = srcTile ? rhsTy : srcTy; + + if (failed(verifyScalarTileOp(*this, tileTy, dstTy, "src", "dst", + /*requireValidRowsEqual=*/true, + /*requireValidColsEqual=*/true))) + return failure(); + if (!mlir::isa(scalarTy)) + return emitOpError("scalar must be a scalar type (integer/float)"); + Type elem = getElemTy(tileTy); + if (targetArch == PTOArch::A3 && + !(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || + elem.isF32())) + return emitOpError("expects A2/A3 tdivs element type to be i32/i16/f16/f32"); + if (targetArch == PTOArch::A5 && + !(elem.isInteger(32) || elem.isInteger(16) || elem.isInteger(8) || + elem.isF16() || elem.isF32())) + return emitOpError("expects A5 tdivs element type to be i32/i16/i8/f16/f32"); + return success(); + }; + auto verifyA2A3 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A3); }; + auto verifyA5 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A5); }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TExpOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyVecTileUnaryOp(*this, srcTy, dstTy, "src", "dst", + /*allowBf16=*/false, /*allowInt8=*/false))) + return failure(); + if (failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) + return failure(); + Type srcElem = getElemTy(srcTy); + if (!srcElem.isF16() && !srcElem.isF32()) + return emitOpError("expects element type to be f16 or f32"); + return mlir::success(); + }; + auto verifyA5 = [&]() -> LogicalResult { return verifyA2A3(); }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TExpandsOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + auto dstSpace = getPTOMemorySpaceEnum(dstTy); + if (!dstSpace || (*dstSpace != pto::AddressSpace::VEC && + *dstSpace != pto::AddressSpace::MAT)) + return emitOpError("expects dst to be in the vec or mat address space"); + Type dstElem = getElemTy(dstTy); + Type scalarTy = getScalar().getType(); + if (scalarTy != dstElem) + return emitOpError("expects scalar type == dst element type"); + if (*dstSpace == pto::AddressSpace::VEC && !isRowMajorTileBuf(dstTy)) + return emitOpError("expects vec dst to use row-major layout on A2/A3"); + if (dstElem.isF16() || dstElem.isBF16() || dstElem.isF32()) + return mlir::success(); + if (auto it = mlir::dyn_cast(dstElem)) { + unsigned w = it.getWidth(); + if (w == 16 || w == 32) + return mlir::success(); + } + return emitOpError("expects A2/A3 texpands dst element type to be i16/i32/f16/bf16/f32"); + }; + auto verifyA5 = [&]() -> LogicalResult { + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + auto dstSpace = getPTOMemorySpaceEnum(dstTy); + if (!dstSpace || (*dstSpace != pto::AddressSpace::VEC && + *dstSpace != pto::AddressSpace::MAT)) + return emitOpError("expects dst to be in the vec or mat address space"); + Type dstElem = getElemTy(dstTy); + Type scalarTy = getScalar().getType(); + if (scalarTy != dstElem) + return emitOpError("expects scalar type == dst element type"); + if (dstElem.isF16() || dstElem.isBF16() || dstElem.isF32()) + return mlir::success(); + if (auto it = mlir::dyn_cast(dstElem)) { + unsigned w = it.getWidth(); + if (w == 8 || w == 16 || w == 32) + return mlir::success(); + } + return emitOpError("expects A5 texpands dst element type to be i8/i16/i32/f16/bf16/f32"); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TExtractOp::verify() { + auto hasMatExtractSourceLayoutA2A3 = [&](pto::TileBufType srcTy) -> bool { + int32_t bl = srcTy.getBLayoutValueI32(); + int32_t sl = srcTy.getSLayoutValueI32(); + return bl == static_cast(pto::BLayout::RowMajor) || + (bl != static_cast(pto::BLayout::RowMajor) && + sl == static_cast(pto::SLayout::RowMajor)); + }; + auto hasMatExtractSourceLayoutA5 = [&](pto::TileBufType srcTy, + pto::AddressSpace dstSpace) -> bool { + int32_t bl = srcTy.getBLayoutValueI32(); + int32_t sl = srcTy.getSLayoutValueI32(); + if (dstSpace == pto::AddressSpace::LEFT) { + return (bl == static_cast(pto::BLayout::RowMajor) && + sl == static_cast(pto::SLayout::ColMajor)) || + (bl != static_cast(pto::BLayout::RowMajor) && + sl == static_cast(pto::SLayout::RowMajor)) || + bl == static_cast(pto::BLayout::RowMajor); + } + return (bl == static_cast(pto::BLayout::RowMajor) && + sl == static_cast(pto::SLayout::ColMajor)) || + (bl != static_cast(pto::BLayout::RowMajor) && + sl == static_cast(pto::SLayout::RowMajor)); + }; + auto isA2A3ExtractElemType = [&](Type ty) -> bool { + return ty.isInteger(8) || ty.isF16() || ty.isBF16() || ty.isF32(); + }; + auto isA5ExtractElemType = [&](Type ty) -> bool { + if (auto it = dyn_cast(ty)) + return it.getWidth() == 8; + if (auto ft = dyn_cast(ty)) + return ft.getWidth() == 8 || ft.isF16() || ft.isBF16() || ft.isF32(); + return false; + }; + auto isRowMajorNoneBoxND = [&](pto::TileBufType ty) -> bool { + return ty.getBLayoutValueI32() == static_cast(pto::BLayout::RowMajor) && + ty.getSLayoutValueI32() == static_cast(pto::SLayout::NoneBox); + }; + auto verifyCommon = [&]() -> FailureOr, + std::optional>> { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + auto srcTb = dyn_cast(srcTy); + auto dstTb = dyn_cast(dstTy); + if (!srcTb || !dstTb) + return emitOpError("expects src and dst to be !pto.tile_buf"); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, dstTy, "dst")) || + failed(verifyNonNegativeIndexRowCol( + *getOperation(), getIndexRow(), getIndexCol(), + /*includeIndexAndIntOpsInConstFold=*/false)) || + failed(verifyExtractStaticBoundsCommon( + *getOperation(), getIndexRow(), getIndexCol(), srcTy, dstTy, + /*includeIndexAndIntOpsInConstFold=*/false))) + return failure(); + Type srcElem = getElemTy(srcTy); + Type dstElem = getElemTy(dstTy); + if (!srcElem || !dstElem || srcElem != dstElem) + return emitOpError("expects src and dst to have the same element type"); + auto srcSpace = getPTOMemorySpaceEnum(srcTy); + auto dstSpace = getPTOMemorySpaceEnum(dstTy); + return std::make_tuple(srcTy, dstTy, srcTb, dstTb, srcElem, dstElem, + srcSpace, dstSpace); + }; + auto verifyA2A3 = [&]() -> LogicalResult { + auto common = verifyCommon(); + if (failed(common)) + return failure(); + auto [srcTy, dstTy, srcTb, dstTb, srcElem, dstElem, srcSpace, dstSpace] = + *common; + (void)srcTy; + (void)dstTy; + (void)srcElem; + if (!isA2A3ExtractElemType(dstElem)) + return emitOpError("expects A2/A3 textract element type to be i8/f16/bf16/f32"); + if (srcSpace && dstSpace && *srcSpace == pto::AddressSpace::VEC && + *dstSpace == pto::AddressSpace::VEC) + return mlir::success(); + if (!srcSpace || *srcSpace != pto::AddressSpace::MAT) + return emitOpError("expects A2/A3 textract src to use loc=mat or vec"); + if (!dstSpace || (*dstSpace != pto::AddressSpace::LEFT && + *dstSpace != pto::AddressSpace::RIGHT)) + return emitOpError("expects A2/A3 textract dst to use loc=left, loc=right, or loc=vec"); + if (!hasMatExtractSourceLayoutA2A3(srcTb)) + return emitOpError("expects A2/A3 textract src to use a supported mat blayout/slayout combination"); + if (*dstSpace == pto::AddressSpace::LEFT) { + if (dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) || + dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor)) + return emitOpError("expects A2/A3 left dst to use row_major blayout and row_major slayout"); + } else { + if (dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) || + dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::ColMajor)) + return emitOpError("expects A2/A3 right dst to use row_major blayout and col_major slayout"); + } + return mlir::success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + auto common = verifyCommon(); + if (failed(common)) + return failure(); + auto [srcTy, dstTy, srcTb, dstTb, srcElem, dstElem, srcSpace, dstSpace] = + *common; + (void)srcTy; + (void)dstTy; + (void)srcElem; + if (!isA5ExtractElemType(dstElem)) + return emitOpError("expects A5 textract element type to be an fp8/f16/bf16/f32 or int8 family type"); + if (!srcSpace || !dstSpace) + return emitOpError("expects src and dst to have explicit loc"); + bool okPair = + (*srcSpace == pto::AddressSpace::MAT && + (*dstSpace == pto::AddressSpace::LEFT || + *dstSpace == pto::AddressSpace::RIGHT || + *dstSpace == pto::AddressSpace::SCALING)) || + (*srcSpace == pto::AddressSpace::VEC && + (*dstSpace == pto::AddressSpace::MAT || + *dstSpace == pto::AddressSpace::VEC)); + if (!okPair) + return emitOpError("expects A5 textract to use a supported src/dst loc pair"); + if (*srcSpace == pto::AddressSpace::MAT) { + if (!hasMatExtractSourceLayoutA5(srcTb, *dstSpace)) + return emitOpError("expects A5 textract src to use a supported mat blayout/slayout combination"); + if (*dstSpace == pto::AddressSpace::LEFT) { + if (dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor) || + dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor)) + return emitOpError("expects A5 left dst to use col_major blayout and row_major slayout"); + } else if (*dstSpace == pto::AddressSpace::RIGHT) { + if (dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) || + dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::ColMajor)) + return emitOpError("expects A5 right dst to use row_major blayout and col_major slayout"); + } + } else if (*srcSpace == pto::AddressSpace::VEC && + *dstSpace == pto::AddressSpace::VEC) { + if (!isRowMajorNoneBoxND(srcTb) || !isRowMajorNoneBoxND(dstTb)) + return emitOpError( + "expects A5 vec->vec textract src/dst to use ND layout " + "(blayout=row_major, slayout=none_box)"); + } + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} +mlir::LogicalResult mlir::pto::TInsertOp::verify() { + auto isColMajorRowMajorNZ = [&](pto::TileBufType ty) -> bool { + return ty.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) && + ty.getSLayoutValueI32() == static_cast(pto::SLayout::RowMajor); + }; + auto isRowMajorNoneBoxND = [&](pto::TileBufType ty) -> bool { + return ty.getBLayoutValueI32() == static_cast(pto::BLayout::RowMajor) && + ty.getSLayoutValueI32() == static_cast(pto::SLayout::NoneBox); + }; + auto isA5SupportedVecElemType = [&](Type ty) -> bool { + if (auto it = dyn_cast(ty)) + return it.getWidth() == 8 || it.getWidth() == 32; + if (auto ft = dyn_cast(ty)) + return ft.getWidth() == 8 || ft.isF16() || ft.isBF16() || ft.isF32(); + return false; + }; + auto isA2A3VecInsertElemType = [&](Type ty) -> bool { + return ty.isInteger(8) || ty.isF16() || ty.isBF16() || ty.isF32(); + }; + auto verifyCommon = [&]() -> FailureOr, + std::optional>> { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + auto srcTb = dyn_cast(srcTy); + auto dstTb = dyn_cast(dstTy); + if (!srcTb || !dstTb) + return emitOpError("expects src and dst to be !pto.tile_buf"); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, dstTy, "dst")) || + failed(verifyNonNegativeIndexRowCol( + *getOperation(), getIndexRow(), getIndexCol(), + /*includeIndexAndIntOpsInConstFold=*/true)) || + failed(verifyInsertStaticBoundsCommon( + *getOperation(), getIndexRow(), getIndexCol(), srcTy, dstTy, + /*includeIndexAndIntOpsInConstFold=*/true))) + return failure(); + Type srcElem = getElemTy(srcTy); + Type dstElem = getElemTy(dstTy); + auto srcSpace = getPTOMemorySpaceEnum(srcTy); + auto dstSpace = getPTOMemorySpaceEnum(dstTy); + return std::make_tuple(srcTy, dstTy, srcTb, dstTb, srcElem, dstElem, + srcSpace, dstSpace); + }; + auto verifyA2A3 = [&]() -> LogicalResult { + auto common = verifyCommon(); + if (failed(common)) + return failure(); + auto [srcTy, dstTy, srcTb, dstTb, srcElem, dstElem, srcSpace, dstSpace] = + *common; + if (srcSpace && dstSpace && *srcSpace == pto::AddressSpace::VEC && + *dstSpace == pto::AddressSpace::VEC) { + if (srcElem != dstElem || !isA2A3VecInsertElemType(srcElem)) + return emitOpError( + "expects A2/A3 vec->vec tinsert src/dst to have same supported dtype " + "(i8/f16/bf16/f32)"); + return success(); + } + if (!srcSpace || !dstSpace || *srcSpace != pto::AddressSpace::ACC || + *dstSpace != pto::AddressSpace::MAT) + return emitOpError("expects A2/A3 tinsert to use acc->mat or vec->vec"); + + if (!isColMajorRowMajorNZ(srcTb)) + return emitOpError("expects A2/A3 tinsert src to use blayout=col_major and slayout=row_major"); + if (!isColMajorRowMajorNZ(dstTb)) + return emitOpError("expects A2/A3 tinsert dst to use blayout=col_major and slayout=row_major"); + if (dstTb.getSFractalSizeI32() != 512) + return emitOpError("expects A2/A3 tinsert dst fractal size to be 512"); + + if (!(srcElem.isF32() && (dstElem.isF16() || dstElem.isBF16()))) + return emitOpError("expects A2/A3 tinsert element types to be src=f32, dst=f16/bf16"); + return success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + auto common = verifyCommon(); + if (failed(common)) + return failure(); + auto [srcTy, dstTy, srcTb, dstTb, srcElem, dstElem, srcSpace, dstSpace] = + *common; + if (!srcSpace || !dstSpace) + return emitOpError("expects A5 tinsert src/dst to have explicit loc"); + + // A5 regular acc->mat path. + if (*srcSpace == pto::AddressSpace::ACC && *dstSpace == pto::AddressSpace::MAT) { + if (!isColMajorRowMajorNZ(srcTb)) + return emitOpError("expects A5 acc->mat tinsert src to use blayout=col_major and slayout=row_major"); + if (!isColMajorRowMajorNZ(dstTb)) + return emitOpError("expects A5 acc->mat tinsert dst to use blayout=col_major and slayout=row_major"); + bool okTypes = (srcElem.isF32() && + (dstElem.isF16() || dstElem.isBF16() || dstElem.isF32())) || + (srcElem.isInteger(32) && dstElem.isInteger(32)); + if (!okTypes) + return emitOpError( + "expects A5 acc->mat tinsert element types to be " + "(src=f32,dst=f16/bf16/f32) or (src=i32,dst=i32)"); + return success(); + } + + // A5 vec->mat path (ND/NZ modes in pto-isa). + if (*srcSpace == pto::AddressSpace::VEC && *dstSpace == pto::AddressSpace::MAT) { + if (!isColMajorRowMajorNZ(dstTb)) + return emitOpError("expects A5 vec->mat tinsert dst to use blayout=col_major and slayout=row_major"); + bool srcIsND = isRowMajorNoneBoxND(srcTb); + bool srcIsNZ = isColMajorRowMajorNZ(srcTb); + if (!srcIsND && !srcIsNZ) + return emitOpError( + "expects A5 vec->mat tinsert src to use ND(row_major/none_box) or NZ(col_major/row_major) layout"); + if (srcElem != dstElem || !isA5SupportedVecElemType(srcElem)) + return emitOpError( + "expects A5 vec->mat tinsert src/dst to have same supported dtype " + "(fp8/f16/bf16/f32/i8/i32)"); + return success(); + } + + // A5 vec->vec path (PR561 ND_VEC). + if (*srcSpace == pto::AddressSpace::VEC && *dstSpace == pto::AddressSpace::VEC) { + if (!isRowMajorNoneBoxND(srcTb) || !isRowMajorNoneBoxND(dstTb)) + return emitOpError( + "expects A5 vec->vec tinsert src/dst to use ND layout " + "(blayout=row_major, slayout=none_box)"); + if (srcElem != dstElem || !isA5SupportedVecElemType(srcElem)) + return emitOpError( + "expects A5 vec->vec tinsert src/dst to have same supported dtype " + "(fp8/f16/bf16/f32/i8/i32)"); + return success(); + } + + return emitOpError( + "expects A5 tinsert to use a supported src/dst loc pair: " + "acc->mat, vec->mat, or vec->vec"); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +static bool isColMajorRowMajorNZTileBuf(pto::TileBufType ty) { + return ty.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) && + ty.getSLayoutValueI32() == static_cast(pto::SLayout::RowMajor); +} + +static bool isA2A3VectorPreQuantTypePair(Type srcElem, Type dstElem) { + if (srcElem.isF32()) + return dstElem.isInteger(8); + if (srcElem.isInteger(32)) + return dstElem.isInteger(8) || dstElem.isF16() || dstElem.isInteger(16); + return false; +} + +static bool isA5Fp8LikeType(Type ty) { + if (auto ft = dyn_cast(ty)) + return ft.getWidth() == 8; + return false; +} + +static bool isA5MxInputType(Type ty) { + return isA5Fp8LikeType(ty); +} + +static LogicalResult verifyA5MxTypeTriple(Operation *op, Type lhsTy, Type rhsTy, + Type dstTy, StringRef lhsName, + StringRef rhsName, StringRef dstName) { + Type lhsElem = getElemTy(lhsTy); + Type rhsElem = getElemTy(rhsTy); + Type dstElem = getElemTy(dstTy); + + if (!isA5MxInputType(lhsElem) || !isA5MxInputType(rhsElem)) + return op->emitOpError() + << "expects A5 mx operands " << lhsName << " and " << rhsName + << " to use fp8 element types"; + + if (!dstElem.isF32()) + return op->emitOpError() + << "expects A5 mx result " << dstName << " to use f32 element type"; + + return success(); +} + +static bool isA5VectorPreQuantTypePair(Type srcElem, Type dstElem) { + if (srcElem.isF32()) + return dstElem.isInteger(8) || isA5Fp8LikeType(dstElem) || dstElem.isF16() || + dstElem.isBF16() || dstElem.isF32(); + if (srcElem.isInteger(32)) + return dstElem.isInteger(8) || dstElem.isF16() || dstElem.isBF16(); + return false; +} + +mlir::LogicalResult mlir::pto::TExtractFPOp::verify() { + auto verifyCommon = [&]() -> FailureOr> { + Type srcTy = getSrc().getType(); + Type fpTy = getFp().getType(); + Type dstTy = getDst().getType(); + auto srcTb = dyn_cast(srcTy); + auto fpTb = dyn_cast(fpTy); + auto dstTb = dyn_cast(dstTy); + if (!srcTb || !fpTb || !dstTb) + return emitOpError("expects src, fp, and dst to be !pto.tile_buf"); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, fpTy, "fp")) || + failed(verifyTileBufCommon(*this, dstTy, "dst")) || + failed(verifyNonNegativeIndexRowCol( + *getOperation(), getIndexRow(), getIndexCol(), + /*includeIndexAndIntOpsInConstFold=*/true)) || + failed(verifyExtractStaticBoundsCommon( + *getOperation(), getIndexRow(), getIndexCol(), srcTy, dstTy, + /*includeIndexAndIntOpsInConstFold=*/true))) + return failure(); + auto srcSpace = getPTOMemorySpaceEnum(srcTy); + auto fpSpace = getPTOMemorySpaceEnum(fpTy); + auto dstSpace = getPTOMemorySpaceEnum(dstTy); + if (!srcSpace || !fpSpace || !dstSpace) + return emitOpError("expects src, fp, and dst to have explicit loc"); + if (*srcSpace != pto::AddressSpace::ACC) + return emitOpError("expects src to use loc=acc"); + if (*fpSpace != pto::AddressSpace::SCALING) + return emitOpError("expects fp to use loc=scaling"); + if (*dstSpace != pto::AddressSpace::MAT) + return emitOpError("expects dst to use loc=mat"); + if (!isColMajorRowMajorNZTileBuf(srcTb)) + return emitOpError("expects src to use blayout=col_major and slayout=row_major"); + if (!isColMajorRowMajorNZTileBuf(dstTb)) + return emitOpError("expects dst to use blayout=col_major and slayout=row_major"); + return std::make_tuple(srcTy, fpTy, dstTy, srcTb, fpTb, dstTb, *srcSpace, + *fpSpace, *dstSpace); + }; + auto verifyA2A3 = [&]() -> LogicalResult { + auto common = verifyCommon(); + if (failed(common)) + return failure(); + auto [srcTy, fpTy, dstTy, srcTb, fpTb, dstTb, srcSpace, fpSpace, dstSpace] = + *common; + (void)fpTy; + (void)srcSpace; + (void)fpSpace; + (void)dstSpace; + if (dstTb.getSFractalSizeI32() != 512) + return emitOpError("expects dst fractal size to be 512"); + Type srcElem = getElemTy(srcTy); + Type dstElem = getElemTy(dstTy); + if (!isA2A3VectorPreQuantTypePair(srcElem, dstElem)) + return emitOpError( + "expects A2/A3 textract_fp element types to be (src=f32,dst=i8) " + "or (src=i32,dst=i8/f16/i16)"); + return success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + auto common = verifyCommon(); + if (failed(common)) + return failure(); + auto [srcTy, fpTy, dstTy, srcTb, fpTb, dstTb, srcSpace, fpSpace, dstSpace] = + *common; + (void)fpTy; + (void)srcTb; + (void)fpTb; + (void)dstTb; + (void)srcSpace; + (void)fpSpace; + (void)dstSpace; + Type srcElem = getElemTy(srcTy); + Type dstElem = getElemTy(dstTy); + if (!isA5VectorPreQuantTypePair(srcElem, dstElem)) + return emitOpError( + "expects A5 textract_fp element types to be (src=f32,dst=i8/fp8/f16/bf16/f32) " + "or (src=i32,dst=i8/f16/bf16)"); + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TInsertFPOp::verify() { + auto verifyCommon = [&]() -> FailureOr> { + Type srcTy = getSrc().getType(); + Type fpTy = getFp().getType(); + Type dstTy = getDst().getType(); + auto srcTb = dyn_cast(srcTy); + auto fpTb = dyn_cast(fpTy); + auto dstTb = dyn_cast(dstTy); + if (!srcTb || !fpTb || !dstTb) + return emitOpError("expects src, fp, and dst to be !pto.tile_buf"); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, fpTy, "fp")) || + failed(verifyTileBufCommon(*this, dstTy, "dst")) || + failed(verifyNonNegativeIndexRowCol( + *getOperation(), getIndexRow(), getIndexCol(), + /*includeIndexAndIntOpsInConstFold=*/true)) || + failed(verifyInsertStaticBoundsCommon( + *getOperation(), getIndexRow(), getIndexCol(), srcTy, dstTy, + /*includeIndexAndIntOpsInConstFold=*/true))) + return failure(); + auto srcSpace = getPTOMemorySpaceEnum(srcTy); + auto fpSpace = getPTOMemorySpaceEnum(fpTy); + auto dstSpace = getPTOMemorySpaceEnum(dstTy); + if (!srcSpace || !fpSpace || !dstSpace) + return emitOpError("expects src, fp, and dst to have explicit loc"); + if (*srcSpace != pto::AddressSpace::ACC) + return emitOpError("expects src to use loc=acc"); + if (*fpSpace != pto::AddressSpace::SCALING) + return emitOpError("expects fp to use loc=scaling"); + if (*dstSpace != pto::AddressSpace::MAT) + return emitOpError("expects dst to use loc=mat"); + if (!isColMajorRowMajorNZTileBuf(srcTb)) + return emitOpError("expects src to use blayout=col_major and slayout=row_major"); + if (!isColMajorRowMajorNZTileBuf(dstTb)) + return emitOpError("expects dst to use blayout=col_major and slayout=row_major"); + return std::make_tuple(srcTy, fpTy, dstTy, srcTb, fpTb, dstTb, *srcSpace, + *fpSpace, *dstSpace); + }; + auto verifyA2A3 = [&]() -> LogicalResult { + auto common = verifyCommon(); + if (failed(common)) + return failure(); + auto [srcTy, fpTy, dstTy, srcTb, fpTb, dstTb, srcSpace, fpSpace, dstSpace] = + *common; + (void)fpTy; + (void)srcTb; + (void)fpTb; + (void)srcSpace; + (void)fpSpace; + (void)dstSpace; + if (dstTb.getSFractalSizeI32() != 512) + return emitOpError("expects dst fractal size to be 512"); + Type srcElem = getElemTy(srcTy); + Type dstElem = getElemTy(dstTy); + if (!isA2A3VectorPreQuantTypePair(srcElem, dstElem)) + return emitOpError( + "expects A2/A3 tinsert_fp element types to be (src=f32,dst=i8) " + "or (src=i32,dst=i8/f16/i16)"); + return success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + auto common = verifyCommon(); + if (failed(common)) + return failure(); + auto [srcTy, fpTy, dstTy, srcTb, fpTb, dstTb, srcSpace, fpSpace, dstSpace] = + *common; + (void)fpTy; + (void)srcTb; + (void)fpTb; + (void)dstTb; + (void)srcSpace; + (void)fpSpace; + (void)dstSpace; + Type srcElem = getElemTy(srcTy); + Type dstElem = getElemTy(dstTy); + if (!isA5VectorPreQuantTypePair(srcElem, dstElem)) + return emitOpError( + "expects A5 tinsert_fp element types to be (src=f32,dst=i8/fp8/f16/bf16/f32) " + "or (src=i32,dst=i8/f16/bf16)"); + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +static mlir::LogicalResult verifyTFillPadLike(Operation *op, Type srcTy, Type dstTy, + bool allowDstExpand, + llvm::StringRef opName) { + if (!isPTOShapedLike(srcTy) || !isPTOShapedLike(dstTy)) + return op->emitError("expects src/dst to be PTO shaped-like types"); + + auto srcShape = getShapeVec(srcTy); + auto dstShape = getShapeVec(dstTy); + if (srcShape.size() != 2 || dstShape.size() != 2) + return op->emitError("expects rank-2 shaped types for src/dst"); + + auto srcElem = getElemTy(srcTy); + auto dstElem = getElemTy(dstTy); + + auto getElemBytes = [](mlir::Type t) -> int64_t { + unsigned elemBytes = getPTOStorageElemByteSize(t); + return elemBytes == 0 ? -1 : static_cast(elemBytes); + }; + + int64_t srcB = getElemBytes(srcElem); + int64_t dstB = getElemBytes(dstElem); + if (srcB < 0 || dstB < 0) + return op->emitError("unsupported element type (expects int/float element types)"); + if (srcB != dstB) + return op->emitError("expects sizeof(src element) == sizeof(dst element)"); + if (!(srcB == 1 || srcB == 2 || srcB == 4)) + return op->emitError("expects element size to be 1, 2, or 4 bytes"); + + // pto.tfillpad lowers to TFILLPAD(dst, src). For loc=mat, pto-isa only + // exposes the homogeneous overload, so src/dst must use the same Tile<...> + // specialization (including valid_shape and pad). + // Note: tfillpad_expand is intentionally not covered here because its + // cross-layer ABI contract for loc=mat heterogeneous shape expansion is not + // finalized yet. + if (opName == "tfillpad") { + auto srcTb = mlir::dyn_cast(srcTy); + auto dstTb = mlir::dyn_cast(dstTy); + auto srcSpace = getPTOMemorySpaceEnum(srcTy); + auto dstSpace = getPTOMemorySpaceEnum(dstTy); + if (srcTb && dstTb && srcSpace && dstSpace && + *srcSpace == mlir::pto::AddressSpace::MAT && + *dstSpace == mlir::pto::AddressSpace::MAT && srcTb != dstTb) { + auto dimToStr = [](int64_t dim) -> std::string { + return dim == ShapedType::kDynamic ? "?" : std::to_string(dim); + }; + SmallVector mismatchFields; + auto srcValid = getValidShapeVec(srcTy); + auto dstValid = getValidShapeVec(dstTy); + if (srcValid.size() == 2 && dstValid.size() == 2) { + if (srcValid[0] != dstValid[0]) + mismatchFields.push_back("v_row (" + dimToStr(srcValid[0]) + " vs " + + dimToStr(dstValid[0]) + ")"); + if (srcValid[1] != dstValid[1]) + mismatchFields.push_back("v_col (" + dimToStr(srcValid[1]) + " vs " + + dimToStr(dstValid[1]) + ")"); + } + if (srcTb.getPadValueI32() != dstTb.getPadValueI32()) + mismatchFields.push_back("pad (" + std::to_string(srcTb.getPadValueI32()) + + " vs " + std::to_string(dstTb.getPadValueI32()) + + ")"); + + auto diag = op->emitError() + << "expects src/dst tile types to be lowerable to TFILLPAD " + "for loc=mat"; + if (!mismatchFields.empty()) + diag << "; mismatching fields: " << llvm::join(mismatchFields, ", "); + diag << "\n src: " << srcTy; + diag << "\n dst: " << dstTy; + diag << "\n note: heterogeneous TFILLPAD overload is only available for loc=vec"; + return failure(); + } + } + + if (auto dstTileTy = mlir::dyn_cast(dstTy)) { + auto padAttr = mlir::dyn_cast(dstTileTy.getPadValueAttr()); + if (!padAttr || padAttr.getValue() == mlir::pto::PadValue::Null) + return op->emitError() << "expects dst PadVal != Null for " << opName; + } + + if (!allowDstExpand) { + if (srcShape != dstShape) + return op->emitError() + << "expects src and dst to have the same static shape for " << opName; + return mlir::success(); + } + + if (srcShape[0] > dstShape[0] || srcShape[1] > dstShape[1]) { + return op->emitError() + << "expects dst static shape to be >= src static shape for " << opName; + } + + return mlir::success(); +} + +mlir::LogicalResult mlir::pto::TFillPadOp::verify() { + return verifyTFillPadLike(getOperation(), getSrc().getType(), getDst().getType(), + /*allowDstExpand=*/false, "tfillpad"); +} + +mlir::LogicalResult mlir::pto::TFillPadExpandOp::verify() { + return verifyTFillPadLike(getOperation(), getSrc().getType(), getDst().getType(), + /*allowDstExpand=*/true, "tfillpad_expand"); +} + +mlir::LogicalResult mlir::pto::TFillPadInplaceOp::verify() { + return verifyTFillPadLike(getOperation(), getSrc().getType(), getDst().getType(), + /*allowDstExpand=*/false, "tfillpad_inplace"); +} + + +llvm::LogicalResult mlir::pto::TGatherOp::verify() { + auto isSupportedGatherElemTypeA5Index = [&](Type ty) -> bool { + if (ty.isF16() || ty.isF32()) + return true; + if (auto it = dyn_cast(ty)) { + unsigned width = it.getWidth(); + return width == 8 || width == 16 || width == 32; + } + return false; + }; + + auto verifyMaskForm = [&](bool allowA5MaskTypes) -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + + Type srcElem = getElemTy(srcTy); + Type dstElem = getElemTy(dstTy); + if (!srcElem || !dstElem) + return emitOpError("failed to get element type for src/dst"); + if (!isRowMajorTileBuf(srcTy) || !isRowMajorTileBuf(dstTy)) + return emitOpError("expects src and dst to use row-major layout"); + auto srcSpace = getPTOMemorySpaceEnum(srcTy); + auto dstSpace = getPTOMemorySpaceEnum(dstTy); + if (!srcSpace || !dstSpace || *srcSpace != pto::AddressSpace::VEC || + *dstSpace != pto::AddressSpace::VEC) + return emitOpError("expects src and dst to be in the vec address space"); + unsigned srcElemBytes = getPTOStorageElemByteSize(srcElem); + unsigned dstElemBytes = getPTOStorageElemByteSize(dstElem); + if (srcElemBytes == 0 || dstElemBytes == 0) + return emitOpError("failed to get element size for src/dst"); + if (srcElemBytes != dstElemBytes) + return emitOpError("expects src and dst element sizes to match"); + + auto dstValid = getValidShapeVec(dstTy); + auto dstShape = getShapeVec(dstTy); + if (dstValid.size() == 2 && dstShape.size() == 2 && + dstValid[1] != ShapedType::kDynamic && dstShape[1] != ShapedType::kDynamic && + dstValid[1] != dstShape[1]) { + return emitOpError("expects dst valid_shape[1] to equal dst cols"); + } + + if (allowA5MaskTypes) { + if (!(srcElemBytes == 1 || srcElemBytes == 2 || srcElemBytes == 4)) + return emitOpError("expects A5 mask-pattern gather element size to be 1, 2, or 4 bytes"); + if (!isSupportedGatherElemTypeA5(srcElem) || !isSupportedGatherElemTypeA5(dstElem)) + return emitOpError( + "expects A5 mask-pattern gather src/dst element type to be i8/i16/i32/f16/bf16/f32/fp8-like"); + } else { + if (!(srcElemBytes == 2 || srcElemBytes == 4)) + return emitOpError("expects A2/A3 mask-pattern gather element size to be 2 or 4 bytes"); + } + return success(); + }; + + auto verifyIndexForm = [&](bool allow16BitIndices, bool allowA5ElemTypes) -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + Type idxTy = getIndices().getType(); + Type tmpTy = getTmp().getType(); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, dstTy, "dst")) || + failed(verifyTileBufCommon(*this, idxTy, "indices")) || + failed(verifyTileBufCommon(*this, tmpTy, "tmp"))) + return failure(); + + Type srcElem = getElemTy(srcTy); + Type dstElem = getElemTy(dstTy); + if (!srcElem || !dstElem) + return emitOpError("failed to get element type for src/dst"); + if (srcElem != dstElem) + return emitOpError("expects src and dst to have the same element type"); + if (allowA5ElemTypes) { + if (!isSupportedGatherElemTypeA5Index(srcElem) || + !isSupportedGatherElemTypeA5Index(dstElem)) + return emitOpError( + "expects A5 gather src/dst element type to be i8/i16/i32/f16/f32"); + } else if (!isSupportedGatherElemTypeA2A3(srcElem) || + !isSupportedGatherElemTypeA2A3(dstElem)) { + return emitOpError("expects gather src/dst element type to be i16/i32/f16/f32"); + } + + auto idxElem = dyn_cast(getElemTy(idxTy)); + if (!idxElem) + return emitOpError("indices element type must be integer"); + unsigned width = idxElem.getWidth(); + if (!(width == 32 || (allow16BitIndices && width == 16))) { + return emitOpError() << "expects indices element type to be i32" + << (allow16BitIndices ? " or i16" : ""); + } + + auto dstValid = getValidShapeVec(dstTy); + auto dstShape = getShapeVec(dstTy); + if (dstValid.size() == 2 && dstShape.size() == 2 && + dstValid[1] != ShapedType::kDynamic && dstShape[1] != ShapedType::kDynamic && + dstValid[1] != dstShape[1]) { + return emitOpError("expects dst valid_shape[1] to equal dst cols"); + } + + auto idxValid = getValidShapeVec(idxTy); + auto idxShape = getShapeVec(idxTy); + if (idxValid.size() == 2 && idxShape.size() == 2 && + idxValid[1] != ShapedType::kDynamic && idxShape[1] != ShapedType::kDynamic && + idxValid[1] != idxShape[1]) { + return emitOpError("expects indices valid_shape[1] to equal indices cols"); + } + + if (!allowA5ElemTypes) { + Type tmpElem = getElemTy(tmpTy); + if (tmpElem != idxElem) + return emitOpError("expects tmp and indices to have the same element type"); + if (failed(verifyTileBufSameValidShape(*this, idxTy, tmpTy, "indices", "tmp"))) + return failure(); + } + return success(); + }; + + auto verifyCompareForm = [&](bool allowA5SrcTypes) -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + Type cdstTy = getCdst().getType(); + Type tmpTy = getTmp().getType(); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, dstTy, "dst")) || + failed(verifyTileBufCommon(*this, cdstTy, "cdst")) || + failed(verifyTileBufCommon(*this, tmpTy, "tmp"))) + return failure(); + + Type srcElem = getElemTy(srcTy); + Type dstElem = getElemTy(dstTy); + Type cdstElem = getElemTy(cdstTy); + if (!srcElem || !dstElem || !cdstElem) + return emitOpError("failed to get element type for src/dst/cdst"); + auto dstInt = dyn_cast(dstElem); + if (!dstInt || dstInt.getWidth() != 32) + return emitOpError("expects dst element type to be i32"); + if (cdstElem != dstElem) + return emitOpError("expects cdst to have the same element type as dst"); + if (getKValue().getType() != srcElem) + return emitOpError("expects kValue to have the same type as src element type"); + + auto cmpAttr = getCmpModeAttr(); + auto cmpMode = cmpAttr ? cmpAttr.getValue() : pto::CmpMode::EQ; + if (cmpMode != pto::CmpMode::EQ && cmpMode != pto::CmpMode::GT) + return emitOpError("expects compare-form tgather cmpMode to be eq or gt"); + + if (allowA5SrcTypes) { + if (!(srcElem.isF16() || srcElem.isF32() || srcElem.isInteger(16) || + srcElem.isInteger(32))) { + return emitOpError( + "expects A5 compare-form tgather src element type to be i16/i32/f16/f32"); + } + } else { + if (!(srcElem.isF16() || srcElem.isF32() || + (srcElem.isInteger(32) && cmpMode == pto::CmpMode::EQ))) { + return emitOpError( + "expects A2/A3 compare-form tgather src element type to be f16/f32, or i32 when cmpMode=eq"); + } + } + + if (failed(verifyVecTileCommonA2A3(*this, srcTy, "src")) || + failed(verifyVecTileCommonA2A3(*this, dstTy, "dst")) || + failed(verifyVecTileCommonA2A3(*this, cdstTy, "cdst")) || + failed(verifyVecTileCommonA2A3(*this, tmpTy, "tmp"))) + return failure(); + return success(); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + if (getMaskPatternAttr()) { + if (getCdst() || getIndices() || getTmp() || getKValue()) + return emitOpError("mask-pattern tgather only allows src and dst operands"); + return verifyMaskForm(/*allowA5MaskTypes=*/false); + } + if (getCdst() || getKValue()) { + if (!getCdst() || !getKValue() || !getTmp()) + return emitOpError("compare-form tgather expects dst, cdst, kValue, and tmp"); + if (getIndices()) + return emitOpError("compare-form tgather does not take indices"); + return verifyCompareForm(/*allowA5SrcTypes=*/false); + } + if (!getIndices() || !getTmp()) + return emitOpError("index-form tgather expects both indices and tmp"); + return verifyIndexForm(/*allow16BitIndices=*/false, /*allowA5ElemTypes=*/false); + }; + + auto verifyA5 = [&]() -> LogicalResult { + if (getMaskPatternAttr()) { + if (getCdst() || getIndices() || getTmp() || getKValue()) + return emitOpError("mask-pattern tgather only allows src and dst operands"); + return verifyMaskForm(/*allowA5MaskTypes=*/true); + } + if (getCdst() || getKValue()) { + if (!getCdst() || !getKValue() || !getTmp()) + return emitOpError("compare-form tgather expects dst, cdst, kValue, and tmp"); + if (getIndices()) + return emitOpError("compare-form tgather does not take indices"); + return verifyCompareForm(/*allowA5SrcTypes=*/true); + } + if (!getIndices() || !getTmp()) + return emitOpError("index-form tgather expects both indices and tmp"); + return verifyIndexForm(/*allow16BitIndices=*/true, /*allowA5ElemTypes=*/true); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} +mlir::LogicalResult mlir::pto::TGatherBOp::verify() { + auto verifyCommon = [&]() -> FailureOr> { + Type srcTy = getSrc().getType(); + Type offTy = getOffsets().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, offTy, "offsets")) || + failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + auto srcElemTy = getElemTy(srcTy); + auto dstElemTy = getElemTy(dstTy); + if (!srcElemTy || !dstElemTy) + return emitOpError() << "failed to get element type for src/dst"; + return std::make_pair(srcElemTy, dstElemTy); + }; + + auto getElemBytes = [](Type ty) -> std::optional { + unsigned elemBytes = getPTOStorageElemByteSize(ty); + if (elemBytes == 0) + return std::nullopt; + return elemBytes; + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + FailureOr> elems = verifyCommon(); + if (failed(elems)) + return failure(); + Type dstTy = getDst().getType(); + Type dstElemTy = elems->second; + if (!isRowMajorTileBuf(dstTy)) + return emitOpError() << "expects dst to use row-major layout"; + auto dstBytes = getElemBytes(dstElemTy); + if (!dstBytes || (*dstBytes != 1 && *dstBytes != 2 && *dstBytes != 4)) + return emitOpError() << "expects dst element size to be 1, 2, or 4 bytes"; + return mlir::success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + FailureOr> elems = verifyCommon(); + if (failed(elems)) + return failure(); + Type dstElemTy = elems->second; + auto dstBytes = getElemBytes(dstElemTy); + if (!dstBytes || (*dstBytes != 1 && *dstBytes != 2 && *dstBytes != 4)) + return emitOpError() << "expects dst element size to be 1, 2, or 4 bytes"; + return mlir::success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TLogOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyVecTileUnaryOp(*this, srcTy, dstTy, "src", "dst", + /*allowBf16=*/false, /*allowInt8=*/false))) + return failure(); + if (failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) + return failure(); + auto elemTy = getElemTy(srcTy); + if (!(elemTy.isF16() || elemTy.isF32())) + return emitOpError() << "expects element type to be f16 or f32"; + return mlir::success(); +} + +mlir::LogicalResult mlir::pto::TLReluOp::verify() { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + auto verifyA2A3 = [&]() -> LogicalResult { + if (failed(verifyVecTileStorage(*this, srcTy, "src")) || + failed(verifyVecTileStorage(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst")) || + failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) + return failure(); + auto valid = getValidShapeVec(srcTy); + if (valid.size() != 2) + return emitOpError("expects src to have rank-2 valid_shape"); + if (valid[0] != ShapedType::kDynamic && valid[0] <= 0) + return emitOpError("expects src valid_shape[0] to be positive"); + if (valid[1] != ShapedType::kDynamic && valid[1] <= 0) + return emitOpError("expects src valid_shape[1] to be positive"); + Type elemTy = getElemTy(srcTy); + if (!(elemTy.isF16() || elemTy.isF32())) + return emitOpError() << "expects A2/A3 tlrelu element type to be f16 or f32"; + return success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + if (failed(verifyVecTileStorage(*this, srcTy, "src")) || + failed(verifyVecTileStorage(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst")) || + failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) + return failure(); + Type elemTy = getElemTy(srcTy); + if (!(elemTy.isF16() || elemTy.isF32())) + return emitOpError() << "expects A5 tlrelu element type to be f16 or f32"; + if (!getSlope().getType().isF32()) + return emitOpError() << "expects slope to have type f32"; + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TMaxOp::verify() { + return verifyArithmeticBinaryTileOpWithArchDispatch( + getOperation(), getSrc0().getType(), getSrc1().getType(), getDst().getType(), + /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/false, + "expects A2/A3 tmax element type to be i32/i16/f16/f32", + "expects A5 tmax element type to be i32/i16/i8/f16/f32"); +} + +mlir::LogicalResult mlir::pto::TMaxSOp::verify() { + return verifyArithmeticScalarTileOpWithArchDispatch( + getOperation(), getSrc().getType(), getDst().getType(), getScalar().getType(), + /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, + "expects A2/A3 tmaxs element type to be i32/i16/f16/f32", + "expects A5 tmaxs element type to be i32/i16/i8/f16/bf16/f32", + /*requireValidRowsEqualOnA2A3=*/true, + /*requireValidRowsEqualOnA5=*/true); +} + +mlir::LogicalResult mlir::pto::TMinOp::verify() { + return verifyArithmeticBinaryTileOpWithArchDispatch( + getOperation(), getSrc0().getType(), getSrc1().getType(), getDst().getType(), + /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, + "expects A2/A3 tmin element type to be i32/i16/f16/f32", + "expects A5 tmin element type to be i32/i16/i8/f16/bf16/f32"); +} + +mlir::LogicalResult mlir::pto::TMinSOp::verify() { + return verifyArithmeticScalarTileOpWithArchDispatch( + getOperation(), getSrc().getType(), getDst().getType(), getScalar().getType(), + /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, + "expects A2/A3 tmins element type to be i32/i16/f16/f32", + "expects A5 tmins element type to be i32/i16/i8/f16/bf16/f32", + /*requireValidRowsEqualOnA2A3=*/true, + /*requireValidRowsEqualOnA5=*/false); +} + +mlir::LogicalResult mlir::pto::TMovOp::verify() { + auto verifyImpl = [&](bool isA5) -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + Value fp = getFp(); + Value preQuantScalar = getPreQuantScalar(); + auto accToVecModeAttr = getAccToVecModeAttr(); + auto reluMode = getReluPreMode(); + const bool hasFp = static_cast(fp); + const bool hasPreQuantScalar = static_cast(preQuantScalar); + + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + if (hasFp && failed(verifyTileBufCommon(*this, fp.getType(), "fp"))) + return failure(); + if (hasFp && hasPreQuantScalar) + return emitOpError() << "expects fp and preQuantScalar forms to be mutually exclusive"; + + auto srcSpace = getPTOMemorySpaceEnum(srcTy); + auto dstSpace = getPTOMemorySpaceEnum(dstTy); + if (!srcSpace || !dstSpace) + return emitOpError() << "expects src and dst to have explicit address spaces"; + + auto srcShape = getShapeVec(srcTy); + auto dstShape = getShapeVec(dstTy); + if (*srcSpace == pto::AddressSpace::MAT && srcShape != dstShape) + return emitOpError() << "expects mat-source tmov to use matching src/dst shapes"; + if (!isA5 && *srcSpace != pto::AddressSpace::MAT && srcShape != dstShape) + return emitOpError() << "expects A2/A3 non-mat tmov to use matching src/dst shapes"; + + const bool isMatToTile = + *srcSpace == pto::AddressSpace::MAT && + (*dstSpace == pto::AddressSpace::LEFT || + *dstSpace == pto::AddressSpace::RIGHT || + *dstSpace == pto::AddressSpace::BIAS || + *dstSpace == pto::AddressSpace::SCALING); + const bool isVecToVec = + *srcSpace == pto::AddressSpace::VEC && + *dstSpace == pto::AddressSpace::VEC; + const bool isVecToMat = + *srcSpace == pto::AddressSpace::VEC && + *dstSpace == pto::AddressSpace::MAT; + const bool isAccToMat = + *srcSpace == pto::AddressSpace::ACC && + *dstSpace == pto::AddressSpace::MAT; + const bool isAccToVec = + *srcSpace == pto::AddressSpace::ACC && + *dstSpace == pto::AddressSpace::VEC; + + bool okPair = isMatToTile || isVecToVec || isAccToMat || isAccToVec; + if (isA5) + okPair = okPair || isVecToMat; + if (!okPair) + return emitOpError() + << "expects a supported tmov address-space pair for this target"; + + if (accToVecModeAttr && !isAccToVec) + return emitOpError() + << "expects accToVecMode to be used only for acc-to-vec tmov"; + + if (reluMode != pto::ReluPreMode::NoRelu && !(isAccToMat || isAccToVec)) + return emitOpError() + << "expects reluPreMode form to use loc=acc src"; + + if (hasPreQuantScalar && !(isAccToMat || isAccToVec)) + return emitOpError() + << "expects preQuantScalar form to use loc=acc src"; + + if (hasFp) { + auto fpTy = fp.getType(); + auto fpSpace = getPTOMemorySpaceEnum(fpTy); + if (!fpSpace || *fpSpace != pto::AddressSpace::SCALING) + return emitOpError() << "expects fp to be in the scaling address space"; + auto srcElemTy = getElemTy(srcTy); + auto srcIntTy = dyn_cast(srcElemTy); + if (!(srcElemTy.isF32() || (srcIntTy && srcIntTy.getWidth() == 32))) + return emitOpError() + << "expects fp form src to have element type f32, i32"; + if (!(isAccToMat || isAccToVec)) + return emitOpError() << "expects fp form to use loc=acc src"; + } + + if ((hasFp || hasPreQuantScalar) && accToVecModeAttr) { + switch (accToVecModeAttr.getValue()) { + case pto::AccToVecMode::SingleModeVec0: + case pto::AccToVecMode::SingleModeVec1: + break; + case pto::AccToVecMode::DualModeSplitM: + case pto::AccToVecMode::DualModeSplitN: + return emitOpError() + << "expects fp/preQuantScalar acc-to-vec forms to use single-mode accToVecMode"; + } + } + + auto srcTb = dyn_cast(srcTy); + auto dstTb = dyn_cast(dstTy); + if (srcTb && *srcSpace == pto::AddressSpace::ACC && + (hasFp || reluMode != pto::ReluPreMode::NoRelu)) { + if (srcTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor) || + srcTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor)) + return emitOpError() + << "expects acc-source fp/relu tmov src to use blayout=col_major and slayout=row_major"; + } + if (srcTb && dstTb && isAccToMat && !isA5 && + dstTb.getSFractalSizeI32() != 512) + return emitOpError() << "expects A2/A3 acc-to-mat tmov destination fractal to be 512"; + + return success(); + }; + auto verifyA2A3 = [&]() -> LogicalResult { return verifyImpl(/*isA5=*/false); }; + auto verifyA5 = [&]() -> LogicalResult { return verifyImpl(/*isA5=*/true); }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TMovFPOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type fpTy = getFp().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, fpTy, "fp")) || + failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + auto srcElemTy = getElemTy(srcTy); + auto srcIntTy = dyn_cast(srcElemTy); + if (!(srcElemTy.isF32() || + (srcIntTy && srcIntTy.getWidth() == 32))) + return emitOpError() + << "expects src to have element type f32, i32"; + auto fpSpace = getPTOMemorySpaceEnum(fpTy); + if (!fpSpace || *fpSpace != mlir::pto::AddressSpace::SCALING) + return emitOpError() << "expects fp to be in the scaling address space"; + auto srcSpace = getPTOMemorySpaceEnum(srcTy); + if (!srcSpace || *srcSpace != mlir::pto::AddressSpace::ACC) + return emitOpError() << "expects src to be in the acc address space"; + auto dstSpace = getPTOMemorySpaceEnum(dstTy); + if (!dstSpace || *dstSpace != mlir::pto::AddressSpace::MAT) + return emitOpError() << "expects dst to be in the mat address space"; + auto srcTb = dyn_cast(srcTy); + auto dstTb = dyn_cast(dstTy); + if (srcTb && + (srcTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor) || + srcTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor))) + return emitOpError() + << "expects src to use blayout=col_major and slayout=row_major"; + if (dstTb && + (dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor) || + dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor))) + return emitOpError() + << "expects dst to use blayout=col_major and slayout=row_major"; + if (dstTb && dstTb.getSFractalSizeI32() != 512) + return emitOpError() << "expects dst to use fractal size 512"; + return mlir::success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type fpTy = getFp().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, fpTy, "fp")) || + failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + auto srcElemTy = getElemTy(srcTy); + auto srcIntTy = dyn_cast(srcElemTy); + if (!(srcElemTy.isF32() || + (srcIntTy && srcIntTy.getWidth() == 32))) + return emitOpError() + << "expects src to have element type f32, i32"; + auto fpSpace = getPTOMemorySpaceEnum(fpTy); + if (!fpSpace || *fpSpace != mlir::pto::AddressSpace::SCALING) + return emitOpError() << "expects fp to be in the scaling address space"; + auto srcTb = dyn_cast(srcTy); + if (srcTb && + (srcTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor) || + srcTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor))) + return emitOpError() + << "expects src to use blayout=col_major and slayout=row_major"; + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} +// 辅助函数:获取 Rank,支持 ShapedType 和 PTO TileTypes +static int64_t getRankHelper(Type t) { + if (auto s = dyn_cast(t)) return s.getRank(); + if (auto tile = dyn_cast(t)) return tile.getRank(); + if (auto view = dyn_cast(t)) return view.getRank(); + return -1; +} + +static LogicalResult verifyMatmulLike(Operation *op, Type aTy, Type bTy, Type dstTy, bool checkRank = true) { + // 1. 检查类型 (ShapedType 或 Tile 类型) + bool aValid = isa(aTy); + bool bValid = isa(bTy); + bool dValid = isa(dstTy); + + if (!aValid || !bValid || !dValid) + return op->emitOpError("expects inputs/outputs to be shaped types or PTO tile types"); + + if (checkRank) { + int64_t aRank = getRankHelper(aTy); + int64_t bRank = getRankHelper(bTy); + int64_t dRank = getRankHelper(dstTy); + + // 检查 Rank 一致性 + if (aRank != -1 && dRank != -1 && aRank != dRank) + return op->emitOpError("expects a and dst to have the same rank"); + if (bRank != -1 && dRank != -1 && bRank != dRank) + return op->emitOpError("expects b and dst to have the same rank"); + } + + return success(); +} + +// ---- LoadScalarOp ---- +LogicalResult LoadScalarOp::verify() { + Type ptrTy = getPtr().getType(); + Type elemTy; + if (auto pty = dyn_cast(ptrTy)) { + elemTy = pty.getElementType(); + } else if (auto memTy = dyn_cast(ptrTy)) { + elemTy = memTy.getElementType(); + if (!isGmAddressSpaceAttr(memTy.getMemorySpace())) + return emitOpError() << "scalar load only supports GM address space pointers"; + } else { + return emitOpError("expects ptr to be !pto.ptr or memref type"); + } + + if (getValue().getType() != elemTy) + return emitOpError("expects result type to match ptr element type"); + + return success(); +} +// ---- StoreScalarOp ---- +LogicalResult StoreScalarOp::verify() { + Type ptrTy = getPtr().getType(); + Type elemTy; + if (auto pty = dyn_cast(ptrTy)) { + elemTy = pty.getElementType(); + } else if (auto memTy = dyn_cast(ptrTy)) { + elemTy = memTy.getElementType(); + if (!isGmAddressSpaceAttr(memTy.getMemorySpace())) + return emitOpError() << "scalar store only supports GM address space pointers"; + } else { + return emitOpError("expects ptr to be !pto.ptr or memref type"); + } + + if (getValue().getType() != elemTy) + return emitOpError("expects value type to match ptr element type"); + + return success(); +} + +// ---- GetBufOp / RlsBufOp ---- +static LogicalResult verifyBufSyncOp(Operation *op, Attribute opTypeAttr, + IntegerAttr bufIdAttr, IntegerAttr modeAttr) { + if (!opTypeAttr) + return op->emitOpError("expects 'op_type' attribute"); + + auto opTypeOr = parseSyncOpTypeLikeAttr(opTypeAttr); + if (failed(opTypeOr)) { + auto diag = + op->emitOpError("expects 'op_type' to be pipe_event_type/sync_op_type, got "); + diag << opTypeAttr; + return failure(); + } + pto::PIPE pipe = mapSyncOpTypeToPipe(*opTypeOr); + if (!isConcreteSyncPipe(pipe)) + return op->emitOpError("expects 'op_type' to map to a concrete pipe, not PIPE_ALL/PIPE_UNASSIGNED"); + + if (!bufIdAttr) + return op->emitOpError("expects 'buf_id' attribute"); + int64_t bufId = bufIdAttr.getInt(); + if (bufId < 0 || bufId > 31) + return op->emitOpError("expects 'buf_id' in range [0, 31]"); + + if (modeAttr) { + int64_t mode = modeAttr.getInt(); + if (mode < 0) + return op->emitOpError("expects 'mode' to be non-negative"); + } + + return success(); +} + +LogicalResult GetBufOp::verify() { + return verifyBufSyncOp(getOperation(), getOpTypeAttr(), getBufIdAttr(), + getModeAttr()); +} + +LogicalResult RlsBufOp::verify() { + return verifyBufSyncOp(getOperation(), getOpTypeAttr(), getBufIdAttr(), + getModeAttr()); +} +// ---- TOp ---- +LogicalResult TGemvBiasOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + if (failed(verifyGemvTileOperands(*this, getA().getType(), getB().getType(), + getDst().getType())) || + failed(verifyMatBiasTile(*this, getBias().getType(), getDst().getType()))) + return failure(); + if (failed(verifyMatmulTypeTriple(*this, getElemTy(getA().getType()), + getElemTy(getB().getType()), + getElemTy(getDst().getType())))) + return failure(); + return verifyMatmulLike(*this, getA().getType(), getB().getType(), + getDst().getType()); + }; + auto verifyA5 = [&]() -> LogicalResult { return verifyA2A3(); }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult TGemvMxOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + return emitOpError("tgemv.mx is only supported on A5 targets"); + }; + auto verifyA5 = [&]() -> LogicalResult { + if (failed(verifyScaleTileMatchesOperand(*this, getAScale().getType(), + getA().getType(), "a_scale", "a")) || + failed(verifyScaleTileMatchesOperand(*this, getBScale().getType(), + getB().getType(), "b_scale", "b")) || + failed(verifyGemvTileOperands(*this, getA().getType(), getB().getType(), + getDst().getType()))) + return failure(); + if (failed(verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), + getDst().getType(), "a", "b", "dst"))) + return failure(); + return verifyMatmulLike(*this, getA().getType(), getB().getType(), + getDst().getType()); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult TGemvMxAccOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + return emitOpError("tgemv.mx.acc is only supported on A5 targets"); + }; + auto verifyA5 = [&]() -> LogicalResult { + if (failed(verifyAccTileCommon(*this, getCIn().getType(), "c_in")) || + failed(verifyScaleTileMatchesOperand(*this, getAScale().getType(), + getA().getType(), "a_scale", "a")) || + failed(verifyScaleTileMatchesOperand(*this, getBScale().getType(), + getB().getType(), "b_scale", "b")) || + failed(verifyGemvTileOperands(*this, getA().getType(), getB().getType(), + getDst().getType()))) + return failure(); + if (failed(verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), + getDst().getType(), "a", "b", "dst"))) + return failure(); + if (failed(verifyTileBufSameElemType(*this, getCIn().getType(), + getDst().getType(), "c_in", "dst")) || + failed(verifyTileBufSameValidShape(*this, getCIn().getType(), + getDst().getType(), "c_in", "dst"))) + return failure(); + return verifyMatmulLike(*this, getA().getType(), getB().getType(), + getDst().getType()); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult TGemvMxBiasOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + return emitOpError("tgemv.mx.bias is only supported on A5 targets"); + }; + auto verifyA5 = [&]() -> LogicalResult { + if (failed(verifyScaleTileMatchesOperand(*this, getAScale().getType(), + getA().getType(), "a_scale", "a")) || + failed(verifyScaleTileMatchesOperand(*this, getBScale().getType(), + getB().getType(), "b_scale", "b")) || + failed(verifyGemvTileOperands(*this, getA().getType(), getB().getType(), + getDst().getType())) || + failed(verifyMatBiasTile(*this, getBias().getType(), getDst().getType(), + /*requireFloatBias=*/true))) + return failure(); + if (failed(verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), + getDst().getType(), "a", "b", "dst"))) + return failure(); + auto biasShape = getShapeVec(getBias().getType()); + auto dstShape = getShapeVec(getDst().getType()); + if (biasShape.size() != 2 || dstShape.size() != 2) + return emitOpError("expects bias and dst to be rank-2 for tgemv.mx.bias"); + if (biasShape[1] != ShapedType::kDynamic && dstShape[1] != ShapedType::kDynamic && + biasShape[1] != dstShape[1]) + return emitOpError("expects bias and dst to have the same column shape"); + if (failed(verifyTileBufSameValidShape(*this, getBias().getType(), + getDst().getType(), "bias", "dst"))) + return failure(); + return verifyMatmulLike(*this, getA().getType(), getB().getType(), + getDst().getType()); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult TMatmulBiasOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + if (failed(verifyMatTileOperands(*this, getA().getType(), getB().getType(), + getDst().getType())) || + failed(verifyMatBiasTile(*this, getBias().getType(), getDst().getType()))) + return failure(); + if (failed(verifyMatmulTypeTriple(*this, getElemTy(getA().getType()), + getElemTy(getB().getType()), + getElemTy(getDst().getType())))) + return failure(); + return verifyMatmulLike(*this, getA().getType(), getB().getType(), + getDst().getType()); + }; + auto verifyA5 = [&]() -> LogicalResult { return verifyA2A3(); }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult TMatmulMxOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + if (failed(verifyTileBufCommon(*this, getAScale().getType(), "a_scale")) || + failed(verifyTileBufCommon(*this, getBScale().getType(), "b_scale"))) + return failure(); + return verifyMatmulLike(*this, getA().getType(), getB().getType(), + getDst().getType()); + }; + auto verifyA5 = [&]() -> LogicalResult { + if (failed(verifyA2A3())) + return failure(); + return verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), + getDst().getType(), "a", "b", "dst"); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult TMatmulMxAccOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + if (failed(verifyAccTileCommon(*this, getCIn().getType(), "c_in")) || + failed(verifyTileBufCommon(*this, getAScale().getType(), "a_scale")) || + failed(verifyTileBufCommon(*this, getBScale().getType(), "b_scale"))) + return failure(); + return success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + if (failed(verifyA2A3())) + return failure(); + if (failed(verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), + getDst().getType(), "a", "b", "dst"))) + return failure(); + if (failed(verifyTileBufSameElemType(*this, getCIn().getType(), + getDst().getType(), "c_in", "dst")) || + failed(verifyTileBufSameValidShape(*this, getCIn().getType(), + getDst().getType(), "c_in", "dst"))) + return failure(); + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} +LogicalResult TMatmulMxBiasOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + if (failed(verifyTileBufCommon(*this, getAScale().getType(), "a_scale")) || + failed(verifyTileBufCommon(*this, getBScale().getType(), "b_scale")) || + failed(verifyMatTileOperands(*this, getA().getType(), getB().getType(), + getDst().getType())) || + failed(verifyMatBiasTile(*this, getBias().getType(), getDst().getType(), + /*requireFloatBias=*/true))) + return failure(); + return verifyMatmulLike(*this, getA().getType(), getB().getType(), + getDst().getType()); + }; + auto verifyA5 = [&]() -> LogicalResult { + if (failed(verifyA2A3())) + return failure(); + return verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), + getDst().getType(), "a", "b", "dst"); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} +// ---- TSetValOp ---- +LogicalResult TSetValOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + // dst can be tile/tensor/tilebuf (PTODpsType). Keep checks minimal. + if (auto shaped = dyn_cast(getDst().getType())) { + if (shaped.getElementType() != getVal().getType()) + return emitOpError("expects val type to match dst element type"); + } + return success(); +} +// ---- TGetValOp ---- +LogicalResult TGetValOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + Type srcTy = getSrc().getType(); + if (!mlir::isa(srcTy)) + return emitOpError("expects src to be tile_buf or memref type"); + + // Memory space must be vec (Ascend does not support getval from MAT etc.). + Attribute memSpace = + isa(srcTy) + ? cast(srcTy).getMemorySpace() + : cast(srcTy).getMemorySpace(); + auto addrSpaceAttr = dyn_cast_or_null(memSpace); + if (!addrSpaceAttr || + addrSpaceAttr.getAddressSpace() != pto::AddressSpace::VEC) { + if (addrSpaceAttr && + addrSpaceAttr.getAddressSpace() == pto::AddressSpace::MAT) + return emitOpError( + "Ascend hardware does not support reading from Mat tile_buf to Scalar unit"); + return emitOpError("expects src memory space to be vec"); + } + + if (getElemTy(srcTy) != getDst().getType()) + return emitOpError("expects dst type to match src element type"); + return success(); +} + +LogicalResult THistogramOp::verify() { + auto isIntegerWidth = [](Type ty, unsigned width) { + auto it = dyn_cast(ty); + return it && it.getWidth() == width; + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + return emitOpError("thistogram is only supported on A5"); + }; + auto verifyA5 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type idxTy = getIdx().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, idxTy, "idx")) || + failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + + auto srcSpace = getPTOMemorySpaceEnum(srcTy); + auto idxSpace = getPTOMemorySpaceEnum(idxTy); + auto dstSpace = getPTOMemorySpaceEnum(dstTy); + if (!srcSpace || *srcSpace != pto::AddressSpace::VEC) + return emitOpError("expects src to be in the vec address space"); + if (!idxSpace || *idxSpace != pto::AddressSpace::VEC) + return emitOpError("expects idx to be in the vec address space"); + if (!dstSpace || *dstSpace != pto::AddressSpace::VEC) + return emitOpError("expects dst to be in the vec address space"); + + auto srcTB = dyn_cast(srcTy); + auto idxTB = dyn_cast(idxTy); + auto dstTB = dyn_cast(dstTy); + if (!srcTB || !idxTB || !dstTB) + return emitOpError("expects src, idx, and dst to be tile_buf types"); + + if (srcTB.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) || + srcTB.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) + return emitOpError("expects src to use row_major + none_box layout"); + if (dstTB.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) || + dstTB.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) + return emitOpError("expects dst to use row_major + none_box layout"); + if (idxTB.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor) || + idxTB.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) + return emitOpError( + "expects idx to use DN layout (col_major + none_box)"); + + if (!isIntegerWidth(getElemTy(srcTy), 16)) + return emitOpError("expects src element type to be ui16"); + if (!isIntegerWidth(getElemTy(idxTy), 8)) + return emitOpError("expects idx element type to be ui8"); + if (!isIntegerWidth(getElemTy(dstTy), 32)) + return emitOpError("expects dst element type to be ui32"); + + auto srcShape = getShapeVec(srcTy); + auto idxShape = getShapeVec(idxTy); + auto dstShape = getShapeVec(dstTy); + auto srcValid = getValidShapeVec(srcTy); + auto idxValid = getValidShapeVec(idxTy); + auto dstValid = getValidShapeVec(dstTy); + if (srcShape.size() != 2 || idxShape.size() != 2 || dstShape.size() != 2 || + srcValid.size() != 2 || idxValid.size() != 2 || dstValid.size() != 2) + return emitOpError( + "expects src, idx, and dst to have rank-2 shape and valid_shape"); + + if (!hasCompatibleKnownExtent(srcShape[0], idxShape[0]) || + !hasCompatibleKnownExtent(srcValid[0], idxValid[0])) + return emitOpError("expects idx rows and valid rows to match src"); + if (!hasCompatibleKnownExtent(srcShape[0], dstShape[0]) || + !hasCompatibleKnownExtent(srcValid[0], dstValid[0])) + return emitOpError("expects dst rows and valid rows to match src"); + + if (!isKnownUnitExtent(idxShape[1]) || !isKnownUnitExtent(idxValid[1])) + return emitOpError("expects idx to have exactly one column"); + if (dstShape[1] != ShapedType::kDynamic && dstShape[1] < 256) + return emitOpError("expects dst shape[1] to be at least 256"); + if (dstValid[1] != ShapedType::kDynamic && dstValid[1] < 256) + return emitOpError("expects dst valid_shape[1] to be at least 256"); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult TGetScaleAddrOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + return emitOpError("tget_scale_addr is only supported on A5"); + }; + auto verifyA5 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, srcTy, "src"))) + return failure(); + if (failed(verifyScaleTileMatchesOperand(*this, dstTy, srcTy, "dst", "src"))) + return failure(); + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +// ---- MScatterOp ---- +LogicalResult MScatterOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + + if (!isTargetArchA5(getOperation())) + return emitOpError("pto.mscatter is only supported on A5 targets"); + + Type srcTy = getSrc().getType(); + Type idxTy = getIdx().getType(); + Type memTy = getMem().getType(); + + if (getPTOTypeRank(srcTy) == -1 || getPTOTypeRank(idxTy) == -1 || + getPTOTypeRank(memTy) == -1) + return emitOpError("expects src, idx, and mem to use supported PTO shapes"); + + if (failed(verifyNDStyleVecTile(*this, srcTy, "src")) || + failed(verifyMGatherMScatterIdxTile(getOperation(), idxTy, "idx"))) + return failure(); + + Type srcElem = getElemTy(srcTy); + Type idxElem = getElemTy(idxTy); + if (!srcElem || !idxElem) + return emitOpError("failed to resolve element types for src or idx"); + + if (!isSupportedMGatherMScatterPayloadElemType(getOperation(), srcElem)) + return emitOpError( + "expects src element type to be i8/ui8/i16/ui16/i32/ui32/f16/bf16/f32 " + "(and on A5 targets also float8_e4m3/float8_e5m2 family types)"); + + if (!isSupportedMGatherMScatterIndexElemType(idxElem)) + return emitOpError("expects idx element type to be signless i32"); + + if (failed(verifyMGatherMScatterMemOperand(getOperation(), getMem(), srcElem, + "src"))) + return failure(); + + if (getScatterAtomicOp() != pto::ScatterAtomicOp::None || + getScatterOob() != pto::ScatterOOB::Undefined) { + if (!isTargetArchA5(getOperation())) + return emitOpError( + "expects non-default scatterAtomicOp/scatterOob only on A5 targets"); + } + + if (!isSupportedMScatterAtomicPayloadElemType(srcElem, getScatterAtomicOp())) + return emitOpError( + "expects scatterAtomicOp-compatible src element type: add supports " + "i32/ui32/f16/f32, max/min support signless i32/f32"); + + if (failed(verifyMGatherMScatterTileShape(getOperation(), srcTy, idxTy, "src"))) + return failure(); + + return success(); +} + +// ---- MGatherOp ---- +LogicalResult MGatherOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + + if (!isTargetArchA5(getOperation())) + return emitOpError("pto.mgather is only supported on A5 targets"); + + Type memTy = getMem().getType(); + Type idxTy = getIdx().getType(); + Type dstTy = getDst().getType(); + + if (getPTOTypeRank(memTy) == -1 || getPTOTypeRank(idxTy) == -1 || + getPTOTypeRank(dstTy) == -1) + return emitOpError("expects mem, idx, and dst to use supported PTO shapes"); + + if (failed(verifyNDStyleVecTile(*this, dstTy, "dst")) || + failed(verifyMGatherMScatterIdxTile(getOperation(), idxTy, "idx"))) + return failure(); + + Type dstElem = getElemTy(dstTy); + Type idxElem = getElemTy(idxTy); + if (!dstElem || !idxElem) + return emitOpError("failed to resolve element types for dst or idx"); + + if (!isSupportedMGatherMScatterPayloadElemType(getOperation(), dstElem)) + return emitOpError( + "expects dst element type to be i8/ui8/i16/ui16/i32/ui32/f16/bf16/f32 " + "(and on A5 targets also float8_e4m3/float8_e5m2 family types)"); + + if (!isSupportedMGatherMScatterIndexElemType(idxElem)) + return emitOpError("expects idx element type to be signless i32"); + + if (failed(verifyMGatherMScatterMemOperand(getOperation(), getMem(), dstElem, + "dst"))) + return failure(); + + if (getGatherOob() != pto::GatherOOB::Undefined && + !isTargetArchA5(getOperation())) + return emitOpError( + "expects non-default gatherOob only on A5 targets"); + + if (failed(verifyMGatherMScatterTileShape(getOperation(), dstTy, idxTy, "dst"))) + return failure(); + + return success(); +} + +void mlir::pto::TCvtOp::print(OpAsmPrinter &p) { + p << " ins(" << getSrc(); + Builder builder(getContext()); + NamedAttrList attrs; + for (auto attr : (*this)->getAttrs()) { + if (attr.getName() == "sat_mode") { + attrs.set(builder.getStringAttr("satmode"), attr.getValue()); + continue; + } + attrs.set(attr.getName(), attr.getValue()); + } + p.printOptionalAttrDict(attrs.getAttrs()); + p << " : " << getSrc().getType(); + p << ") outs(" << getDst() << " : " << getDst().getType() << ")"; +} + +ParseResult mlir::pto::TCvtOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand src, dst; + Type srcTy, dstTy; + + if (parser.parseKeyword("ins") || parser.parseLParen() || parser.parseOperand(src)) + return failure(); + NamedAttrList attrs; + if (parser.parseOptionalAttrDict(attrs) || parser.parseColonType(srcTy)) + return failure(); + if (auto satmode = attrs.get("satmode")) { + attrs.erase("satmode"); + if (attrs.get("sat_mode")) + return parser.emitError(parser.getCurrentLocation(), + "cannot specify both satmode and sat_mode"); + attrs.set("sat_mode", satmode); + } + result.attributes = attrs; + if (parser.parseRParen() || parser.parseKeyword("outs") || parser.parseLParen() || + parser.parseOperand(dst) || parser.parseColonType(dstTy) || parser.parseRParen()) + return failure(); + + if (parser.resolveOperand(src, srcTy, result.operands) || + parser.resolveOperand(dst, dstTy, result.operands)) + return failure(); + return success(); +} + +void mlir::pto::TMrgSortOp::print(OpAsmPrinter &p) { + if (isFormat1()) { + p << " ins(" << getSrc() << ", " << getBlockLen() << " : " << getSrc().getType() + << ", " << getBlockLen().getType() << ") outs(" << getDst() << " : " + << getDst().getType() << ")"; + } else if (isFormat2()) { + p << " ins("; + llvm::interleaveComma(getSrcs(), p, [&](Value src) { p << src; }); + p << ", " << getTmp(); + p << " {exhausted = " << (getExhausted() ? "true" : "false") << "} : "; + llvm::interleaveComma(getSrcs().getTypes(), p, [&](Type ty) { p << ty; }); + p << ", " << getTmp().getType(); + p << ") outs(" << getDst() << ", " << getExcuted() + << " : " << getDst().getType() << ", " << getExcuted().getType() << ")"; + } else { + llvm::report_fatal_error("TMrgSortOp print expects format1 or format2"); + } + p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"operandSegmentSizes", "exhausted"}); +} + +ParseResult mlir::pto::TMrgSortOp::parse(OpAsmParser &parser, OperationState &result) { + if (parser.parseKeyword("ins") || parser.parseLParen()) + return failure(); + OpAsmParser::UnresolvedOperand first, second; + if (parser.parseOperand(first) || parser.parseComma() || parser.parseOperand(second)) + return failure(); + + if (parser.parseOptionalColon().succeeded()) { + Type srcTy, blockLenTy, dstTy; + if (parser.parseType(srcTy) || parser.parseComma() || parser.parseType(blockLenTy) || + parser.parseRParen() || parser.parseKeyword("outs") || parser.parseLParen()) + return failure(); + OpAsmParser::UnresolvedOperand dstOp; + if (parser.parseOperand(dstOp) || parser.parseColon() || parser.parseType(dstTy) || + parser.parseRParen()) + return failure(); + result.addAttribute("operandSegmentSizes", + parser.getBuilder().getDenseI32ArrayAttr({1, 1, 1, 0, 0})); + if (parser.resolveOperand(first, srcTy, result.operands) || + parser.resolveOperand(second, blockLenTy, result.operands) || + parser.resolveOperand(dstOp, dstTy, result.operands)) + return failure(); + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + if (!result.attributes.get("exhausted")) + result.addAttribute("exhausted", parser.getBuilder().getBoolAttr(false)); + return success(); + } + + SmallVector srcs = {first, second}; + while (parser.parseOptionalComma().succeeded()) { + OpAsmParser::UnresolvedOperand next; + if (parser.parseOperand(next)) + return failure(); + srcs.push_back(next); + } + if (srcs.size() < 3 || srcs.size() > 5) + return parser.emitError(parser.getCurrentLocation(), + "tmrgsort format2 expects 2 to 4 src operands plus one tmp operand"); + OpAsmParser::UnresolvedOperand tmpOp = srcs.pop_back_val(); + bool exhaustedVal = false; + if (parser.parseOptionalLBrace().succeeded()) { + if (parser.parseKeyword("exhausted") || parser.parseEqual()) + return failure(); + StringRef kw; + if (parser.parseKeyword(&kw) || parser.parseRBrace()) + return failure(); + exhaustedVal = (kw == "true"); + } + SmallVector srcTypes; + srcTypes.reserve(srcs.size()); + if (parser.parseColon()) + return failure(); + Type firstSrcTy; + if (parser.parseType(firstSrcTy)) + return failure(); + srcTypes.push_back(firstSrcTy); + while (parser.parseOptionalComma().succeeded()) { + Type nextTy; + if (parser.parseType(nextTy)) + return failure(); + srcTypes.push_back(nextTy); + } + if (srcTypes.size() != srcs.size() + 1 || parser.parseRParen() || + parser.parseKeyword("outs") || parser.parseLParen()) + return failure(); + Type tmpTy = srcTypes.pop_back_val(); + OpAsmParser::UnresolvedOperand dstOp, excutedOp; + Type dstTy, excutedTy; + if (parser.parseOperand(dstOp) || parser.parseComma() || parser.parseOperand(excutedOp) || + parser.parseColon() || parser.parseType(dstTy) || parser.parseComma() || + parser.parseType(excutedTy) || parser.parseRParen()) + return failure(); + result.addAttribute("operandSegmentSizes", + parser.getBuilder().getDenseI32ArrayAttr( + {static_cast(srcs.size()), 0, 1, 1, 1})); + if (parser.resolveOperands(srcs, srcTypes, parser.getCurrentLocation(), result.operands) || + parser.resolveOperand(dstOp, dstTy, result.operands) || + parser.resolveOperand(tmpOp, tmpTy, result.operands) || + parser.resolveOperand(excutedOp, excutedTy, result.operands)) + return failure(); + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + if (!result.attributes.get("exhausted")) + result.addAttribute("exhausted", parser.getBuilder().getBoolAttr(exhaustedVal)); + return success(); +} + +mlir::LogicalResult mlir::pto::TMrgSortOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + if (isFormat1()) { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (!isPTOShapedLike(srcTy) || !isPTOShapedLike(dstTy)) + return emitOpError() << "format1 expects PTO shaped-like types for src/dst"; + if (getElemTy(srcTy) != getElemTy(dstTy)) + return emitOpError() << "expects src/dst to have the same element type"; + if (!getElemTy(srcTy).isF16() && !getElemTy(srcTy).isF32()) + return emitOpError() << "expects element type to be f16 or f32"; + auto ss = getShapeVec(srcTy); + auto ds = getShapeVec(dstTy); + if (ss.size() != 2 || ds.size() != 2) + return emitOpError() << "expects src/dst to be rank-2 tile-shaped"; + if (ss[0] != mlir::ShapedType::kDynamic && ss[0] != 1) + return emitOpError() << "expects src rows == 1"; + if (ds[0] != mlir::ShapedType::kDynamic && ds[0] != 1) + return emitOpError() << "expects dst rows == 1"; + if (ss[1] != mlir::ShapedType::kDynamic && ds[1] != mlir::ShapedType::kDynamic && ss[1] != ds[1]) + return emitOpError() << "expects src/dst cols to match"; + if (getBlockLen()) { + if (auto cstOp = getBlockLen().getDefiningOp()) { + if (auto intAttr = mlir::dyn_cast(cstOp.getValue())) { + int64_t v = intAttr.getValue().getSExtValue(); + if (v <= 0 || (v % 64) != 0) + return emitOpError() << "expects blockLen > 0 and multiple of 64"; + } + } + } + return mlir::success(); + } + if (isFormat2()) { + for (Value v : getSrcs()) + if (!isPTOShapedLike(v.getType())) + return emitOpError() << "format2 expects PTO shaped-like type for each src"; + if (getSrcs().size() < 2u || getSrcs().size() > 4u) + return emitOpError() << "format2 expects 2 to 4 srcs"; + if (getDsts().size() != 1u || !getTmp() || !getExcuted()) + return emitOpError() << "format2 expects ins(srcs..., tmp), outs(dst), and excuted=vector"; + Type dstTy = getDst().getType(); + Type tmpTy = getTmp().getType(); + if (!isPTOShapedLike(dstTy) || !isPTOShapedLike(tmpTy)) + return emitOpError() << "format2 dst/tmp must be PTO shaped-like"; + auto excutedTy = mlir::dyn_cast(getExcuted().getType()); + if (!excutedTy || excutedTy.getRank() != 1 || excutedTy.getNumElements() != 4 || + !excutedTy.getElementType().isInteger(16)) + return emitOpError() << "format2 excuted must be vector<4xi16>"; + Type elemTy = getElemTy(dstTy); + if (elemTy != getElemTy(tmpTy)) + return emitOpError() << "format2 expects dst/tmp element types to match"; + auto dstShape = getShapeVec(dstTy); + auto tmpShape = getShapeVec(tmpTy); + if (dstShape.size() != 2 || tmpShape.size() != 2) + return emitOpError() << "format2 expects dst/tmp to be rank-2 tile-shaped"; + if ((dstShape[0] != mlir::ShapedType::kDynamic && dstShape[0] != 1) || + (tmpShape[0] != mlir::ShapedType::kDynamic && tmpShape[0] != 1)) + return emitOpError() << "format2 expects dst/tmp rows == 1"; + if (dstShape[1] != mlir::ShapedType::kDynamic && + tmpShape[1] != mlir::ShapedType::kDynamic && + tmpShape[1] < dstShape[1]) + return emitOpError() << "format2 expects tmp.cols >= dst.cols"; + for (Value src : getSrcs()) { + Type srcTy = src.getType(); + auto srcShape = getShapeVec(srcTy); + if (srcShape.size() != 2) + return emitOpError() << "format2 expects src to be rank-2 tile-shaped"; + if (srcShape[0] != mlir::ShapedType::kDynamic && srcShape[0] != 1) + return emitOpError() << "format2 expects src rows == 1"; + if (getElemTy(srcTy) != elemTy) + return emitOpError() << "format2 expects src/dst/tmp element types to match"; + } + return mlir::success(); + } + return emitOpError() << "tmrgsort expects format1 (1 src + blockLen + 1 dst) or " + "format2 (2 to 4 srcs + tmp, outs dst, excuted)"; +} + +mlir::LogicalResult mlir::pto::TMulOp::verify() { + return verifyArithmeticBinaryTileOpWithArchDispatch( + getOperation(), getSrc0().getType(), getSrc1().getType(), getDst().getType(), + /*allowInt8OnA5=*/false, /*allowBf16OnA5=*/false, + "expects A2/A3 tmul element type to be i32/i16/f16/f32", + "expects A5 tmul element type to be i32/i16/f16/f32"); +} + +mlir::LogicalResult mlir::pto::TMulSOp::verify() { + return verifyArithmeticScalarTileOpWithArchDispatch( + getOperation(), getSrc0().getType(), getDst().getType(), + getScalar().getType(), /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, + "expects A2/A3 tmuls element type to be i32/i16/f16/f32", + "expects A5 tmuls element type to be i32/i16/i8/f16/bf16/f32", + /*requireValidRowsEqualOnA2A3=*/true, + /*requireValidRowsEqualOnA5=*/false); +} + +mlir::LogicalResult mlir::pto::TShlSOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + + Type srcElem = getElemTy(srcTy); + Type dstElem = getElemTy(dstTy); + if (!srcElem || !dstElem) + return emitOpError() << "failed to get element type for src/dst"; + if (srcElem != dstElem) + return emitOpError() << "expects src and dst to have the same element type"; + if (!mlir::isa(srcElem)) + return emitOpError() << "expects integral element types"; + if (auto scalarValue = getConstantIntegerValue(getScalar()); scalarValue && *scalarValue < 0) + return emitOpError("expects tshls scalar to be non-negative"); + return mlir::success(); +} + +mlir::LogicalResult mlir::pto::TShrSOp::verify() { + auto verifyCommon = [&]() -> FailureOr { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyVecTileCommon(*this, srcTy, "src")) || + failed(verifyVecTileCommon(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) + return failure(); + + Type srcElem = getElemTy(srcTy); + Type dstElem = getElemTy(dstTy); + if (!srcElem || !dstElem) { + emitOpError("failed to get element type for src/dst"); + return failure(); + } + if (srcElem != dstElem) { + emitOpError("expects src and dst to have the same element type"); + return failure(); + } + return srcElem; + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + FailureOr elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + auto it = mlir::dyn_cast(*elemOr); + if (!it || (it.getWidth() != 16 && it.getWidth() != 32)) + return emitOpError( + "expects A2/A3 tshrs src and dst element type to be i16/i32"); + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + FailureOr elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + auto it = mlir::dyn_cast(*elemOr); + if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && + it.getWidth() != 32)) + return emitOpError( + "expects A5 tshrs src and dst element type to be i8/i16/i32"); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TNegOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyVecTileStorage(*this, srcTy, "src")) || + failed(verifyVecTileStorage(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst")) || + failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) + return failure(); + + Type elemTy = getElemTy(srcTy); + if (!(elemTy.isInteger(16) || elemTy.isInteger(32) || elemTy.isF16() || + elemTy.isF32())) + return emitOpError() + << "expects A2/A3 tneg element type to be i16/i32/f16/f32"; + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyVecTileStorage(*this, srcTy, "src")) || + failed(verifyVecTileStorage(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst"))) + return failure(); + + auto srcValid = getValidShapeVec(srcTy); + auto dstValid = getValidShapeVec(dstTy); + if (srcValid.size() != 2 || dstValid.size() != 2) + return emitOpError() << "expects src and dst to have rank-2 valid_shape"; + if (srcValid[1] != ShapedType::kDynamic && + dstValid[1] != ShapedType::kDynamic && + srcValid[1] != dstValid[1]) + return emitOpError() + << "expects src and dst to have the same valid_shape[1]"; + + Type elemTy = getElemTy(srcTy); + if (!(elemTy.isInteger(8) || elemTy.isInteger(16) || elemTy.isInteger(32) || + elemTy.isF16() || elemTy.isF32() || elemTy.isBF16())) + return emitOpError() + << "expects A5 tneg element type to be i8/i16/i32/f16/f32/bf16"; + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TNotOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyVecTileCommon(*this, srcTy, "src")) || + failed(verifyVecTileCommon(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) + return failure(); + auto elemTy = getElemTy(srcTy); + if (elemTy != getElemTy(dstTy)) + return emitOpError() << "expects src and dst to have the same element type"; + if (!elemTy.isInteger(16)) + return emitOpError() << "expects A2/A3 tnot element type to be i16"; + return success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyVecTileCommon(*this, srcTy, "src")) || + failed(verifyVecTileCommon(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) + return failure(); + auto elemTy = getElemTy(srcTy); + if (elemTy != getElemTy(dstTy)) + return emitOpError() << "expects src and dst to have the same element type"; + if (!(elemTy.isInteger(8) || elemTy.isInteger(16) || elemTy.isInteger(32))) + return emitOpError() << "expects A5 tnot element type to be i8/i16/i32"; + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TOrOp::verify() { + auto verifyCommon = [&]() -> FailureOr { + return verifyMatchingRowMajorBinaryTileOpCommon( + getOperation(), getSrc0().getType(), getSrc1().getType(), + getDst().getType()); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + FailureOr elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + auto it = mlir::dyn_cast(*elemOr); + if (!it || (it.getWidth() != 8 && it.getWidth() != 16)) + return emitOpError( + "expects A2/A3 tor src0, src1, and dst element type to be i8/i16"); + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + FailureOr elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + auto it = mlir::dyn_cast(*elemOr); + if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && + it.getWidth() != 32)) + return emitOpError( + "expects A5 tor src0, src1, and dst element type to be i8/i16/i32"); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TOrSOp::verify() { + auto verifyCommon = [&]() -> FailureOr { + return verifyDistinctRowMajorUnaryTileOpCommon(getOperation(), getSrc(), + getDst(), "src", "dst"); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + FailureOr elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + auto it = mlir::dyn_cast(*elemOr); + if (!it || (it.getWidth() != 8 && it.getWidth() != 16)) + return emitOpError( + "expects A2/A3 tors src and dst element type to be i8/i16"); + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + FailureOr elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + auto it = mlir::dyn_cast(*elemOr); + if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && + it.getWidth() != 32)) + return emitOpError( + "expects A5 tors src and dst element type to be i8/i16/i32"); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +static FailureOr verifyPTOShapedBinarySameElemAndShape(Operation *op, + Type src0Ty, + Type src1Ty, + Type dstTy) { + if (!isPTOShapedLike(src0Ty) || !isPTOShapedLike(src1Ty) || + !isPTOShapedLike(dstTy)) + return op->emitOpError( + "expects src0/src1/dst to be memref/tensor/tile_buf/tile_view types"), + failure(); + Type e0 = getElemTy(src0Ty), e1 = getElemTy(src1Ty), ed = getElemTy(dstTy); + if (!e0 || !e1 || !ed) + return op->emitOpError("failed to get element type for operands"), failure(); + if (e0 != e1 || e0 != ed) + return op->emitOpError("expects src0/src1/dst to have the same element type"), + failure(); + auto s0 = getShapeVec(src0Ty), s1 = getShapeVec(src1Ty), sd = getShapeVec(dstTy); + if (s0 != s1 || s0 != sd) + return op->emitOpError("expects src0/src1/dst to have the same shape"), + failure(); + return e0; +} + +mlir::LogicalResult mlir::pto::TPartAddOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + Type src0Ty = getSrc0().getType(); + Type src1Ty = getSrc1().getType(); + Type dstTy = getDst().getType(); + if (!isPTOShapedLike(src0Ty) || !isPTOShapedLike(src1Ty) || + !isPTOShapedLike(dstTy)) + return emitOpError() << "expects PTO shaped-like src0/src1/dst"; + if (getElemTy(src0Ty) != getElemTy(src1Ty) || + getElemTy(src0Ty) != getElemTy(dstTy)) + return emitOpError() << "expects src0/src1/dst to have the same element type"; + auto s0 = getShapeVec(src0Ty); + auto s1 = getShapeVec(src1Ty); + auto d = getShapeVec(dstTy); + if (s0.size() != 2 || s1.size() != 2 || d.size() != 2) + return emitOpError() << "expects src0/src1/dst to be rank-2 (tile-shaped)"; + if (failed(verifyPartialValidPattern(*this, src0Ty, src1Ty, dstTy))) + return failure(); + Type elem = getElemTy(src0Ty); + if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || elem.isF32())) + return emitOpError("expects A2/A3 tpartadd element type to be i32/i16/f16/f32"); + return mlir::success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + Type src0Ty = getSrc0().getType(); + Type src1Ty = getSrc1().getType(); + Type dstTy = getDst().getType(); + if (!isPTOShapedLike(src0Ty) || !isPTOShapedLike(src1Ty) || + !isPTOShapedLike(dstTy)) + return emitOpError() << "expects PTO shaped-like src0/src1/dst"; + if (getElemTy(src0Ty) != getElemTy(src1Ty) || + getElemTy(src0Ty) != getElemTy(dstTy)) + return emitOpError() << "expects src0/src1/dst to have the same element type"; + Type elem = getElemTy(src0Ty); + if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isInteger(8) || + elem.isF16() || elem.isBF16() || elem.isF32())) + return emitOpError("expects A5 tpartadd element type to be i32/i16/i8/f16/bf16/f32"); + auto s0 = getShapeVec(src0Ty); + auto s1 = getShapeVec(src1Ty); + auto d = getShapeVec(dstTy); + if (s0.size() != 2 || s1.size() != 2 || d.size() != 2) + return emitOpError() << "expects src0/src1/dst to be rank-2 (tile-shaped)"; + return mlir::success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TPartMaxOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + Type t0 = getSrc0().getType(); + Type t1 = getSrc1().getType(); + Type td = getDst().getType(); + FailureOr elemOr = + verifyPTOShapedBinarySameElemAndShape(getOperation(), t0, t1, td); + if (failed(elemOr)) + return failure(); + if (failed(verifyPartialValidPattern(*this, t0, t1, td))) + return failure(); + Type e0 = *elemOr; + if (!(e0.isInteger(32) || e0.isInteger(16) || e0.isF16() || e0.isF32())) + return emitOpError("expects A2/A3 tpartmax element type to be i32/i16/f16/f32"); + return mlir::success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + Type t0 = getSrc0().getType(); + Type t1 = getSrc1().getType(); + Type td = getDst().getType(); + FailureOr elemOr = + verifyPTOShapedBinarySameElemAndShape(getOperation(), t0, t1, td); + if (failed(elemOr)) + return failure(); + Type e0 = *elemOr; + if (!(e0.isInteger(32) || e0.isInteger(16) || e0.isInteger(8) || + e0.isF16() || e0.isBF16() || e0.isF32())) + return emitOpError("expects A5 tpartmax element type to be i32/i16/i8/f16/bf16/f32"); + return mlir::success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TPartMinOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + Type t0 = getSrc0().getType(); + Type t1 = getSrc1().getType(); + Type td = getDst().getType(); + FailureOr elemOr = + verifyPTOShapedBinarySameElemAndShape(getOperation(), t0, t1, td); + if (failed(elemOr)) + return failure(); + if (failed(verifyPartialValidPattern(*this, t0, t1, td))) + return failure(); + Type e0 = *elemOr; + if (!(e0.isInteger(32) || e0.isInteger(16) || e0.isF16() || e0.isF32())) + return emitOpError("expects A2/A3 tpartmin element type to be i32/i16/f16/f32"); + return mlir::success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + Type t0 = getSrc0().getType(); + Type t1 = getSrc1().getType(); + Type td = getDst().getType(); + FailureOr elemOr = + verifyPTOShapedBinarySameElemAndShape(getOperation(), t0, t1, td); + if (failed(elemOr)) + return failure(); + Type e0 = *elemOr; + if (!(e0.isInteger(32) || e0.isInteger(16) || e0.isInteger(8) || + e0.isF16() || e0.isBF16() || e0.isF32())) + return emitOpError("expects A5 tpartmin element type to be i32/i16/i8/f16/bf16/f32"); + return mlir::success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +static LogicalResult verifyTPartArgOpCommon(Operation *op, Type src0Ty, + Type src1Ty, Type src0IdxTy, + Type src1IdxTy, Type dstTy, + Type dstIdxTy, StringRef opName) { + FailureOr dataElemOr = + verifyPTOShapedBinarySameElemAndShape(op, src0Ty, src1Ty, dstTy); + if (failed(dataElemOr)) + return failure(); + if (failed(verifyPartialValidPattern(op, src0Ty, src1Ty, dstTy))) + return failure(); + + if (!isPTOShapedLike(src0IdxTy) || !isPTOShapedLike(src1IdxTy) || + !isPTOShapedLike(dstIdxTy)) + return op->emitOpError("expects PTO shaped-like src0Idx/src1Idx/dstIdx"); + Type idxElem = getElemTy(src0IdxTy); + if (!idxElem || idxElem != getElemTy(src1IdxTy) || + idxElem != getElemTy(dstIdxTy)) + return op->emitOpError( + "expects src0Idx/src1Idx/dstIdx to have the same element type"); + auto idxInt = dyn_cast(idxElem); + if (!idxInt || idxInt.getWidth() != 32) + return op->emitOpError( + "expects src0Idx/src1Idx/dstIdx element type to be i32 or ui32"); + + auto dataShape = getShapeVec(src0Ty); + if (dataShape != getShapeVec(src0IdxTy) || + dataShape != getShapeVec(src1IdxTy) || + dataShape != getShapeVec(dstIdxTy)) + return op->emitOpError( + "expects data and index operands to have the same shape"); + if (getValidShapeVec(src0Ty) != getValidShapeVec(src0IdxTy) || + getValidShapeVec(src1Ty) != getValidShapeVec(src1IdxTy) || + getValidShapeVec(dstTy) != getValidShapeVec(dstIdxTy)) + return op->emitOpError( + "expects each data operand and its index operand to have the same valid_shape"); + + Type elem = *dataElemOr; + PTOArch arch = getTargetArch(op); + if (arch == PTOArch::A5) { + if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isInteger(8) || + elem.isF16() || elem.isBF16() || elem.isF32())) + return op->emitOpError() << "expects A5 " << opName + << " element type to be i32/i16/i8/f16/bf16/f32"; + } else { + if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || + elem.isF32())) + return op->emitOpError() << "expects A2/A3 " << opName + << " element type to be i32/i16/f16/f32"; + } + return success(); +} + +mlir::LogicalResult mlir::pto::TPartArgMaxOp::verify() { + auto verifyByArch = [&]() -> LogicalResult { + return verifyTPartArgOpCommon( + getOperation(), getSrc0().getType(), getSrc1().getType(), + getSrc0Idx().getType(), getSrc1Idx().getType(), getDst().getType(), + getDstIdx().getType(), "tpartargmax"); + }; + return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); +} + +mlir::LogicalResult mlir::pto::TPartArgMinOp::verify() { + auto verifyByArch = [&]() -> LogicalResult { + return verifyTPartArgOpCommon( + getOperation(), getSrc0().getType(), getSrc1().getType(), + getSrc0Idx().getType(), getSrc1Idx().getType(), getDst().getType(), + getDstIdx().getType(), "tpartargmin"); + }; + return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); +} + +mlir::LogicalResult mlir::pto::TPartMulOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + Type src0Ty = getSrc0().getType(); + Type src1Ty = getSrc1().getType(); + Type dstTy = getDst().getType(); + if (!isPTOShapedLike(src0Ty) || !isPTOShapedLike(src1Ty) || + !isPTOShapedLike(dstTy)) + return emitOpError() << "expects PTO shaped-like src0/src1/dst"; + if (getElemTy(src0Ty) != getElemTy(src1Ty) || + getElemTy(src0Ty) != getElemTy(dstTy)) + return emitOpError() + << "expects src0/src1/dst to have the same element type"; + auto s0 = getShapeVec(src0Ty); + auto s1 = getShapeVec(src1Ty); + auto d = getShapeVec(dstTy); + if (s0.size() != 2 || s1.size() != 2 || d.size() != 2) + return emitOpError() + << "expects src0/src1/dst to be rank-2 (tile-shaped)"; + if (failed(verifyPartialValidPattern(*this, src0Ty, src1Ty, dstTy))) + return failure(); + Type elem = getElemTy(src0Ty); + if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || + elem.isF32())) + return emitOpError( + "expects A2/A3 tpartmul element type to be i32/i16/f16/f32"); + return mlir::success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + Type src0Ty = getSrc0().getType(); + Type src1Ty = getSrc1().getType(); + Type dstTy = getDst().getType(); + if (!isPTOShapedLike(src0Ty) || !isPTOShapedLike(src1Ty) || + !isPTOShapedLike(dstTy)) + return emitOpError() << "expects PTO shaped-like src0/src1/dst"; + if (getElemTy(src0Ty) != getElemTy(src1Ty) || + getElemTy(src0Ty) != getElemTy(dstTy)) + return emitOpError() + << "expects src0/src1/dst to have the same element type"; + Type elem = getElemTy(src0Ty); + if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isInteger(8) || + elem.isF16() || elem.isBF16() || elem.isF32())) + return emitOpError( + "expects A5 tpartmul element type to be i32/i16/i8/f16/bf16/f32"); + auto s0 = getShapeVec(src0Ty); + auto s1 = getShapeVec(src1Ty); + auto d = getShapeVec(dstTy); + if (s0.size() != 2 || s1.size() != 2 || d.size() != 2) + return emitOpError() + << "expects src0/src1/dst to be rank-2 (tile-shaped)"; + return mlir::success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TPReluOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + auto verifyCommon = [&]() -> FailureOr> { + Type t0 = getSrc0().getType(); + Type t1 = getSrc1().getType(); + Type tt = getTmp().getType(); + Type td = getDst().getType(); + if (failed(verifyTileBufCommon(*this, t0, "src0")) || + failed(verifyTileBufCommon(*this, t1, "src1")) || + failed(verifyTileBufCommon(*this, tt, "tmp")) || + failed(verifyTileBufCommon(*this, td, "dst"))) + return failure(); + + Type e0 = getElemTy(t0), e1 = getElemTy(t1), et = getElemTy(tt), ed = getElemTy(td); + if (!e0 || !e1 || !et || !ed) { + emitOpError("failed to get element type for operands"); + return failure(); + } + if (e0 != e1 || e0 != ed) { + emitOpError("expects dst/src0/src1 to have the same element type"); + return failure(); + } + if (!(e0.isF16() || e0.isF32())) { + emitOpError("expects dst/src0/src1 element type to be f16 or f32"); + return failure(); + } + if (!isRowMajorTileBuf(t0) || !isRowMajorTileBuf(t1) || !isRowMajorTileBuf(td)) { + emitOpError("expects src0, src1, and dst to use row-major layout"); + return failure(); + } + if (failed(verifyTileBufSameValidShape(*this, t0, td, "src0", "dst")) || + failed(verifyTileBufSameValidShape(*this, t1, td, "src1", "dst"))) + return failure(); + + auto s0 = getShapeVec(t0), s1 = getShapeVec(t1), st = getShapeVec(tt), sd = getShapeVec(td); + if (s0 != s1 || s0 != st || s0 != sd) { + emitOpError("expects src0/src1/tmp/dst to have the same shape"); + return failure(); + } + return std::make_tuple(t0, t1, tt, td); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + auto tysOr = verifyCommon(); + if (failed(tysOr)) + return failure(); + auto [t0, t1, tt, td] = *tysOr; + Type tmpElem = getElemTy(tt); + auto tmpIntTy = mlir::dyn_cast(tmpElem); + if (!tmpIntTy || tmpIntTy.getWidth() != 8) + return emitOpError("expects A2/A3 tmp element type to be u8"); + if (!isRowMajorTileBuf(tt)) + return emitOpError("expects tmp to use row-major layout"); + if (auto arch = getVerifierArchName(getOperation()); + arch && arch->equals_insensitive("a3")) { + if (getSrc0() == getSrc1() || getSrc0() == getTmp() || getSrc0() == getDst() || + getSrc1() == getTmp() || getSrc1() == getDst() || getTmp() == getDst()) + return emitOpError( + "expects A3 src0, src1, tmp, and dst to use different storage"); + } + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + auto tysOr = verifyCommon(); + if (failed(tysOr)) + return failure(); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TQuantOp::verify() { + // Structural checks: always run regardless of operand representation + // (applies both before and after PTOViewToMemref lowering). + auto verifyStructural = [&]() -> LogicalResult { + // dst elem type and offset presence must be consistent with quant_type. + Type dstTy = getDst().getType(); + Type dstElemTy = getElemTy(dstTy); + auto dstIntTy = dyn_cast(dstElemTy); + if (getQuantType() == mlir::pto::QuantType::INT8_SYM) { + if (!dstIntTy || dstIntTy.getWidth() != 8) + return emitOpError() + << "expects dst element type i8/ui8 for INT8_SYM quantization"; + if (getOffset()) + return emitOpError() + << "INT8_SYM quantization must not have an offset operand"; + } else { + // INT8_ASYM + if (!dstIntTy || dstIntTy.getWidth() != 8) + return emitOpError() + << "expects dst element type i8/ui8 for INT8_ASYM quantization"; + if (!getOffset()) + return emitOpError() + << "INT8_ASYM quantization requires an offset operand"; + } + return success(); + }; + + if (failed(verifyStructural())) + return failure(); + + // Layout/tile-buffer checks: only meaningful for pre-lowering tile types. + // Skip when operands are already plain MemRefs (post PTOViewToMemref). + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + + auto verifyCommon = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type fpTy = getFp().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, fpTy, "fp")) || + failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + // src must be f32 (ISA static_assert) + if (!getElemTy(srcTy).isF32()) + return emitOpError() << "expects src to have element type f32"; + if (getOffset()) { + Type offsetTy = getOffset().getType(); + if (failed(verifyTileBufCommon(*this, offsetTy, "offset"))) + return failure(); + if (!getElemTy(offsetTy).isF32()) + return emitOpError() << "expects offset to have element type f32"; + } + return success(); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + if (failed(verifyCommon())) + return failure(); + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (!isRowMajorTileBuf(srcTy) || !isRowMajorTileBuf(dstTy)) + return emitOpError() << "expects A2/A3 src and dst to use row-major layout"; + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + return verifyCommon(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TDequantOp::verify() { + // Structural checks: src must be i8 or i16, dst/scale/offset must be f32. + auto verifyStructural = [&]() -> LogicalResult { + Type srcElemTy = getElemTy(getSrc().getType()); + auto srcIntTy = dyn_cast(srcElemTy); + if (!srcIntTy || !(srcIntTy.getWidth() == 8 || srcIntTy.getWidth() == 16)) + return emitOpError() + << "expects src element type i8 or i16"; + if (!getElemTy(getDst().getType()).isF32()) + return emitOpError() << "expects dst element type f32"; + if (!getElemTy(getScale().getType()).isF32()) + return emitOpError() << "expects scale element type f32"; + if (!getElemTy(getOffset().getType()).isF32()) + return emitOpError() << "expects offset element type f32"; + return success(); + }; + + if (failed(verifyStructural())) + return failure(); + + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + + auto verifyCommon = [&]() -> LogicalResult { + if (failed(verifyTileBufCommon(*this, getSrc().getType(), "src")) || + failed(verifyTileBufCommon(*this, getScale().getType(), "scale")) || + failed(verifyTileBufCommon(*this, getOffset().getType(), "offset")) || + failed(verifyTileBufCommon(*this, getDst().getType(), "dst"))) + return failure(); + return success(); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + if (failed(verifyCommon())) + return failure(); + if (!isRowMajorTileBuf(getSrc().getType()) || + !isRowMajorTileBuf(getDst().getType())) + return emitOpError() + << "expects A2/A3 src and dst to use row-major layout"; + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { return verifyCommon(); }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TRecipOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + Type ts = getSrc().getType(); + Type td = getDst().getType(); + if (failed(verifyVecTileUnaryOp(*this, ts, td, "src", "dst", + /*allowBf16=*/false, /*allowInt8=*/false))) + return failure(); + if (failed(verifyTileBufSameValidShape(*this, ts, td, "src", "dst"))) + return failure(); + Type elemTy = getElemTy(ts); + if (!(elemTy.isF16() || elemTy.isF32())) + return emitOpError() << "expects element type to be f16 or f32"; + if (auto arch = getVerifierArchName(getOperation()); + arch && arch->equals_insensitive("a3") && getSrc() == getDst()) + return emitOpError("expects A3 trecip src and dst to use different storage"); + return mlir::success(); +} + +mlir::LogicalResult mlir::pto::TReluOp::verify() { + auto verifyByArch = [&](StringRef errorMessage) -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyVecTileCommon(*this, srcTy, "src")) || + failed(verifyVecTileCommon(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst")) || + failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) + return failure(); + Type elemTy = getElemTy(srcTy); + if (!(elemTy.isInteger(32) || elemTy.isF16() || elemTy.isF32())) + return emitOpError() << errorMessage; + return success(); + }; + auto verifyA2A3 = [&]() -> LogicalResult { + return verifyByArch("expects A2/A3 trelu element type to be i32/f16/f32"); + }; + auto verifyA5 = [&]() -> LogicalResult { + return verifyByArch("expects A5 trelu element type to be i32/f16/f32"); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + + +mlir::LogicalResult mlir::pto::TRemOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + + Type src0Ty = getSrc0().getType(); + Type src1Ty = getSrc1().getType(); + Type tmpTy = getTmp().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, src0Ty, "src0")) || + failed(verifyTileBufCommon(*this, src1Ty, "src1")) || + failed(verifyTileBufCommon(*this, tmpTy, "tmp")) || + failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameElemType(*this, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst")) || + failed(verifyTileBufSameValidShape(*this, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) + return failure(); + if (getElemTy(tmpTy) != getElemTy(dstTy)) + return emitOpError("expects tmp and dst to have the same element type"); + if (!isRowMajorTileBuf(src0Ty) || !isRowMajorTileBuf(src1Ty) || + !isRowMajorTileBuf(tmpTy) || !isRowMajorTileBuf(dstTy)) + return emitOpError("expects src0, src1, tmp, and dst to use row-major layout"); + auto dstValid = getValidShapeVec(dstTy); + auto tmpValid = getValidShapeVec(tmpTy); + if (dstValid.size() != 2 || tmpValid.size() != 2) + return emitOpError("expects tmp and dst to be rank-2 tiles"); + if (tmpValid[0] != ShapedType::kDynamic && tmpValid[0] < 1) + return emitOpError("expects tmp to have at least 1 valid row"); + if (dstValid[1] != ShapedType::kDynamic && tmpValid[1] != ShapedType::kDynamic && + tmpValid[1] < dstValid[1]) + return emitOpError("expects tmp valid columns to cover dst valid columns"); + + Type elem = getElemTy(src0Ty); + auto verifyA2A3 = [&]() -> LogicalResult { + if (!(elem.isInteger(32) || elem.isF32())) + return emitOpError("expects A2/A3 trem element type to be i32/f32"); + return success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || elem.isF32())) + return emitOpError("expects A5 trem element type to be i32/i16/f16/f32"); + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TFModOp::verify() { + return verifyArithmeticBinaryTileOpWithArchDispatch( + getOperation(), getSrc0().getType(), getSrc1().getType(), getDst().getType(), + /*allowInt8OnA5=*/false, /*allowBf16OnA5=*/false, + "expects A2/A3 tfmod element type to be i32/i16/f16/f32", + "expects A5 tfmod element type to be i32/i16/f16/f32"); +} + +mlir::LogicalResult mlir::pto::TRemSOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + Type ts = getSrc().getType(); + Type tt = getTmp().getType(); + Type td = getDst().getType(); + Type scalarTy = getScalar().getType(); + if (failed(verifyTileBufCommon(*this, ts, "src")) || + failed(verifyTileBufCommon(*this, tt, "tmp")) || + failed(verifyTileBufCommon(*this, td, "dst"))) + return failure(); + if (failed(verifyTileBufSameElemType(*this, ts, td, "src", "dst")) || + failed(verifyTileBufSameValidShape(*this, ts, td, "src", "dst"))) + return failure(); + if (getElemTy(tt) != getElemTy(td)) + return emitOpError("expects tmp and dst to have the same element type"); + if (!isRowMajorTileBuf(ts) || !isRowMajorTileBuf(tt) || !isRowMajorTileBuf(td)) + return emitOpError("expects src, tmp, and dst to use row-major layout"); + Type elem = getElemTy(ts); + if (scalarTy != elem) + return emitOpError("expects scalar type to match the tile element type"); + auto dstValid = getValidShapeVec(td); + auto tmpValid = getValidShapeVec(tt); + if (dstValid.size() != 2 || tmpValid.size() != 2) + return emitOpError("expects tmp and dst to be rank-2 tiles"); + if (tmpValid[0] != ShapedType::kDynamic && tmpValid[0] < 1) + return emitOpError("expects tmp to have at least 1 valid row"); + if (dstValid[1] != ShapedType::kDynamic && tmpValid[1] != ShapedType::kDynamic && + tmpValid[1] < dstValid[1]) + return emitOpError("expects tmp valid columns to cover dst valid columns"); + auto verifyA2A3 = [&]() -> LogicalResult { + if (!(elem.isInteger(32) || elem.isF32())) + return emitOpError("expects A2/A3 trems element type to be i32/f32"); + return success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || elem.isF32())) + return emitOpError("expects A5 trems element type to be i32/i16/f16/f32"); + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TFModSOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + Type scalarTy = getScalar().getType(); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst")) || + failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) + return failure(); + if (!isRowMajorTileBuf(srcTy) || !isRowMajorTileBuf(dstTy)) + return emitOpError("expects src and dst to use row-major layout"); + + Type elem = getElemTy(srcTy); + if (scalarTy != elem) + return emitOpError("expects scalar type to match the tile element type"); + + auto verifyA2A3 = [&]() -> LogicalResult { + if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || elem.isF32())) + return emitOpError("expects A2/A3 tfmods element type to be i32/i16/f16/f32"); + return success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || elem.isF32())) + return emitOpError("expects A5 tfmods element type to be i32/i16/f16/f32"); + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + + +static std::optional getStaticNumElements(ArrayRef shape) { + int64_t numel = 1; + for (int64_t d : shape) { + if (d == ShapedType::kDynamic) + return std::nullopt; + if (d < 0) + return std::nullopt; + numel *= d; + } + return numel; +} + +static std::optional getElemBytes(Type elemTy) { + if (!elemTy) + return std::nullopt; + if (auto ft = dyn_cast(elemTy)) { + if (ft.isF16() || ft.isBF16()) + return 2; + if (ft.isF32()) + return 4; + if (ft.isF64()) + return 8; + return std::nullopt; + } + if (auto it = dyn_cast(elemTy)) { + int64_t bits = it.getWidth(); + if (bits <= 0) + return std::nullopt; + return std::max(1, bits / 8); + } + return std::nullopt; +} + +[[maybe_unused]] static bool isTileBufOrMemref(Type ty) { + return mlir::isa(ty); +} + +static constexpr llvm::StringLiteral kLoweredSetValidShapeAttrName = + "__pto.lowered_set_validshape"; + +static bool isLocallyBoundTileSource(Value value) { + if (!value || isa(value)) + return false; + + if (isa( + value.getDefiningOp())) + return true; + + if (auto bitcast = value.getDefiningOp()) + return isLocallyBoundTileSource(bitcast.getSrc()); + if (auto reshape = value.getDefiningOp()) + return isLocallyBoundTileSource(reshape.getSrc()); + + return false; +} + +static std::optional getConstIndexLike(Value v) { + if (auto cOp = v.getDefiningOp()) + return cOp.value(); + if (auto cInt = v.getDefiningOp()) + return cInt.value(); + if (auto cOp = v.getDefiningOp()) { + if (auto ia = dyn_cast(cOp.getValue())) + return ia.getInt(); + } + if (auto castOp = v.getDefiningOp()) + return getConstIndexLike(castOp.getIn()); + if (auto extOp = v.getDefiningOp()) + return getConstIndexLike(extOp.getIn()); + if (auto extOp = v.getDefiningOp()) + return getConstIndexLike(extOp.getIn()); + if (auto truncOp = v.getDefiningOp()) + return getConstIndexLike(truncOp.getIn()); + return std::nullopt; +} + +mlir::LogicalResult mlir::pto::SetValidShapeOp::verify() { + SmallVector shape; + if (auto srcTy = llvm::dyn_cast(getSource().getType())) { + if (srcTy.getRank() != 2) + return emitOpError("expects rank-2 tile_buf source"); + + ArrayRef validShape = srcTy.getValidShape(); + if (validShape.size() != 2) + return emitOpError("expects source validShape to be rank-2"); + if (!srcTy.hasDynamicValid()) + return emitOpError("expects source tile_buf to have dynamic validShape (?, ?)"); + + shape.assign(srcTy.getShape().begin(), srcTy.getShape().end()); + + if (!isLocallyBoundTileSource(getSource())) + return emitOpError( + "requires a locally bound tile source; function arguments/results " + "are unsupported"); + } else if (auto srcTy = llvm::dyn_cast(getSource().getType())) { + if (!(*this)->hasAttr(kLoweredSetValidShapeAttrName)) + return emitOpError( + "expects tile_buf source; memref source is only valid for the internal lowered form"); + if (srcTy.getRank() != 2) + return emitOpError("expects rank-2 memref source after tile lowering"); + shape.assign(srcTy.getShape().begin(), srcTy.getShape().end()); + } else { + return emitOpError("expects tile_buf source (or lowered memref source)"); + } + + auto checkDim = [&](Value operand, unsigned dimIdx, + StringRef dimName) -> LogicalResult { + int64_t maxStatic = shape[dimIdx]; + + auto constVal = getConstIndexLike(operand); + if (!constVal) + return success(); + + if (*constVal < 0) + return emitOpError() << "expects " << dimName << " operand to be non-negative"; + if (maxStatic != ShapedType::kDynamic && *constVal > maxStatic) + return emitOpError() << "expects " << dimName << " operand <= shape dim (" + << maxStatic << ")"; + return success(); + }; + + if (failed(checkDim(getValidRow(), /*dimIdx=*/0, "row"))) + return failure(); + if (failed(checkDim(getValidCol(), /*dimIdx=*/1, "col"))) + return failure(); + + return success(); +} + +mlir::LogicalResult mlir::pto::GetValidShapeOp::verify() { + if (auto srcTy = llvm::dyn_cast(getSource().getType())) { + if (srcTy.getRank() != 2) + return emitOpError("expects rank-2 tile_buf source"); + if (srcTy.getValidShape().size() != 2) + return emitOpError("expects source validShape to be rank-2"); + return success(); + } + if (auto srcTy = llvm::dyn_cast(getSource().getType())) { + if (srcTy.getRank() != 2) + return emitOpError("expects rank-2 memref source after tile lowering"); + return success(); + } + return emitOpError("expects tile_buf source (or lowered memref source)"); +} + + +mlir::LogicalResult mlir::pto::TReshapeOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + Type ts = getSrc().getType(); + Type tr = getResult().getType(); + auto srcTb = dyn_cast(ts); + auto dstTb = dyn_cast(tr); + if (!srcTb || !dstTb) + return emitOpError("expects src/result to be !pto.tile_buf types"); + + if (failed(verifyTileBufCommon(*this, ts, "src")) || + failed(verifyTileBufCommon(*this, tr, "dst"))) + return failure(); + + if (srcTb.getMemorySpace() != dstTb.getMemorySpace()) + return emitOpError("expects src and dst to use the same loc"); + + Type srcElem = srcTb.getElementType(); + Type dstElem = dstTb.getElementType(); + auto srcElemBytes = getElemBytes(srcElem); + auto dstElemBytes = getElemBytes(dstElem); + if (!srcElem || !dstElem || !srcElemBytes.has_value() || !dstElemBytes.has_value()) + return emitOpError("failed to get element byte width for src/dst"); + + auto srcNumel = getStaticNumElements(getShapeVec(ts)); + auto dstNumel = getStaticNumElements(getShapeVec(tr)); + if (!srcNumel.has_value() || !dstNumel.has_value()) + return emitOpError("expects static shapes for treshape"); + + if (srcElemBytes.value() * srcNumel.value() != + dstElemBytes.value() * dstNumel.value()) + return emitOpError("expects src and dst to have the same total byte size"); + + bool srcBoxed = + srcTb.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox); + bool dstBoxed = + dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox); + if (srcBoxed != dstBoxed) + return emitOpError("cannot reshape between boxed and non-boxed tile layouts"); + + return success(); +} + +mlir::LogicalResult mlir::pto::BitcastOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + auto srcTy = llvm::dyn_cast(getSrc().getType()); + auto dstTy = llvm::dyn_cast(getResult().getType()); + if (!srcTy || !dstTy) + return emitOpError("expects tile_buf src and tile_buf result"); + + if (srcTy.getMemorySpace() != dstTy.getMemorySpace()) + return emitOpError("expects src/result to have the same memorySpace"); + + if (srcTy.getElementType() == dstTy.getElementType()) + return emitOpError( + "expects src/result to have different element types; use " + "pto.treshape for shape/config changes"); + + if (srcTy.getShape() != dstTy.getShape()) + return emitOpError("expects src/result to have the same shape; use pto.treshape for shape changes"); + + if (srcTy.getValidShape() != dstTy.getValidShape()) + return emitOpError("expects src/result to have the same validShape"); + + auto srcCfg = srcTy.getConfigAttr(); + auto dstCfg = dstTy.getConfigAttr(); + if (srcCfg != dstCfg) + return emitOpError("expects src/result to have the same tile config"); + + auto numel = getStaticNumElements(srcTy.getShape()); + if (!numel.has_value()) + return emitOpError("expects static shapes for bitcast"); + + auto srcBytes = getElemBytes(srcTy.getElementType()); + auto dstBytes = getElemBytes(dstTy.getElementType()); + if (!srcBytes.has_value() || !dstBytes.has_value()) + return emitOpError("unsupported element type for bitcast"); + + int64_t srcTotalBytes = numel.value() * srcBytes.value(); + int64_t dstTotalBytes = numel.value() * dstBytes.value(); + if (dstTotalBytes > srcTotalBytes) + return emitOpError("bitcast result requires more bytes than source storage"); + + return success(); +} + + +mlir::LogicalResult mlir::pto::TRowExpandOp::verify() { + auto verifyCommon = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyNDStyleVecTile(*this, dstTy, "dst"))) + return failure(); + auto srcSpace = getPTOMemorySpaceEnum(srcTy); + if (!srcSpace || *srcSpace != pto::AddressSpace::VEC) + return emitOpError("expects src to be in the vec address space"); + if (auto srcTb = dyn_cast(srcTy)) { + if (srcTb.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) + return emitOpError("expects src to use the none_box slayout"); + } + if (getElemTy(srcTy) != getElemTy(dstTy)) + return emitOpError("expects src and dst to have the same element type"); + if (!isSupportedVecElemType(getElemTy(srcTy), /*allowBf16=*/true, + /*allowInt8=*/true)) + return emitOpError("expects trowexpand element type to be supported"); + auto srcValid = getValidShapeVec(getSrc()); + auto dstValid = getValidShapeVec(getDst()); + if (srcValid.size() != 2 || dstValid.size() != 2) + return emitOpError("expects src and dst to have rank-2 valid_shape"); + if (srcValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && + srcValid[0] != dstValid[0]) + return emitOpError("expects src and dst to have the same valid_shape[0]"); + if (srcValid[0] != ShapedType::kDynamic && srcValid[0] == 0) + return emitOpError("expects src valid_shape[0] to be non-zero"); + if (srcValid[1] != ShapedType::kDynamic && srcValid[1] == 0) + return emitOpError("expects src valid_shape[1] to be non-zero"); + if (dstValid[0] != ShapedType::kDynamic && dstValid[0] == 0) + return emitOpError("expects dst valid_shape[0] to be non-zero"); + if (dstValid[1] != ShapedType::kDynamic && dstValid[1] == 0) + return emitOpError("expects dst valid_shape[1] to be non-zero"); + return success(); + }; + auto verifyA2A3 = [&]() -> LogicalResult { + return verifyCommon(); + }; + auto verifyA5 = [&]() -> LogicalResult { + return verifyCommon(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + + +ParseResult mlir::pto::TSort32Op::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand src, idx, tmp, dst; + Type srcTy, dstTy, idxTy, tmpTy; + bool hasTmp = false; + + if (parser.parseKeyword("ins") || parser.parseLParen() || parser.parseOperand(src)) + return failure(); + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseOperand(idx)) + return failure(); + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseOperand(tmp)) + return failure(); + hasTmp = true; + } + } else { + return failure(); + } + if (parser.parseColonType(srcTy) || parser.parseComma() || parser.parseType(idxTy)) + return failure(); + if (hasTmp) { + if (parser.parseComma() || parser.parseType(tmpTy)) + return failure(); + } + if (parser.parseRParen()) + return failure(); + + if (parser.parseKeyword("outs") || parser.parseLParen() || + parser.parseOperand(dst) || parser.parseColonType(dstTy) || + parser.parseRParen()) + return failure(); + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + if (parser.resolveOperand(src, srcTy, result.operands) || + parser.resolveOperand(idx, idxTy, result.operands)) + return failure(); + if (hasTmp) { + if (parser.resolveOperand(tmp, tmpTy, result.operands)) + return failure(); + } + if (parser.resolveOperand(dst, dstTy, result.operands)) + return failure(); + + result.addAttribute( + "operandSegmentSizes", + parser.getBuilder().getDenseI32ArrayAttr({1, 1, hasTmp ? 1 : 0, 1})); + return success(); +} + +void mlir::pto::TSort32Op::print(OpAsmPrinter &p) { + p << " ins(" << getSrc() << ", " << getIdx(); + if (getTmp()) { + p << ", " << getTmp(); + p << " : " << getSrc().getType() << ", " << getIdx().getType() + << ", " << getTmp().getType() << ")"; + } else { + p << " : " << getSrc().getType() << ", " << getIdx().getType() << ")"; + } + p << " outs(" << getDst() << " : " << getDst().getType() << ")"; + p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"operandSegmentSizes"}); +} + +ParseResult mlir::pto::TRsqrtOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand src, tmp, dst; + Type srcTy, tmpTy, dstTy; + bool hasTmp = false; + + if (parser.parseKeyword("ins") || parser.parseLParen() || parser.parseOperand(src)) + return failure(); + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseOperand(tmp)) + return failure(); + hasTmp = true; + } + if (parser.parseColonType(srcTy)) + return failure(); + if (hasTmp) { + if (parser.parseComma() || parser.parseType(tmpTy)) + return failure(); + } + if (parser.parseRParen()) + return failure(); + + if (parser.parseKeyword("outs") || parser.parseLParen() || + parser.parseOperand(dst) || parser.parseColonType(dstTy) || + parser.parseRParen()) + return failure(); + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + if (parser.resolveOperand(src, srcTy, result.operands) || + parser.resolveOperand(dst, dstTy, result.operands)) + return failure(); + if (hasTmp && parser.resolveOperand(tmp, tmpTy, result.operands)) + return failure(); + + return success(); +} + +void mlir::pto::TRsqrtOp::print(OpAsmPrinter &p) { + p << " ins(" << getSrc(); + if (getTmp()) + p << ", " << getTmp(); + p << " : " << getSrc().getType(); + if (getTmp()) + p << ", " << getTmp().getType(); + p << ")"; + p << " outs(" << getDst() << " : " << getDst().getType() << ")"; + p.printOptionalAttrDict((*this)->getAttrs()); +} + +static ParseResult parseTRowExpandBinaryLikeOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand src0, src1, tmp, dst; + Type src0Ty, src1Ty, tmpTy, dstTy; + bool hasTmp = false; + + if (parser.parseKeyword("ins") || parser.parseLParen() || + parser.parseOperand(src0) || parser.parseComma() || parser.parseOperand(src1)) + return failure(); + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseOperand(tmp)) + return failure(); + hasTmp = true; + } + if (parser.parseColon()) + return failure(); + if (parser.parseType(src0Ty) || parser.parseComma() || parser.parseType(src1Ty)) + return failure(); + if (hasTmp) { + if (parser.parseComma() || parser.parseType(tmpTy)) + return failure(); + } + if (parser.parseRParen()) + return failure(); + if (parser.parseKeyword("outs") || parser.parseLParen() || + parser.parseOperand(dst) || parser.parseColonType(dstTy) || + parser.parseRParen()) + return failure(); + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + if (parser.resolveOperand(src0, src0Ty, result.operands) || + parser.resolveOperand(src1, src1Ty, result.operands)) + return failure(); + if (hasTmp) { + if (parser.resolveOperand(tmp, tmpTy, result.operands)) + return failure(); + } + if (parser.resolveOperand(dst, dstTy, result.operands)) + return failure(); + + result.addAttribute( + "operandSegmentSizes", + parser.getBuilder().getDenseI32ArrayAttr({1, 1, hasTmp ? 1 : 0, 1})); + return success(); +} + +static void printTRowExpandBinaryLikeOp(OpAsmPrinter &p, Operation *op, Value src0, + Value src1, Value tmp, Value dst) { + p << " ins(" << src0 << ", " << src1; + if (tmp) { + p << ", " << tmp; + p << " : " << src0.getType() << ", " << src1.getType() << ", " + << tmp.getType() << ")"; + } else { + p << " : " << src0.getType() << ", " << src1.getType() << ")"; + } + p << " outs(" << dst << " : " << dst.getType() << ")"; + p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"operandSegmentSizes"}); +} + +ParseResult mlir::pto::TRowExpandDivOp::parse(OpAsmParser &parser, OperationState &result) { + return parseTRowExpandBinaryLikeOp(parser, result); +} + +void mlir::pto::TRowExpandDivOp::print(OpAsmPrinter &p) { + printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), + getDst()); +} + +ParseResult mlir::pto::TRowExpandMulOp::parse(OpAsmParser &parser, OperationState &result) { + return parseTRowExpandBinaryLikeOp(parser, result); +} + +void mlir::pto::TRowExpandMulOp::print(OpAsmPrinter &p) { + printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), + getDst()); +} + +ParseResult mlir::pto::TRowExpandSubOp::parse(OpAsmParser &parser, OperationState &result) { + return parseTRowExpandBinaryLikeOp(parser, result); +} + +void mlir::pto::TRowExpandSubOp::print(OpAsmPrinter &p) { + printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), + getDst()); +} + +ParseResult mlir::pto::TRowExpandExpdifOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseTRowExpandBinaryLikeOp(parser, result); +} + +void mlir::pto::TRowExpandExpdifOp::print(OpAsmPrinter &p) { + printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), + getDst()); +} + +ParseResult mlir::pto::TRowExpandMaxOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseTRowExpandBinaryLikeOp(parser, result); +} + +void mlir::pto::TRowExpandMaxOp::print(OpAsmPrinter &p) { + printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), + getDst()); +} + +ParseResult mlir::pto::TRowExpandMinOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseTRowExpandBinaryLikeOp(parser, result); +} + +void mlir::pto::TRowExpandMinOp::print(OpAsmPrinter &p) { + printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), + getDst()); +} + +static FailureOr verifyTRowExpandBinaryCore(Operation *op, Type src0Ty, + Type src1Ty, Type dstTy, + Type tmpTy, bool hasTmp) { + if (failed(verifyTileBufCommon(op, src0Ty, "src0")) || + failed(verifyTileBufCommon(op, src1Ty, "src1")) || + failed(verifyTileBufCommon(op, dstTy, "dst"))) + return failure(); + if (hasTmp && failed(verifyTileBufCommon(op, tmpTy, "tmp"))) + return failure(); + if (failed(verifyTileBufSameElemType(op, src0Ty, dstTy, "src0", "dst"))) + return failure(); + if (getElemTy(src0Ty) != getElemTy(src1Ty)) { + op->emitOpError("expects src0 and src1 to have the same element type"); + return failure(); + } + if (!isRowMajorTileBuf(dstTy)) { + op->emitOpError("expects dst to use row-major layout"); + return failure(); + } + return getElemTy(src0Ty); +} + +mlir::LogicalResult mlir::pto::TRowExpandDivOp::verify() { + auto verifyByArch = [&](PTOArch targetArch) -> LogicalResult { + Type src0Ty = getSrc0().getType(); + Type src1Ty = getSrc1().getType(); + Type dstTy = getDst().getType(); + FailureOr elemOr = verifyTRowExpandBinaryCore( + *this, src0Ty, src1Ty, dstTy, getTmp() ? getTmp().getType() : Type{}, + static_cast(getTmp())); + if (failed(elemOr)) + return failure(); + Type elem = *elemOr; + bool supported = + elem.isF16() || elem.isF32() || + (targetArch == PTOArch::A5 && + (elem.isInteger(8) || elem.isInteger(16) || elem.isInteger(32))); + if (!supported) { + if (targetArch == PTOArch::A5) + return emitOpError( + "expects A5 trowexpanddiv element type to be i8/i16/i32/f16/f32"); + return emitOpError("expects element type to be f16 or f32"); + } + return mlir::success(); + }; + auto verifyA2A3 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A3); }; + auto verifyA5 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A5); }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + + +mlir::LogicalResult mlir::pto::TRowExpandMulOp::verify() { + auto verifyByArch = [&](PTOArch targetArch) -> LogicalResult { + Type src0Ty = getSrc0().getType(); + Type src1Ty = getSrc1().getType(); + Type dstTy = getDst().getType(); + FailureOr elemOr = verifyTRowExpandBinaryCore( + *this, src0Ty, src1Ty, dstTy, getTmp() ? getTmp().getType() : Type{}, + static_cast(getTmp())); + if (failed(elemOr)) + return failure(); + Type elem = *elemOr; + bool supported = elem.isF16() || elem.isF32() || elem.isInteger(16) || + elem.isInteger(32) || + (targetArch == PTOArch::A5 && elem.isInteger(8)); + if (!supported) { + if (targetArch == PTOArch::A5) + return emitOpError( + "expects A5 trowexpandmul element type to be i8/i16/i32/f16/f32"); + return emitOpError( + "expects A2/A3 trowexpandmul element type to be i16/i32/f16/f32"); + } + return mlir::success(); + }; + auto verifyA2A3 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A3); }; + auto verifyA5 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A5); }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + + +mlir::LogicalResult mlir::pto::TRowExpandSubOp::verify() { + auto verifyByArch = [&](PTOArch targetArch) -> LogicalResult { + Type src0Ty = getSrc0().getType(); + Type src1Ty = getSrc1().getType(); + Type dstTy = getDst().getType(); + FailureOr elemOr = verifyTRowExpandBinaryCore( + *this, src0Ty, src1Ty, dstTy, getTmp() ? getTmp().getType() : Type{}, + static_cast(getTmp())); + if (failed(elemOr)) + return failure(); + Type elem = *elemOr; + bool supported = elem.isF16() || elem.isF32() || elem.isInteger(16) || + elem.isInteger(32) || + (targetArch == PTOArch::A5 && elem.isInteger(8)); + if (!supported) { + if (targetArch == PTOArch::A5) + return emitOpError( + "expects A5 trowexpandsub element type to be i8/i16/i32/f16/f32"); + return emitOpError( + "expects A2/A3 trowexpandsub element type to be i16/i32/f16/f32"); + } + return mlir::success(); + }; + auto verifyA2A3 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A3); }; + auto verifyA5 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A5); }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TRowExpandAddOp::verify() { + auto verifyByArch = [&](PTOArch targetArch) -> LogicalResult { + Type src0Ty = getSrc0().getType(); + Type src1Ty = getSrc1().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, src0Ty, "src0")) || + failed(verifyTileBufCommon(*this, src1Ty, "src1")) || + failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst"))) + return failure(); + if (failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) + return failure(); + if (getElemTy(src0Ty) != getElemTy(src1Ty)) + return emitOpError("expects src0 and src1 to have the same element type"); + if (!isRowMajorTileBuf(src0Ty)) + return emitOpError("expects src0 to use row-major layout"); + if (!isRowMajorTileBuf(dstTy)) + return emitOpError("expects dst to use row-major layout"); + Type elem = getElemTy(src0Ty); + bool supported = elem.isF16() || elem.isF32() || elem.isInteger(16) || + elem.isInteger(32) || + (targetArch == PTOArch::A5 && elem.isInteger(8)); + if (!supported) { + if (targetArch == PTOArch::A5) + return emitOpError( + "expects A5 trowexpandadd element type to be i8/i16/i32/f16/f32"); + return emitOpError( + "expects A2/A3 trowexpandadd element type to be i16/i32/f16/f32"); + } + auto src1Valid = getValidShapeVec(src1Ty); + auto dstValid = getValidShapeVec(dstTy); + if (src1Valid.size() != 2 || dstValid.size() != 2) + return emitOpError("expects src1 and dst to have rank-2 valid_shape"); + if (src1Valid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && + src1Valid[0] != dstValid[0]) + return emitOpError("expects src1 valid_shape[0] to equal dst valid_shape[0]"); + bool src1IsRowMajor = isRowMajorTileBuf(src1Ty); + int64_t expectedCol = elem.isInteger(8) + ? 32 + : ((elem.isF16() || elem.isInteger(16)) ? 16 : 8); + int64_t src1Col = src1Valid[1]; + if (src1IsRowMajor) { + if (src1Col != ShapedType::kDynamic && src1Col != expectedCol) + return emitOpError("expects row-major src1 valid_shape[1] to be 32/sizeof(dtype)"); + } else { + if (src1Col != ShapedType::kDynamic && src1Col != 1) + return emitOpError("expects non-row-major src1 valid_shape[1] to be 1"); + } + return mlir::success(); + }; + auto verifyA2A3 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A3); }; + auto verifyA5 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A5); }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +static LogicalResult verifyTRowExpandReduceLikeOp(Operation *op, Type src0Ty, + Type src1Ty, Type dstTy, + Type tmpTy, bool hasTmp, + PTOArch targetArch, + StringRef opName, + bool allowIntegerTypes) { + if (failed(verifyTileBufCommon(op, src0Ty, "src0")) || + failed(verifyTileBufCommon(op, src1Ty, "src1")) || + failed(verifyTileBufCommon(op, dstTy, "dst"))) + return failure(); + if (hasTmp) { + if (failed(verifyTileBufCommon(op, tmpTy, "tmp"))) + return failure(); + if (getElemTy(tmpTy) != getElemTy(dstTy)) + return op->emitOpError() << "expects tmp and dst to have the same element type"; + } + + Type elem = getElemTy(dstTy); + if (!elem || getElemTy(src0Ty) != elem || getElemTy(src1Ty) != elem) + return op->emitOpError("expects src0, src1, and dst to have the same element type"); + bool supported = elem.isF16() || elem.isF32() || + (allowIntegerTypes && + (elem.isInteger(16) || elem.isInteger(32) || + (targetArch == PTOArch::A5 && elem.isInteger(8)))); + if (!supported) { + if (!allowIntegerTypes) + return op->emitOpError() << "expects " << opName + << " element type to be f16 or f32"; + if (targetArch == PTOArch::A5) + return op->emitOpError() << "expects A5 " << opName + << " element type to be i8/i16/i32/f16/f32"; + return op->emitOpError() << "expects A2/A3 " << opName + << " element type to be i16/i32/f16/f32"; + } + + if (!isRowMajorTileBuf(dstTy)) + return op->emitOpError("expects dst to use row-major layout"); + + auto src0Valid = getValidShapeVec(src0Ty); + auto src1Valid = getValidShapeVec(src1Ty); + auto dstValid = getValidShapeVec(dstTy); + if (src0Valid.size() != 2 || src1Valid.size() != 2 || dstValid.size() != 2) + return op->emitOpError("expects src0, src1, and dst to have rank-2 valid_shape"); + + if (dstValid[0] != ShapedType::kDynamic && dstValid[0] == 0) + return op->emitOpError("expects dst valid_shape[0] to be non-zero"); + if (dstValid[1] != ShapedType::kDynamic && dstValid[1] == 0) + return op->emitOpError("expects dst valid_shape[1] to be non-zero"); + + auto validShapeMatches = [](ArrayRef lhs, + ArrayRef rhs) -> bool { + if (lhs.size() != rhs.size()) + return false; + for (auto [l, r] : llvm::zip(lhs, rhs)) { + if (l != ShapedType::kDynamic && r != ShapedType::kDynamic && l != r) + return false; + } + return true; + }; + + const bool src0MatchesDst = validShapeMatches(src0Valid, dstValid); + const bool src1MatchesDst = validShapeMatches(src1Valid, dstValid); + + auto checkBroadcastOperand = [&](Type operandTy, ArrayRef operandValid, + StringRef operandName, + bool requireNonRowMajor) -> LogicalResult { + if (operandValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && + operandValid[0] != dstValid[0]) { + return op->emitOpError() << "expects " << operandName + << " valid_shape[0] to equal dst valid_shape[0]"; + } + int64_t expectedCol = elem.isInteger(8) ? 32 : ((elem.isF16() || elem.isInteger(16)) ? 16 : 8); + int64_t operandCol = operandValid[1]; + bool operandIsRowMajor = isRowMajorTileBuf(operandTy); + if (requireNonRowMajor && operandIsRowMajor) { + return op->emitOpError() << "expects " << operandName + << " to use a non-row-major layout when tmp is present"; + } + if (operandIsRowMajor) { + if (operandCol != ShapedType::kDynamic && operandCol != expectedCol) { + return op->emitOpError() + << "expects row-major " << operandName + << " valid_shape[1] to be 32/sizeof(dtype)"; + } + return success(); + } + if (operandCol != ShapedType::kDynamic && operandCol != 1) { + return op->emitOpError() << "expects non-row-major " << operandName + << " valid_shape[1] to be 1"; + } + return success(); + }; + + auto checkFullAndBroadcast = [&](Type fullTy, ArrayRef fullValid, + StringRef fullName, Type broadcastTy, + ArrayRef broadcastValid, + StringRef broadcastName) -> LogicalResult { + if (!isRowMajorTileBuf(fullTy)) + return op->emitOpError() << "expects " << fullName + << " to use row-major layout when it matches dst"; + if (fullValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && + fullValid[0] != dstValid[0]) + return op->emitOpError() << "expects " << fullName + << " valid_shape[0] to equal dst valid_shape[0]"; + if (fullValid[1] != ShapedType::kDynamic && dstValid[1] != ShapedType::kDynamic && + fullValid[1] != dstValid[1]) + return op->emitOpError() << "expects " << fullName + << " valid_shape[1] to equal dst valid_shape[1]"; + return checkBroadcastOperand(broadcastTy, broadcastValid, broadcastName, + /*requireNonRowMajor=*/hasTmp && + targetArch == PTOArch::A3); + }; + + if (hasTmp && targetArch == PTOArch::A5) + return op->emitOpError("expects A5 form to omit tmp"); + + if (src0MatchesDst) { + if (succeeded(checkFullAndBroadcast(src0Ty, src0Valid, "src0", src1Ty, + src1Valid, "src1"))) + return success(); + } + if (src1MatchesDst) { + if (succeeded(checkFullAndBroadcast(src1Ty, src1Valid, "src1", src0Ty, + src0Valid, "src0"))) + return success(); + } + + return op->emitOpError() << "expects one of src0/src1 to match dst valid_shape" + << " and the other to be a per-row scalar vector"; +} + +mlir::LogicalResult mlir::pto::TRowExpandExpdifOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + getTmp() ? getTmp().getType() : Type{}, + (bool)getTmp(), PTOArch::A3, + "trowexpandexpdif", + /*allowIntegerTypes=*/false); + }; + auto verifyA5 = [&]() -> LogicalResult { + return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + getTmp() ? getTmp().getType() : Type{}, + (bool)getTmp(), PTOArch::A5, + "trowexpandexpdif", + /*allowIntegerTypes=*/false); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TRowExpandMaxOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + getTmp() ? getTmp().getType() : Type{}, + (bool)getTmp(), PTOArch::A3, + "trowexpandmax", + /*allowIntegerTypes=*/true); + }; + auto verifyA5 = [&]() -> LogicalResult { + return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + getTmp() ? getTmp().getType() : Type{}, + (bool)getTmp(), PTOArch::A5, + "trowexpandmax", + /*allowIntegerTypes=*/true); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TRowExpandMinOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + getTmp() ? getTmp().getType() : Type{}, + (bool)getTmp(), PTOArch::A3, + "trowexpandmin", + /*allowIntegerTypes=*/true); + }; + auto verifyA5 = [&]() -> LogicalResult { + return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + getTmp() ? getTmp().getType() : Type{}, + (bool)getTmp(), PTOArch::A5, + "trowexpandmin", + /*allowIntegerTypes=*/true); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + + +mlir::LogicalResult mlir::pto::TRowMaxOp::verify() { + auto verifyByArch = [&]() -> LogicalResult { + return verifyTRowReductionNoTmpCommon(*this, getSrc().getType(), + getDst().getType(), + "expects element type to be i16/i32/f16/f32"); + }; + return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); +} + +mlir::LogicalResult mlir::pto::TRowArgMaxOp::verify() { + auto verifyByArch = [&]() -> LogicalResult { + return verifyTRowArgReductionCommon(*this, getSrc().getType(), + getTmp().getType(), getDst().getType()); + }; + + return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); +} + + +mlir::LogicalResult mlir::pto::TRowMinOp::verify() { + auto verifyByArch = [&]() -> LogicalResult { + return verifyTRowReductionWithTmpCommon( + *this, getSrc().getType(), getTmp().getType(), getDst().getType(), + "expects element type to be i16/i32/f16/f32"); + }; + return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); +} + +mlir::LogicalResult mlir::pto::TRowArgMinOp::verify() { + auto verifyByArch = [&]() -> LogicalResult { + return verifyTRowArgReductionCommon(*this, getSrc().getType(), + getTmp().getType(), getDst().getType()); + }; + + return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); +} + + +mlir::LogicalResult mlir::pto::TRowSumOp::verify() { + auto verifyByArch = [&]() -> LogicalResult { + return verifyTRowReductionNoTmpCommon(*this, getSrc().getType(), + getDst().getType(), + "expects element type to be i16/i32/f16/f32"); + }; + return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); +} + +mlir::LogicalResult mlir::pto::TRowProdOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + return verifyTRowReductionWithTmpCommon( + *this, getSrc().getType(), getTmp().getType(), getDst().getType(), + "expects A2/A3 trowprod element type to be i16/i32/f16/f32"); + }; + auto verifyA5 = [&]() -> LogicalResult { + return verifyTRowReductionWithTmpCommon( + *this, getSrc().getType(), getTmp().getType(), getDst().getType(), + "expects A5 trowprod element type to be i16/i32/f16/f32"); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + + +mlir::LogicalResult mlir::pto::TRsqrtOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + Type ts = getSrc().getType(); + Type td = getDst().getType(); + if (failed(verifyVecTileUnaryOp(*this, ts, td, "src", "dst", + /*allowBf16=*/false, /*allowInt8=*/false))) + return failure(); + if (failed(verifyTileBufSameValidShape(*this, ts, td, "src", "dst"))) + return failure(); + auto ft = mlir::dyn_cast(getElemTy(ts)); + if (!ft || (!ft.isF16() && !ft.isF32())) + return emitOpError("expects element type to be f16 or f32"); + if (auto tmp = getTmp()) { + Type tt = tmp.getType(); + if (failed(verifyVecTileCommon(*this, tt, "tmp"))) + return failure(); + + auto tmpElemTy = getElemTy(tt); + auto tmpElemBytes = getElemBytes(tmpElemTy); + auto tmpNumel = getStaticNumElements(getShapeVec(tt)); + if (!tmpElemBytes.has_value() || !tmpNumel.has_value()) + return emitOpError("expects tmp to have a static, byte-addressable tile type"); + if (tmpElemBytes.value() * tmpNumel.value() < 32) + return emitOpError("expects tmp to be at least 32 bytes when provided"); + } + return mlir::success(); +} + + +mlir::LogicalResult mlir::pto::TScatterOp::verify() { + const bool hasIndexes = static_cast(getIndexes()); + const bool hasMaskPattern = static_cast(getMaskPatternAttr()); + if (hasIndexes == hasMaskPattern) { + return emitOpError( + "expects exactly one of indexes operand or maskPattern attribute"); + } + + auto isAllowedDataElem = [&](mlir::Type t) -> bool { + if (t.isF16() || t.isF32() || t.isBF16()) return true; + if (auto it = mlir::dyn_cast(t)) + return (it.getWidth() == 8 || it.getWidth() == 16 || it.getWidth() == 32); + return false; + }; + auto isAllowedIndexElem = [&](mlir::Type t) -> bool { + if (auto it = mlir::dyn_cast(t)) + return (it.getWidth() == 16 || it.getWidth() == 32); + return false; + }; + auto getMaskScatterTimes = [&](mlir::pto::MaskPatternAttr mp) -> unsigned { + switch (mp.getValue()) { + case mlir::pto::MaskPattern::P1111: + return 1; + case mlir::pto::MaskPattern::P0101: + case mlir::pto::MaskPattern::P1010: + return 2; + default: + return 4; + } + }; + + auto verifyIndexedForm = [&]() -> LogicalResult { + Type ts = getSrc().getType(); + Type ti = getIndexes().getType(); + Type td = getDst().getType(); + if (failed(verifyVecTileStorage(*this, ts, "src")) || + failed(verifyVecTileStorage(*this, ti, "indexes")) || + failed(verifyVecTileStorage(*this, td, "dst"))) + return failure(); + + Type srcElem = getElemTy(ts), dstElem = getElemTy(td), idxElem = getElemTy(ti); + if (!srcElem || !dstElem || !idxElem) + return emitOpError("failed to get element type for operands"); + if (srcElem != dstElem) + return emitOpError("expects src/dst to have the same element type"); + + if (!isAllowedDataElem(srcElem)) + return emitOpError("expects src/dst element type to be i8/i16/i32/f16/bf16/f32"); + if (!isAllowedIndexElem(idxElem)) + return emitOpError("expects indexes element type to be i16/i32"); + + auto bwData = getPTOStorageElemBitWidth(srcElem); + auto bwIdx = getPTOStorageElemBitWidth(idxElem); + if (bwData != 8 && bwData != 16 && bwData != 32) + return emitOpError("unexpected src/dst element bitwidth"); + + unsigned dataBytes = bwData / 8; + unsigned idxBytes = bwIdx / 8; + unsigned expectedIdxBytes = (dataBytes == 1) ? 2 : dataBytes; + if (idxBytes != expectedIdxBytes) + return emitOpError("expects indexes element size to match the documented scatter rule"); + return mlir::success(); + }; + + auto verifyMaskForm = [&]() -> LogicalResult { + Type ts = getSrc().getType(); + Type td = getDst().getType(); + if (failed(verifyVecTileCommon(*this, ts, "src")) || + failed(verifyVecTileCommon(*this, td, "dst"))) + return failure(); + + auto srcTB = dyn_cast(ts); + auto dstTB = dyn_cast(td); + if (!srcTB || !dstTB) + return emitOpError("expects src and dst to be tile_buf types"); + + if (getElemTy(ts) != getElemTy(td)) + return emitOpError("expects src and dst to have the same element type"); + if (!isAllowedDataElem(getElemTy(ts))) + return emitOpError("expects src/dst element type to be i8/i16/i32/f16/bf16/f32"); + + auto srcValid = getValidShapeVec(ts); + auto dstValid = getValidShapeVec(td); + if (srcValid.size() != 2 || dstValid.size() != 2) + return emitOpError("expects src and dst to have rank-2 valid_shape"); + + auto mp = getMaskPatternAttr(); + if (!mp) + return emitOpError("expects mask-pattern tscatter to provide maskPattern"); + const unsigned times = getMaskScatterTimes(mp); + if (srcValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && + srcValid[0] != dstValid[0]) + return emitOpError("expects src and dst to have the same valid rows"); + if (srcValid[1] != ShapedType::kDynamic && dstValid[1] != ShapedType::kDynamic && + srcValid[1] != static_cast(dstValid[1] * times)) + return emitOpError("expects src valid cols to equal dst valid cols times the mask expansion factor"); + + if (srcTB.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) || + dstTB.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor)) + return emitOpError("expects mask-pattern tscatter to use row_major blayout"); + return mlir::success(); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + if (hasMaskPattern) + return verifyMaskForm(); + return verifyIndexedForm(); + }; + auto verifyA5 = [&]() -> LogicalResult { + if (hasMaskPattern) + return emitOpError("mask-pattern tscatter is not supported on A5 yet"); + return verifyIndexedForm(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + + +mlir::LogicalResult mlir::pto::TSelOp::verify() { + auto verifyCommon = [&]() -> FailureOr { + Type t0 = getSrc0().getType(); + Type t1 = getSrc1().getType(); + Type td = getDst().getType(); + if (failed(verifyTileBufCommon(*this, t0, "src0")) || + failed(verifyTileBufCommon(*this, t1, "src1")) || + failed(verifyTileBufCommon(*this, td, "dst"))) + return failure(); + + Type srcElem = getElemTy(t0); + Type src1Elem = getElemTy(t1); + Type dstElem = getElemTy(td); + if (!srcElem || !src1Elem || !dstElem) { + emitOpError("failed to get element type for operands"); + return failure(); + } + if (srcElem != src1Elem || srcElem != dstElem) { + emitOpError("expects src0, src1, and dst to have the same element type"); + return failure(); + } + + if (!isRowMajorTileBuf(t0) || !isRowMajorTileBuf(t1) || + !isRowMajorTileBuf(td)) { + emitOpError( + "expects src0, src1, and dst to use row-major layout"); + return failure(); + } + return srcElem; + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + FailureOr srcElem = verifyCommon(); + if (failed(srcElem)) + return failure(); + Type elem = *srcElem; + bool ok = elem.isF16() || elem.isBF16() || elem.isF32(); + if (auto it = dyn_cast(elem)) + ok = it.getWidth() == 16 || it.getWidth() == 32; + if (!ok) + return emitOpError( + "expects A2/A3 tsel src0, src1, and dst element type to be i16/i32/f16/bf16/f32"); + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + FailureOr srcElem = verifyCommon(); + if (failed(srcElem)) + return failure(); + Type elem = *srcElem; + bool ok = elem.isF16() || elem.isBF16() || elem.isF32(); + if (auto it = dyn_cast(elem)) + ok = it.getWidth() == 8 || it.getWidth() == 16 || it.getWidth() == 32; + if (!ok) + return emitOpError( + "expects A5 tsel src0, src1, and dst element type to be i8/i16/i32/f16/bf16/f32"); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + + +mlir::LogicalResult mlir::pto::TSelSOp::verify() { + // Constraints & Verification per PTO_IR_manual.md pto.tsels: + // - src and dst same element type; A2A3: i16/i32/f16/f32; A5: i8/i16/i32/f16/f32 + // - src and dst row-major; src and dst same valid region + auto verifyCommon = [&]() -> FailureOr { + Type tMask = getMask().getType(); + Type tSrc = getSrc().getType(); + Type tTmp = getTmp().getType(); + Type tDst = getDst().getType(); + if (failed(verifyTileBufCommon(*this, tMask, "mask")) || + failed(verifyTileBufCommon(*this, tSrc, "src")) || + failed(verifyTileBufCommon(*this, tTmp, "tmp")) || + failed(verifyTileBufCommon(*this, tDst, "dst"))) + return failure(); + Type eMask = getElemTy(tMask), eSrc = getElemTy(tSrc); + Type eTmp = getElemTy(tTmp), eDst = getElemTy(tDst); + if (!eMask || !eSrc || !eTmp || !eDst) { + emitOpError("failed to get element type for operands"); + return failure(); + } + if (eSrc != eDst) + return emitOpError("expects src and dst to have the same element type"); + if (failed(verifyTileBufSameValidShape(*this, tSrc, tDst, "src", "dst"))) + return failure(); + return eDst; + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + FailureOr elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + Type tSrc = getSrc().getType(); + Type tDst = getDst().getType(); + if (!isRowMajorTileBuf(tSrc) || !isRowMajorTileBuf(tDst)) + return emitOpError("expects src and dst to use row-major layout"); + Type elem = *elemOr; + bool ok = elem.isF16() || elem.isF32(); + if (auto it = mlir::dyn_cast(elem)) + ok = (it.getWidth() == 16 || it.getWidth() == 32); + if (!ok) + return emitOpError( + "expects A2/A3 tsels src and dst element type to be i16, i32, f16, or f32"); + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + FailureOr elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + Type tSrc = getSrc().getType(); + Type tDst = getDst().getType(); + if (!isRowMajorTileBuf(tSrc) || !isRowMajorTileBuf(tDst)) + return emitOpError("expects src and dst to use row-major layout"); + Type elem = *elemOr; + bool ok = elem.isF16() || elem.isF32(); + if (auto it = mlir::dyn_cast(elem)) + ok = (it.getWidth() == 8 || it.getWidth() == 16 || it.getWidth() == 32); + if (!ok) + return emitOpError( + "expects A5 tsels src and dst element type to be i8, i16, i32, f16, or f32"); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + + +mlir::LogicalResult mlir::pto::TShlOp::verify() { + auto verify = [&]() -> LogicalResult { + FailureOr elemOr = verifyShiftLikeBinaryTileOpCommon( + *this, getSrc0().getType(), getSrc1().getType(), getDst().getType()); + if (failed(elemOr)) + return failure(); + auto it = mlir::dyn_cast(*elemOr); + if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && + it.getWidth() != 32)) + return emitOpError( + "expects tshl src0 and src1 element type to be i8/i16/i32"); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verify, verify); +} + + +mlir::LogicalResult mlir::pto::TShrOp::verify() { + auto verify = [&]() -> LogicalResult { + FailureOr elemOr = verifyShiftLikeBinaryTileOpCommon( + *this, getSrc0().getType(), getSrc1().getType(), getDst().getType()); + if (failed(elemOr)) + return failure(); + auto it = mlir::dyn_cast(*elemOr); + if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && + it.getWidth() != 32)) + return emitOpError( + "expects tshr src0 and src1 element type to be i8/i16/i32"); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verify, verify); +} + + +mlir::LogicalResult mlir::pto::TSort32Op::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + Type idxTy = getIdx().getType(); + if (failed(verifyVecTileCommon(*this, srcTy, "src")) || + failed(verifyVecTileCommon(*this, dstTy, "dst")) || + failed(verifyVecTileCommon(*this, idxTy, "idx"))) + return failure(); + if (getTmp() && + failed(verifyVecTileCommon(*this, getTmp().getType(), "tmp"))) + return failure(); + + auto srcElem = getElemTy(srcTy); + auto dstElem = getElemTy(dstTy); + if (!srcElem || !dstElem || srcElem != dstElem) + return emitOpError() << "expects src and dst to have the same element type"; + if (!(srcElem.isF16() || srcElem.isF32())) + return emitOpError() << "expects src and dst element type to be f16 or f32"; + + auto idxElem = getElemTy(idxTy); + auto idxInt = dyn_cast(idxElem); + if (!idxInt || idxInt.getWidth() != 32) + return emitOpError() << "expects idx element type to be i32/u32"; + return mlir::success(); +} + + +mlir::LogicalResult mlir::pto::TSqrtOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyVecTileUnaryOp(*this, srcTy, dstTy, "src", "dst", + /*allowBf16=*/false, /*allowInt8=*/false))) + return failure(); + if (failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) + return failure(); + + auto srcElem = getElemTy(srcTy); + if (!(mlir::isa(srcElem) || mlir::isa(srcElem))) + return emitOpError() << "expects src and dst element type to be float or half"; + + return mlir::success(); +} + + + +mlir::LogicalResult mlir::pto::TStoreFPOp::verify() { + auto shouldBypassDecoded = [&]() -> bool { + Value src = getSrc(); + Value fp = getFp(); + return isa(src.getType()) || isa(fp.getType()) || + src.getDefiningOp() || + fp.getDefiningOp(); + }; + + auto verifyDstType = [&]() -> LogicalResult { + Type dstTy = getDst().getType(); + if (!isa(dstTy)) + return emitOpError() + << "expects dst to be a memref or !pto.partition_tensor_view"; + if (auto dstPart = dyn_cast(dstTy)) { + for (auto [idx, dim] : llvm::enumerate(dstPart.getShape())) { + if (dim != ShapedType::kDynamic && dim <= 0) + return emitOpError() + << "expects dst shape[" << idx << "] to be positive"; + } + } + return success(); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type fpTy = getFp().getType(); + if (!isa(srcTy)) + return emitOpError() << "expects src to be a !pto.tile_buf"; + if (!isa(fpTy)) + return emitOpError() << "expects fp to be a !pto.tile_buf"; + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, fpTy, "fp"))) + return failure(); + if (failed(verifyDstType())) + return failure(); + auto srcSpace = getPTOMemorySpaceEnum(srcTy); + if (!srcSpace || *srcSpace != pto::AddressSpace::ACC) + return emitOpError() << "expects src to be in the acc address space"; + auto srcElemTy = getElemTy(srcTy); + auto srcIntTy = dyn_cast(srcElemTy); + if (!(srcElemTy.isF32() || + (srcIntTy && srcIntTy.getWidth() == 32))) + return emitOpError() + << "expects src to have element type f32, i32"; + auto srcShape = getShapeVec(srcTy); + if (srcShape.size() != 2) + return emitOpError() << "expects src to have rank 2"; + if (srcShape[1] != ShapedType::kDynamic && + (srcShape[1] < 1 || srcShape[1] > 4095)) + return emitOpError() << "expects src.cols to be in the range [1, 4095]"; + auto srcValid = getValidShapeVec(srcTy); + if (srcValid.size() != 2) + return emitOpError() << "expects src to have a rank-2 valid_shape"; + if (srcValid[1] != ShapedType::kDynamic && + (srcValid[1] < 1 || srcValid[1] > 4095)) + return emitOpError() + << "expects src.valid_shape[1] to be in the range [1, 4095]"; + return mlir::success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type fpTy = getFp().getType(); + if (!isa(srcTy)) + return emitOpError() << "expects src to be a !pto.tile_buf"; + if (!isa(fpTy)) + return emitOpError() << "expects fp to be a !pto.tile_buf"; + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, fpTy, "fp"))) + return failure(); + if (failed(verifyDstType())) + return failure(); + auto srcSpace = getPTOMemorySpaceEnum(srcTy); + if (!srcSpace || *srcSpace != pto::AddressSpace::ACC) + return emitOpError() << "expects src to be in the acc address space"; + return mlir::success(); + }; + if (shouldBypassDecoded()) + return success(); + switch (getVerifierTargetArch(getOperation())) { + case VerifierTargetArch::A2A3: + return verifyA2A3(); + case VerifierTargetArch::A5: + return verifyA5(); + } + return failure(); +} + + +mlir::LogicalResult mlir::pto::TSubOp::verify() { + return verifyArithmeticBinaryTileOpWithArchDispatch( + getOperation(), getSrc0().getType(), getSrc1().getType(), getDst().getType(), + /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/false, + "expects A2/A3 tsub element type to be i32/i16/f16/f32", + "expects A5 tsub element type to be i32/i16/i8/f16/f32"); +} + + +mlir::LogicalResult mlir::pto::TSubCOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + Type src0Ty = getSrc0().getType(); + Type src1Ty = getSrc1().getType(); + Type src2Ty = getSrc2().getType(); + Type dstTy = getDst().getType(); + if (!isPTOShapedLike(src0Ty) || !isPTOShapedLike(src1Ty) || !isPTOShapedLike(src2Ty) || !isPTOShapedLike(dstTy)) + return emitOpError() << "expects PTO shaped-like src0, src1, src2, and dst"; + + auto d = getShapeVec(dstTy); + if (getShapeVec(src0Ty).size() != d.size() || getShapeVec(src1Ty).size() != d.size() || getShapeVec(src2Ty).size() != d.size()) + return emitOpError() << "expects all tensors to have the same rank"; + return mlir::success(); +} + + +mlir::LogicalResult mlir::pto::TSubSOp::verify() { + return verifyArithmeticScalarTileOpWithArchDispatch( + getOperation(), getSrc().getType(), getDst().getType(), getScalar().getType(), + /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, + "expects A2/A3 tsubs element type to be i32/i16/f16/f32", + "expects A5 tsubs element type to be i32/i16/i8/f16/bf16/f32", + /*requireValidRowsEqualOnA2A3=*/true, + /*requireValidRowsEqualOnA5=*/false); +} + + +mlir::LogicalResult mlir::pto::TSubSCOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + Type src0Ty = getSrc0().getType(); + Type src1Ty = getSrc1().getType(); + Type dstTy = getDst().getType(); + if (!isPTOShapedLike(src0Ty) || !isPTOShapedLike(src1Ty) || !isPTOShapedLike(dstTy)) + return emitOpError() << "expects PTO shaped-like src0, src1, and dst"; + + auto d = getShapeVec(dstTy); + if (getShapeVec(src0Ty).size() != d.size() || getShapeVec(src1Ty).size() != d.size()) + return emitOpError() << "expects src0, src1, and dst to have the same rank"; + return mlir::success(); +} +mlir::LogicalResult mlir::pto::TTransOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type tmpTy = getTmp().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, tmpTy, "tmp")) || + failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + Type srcElem = getElemTy(srcTy); + Type tmpElem = getElemTy(tmpTy); + Type dstElem = getElemTy(dstTy); + if (!srcElem || !tmpElem || !dstElem || srcElem != dstElem || srcElem != tmpElem) + return emitOpError() << "expects src and dst to have the same element type"; + if (auto srcTb = dyn_cast(srcTy)) { + if (srcTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor)) + return emitOpError() << "expects A2/A3 transpose src to use the row_major blayout"; + } + unsigned elemBytes = getPTOStorageElemByteSize(srcElem); + if (elemBytes == 0) + return emitOpError() << "failed to get transpose element size"; + if (elemBytes != 1 && elemBytes != 2 && elemBytes != 4) + return emitOpError() << "expects transpose element size to be 1, 2, or 4 bytes"; + auto isAllowedWidthType = [&](Type ty) { + if (elemBytes == 4) + return ty.isInteger(32) || ty.isF32(); + if (elemBytes == 2) + return ty.isInteger(16) || ty.isF16() || ty.isBF16(); + return ty.isInteger(8); + }; + if (!isAllowedWidthType(srcElem)) + return emitOpError() << "expects transpose element type to match the supported set for its width"; + return mlir::success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type tmpTy = getTmp().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, tmpTy, "tmp")) || + failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + Type srcElem = getElemTy(srcTy); + Type tmpElem = getElemTy(tmpTy); + Type dstElem = getElemTy(dstTy); + if (!srcElem || !tmpElem || !dstElem || srcElem != dstElem || srcElem != tmpElem) + return emitOpError() << "expects src, tmp, and dst to have the same element type"; + unsigned elemBytes = getPTOStorageElemByteSize(srcElem); + if (elemBytes == 0) + return emitOpError() << "failed to get transpose element size"; + if (elemBytes != 1 && elemBytes != 2 && elemBytes != 4) + return emitOpError() << "expects transpose element size to be 1, 2, or 4 bytes"; + auto isAllowedWidthType = [&](Type ty) { + if (elemBytes == 4) + return ty.isInteger(32) || ty.isF32(); + if (elemBytes == 2) + return ty.isInteger(16) || ty.isF16() || ty.isBF16(); + return ty.isInteger(8); + }; + if (!isAllowedWidthType(srcElem)) + return emitOpError() << "expects transpose element type to match the supported set for its width"; + auto checkAlignedMajor = [&](Type ty, StringRef name) -> LogicalResult { + auto tb = mlir::dyn_cast(ty); + if (!tb) + return success(); + auto shape = getShapeVec(ty); + if (shape.size() != 2) + return success(); + bool rowMajor = tb.getBLayoutValueI32() == static_cast(pto::BLayout::RowMajor); + int64_t major = rowMajor ? shape[1] : shape[0]; + if (major != ShapedType::kDynamic && (major * static_cast(elemBytes)) % 32 != 0) + return emitOpError() << "expects " << name << " major dimension times element size to be 32-byte aligned on A5"; + return success(); + }; + if (failed(checkAlignedMajor(srcTy, "src")) || failed(checkAlignedMajor(dstTy, "dst"))) + return failure(); + return mlir::success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TXorOp::verify() { + auto verifyBase = [&]() -> FailureOr { + return verifyMatchingRowMajorBinaryTileOpCommon( + getOperation(), getSrc0().getType(), getSrc1().getType(), + getDst().getType()); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + FailureOr elemOr = verifyBase(); + if (failed(elemOr)) + return failure(); + Type tmpTy = getTmp().getType(); + if (failed(verifyTileBufCommon(*this, tmpTy, "tmp"))) + return failure(); + Type elem = *elemOr; + if (getElemTy(tmpTy) != elem) + return emitOpError("expects tmp to have the same element type as src0, src1, and dst"); + if (!isRowMajorTileBuf(tmpTy)) + return emitOpError("expects tmp to use row-major layout"); + if (failed(verifyTileBufSameValidShape(*this, tmpTy, getDst().getType(), "tmp", "dst"))) + return failure(); + auto it = mlir::dyn_cast(elem); + if (!it || (it.getWidth() != 8 && it.getWidth() != 16)) + return emitOpError( + "expects A2/A3 txor src0, src1, tmp, and dst element type to be i8/i16"); + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + FailureOr elemOr = verifyBase(); + if (failed(elemOr)) + return failure(); + auto it = mlir::dyn_cast(*elemOr); + if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && + it.getWidth() != 32)) + return emitOpError( + "expects A5 txor src0, src1, and dst element type to be i8/i16/i32"); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + + +mlir::LogicalResult mlir::pto::TXorSOp::verify() { + auto verifyCommon = [&]() -> FailureOr { + return verifyDistinctRowMajorUnaryTileOpCommon(getOperation(), getSrc(), + getDst(), "src", "dst"); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + FailureOr elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + auto it = mlir::dyn_cast(*elemOr); + if (!it || (it.getWidth() != 8 && it.getWidth() != 16)) + return emitOpError( + "expects A2/A3 txors src and dst element type to be i8/i16"); + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + FailureOr elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + auto it = mlir::dyn_cast(*elemOr); + if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && + it.getWidth() != 32)) + return emitOpError( + "expects A5 txors src and dst element type to be i8/i16/i32"); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} +mlir::LogicalResult mlir::pto::TPrintOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + auto srcType = getSrc().getType(); + if (auto tb = mlir::dyn_cast(srcType)) { + auto elem = tb.getElementType(); + if (!(elem.isF16() || elem.isF32() || + elem.isInteger(8) || elem.isInteger(16) || elem.isInteger(32))) + return emitOpError() << "expects printable tile element type"; + auto space = getPTOMemorySpaceEnum(srcType); + if (!space || *space != pto::AddressSpace::VEC) + return emitOpError() << "expects printable tile_buf to be in vec address space"; + return success(); + } + if (mlir::dyn_cast(srcType) || + mlir::dyn_cast(srcType)) + return mlir::success(); + return emitOpError() << "expects tile_buf, memref, or partition_tensor_view for src"; +} + + + +[[maybe_unused]] static LogicalResult verifyMatmulCommon(Operation *op, Value lhs, Value rhs, + Value biasOpt, Type maybeDstElemTy, + Type maybeResultElemTy) { + // ---- case A: tensor/memref (ShapedType) ---- + if (auto lhsTy = dyn_cast(lhs.getType())) { + auto rhsTy = dyn_cast(rhs.getType()); + if (!rhsTy || !lhsTy.hasRank() || !rhsTy.hasRank()) + return op->emitOpError("expects lhs and rhs to be ranked tensors or memrefs"); + + if (lhsTy.getElementType() != rhsTy.getElementType()) + return op->emitOpError() + << "expects lhs and rhs to have the same element type, but got lhs=" + << lhsTy.getElementType() << " rhs=" << rhsTy.getElementType(); + + if (biasOpt) { + auto biasTy = dyn_cast(biasOpt.getType()); + if (!biasTy || !biasTy.hasRank()) + return op->emitOpError("expects bias to be a ranked tensor or memref"); + if (biasTy.getElementType() != lhsTy.getElementType()) + return op->emitOpError() + << "expects bias to have the same element type as lhs and rhs, but got bias=" + << biasTy.getElementType() << " vs " << lhsTy.getElementType(); + } + + if (maybeDstElemTy && maybeDstElemTy != lhsTy.getElementType()) + return op->emitOpError() + << "expects dst to have the same element type as lhs and rhs, but got dst=" + << maybeDstElemTy << " vs " << lhsTy.getElementType(); + + if (maybeResultElemTy && maybeResultElemTy != lhsTy.getElementType()) + return op->emitOpError() + << "expects result to have the same element type as lhs and rhs, but got result=" + << maybeResultElemTy << " vs " << lhsTy.getElementType(); + + return success(); + } + + // ---- case B: tile ---- + auto lhsTile = dyn_cast(lhs.getType()); + auto rhsTile = dyn_cast(rhs.getType()); + if (!lhsTile || !rhsTile) + return op->emitOpError("expects lhs and rhs to be ranked tensors, memrefs, or !pto.tile"); + + if (lhsTile.getElementType() != rhsTile.getElementType()) + return op->emitOpError() << "expects lhs and rhs tiles to have the same element type, but got lhs=" + << lhsTile.getElementType() << " rhs=" << rhsTile.getElementType(); + + if ((int64_t)lhsTile.getShape().size() != 2 || (int64_t)rhsTile.getShape().size() != 2) + return op->emitOpError("expects lhs and rhs tiles to be 2D"); + + if (lhsTile.getShape()[1] != rhsTile.getShape()[0]) + return op->emitOpError() << "expects lhs dim1 to equal rhs dim0, but got " + << lhsTile.getShape()[1] << " vs " << rhsTile.getShape()[0]; + + if (biasOpt) { + auto biasTile = dyn_cast(biasOpt.getType()); + if (!biasTile) + return op->emitOpError("expects bias to be !pto.tile when lhs and rhs are !pto.tile"); + if (biasTile.getElementType() != lhsTile.getElementType()) + return op->emitOpError("expects bias to have the same element type as lhs and rhs"); + } + + if (maybeDstElemTy && maybeDstElemTy != lhsTile.getElementType()) + return op->emitOpError() << "expects dst to have the same element type as lhs and rhs"; + + if (maybeResultElemTy && maybeResultElemTy != lhsTile.getElementType()) + return op->emitOpError() << "expects result to have the same element type as lhs and rhs"; + + return success(); +} + +LogicalResult mlir::pto::TMatmulOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + if (failed(verifyMatTileOperands(*this, getLhs().getType(), getRhs().getType(), + getDst().getType()))) + return failure(); + if (failed(verifyMatmulTypeTriple(*this, getElemTy(getLhs().getType()), + getElemTy(getRhs().getType()), + getElemTy(getDst().getType())))) + return failure(); + return verifyMatmulLike(*this, getLhs().getType(), getRhs().getType(), + getDst().getType()); + }; + auto verifyA5 = [&]() -> LogicalResult { return verifyA2A3(); }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult mlir::pto::TGemvOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + if (failed(verifyGemvTileOperands(*this, getLhs().getType(), getRhs().getType(), + getDst().getType()))) + return failure(); + if (failed(verifyMatmulTypeTriple(*this, getElemTy(getLhs().getType()), + getElemTy(getRhs().getType()), + getElemTy(getDst().getType())))) + return failure(); + return verifyMatmulLike(*this, getLhs().getType(), getRhs().getType(), + getDst().getType()); + }; + auto verifyA5 = [&]() -> LogicalResult { return verifyA2A3(); }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult mlir::pto::TMatmulAccOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + if (failed(verifyAccTileCommon(*this, getAccIn().getType(), "acc_in")) || + failed(verifyMatTileOperands(*this, getLhs().getType(), getRhs().getType(), + getDst().getType()))) + return failure(); + return success(); +} + +LogicalResult mlir::pto::TGemvAccOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + if (failed(verifyAccTileCommon(*this, getAccIn().getType(), "acc_in")) || + failed(verifyGemvTileOperands(*this, getLhs().getType(), getRhs().getType(), + getDst().getType()))) + return failure(); + return success(); +} + +//===----------------------------------------------------------------------===// +// inferReturnTypes() for matmul ops (keep your existing code) +//===----------------------------------------------------------------------=== +[[maybe_unused]] static mlir::Type inferMatmulTileResult2DFromAB(MLIRContext *context, ValueRange operands) { + if (operands.size() < 2) + return mlir::Type(); + + auto lhsTile = dyn_cast(operands[0].getType()); + auto rhsTile = dyn_cast(operands[1].getType()); + if (!lhsTile || !rhsTile) + return mlir::Type(); + + Type elemTy = lhsTile.getElementType(); + + if (operands.size() >= 3) { + if (auto biasTile = dyn_cast(operands[2].getType())) { + return mlir::pto::TileType::get(context, biasTile.getShape(), elemTy); + } + } + + auto lhsShape = lhsTile.getShape(); + auto rhsShape = rhsTile.getShape(); + if (lhsShape.size() >= 2 && rhsShape.size() >= 2) { + int64_t M = lhsShape[0]; + int64_t N = rhsShape[1]; + llvm::SmallVector outShape = {M, N}; + return mlir::pto::TileType::get(context, outShape, elemTy); + } + + return mlir::Type(); +} + +[[maybe_unused]] static RankedTensorType inferMatmulResult2DFromAB(ValueRange operands) { + if (operands.size() < 2) + return RankedTensorType(); + + auto lhsTy = dyn_cast(operands[0].getType()); + auto rhsTy = dyn_cast(operands[1].getType()); + if (!lhsTy || !rhsTy || !lhsTy.hasRank() || !rhsTy.hasRank()) + return RankedTensorType(); + + Type elemTy = lhsTy.getElementType(); + + if (operands.size() >= 3) { + if (auto biasRT = dyn_cast(operands[2].getType())) + return RankedTensorType::get(biasRT.getShape(), elemTy); + if (auto biasMR = dyn_cast(operands[2].getType())) { + if (biasMR.hasStaticShape()) + return RankedTensorType::get(biasMR.getShape(), elemTy); + } + } + + if (lhsTy.getRank() >= 2 && rhsTy.getRank() >= 2) { + int64_t M = lhsTy.getDimSize(0); + int64_t N = rhsTy.getDimSize(1); + return RankedTensorType::get({M, N}, elemTy); + } + + return RankedTensorType(); +} + +[[maybe_unused]] static RankedTensorType inferAccReturnFromAccIn(ValueRange operands) { + if (operands.empty()) + return RankedTensorType(); + if (auto accRT = dyn_cast(operands[0].getType())) + return accRT; + return RankedTensorType(); +} + +namespace mlir { +namespace pto { + +static LogicalResult parseShapeAndElem(AsmParser &parser, + SmallVectorImpl &shape, + Type &elementType, + bool allowDynamic) { + if (parser.parseLess()) + return failure(); + + if (parser.parseDimensionList(shape, allowDynamic)) + return failure(); + + if (parser.parseType(elementType)) + return failure(); + + if (parser.parseGreater()) + return failure(); + + return success(); +} + +static void printShapeAndElem(AsmPrinter &printer, + ArrayRef shape, + Type elementType) { + printer << "<"; + for (auto d : shape) { + if (d == ShapedType::kDynamic) + printer << "?"; + else + printer << d; + printer << "x"; + } + printer.printType(elementType); + printer << ">"; +} + +// ============================================================================= +// PartitionTensorViewType Implementation +// ============================================================================= + +Type PartitionTensorViewType::parse(AsmParser &parser) { + SmallVector shape; + Type elemTy; + if (failed(parseShapeAndElem(parser, shape, elemTy, /*allowDynamic=*/true))) + return Type(); + + return PartitionTensorViewType::get(parser.getContext(), shape, elemTy); +} + +void PartitionTensorViewType::print(AsmPrinter &printer) const { + printShapeAndElem(printer, getShape(), getElementType()); +} + +// ---- TileType ---- +Type TileType::parse(AsmParser &parser) { + SmallVector shape; + Type elemTy; + if (failed(parseShapeAndElem(parser, shape, elemTy, /*allowDynamic=*/true))) + return Type(); + return TileType::get(parser.getContext(), shape, elemTy); +} + +void TileType::print(AsmPrinter &printer) const { + printShapeAndElem(printer, getShape(), getElementType()); +} + +// ---- LocalArrayType ---- +// Asm form: !pto.local_array +// Static shape only (no '?'). Element type must be a scalar; this is enforced +// by the type verifier below. +Type LocalArrayType::parse(AsmParser &parser) { + SmallVector shape; + Type elemTy; + if (failed(parseShapeAndElem(parser, shape, elemTy, /*allowDynamic=*/false))) + return Type(); + return LocalArrayType::getChecked( + [&]() { return parser.emitError(parser.getNameLoc()); }, + parser.getContext(), shape, elemTy); +} + +void LocalArrayType::print(AsmPrinter &printer) const { + printShapeAndElem(printer, getShape(), getElementType()); +} + +LogicalResult LocalArrayType::verify( + llvm::function_ref emitError, + llvm::ArrayRef shape, Type elementType) { + if (shape.empty()) + return emitError() << "'!pto.local_array' requires at least one dimension"; + for (auto [i, d] : llvm::enumerate(shape)) { + if (d <= 0) + return emitError() + << "'!pto.local_array' dimension " << i + << " must be a positive static size, got " << d; + } + if (!elementType.isIntOrFloat()) + return emitError() + << "'!pto.local_array' element type must be a scalar integer or " + "float, got " + << elementType; + return success(); +} + +// ============================================================================= +// Decompose Helper (Reverse Engineering AffineMap -> Strides) +// ============================================================================= + +// Helper: 递归地将 Add 表达式拆解为单独的项列表 +static void flattenAddExpr(AffineExpr expr, SmallVectorImpl &terms) { + if (auto add = llvm::dyn_cast(expr)) { + if (add.getKind() == AffineExprKind::Add) { + flattenAddExpr(add.getLHS(), terms); + flattenAddExpr(add.getRHS(), terms); + return; + } + } + terms.push_back(expr); +} + +// Helper: 从 AffineMap 中提取 Strides +static void decomposeStridedLayout(AffineMap map, SmallVectorImpl &strides) { + // 1. 初始化 + strides.assign(map.getNumDims(), 0); + + if (map.getNumResults() != 1) return; + + // 2. 摊平表达式 + SmallVector terms; + flattenAddExpr(map.getResult(0), terms); + + // 3. 分析每一项 + for (auto term : terms) { + // 情况 A: dN * Const 或 Const * dN + if (auto mul = llvm::dyn_cast(term)) { + if (mul.getKind() == AffineExprKind::Mul) { + AffineExpr lhs = mul.getLHS(); + AffineExpr rhs = mul.getRHS(); + + // 尝试匹配 LHS=Dim, RHS=Const + if (auto dim = llvm::dyn_cast(lhs)) { + if (auto cst = llvm::dyn_cast(rhs)) { + strides[dim.getPosition()] = cst.getValue(); + continue; + } + } + + // 尝试匹配 LHS=Const, RHS=Dim (乘法交换律) + if (auto dim = llvm::dyn_cast(rhs)) { + if (auto cst = llvm::dyn_cast(lhs)) { + strides[dim.getPosition()] = cst.getValue(); + continue; + } + } + } + } + // 情况 B: 单独的 dN (隐含 Stride = 1) + else if (auto dim = llvm::dyn_cast(term)) { + strides[dim.getPosition()] = 1; + } + } +} + +// ============================================================================= +// [Critical] Strict Alignment Protocol Helper +// ============================================================================= +// This function is the SINGLE source of truth for building the AffineMap. +// Both the Parser and the Op Inference MUST use this exact function. +// It ensures that the order of AffineExpr addition is: +// 0 + (d0*str0 + d1*str1...) + (s0*str0 + s1*str1...) +// This guarantees bitwise-identical AffineMaps for verification. +static AffineMap buildStrictBitwiseAffineMap(MLIRContext *ctx, + ArrayRef strides, + bool isMultiDimSymbol) { + unsigned rank = strides.size(); + + // Step 1: Initialize with Constant(0) + AffineExpr totalExpr = getAffineConstantExpr(0, ctx); + + // Step 2: Add Dimensions (d0*str0 + d1*str1...) + // Strictly in order: 0, 1, 2... + for (unsigned i = 0; i < rank; ++i) { + auto dim = getAffineDimExpr(i, ctx); + auto str = getAffineConstantExpr(strides[i], ctx); + totalExpr = totalExpr + (dim * str); + } + + // Step 3: Add Symbols (s0*str0 + s1*str1...) + // Strictly in order: 0, 1, 2... + if (isMultiDimSymbol) { + for (unsigned i = 0; i < rank; ++i) { + auto sym = getAffineSymbolExpr(i, ctx); + auto str = getAffineConstantExpr(strides[i], ctx); + totalExpr = totalExpr + (sym * str); + } + } + // (Optional: handle single dynamic offset case if needed, omitted for clarity) + + // numSymbols is rank if multi-dim (for offsets), else 0 + unsigned numSymbols = isMultiDimSymbol ? rank : 0; + return AffineMap::get(rank, numSymbols, totalExpr); +} + + +// ============================================================================= +// Parser Implementation +// ============================================================================= + +// Helper for parsing [64, 1] +static ParseResult parseStrideList(AsmParser &parser, SmallVectorImpl &strides) { + if (parser.parseLSquare()) return failure(); + do { + int64_t stride; + if (parser.parseInteger(stride)) return failure(); + strides.push_back(stride); + } while (succeeded(parser.parseOptionalComma())); + if (parser.parseRSquare()) return failure(); + return success(); +} + +// The custom attribute parser for: strided<[64, 1], offset: [?, ?]> +[[maybe_unused]] static ParseResult parseStridedLayout(AsmParser &parser, Attribute &layout) { + if (parser.parseLess()) return failure(); + + // 1. Parse Strides + SmallVector strides; + if (parseStrideList(parser, strides)) return failure(); + + bool isMultiDim = false; + unsigned numSymbols = 0; + + // 2. Parse Offset + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseKeyword("offset") || parser.parseColon()) return failure(); + + // Check for multi-dim syntax: [?, ?] + if (succeeded(parser.parseOptionalLSquare())) { + isMultiDim = true; + do { + if (parser.parseQuestion()) return failure(); + numSymbols++; + } while (succeeded(parser.parseOptionalComma())); + if (parser.parseRSquare()) return failure(); + } else { + // Fallback for old scalar syntax '?' + if (parser.parseOptionalQuestion()) { /* handle single scalar */ } + } + } + + if (parser.parseGreater()) return failure(); + + // 3. Validation + if (isMultiDim && numSymbols != strides.size()) { + return parser.emitError(parser.getCurrentLocation(), + "Number of offset symbols must match rank"); + } + + // 4. [CALL SHARED BUILDER] + // Delegate to the strict builder + MLIRContext *ctx = parser.getContext(); + AffineMap map = buildStrictBitwiseAffineMap(ctx, strides, isMultiDim); + + layout = AffineMapAttr::get(map); + return success(); +} + +// ============================================================================= +// Printer Implementation +// ============================================================================= + +[[maybe_unused]] static void printLayout(AsmPrinter &printer, Attribute layoutAttr) { + if (!layoutAttr) return; + auto mapAttr = llvm::dyn_cast(layoutAttr); + if (!mapAttr) { printer << ", " << layoutAttr; return; } + + AffineMap map = mapAttr.getValue(); + if (map.isIdentity()) return; + + // 1. [核心修改] 反解 Strides + SmallVector strides; + decomposeStridedLayout(map, strides); + + printer << ", strided<["; + // 2. 打印真实的 strides + llvm::interleaveComma(strides, printer); + printer << "]"; + + // Print Offset: [?, ?] + unsigned numSyms = map.getNumSymbols(); + if (numSyms > 0) { + printer << ", offset: ["; + for (unsigned i = 0; i < numSyms; ++i) { + printer << "?"; + if (i < numSyms - 1) printer << ", "; + } + printer << "]"; + } + printer << ">"; +} + +// ---- TileBuf --- + + +// Tile subview 相关实现 + +// ============================================================================= +// Op Interface Implementation: SubViewOp +// ============================================================================= + +ParseResult mlir::pto::SubViewOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand source; + SmallVector offsets; + SmallVector valids; + Type sourceTy; + Type resultTy; + bool hasExplicitResultTy = false; + + if (parser.parseOperand(source) || parser.parseLSquare() || + parser.parseOperandList(offsets) || parser.parseRSquare() || + parser.parseKeyword("sizes")) + return failure(); + + ArrayAttr sizesAttr; + if (parser.parseAttribute(sizesAttr, "sizes", result.attributes)) + return failure(); + + if (succeeded(parser.parseOptionalKeyword("valid"))) { + OpAsmParser::UnresolvedOperand vrow, vcol; + if (parser.parseLSquare() || parser.parseOperand(vrow) || parser.parseComma() || + parser.parseOperand(vcol) || parser.parseRSquare()) + return failure(); + valids.push_back(vrow); + valids.push_back(vcol); + } + + if (parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(sourceTy)) + return failure(); + + if (succeeded(parser.parseOptionalArrow())) { + if (parser.parseType(resultTy)) + return failure(); + hasExplicitResultTy = true; + } + + if (parser.resolveOperand(source, sourceTy, result.operands)) + return failure(); + + Type indexTy = parser.getBuilder().getIndexType(); + if (parser.resolveOperands(offsets, indexTy, result.operands)) + return failure(); + if (!valids.empty() && + parser.resolveOperands(valids, indexTy, result.operands)) + return failure(); + + int32_t hasValid = valids.empty() ? 0 : 1; + result.addAttribute( + "operandSegmentSizes", + parser.getBuilder().getDenseI32ArrayAttr( + {1, static_cast(offsets.size()), hasValid, hasValid})); + + if (hasExplicitResultTy) { + result.addTypes(resultTy); + return success(); + } + + SmallVector inferredReturnTypes; + DictionaryAttr attrs = result.attributes.getDictionary(parser.getContext()); + if (failed(SubViewOp::inferReturnTypes( + parser.getContext(), std::nullopt, result.operands, attrs, nullptr, + RegionRange(), inferredReturnTypes))) { + return parser.emitError(parser.getCurrentLocation(), + "failed to infer pto.subview result type"); + } + result.addTypes(inferredReturnTypes); + return success(); +} + +void mlir::pto::SubViewOp::print(OpAsmPrinter &printer) { + printer << " " << getSource() << "["; + printer.printOperands(getOffsets()); + printer << "] sizes " << getSizes(); + if (getValidRow()) { + printer << " valid [" << getValidRow() << ", " << getValidCol() << "]"; + } + printer.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"operandSegmentSizes", + "sizes"}); + printer << " : " << getSource().getType() << " -> " << getResult().getType(); +} + +LogicalResult SubViewOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + + // 1. 获取 Source Type + if (operands.empty()) return failure(); + auto sourceType = llvm::dyn_cast(operands[0].getType()); + if (!sourceType) return failure(); + + // 2. 获取 subview 逻辑窗口(sizes) + ArrayAttr sizeAttr; + if (properties) { + const auto *prop = properties.as(); + if (prop) sizeAttr = prop->sizes; + } + if (!sizeAttr && attributes) { + sizeAttr = attributes.getAs("sizes"); + } + if (!sizeAttr) return failure(); + + SmallVector subviewShape; + for (auto attr : sizeAttr) { + int64_t dim = llvm::cast(attr).getInt(); + subviewShape.push_back(dim); + } + + // Design: subview 的结果 tile 类型显式表达逻辑子窗口 shape(sizes)。 + ArrayRef parentShape = sourceType.getShape(); + if (subviewShape.size() != parentShape.size()) + return failure(); + + // Derive valid shape from explicit valid_row/valid_col when provided. + // Otherwise default to subview shape (no parent valid-shape inheritance). + SmallVector validShape; + constexpr int64_t kDynamicValidDim = -1; + int64_t rank = static_cast(subviewShape.size()); + Value explicitVRow; + Value explicitVCol; + + // Robustly decode optional valid operands using AttrSizedOperandSegments: + // [source, offsets..., valid_row?, valid_col?] + if (attributes) { + if (auto segAttr = + attributes.getAs("operandSegmentSizes")) { + ArrayRef segs = segAttr.asArrayRef(); + if (segs.size() == 4) { + int32_t srcSeg = segs[0]; + int32_t offSeg = segs[1]; + int32_t vRowSeg = segs[2]; + int32_t vColSeg = segs[3]; + if (srcSeg == 1 && offSeg >= 0 && (vRowSeg == 0 || vRowSeg == 1) && + (vColSeg == 0 || vColSeg == 1)) { + size_t idx = static_cast(srcSeg + offSeg); + if (vRowSeg == 1 && idx < operands.size()) + explicitVRow = operands[idx++]; + if (vColSeg == 1 && idx < operands.size()) + explicitVCol = operands[idx]; + } + } + } + } + + // Fallback for legacy callers that may not provide operandSegmentSizes. + if (!explicitVRow && !explicitVCol && rank == 2) { + size_t expectedWithoutValid = static_cast(1 + rank); + if (operands.size() >= expectedWithoutValid + 2) { + explicitVRow = operands[expectedWithoutValid]; + explicitVCol = operands[expectedWithoutValid + 1]; + } + } + + for (size_t i = 0, e = subviewShape.size(); i < e; ++i) { + int64_t vdim = subviewShape[i]; + Value explicitV = (i == 0) ? explicitVRow : (i == 1 ? explicitVCol : Value()); + if (explicitV) { + auto cst = getConstIndexValue(explicitV); + vdim = cst ? std::min(*cst, subviewShape[i]) : kDynamicValidDim; + } + validShape.push_back(vdim); + } + + // 3. 继承 Config (若为空使用默认) + auto cfg = sourceType.getConfigAttr(); + if (!cfg) cfg = TileBufConfigAttr::getDefault(context); + + // 4. 构建 Result Type + auto canonicalValidShape = canonicalizeTileBufValidShape(validShape); + auto resultType = TileBufType::get( + context, subviewShape, sourceType.getElementType(), + sourceType.getMemorySpace(), canonicalValidShape, cfg); + + inferredReturnTypes.push_back(resultType); + return success(); +} + +// ============================================================================= +// SubViewOp verifier +// ============================================================================= +static bool getConstIndex(Value v, int64_t &out) { + if (auto cOp = v.getDefiningOp()) { + out = cOp.value(); + return true; + } + if (auto cInt = v.getDefiningOp()) { + out = cInt.value(); + return true; + } + if (auto cOp = v.getDefiningOp()) { + if (auto ia = dyn_cast(cOp.getValue())) { + out = ia.getInt(); + return true; + } + } + if (auto castOp = v.getDefiningOp()) + return getConstIndex(castOp.getIn(), out); + if (auto extOp = v.getDefiningOp()) + return getConstIndex(extOp.getIn(), out); + if (auto extOp = v.getDefiningOp()) + return getConstIndex(extOp.getIn(), out); + if (auto truncOp = v.getDefiningOp()) + return getConstIndex(truncOp.getIn(), out); + return false; +} + +static LogicalResult computeInnerShape(TileBufConfigAttr cfg, Type elemTy, + int64_t &innerRows, int64_t &innerCols, + bool &boxed, int32_t &bl, int32_t &sl) { + auto readBLayoutI32 = [](Attribute attr, int32_t &out) -> bool { + if (auto a = dyn_cast(attr)) { + out = (int32_t)a.getValue(); + return true; + } + if (auto a = dyn_cast(attr)) { + out = (int32_t)a.getInt(); + return true; + } + return false; + }; + auto readSLayoutI32 = [](Attribute attr, int32_t &out) -> bool { + if (auto a = dyn_cast(attr)) { + out = (int32_t)a.getValue(); + return true; + } + if (auto a = dyn_cast(attr)) { + out = (int32_t)a.getInt(); + return true; + } + return false; + }; + bl = 0; + sl = 0; + int32_t fr = 512; + (void)readBLayoutI32(cfg.getBLayout(), bl); + (void)readSLayoutI32(cfg.getSLayout(), sl); + if (auto attr = dyn_cast(cfg.getSFractalSize())) fr = (int32_t)attr.getInt(); + + boxed = (sl != 0); + if (!boxed) { + innerRows = 1; + innerCols = 1; + return success(); + } + + int64_t elemBytes = static_cast(getElemByteSize(elemTy)); + if (elemBytes <= 0) return failure(); + + if (fr == 1024) { + innerRows = 16; + innerCols = 16; + return success(); + } + if (fr == 32) { + innerRows = 16; + innerCols = 2; + return success(); + } + if (fr == 512) { + if (sl == 1) { + innerRows = 16; + innerCols = 32 / elemBytes; + return success(); + } + if (sl == 2) { + innerRows = 32 / elemBytes; + innerCols = 16; + return success(); + } + } + return failure(); +} + +mlir::LogicalResult mlir::pto::SubViewOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + auto srcTy = llvm::dyn_cast(getSource().getType()); + auto dstTy = llvm::dyn_cast(getResult().getType()); + if (!srcTy || !dstTy) + return emitOpError("expects tile_buf src and tile_buf result"); + if (srcTy.getRank() != 2 || dstTy.getRank() != 2) + return emitOpError("expects rank-2 tilebuf for src/dst"); + + auto sizesAttr = getSizes(); + if (!sizesAttr || sizesAttr.size() != 2) + return emitOpError("subview expects 2D sizes"); + int64_t sizeR = cast(sizesAttr[0]).getInt(); + int64_t sizeC = cast(sizesAttr[1]).getInt(); + if (sizeR <= 0 || sizeC <= 0) + return emitOpError("subview sizes must be positive"); + if (getOffsets().size() != 2) + return emitOpError("subview expects 2D offsets"); + + int64_t offR = 0, offC = 0; + bool offRConst = getConstIndex(getOffsets()[0], offR); + bool offCConst = getConstIndex(getOffsets()[1], offC); + if (offRConst && offR < 0) + return emitOpError("subview offsets must be non-negative"); + if (offCConst && offC < 0) + return emitOpError("subview offsets must be non-negative"); + + bool hasValidRow = static_cast(getValidRow()); + bool hasValidCol = static_cast(getValidCol()); + if (hasValidRow != hasValidCol) + return emitOpError( + "subview expects valid_row and valid_col to be both present or both absent"); + + if (hasValidRow) { + int64_t vRow = 0, vCol = 0; + if (getConstIndex(getValidRow(), vRow)) { + if (vRow <= 0) + return emitOpError("valid_row must be positive when constant"); + if (vRow > sizeR) + return emitOpError("valid_row must be <= subview row size"); + } + if (getConstIndex(getValidCol(), vCol)) { + if (vCol <= 0) + return emitOpError("valid_col must be positive when constant"); + if (vCol > sizeC) + return emitOpError("valid_col must be <= subview col size"); + } + } + + auto dstShape = dstTy.getShape(); + if (dstShape.size() != 2) + return emitOpError("expects result to be rank-2"); + auto srcShape = srcTy.getShape(); + if (srcShape.size() != 2) + return emitOpError("expects source to be rank-2"); + if (dstShape[0] != sizeR || dstShape[1] != sizeC) + return emitOpError("expects result shape to match subview sizes"); + + if (dstTy.getElementType() != srcTy.getElementType()) + return emitOpError("expects result element type to match source"); + if (dstTy.getMemorySpace() != srcTy.getMemorySpace()) + return emitOpError("expects result address space to match source"); + auto srcCfg = srcTy.getConfigAttr(); + if (!srcCfg) srcCfg = TileBufConfigAttr::getDefault(getContext()); + auto dstCfg = dstTy.getConfigAttr(); + if (!dstCfg) dstCfg = TileBufConfigAttr::getDefault(getContext()); + if (dstCfg != srcCfg) + return emitOpError("expects result tile config to match source"); + + // Design choice: when valid[...] is omitted, infer result valid_shape from + // subview sizes directly. We intentionally do not constrain it by source + // valid_shape to allow user-controlled subview semantics. + + auto expectedValidDim = [&](Value explicitValid, int64_t defaultSize) { + if (!explicitValid) + return defaultSize; + int64_t c = 0; + if (getConstIndex(explicitValid, c)) + return std::min(c, defaultSize); + return ShapedType::kDynamic; + }; + int64_t expectedVRow = expectedValidDim(getValidRow(), sizeR); + int64_t expectedVCol = expectedValidDim(getValidCol(), sizeC); + auto dstValid = dstTy.getValidShape(); + if (dstValid.size() != 2) + return emitOpError("expects result to have rank-2 valid_shape"); + if (dstValid[0] != expectedVRow) + return emitOpError("expects result valid_shape[0] to match inferred/explicit valid_row"); + if (dstValid[1] != expectedVCol) + return emitOpError("expects result valid_shape[1] to match inferred/explicit valid_col"); + + auto cfg = srcTy.getConfigAttr(); + if (!cfg) cfg = TileBufConfigAttr::getDefault(getContext()); + + int64_t innerRows = 1, innerCols = 1; + bool boxed = false; + int32_t bl = 0, sl = 0; + if (failed(computeInnerShape(cfg, srcTy.getElementType(), innerRows, innerCols, + boxed, bl, sl))) + return emitOpError("unsupported tile layout for subview"); + + if (!boxed) + return success(); + + // Boxed layout: require static 2D sizes with inner alignment. Offsets may be + // dynamic, but static offsets must be aligned. + if (sizeR % innerRows != 0 || sizeC % innerCols != 0) + return emitOpError("boxed layout subview sizes must be multiples of inner shape"); + + if (offRConst) { + if (offR % innerRows != 0) + return emitOpError("boxed layout subview offsets must be multiples of inner shape"); + } + if (offCConst) { + if (offC % innerCols != 0) + return emitOpError("boxed layout subview offsets must be multiples of inner shape"); + } + + (void)bl; + if (srcShape.size() != 2 || + srcShape[0] == ShapedType::kDynamic || + srcShape[1] == ShapedType::kDynamic) { + return emitOpError("boxed layout subview requires static source shape"); + } + + return success(); +} + +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +// ============================================================================= +// Helper Functions +// ============================================================================= + +[[maybe_unused]] static AddressSpace getAddressSpace(Value val) { + auto type = llvm::dyn_cast(val.getType()); + if (!type) return AddressSpace::Zero; // Default + + // 假设你的 AddressSpaceAttr 存储在 MemRef 的 memorySpace 中 + // 需要根据你的 getPTOAddressSpaceAttr 实现来调整 + auto attr = llvm::dyn_cast_or_null(type.getMemorySpace()); + if (attr) return attr.getAddressSpace(); + return AddressSpace::Zero; +} + +// ============================================================================= +// Side Effects Implementation +// ============================================================================= + +// [Fix] 辅助函数:重载以支持 OpOperand* 和 OpResult,避免直接传 Value + +// 针对操作数 (Operand) 的重载 +static void addEffect( + SmallVectorImpl> &effects, + OpOperand *operand, MemoryEffects::Effect *effect) { + if (operand) + effects.emplace_back(effect, operand, SideEffects::DefaultResource::get()); +} + +// 针对结果 (Result) 的重载 +static void addEffect( + SmallVectorImpl> &effects, + OpResult result, MemoryEffects::Effect *effect) { + if (result) + effects.emplace_back(effect, result, SideEffects::DefaultResource::get()); +} + +// === TLoadOp === +// Read: src, Write: dst +// 针对 OpOperand* 的重载 +void TLoadOp::getEffects(SmallVectorImpl> &effects) { + // [Fix] 单个操作数,直接取地址 + addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +void TPrefetchOp::getEffects( + SmallVectorImpl> &effects) { + addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +// === TAbsOp === +// Read: src, Write: dst +void TAbsOp::getEffects( + SmallVectorImpl> &effects) { + addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +// === TStoreOp === +// Read: src, Write: dst (GM) +void TStoreOp::getEffects(SmallVectorImpl> &effects) { + addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); + auto preQuantRange = getPreQuantScalarMutable(); + if (!preQuantRange.empty()) + addEffect(effects, &*preQuantRange.begin(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +// === TMovOp === +// Read: src, Write: dst +void TMovOp::getEffects(SmallVectorImpl> &effects) { + addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); + auto fpRange = getFpMutable(); + if (!fpRange.empty()) + addEffect(effects, &*fpRange.begin(), MemoryEffects::Read::get()); + auto preQuantRange = getPreQuantScalarMutable(); + if (!preQuantRange.empty()) + addEffect(effects, &*preQuantRange.begin(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +#define PTO_ADD_READ(operand) addEffect(effects, &(operand), MemoryEffects::Read::get()) +#define PTO_ADD_WRITE(operand) addEffect(effects, &(operand), MemoryEffects::Write::get()) + +#define PTO_DEFINE_UNARY_EFFECTS(OpClass, srcOperand, dstOperand) \ + void OpClass::getEffects( \ + SmallVectorImpl> &effects) { \ + PTO_ADD_READ(srcOperand); \ + PTO_ADD_WRITE(dstOperand); \ + } + +#define PTO_DEFINE_BINARY_EFFECTS(OpClass, lhsOperand, rhsOperand, dstOperand) \ + void OpClass::getEffects( \ + SmallVectorImpl> &effects) { \ + PTO_ADD_READ(lhsOperand); \ + PTO_ADD_READ(rhsOperand); \ + PTO_ADD_WRITE(dstOperand); \ + } + +#define PTO_DEFINE_TERNARY_EFFECTS(OpClass, op0, op1, op2, dstOperand) \ + void OpClass::getEffects( \ + SmallVectorImpl> &effects) { \ + PTO_ADD_READ(op0); \ + PTO_ADD_READ(op1); \ + PTO_ADD_READ(op2); \ + PTO_ADD_WRITE(dstOperand); \ + } + +#define PTO_DEFINE_QUATERNARY_EFFECTS(OpClass, op0, op1, op2, op3, dstOperand) \ + void OpClass::getEffects( \ + SmallVectorImpl> &effects) { \ + PTO_ADD_READ(op0); \ + PTO_ADD_READ(op1); \ + PTO_ADD_READ(op2); \ + PTO_ADD_READ(op3); \ + PTO_ADD_WRITE(dstOperand); \ + } + +void LoadScalarOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getPtrMutable()); +} + +void StoreScalarOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_WRITE(getPtrMutable()); +} + +// === Tile/Device ops added for InsertSync === + +// MGATHER: Read(mem, idx) -> Write(dst) +void MGatherOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getMemMutable()); + PTO_ADD_READ(getIdxMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +// MSCATTER: Read(src, idx) -> Write(mem) +void MScatterOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_READ(getIdxMutable()); + PTO_ADD_WRITE(getMemMutable()); +} + +// TGETVAL: Read(src) -> scalar result +void TGetValOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); +} + +void THistogramOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_READ(getIdxMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +void TGetScaleAddrOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +// TSETVAL: Write(dst) (single element update) +void TSetValOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_WRITE(getDstMutable()); +} + +// SET_VALIDSHAPE: update runtime valid row/col metadata on source tile in-place. +void SetValidShapeOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_WRITE(getSourceMutable()); +} + +// GET_VALIDSHAPE: read runtime valid row/col metadata from source tile. +void GetValidShapeOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSourceMutable()); +} + +// Elementwise + reductions: mostly PIPE_V tilebuf ops +PTO_DEFINE_BINARY_EFFECTS(TAddOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_TERNARY_EFFECTS(TAddCOp, getSrc0Mutable(), getSrc1Mutable(), getSrc2Mutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TAddSOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TAddSCOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +void TAxpyOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_READ(getScalarMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +PTO_DEFINE_BINARY_EFFECTS(TAndOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TConcatOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_QUATERNARY_EFFECTS(TConcatidxOp, getSrc0Mutable(), getSrc1Mutable(), getSrc0IdxMutable(), getSrc1IdxMutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TAndSOp, getSrcMutable(), getDstMutable()) + +// TCI: Write(dst) (generates sequence) +void TCIOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_WRITE(getDstMutable()); +} + +// TTRI: Write(dst) (generates triangular mask) +void TTriOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_WRITE(getDstMutable()); +} + +PTO_DEFINE_BINARY_EFFECTS(TCmpOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TCmpSOp, getSrcMutable(), getDstMutable()) + +PTO_DEFINE_UNARY_EFFECTS(TColExpandOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TColExpandAddOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TColExpandMulOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TColExpandDivOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TColExpandSubOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TColExpandExpdifOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TColExpandMaxOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TColExpandMinOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TColMaxOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TColMinOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TColProdOp, getSrcMutable(), getDstMutable()) + +void TColArgMaxOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +void TColArgMinOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +void TColSumOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + auto tmp = getTmpMutable(); + if (!tmp.empty()) { + PTO_ADD_WRITE(tmp[0]); + } + PTO_ADD_WRITE(getDstMutable()); +} + +void TCvtOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getDstMutable()); +} +void TRandomOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_WRITE(getDstMutable()); +} +PTO_DEFINE_BINARY_EFFECTS(TDivOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) + +// TDIVS has custom assembly format; conservatively treat first 2 operands as reads. +void TDivSOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_READ(getScalarMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +PTO_DEFINE_UNARY_EFFECTS(TExpOp, getSrcMutable(), getDstMutable()) + +// TEXPANDS: Write(dst) (broadcast scalar) +void TExpandsOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_WRITE(getDstMutable()); +} + +// TEXTRACT: Read(src) -> Write(dst) +void TExtractOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +// TINSERT: Read(src) -> Write(dst) +void TInsertOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +// TEXTRACT_FP: Read(src), Read(fp) -> Write(dst) +void TExtractFPOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_READ(getFpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +// TINSERT_FP: Read(src), Read(fp) -> Write(dst) +void TInsertFPOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_READ(getFpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +PTO_DEFINE_UNARY_EFFECTS(TFillPadOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TFillPadExpandOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TFillPadInplaceOp, getSrcMutable(), getDstMutable()) + +void TGatherOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + if (auto cdst = getCdstMutable(); !cdst.empty()) + PTO_ADD_WRITE(cdst[0]); + if (auto indices = getIndicesMutable(); !indices.empty()) + PTO_ADD_READ(indices[0]); + if (auto tmp = getTmpMutable(); !tmp.empty()) + PTO_ADD_READ(tmp[0]); + PTO_ADD_WRITE(getDstMutable()); +} + +PTO_DEFINE_BINARY_EFFECTS(TGatherBOp, getSrcMutable(), getOffsetsMutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TLogOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TLReluOp, getSrcMutable(), getDstMutable()) + +PTO_DEFINE_BINARY_EFFECTS(TMaxOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TMaxSOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TMinOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TMinSOp, getSrcMutable(), getDstMutable()) + +PTO_DEFINE_BINARY_EFFECTS(TMovFPOp, getSrcMutable(), getFpMutable(), getDstMutable()) + +void TMrgSortOp::getEffects( + SmallVectorImpl> &effects) { + for (auto &opnd : getSrcsMutable()) { + PTO_ADD_READ(opnd); + } + auto tmp = getTmpMutable(); + if (!tmp.empty()) + PTO_ADD_WRITE(tmp[0]); + for (auto &opnd : getDstsMutable()) { + PTO_ADD_WRITE(opnd); + } + auto executed = getExcutedMutable(); + if (!executed.empty()) { + PTO_ADD_WRITE(executed[0]); + } +} + +PTO_DEFINE_BINARY_EFFECTS(TMulOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TMulSOp, getSrc0Mutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TNegOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TNotOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TOrOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TOrSOp, getSrcMutable(), getDstMutable()) + +PTO_DEFINE_BINARY_EFFECTS(TPartAddOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TPartMaxOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TPartMinOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +void TPartArgMaxOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrc0Mutable()); + PTO_ADD_READ(getSrc1Mutable()); + PTO_ADD_READ(getSrc0IdxMutable()); + PTO_ADD_READ(getSrc1IdxMutable()); + PTO_ADD_WRITE(getDstMutable()); + PTO_ADD_WRITE(getDstIdxMutable()); +} +void TPartArgMinOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrc0Mutable()); + PTO_ADD_READ(getSrc1Mutable()); + PTO_ADD_READ(getSrc0IdxMutable()); + PTO_ADD_READ(getSrc1IdxMutable()); + PTO_ADD_WRITE(getDstMutable()); + PTO_ADD_WRITE(getDstIdxMutable()); +} +PTO_DEFINE_BINARY_EFFECTS(TPartMulOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +// TPRELU: Read(src0, src1) -> Write(tmp, dst) +void TPReluOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrc0Mutable()); + PTO_ADD_READ(getSrc1Mutable()); + // A5 pto-isa TPRELU implementation does not consume tmp; modeling tmp as a + // write-only scratch on A5 incorrectly inflates local-memory planning and + // can trigger false vec-overflow diagnostics. + if (getTargetArch(getOperation()) != PTOArch::A5) + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +void TQuantOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_READ(getFpMutable()); + auto offsetRange = getOffsetMutable(); + if (!offsetRange.empty()) + PTO_ADD_READ(offsetRange[0]); + PTO_ADD_WRITE(getDstMutable()); +} +PTO_DEFINE_TERNARY_EFFECTS(TDequantOp, getSrcMutable(), getScaleMutable(), + getOffsetMutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TRecipOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TReluOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TFModOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TFModSOp, getSrcMutable(), getDstMutable()) +void TRemOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrc0Mutable()); + PTO_ADD_READ(getSrc1Mutable()); + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +void TRemSOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} +PTO_DEFINE_UNARY_EFFECTS(TRowExpandOp, getSrcMutable(), getDstMutable()) + +void TRowExpandDivOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrc0Mutable()); + PTO_ADD_READ(getSrc1Mutable()); + auto tmp = getTmpMutable(); + if (!tmp.empty()) + PTO_ADD_WRITE(tmp[0]); + PTO_ADD_WRITE(getDstMutable()); +} + +void TRowExpandMulOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrc0Mutable()); + PTO_ADD_READ(getSrc1Mutable()); + auto tmp = getTmpMutable(); + if (!tmp.empty()) + PTO_ADD_WRITE(tmp[0]); + PTO_ADD_WRITE(getDstMutable()); +} + +void TRowExpandSubOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrc0Mutable()); + PTO_ADD_READ(getSrc1Mutable()); + auto tmp = getTmpMutable(); + if (!tmp.empty()) + PTO_ADD_WRITE(tmp[0]); + PTO_ADD_WRITE(getDstMutable()); +} + +PTO_DEFINE_BINARY_EFFECTS(TRowExpandAddOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) + +void TRowExpandExpdifOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrc0Mutable()); + PTO_ADD_READ(getSrc1Mutable()); + auto tmp = getTmpMutable(); + if (!tmp.empty()) + PTO_ADD_WRITE(tmp[0]); + PTO_ADD_WRITE(getDstMutable()); +} + +void TRowExpandMaxOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrc0Mutable()); + PTO_ADD_READ(getSrc1Mutable()); + auto tmp = getTmpMutable(); + if (!tmp.empty()) + PTO_ADD_WRITE(tmp[0]); + PTO_ADD_WRITE(getDstMutable()); +} + +void TRowExpandMinOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrc0Mutable()); + PTO_ADD_READ(getSrc1Mutable()); + auto tmp = getTmpMutable(); + if (!tmp.empty()) + PTO_ADD_WRITE(tmp[0]); + PTO_ADD_WRITE(getDstMutable()); +} + +// Row reductions use tmp scratch tile. +void TRowMaxOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +void TRowArgMaxOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + // A5 lowering does not consume tmp for TROWARGMAX; modeling tmp as a + // scratch write inflates local-memory planning and can trigger false + // vec-overflow diagnostics, mirroring the fixed A5 TPRELU issue. + if (getTargetArch(getOperation()) != PTOArch::A5) + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +void TRowMinOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +void TRowArgMinOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + // A5 lowering does not consume tmp for TROWARGMIN; modeling tmp as a + // scratch write inflates local-memory planning and can trigger false + // vec-overflow diagnostics, mirroring the fixed A5 TPRELU issue. + if (getTargetArch(getOperation()) != PTOArch::A5) + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +void TRowSumOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +void TRowProdOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} +void TRsqrtOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + auto tmp = getTmpMutable(); + if (!tmp.empty()) + PTO_ADD_WRITE(tmp[0]); + PTO_ADD_WRITE(getDstMutable()); +} + +void TScatterOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + if (getIndexes()) { + auto idx = getIndexesMutable(); + if (!idx.empty()) + PTO_ADD_READ(idx[0]); + } + PTO_ADD_WRITE(getDstMutable()); +} + +// Select: Read(mask, src0, src1) -> Write(tmp, dst) +void TSelOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getMaskMutable()); + PTO_ADD_READ(getSrc0Mutable()); + PTO_ADD_READ(getSrc1Mutable()); + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +// TSELS: Read(src0, src1) -> Write(tmp, dst) +void TSelSOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getMaskMutable()); + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +PTO_DEFINE_BINARY_EFFECTS(TShlOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TShrOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TShlSOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TShrSOp, getSrcMutable(), getDstMutable()) + +// TSORT32: Read(src, idx) -> Write(dst [, tmp]) +void TSort32Op::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_READ(getIdxMutable()); + auto tmp = getTmpMutable(); + if (!tmp.empty()) + PTO_ADD_WRITE(tmp[0]); + PTO_ADD_WRITE(getDstMutable()); +} + +PTO_DEFINE_UNARY_EFFECTS(TSqrtOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TSubOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_TERNARY_EFFECTS(TSubCOp, getSrc0Mutable(), getSrc1Mutable(), getSrc2Mutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TSubSOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TSubSCOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) + +// TXORS: Read(src) -> Write(tmp, dst) +void TXorSOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +// TXOR: Read(src0, src1) -> Write(tmp?, dst) +void TXorOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrc0Mutable()); + PTO_ADD_READ(getSrc1Mutable()); + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +// TTRANS: Read(src) -> Write(tmp, dst) +void TTransOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +void TPrintOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getSrcMutable()); +} + +#undef PTO_DEFINE_TERNARY_EFFECTS +#undef PTO_DEFINE_BINARY_EFFECTS +#undef PTO_DEFINE_UNARY_EFFECTS +#undef PTO_ADD_WRITE +#undef PTO_ADD_READ + +// === TMatmulOp === +// Read: lhs, rhs, (bias), Write: dst +void TMatmulOp::getEffects(SmallVectorImpl> &effects) { + // Singleton -> 直接取地址 + addEffect(effects, &getLhsMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getRhsMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +// === TMatmulAccOp === +// Read: acc_in, lhs, rhs, Write: dst +void TMatmulAccOp::getEffects(SmallVectorImpl> &effects) { + addEffect(effects, &getAccInMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getLhsMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getRhsMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +// === TMatmulBiasOp === +// Read: a, b, bias, Write: dst +void TMatmulBiasOp::getEffects(SmallVectorImpl> &effects) { + addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); + // 这里的 bias 是必选的 AnyType:$bias,所以是 Singleton + addEffect(effects, &getBiasMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +// === TGemvOp === +// Read: lhs, rhs, Write: dst +void TGemvOp::getEffects(SmallVectorImpl> &effects) { + addEffect(effects, &getLhsMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getRhsMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +// === TGemvAccOp === +// Read: acc_in, lhs, rhs, Write: dst +void TGemvAccOp::getEffects(SmallVectorImpl> &effects) { + addEffect(effects, &getAccInMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getLhsMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getRhsMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +// === TGemvBiasOp === +// Read: a, b, bias, Write: dst +void TGemvBiasOp::getEffects(SmallVectorImpl> &effects) { + addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBiasMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +// === TGemvMxOp === +// Read: a, a_scale, b, b_scale, Write: dst +void TGemvMxOp::getEffects(SmallVectorImpl> &effects) { + addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getAScaleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBScaleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +// === TGemvMxAccOp === +// Read: c_in, a, a_scale, b, b_scale, Write: dst +void TGemvMxAccOp::getEffects(SmallVectorImpl> &effects) { + addEffect(effects, &getCInMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getAScaleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBScaleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +// === TGemvMxBiasOp === +// Read: a, a_scale, b, b_scale, bias, Write: dst +void TGemvMxBiasOp::getEffects(SmallVectorImpl> &effects) { + addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getAScaleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBScaleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBiasMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +// === TMatmulOp === +void TMatmulMxOp::getEffects(SmallVectorImpl> &effects) { + addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getAScaleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBScaleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +// === TMatmulAccMxOp === +// Read: acc_in, lhs, rhs, Write: dst +void TMatmulMxAccOp::getEffects(SmallVectorImpl> &effects) { + addEffect(effects, &getCInMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getAScaleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBScaleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +// === TMatmulBiasMxOp === +// Read: a, b, bias, Write: dst +void TMatmulMxBiasOp::getEffects(SmallVectorImpl> &effects) { + addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getAScaleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBScaleMutable(), MemoryEffects::Read::get()); + // 这里的 bias 是必选的 AnyType:$bias,所以是 Singleton + addEffect(effects, &getBiasMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +static bool isInsideSectionCube(Operation *op) { + return op->getParentOfType() != nullptr; +} + +static bool isInsideSectionVector(Operation *op) { + return op->getParentOfType() != nullptr; +} + +static std::optional +getEnclosingFunctionKernelKind(Operation *op) { + auto funcOp = op->getParentOfType(); + if (!funcOp) + return std::nullopt; + + auto kernelKindAttr = + funcOp->getAttrOfType( + FunctionKernelKindAttr::name); + if (!kernelKindAttr) + return std::nullopt; + + return kernelKindAttr.getKernelKind(); +} + +static bool isInsideSectionOrAttributedKernel(Operation *op) { + return isInsideSectionCube(op) || isInsideSectionVector(op) || + getEnclosingFunctionKernelKind(op).has_value(); +} + +static LogicalResult verifySplitAttr(Operation *op, int64_t split) { + if (split < 0 || split > 2) + return op->emitOpError("expects 'split' to be 0, 1, or 2"); + return success(); +} + +static LogicalResult verifyFrontendKernelKind(Operation *op, + FunctionKernelKind expected, + StringRef kernelName) { + auto kernelKind = getEnclosingFunctionKernelKind(op); + if (!kernelKind || *kernelKind != expected) { + return op->emitOpError("must be inside a ") + << kernelName << " kernel function"; + } + return success(); +} + +static ParseResult parseFrontendInitializePipeOp(OpAsmParser &parser, + OperationState &result) { + NamedAttrList attrs; + bool sawId = false; + bool sawDirMask = false; + bool sawSlotSize = false; + bool sawLocalSlotNum = false; + bool sawNoSplit = false; + + if (parser.parseLBrace()) + return failure(); + + while (failed(parser.parseOptionalRBrace())) { + StringRef keyword; + if (parser.parseKeyword(&keyword) || parser.parseEqual()) + return failure(); + + if (keyword == "id") { + if (sawId) + return parser.emitError(parser.getCurrentLocation(), + "duplicate 'id' clause"); + IntegerAttr idAttr; + if (parser.parseAttribute(idAttr, parser.getBuilder().getI32Type(), "id", + attrs)) + return failure(); + sawId = true; + } else if (keyword == "dir_mask") { + if (sawDirMask) + return parser.emitError(parser.getCurrentLocation(), + "duplicate 'dir_mask' clause"); + IntegerAttr dirMaskAttr; + if (parser.parseAttribute(dirMaskAttr, parser.getBuilder().getI8Type(), + "dir_mask", attrs)) + return failure(); + sawDirMask = true; + } else if (keyword == "slot_size") { + if (sawSlotSize) + return parser.emitError(parser.getCurrentLocation(), + "duplicate 'slot_size' clause"); + IntegerAttr slotSizeAttr; + if (parser.parseAttribute(slotSizeAttr, parser.getBuilder().getI32Type(), + "slot_size", attrs)) + return failure(); + sawSlotSize = true; + } else if (keyword == "local_slot_num") { + if (sawLocalSlotNum) + return parser.emitError(parser.getCurrentLocation(), + "duplicate 'local_slot_num' clause"); + IntegerAttr localSlotNumAttr; + if (parser.parseAttribute(localSlotNumAttr, parser.getBuilder().getI32Type(), + "local_slot_num", attrs)) + return failure(); + sawLocalSlotNum = true; + } else if (keyword == "nosplit") { + if (sawNoSplit) + return parser.emitError(parser.getCurrentLocation(), + "duplicate 'nosplit' clause"); + BoolAttr noSplitAttr; + if (parser.parseAttribute(noSplitAttr, "nosplit", attrs)) + return failure(); + sawNoSplit = true; + } else { + return parser.emitError(parser.getCurrentLocation()) + << "unexpected keyword '" << keyword << "'"; + } + + if (succeeded(parser.parseOptionalRBrace())) + break; + if (parser.parseComma()) + return failure(); + } + + if (!sawDirMask) + return parser.emitError(parser.getNameLoc(), "expected 'dir_mask' clause"); + if (!sawSlotSize) + return parser.emitError(parser.getNameLoc(), "expected 'slot_size' clause"); + if (!sawId) + attrs.set("id", parser.getBuilder().getI32IntegerAttr(0)); + + OpAsmParser::UnresolvedOperand gmSlotBuffer; + OpAsmParser::UnresolvedOperand gmSlotTensor; + OpAsmParser::UnresolvedOperand c2vConsumerBuf; + OpAsmParser::UnresolvedOperand v2cConsumerBuf; + Type gmSlotBufferTy; + Type gmSlotTensorTy; + Type c2vConsumerBufTy; + Type v2cConsumerBufTy; + bool hasGmSlotBuffer = false; + bool hasGmSlotTensor = false; + bool hasC2vConsumerBuf = false; + bool hasV2cConsumerBuf = false; + + if (parser.parseLParen()) + return failure(); + while (failed(parser.parseOptionalRParen())) { + StringRef keyword; + if (parser.parseKeyword(&keyword) || parser.parseEqual()) + return failure(); + + if (keyword == "gm_slot_buffer") { + if (hasGmSlotBuffer) + return parser.emitError(parser.getCurrentLocation(), + "duplicate 'gm_slot_buffer' operand"); + if (parser.parseOperand(gmSlotBuffer) || + parser.parseColonType(gmSlotBufferTy)) + return failure(); + hasGmSlotBuffer = true; + } else if (keyword == "gm_slot_tensor") { + if (hasGmSlotTensor) + return parser.emitError(parser.getCurrentLocation(), + "duplicate 'gm_slot_tensor' operand"); + if (parser.parseOperand(gmSlotTensor) || + parser.parseColonType(gmSlotTensorTy)) + return failure(); + hasGmSlotTensor = true; + } else if (keyword == "c2v_consumer_buf") { + if (hasC2vConsumerBuf) + return parser.emitError(parser.getCurrentLocation(), + "duplicate 'c2v_consumer_buf' operand"); + if (parser.parseOperand(c2vConsumerBuf) || + parser.parseColonType(c2vConsumerBufTy)) + return failure(); + hasC2vConsumerBuf = true; + } else if (keyword == "v2c_consumer_buf") { + if (hasV2cConsumerBuf) + return parser.emitError(parser.getCurrentLocation(), + "duplicate 'v2c_consumer_buf' operand"); + if (parser.parseOperand(v2cConsumerBuf) || + parser.parseColonType(v2cConsumerBufTy)) + return failure(); + hasV2cConsumerBuf = true; + } else { + return parser.emitError(parser.getCurrentLocation()) + << "unexpected initialize_pipe operand '" << keyword << "'"; + } + + if (succeeded(parser.parseOptionalRParen())) + break; + if (parser.parseComma()) + return failure(); + } + + if (parser.parseOptionalAttrDict(attrs)) + return failure(); + + result.addAttributes(attrs); + result.addAttribute("operandSegmentSizes", + parser.getBuilder().getDenseI32ArrayAttr( + {hasGmSlotBuffer ? 1 : 0, hasGmSlotTensor ? 1 : 0, + hasC2vConsumerBuf ? 1 : 0, + hasV2cConsumerBuf ? 1 : 0})); + if (hasGmSlotBuffer && + parser.resolveOperand(gmSlotBuffer, gmSlotBufferTy, result.operands)) + return failure(); + if (hasGmSlotTensor && + parser.resolveOperand(gmSlotTensor, gmSlotTensorTy, result.operands)) + return failure(); + if (hasC2vConsumerBuf && + parser.resolveOperand(c2vConsumerBuf, c2vConsumerBufTy, result.operands)) + return failure(); + if (hasV2cConsumerBuf && + parser.resolveOperand(v2cConsumerBuf, v2cConsumerBufTy, result.operands)) + return failure(); + return success(); +} + +template +static void printFrontendInitializePipeOp(InitOpT op, OpAsmPrinter &p) { + p << " {"; + bool needsComma = false; + auto printClause = [&](StringRef keyword, auto value) { + if (needsComma) + p << ", "; + p << keyword << " = " << value; + needsComma = true; + }; + + if (op.getId() != 0) + printClause("id", op.getId()); + printClause("dir_mask", static_cast(op.getDirMask())); + printClause("slot_size", op.getSlotSize()); + if (auto localSlotNumAttr = op.getLocalSlotNumAttr()) + printClause("local_slot_num", localSlotNumAttr.getInt()); + if (auto noSplitAttr = op.getNosplitAttr()) + printClause("nosplit", noSplitAttr.getValue() ? "true" : "false"); + p << "}"; + + p << "("; + bool needsOperandComma = false; + auto printOperandClause = [&](StringRef keyword, Value value) { + if (needsOperandComma) + p << ", "; + p << keyword << " = " << value << " : " << value.getType(); + needsOperandComma = true; + }; + if (op.getGmSlotBuffer()) { + printOperandClause("gm_slot_buffer", op.getGmSlotBuffer()); + } + if (op.getGmSlotTensor()) + printOperandClause("gm_slot_tensor", op.getGmSlotTensor()); + if (op.getC2vConsumerBuf()) + printOperandClause("c2v_consumer_buf", op.getC2vConsumerBuf()); + if (op.getV2cConsumerBuf()) + printOperandClause("v2c_consumer_buf", op.getV2cConsumerBuf()); + p << ")"; + p.printOptionalAttrDict( + op->getAttrs(), + /*elidedAttrs=*/{"id", "dir_mask", "slot_size", "local_slot_num", + "nosplit", "operandSegmentSizes"}); +} + +static std::optional +getStaticElementCount(ArrayRef shape) { + uint64_t count = 1; + for (int64_t dim : shape) { + if (dim == ShapedType::kDynamic || dim < 0) + return std::nullopt; + count *= static_cast(dim); + } + return count; +} + +static bool isSameOrHalfSlotByteSize(uint64_t tensorBytes, uint64_t slotBytes) { + return tensorBytes == slotBytes || tensorBytes * 2 == slotBytes; +} + +static LogicalResult verifyFrontendGlobalSlotTensor(Operation *op, Value tensor, + int8_t dirMask, + int32_t slotSize) { + (void)dirMask; + auto tvTy = dyn_cast(tensor.getType()); + if (!tvTy) + return op->emitOpError("expects 'gm_slot_tensor' to be !pto.tensor_view"); + + ArrayRef shape = tvTy.getShape(); + if (shape.empty()) + return op->emitOpError( + "expects 'gm_slot_tensor' to describe one slot entry tensor"); + + if (auto elemCount = getStaticElementCount(shape)) { + uint64_t elemBytes = getElemByteSize(tvTy.getElementType()); + if (elemBytes != 0) { + uint64_t tensorBytes = *elemCount * elemBytes; + if (!isSameOrHalfSlotByteSize(tensorBytes, + static_cast(slotSize))) { + return op->emitOpError() + << "expects 'slot_size' to equal gm_slot_tensor byte size " + "or twice gm_slot_tensor byte size for split GlobalTensor " + "entries (got slot_size = " + << slotSize << ", gm_slot_tensor byte size = " << tensorBytes + << ")"; + } + } + } + + return success(); +} + +template +static LogicalResult verifyFrontendInitCommon(InitOpT op, + FunctionKernelKind expected, + StringRef kernelName) { + if (failed(verifyFrontendKernelKind(op.getOperation(), expected, kernelName))) + return failure(); + + auto funcOp = op->template getParentOfType(); + if (!funcOp) + return op.emitOpError("must be nested under a func.func"); + + if (op.getId() < 0) + return op.emitOpError("expects 'id' to be non-negative"); + + unsigned sameIdInitCount = 0; + funcOp.walk([&](Operation *candidate) { + if (auto aic = dyn_cast(candidate)) { + if (aic.getId() == op.getId()) + ++sameIdInitCount; + return; + } + if (auto aiv = dyn_cast(candidate)) + if (aiv.getId() == op.getId()) + ++sameIdInitCount; + }); + if (sameIdInitCount > 1) { + return op.emitOpError( + "requires 'id' to be unique across frontend initialize_pipe ops in the function"); + } + + int8_t dirMask = op.getDirMask(); + if (dirMask != 1 && dirMask != 2 && dirMask != 3) + return op.emitOpError("expects 'dir_mask' to be 1, 2, or 3"); + if (op.getSlotSize() <= 0) + return op.emitOpError("expects 'slot_size' to be greater than 0"); + + bool hasGlobalSlotTensor = static_cast(op.getGmSlotTensor()); + bool hasC2vConsumerBuf = static_cast(op.getC2vConsumerBuf()); + bool hasV2cConsumerBuf = static_cast(op.getV2cConsumerBuf()); + if (hasGlobalSlotTensor) { + if (op.getGmSlotBuffer() || hasC2vConsumerBuf || hasV2cConsumerBuf) { + return op.emitOpError( + "globaltensor pipe init expects only 'gm_slot_tensor' and no " + "'gm_slot_buffer', 'c2v_consumer_buf', or 'v2c_consumer_buf'"); + } + if (op.getLocalSlotNumAttr()) + return op.emitOpError( + "globaltensor pipe init does not use 'local_slot_num'"); + if (getTargetArch(op.getOperation()) == PTOArch::A5) { + return op.emitOpError( + "globaltensor pipe entries are supported for a2/a3 l2g2l pipes"); + } + return verifyFrontendGlobalSlotTensor( + op.getOperation(), op.getGmSlotTensor(), dirMask, op.getSlotSize()); + } + + if (hasC2vConsumerBuf != hasV2cConsumerBuf) { + return op.emitOpError( + "expects 'c2v_consumer_buf' and 'v2c_consumer_buf' to be provided together"); + } + if (!hasC2vConsumerBuf) { + return op.emitOpError( + "expects local pipe init to provide 'c2v_consumer_buf' and " + "'v2c_consumer_buf'; use 'gm_slot_tensor' for globaltensor pipe entries"); + } + + if (auto localSlotNumAttr = op.getLocalSlotNumAttr()) { + int32_t localSlotNum = localSlotNumAttr.getInt(); + if (localSlotNum <= 0) + return op.emitOpError("expects 'local_slot_num' to be greater than 0"); + int32_t loweredSlotNum = dirMask == 3 ? 4 : 8; + if (localSlotNum > loweredSlotNum) { + return op.emitOpError() + << "expects 'local_slot_num' to be less than or equal to " + << loweredSlotNum << " for dir_mask = " << static_cast(dirMask); + } + } + + return success(); +} + +ParseResult AicInitializePipeOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseFrontendInitializePipeOp(parser, result); +} + +void AicInitializePipeOp::print(OpAsmPrinter &p) { + printFrontendInitializePipeOp(*this, p); +} + +ParseResult AivInitializePipeOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseFrontendInitializePipeOp(parser, result); +} + +void AivInitializePipeOp::print(OpAsmPrinter &p) { + printFrontendInitializePipeOp(*this, p); +} + +static ReserveBufferOp findReserveBufferByName(func::FuncOp funcOp, + StringRef name) { + ReserveBufferOp found; + funcOp.walk([&](ReserveBufferOp reserveOp) { + if (reserveOp.getName() != name) + return WalkResult::advance(); + found = reserveOp; + return WalkResult::interrupt(); + }); + return found; +} + +LogicalResult ReserveBufferOp::verify() { + auto funcOp = getOperation()->getParentOfType(); + if (!funcOp) + return emitOpError("must be nested under a func.func"); + + if (getSize() <= 0) + return emitOpError("expects 'size' to be greater than 0"); + + auto location = getLocation().getAddressSpace(); + if (location != AddressSpace::VEC && location != AddressSpace::MAT) + return emitOpError("expects 'location' to be #pto.address_space or #pto.address_space"); + + if (!getAutoAlloc() && !getBaseAttr()) + return emitOpError("expects 'base' when 'auto' is false"); + + if (auto baseAttr = getBaseAttr(); baseAttr && baseAttr.getInt() < 0) + return emitOpError("expects 'base' to be non-negative when present"); + + unsigned sameNameCount = 0; + funcOp.walk([&](ReserveBufferOp reserveOp) { + if (reserveOp.getName() == getName()) + ++sameNameCount; + }); + if (sameNameCount > 1) + return emitOpError("requires 'name' to be unique within the function"); + + return success(); +} + +LogicalResult ImportReservedBufferOp::verify() { + auto funcOp = getOperation()->getParentOfType(); + if (!funcOp) + return emitOpError("must be nested under a func.func"); + + auto peerFunc = SymbolTable::lookupNearestSymbolFrom( + getOperation(), getPeerFuncAttr()); + if (!peerFunc) + return emitOpError("expects 'peer_func' to reference an existing func.func"); + + unsigned sameImportCount = 0; + funcOp.walk([&](ImportReservedBufferOp importOp) { + if (importOp.getName() == getName() && + importOp.getPeerFuncAttr() == getPeerFuncAttr()) { + ++sameImportCount; + } + }); + if (sameImportCount > 1) { + return emitOpError( + "requires (name, peer_func) to be unique within the function"); + } + + if (!findReserveBufferByName(peerFunc, getName())) + return emitOpError("expects matching peer reserve_buffer to exist"); + + return success(); +} + +static FailureOr lookupFrontendInitOpById(Operation *op, + func::FuncOp funcOp, + int32_t id) { + Operation *matchedInit = nullptr; + unsigned matchedInitCount = 0; + funcOp.walk([&](Operation *candidate) { + if (auto aic = dyn_cast(candidate)) { + if (aic.getId() == static_cast(id)) { + matchedInit = candidate; + ++matchedInitCount; + } + return WalkResult::advance(); + } + if (auto aiv = dyn_cast(candidate)) { + if (aiv.getId() == static_cast(id)) { + matchedInit = candidate; + ++matchedInitCount; + } + return WalkResult::advance(); + } + return WalkResult::advance(); + }); + + if (matchedInitCount == 0) { + op->emitOpError() << "expects 'id' = " << id + << " to match a frontend initialize_pipe op in the same function"; + return failure(); + } + if (matchedInitCount > 1) { + op->emitOpError() << "expects 'id' = " << id + << " to match exactly one frontend initialize_pipe op in the same function"; + return failure(); + } + return matchedInit; +} + +static LogicalResult verifyFrontendSplitOp(Operation *op, + FunctionKernelKind expected, + StringRef kernelName, + int32_t id, + int64_t split) { + if (failed(verifyFrontendKernelKind(op, expected, kernelName))) + return failure(); + if (id < 0) + return op->emitOpError("expects 'id' to be non-negative"); + return verifySplitAttr(op, split); +} + +static FailureOr lookupFrontendInitDirMaskById(Operation *op, + func::FuncOp funcOp, + int32_t id) { + auto initOr = lookupFrontendInitOpById(op, funcOp, id); + if (failed(initOr)) + return failure(); + if (auto aic = dyn_cast(*initOr)) + return aic.getDirMask(); + return cast(*initOr).getDirMask(); +} + +static LogicalResult verifyFrontendDataOpDirection(Operation *op, int32_t id, + bool expectC2V) { + auto funcOp = op->getParentOfType(); + if (!funcOp) + return op->emitOpError("must be nested under a func.func"); + + auto dirMaskOr = lookupFrontendInitDirMaskById(op, funcOp, id); + if (failed(dirMaskOr)) + return failure(); + + int8_t dirMask = *dirMaskOr; + if (expectC2V && dirMask != 1 && dirMask != 3) { + return op->emitOpError() + << "expects 'id' = " << id + << " to reference initialize_pipe with dir_mask = 1 or 3"; + } + if (!expectC2V && dirMask != 2 && dirMask != 3) { + return op->emitOpError() + << "expects 'id' = " << id + << " to reference initialize_pipe with dir_mask = 2 or 3"; + } + return success(); +} + +static Value getFrontendInitGmSlotTensor(Operation *initOp) { + if (auto aic = dyn_cast(initOp)) + return aic.getGmSlotTensor(); + return cast(initOp).getGmSlotTensor(); +} + +static LogicalResult verifyFrontendTensorEntryMatchesInit(Operation *op, + int32_t id, + Type entryTy) { + auto entryViewTy = dyn_cast(entryTy); + if (!entryViewTy) + return success(); + + auto funcOp = op->getParentOfType(); + if (!funcOp) + return op->emitOpError("must be nested under a func.func"); + + auto initOr = lookupFrontendInitOpById(op, funcOp, id); + if (failed(initOr)) + return failure(); + Value gmSlotTensor = getFrontendInitGmSlotTensor(*initOr); + if (!gmSlotTensor) { + return op->emitOpError() + << "expects 'id' = " << id + << " to reference initialize_pipe with 'gm_slot_tensor' when the " + "pipe entry is !pto.tensor_view"; + } + + auto slotTensorTy = dyn_cast(gmSlotTensor.getType()); + if (!slotTensorTy) + return op->emitOpError("expects 'gm_slot_tensor' to be !pto.tensor_view"); + if (slotTensorTy.getElementType() != entryViewTy.getElementType()) { + return op->emitOpError() + << "expects pipe entry element type to match gm_slot_tensor element type"; + } + if (slotTensorTy.getRank() != entryViewTy.getRank()) { + return op->emitOpError() + << "expects pipe entry rank to match gm_slot_tensor rank"; + } + + ArrayRef slotShape = slotTensorTy.getShape(); + ArrayRef entryShape = entryViewTy.getShape(); + for (auto [idx, entryDim] : llvm::enumerate(entryShape)) { + int64_t slotDim = slotShape[idx]; + if (slotDim == ShapedType::kDynamic || + entryDim == ShapedType::kDynamic || slotDim == entryDim) + continue; + return op->emitOpError() + << "expects pipe entry dimension " << idx + << " to match gm_slot_tensor dimension " << slotDim; + } + return success(); +} + +template +static LogicalResult verifyFrontendPopOp(FrontendPopOpT op, + FunctionKernelKind expected, + StringRef kernelName, + bool expectC2V) { + if (failed(verifyFrontendSplitOp(op.getOperation(), expected, kernelName, + op.getId(), + op.getSplit()))) + return failure(); + if (failed(verifyFrontendDataOpDirection(op.getOperation(), op.getId(), + expectC2V))) + return failure(); + if (failed(verifyFrontendTensorEntryMatchesInit(op.getOperation(), op.getId(), + op.getTile().getType()))) + return failure(); + + bool hasValidRow = static_cast(op.getValidRow()); + bool hasValidCol = static_cast(op.getValidCol()); + if (hasValidRow != hasValidCol) + return op.emitOpError( + "expects valid_row and valid_col operands to be provided together"); + if (!hasValidRow) + return success(); + + if (isa(op.getTile().getType())) + return op.emitOpError( + "does not accept valid_row/valid_col when result is !pto.tensor_view"); + + auto tileTy = dyn_cast(op.getTile().getType()); + if (!tileTy) + return op.emitOpError( + "expects tile result to be !pto.tile_buf when valid_row/valid_col operands are provided"); + if (!tileTy.hasDynamicValid()) + return op.emitOpError( + "expects tile result to have dynamic validShape (?, ?) when valid_row/valid_col operands are provided"); + return success(); +} + +static LogicalResult verifyPipeShape(Operation *op, int8_t dirMask, int32_t slotSize, + int32_t slotNum, + std::optional flagBase) { + constexpr int32_t kMaxHardwareFlagIds = 16; + if (dirMask != 1 && dirMask != 2 && dirMask != 3) + return op->emitOpError("expects 'dir_mask' to be 1, 2, or 3"); + if (slotSize <= 0) + return op->emitOpError("expects 'slot_size' to be greater than 0"); + if (slotNum != 4 && slotNum != 8) + return op->emitOpError("expects 'slot_num' to be 4 or 8"); + if (flagBase && *flagBase < 0) + return op->emitOpError("expects 'flag_base' to be non-negative when present"); + if (flagBase) { + int32_t flagWidth = dirMask == 3 ? 4 : 2; + if (*flagBase + flagWidth > kMaxHardwareFlagIds) { + return op->emitOpError() + << "requires 'flag_base' and dir_mask to fit within " + << kMaxHardwareFlagIds << " hardware flag ids"; + } + } + + return success(); +} + +static LogicalResult verifyPipeHandleProducer(Operation *op, Value pipeHandle) { + if (!isa(pipeHandle.getType())) + return op->emitOpError("expects pipe operand type !pto.pipe"); + if (!pipeHandle.getDefiningOp() && + !pipeHandle.getDefiningOp()) { + return op->emitOpError( + "pipe_handle must be produced by pto.initialize_l2l_pipe or " + "pto.initialize_l2g2l_pipe"); + } + return success(); +} + +static bool getTensorLikeElementAndShape(Type ty, Type &elementType, + ArrayRef &shape) { + if (auto tvTy = dyn_cast(ty)) { + elementType = tvTy.getElementType(); + shape = tvTy.getShape(); + return true; + } + if (auto memrefTy = dyn_cast(ty)) { + elementType = memrefTy.getElementType(); + shape = memrefTy.getShape(); + return true; + } + return false; +} + +static LogicalResult verifyTensorEntryMatchesInternalPipeInit(Operation *op, + Value pipeHandle, + Type entryTy) { + auto entryViewTy = dyn_cast(entryTy); + if (!entryViewTy) + return success(); + + auto initOp = pipeHandle.getDefiningOp(); + if (!initOp) { + return op->emitOpError() + << "expects !pto.tensor_view pipe entry to use a pipe produced by " + "pto.initialize_l2g2l_pipe"; + } + if (initOp.getLocalAddr()) { + return op->emitOpError() + << "expects !pto.tensor_view pipe entry to use global-only " + "pto.initialize_l2g2l_pipe without local_addr"; + } + + Type slotElementType; + ArrayRef slotShape; + if (!getTensorLikeElementAndShape(initOp.getGmAddr().getType(), + slotElementType, slotShape)) { + return op->emitOpError() + << "expects !pto.tensor_view pipe entry to use " + "pto.initialize_l2g2l_pipe gm_addr with tensor/memref slot type"; + } + + if (slotElementType != entryViewTy.getElementType()) { + return op->emitOpError() + << "expects pipe entry element type to match initialize_l2g2l_pipe " + "gm_addr element type"; + } + if (slotShape.size() != static_cast(entryViewTy.getRank())) { + return op->emitOpError() + << "expects pipe entry rank to match initialize_l2g2l_pipe gm_addr " + "rank"; + } + + ArrayRef entryShape = entryViewTy.getShape(); + for (auto [idx, entryDim] : llvm::enumerate(entryShape)) { + int64_t slotDim = slotShape[idx]; + if (slotDim == ShapedType::kDynamic || + entryDim == ShapedType::kDynamic || slotDim == entryDim) + continue; + return op->emitOpError() + << "expects pipe entry dimension " << idx + << " to match initialize_l2g2l_pipe gm_addr dimension " + << slotDim; + } + + if (auto entryElemCount = getStaticElementCount(entryShape)) { + uint64_t elemBytes = getElemByteSize(entryViewTy.getElementType()); + uint64_t entryBytes = *entryElemCount * elemBytes; + if (elemBytes != 0) { + int8_t split = 0; + if (auto alloc = dyn_cast(op)) + split = alloc.getSplit(); + else if (auto push = dyn_cast(op)) + split = push.getSplit(); + else if (auto pop = dyn_cast(op)) + split = pop.getSplit(); + else if (auto free = dyn_cast(op)) + split = free.getSplit(); + + uint64_t slotBytes = static_cast(initOp.getSlotSize()); + bool isSplitEntry = split != 0; + bool byteSizeMatches = + entryBytes == slotBytes || (isSplitEntry && entryBytes * 2 == slotBytes); + if (!byteSizeMatches) { + return op->emitOpError() + << "expects pipe entry byte size to match initialize_l2g2l_pipe " + "slot_size" + << (isSplitEntry ? " or half slot_size for split entries" : "") + << " (got entry byte size = " << entryBytes + << ", slot_size = " << initOp.getSlotSize() << ")"; + } + } + } + + return success(); +} + +LogicalResult BuildAsyncSessionOp::verify() { + Type scratchTy = getScratch().getType(); + if (!isa(scratchTy)) + return emitOpError("expects scratch to be tile_buf or memref type"); + + auto scratchSpace = getPTOMemorySpaceEnum(scratchTy); + if (!scratchSpace || *scratchSpace != pto::AddressSpace::VEC) + return emitOpError("expects scratch to be in vec address space"); + + auto scratchShape = getShapeVec(scratchTy); + if (scratchShape.empty() || scratchShape.size() > 2) + return emitOpError("expects scratch to be rank-1 or rank-2"); + for (int64_t dim : scratchShape) { + if (dim == ShapedType::kDynamic) + return emitOpError("expects scratch to have a static shape"); + } + + auto scratchBytes = getStaticByteSize(scratchTy); + if (!scratchBytes) + return emitOpError("expects scratch byte size to be statically known"); + if (*scratchBytes < sizeof(uint64_t)) + return emitOpError("expects scratch to provide at least 8 bytes"); + + Type workspaceElemTy; + Type workspaceTy = getWorkspace().getType(); + if (auto ptrTy = dyn_cast(workspaceTy)) { + workspaceElemTy = ptrTy.getElementType(); + } else if (auto memTy = dyn_cast(workspaceTy)) { + workspaceElemTy = memTy.getElementType(); + if (!isGmAddressSpaceAttr(memTy.getMemorySpace())) + return emitOpError("expects workspace to be in GM address space"); + } else { + return emitOpError("expects workspace to be !pto.ptr or memref type"); + } + if (!isByteIntegerType(workspaceElemTy)) + return emitOpError("expects workspace element type to be an 8-bit integer"); + + if (auto syncIdAttr = getSyncIdAttr()) { + int64_t syncId = syncIdAttr.getInt(); + if (syncId < 0 || syncId > 7) + return emitOpError("expects sync_id in range [0, 7]"); + } + if (auto blockBytesAttr = getBlockBytesAttr()) { + if (blockBytesAttr.getInt() <= 0) + return emitOpError("expects block_bytes to be greater than 0"); + } + if (auto commBlockOffsetAttr = getCommBlockOffsetAttr()) { + if (commBlockOffsetAttr.getInt() < 0) + return emitOpError("expects comm_block_offset to be non-negative"); + } + if (auto queueNumAttr = getQueueNumAttr()) { + if (queueNumAttr.getInt() <= 0) + return emitOpError("expects queue_num to be greater than 0"); + } + if (auto channelGroupIdxAttr = getChannelGroupIdxAttr()) { + APInt value = channelGroupIdxAttr.getValue(); + if (value.isNegative()) + return emitOpError("expects channel_group_idx to be non-negative"); + if (value.ugt(UINT32_MAX)) + return emitOpError("expects channel_group_idx to fit in uint32"); + } + + return success(); +} + +static LogicalResult verifyAsyncTransferOp(Operation *op, Value dst, Value src) { + Type dstElemTy = getElemTy(dst.getType()); + Type srcElemTy = getElemTy(src.getType()); + if (!dstElemTy || !srcElemTy) + return op->emitOpError("expects src and dst to have element types"); + if (dstElemTy != srcElemTy) + return op->emitOpError("expects src and dst to have the same element type"); + if (failed(verifyAsyncFlatContiguous1DGMViewLike(op, dst, "dst")) || + failed(verifyAsyncFlatContiguous1DGMViewLike(op, src, "src"))) + return failure(); + if (getShapeVec(dst.getType()) != getShapeVec(src.getType())) + return op->emitOpError("expects src and dst to have the same static shape"); + return success(); +} + +LogicalResult TPutAsyncOp::verify() { + return verifyAsyncTransferOp(getOperation(), getDst(), getSrc()); +} + +LogicalResult TGetAsyncOp::verify() { + return verifyAsyncTransferOp(getOperation(), getDst(), getSrc()); +} + +LogicalResult TPutOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + if (failed(verifyCommGlobalLike(*this, getDst(), "dst")) || + failed(verifyCommGlobalLike(*this, getSrc(), "src")) || + failed(verifyCommStagingTileLike(*this, getPing(), "ping")) || + failed(verifyCommPingPongSameType(*this, getPing(), getPong(), "ping", + "pong"))) + return failure(); + if (getElemTy(getDst().getType()) != getElemTy(getSrc().getType())) + return emitOpError("expects src and dst to have the same element type"); + if (getShapeVec(getDst().getType()) != getShapeVec(getSrc().getType())) + return emitOpError("expects src and dst to have the same static shape"); + if (getElemTy(getPing().getType()) != getElemTy(getSrc().getType())) + return emitOpError("expects staging tile element type to match src/dst"); + return success(); +} + +LogicalResult TGetOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + if (failed(verifyCommGlobalLike(*this, getDst(), "dst")) || + failed(verifyCommGlobalLike(*this, getSrc(), "src")) || + failed(verifyCommStagingTileLike(*this, getPing(), "ping")) || + failed(verifyCommPingPongSameType(*this, getPing(), getPong(), "ping", + "pong"))) + return failure(); + if (getElemTy(getDst().getType()) != getElemTy(getSrc().getType())) + return emitOpError("expects src and dst to have the same element type"); + if (getShapeVec(getDst().getType()) != getShapeVec(getSrc().getType())) + return emitOpError("expects src and dst to have the same static shape"); + if (getElemTy(getPing().getType()) != getElemTy(getSrc().getType())) + return emitOpError("expects staging tile element type to match src/dst"); + return success(); +} + +LogicalResult TNotifyOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + if (failed(verifyCommSignalLike(*this, getSignal(), "signal"))) + return failure(); + auto valueTy = dyn_cast(getValue().getType()); + if (!valueTy || valueTy.getWidth() != 32) + return emitOpError("expects value to be i32"); + return success(); +} + +LogicalResult TWaitOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + if (failed(verifyCommSignalLike(*this, getSignal(), "signal"))) + return failure(); + auto cmpTy = dyn_cast(getCmpValue().getType()); + if (!cmpTy || cmpTy.getWidth() != 32) + return emitOpError("expects cmp_value to be i32"); + return success(); +} + +LogicalResult TTestOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + if (failed(verifyCommSignalLike(*this, getSignal(), "signal"))) + return failure(); + auto cmpTy = dyn_cast(getCmpValue().getType()); + if (!cmpTy || cmpTy.getWidth() != 32) + return emitOpError("expects cmp_value to be i32"); + return success(); +} + +static LogicalResult verifySyncAllGmWorkspace(Operation *op, Value workspace, + StringRef name) { + Type ty = workspace.getType(); + if (!isa(ty)) + return op->emitOpError() << "expects " << name + << " to be a GM memref/tensor_view/partition_view"; + + if (auto memTy = dyn_cast(ty)) { + if (!memTy.hasRank()) + return op->emitOpError() << "expects " << name << " to be ranked"; + if (!isGmAddressSpaceAttr(memTy.getMemorySpace())) + return op->emitOpError() << "expects " << name + << " to be in GM address space"; + } + + auto elemTy = dyn_cast(getElemTy(ty)); + if (!elemTy || elemTy.getWidth() != 32) + return op->emitOpError() << "expects " << name + << " element type to be i32"; + + SmallVector shape = getShapeVec(ty); + if (shape.empty()) + return op->emitOpError() << "expects " << name << " to have rank >= 1"; + for (int64_t dim : shape) { + if (dim != ShapedType::kDynamic && dim <= 0) + return op->emitOpError() << "expects " << name + << " shape to be positive"; + } + return success(); +} + +static LogicalResult verifySyncAllTileWorkspace(Operation *op, Value workspace, + StringRef name, + pto::AddressSpace expectedSpace) { + Type ty = workspace.getType(); + if (!isa(ty)) + return op->emitOpError() << "expects " << name + << " to be tile_buf or memref type"; + + if (isa(ty) && failed(verifyTileBufCommon(op, ty, name))) + return failure(); + + auto as = getPTOMemorySpaceEnum(ty); + if (!as || *as != expectedSpace) + return op->emitOpError() << "expects " << name << " to be in " + << (expectedSpace == pto::AddressSpace::VEC + ? "vec" + : "mat") + << " address space"; + + Type elemTy = getElemTy(ty); + auto intTy = dyn_cast_or_null(elemTy); + if (!intTy || intTy.getWidth() != 32) + return op->emitOpError() << "expects " << name + << " element type to be i32"; + + auto shape = getShapeVec(ty); + if (shape.empty() || shape.size() > 2) + return op->emitOpError() << "expects " << name + << " to be rank-1 or rank-2"; + for (int64_t dim : shape) { + if (dim != ShapedType::kDynamic && dim <= 0) + return op->emitOpError() << "expects " << name + << " shape to be positive"; + } + return success(); +} + +LogicalResult SyncAllOp::verify() { + bool hasGm = static_cast(getGmWorkspace()); + bool hasUb = static_cast(getUbWorkspace()); + bool hasL1 = static_cast(getL1Workspace()); + auto mode = getMode().getValue(); + auto coreType = getCoreType().getValue(); + + if (mode == pto::SyncAllMode::Hard) { + if (hasGm || hasUb || hasL1 || getUsedCores()) + return emitOpError( + "expects hard syncall to have no workspace operands or used_cores"); + return success(); + } + + if (!hasGm) + return emitOpError("expects soft syncall to provide gm_workspace"); + if (failed(verifySyncAllGmWorkspace(getOperation(), getGmWorkspace(), + "gm_workspace"))) + return failure(); + + if (auto used = getUsedCores()) { + auto intTy = dyn_cast(used.getType()); + if (!intTy || intTy.getWidth() != 32) + return emitOpError("expects used_cores to be i32"); + } + + switch (coreType) { + case pto::SyncCoreType::AIVOnly: + if (!hasUb || hasL1) + return emitOpError("expects soft AIV-only syncall to use gm_workspace " + "+ ub_workspace only"); + return verifySyncAllTileWorkspace(getOperation(), getUbWorkspace(), + "ub_workspace", + pto::AddressSpace::VEC); + case pto::SyncCoreType::AICOnly: + if (hasUb || !hasL1) + return emitOpError("expects soft AIC-only syncall to use gm_workspace " + "+ l1_workspace only"); + return verifySyncAllTileWorkspace(getOperation(), getL1Workspace(), + "l1_workspace", + pto::AddressSpace::MAT); + case pto::SyncCoreType::Mix: + if (!hasUb || !hasL1) + return emitOpError("expects soft mixed syncall to use gm_workspace + " + "ub_workspace + l1_workspace"); + if (failed(verifySyncAllTileWorkspace(getOperation(), getUbWorkspace(), + "ub_workspace", + pto::AddressSpace::VEC))) + return failure(); + return verifySyncAllTileWorkspace(getOperation(), getL1Workspace(), + "l1_workspace", + pto::AddressSpace::MAT); + } + + llvm_unreachable("unhandled SyncCoreType"); +} + +LogicalResult TBroadcastOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + if (failed(verifyCommGlobalLike(*this, getSrc(), "src")) || + failed(verifyCommStagingTileLike(*this, getPing(), "ping")) || + failed(verifyCommPingPongSameType(*this, getPing(), getPong(), "ping", + "pong")) || + failed(verifyCommGlobalGroup(*this, getGroup(), "group"))) + return failure(); + if (getRoot() >= static_cast(getGroup().size())) + return emitOpError("expects root to index into group operands"); + if (getSrc().getType() != getGroup().front().getType()) + return emitOpError("expects src type to match group member type"); + if (getElemTy(getPing().getType()) != getElemTy(getSrc().getType())) + return emitOpError("expects staging tile element type to match src"); + return success(); +} + +LogicalResult CommTGatherOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + if (failed(verifyCommGlobalLike(*this, getDst(), "dst")) || + failed(verifyCommStagingTileLike(*this, getPing(), "ping")) || + failed(verifyCommPingPongSameType(*this, getPing(), getPong(), "ping", + "pong")) || + failed(verifyCommGlobalGroup(*this, getGroup(), "group"))) + return failure(); + if (getRoot() >= static_cast(getGroup().size())) + return emitOpError("expects root to index into group operands"); + if (getElemTy(getDst().getType()) != getElemTy(getGroup().front().getType())) + return emitOpError("expects dst element type to match group member type"); + if (getElemTy(getPing().getType()) != getElemTy(getDst().getType())) + return emitOpError("expects staging tile element type to match dst"); + return success(); +} + +LogicalResult CommTScatterOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + if (failed(verifyCommGlobalLike(*this, getSrc(), "src")) || + failed(verifyCommStagingTileLike(*this, getPing(), "ping")) || + failed(verifyCommPingPongSameType(*this, getPing(), getPong(), "ping", + "pong")) || + failed(verifyCommGlobalGroup(*this, getGroup(), "group"))) + return failure(); + if (getRoot() >= static_cast(getGroup().size())) + return emitOpError("expects root to index into group operands"); + if (getElemTy(getSrc().getType()) != getElemTy(getGroup().front().getType())) + return emitOpError("expects src element type to match group member type"); + if (getElemTy(getPing().getType()) != getElemTy(getSrc().getType())) + return emitOpError("expects staging tile element type to match src"); + return success(); +} + +LogicalResult TReduceOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + if (failed(verifyCommGlobalLike(*this, getDst(), "dst")) || + failed(verifyCommStagingTileLike(*this, getAcc(), "acc")) || + failed(verifyCommStagingTileLike(*this, getRecvPing(), "recv_ping")) || + failed(verifyCommPingPongSameType(*this, getRecvPing(), getRecvPong(), + "recv_ping", "recv_pong")) || + failed(verifyCommGlobalGroup(*this, getGroup(), "group"))) + return failure(); + if (getRoot() >= static_cast(getGroup().size())) + return emitOpError("expects root to index into group operands"); + if (getElemTy(getDst().getType()) != getElemTy(getGroup().front().getType())) + return emitOpError("expects dst element type to match group member type"); + if (getAcc().getType() != getRecvPing().getType()) + return emitOpError("expects acc and recv_ping to have identical types"); + if (getElemTy(getAcc().getType()) != getElemTy(getDst().getType())) + return emitOpError("expects accumulator/receive tiles to match dst element type"); + return success(); +} + +LogicalResult AicInitializePipeOp::verify() { + return verifyFrontendInitCommon(*this, FunctionKernelKind::Cube, "cube"); +} + +LogicalResult AivInitializePipeOp::verify() { + return verifyFrontendInitCommon(*this, FunctionKernelKind::Vector, "vector"); +} + +LogicalResult TAllocToAivOp::verify() { + if (failed(verifyFrontendSplitOp(getOperation(), FunctionKernelKind::Cube, + "cube", getId(), getSplit()))) + return failure(); + if (failed(verifyFrontendDataOpDirection(getOperation(), getId(), + /*expectC2V=*/true))) + return failure(); + return verifyFrontendTensorEntryMatchesInit(getOperation(), getId(), + getEntry().getType()); +} + +LogicalResult TAllocToAicOp::verify() { + if (failed(verifyFrontendSplitOp(getOperation(), FunctionKernelKind::Vector, + "vector", getId(), getSplit()))) + return failure(); + if (failed(verifyFrontendDataOpDirection(getOperation(), getId(), + /*expectC2V=*/false))) + return failure(); + return verifyFrontendTensorEntryMatchesInit(getOperation(), getId(), + getEntry().getType()); +} + +LogicalResult TPushToAivOp::verify() { + if (failed(verifyFrontendSplitOp(getOperation(), FunctionKernelKind::Cube, + "cube", getId(), getSplit()))) + return failure(); + if (failed(verifyFrontendDataOpDirection(getOperation(), getId(), + /*expectC2V=*/true))) + return failure(); + return verifyFrontendTensorEntryMatchesInit(getOperation(), getId(), + getTile().getType()); +} + +LogicalResult TPushToAicOp::verify() { + if (failed(verifyFrontendSplitOp(getOperation(), FunctionKernelKind::Vector, + "vector", getId(), getSplit()))) + return failure(); + if (failed(verifyFrontendDataOpDirection(getOperation(), getId(), + /*expectC2V=*/false))) + return failure(); + return verifyFrontendTensorEntryMatchesInit(getOperation(), getId(), + getTile().getType()); +} + +LogicalResult TPopFromAicOp::verify() { + return verifyFrontendPopOp(*this, FunctionKernelKind::Vector, "vector", + /*expectC2V=*/true); +} + +LogicalResult TPopFromAivOp::verify() { + return verifyFrontendPopOp(*this, FunctionKernelKind::Cube, "cube", + /*expectC2V=*/false); +} + +LogicalResult TFreeFromAicOp::verify() { + if (failed(verifyFrontendSplitOp(getOperation(), FunctionKernelKind::Vector, + "vector", getId(), getSplit()))) + return failure(); + if (failed(verifyFrontendDataOpDirection(getOperation(), getId(), + /*expectC2V=*/true))) + return failure(); + if (getEntry()) + return verifyFrontendTensorEntryMatchesInit(getOperation(), getId(), + getEntry().getType()); + return success(); +} + +LogicalResult TFreeFromAivOp::verify() { + if (failed(verifyFrontendSplitOp(getOperation(), FunctionKernelKind::Cube, + "cube", getId(), getSplit()))) + return failure(); + if (failed(verifyFrontendDataOpDirection(getOperation(), getId(), + /*expectC2V=*/false))) + return failure(); + if (getEntry()) + return verifyFrontendTensorEntryMatchesInit(getOperation(), getId(), + getEntry().getType()); + return success(); +} + +LogicalResult InitializeL2G2LPipeOp::verify() { + if (failed(verifyPipeShape(getOperation(), getDirMask(), getSlotSize(), + getSlotNum(), + getFlagBaseAttr() + ? std::optional(getFlagBaseAttr().getInt()) + : std::nullopt))) + return failure(); + + if (!getLocalAddr()) { + if (getPeerLocalAddr()) + return emitOpError("'peer_local_addr' requires 'local_addr'"); + if (getLocalSlotNumAttr()) + return emitOpError( + "'local_slot_num' is only allowed when 'local_addr' is present"); + return success(); + } + + if (auto localSlotNumAttr = getLocalSlotNumAttr()) { + int32_t localSlotNum = localSlotNumAttr.getInt(); + if (localSlotNum <= 0) + return emitOpError("expects 'local_slot_num' to be greater than 0"); + if (static_cast(localSlotNum) > getSlotNum()) + return emitOpError( + "expects 'local_slot_num' to be less than or equal to slot_num"); + } + + if (getDirMask() == 3 && !getPeerLocalAddr()) + return emitOpError("expects 'peer_local_addr' when dir_mask is 3"); + if (getDirMask() != 3 && getPeerLocalAddr()) + return emitOpError("'peer_local_addr' is only allowed when dir_mask is 3"); + return success(); +} + +LogicalResult InitializeL2LPipeOp::verify() { + if (failed(verifyPipeShape(getOperation(), getDirMask(), getSlotSize(), + getSlotNum(), + getFlagBaseAttr() + ? std::optional(getFlagBaseAttr().getInt()) + : std::nullopt))) + return failure(); + + if (getDirMask() == 3 && !getPeerLocalAddr()) + return emitOpError("expects 'peer_local_addr' when dir_mask is 3"); + if (getDirMask() != 3 && getPeerLocalAddr()) + return emitOpError("'peer_local_addr' is only allowed when dir_mask is 3"); + return success(); +} + +LogicalResult TPushOp::verify() { + if (!isInsideSectionOrAttributedKernel(getOperation())) + return emitOpError("must be inside pto.section.cube/vector or a kernel_kind function"); + if (failed(verifyPipeHandleProducer(getOperation(), getPipeHandle()))) + return failure(); + if (failed(verifySplitAttr(getOperation(), getSplit()))) + return failure(); + if (failed(verifyTensorEntryMatchesInternalPipeInit( + getOperation(), getPipeHandle(), getTile().getType()))) + return failure(); + if (!isa(getTile().getType()) && + getPipe() == pto::PIPE::PIPE_UNASSIGNED) + return emitOpError("tile type must map to a supported producer pipe"); + return success(); +} + +LogicalResult TAllocOp::verify() { + if (!isInsideSectionOrAttributedKernel(getOperation())) + return emitOpError("must be inside pto.section.cube/vector or a kernel_kind function"); + if (failed(verifyPipeHandleProducer(getOperation(), getPipeHandle()))) + return failure(); + if (failed(verifyTensorEntryMatchesInternalPipeInit( + getOperation(), getPipeHandle(), getEntry().getType()))) + return failure(); + return verifySplitAttr(getOperation(), getSplit()); +} + +LogicalResult TPopOp::verify() { + if (!isInsideSectionOrAttributedKernel(getOperation())) + return emitOpError("must be inside pto.section.cube/vector or a kernel_kind function"); + if (failed(verifyPipeHandleProducer(getOperation(), getPipeHandle()))) + return failure(); + if (failed(verifySplitAttr(getOperation(), getSplit()))) + return failure(); + if (failed(verifyTensorEntryMatchesInternalPipeInit( + getOperation(), getPipeHandle(), getTile().getType()))) + return failure(); + if (!isa(getTile().getType()) && + getPipe() == pto::PIPE::PIPE_UNASSIGNED) + return emitOpError( + "tile type and target arch must map to a supported consumer pipe"); + return success(); +} + +LogicalResult TFreeOp::verify() { + if (!isInsideSectionOrAttributedKernel(getOperation())) + return emitOpError("must be inside pto.section.cube/vector or a kernel_kind function"); + if (failed(verifyPipeHandleProducer(getOperation(), getPipeHandle()))) + return failure(); + if (getEntry() && + failed(verifyTensorEntryMatchesInternalPipeInit( + getOperation(), getPipeHandle(), getEntry().getType()))) + return failure(); + return verifySplitAttr(getOperation(), getSplit()); +} + +ParseResult TFreeOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand first; + OpAsmParser::UnresolvedOperand pipe; + Type firstTy; + Type pipeTy; + bool hasEntry = false; + + if (parser.parseLParen() || parser.parseOperand(first)) + return failure(); + + if (succeeded(parser.parseOptionalComma())) { + hasEntry = true; + if (parser.parseOperand(pipe) || parser.parseColonType(firstTy) || + parser.parseComma() || parser.parseType(pipeTy) || parser.parseRParen()) + return failure(); + } else { + if (parser.parseColonType(pipeTy) || parser.parseRParen()) + return failure(); + pipe = first; + } + + NamedAttrList attrs; + if (parser.parseLBrace() || parser.parseKeyword("split") || + parser.parseEqual()) + return failure(); + IntegerAttr splitAttr; + if (parser.parseAttribute(splitAttr, parser.getBuilder().getI8Type(), + "split", attrs) || + parser.parseRBrace() || parser.parseOptionalAttrDict(attrs)) + return failure(); + + result.addAttributes(attrs); + if (hasEntry && + parser.resolveOperand(first, firstTy, result.operands)) + return failure(); + if (parser.resolveOperand(pipe, pipeTy, result.operands)) + return failure(); + return success(); +} + +void TFreeOp::print(OpAsmPrinter &p) { + p << "("; + if (getEntry()) { + p << getEntry() << ", " << getPipeHandle() << " : " + << getEntry().getType() << ", " << getPipeHandle().getType(); + } else { + p << getPipeHandle() << " : " << getPipeHandle().getType(); + } + p << ") {split = " << static_cast(getSplit()) << "}"; + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"split"}); +} + +void BuildAsyncSessionOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getScratchMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getWorkspaceMutable(), MemoryEffects::Read::get()); + addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); +} + +void TPutAsyncOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); + addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getSessionMutable(), MemoryEffects::Read::get()); + addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); +} + +void TGetAsyncOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); + addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getSessionMutable(), MemoryEffects::Read::get()); + addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); +} + +void TPutOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); + addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getPingMutable(), MemoryEffects::Read::get()); +} + +void TGetOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); + addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getPingMutable(), MemoryEffects::Read::get()); +} + +void TNotifyOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getSignalMutable(), MemoryEffects::Write::get()); + addEffect(effects, &getValueMutable(), MemoryEffects::Read::get()); +} + +void TWaitOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getSignalMutable(), MemoryEffects::Write::get()); + addEffect(effects, &getCmpValueMutable(), MemoryEffects::Read::get()); +} + +void TTestOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getSignalMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getCmpValueMutable(), MemoryEffects::Read::get()); + addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); +} + +void TBroadcastOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getPingMutable(), MemoryEffects::Write::get()); + if (getPong()) { + auto pongRange = getPongMutable(); + if (auto it = pongRange.begin(); it != pongRange.end()) + addEffect(effects, &*it, MemoryEffects::Write::get()); + } +} + +void CommTGatherOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); + addEffect(effects, &getPingMutable(), MemoryEffects::Read::get()); +} + +void CommTScatterOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getPingMutable(), MemoryEffects::Write::get()); + if (getPong()) { + auto pongRange = getPongMutable(); + if (auto it = pongRange.begin(); it != pongRange.end()) + addEffect(effects, &*it, MemoryEffects::Write::get()); + } +} + +void TReduceOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); + addEffect(effects, &getAccMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getAccMutable(), MemoryEffects::Write::get()); + addEffect(effects, &getRecvPingMutable(), MemoryEffects::Write::get()); +} + +void WaitAsyncEventOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getEventMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getSessionMutable(), MemoryEffects::Read::get()); + addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); +} + +void TestAsyncEventOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getEventMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getSessionMutable(), MemoryEffects::Read::get()); + addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); +} + +void InitializeL2G2LPipeOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getGmAddrMutable(), MemoryEffects::Read::get()); + auto localAddr = getLocalAddrMutable(); + if (!localAddr.empty()) + addEffect(effects, &*localAddr.begin(), MemoryEffects::Read::get()); + auto peerLocalAddr = getPeerLocalAddrMutable(); + if (!peerLocalAddr.empty()) + addEffect(effects, &*peerLocalAddr.begin(), MemoryEffects::Read::get()); + addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); +} + +void InitializeL2LPipeOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getLocalAddrMutable(), MemoryEffects::Read::get()); + addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); +} + +void TPushOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getTileMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getPipeHandleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getPipeHandleMutable(), MemoryEffects::Write::get()); +} + +void TAllocOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getEntryMutable(), MemoryEffects::Write::get()); + addEffect(effects, &getPipeHandleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getPipeHandleMutable(), MemoryEffects::Write::get()); +} + +void TPopOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getPipeHandleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getPipeHandleMutable(), MemoryEffects::Write::get()); + addEffect(effects, &getTileMutable(), MemoryEffects::Write::get()); +} + +void TFreeOp::getEffects( + SmallVectorImpl> + &effects) { + auto entry = getEntryMutable(); + if (!entry.empty()) + addEffect(effects, &*entry.begin(), MemoryEffects::Read::get()); + addEffect(effects, &getPipeHandleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getPipeHandleMutable(), MemoryEffects::Write::get()); +} + +// [Include 必须放在最后] +#include "PTO/IR/PTOInterfaces.cpp.inc" +#define GET_OP_CLASSES +#include "PTO/IR/PTOOps.cpp.inc" diff --git a/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp b/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp index 23a4032a6..0e1e75998 100644 --- a/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp +++ b/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp @@ -6,2571 +6,5 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -//===--------- SyncSolver.cpp ------- Graph Sync Solver -------------------===// -//===----------------------------------------------------------------------===// -#include "PTO/Transforms/GraphSyncSolver/SyncSolver.h" -#include "PTO/Transforms/GraphSyncSolver/GraphSolver.h" -#include "PTO/Transforms/GraphSyncSolver/MemInfo.h" -#include "PTO/Transforms/GraphSyncSolver/SyncSolverIR.h" -#include "PTO/Transforms/GraphSyncSolver/Utility.h" - -#include "PTO/IR/PTO.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/IR/Value.h" -#include "mlir/Interfaces/LoopLikeInterface.h" -#include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringExtras.h" -#include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/LogicalResult.h" -#include -#include -#include -#include -#include -#include -#include - -#define DEBUG_TYPE "PTO-gss-solver" - -using namespace mlir; -using namespace pto::syncsolver; - -// Reset per-pass bookkeeping to start fresh. -void Solver::reset(bool resetEventIdRanOutOpts) { - if (resetEventIdRanOutOpts) { - reusePairs.clear(); - disabledMultiEventIdPairs.clear(); - backwardSyncEventsAfterMerge.clear(); - moveBackwardSyncPairsToOutmostLoop = false; - dontMoveBackwardSyncPairsToOutmostLoop = false; - } - skipOcc.clear(); - syncedPairs.clear(); - processedOccPairs.clear(); - chosenConflictedPairs.clear(); - scopeOccChosenConflicts.clear(); - scopeOccPairChosenConflicts.clear(); - backwardSyncEvents.clear(); - replacedWithReusableSyncedPairs.clear(); - reusedPairs.clear(); - barrierAllPairs.clear(); - insertedBarrierAllBefore.clear(); - eventIdSolver.clear(); - resetUnitFlag(); -} - -void Solver::resetUnitFlag() { - for (auto *rwOp : unitFlagFeaturedOps) { - rwOp->mergedUnitFlagInfo.reset(); - for (auto *occ : opAllOccurrences[rwOp]) { - occ->unitFlagInfo.reset(); - } - } -} - -// Helpers to find first/last iteration occurrences relative to parent -// occurrences. -Occurrence *Solver::getFirstIterOcc(Occurrence *occ, Occurrence *parOcc) { - assert(occ != nullptr && parOcc != nullptr); - if (parOcc->depth + 1 < occ->depth) { - auto *newParOcc = getFirstIterOcc( - occ->getNthParent(occ->depth - parOcc->depth - 1), parOcc); - return getFirstIterOcc(occ, newParOcc); - } - auto *it = - std::find_if(parOcc->childOccs.begin(), parOcc->childOccs.end(), - [occ](Occurrence *curOcc) { return occ->op == curOcc->op; }); - assert(it != parOcc->childOccs.end()); - return *it; -} - -Occurrence *Solver::getLastIterOcc(Occurrence *occ, Occurrence *parOcc) { - assert(occ != nullptr && parOcc != nullptr); - if (parOcc->depth + 1 < occ->depth) { - auto *newParOcc = getLastIterOcc( - occ->getNthParent(occ->depth - parOcc->depth - 1), parOcc); - return getLastIterOcc(occ, newParOcc); - } - auto it = - std::find_if(parOcc->childOccs.rbegin(), parOcc->childOccs.rend(), - [occ](Occurrence *curOcc) { return occ->op == curOcc->op; }); - assert(it != parOcc->childOccs.rend()); - return *it; -} - -bool Solver::checkSkipCrossCorePair(Occurrence *occ1, Occurrence *occ2) { - if (!options.isCrossCoreMode()) { - return false; - } - auto *rwOp1 = llvm::dyn_cast(occ1->op); - auto *rwOp2 = llvm::dyn_cast(occ2->op); - assert(rwOp1 != nullptr && rwOp2 != nullptr); - assert(rwOp1->coreType != pto::TCoreType::CUBE_OR_VECTOR); - assert(rwOp2->coreType != pto::TCoreType::CUBE_OR_VECTOR); - if (rwOp1->coreType == rwOp2->coreType) { - return true; - } - if (rwOp1->coreType == pto::TCoreType::CUBE_AND_VECTOR) { - return true; - } - return false; -} - -bool Solver::checkSkipParallelLoop(Occurrence *occ1, Occurrence *occ2) { - if (!isBackwardSync(occ1, occ2)) { - return false; - } - auto [parOcc1, parOcc2] = Occurrence::getLCAPair(occ1, occ2); - assert(parOcc1 != nullptr && parOcc2 != nullptr); - auto *parentLCALoopOcc = Occurrence::getParentloop(parOcc1); - assert(parentLCALoopOcc != nullptr); - auto *parentLCALoopOp = llvm::cast(parentLCALoopOcc->op); - return parentLCALoopOp->isParallel; -} - -// Check whether occurrences belong to impossible (if-else) pairing. -bool Solver::checkImpossibleOccPair(Occurrence *occ1, Occurrence *occ2) { - assert(occ1 != nullptr && occ2 != nullptr); - if (occ1->op == occ2->op) { - return false; - } - auto [parOcc1, parOcc2] = Occurrence::getLCAPair(occ1, occ2); - assert(parOcc1 != nullptr && parOcc2 != nullptr); - bool isIfElseSituation = - parOcc1->parentOcc != nullptr && - parOcc1->parentOcc == parOcc2->parentOcc && - llvm::isa_and_present(parOcc1->parentOcc->op); - return isIfElseSituation; -} - -// Detect whether occ1 and occ2 have already been covered by an earlier sync. -bool Solver::checkAlreadySynced(Occurrence *occ1, Occurrence *occ2) { - assert(occ1 != nullptr && occ2 != nullptr); - assert(occ1->op != nullptr && occ2->op != nullptr); - - auto [parOcc1, parOcc2] = Occurrence::getLCAPair(occ1, occ2); - assert(parOcc1 != nullptr && parOcc2 != nullptr); - assert(parOcc1->parentOcc != nullptr && parOcc2->parentOcc != nullptr); - - auto [parOp1, parOp2] = OperationBase::getLCAPair(occ1->op, occ2->op); - assert(parOp1 != nullptr && parOp2 != nullptr); - assert(parOp1->parentOp != nullptr && parOp2->parentOp != nullptr); - - auto *parentLoop = OperationBase::getParentloop(parOcc1->op); - auto *curLoop = OperationBase::getParentloop(parOp1); - if (parentLoop == nullptr || parentLoop == curLoop) { - return false; - } - - assert(curLoop != nullptr); - assert(parentLoop->isProperAncestor(curLoop)); - while (curLoop != parentLoop) { - if (!llvm::cast(curLoop)->isParallel) { - return true; - } - curLoop = OperationBase::getParentloop(curLoop); - assert(curLoop != nullptr); - } - return false; -} - -// Unit-flag reuse check between two RWOperations. -bool Solver::checkAlreadySyncedWithUnitFlag(Occurrence *occ1, - Occurrence *occ2) { - assert(occ1 != nullptr && occ2 != nullptr); - if (!options.enableUnitFlagFeature) { - return false; - } - if (!occ1->hasUnitFlagFeat || !occ2->hasUnitFlagFeat) { - return false; - } - llvm::DenseSet visited; - DEBUG_WITH_TYPE("gss-sync-solver-check-unit-flag", { - llvm::dbgs() << "unit-flag-step: " << occ1->syncIrIndex << ' ' - << occ1->op->str(0, false) << "\n"; - }); - Occurrence *curOcc = occ1->unitFlagInfo.linkedElementAsSet; - while (curOcc != nullptr) { - DEBUG_WITH_TYPE("gss-sync-solver-check-unit-flag", { - llvm::dbgs() << "unit-flag-step: " << curOcc->syncIrIndex << ' ' - << curOcc->op->str(0, false) << "\n"; - }); - auto [it, isInserted] = visited.insert(curOcc); - if (!isInserted) { - break; - } - if (curOcc == occ2) { - return true; - } - curOcc = curOcc->unitFlagInfo.linkedElementAsSet; - } - return false; -} - -bool Solver::ignoreMemoryConflict(RWOperation *rwOp1, RWOperation *rwOp2, - const MemInfo &memInfo1, - const MemInfo &memInfo2) { - if (options.isIntraCoreMode()) { - if (memInfo1.isWorkSpace && memInfo2.isWorkSpace) { - if (options.intraCoreIgnoreWorkSpaceFunctionArguments) { - return true; - } - } - } - return false; -} - -bool Solver::checkMemInfoConflict(RWOperation *rwOp1, RWOperation *rwOp2, - const MemInfo &memInfo1, - const MemInfo &memInfo2, - std::optional lcmLen, - std::optional eventIdNum) { - if (ignoreMemoryConflict(rwOp1, rwOp2, memInfo1, memInfo2)) { - return false; - } - return MemInfo::checkConflict(memInfo1, memInfo2, lcmLen, eventIdNum); -} - -bool Solver::checkMemInfoConflict( - RWOperation *rwOp1, RWOperation *rwOp2, - const llvm::SmallVector &memInfoList1, - const llvm::SmallVector &memInfoList2, - std::optional lcmLen, std::optional eventIdNum) { - for (auto &memInfo1 : memInfoList1) { - for (auto &memInfo2 : memInfoList2) { - if (checkMemInfoConflict(rwOp1, rwOp2, memInfo1, memInfo2, lcmLen, - eventIdNum)) { - return true; - } - } - } - return false; -} - -// High-level wrapper computing pipe pairs that represent memory conflicts -// between two RW ops. -llvm::SmallVector> -Solver::checkMemoryConflicts(RWOperation *rwOp1, RWOperation *rwOp2) { - assert(rwOp1 != nullptr && rwOp2 != nullptr); - auto [it, isInserted] = checkMemoryConflictsMem.insert({{rwOp1, rwOp2}, {}}); - if (!isInserted) { - return it->second; - } - auto coreSrc = rwOp1->coreType; - auto coreDst = rwOp2->coreType; - if (options.isCrossCoreMode()) { - if (coreDst == pto::TCoreType::CUBE_AND_VECTOR) { - coreDst = (coreSrc == pto::TCoreType::VECTOR) ? pto::TCoreType::CUBE - : pto::TCoreType::VECTOR; - } - assert(coreSrc == pto::TCoreType::VECTOR || - coreSrc == pto::TCoreType::CUBE); - assert(coreDst == pto::TCoreType::VECTOR || - coreDst == pto::TCoreType::CUBE); - } - llvm::SetVector> collectedConflictsSet; - if (checkMemInfoConflict(rwOp1, rwOp2, rwOp1->readMemInfo, - rwOp2->writeMemInfo)) { - collectedConflictsSet.insert({CorePipeInfo(coreSrc, rwOp1->pipeRead), - CorePipeInfo(coreDst, rwOp2->pipeWrite)}); - } - if (checkMemInfoConflict(rwOp1, rwOp2, rwOp1->writeMemInfo, - rwOp2->readMemInfo)) { - collectedConflictsSet.insert({CorePipeInfo(coreSrc, rwOp1->pipeWrite), - CorePipeInfo(coreDst, rwOp2->pipeRead)}); - } - if (checkMemInfoConflict(rwOp1, rwOp2, rwOp1->writeMemInfo, - rwOp2->writeMemInfo)) { - collectedConflictsSet.insert({CorePipeInfo(coreSrc, rwOp1->pipeWrite), - CorePipeInfo(coreDst, rwOp2->pipeWrite)}); - } - llvm::SmallVector> collectedConflicts( - collectedConflictsSet.begin(), collectedConflictsSet.end()); - return it->second = collectedConflicts; -} - -bool Solver::checkMemoryConflictBetweenOccExclusive( - Occurrence *occ1, Occurrence *occ2, - std::function filter) { - assert(occ1 != nullptr && occ2 != nullptr); - auto *rwOp1 = llvm::dyn_cast_if_present(occ1->op); - auto *rwOp2 = llvm::dyn_cast_if_present(occ2->op); - assert(rwOp1 != nullptr && rwOp2 != nullptr); - for (int i = occ1->syncIrEndIndex; i < occ2->syncIrIndex; i++) { - if (auto *otherOp = llvm::dyn_cast_if_present(syncIr[i]->op)) { - if (!filter(otherOp)) { - continue; - } - if (!checkMemoryConflicts(rwOp1, otherOp).empty()) { - return true; - } - if (!checkMemoryConflicts(rwOp2, otherOp).empty()) { - return true; - } - } - } - return false; -} - -std::optional -Solver::getMultiBufferLoop(RWOperation *rwOp1, RWOperation *rwOp2, - const llvm::SmallVector &memInfoList1, - const llvm::SmallVector &memInfoList2) { - std::optional multibufferLoop; - for (auto &memInfo1 : memInfoList1) { - for (auto &memInfo2 : memInfoList2) { - if (checkMemInfoConflict(rwOp1, rwOp2, memInfo1, memInfo2)) { - if (!memInfo1.pointerLikeInfo.has_value() || - !memInfo2.pointerLikeInfo.has_value()) { - return {}; - } - auto multibufferLoop1 = memInfo1.pointerLikeInfo->parentLoop; - auto multibufferLoop2 = memInfo2.pointerLikeInfo->parentLoop; - if (multibufferLoop1 == nullptr || - multibufferLoop1 != multibufferLoop2) { - return {}; - } - if (multibufferLoop.has_value() && - multibufferLoop.value() != multibufferLoop1) { - return {}; - } - multibufferLoop = multibufferLoop1; - } - } - } - return multibufferLoop; -} - -std::optional -Solver::getMultiBufferLoop(RWOperation *rwOp1, RWOperation *rwOp2) { - std::optional multibufferLoop; - if (checkMemInfoConflict(rwOp1, rwOp2, rwOp1->readMemInfo, - rwOp2->writeMemInfo)) { - auto curMultibufferLoop = getMultiBufferLoop( - rwOp1, rwOp2, rwOp1->readMemInfo, rwOp2->writeMemInfo); - if (multibufferLoop.has_value() && - multibufferLoop.value() != curMultibufferLoop) { - return {}; - } - multibufferLoop = curMultibufferLoop; - } - if (checkMemInfoConflict(rwOp1, rwOp2, rwOp1->writeMemInfo, - rwOp2->readMemInfo)) { - auto curMultibufferLoop = getMultiBufferLoop( - rwOp1, rwOp2, rwOp1->writeMemInfo, rwOp2->readMemInfo); - if (multibufferLoop.has_value() && - multibufferLoop.value() != curMultibufferLoop) { - return {}; - } - multibufferLoop = curMultibufferLoop; - } - if (checkMemInfoConflict(rwOp1, rwOp2, rwOp1->writeMemInfo, - rwOp2->writeMemInfo)) { - auto curMultibufferLoop = getMultiBufferLoop( - rwOp1, rwOp2, rwOp1->writeMemInfo, rwOp2->writeMemInfo); - if (multibufferLoop.has_value() && - multibufferLoop.value() != curMultibufferLoop) { - return {}; - } - multibufferLoop = curMultibufferLoop; - } - return multibufferLoop; -} - -std::optional -Solver::getMultiBufferEventIdInfo(Occurrence *occ1, Occurrence *occ2, - RWOperation *rwOp1, RWOperation *rwOp2) { - assert(rwOp1 != nullptr && rwOp2 != nullptr); - - int64_t lcm = 1; - int64_t minWriteSize = LONG_MAX; - LoopLikeOpInterface multibufferLoop{nullptr}; - - if (options.isTestMode()) { - auto *parLoop1 = occ1->getParentOfType(); - auto *parLoop2 = occ2->getParentOfType(); - if (!parLoop1 || parLoop1 != parLoop2) { - return {}; - } - auto [setOcc, waitOcc] = getSetWaitOcc(occ1, occ2); - if (!parLoop1->isProperAncestor(setOcc) || - !parLoop1->isProperAncestor(waitOcc)) { - return {}; - } - } else { - auto multibufferLoopOpt = getMultiBufferLoop(rwOp1, rwOp2); - if (!multibufferLoopOpt.has_value() || !multibufferLoopOpt.value()) { - return {}; - } - multibufferLoop = multibufferLoopOpt.value(); - assert(multibufferLoop != nullptr); - auto [setOcc, waitOcc] = getSetWaitOcc(occ1, occ2); - if (!setOcc->getParentWithOp(multibufferLoop, - /*assertExists=*/false) || - !waitOcc->getParentWithOp(multibufferLoop, - /*assertExists=*/false)) { - return {}; - } - } - - for (auto &memInfo1 : rwOp1->readMemInfo) { - for (auto &memInfo2 : rwOp2->writeMemInfo) { - if (checkMemInfoConflict(rwOp1, rwOp2, memInfo1, memInfo2)) { - int64_t curLcm = std::lcm(memInfo1.getSz(), memInfo2.getSz()); - lcm = std::lcm(lcm, curLcm); - minWriteSize = std::min(minWriteSize, memInfo2.getSz()); - } - } - } - for (auto &memInfo1 : rwOp1->writeMemInfo) { - for (auto &memInfo2 : rwOp2->readMemInfo) { - if (checkMemInfoConflict(rwOp1, rwOp2, memInfo1, memInfo2)) { - int64_t curLcm = std::lcm(memInfo1.getSz(), memInfo2.getSz()); - lcm = std::lcm(lcm, curLcm); - minWriteSize = std::min(minWriteSize, memInfo1.getSz()); - } - } - } - for (auto &memInfo1 : rwOp1->writeMemInfo) { - for (auto &memInfo2 : rwOp2->writeMemInfo) { - if (checkMemInfoConflict(rwOp1, rwOp2, memInfo1, memInfo2)) { - int64_t curLcm = std::lcm(memInfo1.getSz(), memInfo2.getSz()); - lcm = std::lcm(lcm, curLcm); - minWriteSize = std::min(minWriteSize, memInfo1.getSz()); - minWriteSize = std::min(minWriteSize, memInfo2.getSz()); - } - } - } - - // In case no write sizes were positive. - if (minWriteSize == LONG_MAX) { - minWriteSize = 1; - return {}; - } - - int64_t eventIdNum = minWriteSize; - for (; eventIdNum >= 1; eventIdNum--) { - // llvm::dbgs() << "checking event-id-num: " << eventIdNum << '\n'; - int64_t curLcm = std::lcm(lcm, eventIdNum); - bool okRW = !checkMemInfoConflict(rwOp1, rwOp2, rwOp1->readMemInfo, - rwOp2->writeMemInfo, curLcm, eventIdNum); - bool okWR = !checkMemInfoConflict(rwOp1, rwOp2, rwOp1->writeMemInfo, - rwOp2->readMemInfo, curLcm, eventIdNum); - bool okWW = !checkMemInfoConflict(rwOp1, rwOp2, rwOp1->writeMemInfo, - rwOp2->writeMemInfo, curLcm, eventIdNum); - if (okRW && okWR && okWW) { - break; - } - } - if (eventIdNum <= 1) { - return {}; - } - EventIdInfo eventIdInfo(eventIdNum); - eventIdInfo.multibufferLoop = multibufferLoop; - return eventIdInfo; -} - -std::optional -Solver::checkMultiBufferEventIdInfo(Occurrence *occ1, Occurrence *occ2, - RWOperation *rwOp1, RWOperation *rwOp2) { - assert(rwOp1 != nullptr && rwOp2 != nullptr); - if (!options.isTestMode()) { - if (!checkAllParentLoopsAreForLoops(rwOp1->op) || - !checkAllParentLoopsAreForLoops(rwOp2->op)) { - return {}; - } - } - if (auto eventIdInfo = getMultiBufferEventIdInfo(occ1, occ2, rwOp1, rwOp2)) { - return eventIdInfo; - } - return {}; -} - -std::optional -Solver::checkCVMultiBufferUnrollEventIdInfo(RWOperation *rwOp1, - RWOperation *rwOp2) { - assert(rwOp1 != nullptr && rwOp2 != nullptr); - if (!options.isCrossCoreMode()) { - return {}; - } - auto *parentLoop1 = rwOp1->getParentOfType(); - auto *parentLoop2 = rwOp2->getParentOfType(); - while (parentLoop1 != nullptr && !parentLoop1->multibufferUnrollNum) { - parentLoop1 = parentLoop1->getParentOfType(); - } - while (parentLoop2 != nullptr && !parentLoop2->multibufferUnrollNum) { - parentLoop2 = parentLoop2->getParentOfType(); - } - if (!parentLoop1 || !parentLoop2) { - return {}; - } - if (auto *parCond1 = rwOp1->getParentOfType()) { - if (!parCond1->isProperAncestor(rwOp2)) { - return {}; - } - } - if (auto *parCond2 = rwOp2->getParentOfType()) { - if (!parCond2->isProperAncestor(rwOp1)) { - return {}; - } - } - assert(parentLoop1->multibufferUnrollNum.value() == - parentLoop2->multibufferUnrollNum.value()); - EventIdInfo eventIdInfo; - eventIdInfo.eventIdNum = parentLoop1->multibufferUnrollNum.value(); - eventIdInfo.multibufferUnrollLoop1 = - cast(parentLoop1->op); - eventIdInfo.multibufferUnrollLoop2 = - cast(parentLoop2->op); - return eventIdInfo; -} - -std::optional -Solver::checkCVMultiBufferPreloadEventIdInfo(RWOperation *rwOp1, - RWOperation *rwOp2) { - assert(rwOp1 != nullptr && rwOp2 != nullptr); - if (!options.isCrossCoreMode()) { - return {}; - } - auto *parentScope1 = rwOp1->getParentOfType(); - auto *parentScope2 = rwOp2->getParentOfType(); - while (parentScope1 != nullptr && !parentScope1->maxPreloadNum.has_value()) { - parentScope1 = parentScope1->getParentOfType(); - } - while (parentScope2 != nullptr && !parentScope2->maxPreloadNum.has_value()) { - parentScope2 = parentScope2->getParentOfType(); - } - if (!parentScope1 || !parentScope2) { - return {}; - } - if (auto *parCond1 = rwOp1->getParentOfType()) { - if (!parCond1->isProperAncestor(rwOp2)) { - return {}; - } - } - if (auto *parCond2 = rwOp2->getParentOfType()) { - if (!parCond2->isProperAncestor(rwOp1)) { - return {}; - } - } - - auto *parentLoop1 = parentScope1->getParentOfType(); - auto *parentLoop2 = parentScope2->getParentOfType(); - if (parentLoop1 == nullptr || parentLoop1 != parentLoop2) { - return {}; - } - - assert(parentScope1->preloadNum.has_value()); - assert(parentScope2->preloadNum.has_value()); - assert(parentScope1->maxPreloadNum.value() == - parentScope2->maxPreloadNum.value()); - - auto parentForLoop = llvm::dyn_cast_if_present(parentLoop1->op); - assert(parentForLoop != nullptr); - - EventIdInfo eventIdInfo; - eventIdInfo.eventIdNum = parentScope1->maxPreloadNum.value(); - eventIdInfo.preloadOffset1 = parentScope1->maxPreloadNum.value() - - parentScope1->preloadNum.value() - 1; - eventIdInfo.preloadOffset2 = parentScope2->maxPreloadNum.value() - - parentScope2->preloadNum.value() - 1; - eventIdInfo.multibufferLoop = parentForLoop; - return eventIdInfo; -} - -// Determine required event id count and optional multibuffer loop parent for -// occurrences. -EventIdInfo Solver::getEventIdInfo(Occurrence *occ1, Occurrence *occ2, - RWOperation *rwOp1, RWOperation *rwOp2, - CorePipeInfo corePipeSrc, - CorePipeInfo corePipeDst) { - assert(occ1 != nullptr && occ2 != nullptr); - assert(rwOp1 != nullptr && rwOp2 != nullptr); - EventIdInfo singleEventId(1); - if (!isBackwardSync(occ1, occ2)) { - return singleEventId; - } - if (auto eventIdInfo = checkCVMultiBufferUnrollEventIdInfo(rwOp1, rwOp2)) { - return eventIdInfo.value(); - } - if (auto eventIdInfo = checkCVMultiBufferPreloadEventIdInfo(rwOp1, rwOp2)) { - return eventIdInfo.value(); - } - if (auto eventIdInfo = - checkMultiBufferEventIdInfo(occ1, occ2, rwOp1, rwOp2)) { - return eventIdInfo.value(); - } - return singleEventId; -} - -// Graph-based check to determine if adding a sync between occ1 and occ2 would -// block progress. Uses GraphSolver (Dijkstra) to estimate minimal reachable -// index. -bool Solver::checkGraphConflict( - Occurrence *occ1, Occurrence *occ2, CorePipeInfo corePipeSrc, - CorePipeInfo corePipeDst, EventIdInfo eventIdInfo, - std::optional startIndex, std::optional endIndex, - const llvm::SmallVector &extraConflictPairs, - const llvm::SmallVector &ignoreConflictPairs) { - assert(occ1 != nullptr && occ2 != nullptr); - if (!startIndex.has_value()) { - startIndex = occ1->endIndex; - } - if (!endIndex.has_value()) { - endIndex = occ2->startIndex; - } - GraphSolver graphSolver(options); - llvm::DenseSet visited; - auto handleConflictPair = [&](ConflictPair *conflictPair) { - if (conflictPair->couldNotRun) { - return; - } - if (conflictPair->endIndex < startIndex.value() || - conflictPair->startIndex > endIndex.value()) { - return; - } - if (conflictPair->isInnerBackward) { - if ((eventIdInfo.eventIdNum * eventIdInfo.eventIdRepeatNum) < - (conflictPair->eventIdInfo.eventIdNum * - conflictPair->eventIdInfo.eventIdRepeatNum)) { - return; - } - } - if (llvm::find(ignoreConflictPairs, conflictPair) != - ignoreConflictPairs.end()) { - return; - } - auto [it, isInserted] = visited.insert(conflictPair); - if (!isInserted) { - return; - } - DEBUG_WITH_TYPE("gss-sync-solver-check-graph-conflict", { - llvm::dbgs() << "add-conflict-pair: " << conflictPair->str() << '\n'; - }); - graphSolver.addConflictPair(conflictPair); - }; - - for (auto *parOcc : occ1->getAllParents()) { - if (scopeOccChosenConflicts.contains(parOcc)) { - for (auto *conflictPair : scopeOccChosenConflicts[parOcc]) { - handleConflictPair(conflictPair); - } - } - } - for (auto *parOcc : occ2->getAllParents()) { - if (scopeOccChosenConflicts.contains(parOcc)) { - for (auto *conflictPair : scopeOccChosenConflicts[parOcc]) { - handleConflictPair(conflictPair); - } - } - } - for (auto &[scopeOccPair, chosenConflicts] : scopeOccPairChosenConflicts) { - auto [scopeOcc1, scopeOcc2] = scopeOccPair; - if (scopeOcc1->isProperAncestor(occ1) && - scopeOcc2->isProperAncestor(occ2)) { - for (auto *conflictPair : chosenConflicts) { - handleConflictPair(conflictPair); - } - } - } - for (auto *parOcc : occ1->getAllParents()) { - if (persistentScopeOccChosenConflicts.contains(parOcc)) { - for (auto *conflictPair : persistentScopeOccChosenConflicts[parOcc]) { - handleConflictPair(conflictPair); - } - } - } - for (auto *parOcc : occ2->getAllParents()) { - if (persistentScopeOccChosenConflicts.contains(parOcc)) { - for (auto *conflictPair : persistentScopeOccChosenConflicts[parOcc]) { - handleConflictPair(conflictPair); - } - } - } - for (auto *conflictPair : extraConflictPairs) { - handleConflictPair(conflictPair); - } - std::optional mnDistance; - if (options.enableUnitFlagFeature) { - mnDistance = graphSolver.runDijkstraUnitFlagEnabled( - occ1, occ2, corePipeSrc, corePipeDst, startIndex.value(), - endIndex.value()); - } else { - mnDistance = graphSolver.runDijkstra(corePipeSrc, corePipeDst, - startIndex.value(), endIndex.value()); - } - return !mnDistance.has_value() || mnDistance.value() > endIndex.value(); -} - -bool Solver::checkSyncOpsConflicts(ConflictPair *conflictPair1, - ConflictPair *conflictPair2) { - if (conflictPair1->isBarrier() || conflictPair2->isBarrier()) { - return false; - } - if (conflictPair1->startIndex > conflictPair2->startIndex) { - std::swap(conflictPair1, conflictPair2); - } - if (conflictPair1->startIndex >= conflictPair2->startIndex || - conflictPair1->endIndex >= conflictPair2->endIndex) { - return true; - } - bool result = false; - if (conflictPair1->setCorePipeInfo != conflictPair2->setCorePipeInfo) { - auto corePipeSrc = conflictPair1->setCorePipeInfo; - auto corePipeDst = conflictPair2->setCorePipeInfo; - Occurrence *occ1 = conflictPair1->setOcc; - Occurrence *occ2 = conflictPair2->setOcc; - auto startIndex = conflictPair1->startIndex + 1; - auto endIndex = conflictPair2->startIndex; - conflictPair1->startIndex += 1; - assert(occ1 != nullptr && occ2 != nullptr); - result = result || - checkGraphConflict(occ1, occ2, corePipeSrc, corePipeDst, - conflictPair1->eventIdInfo, startIndex, - endIndex, {conflictPair1}, {conflictPair2}); - conflictPair1->startIndex -= 1; - } - if (conflictPair1->waitCorePipeInfo != conflictPair2->waitCorePipeInfo) { - auto corePipeSrc = conflictPair1->waitCorePipeInfo; - auto corePipeDst = conflictPair2->waitCorePipeInfo; - Occurrence *occ1 = conflictPair1->waitOcc; - Occurrence *occ2 = conflictPair2->waitOcc; - auto startIndex = conflictPair1->endIndex; - auto endIndex = conflictPair2->endIndex - 1; - conflictPair2->endIndex -= 1; - assert(occ1 != nullptr && occ2 != nullptr); - result = result || - checkGraphConflict(occ1, occ2, corePipeSrc, corePipeDst, - conflictPair1->eventIdInfo, startIndex, - endIndex, {conflictPair1}, {conflictPair2}); - conflictPair2->endIndex += 1; - } - DEBUG_WITH_TYPE("gss-check-sync-ops-conflicts", { - if (result) { - llvm::dbgs() << "sync-ops-conflict-found: " << "\n"; - llvm::dbgs() << " " << conflictPair1->str() << '\n'; - llvm::dbgs() << " " << conflictPair2->str() << '\n'; - } - }); - return result; -} - -// Check whether two ConflictPair entries conflict in pipe and time ranges. -bool Solver::checkIntersect(ConflictPair *conflictPair1, - ConflictPair *conflictPair2) { - assert(conflictPair1 != nullptr && conflictPair2 != nullptr); - if (conflictPair1 == conflictPair2) { - return false; - } - if (conflictPair1->isBarrier() || conflictPair2->isBarrier()) { - return false; - } - if (conflictPair1->dontCheckForConflict || - conflictPair2->dontCheckForConflict) { - return false; - } - if (options.isCrossCoreMode()) { - return checkSyncOpsConflicts(conflictPair1, conflictPair2); - } - if (conflictPair1->setCorePipeInfo != conflictPair2->setCorePipeInfo || - conflictPair1->waitCorePipeInfo != conflictPair2->waitCorePipeInfo) { - return false; - } - for (auto [l1, r1] : getRanges(conflictPair1)) { - for (auto [l2, r2] : getRanges(conflictPair2)) { - if (checkRangesIntersect(l1, r1 + 1, l2, r2 + 1)) { - return true; - } - } - } - return false; -} - -// Obtain available event ids while accounting for already chosen conflicts. -std::vector -Solver::getIntersectingConflictPairs(ConflictPair *conflictPair) { - assert(conflictPair != nullptr); - if (conflictPair->isBarrier()) { - return {}; - } - if (conflictPair->dontCheckForConflict) { - return {}; - } - std::vector intersectingConflictPairs; - for (auto &curConflictPair : chosenConflictedPairs) { - if (checkIntersect(conflictPair, curConflictPair.get())) { - intersectingConflictPairs.push_back(curConflictPair.get()); - } - } - for (auto &curConflictPair : persistentChosenConflictedPairs) { - if (checkIntersect(conflictPair, curConflictPair.get())) { - intersectingConflictPairs.push_back(curConflictPair.get()); - } - } - return intersectingConflictPairs; -} - -// Processed-pair tracking helpers. -bool Solver::checkVisited(Occurrence *occ1, Occurrence *occ2) { - auto [it, isInserted] = processedOccPairs.insert(std::make_pair(occ1, occ2)); - return !isInserted; -} - -bool Solver::checkSkippable(bool reverseOrder, Occurrence *occ) { - return skipOcc[reverseOrder].contains(occ); -} - -// Synced-pair memoization helpers. -EventIdNode *Solver::getOldEventIdNodeIfExists(ConflictPair *conflictPair) { - assert(conflictPair != nullptr); - auto oldConflictPairs = getMemorizedSyncedPairs(conflictPair); - if (oldConflictPairs.empty()) { - return {}; - } - ConflictPair *oldConflictPair = *oldConflictPairs.begin(); - assert(oldConflictPair != nullptr && oldConflictPair->eventIdNode != nullptr); - return oldConflictPair->eventIdNode; -} - -llvm::DenseSet -Solver::getMemorizedSyncedPairs(ConflictPair *conflictPair) { - auto key = std::make_tuple( - conflictPair->backwardSyncLoopOp, conflictPair->op1, conflictPair->op2, - conflictPair->setCorePipeInfo, conflictPair->waitCorePipeInfo); - return syncedPairs[key]; -} - -void Solver::memorizeSyncedPair(ConflictPair *conflictPair) { - auto key = std::make_tuple( - conflictPair->backwardSyncLoopOp, conflictPair->op1, conflictPair->op2, - conflictPair->setCorePipeInfo, conflictPair->waitCorePipeInfo); - syncedPairs[key].insert(conflictPair); -#ifndef NDEBUG - for (auto *oldConflictPair : syncedPairs[key]) { - assert(oldConflictPair->eventIdNode == conflictPair->eventIdNode); - } -#endif -} - -void Solver::forgetSyncedPair(ConflictPair *conflictPair) { - assert(conflictPair != nullptr); - auto key = std::make_tuple( - conflictPair->backwardSyncLoopOp, conflictPair->op1, conflictPair->op2, - conflictPair->setCorePipeInfo, conflictPair->waitCorePipeInfo); - syncedPairs[key].erase(conflictPair); -} - -void Solver::memorizeReusedSyncedPair(ConflictPair *conflictPair, - ConflictPair *reusedConflictPair) { - assert(conflictPair != nullptr); - replacedWithReusableSyncedPairs[{ - conflictPair->backwardSyncLoopOp, conflictPair->op1, conflictPair->op2, - conflictPair->setCorePipeInfo, conflictPair->waitCorePipeInfo}] = - reusedConflictPair; -} - -bool Solver::skipMMad1DecomposedLoopOpt(Occurrence *occ1, Occurrence *occ2) { - auto *parentLoopOp1 = OperationBase::getParentloop(occ1->op); - auto *parentLoopOp2 = OperationBase::getParentloop(occ2->op); - if (parentLoopOp1 != nullptr && parentLoopOp2 != nullptr) { - if (parentLoopOp1 != parentLoopOp2) { - if (isa(parentLoopOp1) && - isa(parentLoopOp2)) { - return true; - } - } - } - return false; -} - -std::optional> -Solver::checkAndApplyMmadl0LoopOpt(ConflictPair *conflictPair, Occurrence *occ1, - Occurrence *occ2, Occurrence *parOcc1, - Occurrence *parOcc2) { - if (!options.decomposeMmadl1Op) { - return {}; - } - if (occ1->parentOcc != nullptr && occ1->parentOcc->parentOcc != nullptr && - occ1->parentOcc->parentOcc->parentOcc == parOcc1 && - llvm::isa_and_present( - occ1->op) && - llvm::isa_and_present( - occ1->parentOcc->parentOcc->op)) { - conflictPair->setOnLastIterOnly = true; - return std::make_pair(occ1, parOcc2); - } - if (!conflictPair->isInnerBackward && occ2->parentOcc != nullptr && - occ2->parentOcc->parentOcc != nullptr && - occ2->parentOcc->parentOcc->parentOcc == parOcc2 && - llvm::isa_and_present( - occ2->op) && - llvm::isa_and_present( - occ2->parentOcc->parentOcc->op)) { - conflictPair->waitOnFirstIterOnly = true; - return std::make_pair(parOcc1, occ2); - } - return {}; -} - -std::optional Solver::checkUnitFlagPatterns(Occurrence *occ1, - Occurrence *occ2) { - return {}; -} - -Occurrence *Solver::getBeforePlaceHolderOcc(Occurrence *occ) { - assert(occ != nullptr); - assert(llvm::isa_and_present(occ->op)); - int index = occ->syncIrIndex - 1; - assert(0 <= index && index < static_cast(syncIr.size())); - auto *placeHolderOcc = syncIr[index].get(); -#ifndef NDEBUG - auto *placeHolderOp = llvm::dyn_cast(placeHolderOcc->op); - assert(placeHolderOp != nullptr); - assert(placeHolderOp->beforeOp == occ->op); -#endif - return placeHolderOcc; -} - -Occurrence *Solver::getAfterPlaceHolderOcc(Occurrence *occ) { - assert(occ != nullptr); - assert(llvm::isa_and_present(occ->op)); - int index = occ->syncIrEndIndex; - assert(0 <= index && index < static_cast(syncIr.size())); - auto *placeHolderOcc = syncIr[index].get(); -#ifndef NDEBUG - auto *placeHolderOp = llvm::dyn_cast(placeHolderOcc->op); - assert(placeHolderOp != nullptr); - assert(placeHolderOp->afterOp == occ->op); -#endif - return placeHolderOcc; -} - -Occurrence *Solver::getScopeBeginPlaceHolderOcc(Occurrence *occ) { - assert(occ != nullptr); - assert(llvm::isa_and_present(occ->op)); - int index = occ->syncIrIndex + 1; - assert(0 <= index && index < static_cast(syncIr.size())); - auto *placeHolderOcc = syncIr[index].get(); -#ifndef NDEBUG - auto *placeHolderOp = llvm::dyn_cast(placeHolderOcc->op); - assert(placeHolderOp != nullptr); - assert(placeHolderOp->scopeBegin == occ->op); -#endif - return placeHolderOcc; -} - -Occurrence *Solver::getScopeEndPlaceHolderOcc(Occurrence *occ) { - assert(occ != nullptr); - assert(llvm::isa_and_present(occ->op)); - int index = occ->syncIrEndIndex - 1; - assert(0 <= index && index < static_cast(syncIr.size())); - auto *placeHolderOcc = syncIr[index].get(); -#ifndef NDEBUG - auto *placeHolderOp = llvm::dyn_cast(placeHolderOcc->op); - assert(placeHolderOp != nullptr); - assert(placeHolderOp->scopeEnd == occ->op); -#endif - return placeHolderOcc; -} - -std::pair -Solver::getSetWaitLCAPairOcc(Occurrence *occ1, Occurrence *occ2) { - assert(occ1 != nullptr && occ2 != nullptr); - - auto [grandParOcc1, grandParOcc2] = Occurrence::getLCAPair(occ1, occ2); - assert(grandParOcc1 != nullptr && grandParOcc2 != nullptr); - assert(grandParOcc1->parentOcc != nullptr && - grandParOcc2->parentOcc != nullptr); - - auto [parOp1, parOp2] = OperationBase::getLCAPair(occ1->op, occ2->op); - assert(parOp1 != nullptr && parOp2 != nullptr); - assert(parOp1->parentOp != nullptr && parOp2->parentOp != nullptr); - assert(parOp1->parentOp == parOp2->parentOp); - - auto *parOcc1 = occ1->getParentWithOp(parOp1->parentOp); - auto *parOcc2 = occ2->getParentWithOp(parOp2->parentOp); - assert(parOcc1 != nullptr && parOcc2 != nullptr); - assert(parOcc1 != occ1 && parOcc2 != occ2); - - auto *setOcc = occ1->getNthParent(occ1->depth - parOcc1->depth - 1); - auto *waitOcc = occ2->getNthParent(occ2->depth - parOcc2->depth - 1); - assert(setOcc != nullptr && waitOcc != nullptr); - assert(parOcc1->isProperAncestor(setOcc)); - assert(parOcc2->isProperAncestor(waitOcc)); - - auto *parLoop = Occurrence::getParentloop(setOcc); - while (parLoop != nullptr && grandParOcc1->isProperAncestor(parLoop)) { - setOcc = parLoop; - waitOcc = Occurrence::getParentloop(waitOcc); - parLoop = Occurrence::getParentloop(setOcc); - } - return std::make_pair(setOcc, waitOcc); -} - -std::pair -Solver::getFixedSetWaitOcc(Occurrence *occ1, Occurrence *occ2) { - // - get setOcc waitOcc where: - // setOcc->op->parent = waitOcc->op->parent = lca(occ1, occ2)->op - auto [setOcc, waitOcc] = getSetWaitLCAPairOcc(occ1, occ2); - - // - check if it's the case of while loop: - // while{ - // before{ - // occ1 - // } - // setOcc; - // waitOcc; - // after{ - // occ2 - // } - // } - // - and fix it to be: - // while{ - // before{ - // occ1 - // setOcc; - // ... - // waitOcc; - // placeHolder - // } - // after{ - // occ2 - // } - // } - if (setOcc->op != waitOcc->op) { - if (auto *parLoopOp = - llvm::dyn_cast_if_present(setOcc->parentOcc->op)) { - if (parLoopOp->body.size() > 1 && !isa(waitOcc->op)) { - auto *placeHolderOcc = getScopeEndPlaceHolderOcc(setOcc); - std::tie(setOcc, waitOcc) = getSetWaitLCAPairOcc(occ1, placeHolderOcc); - } - } - } - - // - check if it's the case of: - // loop(iter-1){ - // condition{ - // true-scope{} - // setOcc() - // false-scope{} - // } - // } - // loop(iter-2){ - // condition{ - // true-scope{} - // waitOcc() - // false-scope{} - // } - // } - // - and fix it to be: - // loop(iter-1){ - // condition{ - // true-scope{} - // false-scope{} - // } - // setOcc() - // } - // loop(iter-2){ - // waitOcc() - // condition{ - // true-scope{} - // false-scope{} - // } - // } - if (isBackwardSync(occ1, occ2)) { - if (setOcc->parentOcc != nullptr) { - if (llvm::isa_and_present(setOcc->parentOcc->op)) { - setOcc = setOcc->parentOcc; - } - } - if (waitOcc->parentOcc != nullptr) { - if (llvm::isa_and_present(waitOcc->parentOcc->op)) { - waitOcc = waitOcc->parentOcc; - } - } - } - - // - for the case of cv-pipelining: - // loop(){ - // op1 - // } {unroll=x} - // setOcc - // waitOcc - // loop(){ - // op2 - // } {unroll=x} - // - and fix it to be: - // loop(){ - // op1 - // setOcc - // } {unroll=x} - // loop(){ - // waitOcc - // op2 - // } {unroll=x} - if (options.isCrossCoreMode()) { - assert(setOcc->op != nullptr && waitOcc->op != nullptr); - auto *forOp1 = llvm::dyn_cast_if_present(setOcc->op); - auto *forOp2 = llvm::dyn_cast_if_present(waitOcc->op); - if (forOp1 != nullptr && forOp2 != nullptr) { - if (forOp1->multibufferUnrollNum && forOp2->multibufferUnrollNum) { - assert(forOp1->multibufferUnrollNum == forOp2->multibufferUnrollNum); - setOcc = occ1->getNthParent(occ1->depth - setOcc->depth - 2); - waitOcc = occ2->getNthParent(occ2->depth - waitOcc->depth - 2); - } - } - } - - // - for the case of cv-pipelining: - // scope(){ - // op1 - // } {preload=x} - // setOcc - // waitOcc - // scope(){ - // op2 - // } {preload=x} - // - and fix it to be: - // scope(){ - // op1 - // setOcc - // } {preload=x} - // scope(){ - // waitOcc - // op2 - // } {preload=x} - if (options.isCrossCoreMode()) { - assert(setOcc->op != nullptr && waitOcc->op != nullptr); - auto *scopeOp1 = llvm::dyn_cast_if_present(setOcc->op); - auto *scopeOp2 = llvm::dyn_cast_if_present(waitOcc->op); - if (scopeOp1 != nullptr && scopeOp2 != nullptr) { - if (scopeOp1->maxPreloadNum && scopeOp2->maxPreloadNum) { - assert(scopeOp1->maxPreloadNum == scopeOp2->maxPreloadNum); - setOcc = occ1->getNthParent(occ1->depth - setOcc->depth - 2); - waitOcc = occ2->getNthParent(occ2->depth - waitOcc->depth - 2); - } - } - } - - // - check if it's the case of: - // { - // op1 - // setOcc - // ... - // waitOcc - // loop(){} - // setOcc - // ... - // waitOcc - // op2 - // } - // - and fix it to be: - // { - // op1 - // setOcc - // ... - // waitOcc - // placeHolder - // loop(){} - // placeHolder - // setOcc - // ... - // waitOcc - // op2 - // } - if (llvm::isa_and_present(setOcc->op)) { - setOcc = getAfterPlaceHolderOcc(setOcc); - } - if (llvm::isa_and_present(waitOcc->op)) { - waitOcc = getBeforePlaceHolderOcc(waitOcc); - } - - return std::make_pair(setOcc, waitOcc); -} - -std::optional> -Solver::getFunctionBlockSetWaitOcc(Occurrence *occ1, Occurrence *occ2) { - assert(occ1 != nullptr && occ2 != nullptr); - auto *parFunctionBlock1 = occ1->getParentOfType(); - auto *parFunctionBlock2 = occ2->getParentOfType(); - if (parFunctionBlock1 == parFunctionBlock2) { - return {}; - } - auto *placeHolderOcc = getScopeBeginPlaceHolderOcc(parFunctionBlock2); - return std::make_pair(placeHolderOcc, occ2); -} - -std::optional> -Solver::getUnlikelyCondSetWaitOcc(Occurrence *occ1, Occurrence *occ2) { - assert(occ1 != nullptr && occ2 != nullptr); - if (options.isCrossCoreMode() && isBackwardSync(occ1, occ2)) { - return {}; - } - if (auto *unlikelyParCondOcc1 = - Occurrence::getUnlikelyParentCondition(occ1)) { - if (!unlikelyParCondOcc1->isProperAncestor(occ2)) { - auto *parentLoopOcc = Occurrence::getParentloop(unlikelyParCondOcc1); - if (parentLoopOcc == nullptr || parentLoopOcc->isProperAncestor(occ2)) { - auto *placeHolderOcc = getScopeEndPlaceHolderOcc( - occ1->getNthParent(occ1->depth - unlikelyParCondOcc1->depth - 1)); - return std::make_pair(occ1, placeHolderOcc); - } - } - } - if (auto *unlikelyParCondOcc2 = - Occurrence::getUnlikelyParentCondition(occ2)) { - if (!unlikelyParCondOcc2->isProperAncestor(occ1)) { - auto *parentLoopOcc = Occurrence::getParentloop(unlikelyParCondOcc2); - if (parentLoopOcc == nullptr || parentLoopOcc->isProperAncestor(occ1)) { - auto *placeHolderOcc = getScopeBeginPlaceHolderOcc( - occ2->getNthParent(occ2->depth - unlikelyParCondOcc2->depth - 1)); - return std::make_pair(placeHolderOcc, occ2); - } - } - } - return {}; -} - -std::pair Solver::getSetWaitOcc(Occurrence *occ1, - Occurrence *occ2) { - if (auto functionBlockOpt = getFunctionBlockSetWaitOcc(occ1, occ2)) { - std::tie(occ1, occ2) = functionBlockOpt.value(); - } - if (auto unlikelyOpt = getUnlikelyCondSetWaitOcc(occ1, occ2)) { - std::tie(occ1, occ2) = unlikelyOpt.value(); - } - return getFixedSetWaitOcc(occ1, occ2); -} - -Occurrence *Solver::getBarrierWaitOcc(Occurrence *occ1, Occurrence *occ2) { - auto [setOcc, waitOcc] = getSetWaitOcc(occ1, occ2); - if (!waitOcc->isProperAncestor(occ2)) { - return waitOcc; - } - auto allParents = occ2->getAllParents(); - while (!allParents.empty() && allParents.back()->isProperAncestor(waitOcc)) { - allParents.pop_back(); - } - while (allParents.size() >= 2 && - llvm::isa_and_present(allParents.back()->op)) { - allParents.pop_back(); - assert(llvm::isa_and_present(allParents.back()->op)); - allParents.pop_back(); - } - waitOcc = !allParents.empty() ? allParents.back() : occ2; - return waitOcc; -} - -void Solver::insertBarrierAllBeforeOcc(Occurrence *occ, bool isUseless, - bool isPersistent) { - assert(occ != nullptr); - auto *rwOp = llvm::dyn_cast_if_present(occ->op); - assert(rwOp != nullptr); - auto conflictPair = std::make_unique( - nullptr, nullptr, rwOp, rwOp, occ, occ, - CorePipeInfo(pto::TCoreType::CUBE_OR_VECTOR, pto::PIPE::PIPE_ALL), - CorePipeInfo(pto::TCoreType::CUBE_OR_VECTOR, pto::PIPE::PIPE_ALL), - occ->startIndex, occ->startIndex); - conflictPair->isUseless = isUseless; - auto *normScopeOcc = occ->parentOcc; - assert(normScopeOcc != nullptr); - LLVM_DEBUG(llvm::dbgs() << (isPersistent ? "is-persistent " : "") - << occ->op->str(0, false) << ' ' - << conflictPair->str() << '\n';); - if (isPersistent) { - persistentScopeOccChosenConflicts[normScopeOcc].insert(conflictPair.get()); - persistentChosenConflictedPairs.push_back(std::move(conflictPair)); - } else { - insertedBarrierAllBefore[occ->op].insert({occ, isUseless}); - scopeOccChosenConflicts[normScopeOcc].insert(conflictPair.get()); - chosenConflictedPairs.push_back(std::move(conflictPair)); - } -} - -void Solver::insertBarrierAllBeforeOp(OperationBase *op, bool isUseless, - bool isPersistent) { - assert(op != nullptr); - for (auto *occ : opAllOccurrences[op]) { - insertBarrierAllBeforeOcc(occ, isUseless, isPersistent); - isUseless = true; - } -} - -// When barrier-all markers need to be chosen, insert them before all -// occurrences for the chosen op. -void Solver::pickAndInsertABarrierAll() { - assert(!insertedBarrierAllBefore.empty()); - OperationBase *chosenOp = nullptr; - for (auto &[op, vec] : insertedBarrierAllBefore) { - if (vec.empty()) { - continue; - } - if (chosenOp == nullptr || chosenOp->id > op->id) { - chosenOp = op; - } - } - assert(chosenOp != nullptr); - insertBarrierAllBeforeOp(chosenOp, /*isUseless=*/false, - /*isPersistent=*/true); -} - -bool Solver::isBackwardSync(Occurrence *occ1, Occurrence *occ2) { - if (occ1->op->id >= occ2->op->id) { - return true; - } - assert(occ1 != nullptr && occ2 != nullptr); - assert(occ1->op != nullptr && occ2->op != nullptr); - auto [parOcc1, parOcc2] = Occurrence::getLCAPair(occ1, occ2); - auto [parOp1, parOp2] = OperationBase::getLCAPair(occ1->op, occ2->op); - return parOcc1->parentOcc->op != parOp1->parentOp; -} - -bool Solver::reuseCmp(ConflictPair *conflictPair1, - ConflictPair *conflictPair2) { - assert(conflictPair1 != nullptr && conflictPair2 != nullptr); - assert(conflictPair1->op1 != nullptr && conflictPair1->op2 != nullptr); - assert(conflictPair2->op1 != nullptr && conflictPair2->op2 != nullptr); - if (conflictPair1->startIndex != conflictPair2->startIndex) { - return conflictPair1->startIndex < conflictPair2->startIndex; - } - if (conflictPair1->endIndex != conflictPair2->endIndex) { - return conflictPair1->endIndex > conflictPair2->endIndex; - } - if (conflictPair1->op1 != conflictPair2->op1) { - return conflictPair1->op1->id > conflictPair2->op1->id; - } - if (conflictPair1->op2 != conflictPair2->op2) { - return conflictPair1->op2->id > conflictPair2->op2->id; - } - return false; -} - -ConflictPair *Solver::getReusableConflictPair( - ConflictPair *conflictPair, - const llvm::DenseSet &conflictPairsSet) { - assert(conflictPair != nullptr); - ConflictPair *ret = nullptr; - for (auto *curConflictPair : conflictPairsSet) { - if (curConflictPair->isBarrier() || curConflictPair->dontReuse) { - continue; - } - if (curConflictPair->op1 != conflictPair->op1 || - curConflictPair->op2 != conflictPair->op2 || - curConflictPair->setCorePipeInfo != conflictPair->setCorePipeInfo || - curConflictPair->waitCorePipeInfo != conflictPair->waitCorePipeInfo) { - continue; - } - if (!checkIntersect(conflictPair, curConflictPair)) { - continue; - } - if (curConflictPair->startIndex >= conflictPair->startIndex) { - continue; - } - if (conflictPair->eventIdNode->eventIdNum < - curConflictPair->eventIdNode->eventIdNum) { - continue; - } - assert(conflictPair->eventIdNode != nullptr); - assert(curConflictPair->eventIdNode != nullptr); - if (conflictPair->eventIdNode->eventIdNum > - curConflictPair->eventIdNode->eventIdNum) { - if (conflictPair->eventIdNode->eventIdNum % - curConflictPair->eventIdNode->eventIdNum) { - continue; - } - } - assert(conflictPair->startIndex <= curConflictPair->endIndex); - assert(curConflictPair->endIndex <= conflictPair->endIndex); - if (ret == nullptr || reuseCmp(ret, curConflictPair)) { - ret = curConflictPair; - } - } - return ret; -} - -bool Solver::reuseConflictPair(ConflictPair *conflictPair, - Occurrence *scopeOcc1, Occurrence *scopeOcc2) { - if (conflictPair->isBarrier()) { - return false; - } - if (scopeOcc1->op != scopeOcc2->op) { - return false; - } - if (!barrierAllPairs.empty()) { - return false; - } - - ConflictPair *oldReusedConflictPair = nullptr; - if (conflictPair->isUseless) { - auto it = replacedWithReusableSyncedPairs.find( - {conflictPair->backwardSyncLoopOp, conflictPair->op1, conflictPair->op2, - conflictPair->setCorePipeInfo, conflictPair->waitCorePipeInfo}); - if (it != replacedWithReusableSyncedPairs.end()) { - oldReusedConflictPair = it->second; - } - } - -#ifndef NDEBUG - if (!conflictPair->isUseless) { - auto it = replacedWithReusableSyncedPairs.find( - {conflictPair->backwardSyncLoopOp, conflictPair->op1, conflictPair->op2, - conflictPair->setCorePipeInfo, conflictPair->waitCorePipeInfo}); - assert(it == replacedWithReusableSyncedPairs.end()); - } -#endif - - if (conflictPair->isUseless && oldReusedConflictPair == nullptr) { - return false; - } - - auto corePipeSrc = conflictPair->setCorePipeInfo; - auto corePipeDst = conflictPair->waitCorePipeInfo; - - if (oldReusedConflictPair == nullptr) { - if (!reusePairs.contains({corePipeSrc, corePipeDst}) || - reusePairs[{corePipeSrc, corePipeDst}] <= - reusedPairs[{corePipeSrc, corePipeDst}]) { - return false; - } - } - - assert(reusePairs.contains(std::make_tuple(corePipeSrc, corePipeDst))); - assert(reusePairs[std::make_tuple(corePipeSrc, corePipeDst)] >= - reusedPairs[std::make_tuple(corePipeSrc, corePipeDst)]); - - ConflictPair *opt1 = nullptr; - ConflictPair *opt2 = nullptr; - ConflictPair *opt3 = nullptr; - ConflictPair *opt4 = nullptr; - ConflictPair *opt5 = nullptr; - - auto it1 = scopeOccChosenConflicts.find(scopeOcc1); - auto it2 = scopeOccChosenConflicts.find(scopeOcc2); - auto it3 = scopeOccPairChosenConflicts.find({scopeOcc1, scopeOcc2}); - auto it4 = persistentScopeOccChosenConflicts.find(scopeOcc1); - auto it5 = persistentScopeOccChosenConflicts.find(scopeOcc2); - - if (it1 != scopeOccChosenConflicts.end()) { - opt1 = getReusableConflictPair(conflictPair, it1->second); - } - if (it2 != scopeOccChosenConflicts.end()) { - opt2 = getReusableConflictPair(conflictPair, it2->second); - } - if (it3 != scopeOccPairChosenConflicts.end()) { - opt3 = getReusableConflictPair(conflictPair, it3->second); - } - if (it4 != persistentScopeOccChosenConflicts.end()) { - opt4 = getReusableConflictPair(conflictPair, it4->second); - } - if (it5 != persistentScopeOccChosenConflicts.end()) { - opt5 = getReusableConflictPair(conflictPair, it5->second); - } - - ConflictPair *reusableConflictPair = nullptr; - for (auto *opt : {opt1, opt2, opt3, opt4, opt5}) { - if (opt != nullptr) { - if (reusableConflictPair == nullptr || - reuseCmp(reusableConflictPair, opt)) { - reusableConflictPair = opt; - } - } - } - - if (reusableConflictPair == nullptr) { - return false; - } - - DEBUG_WITH_TYPE("gss-sync-solver-reuse", { - llvm::dbgs() << "reuse: " << conflictPair->str() << '\n'; - llvm::dbgs() << "with: " << reusableConflictPair->str() << '\n'; - }); - - assert(reusableConflictPair->startIndex < conflictPair->startIndex); - assert(reusableConflictPair->endIndex <= conflictPair->endIndex); - reusableConflictPair->setOp = conflictPair->setOp; - reusableConflictPair->setOcc = conflictPair->setOcc; - reusableConflictPair->startIndex = conflictPair->startIndex; - - if (!conflictPair->isUseless) { - memorizeReusedSyncedPair(conflictPair, reusableConflictPair); - } - - DEBUG_WITH_TYPE("gss-sync-solver-reuse", { - if (oldReusedConflictPair != nullptr) { - llvm::dbgs() << "old-reuse: " << oldReusedConflictPair->str() << '\n'; - } - }); - - if (oldReusedConflictPair != nullptr) { - assert(oldReusedConflictPair->op1 == reusableConflictPair->op1); - assert(oldReusedConflictPair->op2 == reusableConflictPair->op2); - assert(oldReusedConflictPair->waitOp == reusableConflictPair->waitOp); - } - - if (!conflictPair->isUseless) { - reusedPairs[{corePipeSrc, corePipeDst}] += 1; - } - - return true; -} - -std::unique_ptr & -Solver::getEventIdSolverRef(pto::PIPE pipeSrc, pto::PIPE pipeDst) { - if (options.isCrossCoreMode()) { - pipeSrc = pto::PIPE::PIPE_UNASSIGNED; - pipeDst = pto::PIPE::PIPE_UNASSIGNED; - } - auto key = std::make_tuple(pipeSrc, pipeDst); - if (!eventIdSolver.contains(key)) { - int64_t eventIdNumMax = - getHWAvailableEventIdNum(options.syncMode, pipeSrc, pipeDst); - if (options.eventIdNumMax.has_value()) { - eventIdNumMax = std::min(eventIdNumMax, options.eventIdNumMax.value()); - eventIdNumMax = std::max(eventIdNumMax, 1); - } - eventIdSolver[key] = std::make_unique(eventIdNumMax); - } - return eventIdSolver[key]; -} - -bool Solver::checkReuseMultiBufferFlagId(ConflictPair *conflictPair) { - if (options.useDifferentMultiBufferFlagIds) { - return false; - } - if (!conflictPair->isInnerBackward || - conflictPair->eventIdInfo.eventIdNum <= 1 || - conflictPair->movedToOuterLoop) { - return false; - } - auto [setOcc, waitOcc] = - std::tie(conflictPair->setOcc, conflictPair->waitOcc); - auto *backwardSyncLoopOcc = conflictPair->backwardSyncLoopOcc; - assert(backwardSyncLoopOcc != nullptr); - if (auto *parCondOcc1 = setOcc->getParentOfType()) { - if (!parCondOcc1->isProperAncestor(backwardSyncLoopOcc)) { - return false; - } - } - if (auto *parCondOcc2 = waitOcc->getParentOfType()) { - if (!parCondOcc2->isProperAncestor(backwardSyncLoopOcc)) { - return false; - } - } - return true; -} - -void Solver::handleSetWaitConflict(Occurrence *occ1, Occurrence *occ2, - CorePipeInfo corePipeSrc, - CorePipeInfo corePipeDst, - EventIdInfo eventIdInfo, bool isUseless) { - assert(occ1 != nullptr && occ2 != nullptr); - auto *rwOp1 = llvm::dyn_cast_if_present(occ1->op); - auto *rwOp2 = llvm::dyn_cast_if_present(occ2->op); - assert(rwOp1 != nullptr && rwOp2 != nullptr); - assert(corePipeSrc != corePipeDst); - - Loop *parentLCALoopOp{nullptr}; - Occurrence *parentLCALoopOcc{nullptr}; - Occurrence *parentLCALoopBeforePHOcc{nullptr}; - Occurrence *parentLCALoopAfterPHOcc{nullptr}; - auto [setOcc, waitOcc] = getSetWaitOcc(occ1, occ2); - - auto [lcaSetOp, lcaWaitOp] = - OperationBase::getLCAPair(setOcc->op, waitOcc->op); - auto *normScopeOcc1 = setOcc->getParentWithOp(lcaSetOp->parentOp); - auto *normScopeOcc2 = waitOcc->getParentWithOp(lcaWaitOp->parentOp); - assert(normScopeOcc1->op == normScopeOcc2->op); - auto *normScopeOp = normScopeOcc1->op; - assert(normScopeOp != nullptr); - assert(normScopeOp->parentOp != nullptr); - - auto conflictPair = std::make_unique( - rwOp1, rwOp2, setOcc->op, waitOcc->op, setOcc, waitOcc, corePipeSrc, - corePipeDst, setOcc->endIndex, waitOcc->startIndex); - assert(conflictPair->startIndex <= conflictPair->endIndex); - - conflictPair->isUseless = isUseless; - conflictPair->isInnerBackward = isBackwardSync(setOcc, waitOcc); - conflictPair->eventIdInfo = eventIdInfo; - - if (conflictPair->isInnerBackward) { - auto [parOcc1, parOcc2] = Occurrence::getLCAPair(occ1, occ2); - assert(parOcc1 != nullptr && parOcc2 != nullptr); - - parentLCALoopOcc = parOcc1->getParentOfType(); - if (moveBackwardSyncPairsToOutmostLoop) { - while (auto *grandParentLoopOcc = - parentLCALoopOcc->getParentOfType()) { - conflictPair->movedToOuterLoop = true; - parentLCALoopOcc = grandParentLoopOcc; - } - } - assert(parentLCALoopOcc != nullptr); - conflictPair->backwardSyncLoopOcc = parentLCALoopOcc; - - parentLCALoopOp = llvm::dyn_cast(parentLCALoopOcc->op); - assert(parentLCALoopOp != nullptr); - conflictPair->backwardSyncLoopOp = parentLCALoopOp; - - parentLCALoopBeforePHOcc = getBeforePlaceHolderOcc(parentLCALoopOcc); - assert(parentLCALoopBeforePHOcc != nullptr); - parentLCALoopAfterPHOcc = getAfterPlaceHolderOcc(parentLCALoopOcc); - assert(parentLCALoopAfterPHOcc != nullptr); - } - - if (auto setWaitOccs = checkAndApplyMmadl0LoopOpt(conflictPair.get(), occ1, - occ2, setOcc, waitOcc)) { - std::tie(setOcc, waitOcc) = setWaitOccs.value(); - conflictPair->updateSetWaitOccs(setOcc, waitOcc); - } - - if (!conflictPair->isInnerBackward || - disabledMultiEventIdPairs.contains({corePipeSrc, corePipeDst})) { - conflictPair->eventIdInfo = EventIdInfo(1); - } - if (checkReuseMultiBufferFlagId(conflictPair.get())) { - conflictPair->eventIdInfo.eventIdRepeatNum = - conflictPair->eventIdInfo.eventIdNum; - conflictPair->eventIdInfo.eventIdNum = 1; - } - - auto &curEventIdSolver = getEventIdSolverRef( - conflictPair->setCorePipeInfo.pipe, conflictPair->waitCorePipeInfo.pipe); - curEventIdSolver->pushActionNone(); - - auto checkColorable = [&]() -> bool { - if (curEventIdSolver->isColorable()) { - return true; - } - LLVM_DEBUG(llvm::dbgs() << "will-be-converted-to-barrier-all " - << conflictPair->str() << '\n';); - insertBarrierAllBeforeOp(occ2->op, conflictPair->isUseless, - /*isPersistent=*/false); - barrierAllPairs.insert({corePipeSrc, corePipeDst}); - curEventIdSolver->undoActions(); - return false; - }; - - if (auto *oldEventIdNode = getOldEventIdNodeIfExists(conflictPair.get())) { - conflictPair->eventIdNode = oldEventIdNode; - curEventIdSolver->insertConflictPair(oldEventIdNode, conflictPair.get()); - } else { - bool reversedPriority = false; - if (conflictPair->isInnerBackward) { - if (OperationBase::getParentloop(occ1->op) == normScopeOp->parentOp && - OperationBase::getParentloop(occ2->op) == normScopeOp->parentOp) { - reversedPriority = true; - } - } - conflictPair->eventIdNode = curEventIdSolver->createNode( - conflictPair.get(), conflictPair->eventIdInfo.eventIdNum, - reversedPriority); - } - - if (options.reuseSyncPairToSaveEventIds) { - if (reuseConflictPair(conflictPair.get(), normScopeOcc1, normScopeOcc2)) { - curEventIdSolver->undoActions(); - return; - } - } - - auto intersectingConflictPairs = - getIntersectingConflictPairs(conflictPair.get()); - curEventIdSolver->addConflicts(conflictPair.get(), intersectingConflictPairs); - if (!checkColorable()) { - return; - } - - LLVM_DEBUG({ - llvm::dbgs() << conflictPair->str() << '\n'; - if (parentLCALoopOcc != nullptr) { - llvm::dbgs() << parentLCALoopOcc->op->str(0, false) << '\n'; - } - }); - - llvm::SmallVector, Occurrence *>> - extraConflictPairs; - - auto insertExtraConflictPair = [&](Occurrence *setOcc, Occurrence *waitOcc, - Occurrence *parentScope, - bool couldNotRun = false) -> bool { - assert(setOcc != nullptr && waitOcc != nullptr && parentScope != nullptr); - auto extraConflictPair = conflictPair->clone(setOcc, waitOcc); - extraConflictPair->isUseless = true; - extraConflictPair->dontReuse = true; - if (couldNotRun || options.moveOutAndMergeBackwardSyncPairs) { - extraConflictPair->couldNotRun = true; - } - LLVM_DEBUG({ - llvm::dbgs() << "extra-conflict-pair: " << extraConflictPair->str() - << "\n"; - }); - curEventIdSolver->insertConflictPair(conflictPair->eventIdNode, - extraConflictPair.get()); - auto intersectingConflictPairs = - getIntersectingConflictPairs(extraConflictPair.get()); - curEventIdSolver->addConflicts(extraConflictPair.get(), - intersectingConflictPairs); - if (!checkColorable()) { - return false; - } - extraConflictPairs.push_back( - std::make_pair(std::move(extraConflictPair), parentScope)); - return true; - }; - - if (conflictPair->isInnerBackward && conflictPair->eventIdNode != nullptr) { - bool insertOuterBwdConflictPair = false; - if ((conflictPair->eventIdInfo.eventIdNum * - conflictPair->eventIdInfo.eventIdRepeatNum) > 1) { - insertOuterBwdConflictPair = true; - } else if (options.isCrossCoreMode()) { - if (setOcc->parentOcc == nullptr || - setOcc->parentOcc->parentOcc == nullptr || - setOcc->parentOcc->parentOcc->op != parentLCALoopOp) { - insertOuterBwdConflictPair = true; - } else if (waitOcc->parentOcc == nullptr || - waitOcc->parentOcc->parentOcc == nullptr || - waitOcc->parentOcc->parentOcc->op != parentLCALoopOp) { - insertOuterBwdConflictPair = true; - } - } - if (insertOuterBwdConflictPair) { - // insert useless conflictPair to cover the whole loop when having - // multi-eventid backward sync to reserve the eventIds. - if (!insertExtraConflictPair(parentLCALoopBeforePHOcc, - parentLCALoopAfterPHOcc, - parentLCALoopOcc->parentOcc)) { - return; - } - } - } - - if (conflictPair->isInnerBackward && conflictPair->eventIdNode != nullptr) { - // insert header/footer useless conflictPairs to reserve the eventIds. - auto *loopOpOcc1 = getFirstIterOcc(waitOcc, normScopeOcc1); - auto *loopOpOcc2 = getLastIterOcc(setOcc, normScopeOcc2); - if (!insertExtraConflictPair(parentLCALoopBeforePHOcc, loopOpOcc1, - parentLCALoopOcc, /*couldNotRun=*/true)) { - return; - } - if (!insertExtraConflictPair(loopOpOcc2, parentLCALoopAfterPHOcc, - parentLCALoopOcc, /*couldNotRun=*/true)) { - return; - } - } - - bool dontInsert = false; - if (conflictPair->isInnerBackward && normScopeOcc1 != normScopeOcc2) { - auto *parCond = OperationBase::getParentCondition(conflictPair->setOp); - if (auto *conditionOp = llvm::dyn_cast_if_present(parCond)) { - if (parentLCALoopOcc->op->isProperAncestor(conditionOp)) { - scopeOccPairChosenConflicts[{normScopeOcc1, normScopeOcc2}].insert( - conflictPair.get()); - dontInsert = true; - } - } - } - if (!dontInsert) { - assert(parentLCALoopOcc != nullptr || normScopeOcc1 == normScopeOcc2); - scopeOccChosenConflicts[normScopeOcc1].insert(conflictPair.get()); - scopeOccChosenConflicts[normScopeOcc2].insert(conflictPair.get()); - } - - memorizeSyncedPair(conflictPair.get()); - chosenConflictedPairs.push_back(std::move(conflictPair)); - - for (auto &[extraConflictPair, parentScope] : extraConflictPairs) { - scopeOccChosenConflicts[parentScope].insert(extraConflictPair.get()); - chosenConflictedPairs.push_back(std::move(extraConflictPair)); - } - - curEventIdSolver->clearActionStack(); -} - -void Solver::handleBarrierConflict(Occurrence *occ1, Occurrence *occ2, - CorePipeInfo corePipeSrc, - CorePipeInfo corePipeDst, bool isUseless) { - assert(occ1 != nullptr && occ2 != nullptr); - auto *rwOp1 = llvm::dyn_cast_if_present(occ1->op); - auto *rwOp2 = llvm::dyn_cast_if_present(occ2->op); - assert(rwOp1 != nullptr && rwOp2 != nullptr); - - assert(corePipeSrc == corePipeDst); - if (corePipeSrc.pipe == pto::PIPE::PIPE_S) { - return; - } - if (options.isRegBasedArch) { - if (corePipeSrc.pipe == pto::PIPE::PIPE_V || - corePipeSrc.pipe == pto::PIPE::PIPE_M) { - return; - } - } - auto *waitOcc = getBarrierWaitOcc(occ1, occ2); - - auto conflictPair = std::make_unique( - rwOp1, rwOp2, waitOcc->op, waitOcc->op, waitOcc, waitOcc, corePipeSrc, - corePipeDst, waitOcc->startIndex, waitOcc->startIndex); - conflictPair->isUseless = isUseless; - assert(conflictPair->startIndex <= conflictPair->endIndex); - - LLVM_DEBUG({ llvm::dbgs() << conflictPair->str() << '\n'; }); - - auto *normScopeOcc = waitOcc->parentOcc; - scopeOccChosenConflicts[normScopeOcc].insert(conflictPair.get()); - chosenConflictedPairs.push_back(std::move(conflictPair)); -} - -void Solver::handleUnitFlagConflict(Occurrence *occ1, Occurrence *occ2, - CorePipeInfo corePipeSrc, - CorePipeInfo corePipeDst, - UnitFlagInfo unitFlagInfo, bool isUseless) { - assert(occ1 != nullptr && occ2 != nullptr); - auto *rwOp1 = llvm::dyn_cast_if_present(occ1->op); - auto *rwOp2 = llvm::dyn_cast_if_present(occ2->op); - assert(rwOp1 != nullptr && rwOp2 != nullptr); - assert(corePipeSrc != corePipeDst); - - auto *setOcc = occ1; - auto *waitOcc = occ2; - auto *normScopeOcc1 = setOcc->parentOcc; - auto *normScopeOcc2 = waitOcc->parentOcc; - - auto conflictPair = std::make_unique( - rwOp1, rwOp2, setOcc->op, waitOcc->op, setOcc, waitOcc, corePipeSrc, - corePipeDst, setOcc->endIndex, waitOcc->startIndex); - assert(conflictPair->startIndex <= conflictPair->endIndex); - - conflictPair->isUseless = true; - conflictPair->dontReuse = true; - conflictPair->replacedWithUnitFlag = true; - conflictPair->dontCheckForConflict = true; - conflictPair->isInnerBackward = isBackwardSync(setOcc, waitOcc); - -#ifndef NDEBUG - Occurrence *parentLCALoopOcc{nullptr}; - if (conflictPair->isInnerBackward) { - auto [parOcc1, parOcc2] = Occurrence::getLCAPair(occ1, occ2); - assert(parOcc1 != nullptr && parOcc2 != nullptr); - parentLCALoopOcc = Occurrence::getParentloop(parOcc1); - assert(parentLCALoopOcc != nullptr); - } - - LLVM_DEBUG({ - llvm::dbgs() << conflictPair->str() << '\n'; - if (parentLCALoopOcc != nullptr) { - llvm::dbgs() << parentLCALoopOcc->op->str(0, false) << '\n'; - } - }); -#endif - - occ1->unitFlagInfo.merge(unitFlagInfo, occ1, occ2, - /*asSet=*/true, /*asWait=*/false); - occ2->unitFlagInfo.merge(unitFlagInfo, occ1, occ2, - /*asSet=*/false, /*asWait=*/true); - if (!isUseless) { - rwOp1->mergedUnitFlagInfo.merge(unitFlagInfo, /*asSet=*/true, - /*asWait=*/false); - rwOp2->mergedUnitFlagInfo.merge(unitFlagInfo, /*asSet=*/false, - /*asWait=*/true); - } - - scopeOccPairChosenConflicts[{normScopeOcc1, normScopeOcc2}].insert( - conflictPair.get()); - chosenConflictedPairs.push_back(std::move(conflictPair)); -} - -void Solver::handleConflict(Occurrence *occ1, Occurrence *occ2, - RWOperation *rwOp1, RWOperation *rwOp2, - CorePipeInfo corePipeSrc, CorePipeInfo corePipeDst, - EventIdInfo eventIdInfo, bool isUseless) { - if (!checkGraphConflict(occ1, occ2, corePipeSrc, corePipeDst, eventIdInfo)) { - return; - } - LLVM_DEBUG({ - llvm::dbgs() << "conflict found: " << "eventIdNum(" - << eventIdInfo.eventIdNum << ")\n"; - llvm::dbgs() << occ1->syncIrIndex << ' ' << occ1->startIndex << ' ' - << occ1->endIndex << ' ' << rwOp1->str(0, false) << '\n'; - llvm::dbgs() << occ2->syncIrIndex << ' ' << occ2->startIndex << ' ' - << occ2->endIndex << ' ' << rwOp2->str(0, false) << '\n'; - }); - if (corePipeSrc == corePipeDst) { - handleBarrierConflict(occ1, occ2, corePipeSrc, corePipeDst, isUseless); - } else if (auto unitFlagInfo = checkUnitFlagPatterns(occ1, occ2)) { - handleUnitFlagConflict(occ1, occ2, corePipeSrc, corePipeDst, - unitFlagInfo.value(), isUseless); - } else { - handleSetWaitConflict(occ1, occ2, corePipeSrc, corePipeDst, eventIdInfo, - isUseless); - } -} - -void Solver::calcAllEventIds() { - for (auto &[pipes, eventIdSolver] : eventIdSolver) { - assert(eventIdSolver != nullptr); - - [[maybe_unused]] auto result = - eventIdSolver->shrinkEventIdMaxToEventIdNum(); - assert(llvm::succeeded(result)); - assert(eventIdSolver->isColorable()); - } -} - -void Solver::collectBackwardSyncEventIds() { - LLVM_DEBUG(llvm::dbgs() << "collectBackwardSyncEventIds\n";); - for (auto &conflictPair : chosenConflictedPairs) { - if (!conflictPair->isUseless && conflictPair->isInnerBackward && - conflictPair->eventIdNode != nullptr) { - LLVM_DEBUG(llvm::dbgs() << " " << conflictPair->str() << "\n";); - for (auto eventId : conflictPair->eventIdNode->getEventIds()) { - auto &e = backwardSyncEvents[conflictPair->backwardSyncLoopOp] - [{conflictPair->setCorePipeInfo, - conflictPair->waitCorePipeInfo}][eventId]; - e = std::max(e, conflictPair->eventIdInfo.eventIdRepeatNum); - } - } - } -} - -void Solver::resetAndBuildSetWaitOpIndex(const SyncMap &syncMapBefore, - const SyncMap &syncMapAfter) { - globalSetWaitIndex = 0; - setWaitStartIndex.clear(); - setWaitEndIndex.clear(); - setWaitStartIndexInclusive.clear(); - setWaitEndIndexInclusive.clear(); - setWaitFlagOpsIndex.clear(); - collectSetWaitOpsIndexes(funcIr.get(), syncMapBefore, syncMapAfter); -} - -std::set> & -Solver::getSetWaitOpsIndexRef(pto::PIPE pipeSrc, pto::PIPE pipeDst, - int64_t eventId) { - auto key = std::make_tuple(pipeSrc, pipeDst, eventId); - return setWaitFlagOpsIndex[key]; -} - -// Collect indices for all Set/Wait ops to facilitate merging decisions. -void Solver::collectSetWaitOpsIndexes(OperationBase *op, - const SyncMap &syncMapBefore, - const SyncMap &syncMapAfter) { - assert(op != nullptr); - setWaitStartIndexInclusive[op] = globalSetWaitIndex++; - if (syncMapBefore.count(op)) { - auto *it = syncMapBefore.find(op); - assert(it != syncMapBefore.end()); - for (auto &syncOp : it->second) { - if (auto *setWaitOp = llvm::dyn_cast(syncOp.get())) { - for (auto eventId : setWaitOp->eventIds) { - auto &index = getSetWaitOpsIndexRef(setWaitOp->pipeSrc, - setWaitOp->pipeDst, eventId); - index.insert({globalSetWaitIndex++, setWaitOp}); - } - } - } - } - setWaitStartIndex[op] = globalSetWaitIndex++; - if (auto *scopeOp = llvm::dyn_cast(op)) { - for (auto &childOp : scopeOp->body) { - collectSetWaitOpsIndexes(childOp.get(), syncMapBefore, syncMapAfter); - } - } - setWaitEndIndex[op] = globalSetWaitIndex++; - if (syncMapAfter.count(op)) { - auto *it = syncMapAfter.find(op); - assert(it != syncMapAfter.end()); - for (auto &syncOp : it->second) { - if (auto *setWaitOp = llvm::dyn_cast(syncOp.get())) { - for (auto eventId : setWaitOp->eventIds) { - auto &index = getSetWaitOpsIndexRef(setWaitOp->pipeSrc, - setWaitOp->pipeDst, eventId); - index.insert({globalSetWaitIndex++, setWaitOp}); - } - } - } - } - setWaitEndIndexInclusive[op] = globalSetWaitIndex++; -} - -bool Solver::checkBackwardSyncEventsContains(OperationBase *op, - CorePipeInfo corePipeSrc, - CorePipeInfo corePipeDst, - int64_t eventId) { - auto *it1 = backwardSyncEvents.find(op); - if (it1 == backwardSyncEvents.end()) { - return false; - } - auto it2 = it1->second.find({corePipeSrc, corePipeDst}); - if (it2 == it1->second.end()) { - return false; - } - return it2->second.contains(eventId); -} - -bool Solver::checkBackwardSyncEventsContainsAfterMerge( - OperationBase *op, CorePipeInfo corePipeSrc, CorePipeInfo corePipeDst) { - auto *it1 = backwardSyncEventsAfterMerge.find(op); - if (it1 == backwardSyncEventsAfterMerge.end()) { - return false; - } - return it1->second.contains({corePipeSrc, corePipeDst}); -} - -// Check whether a backward-sync event id can be merged at scope level. -bool Solver::checkMergeable(Scope *scopeOp, CorePipeInfo corePipeSrc, - CorePipeInfo corePipeDst, int64_t eventId, - bool shouldBeUsedAtleastOnce) { - auto &index = - getSetWaitOpsIndexRef(corePipeSrc.pipe, corePipeDst.pipe, eventId); - if (shouldBeUsedAtleastOnce) { - auto it = index.lower_bound({setWaitStartIndexInclusive[scopeOp], nullptr}); - bool usedAtleastOnce = - it != index.end() && it->first < setWaitEndIndexInclusive[scopeOp]; - if (!usedAtleastOnce) { - return false; - } - } - { - auto it1 = - index.lower_bound({setWaitStartIndexInclusive[scopeOp], nullptr}); - auto it2 = index.lower_bound({setWaitEndIndex[scopeOp], nullptr}); - bool usedBefore = - it1 != index.end() && it1->first < setWaitStartIndex[scopeOp]; - bool usedAfter = - it2 != index.end() && it2->first < setWaitEndIndexInclusive[scopeOp]; - if (usedBefore || usedAfter) { - return false; - } - } - if (auto *conditionOp = llvm::dyn_cast(scopeOp)) { - if (!conditionOp->hasFalseScope()) { - return false; - } - return checkMergeable(conditionOp->getTrueScope(), corePipeSrc, corePipeDst, - eventId, true) && - checkMergeable(conditionOp->getFalseScope(), corePipeSrc, - corePipeDst, eventId, true); - } - if (auto *loopOp = llvm::dyn_cast(scopeOp)) { - for (auto &childOp : loopOp->body) { - if (auto *childScopeOp = llvm::dyn_cast(childOp.get())) { - if (!checkMergeable(childScopeOp, corePipeSrc, corePipeDst, eventId, - false)) { - return false; - } - } - } - for (auto &childOp : loopOp->body) { - if (auto *childScopeOp = llvm::dyn_cast(childOp.get())) { - if (checkMergeable(childScopeOp, corePipeSrc, corePipeDst, eventId, - true)) { - return true; - } - } - } - return false; - } - for (auto &childOp : scopeOp->body) { - auto it1 = - index.lower_bound({setWaitStartIndexInclusive[childOp.get()], nullptr}); - auto it2 = index.lower_bound({setWaitEndIndex[childOp.get()], nullptr}); - bool usedAtleastOnce = it1 != index.end() && - it1->first < setWaitEndIndexInclusive[childOp.get()]; - if (!usedAtleastOnce) { - continue; - } - bool before = - it1 != index.end() && it1->first < setWaitStartIndex[childOp.get()]; - bool after = it2 != index.end() && - it2->first < setWaitEndIndexInclusive[childOp.get()]; - if (before || after) { - return false; - } - if (!checkBackwardSyncEventsContains(childOp.get(), corePipeSrc, - corePipeDst, eventId)) { - return false; - } - if (checkBackwardSyncEventsContainsAfterMerge(childOp.get(), corePipeSrc, - corePipeDst)) { - return false; - } - } - return true; -} - -// Attempt to merge backward sync events across children and prune duplicates. -void Solver::mergeBackwardSyncEventIds(OperationBase *op) { - auto *scopeOp = llvm::dyn_cast_if_present(op); - if (scopeOp == nullptr) { - return; - } - for (auto &op : scopeOp->body) { - mergeBackwardSyncEventIds(op.get()); - } - - if (llvm::isa_and_present(op)) { - return; - } - if (llvm::isa_and_present(op->parentOp)) { - return; - } - - auto *conditionOp = llvm::dyn_cast(op); - if (conditionOp != nullptr) { - if (!conditionOp->hasFalseScope()) { - return; - } - } - - llvm::DenseSet> toBeErased; - - llvm::SmallVector coreTypes; - if (options.isCrossCoreMode()) { - coreTypes = {pto::TCoreType::VECTOR, pto::TCoreType::CUBE}; - } else { - coreTypes = {pto::TCoreType::CUBE_OR_VECTOR}; - } - size_t pipeNumMax = static_cast(pto::PIPE::PIPE_NUM); - const int64_t eventIdMax = getHWAvailableEventIdNum(options.syncMode); - - for (int64_t eventId = 0; eventId < eventIdMax; ++eventId) { - for (auto coreSrc : coreTypes) { - for (auto coreDst : coreTypes) { - for (size_t pipeSrcInt = 0; pipeSrcInt < pipeNumMax; pipeSrcInt++) { - for (size_t pipeDstInt = 0; pipeDstInt < pipeNumMax; pipeDstInt++) { - auto pipeSrc = static_cast(pipeSrcInt); - auto pipeDst = static_cast(pipeDstInt); - auto corePipeSrc = CorePipeInfo(coreSrc, pipeSrc); - auto corePipeDst = CorePipeInfo(coreDst, pipeDst); - if (checkBackwardSyncEventsContains(scopeOp, corePipeSrc, - corePipeDst, eventId)) { - continue; - } - if (checkMergeable(scopeOp, corePipeSrc, corePipeDst, eventId)) { - toBeErased.insert({corePipeSrc, corePipeDst, eventId}); - backwardSyncEvents[scopeOp][{corePipeSrc, corePipeDst}].insert( - {eventId, 1}); - } - } - } - } - } - } - - if (isa(scopeOp)) { - for (auto &op : scopeOp->body) { - if (auto *block = llvm::dyn_cast(op.get())) { - for (auto &childOp : block->body) { - if (auto *childScopeOp = llvm::dyn_cast(childOp.get())) { - for (auto [corePipeSrc, corePipeDst, eventId] : toBeErased) { - if (checkBackwardSyncEventsContains(childScopeOp, corePipeSrc, - corePipeDst, eventId)) { - auto key = std::make_tuple(corePipeSrc, corePipeDst); - backwardSyncEvents[childScopeOp][key].erase(eventId); - if (backwardSyncEvents[childScopeOp][key].empty()) { - backwardSyncEvents[childScopeOp].erase(key); - } - } - } - } - } - } - } - } else { - for (auto &childOp : scopeOp->body) { - if (auto *childScopeOp = llvm::dyn_cast(childOp.get())) { - for (auto [corePipeSrc, corePipeDst, eventId] : toBeErased) { - if (checkBackwardSyncEventsContains(childScopeOp, corePipeSrc, - corePipeDst, eventId)) { - auto key = std::make_tuple(corePipeSrc, corePipeDst); - backwardSyncEvents[childScopeOp][key].erase(eventId); - if (backwardSyncEvents[childScopeOp][key].empty()) { - backwardSyncEvents[childScopeOp].erase(key); - } - } - } - } - } - } -} - -void Solver::mergeBackwardSyncPairs(SyncMap &syncMapBefore, - SyncMap &syncMapAfter) { - if (!options.moveOutAndMergeBackwardSyncPairs) { - return; - } - if (options.isIntraCoreMode()) { - resetAndBuildSetWaitOpIndex(syncMapBefore, syncMapAfter); - auto *scopeOp = llvm::dyn_cast(funcIr.get()); - assert(scopeOp != nullptr && scopeOp->body.front() != nullptr); - mergeBackwardSyncEventIds(scopeOp->body.front().get()); - } -} - -SyncBeforeAfterMap Solver::getBeforeAfterSyncMaps() { - calcAllEventIds(); - SyncMap syncMapBefore, syncMapAfter; - std::vector conflictPairs; - for (auto &conflictPair : chosenConflictedPairs) { - conflictPairs.push_back(conflictPair.get()); - } - for (auto &conflictPair : persistentChosenConflictedPairs) { - conflictPairs.push_back(conflictPair.get()); - } - - for (auto *conflictPair : conflictPairs) { - if (conflictPair->isUseless) { - continue; - } - if (conflictPair->replacedWithUnitFlag) { - continue; - } - assert(conflictPair->setOp != nullptr && conflictPair->waitOp != nullptr); - if (conflictPair->isBarrier()) { - auto barrierOp = std::make_unique( - conflictPair->waitOp->op, conflictPair->waitOp->parentOp, - conflictPair->waitCorePipeInfo.pipe); - LLVM_DEBUG(barrierOp->debugId = conflictPair->id); - syncMapBefore[conflictPair->waitOp].push_back(std::move(barrierOp)); - } else { - assert(conflictPair->eventIdNode != nullptr); - auto setOp = std::make_unique( - conflictPair->setOp->op, conflictPair->setOp->parentOp, - conflictPair->eventIdNode->getEventIds(), - conflictPair->setCorePipeInfo.pipe, - conflictPair->waitCorePipeInfo.pipe); - auto waitOp = std::make_unique( - conflictPair->waitOp->op, conflictPair->waitOp->parentOp, - conflictPair->eventIdNode->getEventIds(), - conflictPair->setCorePipeInfo.pipe, - conflictPair->waitCorePipeInfo.pipe); - if (options.isCrossCoreMode()) { - setOp->coreType = conflictPair->setCorePipeInfo.coreType; - waitOp->coreType = conflictPair->waitCorePipeInfo.coreType; - } - setOp->eventIdInfo = conflictPair->eventIdInfo; - waitOp->eventIdInfo = conflictPair->eventIdInfo; - setOp->checkLastIter = conflictPair->setOnLastIterOnly; - waitOp->checkFirstIter = conflictPair->waitOnFirstIterOnly; - LLVM_DEBUG({ - setOp->debugId = conflictPair->id; - waitOp->debugId = conflictPair->id; - }); - assert(setOp != nullptr && waitOp != nullptr); - syncMapAfter[conflictPair->setOp].push_back(std::move(setOp)); - syncMapBefore[conflictPair->waitOp].push_front(std::move(waitOp)); - } - } - - collectBackwardSyncEventIds(); - mergeBackwardSyncPairs(syncMapBefore, syncMapAfter); - - for (auto &[op, mp] : backwardSyncEvents) { - if (mp.empty()) { - continue; - } - auto *scopeOp = llvm::dyn_cast(op); - assert(scopeOp != nullptr); - for (auto [setWaitCorePipes, eventIdsMp] : mp) { - if (eventIdsMp.empty()) { - continue; - } - llvm::SmallVector eventIds; - for (auto [eventId, repeatNum] : eventIdsMp) { - llvm::SmallVector curEventIds(repeatNum, eventId); - llvm::append_range(eventIds, curEventIds); - } - llvm::sort(eventIds); - auto [corePipeSrc, corePipeDst] = setWaitCorePipes; - auto setOp = - std::make_unique(scopeOp->op, scopeOp->parentOp, eventIds, - corePipeSrc.pipe, corePipeDst.pipe); - auto waitOp = - std::make_unique(scopeOp->op, scopeOp->parentOp, eventIds, - corePipeSrc.pipe, corePipeDst.pipe); - setOp->allAtOnce = true; - waitOp->allAtOnce = true; - if (options.isCrossCoreMode()) { - setOp->coreType = corePipeSrc.coreType; - waitOp->coreType = corePipeDst.coreType; - } - assert(setOp != nullptr && waitOp != nullptr); - syncMapBefore[scopeOp].push_back(std::move(setOp)); - syncMapAfter[scopeOp].push_front(std::move(waitOp)); - } - } - return std::make_pair(std::move(syncMapBefore), std::move(syncMapAfter)); -} - -void Solver::processConflict(Occurrence *occ1, Occurrence *occ2, - RWOperation *rwOp1, RWOperation *rwOp2, - bool isUseless) { - for (auto [corePipeSrc, corePipeDst] : checkMemoryConflicts(rwOp1, rwOp2)) { - if (options.alwaysUsePipeSAsWaitingPipe) { - corePipeDst.pipe = pto::PIPE::PIPE_S; - } - auto eventIdInfo = - getEventIdInfo(occ1, occ2, rwOp1, rwOp2, corePipeSrc, corePipeDst); - handleConflict(occ1, occ2, rwOp1, rwOp2, corePipeSrc, corePipeDst, - eventIdInfo, isUseless); - } -} - -// Main processing loop that iterates processingOrders and attempts to -// discover and record conflicts. -void Solver::processOrders() { - for (auto &[occ1, occ2, rwOp1, rwOp2, isUseless] : processingOrders) { - assert(occ1 != occ2); - assert(occ1->syncIrIndex < occ2->syncIrIndex); - if (checkVisited(occ1, occ2)) { - assert(false && "expected to not check a pair more than once."); - continue; - } - if (checkImpossibleOccPair(occ1, occ2) || checkAlreadySynced(occ1, occ2) || - skipMMad1DecomposedLoopOpt(occ1, occ2) || - checkSkipParallelLoop(occ1, occ2) || - checkSkipCrossCorePair(occ1, occ2)) { - continue; - } - DEBUG_WITH_TYPE("gss-sync-solver-checking", { - llvm::dbgs() << "checking: " << (isUseless ? "is-useless\n" : "\n"); - llvm::dbgs() << occ1->syncIrIndex << ' ' << occ1->startIndex << ' ' - << occ1->endIndex << ' ' << occ1->op->str(0, false) << '\n'; - llvm::dbgs() << occ2->syncIrIndex << ' ' << occ2->startIndex << ' ' - << occ2->endIndex << ' ' << occ2->op->str(0, false) << '\n'; - }); - if (checkAlreadySyncedWithUnitFlag(occ1, occ2)) { - continue; - } - processConflict(occ1, occ2, rwOp1, rwOp2, isUseless); - } -} - -void Solver::insertMergedBackwardSyncPairs() { - for (auto &[scopeOp, st] : backwardSyncEventsAfterMerge) { - for (auto &corePipeInfoPair : st) { - auto [corePipeSrc, corePipeDst] = corePipeInfoPair; - for (auto *scopeOcc : opAllOccurrences[scopeOp]) { - auto *parentScopeOcc = scopeOcc->parentOcc; - assert(parentScopeOcc != nullptr); - Occurrence *setOcc = nullptr; - Occurrence *waitOcc = nullptr; - auto startIndex = scopeOcc->startIndex; - auto endIndex = scopeOcc->endIndex; - if (isa(scopeOp)) { - setOcc = getBeforePlaceHolderOcc(scopeOcc); - waitOcc = getAfterPlaceHolderOcc(scopeOcc); - startIndex = setOcc->endIndex; - endIndex = waitOcc->startIndex; - } - auto conflictPair = std::make_unique( - nullptr, nullptr, nullptr, nullptr, setOcc, waitOcc, corePipeSrc, - corePipeDst, startIndex, endIndex); - assert(conflictPair->startIndex <= conflictPair->endIndex); - conflictPair->isUseless = true; - conflictPair->dontReuse = true; - conflictPair->dontCheckForConflict = true; - conflictPair->couldNotRun = false; // notice this - LLVM_DEBUG({ - llvm::dbgs() << "consider-merged-backward-pair: " - << scopeOp->str(0, false) << ' ' << conflictPair->str() - << "\n"; - }); - scopeOccChosenConflicts[parentScopeOcc].insert(conflictPair.get()); - chosenConflictedPairs.push_back(std::move(conflictPair)); - } - } - } -} - -llvm::LogicalResult Solver::considerOuterBackwardSyncPairs() { - if (!options.considerOuterBackwardSyncPairs) { - return llvm::failure(); - } - bool backwardPairsPositionChanged = false; - for (auto &[scopeOp, st] : backwardSyncEventsAfterMerge) { - SmallVector> toBeErased; - for (auto &corePipeInfoPair : st) { - if (!backwardSyncEvents.contains(scopeOp) || - !backwardSyncEvents[scopeOp].contains(corePipeInfoPair)) { - toBeErased.push_back(corePipeInfoPair); - } - } - if (!toBeErased.empty()) { - backwardPairsPositionChanged = true; - for (auto &corePipeInfoPair : toBeErased) { - st.erase(corePipeInfoPair); - } - } - } - int chosenOpsDepth = -1; - SmallVector chosenOps; - for (auto &[scopeOp, mp] : backwardSyncEvents) { - if (backwardSyncEventsAfterMerge.contains(scopeOp)) { - continue; - } - int scopeOpDepth = scopeOp->getDepth(); - if (chosenOpsDepth == scopeOpDepth) { - chosenOps.push_back(scopeOp); - } else if (chosenOpsDepth == -1 || chosenOpsDepth < scopeOpDepth) { - chosenOps.clear(); - chosenOps.push_back(scopeOp); - chosenOpsDepth = scopeOpDepth; - } - } - if (chosenOps.empty()) { - return llvm::failure(); - } - bool newPairIsInserted = false; - for (auto *chosenOp : chosenOps) { - for (auto &[corePipeInfoPair, eventIdsMp] : backwardSyncEvents[chosenOp]) { - assert(!eventIdsMp.empty()); - if (!eventIdsMp.empty()) { - auto [it, isInserted] = - backwardSyncEventsAfterMerge[chosenOp].insert(corePipeInfoPair); - newPairIsInserted |= isInserted; - } - } - } - return llvm::success(backwardPairsPositionChanged || newPairIsInserted); -} - -llvm::LogicalResult Solver::reuseSyncPairToSaveEventIds() { - if (!options.reuseSyncPairToSaveEventIds || barrierAllPairs.empty()) { - return llvm::failure(); - } - bool limitReached = true; - for (auto [corePipeSrc, corePipeDst] : barrierAllPairs) { - if (reusePairs[{corePipeSrc, corePipeDst}] < maxReuseNum) { - if (reusePairs[{corePipeSrc, corePipeDst}] <= - reusedPairs[{corePipeSrc, corePipeDst}]) { - reusePairs[{corePipeSrc, corePipeDst}] += 1; - limitReached = false; - } - } - } - DEBUG_WITH_TYPE("gss-sync-solver-reuse", { - llvm::dbgs() << "reusePairs: \n"; - for (auto [pipeCorePairs, cnt] : reusePairs) { - llvm::dbgs() << get<0>(pipeCorePairs).pipe << ' ' - << get<1>(pipeCorePairs).pipe << ' ' << cnt << '\n'; - } - }); - return llvm::success(!limitReached); -} - -llvm::LogicalResult Solver::disableMultiEventIdForBarrierAllPairs() { - if (!options.disableMultiEventIdForBarrierAllPairs || - barrierAllPairs.empty()) { - return llvm::failure(); - } - bool newPairIsInserted = false; - for (auto corePipeInfoPair : barrierAllPairs) { - auto [it, isInserted] = disabledMultiEventIdPairs.insert(corePipeInfoPair); - newPairIsInserted |= isInserted; - } - LLVM_DEBUG({ - if (newPairIsInserted) { - llvm::dbgs() << "disabled-multi-event-id-pairs: \n"; - for (auto &[corePipeSrc, corePipeDst] : disabledMultiEventIdPairs) { - llvm::dbgs() << corePipeSrc.coreType << ' ' << corePipeSrc.pipe << ' ' - << corePipeDst.coreType << ' ' << corePipeDst.pipe << '\n'; - } - } - }); - return llvm::success(newPairIsInserted); -} - -llvm::LogicalResult Solver::tryMovingOutBackwardSyncPairsToOuterLoops() { - if (!options.moveOutAndMergeBackwardSyncPairs || !options.isCrossCoreMode() || - dontMoveBackwardSyncPairsToOutmostLoop) { - return llvm::failure(); - } - if (!moveBackwardSyncPairsToOutmostLoop) { - moveBackwardSyncPairsToOutmostLoop = true; - return llvm::success(); - } - if (!barrierAllPairs.empty()) { - moveBackwardSyncPairsToOutmostLoop = false; - dontMoveBackwardSyncPairsToOutmostLoop = true; - return llvm::success(); - } - return llvm::failure(); -} - -// High-level solve orchestration with multiple passes and optional merging -// iterations. -llvm::LogicalResult Solver::runSolver(bool enableOpts1, bool enableOpts2) { - reset(/*resetEventIdRanOutOpts=*/true); - - int64_t runNum = 0; - while (runNum++ < maxRunNum) { - LLVM_DEBUG(llvm::dbgs() << "runNum: " << runNum << '\n'); - - reset(); - insertMergedBackwardSyncPairs(); - processOrders(); - - if (llvm::succeeded(tryMovingOutBackwardSyncPairsToOuterLoops())) { - continue; - } - - if (enableOpts1) { - if (options.considerOuterBackwardSyncPairs) { - getBeforeAfterSyncMaps(); - if (llvm::succeeded(considerOuterBackwardSyncPairs())) { - continue; - } - if (!barrierAllPairs.empty()) { - backwardSyncEventsAfterMerge.clear(); - } - } - } - - if (enableOpts2) { - if (!barrierAllPairs.empty()) { - if (llvm::succeeded(reuseSyncPairToSaveEventIds())) { - continue; - } - if (llvm::succeeded(disableMultiEventIdForBarrierAllPairs())) { - continue; - } - } - } - - if (!barrierAllPairs.empty()) { - pickAndInsertABarrierAll(); - reset(/*resetEventIdRanOutOpts=*/true); - continue; - } - break; - } - - reset(); - insertMergedBackwardSyncPairs(); - processOrders(); - - return llvm::success(runNum < maxRunNum); -} - -void Solver::solve() { - if (llvm::succeeded(runSolver())) { - return; - } - if (!options.isTestMode()) { - if (llvm::succeeded(runSolver(/*enableOpts1=*/false))) { - return; - } - if (llvm::succeeded( - runSolver(/*enableOpts1=*/false, /*enableOpts2=*/false))) { - return; - } - } - llvm_unreachable("GSS: runSolver() failed."); -} +#include "SyncSolver.def" diff --git a/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.def b/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.def new file mode 100644 index 000000000..23a4032a6 --- /dev/null +++ b/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.def @@ -0,0 +1,2576 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===--------- SyncSolver.cpp ------- Graph Sync Solver -------------------===// +//===----------------------------------------------------------------------===// + +#include "PTO/Transforms/GraphSyncSolver/SyncSolver.h" +#include "PTO/Transforms/GraphSyncSolver/GraphSolver.h" +#include "PTO/Transforms/GraphSyncSolver/MemInfo.h" +#include "PTO/Transforms/GraphSyncSolver/SyncSolverIR.h" +#include "PTO/Transforms/GraphSyncSolver/Utility.h" + +#include "PTO/IR/PTO.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/LogicalResult.h" +#include +#include +#include +#include +#include +#include +#include + +#define DEBUG_TYPE "PTO-gss-solver" + +using namespace mlir; +using namespace pto::syncsolver; + +// Reset per-pass bookkeeping to start fresh. +void Solver::reset(bool resetEventIdRanOutOpts) { + if (resetEventIdRanOutOpts) { + reusePairs.clear(); + disabledMultiEventIdPairs.clear(); + backwardSyncEventsAfterMerge.clear(); + moveBackwardSyncPairsToOutmostLoop = false; + dontMoveBackwardSyncPairsToOutmostLoop = false; + } + skipOcc.clear(); + syncedPairs.clear(); + processedOccPairs.clear(); + chosenConflictedPairs.clear(); + scopeOccChosenConflicts.clear(); + scopeOccPairChosenConflicts.clear(); + backwardSyncEvents.clear(); + replacedWithReusableSyncedPairs.clear(); + reusedPairs.clear(); + barrierAllPairs.clear(); + insertedBarrierAllBefore.clear(); + eventIdSolver.clear(); + resetUnitFlag(); +} + +void Solver::resetUnitFlag() { + for (auto *rwOp : unitFlagFeaturedOps) { + rwOp->mergedUnitFlagInfo.reset(); + for (auto *occ : opAllOccurrences[rwOp]) { + occ->unitFlagInfo.reset(); + } + } +} + +// Helpers to find first/last iteration occurrences relative to parent +// occurrences. +Occurrence *Solver::getFirstIterOcc(Occurrence *occ, Occurrence *parOcc) { + assert(occ != nullptr && parOcc != nullptr); + if (parOcc->depth + 1 < occ->depth) { + auto *newParOcc = getFirstIterOcc( + occ->getNthParent(occ->depth - parOcc->depth - 1), parOcc); + return getFirstIterOcc(occ, newParOcc); + } + auto *it = + std::find_if(parOcc->childOccs.begin(), parOcc->childOccs.end(), + [occ](Occurrence *curOcc) { return occ->op == curOcc->op; }); + assert(it != parOcc->childOccs.end()); + return *it; +} + +Occurrence *Solver::getLastIterOcc(Occurrence *occ, Occurrence *parOcc) { + assert(occ != nullptr && parOcc != nullptr); + if (parOcc->depth + 1 < occ->depth) { + auto *newParOcc = getLastIterOcc( + occ->getNthParent(occ->depth - parOcc->depth - 1), parOcc); + return getLastIterOcc(occ, newParOcc); + } + auto it = + std::find_if(parOcc->childOccs.rbegin(), parOcc->childOccs.rend(), + [occ](Occurrence *curOcc) { return occ->op == curOcc->op; }); + assert(it != parOcc->childOccs.rend()); + return *it; +} + +bool Solver::checkSkipCrossCorePair(Occurrence *occ1, Occurrence *occ2) { + if (!options.isCrossCoreMode()) { + return false; + } + auto *rwOp1 = llvm::dyn_cast(occ1->op); + auto *rwOp2 = llvm::dyn_cast(occ2->op); + assert(rwOp1 != nullptr && rwOp2 != nullptr); + assert(rwOp1->coreType != pto::TCoreType::CUBE_OR_VECTOR); + assert(rwOp2->coreType != pto::TCoreType::CUBE_OR_VECTOR); + if (rwOp1->coreType == rwOp2->coreType) { + return true; + } + if (rwOp1->coreType == pto::TCoreType::CUBE_AND_VECTOR) { + return true; + } + return false; +} + +bool Solver::checkSkipParallelLoop(Occurrence *occ1, Occurrence *occ2) { + if (!isBackwardSync(occ1, occ2)) { + return false; + } + auto [parOcc1, parOcc2] = Occurrence::getLCAPair(occ1, occ2); + assert(parOcc1 != nullptr && parOcc2 != nullptr); + auto *parentLCALoopOcc = Occurrence::getParentloop(parOcc1); + assert(parentLCALoopOcc != nullptr); + auto *parentLCALoopOp = llvm::cast(parentLCALoopOcc->op); + return parentLCALoopOp->isParallel; +} + +// Check whether occurrences belong to impossible (if-else) pairing. +bool Solver::checkImpossibleOccPair(Occurrence *occ1, Occurrence *occ2) { + assert(occ1 != nullptr && occ2 != nullptr); + if (occ1->op == occ2->op) { + return false; + } + auto [parOcc1, parOcc2] = Occurrence::getLCAPair(occ1, occ2); + assert(parOcc1 != nullptr && parOcc2 != nullptr); + bool isIfElseSituation = + parOcc1->parentOcc != nullptr && + parOcc1->parentOcc == parOcc2->parentOcc && + llvm::isa_and_present(parOcc1->parentOcc->op); + return isIfElseSituation; +} + +// Detect whether occ1 and occ2 have already been covered by an earlier sync. +bool Solver::checkAlreadySynced(Occurrence *occ1, Occurrence *occ2) { + assert(occ1 != nullptr && occ2 != nullptr); + assert(occ1->op != nullptr && occ2->op != nullptr); + + auto [parOcc1, parOcc2] = Occurrence::getLCAPair(occ1, occ2); + assert(parOcc1 != nullptr && parOcc2 != nullptr); + assert(parOcc1->parentOcc != nullptr && parOcc2->parentOcc != nullptr); + + auto [parOp1, parOp2] = OperationBase::getLCAPair(occ1->op, occ2->op); + assert(parOp1 != nullptr && parOp2 != nullptr); + assert(parOp1->parentOp != nullptr && parOp2->parentOp != nullptr); + + auto *parentLoop = OperationBase::getParentloop(parOcc1->op); + auto *curLoop = OperationBase::getParentloop(parOp1); + if (parentLoop == nullptr || parentLoop == curLoop) { + return false; + } + + assert(curLoop != nullptr); + assert(parentLoop->isProperAncestor(curLoop)); + while (curLoop != parentLoop) { + if (!llvm::cast(curLoop)->isParallel) { + return true; + } + curLoop = OperationBase::getParentloop(curLoop); + assert(curLoop != nullptr); + } + return false; +} + +// Unit-flag reuse check between two RWOperations. +bool Solver::checkAlreadySyncedWithUnitFlag(Occurrence *occ1, + Occurrence *occ2) { + assert(occ1 != nullptr && occ2 != nullptr); + if (!options.enableUnitFlagFeature) { + return false; + } + if (!occ1->hasUnitFlagFeat || !occ2->hasUnitFlagFeat) { + return false; + } + llvm::DenseSet visited; + DEBUG_WITH_TYPE("gss-sync-solver-check-unit-flag", { + llvm::dbgs() << "unit-flag-step: " << occ1->syncIrIndex << ' ' + << occ1->op->str(0, false) << "\n"; + }); + Occurrence *curOcc = occ1->unitFlagInfo.linkedElementAsSet; + while (curOcc != nullptr) { + DEBUG_WITH_TYPE("gss-sync-solver-check-unit-flag", { + llvm::dbgs() << "unit-flag-step: " << curOcc->syncIrIndex << ' ' + << curOcc->op->str(0, false) << "\n"; + }); + auto [it, isInserted] = visited.insert(curOcc); + if (!isInserted) { + break; + } + if (curOcc == occ2) { + return true; + } + curOcc = curOcc->unitFlagInfo.linkedElementAsSet; + } + return false; +} + +bool Solver::ignoreMemoryConflict(RWOperation *rwOp1, RWOperation *rwOp2, + const MemInfo &memInfo1, + const MemInfo &memInfo2) { + if (options.isIntraCoreMode()) { + if (memInfo1.isWorkSpace && memInfo2.isWorkSpace) { + if (options.intraCoreIgnoreWorkSpaceFunctionArguments) { + return true; + } + } + } + return false; +} + +bool Solver::checkMemInfoConflict(RWOperation *rwOp1, RWOperation *rwOp2, + const MemInfo &memInfo1, + const MemInfo &memInfo2, + std::optional lcmLen, + std::optional eventIdNum) { + if (ignoreMemoryConflict(rwOp1, rwOp2, memInfo1, memInfo2)) { + return false; + } + return MemInfo::checkConflict(memInfo1, memInfo2, lcmLen, eventIdNum); +} + +bool Solver::checkMemInfoConflict( + RWOperation *rwOp1, RWOperation *rwOp2, + const llvm::SmallVector &memInfoList1, + const llvm::SmallVector &memInfoList2, + std::optional lcmLen, std::optional eventIdNum) { + for (auto &memInfo1 : memInfoList1) { + for (auto &memInfo2 : memInfoList2) { + if (checkMemInfoConflict(rwOp1, rwOp2, memInfo1, memInfo2, lcmLen, + eventIdNum)) { + return true; + } + } + } + return false; +} + +// High-level wrapper computing pipe pairs that represent memory conflicts +// between two RW ops. +llvm::SmallVector> +Solver::checkMemoryConflicts(RWOperation *rwOp1, RWOperation *rwOp2) { + assert(rwOp1 != nullptr && rwOp2 != nullptr); + auto [it, isInserted] = checkMemoryConflictsMem.insert({{rwOp1, rwOp2}, {}}); + if (!isInserted) { + return it->second; + } + auto coreSrc = rwOp1->coreType; + auto coreDst = rwOp2->coreType; + if (options.isCrossCoreMode()) { + if (coreDst == pto::TCoreType::CUBE_AND_VECTOR) { + coreDst = (coreSrc == pto::TCoreType::VECTOR) ? pto::TCoreType::CUBE + : pto::TCoreType::VECTOR; + } + assert(coreSrc == pto::TCoreType::VECTOR || + coreSrc == pto::TCoreType::CUBE); + assert(coreDst == pto::TCoreType::VECTOR || + coreDst == pto::TCoreType::CUBE); + } + llvm::SetVector> collectedConflictsSet; + if (checkMemInfoConflict(rwOp1, rwOp2, rwOp1->readMemInfo, + rwOp2->writeMemInfo)) { + collectedConflictsSet.insert({CorePipeInfo(coreSrc, rwOp1->pipeRead), + CorePipeInfo(coreDst, rwOp2->pipeWrite)}); + } + if (checkMemInfoConflict(rwOp1, rwOp2, rwOp1->writeMemInfo, + rwOp2->readMemInfo)) { + collectedConflictsSet.insert({CorePipeInfo(coreSrc, rwOp1->pipeWrite), + CorePipeInfo(coreDst, rwOp2->pipeRead)}); + } + if (checkMemInfoConflict(rwOp1, rwOp2, rwOp1->writeMemInfo, + rwOp2->writeMemInfo)) { + collectedConflictsSet.insert({CorePipeInfo(coreSrc, rwOp1->pipeWrite), + CorePipeInfo(coreDst, rwOp2->pipeWrite)}); + } + llvm::SmallVector> collectedConflicts( + collectedConflictsSet.begin(), collectedConflictsSet.end()); + return it->second = collectedConflicts; +} + +bool Solver::checkMemoryConflictBetweenOccExclusive( + Occurrence *occ1, Occurrence *occ2, + std::function filter) { + assert(occ1 != nullptr && occ2 != nullptr); + auto *rwOp1 = llvm::dyn_cast_if_present(occ1->op); + auto *rwOp2 = llvm::dyn_cast_if_present(occ2->op); + assert(rwOp1 != nullptr && rwOp2 != nullptr); + for (int i = occ1->syncIrEndIndex; i < occ2->syncIrIndex; i++) { + if (auto *otherOp = llvm::dyn_cast_if_present(syncIr[i]->op)) { + if (!filter(otherOp)) { + continue; + } + if (!checkMemoryConflicts(rwOp1, otherOp).empty()) { + return true; + } + if (!checkMemoryConflicts(rwOp2, otherOp).empty()) { + return true; + } + } + } + return false; +} + +std::optional +Solver::getMultiBufferLoop(RWOperation *rwOp1, RWOperation *rwOp2, + const llvm::SmallVector &memInfoList1, + const llvm::SmallVector &memInfoList2) { + std::optional multibufferLoop; + for (auto &memInfo1 : memInfoList1) { + for (auto &memInfo2 : memInfoList2) { + if (checkMemInfoConflict(rwOp1, rwOp2, memInfo1, memInfo2)) { + if (!memInfo1.pointerLikeInfo.has_value() || + !memInfo2.pointerLikeInfo.has_value()) { + return {}; + } + auto multibufferLoop1 = memInfo1.pointerLikeInfo->parentLoop; + auto multibufferLoop2 = memInfo2.pointerLikeInfo->parentLoop; + if (multibufferLoop1 == nullptr || + multibufferLoop1 != multibufferLoop2) { + return {}; + } + if (multibufferLoop.has_value() && + multibufferLoop.value() != multibufferLoop1) { + return {}; + } + multibufferLoop = multibufferLoop1; + } + } + } + return multibufferLoop; +} + +std::optional +Solver::getMultiBufferLoop(RWOperation *rwOp1, RWOperation *rwOp2) { + std::optional multibufferLoop; + if (checkMemInfoConflict(rwOp1, rwOp2, rwOp1->readMemInfo, + rwOp2->writeMemInfo)) { + auto curMultibufferLoop = getMultiBufferLoop( + rwOp1, rwOp2, rwOp1->readMemInfo, rwOp2->writeMemInfo); + if (multibufferLoop.has_value() && + multibufferLoop.value() != curMultibufferLoop) { + return {}; + } + multibufferLoop = curMultibufferLoop; + } + if (checkMemInfoConflict(rwOp1, rwOp2, rwOp1->writeMemInfo, + rwOp2->readMemInfo)) { + auto curMultibufferLoop = getMultiBufferLoop( + rwOp1, rwOp2, rwOp1->writeMemInfo, rwOp2->readMemInfo); + if (multibufferLoop.has_value() && + multibufferLoop.value() != curMultibufferLoop) { + return {}; + } + multibufferLoop = curMultibufferLoop; + } + if (checkMemInfoConflict(rwOp1, rwOp2, rwOp1->writeMemInfo, + rwOp2->writeMemInfo)) { + auto curMultibufferLoop = getMultiBufferLoop( + rwOp1, rwOp2, rwOp1->writeMemInfo, rwOp2->writeMemInfo); + if (multibufferLoop.has_value() && + multibufferLoop.value() != curMultibufferLoop) { + return {}; + } + multibufferLoop = curMultibufferLoop; + } + return multibufferLoop; +} + +std::optional +Solver::getMultiBufferEventIdInfo(Occurrence *occ1, Occurrence *occ2, + RWOperation *rwOp1, RWOperation *rwOp2) { + assert(rwOp1 != nullptr && rwOp2 != nullptr); + + int64_t lcm = 1; + int64_t minWriteSize = LONG_MAX; + LoopLikeOpInterface multibufferLoop{nullptr}; + + if (options.isTestMode()) { + auto *parLoop1 = occ1->getParentOfType(); + auto *parLoop2 = occ2->getParentOfType(); + if (!parLoop1 || parLoop1 != parLoop2) { + return {}; + } + auto [setOcc, waitOcc] = getSetWaitOcc(occ1, occ2); + if (!parLoop1->isProperAncestor(setOcc) || + !parLoop1->isProperAncestor(waitOcc)) { + return {}; + } + } else { + auto multibufferLoopOpt = getMultiBufferLoop(rwOp1, rwOp2); + if (!multibufferLoopOpt.has_value() || !multibufferLoopOpt.value()) { + return {}; + } + multibufferLoop = multibufferLoopOpt.value(); + assert(multibufferLoop != nullptr); + auto [setOcc, waitOcc] = getSetWaitOcc(occ1, occ2); + if (!setOcc->getParentWithOp(multibufferLoop, + /*assertExists=*/false) || + !waitOcc->getParentWithOp(multibufferLoop, + /*assertExists=*/false)) { + return {}; + } + } + + for (auto &memInfo1 : rwOp1->readMemInfo) { + for (auto &memInfo2 : rwOp2->writeMemInfo) { + if (checkMemInfoConflict(rwOp1, rwOp2, memInfo1, memInfo2)) { + int64_t curLcm = std::lcm(memInfo1.getSz(), memInfo2.getSz()); + lcm = std::lcm(lcm, curLcm); + minWriteSize = std::min(minWriteSize, memInfo2.getSz()); + } + } + } + for (auto &memInfo1 : rwOp1->writeMemInfo) { + for (auto &memInfo2 : rwOp2->readMemInfo) { + if (checkMemInfoConflict(rwOp1, rwOp2, memInfo1, memInfo2)) { + int64_t curLcm = std::lcm(memInfo1.getSz(), memInfo2.getSz()); + lcm = std::lcm(lcm, curLcm); + minWriteSize = std::min(minWriteSize, memInfo1.getSz()); + } + } + } + for (auto &memInfo1 : rwOp1->writeMemInfo) { + for (auto &memInfo2 : rwOp2->writeMemInfo) { + if (checkMemInfoConflict(rwOp1, rwOp2, memInfo1, memInfo2)) { + int64_t curLcm = std::lcm(memInfo1.getSz(), memInfo2.getSz()); + lcm = std::lcm(lcm, curLcm); + minWriteSize = std::min(minWriteSize, memInfo1.getSz()); + minWriteSize = std::min(minWriteSize, memInfo2.getSz()); + } + } + } + + // In case no write sizes were positive. + if (minWriteSize == LONG_MAX) { + minWriteSize = 1; + return {}; + } + + int64_t eventIdNum = minWriteSize; + for (; eventIdNum >= 1; eventIdNum--) { + // llvm::dbgs() << "checking event-id-num: " << eventIdNum << '\n'; + int64_t curLcm = std::lcm(lcm, eventIdNum); + bool okRW = !checkMemInfoConflict(rwOp1, rwOp2, rwOp1->readMemInfo, + rwOp2->writeMemInfo, curLcm, eventIdNum); + bool okWR = !checkMemInfoConflict(rwOp1, rwOp2, rwOp1->writeMemInfo, + rwOp2->readMemInfo, curLcm, eventIdNum); + bool okWW = !checkMemInfoConflict(rwOp1, rwOp2, rwOp1->writeMemInfo, + rwOp2->writeMemInfo, curLcm, eventIdNum); + if (okRW && okWR && okWW) { + break; + } + } + if (eventIdNum <= 1) { + return {}; + } + EventIdInfo eventIdInfo(eventIdNum); + eventIdInfo.multibufferLoop = multibufferLoop; + return eventIdInfo; +} + +std::optional +Solver::checkMultiBufferEventIdInfo(Occurrence *occ1, Occurrence *occ2, + RWOperation *rwOp1, RWOperation *rwOp2) { + assert(rwOp1 != nullptr && rwOp2 != nullptr); + if (!options.isTestMode()) { + if (!checkAllParentLoopsAreForLoops(rwOp1->op) || + !checkAllParentLoopsAreForLoops(rwOp2->op)) { + return {}; + } + } + if (auto eventIdInfo = getMultiBufferEventIdInfo(occ1, occ2, rwOp1, rwOp2)) { + return eventIdInfo; + } + return {}; +} + +std::optional +Solver::checkCVMultiBufferUnrollEventIdInfo(RWOperation *rwOp1, + RWOperation *rwOp2) { + assert(rwOp1 != nullptr && rwOp2 != nullptr); + if (!options.isCrossCoreMode()) { + return {}; + } + auto *parentLoop1 = rwOp1->getParentOfType(); + auto *parentLoop2 = rwOp2->getParentOfType(); + while (parentLoop1 != nullptr && !parentLoop1->multibufferUnrollNum) { + parentLoop1 = parentLoop1->getParentOfType(); + } + while (parentLoop2 != nullptr && !parentLoop2->multibufferUnrollNum) { + parentLoop2 = parentLoop2->getParentOfType(); + } + if (!parentLoop1 || !parentLoop2) { + return {}; + } + if (auto *parCond1 = rwOp1->getParentOfType()) { + if (!parCond1->isProperAncestor(rwOp2)) { + return {}; + } + } + if (auto *parCond2 = rwOp2->getParentOfType()) { + if (!parCond2->isProperAncestor(rwOp1)) { + return {}; + } + } + assert(parentLoop1->multibufferUnrollNum.value() == + parentLoop2->multibufferUnrollNum.value()); + EventIdInfo eventIdInfo; + eventIdInfo.eventIdNum = parentLoop1->multibufferUnrollNum.value(); + eventIdInfo.multibufferUnrollLoop1 = + cast(parentLoop1->op); + eventIdInfo.multibufferUnrollLoop2 = + cast(parentLoop2->op); + return eventIdInfo; +} + +std::optional +Solver::checkCVMultiBufferPreloadEventIdInfo(RWOperation *rwOp1, + RWOperation *rwOp2) { + assert(rwOp1 != nullptr && rwOp2 != nullptr); + if (!options.isCrossCoreMode()) { + return {}; + } + auto *parentScope1 = rwOp1->getParentOfType(); + auto *parentScope2 = rwOp2->getParentOfType(); + while (parentScope1 != nullptr && !parentScope1->maxPreloadNum.has_value()) { + parentScope1 = parentScope1->getParentOfType(); + } + while (parentScope2 != nullptr && !parentScope2->maxPreloadNum.has_value()) { + parentScope2 = parentScope2->getParentOfType(); + } + if (!parentScope1 || !parentScope2) { + return {}; + } + if (auto *parCond1 = rwOp1->getParentOfType()) { + if (!parCond1->isProperAncestor(rwOp2)) { + return {}; + } + } + if (auto *parCond2 = rwOp2->getParentOfType()) { + if (!parCond2->isProperAncestor(rwOp1)) { + return {}; + } + } + + auto *parentLoop1 = parentScope1->getParentOfType(); + auto *parentLoop2 = parentScope2->getParentOfType(); + if (parentLoop1 == nullptr || parentLoop1 != parentLoop2) { + return {}; + } + + assert(parentScope1->preloadNum.has_value()); + assert(parentScope2->preloadNum.has_value()); + assert(parentScope1->maxPreloadNum.value() == + parentScope2->maxPreloadNum.value()); + + auto parentForLoop = llvm::dyn_cast_if_present(parentLoop1->op); + assert(parentForLoop != nullptr); + + EventIdInfo eventIdInfo; + eventIdInfo.eventIdNum = parentScope1->maxPreloadNum.value(); + eventIdInfo.preloadOffset1 = parentScope1->maxPreloadNum.value() - + parentScope1->preloadNum.value() - 1; + eventIdInfo.preloadOffset2 = parentScope2->maxPreloadNum.value() - + parentScope2->preloadNum.value() - 1; + eventIdInfo.multibufferLoop = parentForLoop; + return eventIdInfo; +} + +// Determine required event id count and optional multibuffer loop parent for +// occurrences. +EventIdInfo Solver::getEventIdInfo(Occurrence *occ1, Occurrence *occ2, + RWOperation *rwOp1, RWOperation *rwOp2, + CorePipeInfo corePipeSrc, + CorePipeInfo corePipeDst) { + assert(occ1 != nullptr && occ2 != nullptr); + assert(rwOp1 != nullptr && rwOp2 != nullptr); + EventIdInfo singleEventId(1); + if (!isBackwardSync(occ1, occ2)) { + return singleEventId; + } + if (auto eventIdInfo = checkCVMultiBufferUnrollEventIdInfo(rwOp1, rwOp2)) { + return eventIdInfo.value(); + } + if (auto eventIdInfo = checkCVMultiBufferPreloadEventIdInfo(rwOp1, rwOp2)) { + return eventIdInfo.value(); + } + if (auto eventIdInfo = + checkMultiBufferEventIdInfo(occ1, occ2, rwOp1, rwOp2)) { + return eventIdInfo.value(); + } + return singleEventId; +} + +// Graph-based check to determine if adding a sync between occ1 and occ2 would +// block progress. Uses GraphSolver (Dijkstra) to estimate minimal reachable +// index. +bool Solver::checkGraphConflict( + Occurrence *occ1, Occurrence *occ2, CorePipeInfo corePipeSrc, + CorePipeInfo corePipeDst, EventIdInfo eventIdInfo, + std::optional startIndex, std::optional endIndex, + const llvm::SmallVector &extraConflictPairs, + const llvm::SmallVector &ignoreConflictPairs) { + assert(occ1 != nullptr && occ2 != nullptr); + if (!startIndex.has_value()) { + startIndex = occ1->endIndex; + } + if (!endIndex.has_value()) { + endIndex = occ2->startIndex; + } + GraphSolver graphSolver(options); + llvm::DenseSet visited; + auto handleConflictPair = [&](ConflictPair *conflictPair) { + if (conflictPair->couldNotRun) { + return; + } + if (conflictPair->endIndex < startIndex.value() || + conflictPair->startIndex > endIndex.value()) { + return; + } + if (conflictPair->isInnerBackward) { + if ((eventIdInfo.eventIdNum * eventIdInfo.eventIdRepeatNum) < + (conflictPair->eventIdInfo.eventIdNum * + conflictPair->eventIdInfo.eventIdRepeatNum)) { + return; + } + } + if (llvm::find(ignoreConflictPairs, conflictPair) != + ignoreConflictPairs.end()) { + return; + } + auto [it, isInserted] = visited.insert(conflictPair); + if (!isInserted) { + return; + } + DEBUG_WITH_TYPE("gss-sync-solver-check-graph-conflict", { + llvm::dbgs() << "add-conflict-pair: " << conflictPair->str() << '\n'; + }); + graphSolver.addConflictPair(conflictPair); + }; + + for (auto *parOcc : occ1->getAllParents()) { + if (scopeOccChosenConflicts.contains(parOcc)) { + for (auto *conflictPair : scopeOccChosenConflicts[parOcc]) { + handleConflictPair(conflictPair); + } + } + } + for (auto *parOcc : occ2->getAllParents()) { + if (scopeOccChosenConflicts.contains(parOcc)) { + for (auto *conflictPair : scopeOccChosenConflicts[parOcc]) { + handleConflictPair(conflictPair); + } + } + } + for (auto &[scopeOccPair, chosenConflicts] : scopeOccPairChosenConflicts) { + auto [scopeOcc1, scopeOcc2] = scopeOccPair; + if (scopeOcc1->isProperAncestor(occ1) && + scopeOcc2->isProperAncestor(occ2)) { + for (auto *conflictPair : chosenConflicts) { + handleConflictPair(conflictPair); + } + } + } + for (auto *parOcc : occ1->getAllParents()) { + if (persistentScopeOccChosenConflicts.contains(parOcc)) { + for (auto *conflictPair : persistentScopeOccChosenConflicts[parOcc]) { + handleConflictPair(conflictPair); + } + } + } + for (auto *parOcc : occ2->getAllParents()) { + if (persistentScopeOccChosenConflicts.contains(parOcc)) { + for (auto *conflictPair : persistentScopeOccChosenConflicts[parOcc]) { + handleConflictPair(conflictPair); + } + } + } + for (auto *conflictPair : extraConflictPairs) { + handleConflictPair(conflictPair); + } + std::optional mnDistance; + if (options.enableUnitFlagFeature) { + mnDistance = graphSolver.runDijkstraUnitFlagEnabled( + occ1, occ2, corePipeSrc, corePipeDst, startIndex.value(), + endIndex.value()); + } else { + mnDistance = graphSolver.runDijkstra(corePipeSrc, corePipeDst, + startIndex.value(), endIndex.value()); + } + return !mnDistance.has_value() || mnDistance.value() > endIndex.value(); +} + +bool Solver::checkSyncOpsConflicts(ConflictPair *conflictPair1, + ConflictPair *conflictPair2) { + if (conflictPair1->isBarrier() || conflictPair2->isBarrier()) { + return false; + } + if (conflictPair1->startIndex > conflictPair2->startIndex) { + std::swap(conflictPair1, conflictPair2); + } + if (conflictPair1->startIndex >= conflictPair2->startIndex || + conflictPair1->endIndex >= conflictPair2->endIndex) { + return true; + } + bool result = false; + if (conflictPair1->setCorePipeInfo != conflictPair2->setCorePipeInfo) { + auto corePipeSrc = conflictPair1->setCorePipeInfo; + auto corePipeDst = conflictPair2->setCorePipeInfo; + Occurrence *occ1 = conflictPair1->setOcc; + Occurrence *occ2 = conflictPair2->setOcc; + auto startIndex = conflictPair1->startIndex + 1; + auto endIndex = conflictPair2->startIndex; + conflictPair1->startIndex += 1; + assert(occ1 != nullptr && occ2 != nullptr); + result = result || + checkGraphConflict(occ1, occ2, corePipeSrc, corePipeDst, + conflictPair1->eventIdInfo, startIndex, + endIndex, {conflictPair1}, {conflictPair2}); + conflictPair1->startIndex -= 1; + } + if (conflictPair1->waitCorePipeInfo != conflictPair2->waitCorePipeInfo) { + auto corePipeSrc = conflictPair1->waitCorePipeInfo; + auto corePipeDst = conflictPair2->waitCorePipeInfo; + Occurrence *occ1 = conflictPair1->waitOcc; + Occurrence *occ2 = conflictPair2->waitOcc; + auto startIndex = conflictPair1->endIndex; + auto endIndex = conflictPair2->endIndex - 1; + conflictPair2->endIndex -= 1; + assert(occ1 != nullptr && occ2 != nullptr); + result = result || + checkGraphConflict(occ1, occ2, corePipeSrc, corePipeDst, + conflictPair1->eventIdInfo, startIndex, + endIndex, {conflictPair1}, {conflictPair2}); + conflictPair2->endIndex += 1; + } + DEBUG_WITH_TYPE("gss-check-sync-ops-conflicts", { + if (result) { + llvm::dbgs() << "sync-ops-conflict-found: " << "\n"; + llvm::dbgs() << " " << conflictPair1->str() << '\n'; + llvm::dbgs() << " " << conflictPair2->str() << '\n'; + } + }); + return result; +} + +// Check whether two ConflictPair entries conflict in pipe and time ranges. +bool Solver::checkIntersect(ConflictPair *conflictPair1, + ConflictPair *conflictPair2) { + assert(conflictPair1 != nullptr && conflictPair2 != nullptr); + if (conflictPair1 == conflictPair2) { + return false; + } + if (conflictPair1->isBarrier() || conflictPair2->isBarrier()) { + return false; + } + if (conflictPair1->dontCheckForConflict || + conflictPair2->dontCheckForConflict) { + return false; + } + if (options.isCrossCoreMode()) { + return checkSyncOpsConflicts(conflictPair1, conflictPair2); + } + if (conflictPair1->setCorePipeInfo != conflictPair2->setCorePipeInfo || + conflictPair1->waitCorePipeInfo != conflictPair2->waitCorePipeInfo) { + return false; + } + for (auto [l1, r1] : getRanges(conflictPair1)) { + for (auto [l2, r2] : getRanges(conflictPair2)) { + if (checkRangesIntersect(l1, r1 + 1, l2, r2 + 1)) { + return true; + } + } + } + return false; +} + +// Obtain available event ids while accounting for already chosen conflicts. +std::vector +Solver::getIntersectingConflictPairs(ConflictPair *conflictPair) { + assert(conflictPair != nullptr); + if (conflictPair->isBarrier()) { + return {}; + } + if (conflictPair->dontCheckForConflict) { + return {}; + } + std::vector intersectingConflictPairs; + for (auto &curConflictPair : chosenConflictedPairs) { + if (checkIntersect(conflictPair, curConflictPair.get())) { + intersectingConflictPairs.push_back(curConflictPair.get()); + } + } + for (auto &curConflictPair : persistentChosenConflictedPairs) { + if (checkIntersect(conflictPair, curConflictPair.get())) { + intersectingConflictPairs.push_back(curConflictPair.get()); + } + } + return intersectingConflictPairs; +} + +// Processed-pair tracking helpers. +bool Solver::checkVisited(Occurrence *occ1, Occurrence *occ2) { + auto [it, isInserted] = processedOccPairs.insert(std::make_pair(occ1, occ2)); + return !isInserted; +} + +bool Solver::checkSkippable(bool reverseOrder, Occurrence *occ) { + return skipOcc[reverseOrder].contains(occ); +} + +// Synced-pair memoization helpers. +EventIdNode *Solver::getOldEventIdNodeIfExists(ConflictPair *conflictPair) { + assert(conflictPair != nullptr); + auto oldConflictPairs = getMemorizedSyncedPairs(conflictPair); + if (oldConflictPairs.empty()) { + return {}; + } + ConflictPair *oldConflictPair = *oldConflictPairs.begin(); + assert(oldConflictPair != nullptr && oldConflictPair->eventIdNode != nullptr); + return oldConflictPair->eventIdNode; +} + +llvm::DenseSet +Solver::getMemorizedSyncedPairs(ConflictPair *conflictPair) { + auto key = std::make_tuple( + conflictPair->backwardSyncLoopOp, conflictPair->op1, conflictPair->op2, + conflictPair->setCorePipeInfo, conflictPair->waitCorePipeInfo); + return syncedPairs[key]; +} + +void Solver::memorizeSyncedPair(ConflictPair *conflictPair) { + auto key = std::make_tuple( + conflictPair->backwardSyncLoopOp, conflictPair->op1, conflictPair->op2, + conflictPair->setCorePipeInfo, conflictPair->waitCorePipeInfo); + syncedPairs[key].insert(conflictPair); +#ifndef NDEBUG + for (auto *oldConflictPair : syncedPairs[key]) { + assert(oldConflictPair->eventIdNode == conflictPair->eventIdNode); + } +#endif +} + +void Solver::forgetSyncedPair(ConflictPair *conflictPair) { + assert(conflictPair != nullptr); + auto key = std::make_tuple( + conflictPair->backwardSyncLoopOp, conflictPair->op1, conflictPair->op2, + conflictPair->setCorePipeInfo, conflictPair->waitCorePipeInfo); + syncedPairs[key].erase(conflictPair); +} + +void Solver::memorizeReusedSyncedPair(ConflictPair *conflictPair, + ConflictPair *reusedConflictPair) { + assert(conflictPair != nullptr); + replacedWithReusableSyncedPairs[{ + conflictPair->backwardSyncLoopOp, conflictPair->op1, conflictPair->op2, + conflictPair->setCorePipeInfo, conflictPair->waitCorePipeInfo}] = + reusedConflictPair; +} + +bool Solver::skipMMad1DecomposedLoopOpt(Occurrence *occ1, Occurrence *occ2) { + auto *parentLoopOp1 = OperationBase::getParentloop(occ1->op); + auto *parentLoopOp2 = OperationBase::getParentloop(occ2->op); + if (parentLoopOp1 != nullptr && parentLoopOp2 != nullptr) { + if (parentLoopOp1 != parentLoopOp2) { + if (isa(parentLoopOp1) && + isa(parentLoopOp2)) { + return true; + } + } + } + return false; +} + +std::optional> +Solver::checkAndApplyMmadl0LoopOpt(ConflictPair *conflictPair, Occurrence *occ1, + Occurrence *occ2, Occurrence *parOcc1, + Occurrence *parOcc2) { + if (!options.decomposeMmadl1Op) { + return {}; + } + if (occ1->parentOcc != nullptr && occ1->parentOcc->parentOcc != nullptr && + occ1->parentOcc->parentOcc->parentOcc == parOcc1 && + llvm::isa_and_present( + occ1->op) && + llvm::isa_and_present( + occ1->parentOcc->parentOcc->op)) { + conflictPair->setOnLastIterOnly = true; + return std::make_pair(occ1, parOcc2); + } + if (!conflictPair->isInnerBackward && occ2->parentOcc != nullptr && + occ2->parentOcc->parentOcc != nullptr && + occ2->parentOcc->parentOcc->parentOcc == parOcc2 && + llvm::isa_and_present( + occ2->op) && + llvm::isa_and_present( + occ2->parentOcc->parentOcc->op)) { + conflictPair->waitOnFirstIterOnly = true; + return std::make_pair(parOcc1, occ2); + } + return {}; +} + +std::optional Solver::checkUnitFlagPatterns(Occurrence *occ1, + Occurrence *occ2) { + return {}; +} + +Occurrence *Solver::getBeforePlaceHolderOcc(Occurrence *occ) { + assert(occ != nullptr); + assert(llvm::isa_and_present(occ->op)); + int index = occ->syncIrIndex - 1; + assert(0 <= index && index < static_cast(syncIr.size())); + auto *placeHolderOcc = syncIr[index].get(); +#ifndef NDEBUG + auto *placeHolderOp = llvm::dyn_cast(placeHolderOcc->op); + assert(placeHolderOp != nullptr); + assert(placeHolderOp->beforeOp == occ->op); +#endif + return placeHolderOcc; +} + +Occurrence *Solver::getAfterPlaceHolderOcc(Occurrence *occ) { + assert(occ != nullptr); + assert(llvm::isa_and_present(occ->op)); + int index = occ->syncIrEndIndex; + assert(0 <= index && index < static_cast(syncIr.size())); + auto *placeHolderOcc = syncIr[index].get(); +#ifndef NDEBUG + auto *placeHolderOp = llvm::dyn_cast(placeHolderOcc->op); + assert(placeHolderOp != nullptr); + assert(placeHolderOp->afterOp == occ->op); +#endif + return placeHolderOcc; +} + +Occurrence *Solver::getScopeBeginPlaceHolderOcc(Occurrence *occ) { + assert(occ != nullptr); + assert(llvm::isa_and_present(occ->op)); + int index = occ->syncIrIndex + 1; + assert(0 <= index && index < static_cast(syncIr.size())); + auto *placeHolderOcc = syncIr[index].get(); +#ifndef NDEBUG + auto *placeHolderOp = llvm::dyn_cast(placeHolderOcc->op); + assert(placeHolderOp != nullptr); + assert(placeHolderOp->scopeBegin == occ->op); +#endif + return placeHolderOcc; +} + +Occurrence *Solver::getScopeEndPlaceHolderOcc(Occurrence *occ) { + assert(occ != nullptr); + assert(llvm::isa_and_present(occ->op)); + int index = occ->syncIrEndIndex - 1; + assert(0 <= index && index < static_cast(syncIr.size())); + auto *placeHolderOcc = syncIr[index].get(); +#ifndef NDEBUG + auto *placeHolderOp = llvm::dyn_cast(placeHolderOcc->op); + assert(placeHolderOp != nullptr); + assert(placeHolderOp->scopeEnd == occ->op); +#endif + return placeHolderOcc; +} + +std::pair +Solver::getSetWaitLCAPairOcc(Occurrence *occ1, Occurrence *occ2) { + assert(occ1 != nullptr && occ2 != nullptr); + + auto [grandParOcc1, grandParOcc2] = Occurrence::getLCAPair(occ1, occ2); + assert(grandParOcc1 != nullptr && grandParOcc2 != nullptr); + assert(grandParOcc1->parentOcc != nullptr && + grandParOcc2->parentOcc != nullptr); + + auto [parOp1, parOp2] = OperationBase::getLCAPair(occ1->op, occ2->op); + assert(parOp1 != nullptr && parOp2 != nullptr); + assert(parOp1->parentOp != nullptr && parOp2->parentOp != nullptr); + assert(parOp1->parentOp == parOp2->parentOp); + + auto *parOcc1 = occ1->getParentWithOp(parOp1->parentOp); + auto *parOcc2 = occ2->getParentWithOp(parOp2->parentOp); + assert(parOcc1 != nullptr && parOcc2 != nullptr); + assert(parOcc1 != occ1 && parOcc2 != occ2); + + auto *setOcc = occ1->getNthParent(occ1->depth - parOcc1->depth - 1); + auto *waitOcc = occ2->getNthParent(occ2->depth - parOcc2->depth - 1); + assert(setOcc != nullptr && waitOcc != nullptr); + assert(parOcc1->isProperAncestor(setOcc)); + assert(parOcc2->isProperAncestor(waitOcc)); + + auto *parLoop = Occurrence::getParentloop(setOcc); + while (parLoop != nullptr && grandParOcc1->isProperAncestor(parLoop)) { + setOcc = parLoop; + waitOcc = Occurrence::getParentloop(waitOcc); + parLoop = Occurrence::getParentloop(setOcc); + } + return std::make_pair(setOcc, waitOcc); +} + +std::pair +Solver::getFixedSetWaitOcc(Occurrence *occ1, Occurrence *occ2) { + // - get setOcc waitOcc where: + // setOcc->op->parent = waitOcc->op->parent = lca(occ1, occ2)->op + auto [setOcc, waitOcc] = getSetWaitLCAPairOcc(occ1, occ2); + + // - check if it's the case of while loop: + // while{ + // before{ + // occ1 + // } + // setOcc; + // waitOcc; + // after{ + // occ2 + // } + // } + // - and fix it to be: + // while{ + // before{ + // occ1 + // setOcc; + // ... + // waitOcc; + // placeHolder + // } + // after{ + // occ2 + // } + // } + if (setOcc->op != waitOcc->op) { + if (auto *parLoopOp = + llvm::dyn_cast_if_present(setOcc->parentOcc->op)) { + if (parLoopOp->body.size() > 1 && !isa(waitOcc->op)) { + auto *placeHolderOcc = getScopeEndPlaceHolderOcc(setOcc); + std::tie(setOcc, waitOcc) = getSetWaitLCAPairOcc(occ1, placeHolderOcc); + } + } + } + + // - check if it's the case of: + // loop(iter-1){ + // condition{ + // true-scope{} + // setOcc() + // false-scope{} + // } + // } + // loop(iter-2){ + // condition{ + // true-scope{} + // waitOcc() + // false-scope{} + // } + // } + // - and fix it to be: + // loop(iter-1){ + // condition{ + // true-scope{} + // false-scope{} + // } + // setOcc() + // } + // loop(iter-2){ + // waitOcc() + // condition{ + // true-scope{} + // false-scope{} + // } + // } + if (isBackwardSync(occ1, occ2)) { + if (setOcc->parentOcc != nullptr) { + if (llvm::isa_and_present(setOcc->parentOcc->op)) { + setOcc = setOcc->parentOcc; + } + } + if (waitOcc->parentOcc != nullptr) { + if (llvm::isa_and_present(waitOcc->parentOcc->op)) { + waitOcc = waitOcc->parentOcc; + } + } + } + + // - for the case of cv-pipelining: + // loop(){ + // op1 + // } {unroll=x} + // setOcc + // waitOcc + // loop(){ + // op2 + // } {unroll=x} + // - and fix it to be: + // loop(){ + // op1 + // setOcc + // } {unroll=x} + // loop(){ + // waitOcc + // op2 + // } {unroll=x} + if (options.isCrossCoreMode()) { + assert(setOcc->op != nullptr && waitOcc->op != nullptr); + auto *forOp1 = llvm::dyn_cast_if_present(setOcc->op); + auto *forOp2 = llvm::dyn_cast_if_present(waitOcc->op); + if (forOp1 != nullptr && forOp2 != nullptr) { + if (forOp1->multibufferUnrollNum && forOp2->multibufferUnrollNum) { + assert(forOp1->multibufferUnrollNum == forOp2->multibufferUnrollNum); + setOcc = occ1->getNthParent(occ1->depth - setOcc->depth - 2); + waitOcc = occ2->getNthParent(occ2->depth - waitOcc->depth - 2); + } + } + } + + // - for the case of cv-pipelining: + // scope(){ + // op1 + // } {preload=x} + // setOcc + // waitOcc + // scope(){ + // op2 + // } {preload=x} + // - and fix it to be: + // scope(){ + // op1 + // setOcc + // } {preload=x} + // scope(){ + // waitOcc + // op2 + // } {preload=x} + if (options.isCrossCoreMode()) { + assert(setOcc->op != nullptr && waitOcc->op != nullptr); + auto *scopeOp1 = llvm::dyn_cast_if_present(setOcc->op); + auto *scopeOp2 = llvm::dyn_cast_if_present(waitOcc->op); + if (scopeOp1 != nullptr && scopeOp2 != nullptr) { + if (scopeOp1->maxPreloadNum && scopeOp2->maxPreloadNum) { + assert(scopeOp1->maxPreloadNum == scopeOp2->maxPreloadNum); + setOcc = occ1->getNthParent(occ1->depth - setOcc->depth - 2); + waitOcc = occ2->getNthParent(occ2->depth - waitOcc->depth - 2); + } + } + } + + // - check if it's the case of: + // { + // op1 + // setOcc + // ... + // waitOcc + // loop(){} + // setOcc + // ... + // waitOcc + // op2 + // } + // - and fix it to be: + // { + // op1 + // setOcc + // ... + // waitOcc + // placeHolder + // loop(){} + // placeHolder + // setOcc + // ... + // waitOcc + // op2 + // } + if (llvm::isa_and_present(setOcc->op)) { + setOcc = getAfterPlaceHolderOcc(setOcc); + } + if (llvm::isa_and_present(waitOcc->op)) { + waitOcc = getBeforePlaceHolderOcc(waitOcc); + } + + return std::make_pair(setOcc, waitOcc); +} + +std::optional> +Solver::getFunctionBlockSetWaitOcc(Occurrence *occ1, Occurrence *occ2) { + assert(occ1 != nullptr && occ2 != nullptr); + auto *parFunctionBlock1 = occ1->getParentOfType(); + auto *parFunctionBlock2 = occ2->getParentOfType(); + if (parFunctionBlock1 == parFunctionBlock2) { + return {}; + } + auto *placeHolderOcc = getScopeBeginPlaceHolderOcc(parFunctionBlock2); + return std::make_pair(placeHolderOcc, occ2); +} + +std::optional> +Solver::getUnlikelyCondSetWaitOcc(Occurrence *occ1, Occurrence *occ2) { + assert(occ1 != nullptr && occ2 != nullptr); + if (options.isCrossCoreMode() && isBackwardSync(occ1, occ2)) { + return {}; + } + if (auto *unlikelyParCondOcc1 = + Occurrence::getUnlikelyParentCondition(occ1)) { + if (!unlikelyParCondOcc1->isProperAncestor(occ2)) { + auto *parentLoopOcc = Occurrence::getParentloop(unlikelyParCondOcc1); + if (parentLoopOcc == nullptr || parentLoopOcc->isProperAncestor(occ2)) { + auto *placeHolderOcc = getScopeEndPlaceHolderOcc( + occ1->getNthParent(occ1->depth - unlikelyParCondOcc1->depth - 1)); + return std::make_pair(occ1, placeHolderOcc); + } + } + } + if (auto *unlikelyParCondOcc2 = + Occurrence::getUnlikelyParentCondition(occ2)) { + if (!unlikelyParCondOcc2->isProperAncestor(occ1)) { + auto *parentLoopOcc = Occurrence::getParentloop(unlikelyParCondOcc2); + if (parentLoopOcc == nullptr || parentLoopOcc->isProperAncestor(occ1)) { + auto *placeHolderOcc = getScopeBeginPlaceHolderOcc( + occ2->getNthParent(occ2->depth - unlikelyParCondOcc2->depth - 1)); + return std::make_pair(placeHolderOcc, occ2); + } + } + } + return {}; +} + +std::pair Solver::getSetWaitOcc(Occurrence *occ1, + Occurrence *occ2) { + if (auto functionBlockOpt = getFunctionBlockSetWaitOcc(occ1, occ2)) { + std::tie(occ1, occ2) = functionBlockOpt.value(); + } + if (auto unlikelyOpt = getUnlikelyCondSetWaitOcc(occ1, occ2)) { + std::tie(occ1, occ2) = unlikelyOpt.value(); + } + return getFixedSetWaitOcc(occ1, occ2); +} + +Occurrence *Solver::getBarrierWaitOcc(Occurrence *occ1, Occurrence *occ2) { + auto [setOcc, waitOcc] = getSetWaitOcc(occ1, occ2); + if (!waitOcc->isProperAncestor(occ2)) { + return waitOcc; + } + auto allParents = occ2->getAllParents(); + while (!allParents.empty() && allParents.back()->isProperAncestor(waitOcc)) { + allParents.pop_back(); + } + while (allParents.size() >= 2 && + llvm::isa_and_present(allParents.back()->op)) { + allParents.pop_back(); + assert(llvm::isa_and_present(allParents.back()->op)); + allParents.pop_back(); + } + waitOcc = !allParents.empty() ? allParents.back() : occ2; + return waitOcc; +} + +void Solver::insertBarrierAllBeforeOcc(Occurrence *occ, bool isUseless, + bool isPersistent) { + assert(occ != nullptr); + auto *rwOp = llvm::dyn_cast_if_present(occ->op); + assert(rwOp != nullptr); + auto conflictPair = std::make_unique( + nullptr, nullptr, rwOp, rwOp, occ, occ, + CorePipeInfo(pto::TCoreType::CUBE_OR_VECTOR, pto::PIPE::PIPE_ALL), + CorePipeInfo(pto::TCoreType::CUBE_OR_VECTOR, pto::PIPE::PIPE_ALL), + occ->startIndex, occ->startIndex); + conflictPair->isUseless = isUseless; + auto *normScopeOcc = occ->parentOcc; + assert(normScopeOcc != nullptr); + LLVM_DEBUG(llvm::dbgs() << (isPersistent ? "is-persistent " : "") + << occ->op->str(0, false) << ' ' + << conflictPair->str() << '\n';); + if (isPersistent) { + persistentScopeOccChosenConflicts[normScopeOcc].insert(conflictPair.get()); + persistentChosenConflictedPairs.push_back(std::move(conflictPair)); + } else { + insertedBarrierAllBefore[occ->op].insert({occ, isUseless}); + scopeOccChosenConflicts[normScopeOcc].insert(conflictPair.get()); + chosenConflictedPairs.push_back(std::move(conflictPair)); + } +} + +void Solver::insertBarrierAllBeforeOp(OperationBase *op, bool isUseless, + bool isPersistent) { + assert(op != nullptr); + for (auto *occ : opAllOccurrences[op]) { + insertBarrierAllBeforeOcc(occ, isUseless, isPersistent); + isUseless = true; + } +} + +// When barrier-all markers need to be chosen, insert them before all +// occurrences for the chosen op. +void Solver::pickAndInsertABarrierAll() { + assert(!insertedBarrierAllBefore.empty()); + OperationBase *chosenOp = nullptr; + for (auto &[op, vec] : insertedBarrierAllBefore) { + if (vec.empty()) { + continue; + } + if (chosenOp == nullptr || chosenOp->id > op->id) { + chosenOp = op; + } + } + assert(chosenOp != nullptr); + insertBarrierAllBeforeOp(chosenOp, /*isUseless=*/false, + /*isPersistent=*/true); +} + +bool Solver::isBackwardSync(Occurrence *occ1, Occurrence *occ2) { + if (occ1->op->id >= occ2->op->id) { + return true; + } + assert(occ1 != nullptr && occ2 != nullptr); + assert(occ1->op != nullptr && occ2->op != nullptr); + auto [parOcc1, parOcc2] = Occurrence::getLCAPair(occ1, occ2); + auto [parOp1, parOp2] = OperationBase::getLCAPair(occ1->op, occ2->op); + return parOcc1->parentOcc->op != parOp1->parentOp; +} + +bool Solver::reuseCmp(ConflictPair *conflictPair1, + ConflictPair *conflictPair2) { + assert(conflictPair1 != nullptr && conflictPair2 != nullptr); + assert(conflictPair1->op1 != nullptr && conflictPair1->op2 != nullptr); + assert(conflictPair2->op1 != nullptr && conflictPair2->op2 != nullptr); + if (conflictPair1->startIndex != conflictPair2->startIndex) { + return conflictPair1->startIndex < conflictPair2->startIndex; + } + if (conflictPair1->endIndex != conflictPair2->endIndex) { + return conflictPair1->endIndex > conflictPair2->endIndex; + } + if (conflictPair1->op1 != conflictPair2->op1) { + return conflictPair1->op1->id > conflictPair2->op1->id; + } + if (conflictPair1->op2 != conflictPair2->op2) { + return conflictPair1->op2->id > conflictPair2->op2->id; + } + return false; +} + +ConflictPair *Solver::getReusableConflictPair( + ConflictPair *conflictPair, + const llvm::DenseSet &conflictPairsSet) { + assert(conflictPair != nullptr); + ConflictPair *ret = nullptr; + for (auto *curConflictPair : conflictPairsSet) { + if (curConflictPair->isBarrier() || curConflictPair->dontReuse) { + continue; + } + if (curConflictPair->op1 != conflictPair->op1 || + curConflictPair->op2 != conflictPair->op2 || + curConflictPair->setCorePipeInfo != conflictPair->setCorePipeInfo || + curConflictPair->waitCorePipeInfo != conflictPair->waitCorePipeInfo) { + continue; + } + if (!checkIntersect(conflictPair, curConflictPair)) { + continue; + } + if (curConflictPair->startIndex >= conflictPair->startIndex) { + continue; + } + if (conflictPair->eventIdNode->eventIdNum < + curConflictPair->eventIdNode->eventIdNum) { + continue; + } + assert(conflictPair->eventIdNode != nullptr); + assert(curConflictPair->eventIdNode != nullptr); + if (conflictPair->eventIdNode->eventIdNum > + curConflictPair->eventIdNode->eventIdNum) { + if (conflictPair->eventIdNode->eventIdNum % + curConflictPair->eventIdNode->eventIdNum) { + continue; + } + } + assert(conflictPair->startIndex <= curConflictPair->endIndex); + assert(curConflictPair->endIndex <= conflictPair->endIndex); + if (ret == nullptr || reuseCmp(ret, curConflictPair)) { + ret = curConflictPair; + } + } + return ret; +} + +bool Solver::reuseConflictPair(ConflictPair *conflictPair, + Occurrence *scopeOcc1, Occurrence *scopeOcc2) { + if (conflictPair->isBarrier()) { + return false; + } + if (scopeOcc1->op != scopeOcc2->op) { + return false; + } + if (!barrierAllPairs.empty()) { + return false; + } + + ConflictPair *oldReusedConflictPair = nullptr; + if (conflictPair->isUseless) { + auto it = replacedWithReusableSyncedPairs.find( + {conflictPair->backwardSyncLoopOp, conflictPair->op1, conflictPair->op2, + conflictPair->setCorePipeInfo, conflictPair->waitCorePipeInfo}); + if (it != replacedWithReusableSyncedPairs.end()) { + oldReusedConflictPair = it->second; + } + } + +#ifndef NDEBUG + if (!conflictPair->isUseless) { + auto it = replacedWithReusableSyncedPairs.find( + {conflictPair->backwardSyncLoopOp, conflictPair->op1, conflictPair->op2, + conflictPair->setCorePipeInfo, conflictPair->waitCorePipeInfo}); + assert(it == replacedWithReusableSyncedPairs.end()); + } +#endif + + if (conflictPair->isUseless && oldReusedConflictPair == nullptr) { + return false; + } + + auto corePipeSrc = conflictPair->setCorePipeInfo; + auto corePipeDst = conflictPair->waitCorePipeInfo; + + if (oldReusedConflictPair == nullptr) { + if (!reusePairs.contains({corePipeSrc, corePipeDst}) || + reusePairs[{corePipeSrc, corePipeDst}] <= + reusedPairs[{corePipeSrc, corePipeDst}]) { + return false; + } + } + + assert(reusePairs.contains(std::make_tuple(corePipeSrc, corePipeDst))); + assert(reusePairs[std::make_tuple(corePipeSrc, corePipeDst)] >= + reusedPairs[std::make_tuple(corePipeSrc, corePipeDst)]); + + ConflictPair *opt1 = nullptr; + ConflictPair *opt2 = nullptr; + ConflictPair *opt3 = nullptr; + ConflictPair *opt4 = nullptr; + ConflictPair *opt5 = nullptr; + + auto it1 = scopeOccChosenConflicts.find(scopeOcc1); + auto it2 = scopeOccChosenConflicts.find(scopeOcc2); + auto it3 = scopeOccPairChosenConflicts.find({scopeOcc1, scopeOcc2}); + auto it4 = persistentScopeOccChosenConflicts.find(scopeOcc1); + auto it5 = persistentScopeOccChosenConflicts.find(scopeOcc2); + + if (it1 != scopeOccChosenConflicts.end()) { + opt1 = getReusableConflictPair(conflictPair, it1->second); + } + if (it2 != scopeOccChosenConflicts.end()) { + opt2 = getReusableConflictPair(conflictPair, it2->second); + } + if (it3 != scopeOccPairChosenConflicts.end()) { + opt3 = getReusableConflictPair(conflictPair, it3->second); + } + if (it4 != persistentScopeOccChosenConflicts.end()) { + opt4 = getReusableConflictPair(conflictPair, it4->second); + } + if (it5 != persistentScopeOccChosenConflicts.end()) { + opt5 = getReusableConflictPair(conflictPair, it5->second); + } + + ConflictPair *reusableConflictPair = nullptr; + for (auto *opt : {opt1, opt2, opt3, opt4, opt5}) { + if (opt != nullptr) { + if (reusableConflictPair == nullptr || + reuseCmp(reusableConflictPair, opt)) { + reusableConflictPair = opt; + } + } + } + + if (reusableConflictPair == nullptr) { + return false; + } + + DEBUG_WITH_TYPE("gss-sync-solver-reuse", { + llvm::dbgs() << "reuse: " << conflictPair->str() << '\n'; + llvm::dbgs() << "with: " << reusableConflictPair->str() << '\n'; + }); + + assert(reusableConflictPair->startIndex < conflictPair->startIndex); + assert(reusableConflictPair->endIndex <= conflictPair->endIndex); + reusableConflictPair->setOp = conflictPair->setOp; + reusableConflictPair->setOcc = conflictPair->setOcc; + reusableConflictPair->startIndex = conflictPair->startIndex; + + if (!conflictPair->isUseless) { + memorizeReusedSyncedPair(conflictPair, reusableConflictPair); + } + + DEBUG_WITH_TYPE("gss-sync-solver-reuse", { + if (oldReusedConflictPair != nullptr) { + llvm::dbgs() << "old-reuse: " << oldReusedConflictPair->str() << '\n'; + } + }); + + if (oldReusedConflictPair != nullptr) { + assert(oldReusedConflictPair->op1 == reusableConflictPair->op1); + assert(oldReusedConflictPair->op2 == reusableConflictPair->op2); + assert(oldReusedConflictPair->waitOp == reusableConflictPair->waitOp); + } + + if (!conflictPair->isUseless) { + reusedPairs[{corePipeSrc, corePipeDst}] += 1; + } + + return true; +} + +std::unique_ptr & +Solver::getEventIdSolverRef(pto::PIPE pipeSrc, pto::PIPE pipeDst) { + if (options.isCrossCoreMode()) { + pipeSrc = pto::PIPE::PIPE_UNASSIGNED; + pipeDst = pto::PIPE::PIPE_UNASSIGNED; + } + auto key = std::make_tuple(pipeSrc, pipeDst); + if (!eventIdSolver.contains(key)) { + int64_t eventIdNumMax = + getHWAvailableEventIdNum(options.syncMode, pipeSrc, pipeDst); + if (options.eventIdNumMax.has_value()) { + eventIdNumMax = std::min(eventIdNumMax, options.eventIdNumMax.value()); + eventIdNumMax = std::max(eventIdNumMax, 1); + } + eventIdSolver[key] = std::make_unique(eventIdNumMax); + } + return eventIdSolver[key]; +} + +bool Solver::checkReuseMultiBufferFlagId(ConflictPair *conflictPair) { + if (options.useDifferentMultiBufferFlagIds) { + return false; + } + if (!conflictPair->isInnerBackward || + conflictPair->eventIdInfo.eventIdNum <= 1 || + conflictPair->movedToOuterLoop) { + return false; + } + auto [setOcc, waitOcc] = + std::tie(conflictPair->setOcc, conflictPair->waitOcc); + auto *backwardSyncLoopOcc = conflictPair->backwardSyncLoopOcc; + assert(backwardSyncLoopOcc != nullptr); + if (auto *parCondOcc1 = setOcc->getParentOfType()) { + if (!parCondOcc1->isProperAncestor(backwardSyncLoopOcc)) { + return false; + } + } + if (auto *parCondOcc2 = waitOcc->getParentOfType()) { + if (!parCondOcc2->isProperAncestor(backwardSyncLoopOcc)) { + return false; + } + } + return true; +} + +void Solver::handleSetWaitConflict(Occurrence *occ1, Occurrence *occ2, + CorePipeInfo corePipeSrc, + CorePipeInfo corePipeDst, + EventIdInfo eventIdInfo, bool isUseless) { + assert(occ1 != nullptr && occ2 != nullptr); + auto *rwOp1 = llvm::dyn_cast_if_present(occ1->op); + auto *rwOp2 = llvm::dyn_cast_if_present(occ2->op); + assert(rwOp1 != nullptr && rwOp2 != nullptr); + assert(corePipeSrc != corePipeDst); + + Loop *parentLCALoopOp{nullptr}; + Occurrence *parentLCALoopOcc{nullptr}; + Occurrence *parentLCALoopBeforePHOcc{nullptr}; + Occurrence *parentLCALoopAfterPHOcc{nullptr}; + auto [setOcc, waitOcc] = getSetWaitOcc(occ1, occ2); + + auto [lcaSetOp, lcaWaitOp] = + OperationBase::getLCAPair(setOcc->op, waitOcc->op); + auto *normScopeOcc1 = setOcc->getParentWithOp(lcaSetOp->parentOp); + auto *normScopeOcc2 = waitOcc->getParentWithOp(lcaWaitOp->parentOp); + assert(normScopeOcc1->op == normScopeOcc2->op); + auto *normScopeOp = normScopeOcc1->op; + assert(normScopeOp != nullptr); + assert(normScopeOp->parentOp != nullptr); + + auto conflictPair = std::make_unique( + rwOp1, rwOp2, setOcc->op, waitOcc->op, setOcc, waitOcc, corePipeSrc, + corePipeDst, setOcc->endIndex, waitOcc->startIndex); + assert(conflictPair->startIndex <= conflictPair->endIndex); + + conflictPair->isUseless = isUseless; + conflictPair->isInnerBackward = isBackwardSync(setOcc, waitOcc); + conflictPair->eventIdInfo = eventIdInfo; + + if (conflictPair->isInnerBackward) { + auto [parOcc1, parOcc2] = Occurrence::getLCAPair(occ1, occ2); + assert(parOcc1 != nullptr && parOcc2 != nullptr); + + parentLCALoopOcc = parOcc1->getParentOfType(); + if (moveBackwardSyncPairsToOutmostLoop) { + while (auto *grandParentLoopOcc = + parentLCALoopOcc->getParentOfType()) { + conflictPair->movedToOuterLoop = true; + parentLCALoopOcc = grandParentLoopOcc; + } + } + assert(parentLCALoopOcc != nullptr); + conflictPair->backwardSyncLoopOcc = parentLCALoopOcc; + + parentLCALoopOp = llvm::dyn_cast(parentLCALoopOcc->op); + assert(parentLCALoopOp != nullptr); + conflictPair->backwardSyncLoopOp = parentLCALoopOp; + + parentLCALoopBeforePHOcc = getBeforePlaceHolderOcc(parentLCALoopOcc); + assert(parentLCALoopBeforePHOcc != nullptr); + parentLCALoopAfterPHOcc = getAfterPlaceHolderOcc(parentLCALoopOcc); + assert(parentLCALoopAfterPHOcc != nullptr); + } + + if (auto setWaitOccs = checkAndApplyMmadl0LoopOpt(conflictPair.get(), occ1, + occ2, setOcc, waitOcc)) { + std::tie(setOcc, waitOcc) = setWaitOccs.value(); + conflictPair->updateSetWaitOccs(setOcc, waitOcc); + } + + if (!conflictPair->isInnerBackward || + disabledMultiEventIdPairs.contains({corePipeSrc, corePipeDst})) { + conflictPair->eventIdInfo = EventIdInfo(1); + } + if (checkReuseMultiBufferFlagId(conflictPair.get())) { + conflictPair->eventIdInfo.eventIdRepeatNum = + conflictPair->eventIdInfo.eventIdNum; + conflictPair->eventIdInfo.eventIdNum = 1; + } + + auto &curEventIdSolver = getEventIdSolverRef( + conflictPair->setCorePipeInfo.pipe, conflictPair->waitCorePipeInfo.pipe); + curEventIdSolver->pushActionNone(); + + auto checkColorable = [&]() -> bool { + if (curEventIdSolver->isColorable()) { + return true; + } + LLVM_DEBUG(llvm::dbgs() << "will-be-converted-to-barrier-all " + << conflictPair->str() << '\n';); + insertBarrierAllBeforeOp(occ2->op, conflictPair->isUseless, + /*isPersistent=*/false); + barrierAllPairs.insert({corePipeSrc, corePipeDst}); + curEventIdSolver->undoActions(); + return false; + }; + + if (auto *oldEventIdNode = getOldEventIdNodeIfExists(conflictPair.get())) { + conflictPair->eventIdNode = oldEventIdNode; + curEventIdSolver->insertConflictPair(oldEventIdNode, conflictPair.get()); + } else { + bool reversedPriority = false; + if (conflictPair->isInnerBackward) { + if (OperationBase::getParentloop(occ1->op) == normScopeOp->parentOp && + OperationBase::getParentloop(occ2->op) == normScopeOp->parentOp) { + reversedPriority = true; + } + } + conflictPair->eventIdNode = curEventIdSolver->createNode( + conflictPair.get(), conflictPair->eventIdInfo.eventIdNum, + reversedPriority); + } + + if (options.reuseSyncPairToSaveEventIds) { + if (reuseConflictPair(conflictPair.get(), normScopeOcc1, normScopeOcc2)) { + curEventIdSolver->undoActions(); + return; + } + } + + auto intersectingConflictPairs = + getIntersectingConflictPairs(conflictPair.get()); + curEventIdSolver->addConflicts(conflictPair.get(), intersectingConflictPairs); + if (!checkColorable()) { + return; + } + + LLVM_DEBUG({ + llvm::dbgs() << conflictPair->str() << '\n'; + if (parentLCALoopOcc != nullptr) { + llvm::dbgs() << parentLCALoopOcc->op->str(0, false) << '\n'; + } + }); + + llvm::SmallVector, Occurrence *>> + extraConflictPairs; + + auto insertExtraConflictPair = [&](Occurrence *setOcc, Occurrence *waitOcc, + Occurrence *parentScope, + bool couldNotRun = false) -> bool { + assert(setOcc != nullptr && waitOcc != nullptr && parentScope != nullptr); + auto extraConflictPair = conflictPair->clone(setOcc, waitOcc); + extraConflictPair->isUseless = true; + extraConflictPair->dontReuse = true; + if (couldNotRun || options.moveOutAndMergeBackwardSyncPairs) { + extraConflictPair->couldNotRun = true; + } + LLVM_DEBUG({ + llvm::dbgs() << "extra-conflict-pair: " << extraConflictPair->str() + << "\n"; + }); + curEventIdSolver->insertConflictPair(conflictPair->eventIdNode, + extraConflictPair.get()); + auto intersectingConflictPairs = + getIntersectingConflictPairs(extraConflictPair.get()); + curEventIdSolver->addConflicts(extraConflictPair.get(), + intersectingConflictPairs); + if (!checkColorable()) { + return false; + } + extraConflictPairs.push_back( + std::make_pair(std::move(extraConflictPair), parentScope)); + return true; + }; + + if (conflictPair->isInnerBackward && conflictPair->eventIdNode != nullptr) { + bool insertOuterBwdConflictPair = false; + if ((conflictPair->eventIdInfo.eventIdNum * + conflictPair->eventIdInfo.eventIdRepeatNum) > 1) { + insertOuterBwdConflictPair = true; + } else if (options.isCrossCoreMode()) { + if (setOcc->parentOcc == nullptr || + setOcc->parentOcc->parentOcc == nullptr || + setOcc->parentOcc->parentOcc->op != parentLCALoopOp) { + insertOuterBwdConflictPair = true; + } else if (waitOcc->parentOcc == nullptr || + waitOcc->parentOcc->parentOcc == nullptr || + waitOcc->parentOcc->parentOcc->op != parentLCALoopOp) { + insertOuterBwdConflictPair = true; + } + } + if (insertOuterBwdConflictPair) { + // insert useless conflictPair to cover the whole loop when having + // multi-eventid backward sync to reserve the eventIds. + if (!insertExtraConflictPair(parentLCALoopBeforePHOcc, + parentLCALoopAfterPHOcc, + parentLCALoopOcc->parentOcc)) { + return; + } + } + } + + if (conflictPair->isInnerBackward && conflictPair->eventIdNode != nullptr) { + // insert header/footer useless conflictPairs to reserve the eventIds. + auto *loopOpOcc1 = getFirstIterOcc(waitOcc, normScopeOcc1); + auto *loopOpOcc2 = getLastIterOcc(setOcc, normScopeOcc2); + if (!insertExtraConflictPair(parentLCALoopBeforePHOcc, loopOpOcc1, + parentLCALoopOcc, /*couldNotRun=*/true)) { + return; + } + if (!insertExtraConflictPair(loopOpOcc2, parentLCALoopAfterPHOcc, + parentLCALoopOcc, /*couldNotRun=*/true)) { + return; + } + } + + bool dontInsert = false; + if (conflictPair->isInnerBackward && normScopeOcc1 != normScopeOcc2) { + auto *parCond = OperationBase::getParentCondition(conflictPair->setOp); + if (auto *conditionOp = llvm::dyn_cast_if_present(parCond)) { + if (parentLCALoopOcc->op->isProperAncestor(conditionOp)) { + scopeOccPairChosenConflicts[{normScopeOcc1, normScopeOcc2}].insert( + conflictPair.get()); + dontInsert = true; + } + } + } + if (!dontInsert) { + assert(parentLCALoopOcc != nullptr || normScopeOcc1 == normScopeOcc2); + scopeOccChosenConflicts[normScopeOcc1].insert(conflictPair.get()); + scopeOccChosenConflicts[normScopeOcc2].insert(conflictPair.get()); + } + + memorizeSyncedPair(conflictPair.get()); + chosenConflictedPairs.push_back(std::move(conflictPair)); + + for (auto &[extraConflictPair, parentScope] : extraConflictPairs) { + scopeOccChosenConflicts[parentScope].insert(extraConflictPair.get()); + chosenConflictedPairs.push_back(std::move(extraConflictPair)); + } + + curEventIdSolver->clearActionStack(); +} + +void Solver::handleBarrierConflict(Occurrence *occ1, Occurrence *occ2, + CorePipeInfo corePipeSrc, + CorePipeInfo corePipeDst, bool isUseless) { + assert(occ1 != nullptr && occ2 != nullptr); + auto *rwOp1 = llvm::dyn_cast_if_present(occ1->op); + auto *rwOp2 = llvm::dyn_cast_if_present(occ2->op); + assert(rwOp1 != nullptr && rwOp2 != nullptr); + + assert(corePipeSrc == corePipeDst); + if (corePipeSrc.pipe == pto::PIPE::PIPE_S) { + return; + } + if (options.isRegBasedArch) { + if (corePipeSrc.pipe == pto::PIPE::PIPE_V || + corePipeSrc.pipe == pto::PIPE::PIPE_M) { + return; + } + } + auto *waitOcc = getBarrierWaitOcc(occ1, occ2); + + auto conflictPair = std::make_unique( + rwOp1, rwOp2, waitOcc->op, waitOcc->op, waitOcc, waitOcc, corePipeSrc, + corePipeDst, waitOcc->startIndex, waitOcc->startIndex); + conflictPair->isUseless = isUseless; + assert(conflictPair->startIndex <= conflictPair->endIndex); + + LLVM_DEBUG({ llvm::dbgs() << conflictPair->str() << '\n'; }); + + auto *normScopeOcc = waitOcc->parentOcc; + scopeOccChosenConflicts[normScopeOcc].insert(conflictPair.get()); + chosenConflictedPairs.push_back(std::move(conflictPair)); +} + +void Solver::handleUnitFlagConflict(Occurrence *occ1, Occurrence *occ2, + CorePipeInfo corePipeSrc, + CorePipeInfo corePipeDst, + UnitFlagInfo unitFlagInfo, bool isUseless) { + assert(occ1 != nullptr && occ2 != nullptr); + auto *rwOp1 = llvm::dyn_cast_if_present(occ1->op); + auto *rwOp2 = llvm::dyn_cast_if_present(occ2->op); + assert(rwOp1 != nullptr && rwOp2 != nullptr); + assert(corePipeSrc != corePipeDst); + + auto *setOcc = occ1; + auto *waitOcc = occ2; + auto *normScopeOcc1 = setOcc->parentOcc; + auto *normScopeOcc2 = waitOcc->parentOcc; + + auto conflictPair = std::make_unique( + rwOp1, rwOp2, setOcc->op, waitOcc->op, setOcc, waitOcc, corePipeSrc, + corePipeDst, setOcc->endIndex, waitOcc->startIndex); + assert(conflictPair->startIndex <= conflictPair->endIndex); + + conflictPair->isUseless = true; + conflictPair->dontReuse = true; + conflictPair->replacedWithUnitFlag = true; + conflictPair->dontCheckForConflict = true; + conflictPair->isInnerBackward = isBackwardSync(setOcc, waitOcc); + +#ifndef NDEBUG + Occurrence *parentLCALoopOcc{nullptr}; + if (conflictPair->isInnerBackward) { + auto [parOcc1, parOcc2] = Occurrence::getLCAPair(occ1, occ2); + assert(parOcc1 != nullptr && parOcc2 != nullptr); + parentLCALoopOcc = Occurrence::getParentloop(parOcc1); + assert(parentLCALoopOcc != nullptr); + } + + LLVM_DEBUG({ + llvm::dbgs() << conflictPair->str() << '\n'; + if (parentLCALoopOcc != nullptr) { + llvm::dbgs() << parentLCALoopOcc->op->str(0, false) << '\n'; + } + }); +#endif + + occ1->unitFlagInfo.merge(unitFlagInfo, occ1, occ2, + /*asSet=*/true, /*asWait=*/false); + occ2->unitFlagInfo.merge(unitFlagInfo, occ1, occ2, + /*asSet=*/false, /*asWait=*/true); + if (!isUseless) { + rwOp1->mergedUnitFlagInfo.merge(unitFlagInfo, /*asSet=*/true, + /*asWait=*/false); + rwOp2->mergedUnitFlagInfo.merge(unitFlagInfo, /*asSet=*/false, + /*asWait=*/true); + } + + scopeOccPairChosenConflicts[{normScopeOcc1, normScopeOcc2}].insert( + conflictPair.get()); + chosenConflictedPairs.push_back(std::move(conflictPair)); +} + +void Solver::handleConflict(Occurrence *occ1, Occurrence *occ2, + RWOperation *rwOp1, RWOperation *rwOp2, + CorePipeInfo corePipeSrc, CorePipeInfo corePipeDst, + EventIdInfo eventIdInfo, bool isUseless) { + if (!checkGraphConflict(occ1, occ2, corePipeSrc, corePipeDst, eventIdInfo)) { + return; + } + LLVM_DEBUG({ + llvm::dbgs() << "conflict found: " << "eventIdNum(" + << eventIdInfo.eventIdNum << ")\n"; + llvm::dbgs() << occ1->syncIrIndex << ' ' << occ1->startIndex << ' ' + << occ1->endIndex << ' ' << rwOp1->str(0, false) << '\n'; + llvm::dbgs() << occ2->syncIrIndex << ' ' << occ2->startIndex << ' ' + << occ2->endIndex << ' ' << rwOp2->str(0, false) << '\n'; + }); + if (corePipeSrc == corePipeDst) { + handleBarrierConflict(occ1, occ2, corePipeSrc, corePipeDst, isUseless); + } else if (auto unitFlagInfo = checkUnitFlagPatterns(occ1, occ2)) { + handleUnitFlagConflict(occ1, occ2, corePipeSrc, corePipeDst, + unitFlagInfo.value(), isUseless); + } else { + handleSetWaitConflict(occ1, occ2, corePipeSrc, corePipeDst, eventIdInfo, + isUseless); + } +} + +void Solver::calcAllEventIds() { + for (auto &[pipes, eventIdSolver] : eventIdSolver) { + assert(eventIdSolver != nullptr); + + [[maybe_unused]] auto result = + eventIdSolver->shrinkEventIdMaxToEventIdNum(); + assert(llvm::succeeded(result)); + assert(eventIdSolver->isColorable()); + } +} + +void Solver::collectBackwardSyncEventIds() { + LLVM_DEBUG(llvm::dbgs() << "collectBackwardSyncEventIds\n";); + for (auto &conflictPair : chosenConflictedPairs) { + if (!conflictPair->isUseless && conflictPair->isInnerBackward && + conflictPair->eventIdNode != nullptr) { + LLVM_DEBUG(llvm::dbgs() << " " << conflictPair->str() << "\n";); + for (auto eventId : conflictPair->eventIdNode->getEventIds()) { + auto &e = backwardSyncEvents[conflictPair->backwardSyncLoopOp] + [{conflictPair->setCorePipeInfo, + conflictPair->waitCorePipeInfo}][eventId]; + e = std::max(e, conflictPair->eventIdInfo.eventIdRepeatNum); + } + } + } +} + +void Solver::resetAndBuildSetWaitOpIndex(const SyncMap &syncMapBefore, + const SyncMap &syncMapAfter) { + globalSetWaitIndex = 0; + setWaitStartIndex.clear(); + setWaitEndIndex.clear(); + setWaitStartIndexInclusive.clear(); + setWaitEndIndexInclusive.clear(); + setWaitFlagOpsIndex.clear(); + collectSetWaitOpsIndexes(funcIr.get(), syncMapBefore, syncMapAfter); +} + +std::set> & +Solver::getSetWaitOpsIndexRef(pto::PIPE pipeSrc, pto::PIPE pipeDst, + int64_t eventId) { + auto key = std::make_tuple(pipeSrc, pipeDst, eventId); + return setWaitFlagOpsIndex[key]; +} + +// Collect indices for all Set/Wait ops to facilitate merging decisions. +void Solver::collectSetWaitOpsIndexes(OperationBase *op, + const SyncMap &syncMapBefore, + const SyncMap &syncMapAfter) { + assert(op != nullptr); + setWaitStartIndexInclusive[op] = globalSetWaitIndex++; + if (syncMapBefore.count(op)) { + auto *it = syncMapBefore.find(op); + assert(it != syncMapBefore.end()); + for (auto &syncOp : it->second) { + if (auto *setWaitOp = llvm::dyn_cast(syncOp.get())) { + for (auto eventId : setWaitOp->eventIds) { + auto &index = getSetWaitOpsIndexRef(setWaitOp->pipeSrc, + setWaitOp->pipeDst, eventId); + index.insert({globalSetWaitIndex++, setWaitOp}); + } + } + } + } + setWaitStartIndex[op] = globalSetWaitIndex++; + if (auto *scopeOp = llvm::dyn_cast(op)) { + for (auto &childOp : scopeOp->body) { + collectSetWaitOpsIndexes(childOp.get(), syncMapBefore, syncMapAfter); + } + } + setWaitEndIndex[op] = globalSetWaitIndex++; + if (syncMapAfter.count(op)) { + auto *it = syncMapAfter.find(op); + assert(it != syncMapAfter.end()); + for (auto &syncOp : it->second) { + if (auto *setWaitOp = llvm::dyn_cast(syncOp.get())) { + for (auto eventId : setWaitOp->eventIds) { + auto &index = getSetWaitOpsIndexRef(setWaitOp->pipeSrc, + setWaitOp->pipeDst, eventId); + index.insert({globalSetWaitIndex++, setWaitOp}); + } + } + } + } + setWaitEndIndexInclusive[op] = globalSetWaitIndex++; +} + +bool Solver::checkBackwardSyncEventsContains(OperationBase *op, + CorePipeInfo corePipeSrc, + CorePipeInfo corePipeDst, + int64_t eventId) { + auto *it1 = backwardSyncEvents.find(op); + if (it1 == backwardSyncEvents.end()) { + return false; + } + auto it2 = it1->second.find({corePipeSrc, corePipeDst}); + if (it2 == it1->second.end()) { + return false; + } + return it2->second.contains(eventId); +} + +bool Solver::checkBackwardSyncEventsContainsAfterMerge( + OperationBase *op, CorePipeInfo corePipeSrc, CorePipeInfo corePipeDst) { + auto *it1 = backwardSyncEventsAfterMerge.find(op); + if (it1 == backwardSyncEventsAfterMerge.end()) { + return false; + } + return it1->second.contains({corePipeSrc, corePipeDst}); +} + +// Check whether a backward-sync event id can be merged at scope level. +bool Solver::checkMergeable(Scope *scopeOp, CorePipeInfo corePipeSrc, + CorePipeInfo corePipeDst, int64_t eventId, + bool shouldBeUsedAtleastOnce) { + auto &index = + getSetWaitOpsIndexRef(corePipeSrc.pipe, corePipeDst.pipe, eventId); + if (shouldBeUsedAtleastOnce) { + auto it = index.lower_bound({setWaitStartIndexInclusive[scopeOp], nullptr}); + bool usedAtleastOnce = + it != index.end() && it->first < setWaitEndIndexInclusive[scopeOp]; + if (!usedAtleastOnce) { + return false; + } + } + { + auto it1 = + index.lower_bound({setWaitStartIndexInclusive[scopeOp], nullptr}); + auto it2 = index.lower_bound({setWaitEndIndex[scopeOp], nullptr}); + bool usedBefore = + it1 != index.end() && it1->first < setWaitStartIndex[scopeOp]; + bool usedAfter = + it2 != index.end() && it2->first < setWaitEndIndexInclusive[scopeOp]; + if (usedBefore || usedAfter) { + return false; + } + } + if (auto *conditionOp = llvm::dyn_cast(scopeOp)) { + if (!conditionOp->hasFalseScope()) { + return false; + } + return checkMergeable(conditionOp->getTrueScope(), corePipeSrc, corePipeDst, + eventId, true) && + checkMergeable(conditionOp->getFalseScope(), corePipeSrc, + corePipeDst, eventId, true); + } + if (auto *loopOp = llvm::dyn_cast(scopeOp)) { + for (auto &childOp : loopOp->body) { + if (auto *childScopeOp = llvm::dyn_cast(childOp.get())) { + if (!checkMergeable(childScopeOp, corePipeSrc, corePipeDst, eventId, + false)) { + return false; + } + } + } + for (auto &childOp : loopOp->body) { + if (auto *childScopeOp = llvm::dyn_cast(childOp.get())) { + if (checkMergeable(childScopeOp, corePipeSrc, corePipeDst, eventId, + true)) { + return true; + } + } + } + return false; + } + for (auto &childOp : scopeOp->body) { + auto it1 = + index.lower_bound({setWaitStartIndexInclusive[childOp.get()], nullptr}); + auto it2 = index.lower_bound({setWaitEndIndex[childOp.get()], nullptr}); + bool usedAtleastOnce = it1 != index.end() && + it1->first < setWaitEndIndexInclusive[childOp.get()]; + if (!usedAtleastOnce) { + continue; + } + bool before = + it1 != index.end() && it1->first < setWaitStartIndex[childOp.get()]; + bool after = it2 != index.end() && + it2->first < setWaitEndIndexInclusive[childOp.get()]; + if (before || after) { + return false; + } + if (!checkBackwardSyncEventsContains(childOp.get(), corePipeSrc, + corePipeDst, eventId)) { + return false; + } + if (checkBackwardSyncEventsContainsAfterMerge(childOp.get(), corePipeSrc, + corePipeDst)) { + return false; + } + } + return true; +} + +// Attempt to merge backward sync events across children and prune duplicates. +void Solver::mergeBackwardSyncEventIds(OperationBase *op) { + auto *scopeOp = llvm::dyn_cast_if_present(op); + if (scopeOp == nullptr) { + return; + } + for (auto &op : scopeOp->body) { + mergeBackwardSyncEventIds(op.get()); + } + + if (llvm::isa_and_present(op)) { + return; + } + if (llvm::isa_and_present(op->parentOp)) { + return; + } + + auto *conditionOp = llvm::dyn_cast(op); + if (conditionOp != nullptr) { + if (!conditionOp->hasFalseScope()) { + return; + } + } + + llvm::DenseSet> toBeErased; + + llvm::SmallVector coreTypes; + if (options.isCrossCoreMode()) { + coreTypes = {pto::TCoreType::VECTOR, pto::TCoreType::CUBE}; + } else { + coreTypes = {pto::TCoreType::CUBE_OR_VECTOR}; + } + size_t pipeNumMax = static_cast(pto::PIPE::PIPE_NUM); + const int64_t eventIdMax = getHWAvailableEventIdNum(options.syncMode); + + for (int64_t eventId = 0; eventId < eventIdMax; ++eventId) { + for (auto coreSrc : coreTypes) { + for (auto coreDst : coreTypes) { + for (size_t pipeSrcInt = 0; pipeSrcInt < pipeNumMax; pipeSrcInt++) { + for (size_t pipeDstInt = 0; pipeDstInt < pipeNumMax; pipeDstInt++) { + auto pipeSrc = static_cast(pipeSrcInt); + auto pipeDst = static_cast(pipeDstInt); + auto corePipeSrc = CorePipeInfo(coreSrc, pipeSrc); + auto corePipeDst = CorePipeInfo(coreDst, pipeDst); + if (checkBackwardSyncEventsContains(scopeOp, corePipeSrc, + corePipeDst, eventId)) { + continue; + } + if (checkMergeable(scopeOp, corePipeSrc, corePipeDst, eventId)) { + toBeErased.insert({corePipeSrc, corePipeDst, eventId}); + backwardSyncEvents[scopeOp][{corePipeSrc, corePipeDst}].insert( + {eventId, 1}); + } + } + } + } + } + } + + if (isa(scopeOp)) { + for (auto &op : scopeOp->body) { + if (auto *block = llvm::dyn_cast(op.get())) { + for (auto &childOp : block->body) { + if (auto *childScopeOp = llvm::dyn_cast(childOp.get())) { + for (auto [corePipeSrc, corePipeDst, eventId] : toBeErased) { + if (checkBackwardSyncEventsContains(childScopeOp, corePipeSrc, + corePipeDst, eventId)) { + auto key = std::make_tuple(corePipeSrc, corePipeDst); + backwardSyncEvents[childScopeOp][key].erase(eventId); + if (backwardSyncEvents[childScopeOp][key].empty()) { + backwardSyncEvents[childScopeOp].erase(key); + } + } + } + } + } + } + } + } else { + for (auto &childOp : scopeOp->body) { + if (auto *childScopeOp = llvm::dyn_cast(childOp.get())) { + for (auto [corePipeSrc, corePipeDst, eventId] : toBeErased) { + if (checkBackwardSyncEventsContains(childScopeOp, corePipeSrc, + corePipeDst, eventId)) { + auto key = std::make_tuple(corePipeSrc, corePipeDst); + backwardSyncEvents[childScopeOp][key].erase(eventId); + if (backwardSyncEvents[childScopeOp][key].empty()) { + backwardSyncEvents[childScopeOp].erase(key); + } + } + } + } + } + } +} + +void Solver::mergeBackwardSyncPairs(SyncMap &syncMapBefore, + SyncMap &syncMapAfter) { + if (!options.moveOutAndMergeBackwardSyncPairs) { + return; + } + if (options.isIntraCoreMode()) { + resetAndBuildSetWaitOpIndex(syncMapBefore, syncMapAfter); + auto *scopeOp = llvm::dyn_cast(funcIr.get()); + assert(scopeOp != nullptr && scopeOp->body.front() != nullptr); + mergeBackwardSyncEventIds(scopeOp->body.front().get()); + } +} + +SyncBeforeAfterMap Solver::getBeforeAfterSyncMaps() { + calcAllEventIds(); + SyncMap syncMapBefore, syncMapAfter; + std::vector conflictPairs; + for (auto &conflictPair : chosenConflictedPairs) { + conflictPairs.push_back(conflictPair.get()); + } + for (auto &conflictPair : persistentChosenConflictedPairs) { + conflictPairs.push_back(conflictPair.get()); + } + + for (auto *conflictPair : conflictPairs) { + if (conflictPair->isUseless) { + continue; + } + if (conflictPair->replacedWithUnitFlag) { + continue; + } + assert(conflictPair->setOp != nullptr && conflictPair->waitOp != nullptr); + if (conflictPair->isBarrier()) { + auto barrierOp = std::make_unique( + conflictPair->waitOp->op, conflictPair->waitOp->parentOp, + conflictPair->waitCorePipeInfo.pipe); + LLVM_DEBUG(barrierOp->debugId = conflictPair->id); + syncMapBefore[conflictPair->waitOp].push_back(std::move(barrierOp)); + } else { + assert(conflictPair->eventIdNode != nullptr); + auto setOp = std::make_unique( + conflictPair->setOp->op, conflictPair->setOp->parentOp, + conflictPair->eventIdNode->getEventIds(), + conflictPair->setCorePipeInfo.pipe, + conflictPair->waitCorePipeInfo.pipe); + auto waitOp = std::make_unique( + conflictPair->waitOp->op, conflictPair->waitOp->parentOp, + conflictPair->eventIdNode->getEventIds(), + conflictPair->setCorePipeInfo.pipe, + conflictPair->waitCorePipeInfo.pipe); + if (options.isCrossCoreMode()) { + setOp->coreType = conflictPair->setCorePipeInfo.coreType; + waitOp->coreType = conflictPair->waitCorePipeInfo.coreType; + } + setOp->eventIdInfo = conflictPair->eventIdInfo; + waitOp->eventIdInfo = conflictPair->eventIdInfo; + setOp->checkLastIter = conflictPair->setOnLastIterOnly; + waitOp->checkFirstIter = conflictPair->waitOnFirstIterOnly; + LLVM_DEBUG({ + setOp->debugId = conflictPair->id; + waitOp->debugId = conflictPair->id; + }); + assert(setOp != nullptr && waitOp != nullptr); + syncMapAfter[conflictPair->setOp].push_back(std::move(setOp)); + syncMapBefore[conflictPair->waitOp].push_front(std::move(waitOp)); + } + } + + collectBackwardSyncEventIds(); + mergeBackwardSyncPairs(syncMapBefore, syncMapAfter); + + for (auto &[op, mp] : backwardSyncEvents) { + if (mp.empty()) { + continue; + } + auto *scopeOp = llvm::dyn_cast(op); + assert(scopeOp != nullptr); + for (auto [setWaitCorePipes, eventIdsMp] : mp) { + if (eventIdsMp.empty()) { + continue; + } + llvm::SmallVector eventIds; + for (auto [eventId, repeatNum] : eventIdsMp) { + llvm::SmallVector curEventIds(repeatNum, eventId); + llvm::append_range(eventIds, curEventIds); + } + llvm::sort(eventIds); + auto [corePipeSrc, corePipeDst] = setWaitCorePipes; + auto setOp = + std::make_unique(scopeOp->op, scopeOp->parentOp, eventIds, + corePipeSrc.pipe, corePipeDst.pipe); + auto waitOp = + std::make_unique(scopeOp->op, scopeOp->parentOp, eventIds, + corePipeSrc.pipe, corePipeDst.pipe); + setOp->allAtOnce = true; + waitOp->allAtOnce = true; + if (options.isCrossCoreMode()) { + setOp->coreType = corePipeSrc.coreType; + waitOp->coreType = corePipeDst.coreType; + } + assert(setOp != nullptr && waitOp != nullptr); + syncMapBefore[scopeOp].push_back(std::move(setOp)); + syncMapAfter[scopeOp].push_front(std::move(waitOp)); + } + } + return std::make_pair(std::move(syncMapBefore), std::move(syncMapAfter)); +} + +void Solver::processConflict(Occurrence *occ1, Occurrence *occ2, + RWOperation *rwOp1, RWOperation *rwOp2, + bool isUseless) { + for (auto [corePipeSrc, corePipeDst] : checkMemoryConflicts(rwOp1, rwOp2)) { + if (options.alwaysUsePipeSAsWaitingPipe) { + corePipeDst.pipe = pto::PIPE::PIPE_S; + } + auto eventIdInfo = + getEventIdInfo(occ1, occ2, rwOp1, rwOp2, corePipeSrc, corePipeDst); + handleConflict(occ1, occ2, rwOp1, rwOp2, corePipeSrc, corePipeDst, + eventIdInfo, isUseless); + } +} + +// Main processing loop that iterates processingOrders and attempts to +// discover and record conflicts. +void Solver::processOrders() { + for (auto &[occ1, occ2, rwOp1, rwOp2, isUseless] : processingOrders) { + assert(occ1 != occ2); + assert(occ1->syncIrIndex < occ2->syncIrIndex); + if (checkVisited(occ1, occ2)) { + assert(false && "expected to not check a pair more than once."); + continue; + } + if (checkImpossibleOccPair(occ1, occ2) || checkAlreadySynced(occ1, occ2) || + skipMMad1DecomposedLoopOpt(occ1, occ2) || + checkSkipParallelLoop(occ1, occ2) || + checkSkipCrossCorePair(occ1, occ2)) { + continue; + } + DEBUG_WITH_TYPE("gss-sync-solver-checking", { + llvm::dbgs() << "checking: " << (isUseless ? "is-useless\n" : "\n"); + llvm::dbgs() << occ1->syncIrIndex << ' ' << occ1->startIndex << ' ' + << occ1->endIndex << ' ' << occ1->op->str(0, false) << '\n'; + llvm::dbgs() << occ2->syncIrIndex << ' ' << occ2->startIndex << ' ' + << occ2->endIndex << ' ' << occ2->op->str(0, false) << '\n'; + }); + if (checkAlreadySyncedWithUnitFlag(occ1, occ2)) { + continue; + } + processConflict(occ1, occ2, rwOp1, rwOp2, isUseless); + } +} + +void Solver::insertMergedBackwardSyncPairs() { + for (auto &[scopeOp, st] : backwardSyncEventsAfterMerge) { + for (auto &corePipeInfoPair : st) { + auto [corePipeSrc, corePipeDst] = corePipeInfoPair; + for (auto *scopeOcc : opAllOccurrences[scopeOp]) { + auto *parentScopeOcc = scopeOcc->parentOcc; + assert(parentScopeOcc != nullptr); + Occurrence *setOcc = nullptr; + Occurrence *waitOcc = nullptr; + auto startIndex = scopeOcc->startIndex; + auto endIndex = scopeOcc->endIndex; + if (isa(scopeOp)) { + setOcc = getBeforePlaceHolderOcc(scopeOcc); + waitOcc = getAfterPlaceHolderOcc(scopeOcc); + startIndex = setOcc->endIndex; + endIndex = waitOcc->startIndex; + } + auto conflictPair = std::make_unique( + nullptr, nullptr, nullptr, nullptr, setOcc, waitOcc, corePipeSrc, + corePipeDst, startIndex, endIndex); + assert(conflictPair->startIndex <= conflictPair->endIndex); + conflictPair->isUseless = true; + conflictPair->dontReuse = true; + conflictPair->dontCheckForConflict = true; + conflictPair->couldNotRun = false; // notice this + LLVM_DEBUG({ + llvm::dbgs() << "consider-merged-backward-pair: " + << scopeOp->str(0, false) << ' ' << conflictPair->str() + << "\n"; + }); + scopeOccChosenConflicts[parentScopeOcc].insert(conflictPair.get()); + chosenConflictedPairs.push_back(std::move(conflictPair)); + } + } + } +} + +llvm::LogicalResult Solver::considerOuterBackwardSyncPairs() { + if (!options.considerOuterBackwardSyncPairs) { + return llvm::failure(); + } + bool backwardPairsPositionChanged = false; + for (auto &[scopeOp, st] : backwardSyncEventsAfterMerge) { + SmallVector> toBeErased; + for (auto &corePipeInfoPair : st) { + if (!backwardSyncEvents.contains(scopeOp) || + !backwardSyncEvents[scopeOp].contains(corePipeInfoPair)) { + toBeErased.push_back(corePipeInfoPair); + } + } + if (!toBeErased.empty()) { + backwardPairsPositionChanged = true; + for (auto &corePipeInfoPair : toBeErased) { + st.erase(corePipeInfoPair); + } + } + } + int chosenOpsDepth = -1; + SmallVector chosenOps; + for (auto &[scopeOp, mp] : backwardSyncEvents) { + if (backwardSyncEventsAfterMerge.contains(scopeOp)) { + continue; + } + int scopeOpDepth = scopeOp->getDepth(); + if (chosenOpsDepth == scopeOpDepth) { + chosenOps.push_back(scopeOp); + } else if (chosenOpsDepth == -1 || chosenOpsDepth < scopeOpDepth) { + chosenOps.clear(); + chosenOps.push_back(scopeOp); + chosenOpsDepth = scopeOpDepth; + } + } + if (chosenOps.empty()) { + return llvm::failure(); + } + bool newPairIsInserted = false; + for (auto *chosenOp : chosenOps) { + for (auto &[corePipeInfoPair, eventIdsMp] : backwardSyncEvents[chosenOp]) { + assert(!eventIdsMp.empty()); + if (!eventIdsMp.empty()) { + auto [it, isInserted] = + backwardSyncEventsAfterMerge[chosenOp].insert(corePipeInfoPair); + newPairIsInserted |= isInserted; + } + } + } + return llvm::success(backwardPairsPositionChanged || newPairIsInserted); +} + +llvm::LogicalResult Solver::reuseSyncPairToSaveEventIds() { + if (!options.reuseSyncPairToSaveEventIds || barrierAllPairs.empty()) { + return llvm::failure(); + } + bool limitReached = true; + for (auto [corePipeSrc, corePipeDst] : barrierAllPairs) { + if (reusePairs[{corePipeSrc, corePipeDst}] < maxReuseNum) { + if (reusePairs[{corePipeSrc, corePipeDst}] <= + reusedPairs[{corePipeSrc, corePipeDst}]) { + reusePairs[{corePipeSrc, corePipeDst}] += 1; + limitReached = false; + } + } + } + DEBUG_WITH_TYPE("gss-sync-solver-reuse", { + llvm::dbgs() << "reusePairs: \n"; + for (auto [pipeCorePairs, cnt] : reusePairs) { + llvm::dbgs() << get<0>(pipeCorePairs).pipe << ' ' + << get<1>(pipeCorePairs).pipe << ' ' << cnt << '\n'; + } + }); + return llvm::success(!limitReached); +} + +llvm::LogicalResult Solver::disableMultiEventIdForBarrierAllPairs() { + if (!options.disableMultiEventIdForBarrierAllPairs || + barrierAllPairs.empty()) { + return llvm::failure(); + } + bool newPairIsInserted = false; + for (auto corePipeInfoPair : barrierAllPairs) { + auto [it, isInserted] = disabledMultiEventIdPairs.insert(corePipeInfoPair); + newPairIsInserted |= isInserted; + } + LLVM_DEBUG({ + if (newPairIsInserted) { + llvm::dbgs() << "disabled-multi-event-id-pairs: \n"; + for (auto &[corePipeSrc, corePipeDst] : disabledMultiEventIdPairs) { + llvm::dbgs() << corePipeSrc.coreType << ' ' << corePipeSrc.pipe << ' ' + << corePipeDst.coreType << ' ' << corePipeDst.pipe << '\n'; + } + } + }); + return llvm::success(newPairIsInserted); +} + +llvm::LogicalResult Solver::tryMovingOutBackwardSyncPairsToOuterLoops() { + if (!options.moveOutAndMergeBackwardSyncPairs || !options.isCrossCoreMode() || + dontMoveBackwardSyncPairsToOutmostLoop) { + return llvm::failure(); + } + if (!moveBackwardSyncPairsToOutmostLoop) { + moveBackwardSyncPairsToOutmostLoop = true; + return llvm::success(); + } + if (!barrierAllPairs.empty()) { + moveBackwardSyncPairsToOutmostLoop = false; + dontMoveBackwardSyncPairsToOutmostLoop = true; + return llvm::success(); + } + return llvm::failure(); +} + +// High-level solve orchestration with multiple passes and optional merging +// iterations. +llvm::LogicalResult Solver::runSolver(bool enableOpts1, bool enableOpts2) { + reset(/*resetEventIdRanOutOpts=*/true); + + int64_t runNum = 0; + while (runNum++ < maxRunNum) { + LLVM_DEBUG(llvm::dbgs() << "runNum: " << runNum << '\n'); + + reset(); + insertMergedBackwardSyncPairs(); + processOrders(); + + if (llvm::succeeded(tryMovingOutBackwardSyncPairsToOuterLoops())) { + continue; + } + + if (enableOpts1) { + if (options.considerOuterBackwardSyncPairs) { + getBeforeAfterSyncMaps(); + if (llvm::succeeded(considerOuterBackwardSyncPairs())) { + continue; + } + if (!barrierAllPairs.empty()) { + backwardSyncEventsAfterMerge.clear(); + } + } + } + + if (enableOpts2) { + if (!barrierAllPairs.empty()) { + if (llvm::succeeded(reuseSyncPairToSaveEventIds())) { + continue; + } + if (llvm::succeeded(disableMultiEventIdForBarrierAllPairs())) { + continue; + } + } + } + + if (!barrierAllPairs.empty()) { + pickAndInsertABarrierAll(); + reset(/*resetEventIdRanOutOpts=*/true); + continue; + } + break; + } + + reset(); + insertMergedBackwardSyncPairs(); + processOrders(); + + return llvm::success(runNum < maxRunNum); +} + +void Solver::solve() { + if (llvm::succeeded(runSolver())) { + return; + } + if (!options.isTestMode()) { + if (llvm::succeeded(runSolver(/*enableOpts1=*/false))) { + return; + } + if (llvm::succeeded( + runSolver(/*enableOpts1=*/false, /*enableOpts2=*/false))) { + return; + } + } + llvm_unreachable("GSS: runSolver() failed."); +} diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index ea9466da1..3a7a2e5a4 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -6,12898 +6,5 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -//===- PTOToEmitC.cpp - PTO to EmitC conversion pass ----------------------===// -//===----------------------------------------------------------------------===// -#pragma GCC diagnostic ignored "-Woverloaded-virtual" -// https://discourse.llvm.org/t/matchandrewrite-hiding-virtual-functions/84933/8 - -#include -#include - -#include "PTO/IR/PTO.h" -#include "PTO/IR/PTOTypeUtils.h" -#include "PTO/IR/PTOSyncUtils.h" -#include "PTO/Transforms/Passes.h" - -#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" -#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" -#include "mlir/Analysis/DataFlowFramework.h" - -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/EmitC/IR/EmitC.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" - -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/TypeRange.h" - -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Target/Cpp/CppEmitter.h" -#include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/Func/Transforms/FuncConversions.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/SCF/Transforms/Patterns.h" -#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h" -#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" - -#include -#include -#include -#include - -#define DEBUG_TYPE "pto-emitc" - -namespace mlir { -#define GEN_PASS_DEF_EMITPTOMANUAL -#include "PTO/Transforms/Passes.h.inc" -} // namespace mlir - -using namespace mlir; -using namespace mlir::pto; - -static std::string getElemTypeStringForGT(Type elemTy); -static bool getStaticMemrefLayout(MemRefType mrTy, - SmallVectorImpl &strides, - int64_t &offset); -static int64_t multiplyOrDynamic(int64_t lhs, int64_t rhs); -static void buildGlobalTensorShapeAndStride(ArrayRef shape, - ArrayRef strides, - SmallVectorImpl &shape5D, - SmallVectorImpl &stride5D); -static std::string joinIntTemplateParams(ArrayRef values); -static SmallVector buildRowMajorStrides(ArrayRef shape); -static std::string getGlobalTensorTypeStringFromShape(Type elemTy, - ArrayRef shape, - StringRef layoutEnum = - "pto::Layout::ND"); -static std::string getGlobalTensorTypeStringFromShapeAndStrides( - Type elemTy, ArrayRef shape, ArrayRef strides, - StringRef layoutEnum = "pto::Layout::ND"); -static emitc::OpaqueType getGlobalTensorOpaqueTypeFromShape( - MLIRContext *ctx, Type elemTy, ArrayRef shape, - StringRef layoutEnum = "pto::Layout::ND"); - -static const char *addrSpaceQualifier(pto::AddressSpace as) { - switch (as) { - case pto::AddressSpace::Zero: - return "__gm__"; - case pto::AddressSpace::VEC: - return "__ubuf__"; - case pto::AddressSpace::GM: - return "__gm__"; - case pto::AddressSpace::MAT: - return "__cbuf__"; - case pto::AddressSpace::LEFT: - return "__ca__"; - case pto::AddressSpace::RIGHT: - return "__cb__"; - case pto::AddressSpace::ACC: - return "__cc__"; - case pto::AddressSpace::BIAS: - // Bias tiles are special in pto-isa; keep a safe fallback qualifier. - return "__gm__"; - case pto::AddressSpace::SCALING: - // pto-isa TileType::Scaling maps to __fbuf__ (see pto/common/memory.hpp). - return "__fbuf__"; - } - return "__gm__"; -} - -[[maybe_unused]] static constexpr llvm::StringLiteral kLoweredSetValidShapeAttrName = - "__pto.lowered_set_validshape"; -[[maybe_unused]] static constexpr llvm::StringLiteral kLoweredSetValidShapeConfigAttrName = - "__pto.lowered_set_validshape_config"; -static constexpr llvm::StringLiteral kForceDynamicValidShapeAttrName = - "__pto.force_dynamic_valid_shape"; -static constexpr llvm::StringLiteral kGlobalTensorStridesAttrName = - "__pto.globaltensor_strides"; - -static Value peelUnrealized(Value v) { - if (auto castOp = v.getDefiningOp()) - return castOp.getOperand(0); - return v; -} - -static Value buildGlobalTensorFromMemref(ConversionPatternRewriter &rewriter, - Location loc, Value basePtr, - MemRefType mrTy, Operation *anchor); - -static Value maybeWrapGlobalMemrefAsGlobalTensor( - ConversionPatternRewriter &rewriter, Location loc, Value loweredValue, - Type originalType, Operation *anchor); - -static bool hasCompatibleKnownExtentForMGather(int64_t lhs, int64_t rhs) { - return lhs == ShapedType::kDynamic || rhs == ShapedType::kDynamic || - lhs == rhs; -} - -static bool isKnownUnitExtentForMGather(int64_t value) { - return value == ShapedType::kDynamic || value == 1; -} - -struct GatherScatterShapeLayoutInfo { - SmallVector shape; - bool rowMajor = false; - bool colMajor = false; -}; - -static std::optional -getGatherScatterShapeLayoutInfo(Type ty) { - if (auto tileTy = dyn_cast(ty)) { - ArrayRef validShape = tileTy.getValidShape(); - if (validShape.size() != 2) - return std::nullopt; - - GatherScatterShapeLayoutInfo info; - info.shape.assign(validShape.begin(), validShape.end()); - int32_t blayout = tileTy.getBLayoutValueI32(); - info.rowMajor = blayout == static_cast(pto::BLayout::RowMajor); - info.colMajor = blayout == static_cast(pto::BLayout::ColMajor); - return info; - } - - auto memRefTy = dyn_cast(ty); - if (!memRefTy || memRefTy.getRank() != 2) - return std::nullopt; - - SmallVector strides; - int64_t offset = ShapedType::kDynamic; - if (failed(getStridesAndOffset(memRefTy, strides, offset)) || - strides.size() != 2) - return std::nullopt; - - GatherScatterShapeLayoutInfo info; - info.shape.assign(memRefTy.getShape().begin(), memRefTy.getShape().end()); - info.rowMajor = strides[1] == 1; - info.colMajor = strides[0] == 1; - return info; -} - -static bool isRowCoalescedMGatherIndexType(Type dataTy, Type idxTy) { - auto dataInfo = getGatherScatterShapeLayoutInfo(dataTy); - auto idxInfo = getGatherScatterShapeLayoutInfo(idxTy); - if (!dataInfo || !idxInfo) - return false; - - const bool rowCoalesce1xR = - idxInfo->rowMajor && isKnownUnitExtentForMGather(idxInfo->shape[0]) && - hasCompatibleKnownExtentForMGather(idxInfo->shape[1], dataInfo->shape[0]); - const bool rowCoalesceRx1 = - idxInfo->colMajor && - hasCompatibleKnownExtentForMGather(idxInfo->shape[0], dataInfo->shape[0]) && - isKnownUnitExtentForMGather(idxInfo->shape[1]); - return rowCoalesce1xR || rowCoalesceRx1; -} - -static std::optional getLayoutAttrFromOp(Operation *op) { - if (!op) - return std::nullopt; - if (auto attr = op->getAttrOfType("layout")) - return attr.getLayout(); - return std::nullopt; -} - -static std::optional resolveLayoutFromValueChain(Value v) { - v = peelUnrealized(v); - while (Operation *def = v.getDefiningOp()) { - if (auto layout = getLayoutAttrFromOp(def)) - return layout; - if (auto subview = dyn_cast(def)) { - v = peelUnrealized(subview.getSource()); - continue; - } - if (auto reinterpret = dyn_cast(def)) { - v = peelUnrealized(reinterpret.getSource()); - continue; - } - if (auto cast = dyn_cast(def)) { - v = peelUnrealized(cast.getSource()); - continue; - } - if (auto unrealized = dyn_cast(def)) { - if (unrealized->getNumOperands() == 0) - break; - v = peelUnrealized(unrealized.getOperand(0)); - continue; - } - break; - } - return std::nullopt; -} - -static std::optional -resolveLayoutForGlobalTensor(Operation *anchor, Value basePtr) { - if (auto layout = getLayoutAttrFromOp(anchor)) - return layout; - return resolveLayoutFromValueChain(basePtr); -} - -static std::string layoutToEmitCString(mlir::pto::Layout layout) { - switch (layout) { - case mlir::pto::Layout::ND: - return "pto::Layout::ND"; - case mlir::pto::Layout::DN: - return "pto::Layout::DN"; - case mlir::pto::Layout::NZ: - return "pto::Layout::NZ"; - } - return "pto::Layout::ND"; -} - -static bool isEmitCGlobalTensorLikeType(Type ty) { - auto opaqueTy = dyn_cast(ty); - return opaqueTy && opaqueTy.getValue().contains("GlobalTensor<"); -} - -static std::string getEmitCScalarTypeToken(Type elemTy) { - if (pto::isPTOFloat8Type(elemTy) && - (elemTy.isFloat8E4M3() || elemTy.isFloat8E4M3FN() || - elemTy.isFloat8E4M3FNUZ() || elemTy.isFloat8E4M3B11FNUZ())) - return "float8_e4m3_t"; - if (pto::isPTOFloat8Type(elemTy) && - (elemTy.isFloat8E5M2() || elemTy.isFloat8E5M2FNUZ())) - return "float8_e5m2_t"; - if (isa(elemTy)) - return "hifloat8_t"; - if (isa(elemTy)) - return "float4_e1m2x2_t"; - if (isa(elemTy)) - return "float4_e2m1x2_t"; - if (elemTy.isF16()) - return "half"; - if (elemTy.isBF16()) - return "bfloat16_t"; - if (elemTy.isF32()) - return "float"; - if (elemTy.isF64()) - return "double"; - if (elemTy.isInteger(8)) - return (elemTy.isSignlessInteger(8) || elemTy.isSignedInteger(8)) ? "int8_t" - : "uint8_t"; - if (elemTy.isInteger(16)) - return (elemTy.isSignlessInteger(16) || elemTy.isSignedInteger(16)) - ? "int16_t" - : "uint16_t"; - if (elemTy.isInteger(32)) - return (elemTy.isSignlessInteger(32) || elemTy.isSignedInteger(32)) - ? "int32_t" - : "uint32_t"; - if (elemTy.isInteger(64)) - return cast(elemTy).isUnsigned() ? "uint64_t" : "int64_t"; - return "float"; -} - -static emitc::PointerType getEmitCPointerType(MLIRContext *ctx, - StringRef pointeeTypeStr) { - return emitc::PointerType::get(emitc::OpaqueType::get(ctx, pointeeTypeStr)); -} - -static emitc::PointerType getEmitCPointerType(MLIRContext *ctx, - StringRef qualifier, - StringRef elemTypeStr) { - return getEmitCPointerType(ctx, (qualifier + " " + elemTypeStr).str()); -} - -static bool isEmitCPointerLikeType(Type ty) { - if (isa(ty)) - return true; - if (auto opaqueTy = dyn_cast(ty)) - return opaqueTy.getValue().ends_with("*"); - return false; -} - -static int64_t getEmitCScalarByteWidth(Type elemTy) { - if (pto::getPTOStorageElemByteSize(elemTy) == 1) - return 1; - if (elemTy.isF16() || elemTy.isBF16() || elemTy.isInteger(16)) - return 2; - if (elemTy.isF32() || elemTy.isInteger(32)) - return 4; - if (elemTy.isF64() || elemTy.isInteger(64)) - return 8; - return 4; -} - -static std::string tileBufBLayoutToken(pto::TileBufConfigAttr configAttr); -static std::string tileBufSLayoutToken(pto::TileBufConfigAttr configAttr); -static std::string tileBufPadToken(pto::TileBufConfigAttr configAttr); -static pto::BLayout getTileBufBLayoutValue(pto::TileBufConfigAttr configAttr); -static int64_t renderTileTemplateDim(int64_t rawDim, Type elemTy, - pto::BLayout blayout, int dimIdx); - -static const char *tileRoleToken(Attribute memorySpace) { - if (auto asAttr = dyn_cast_or_null(memorySpace)) { - switch (asAttr.getAddressSpace()) { - case pto::AddressSpace::VEC: - return "TileType::Vec"; - case pto::AddressSpace::MAT: - return "TileType::Mat"; - case pto::AddressSpace::LEFT: - return "TileType::Left"; - case pto::AddressSpace::RIGHT: - return "TileType::Right"; - case pto::AddressSpace::ACC: - return "TileType::Acc"; - case pto::AddressSpace::BIAS: - return "TileType::Bias"; - case pto::AddressSpace::SCALING: - return "TileType::Scaling"; - case pto::AddressSpace::GM: - case pto::AddressSpace::Zero: - return "TileType::Vec"; - } - } - return "TileType::Vec"; -} - -static std::string tileBufCompactToken(pto::TileBufConfigAttr configAttr) { - std::string compactTok = "CompactMode::Null"; - if (auto compactAttr = dyn_cast(configAttr.getCompactMode())) { - switch (static_cast(compactAttr.getValue())) { - case 1: - compactTok = "CompactMode::Normal"; - break; - case 2: - compactTok = "CompactMode::RowPlusOne"; - break; - default: - compactTok = "CompactMode::Null"; - break; - } - } - return compactTok; -} - -static std::optional getEmitCTileTypeString(pto::TileBufType type) { - if (type.getRank() != 2) - return std::nullopt; - auto validShape = type.getValidShape(); - if (validShape.size() != 2) - return std::nullopt; - - Type elemTy = type.getElementType(); - auto configAttr = type.getConfigAttr(); - pto::BLayout blayout = getTileBufBLayoutValue(configAttr); - ArrayRef shape = type.getShape(); - int64_t rows = shape[0]; - int64_t cols = shape[1]; - - auto render = [&](int64_t dim, int dimIdx) { - return renderTileTemplateDim(dim, elemTy, blayout, dimIdx); - }; - - std::string vrowTok = - validShape[0] == ShapedType::kDynamic - ? "-1" - : std::to_string(render(validShape[0], 0)); - std::string vcolTok = - validShape[1] == ShapedType::kDynamic - ? "-1" - : std::to_string(render(validShape[1], 1)); - - int32_t fractal = 512; - if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) - fractal = frAttr.getInt(); - - return std::string("Tile<") + tileRoleToken(type.getMemorySpace()) + ", " + - getEmitCScalarTypeToken(elemTy) + ", " + - std::to_string(render(rows, 0)) + ", " + - std::to_string(render(cols, 1)) + ", " + - tileBufBLayoutToken(configAttr) + ", " + vrowTok + ", " + vcolTok + - ", " + tileBufSLayoutToken(configAttr) + ", " + - std::to_string(fractal) + ", " + tileBufPadToken(configAttr) + ", " + - tileBufCompactToken(configAttr) + ">"; -} - -//===----------------------------------------------------------------------===// -// Type Converter -//===----------------------------------------------------------------------===// - -class PTOToEmitCTypeConverter : public TypeConverter { -public: - PTOToEmitCTypeConverter(MLIRContext *Ctx, PTOArch targetArch) { - // --------------------------------------------------------- - // 1. 基本类型 (f32, i32, index) - // --------------------------------------------------------- - addConversion([Ctx](FloatType type) -> Type { - if (type.isFloat8E4M3() || type.isFloat8E4M3FN() || - type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ()) - return emitc::OpaqueType::get(Ctx, "float8_e4m3_t"); - if (type.isFloat8E5M2() || type.isFloat8E5M2FNUZ()) - return emitc::OpaqueType::get(Ctx, "float8_e5m2_t"); - if (type.isF32()) return emitc::OpaqueType::get(Ctx, "float"); - if (type.isF16()) return emitc::OpaqueType::get(Ctx, "half"); - if (type.isBF16()) return emitc::OpaqueType::get(Ctx, "bfloat16_t"); - if (type.isF64()) return emitc::OpaqueType::get(Ctx, "double"); - llvm::errs() << "[Debug] Unsupported FloatType: " << type << "\n"; - return Type{}; - }); - - addConversion([Ctx](pto::HiF8Type) -> Type { - return emitc::OpaqueType::get(Ctx, "hifloat8_t"); - }); - addConversion([Ctx](pto::F4E1M2x2Type) -> Type { - return emitc::OpaqueType::get(Ctx, "float4_e1m2x2_t"); - }); - addConversion([Ctx](pto::F4E2M1x2Type) -> Type { - return emitc::OpaqueType::get(Ctx, "float4_e2m1x2_t"); - }); - - addConversion([Ctx](IntegerType type) -> Type { - if (type.getWidth() == 1) - return type; - - // Prefer fixed-width C types. Preserve signedness if the MLIR integer is - // explicitly signed/unsigned; treat signless as signed by default. - const bool isUnsigned = type.isUnsignedInteger(); - switch (type.getWidth()) { - case 8: - return emitc::OpaqueType::get(Ctx, isUnsigned ? "uint8_t" : "int8_t"); - case 16: - return emitc::OpaqueType::get(Ctx, - isUnsigned ? "uint16_t" : "int16_t"); - case 32: - return emitc::OpaqueType::get(Ctx, - isUnsigned ? "uint32_t" : "int32_t"); - case 64: - return emitc::OpaqueType::get(Ctx, - isUnsigned ? "uint64_t" : "int64_t"); - default: - llvm::errs() << "[Debug] Unsupported IntegerType width: " - << type.getWidth() << "\n"; - return emitc::OpaqueType::get(Ctx, "int32_t"); // Fallback - } - }); - - addConversion([Ctx](IndexType type) -> Type { - return emitc::OpaqueType::get(Ctx, "int32_t"); - }); - - // vector<4xi16> (e.g. TMRGSORT executedNumList) -> pto::MrgSortExecutedNumList - addConversion([Ctx](VectorType type) -> Type { - if (type.getRank() == 1 && type.getNumElements() == 4 && - type.getElementType().isInteger(16)) - return emitc::OpaqueType::get(Ctx, "pto::MrgSortExecutedNumList"); - return Type{}; - }); - - // --------------------------------------------------------- - // 2. PTO 特殊类型 (透传或转换) - // --------------------------------------------------------- - addConversion([](emitc::OpaqueType type) { return type; }); - addConversion([](emitc::PointerType type) { return type; }); - - // --------------------------------------------------------- - // 2.5 PtrType 转换 (指针类型) - // --------------------------------------------------------- - addConversion([this, Ctx](pto::PtrType type) -> std::optional { - Type elemType = type.getElementType(); - Type newElemType = convertType(elemType); - if (!newElemType) - return std::nullopt; - - std::string elemTypeStr; - if (auto opq = dyn_cast(newElemType)) { - elemTypeStr = opq.getValue().str(); - } else { - llvm::errs() << " [Error] PtrType elem type is not OpaqueType: " - << newElemType << "\n"; - return std::nullopt; - } - - std::string qualifier = "__gm__"; - - std::string finalTypeStr = qualifier + " " + elemTypeStr; - return getEmitCPointerType(Ctx, finalTypeStr); - }); - - addConversion([Ctx](pto::PipeType type) -> Type { - (void)type; - return emitc::OpaqueType::get(Ctx, "auto"); - }); - - addConversion([Ctx](pto::EventIdArrayType type) -> Type { - std::string tok = "PTOAS_EventIdArray<" + std::to_string(type.getSize()) + ">"; - return emitc::OpaqueType::get(Ctx, tok); - }); - - // !pto.local_array -> !emitc.array. - // Variables of this type render as `T a[D1][D2]...;` in the emitted C++. - addConversion([this](pto::LocalArrayType type) -> std::optional { - Type convertedElem = convertType(type.getElementType()); - if (!convertedElem) - return std::nullopt; - return emitc::ArrayType::get(type.getShape(), convertedElem); - }); - - addConversion([Ctx](pto::AsyncSessionType type) -> Type { - (void)type; - return emitc::OpaqueType::get(Ctx, "pto::comm::AsyncSession"); - }); - - addConversion([Ctx](pto::AsyncEventType type) -> Type { - (void)type; - return emitc::OpaqueType::get(Ctx, "pto::comm::AsyncEvent"); - }); - - addConversion([Ctx](pto::PrefetchAsyncContextType type) -> Type { - (void)type; - return emitc::OpaqueType::get(Ctx, "pto::PrefetchAsyncContext"); - }); - - addConversion([Ctx](pto::TensorViewType type) -> Type { - return getGlobalTensorOpaqueTypeFromShape( - Ctx, type.getElementType(), type.getShape()); - }); - - addConversion([Ctx](pto::PartitionTensorViewType type) -> Type { - return getGlobalTensorOpaqueTypeFromShape( - Ctx, type.getElementType(), type.getShape()); - }); - - addConversion([Ctx](pto::TileBufType type) -> std::optional { - auto typeString = getEmitCTileTypeString(type); - if (!typeString) - return std::nullopt; - return emitc::OpaqueType::get(Ctx, *typeString); - }); - - // --------------------------------------------------------- - // 3. MemRef 转换 (Debug 重点) - // --------------------------------------------------------- - addConversion([this, Ctx](MemRefType type) -> std::optional { - LLVM_DEBUG(llvm::dbgs() << "Converting MemRef: " << type << "\n"); - - // A. 转换元素类型 - Type elemType = type.getElementType(); - Type newElemType = convertType(elemType); - if (!newElemType) { - llvm::errs() << " [Error] Failed to convert element type: " << elemType << "\n"; - return std::nullopt; - } - - // 获取元素类型的字符串 - std::string elemTypeStr; - if (auto opq = dyn_cast(newElemType)) { - elemTypeStr = opq.getValue().str(); - } else { - llvm::errs() << " [Error] Converted element type is not OpaqueType: " << newElemType << "\n"; - return std::nullopt; - } - - // B. 处理 Memory Space - std::string qualifier = ""; - Attribute memorySpace = type.getMemorySpace(); - - if (!memorySpace) { - qualifier = "__gm__"; - } else if (auto ptoAttr = dyn_cast(memorySpace)) { - qualifier = addrSpaceQualifier(ptoAttr.getAddressSpace()); - } else { - llvm::errs() << " [Warning] Unknown MemorySpace Attribute type: " << memorySpace << "\n"; - qualifier = "__gm__"; // Fallback - } - - std::string finalTypeStr = qualifier + " " + elemTypeStr; - LLVM_DEBUG(llvm::dbgs() << " [Success] -> " << finalTypeStr << "*\n"); - - return getEmitCPointerType(Ctx, finalTypeStr); - }); - - // --------------------------------------------------------- - // 4. Function & Materialization - // --------------------------------------------------------- - addConversion([this](FunctionType type) -> Type { - SmallVector inputs; - if (failed(convertTypes(type.getInputs(), inputs))) return Type{}; - SmallVector results; - if (failed(convertTypes(type.getResults(), results))) return Type{}; - return FunctionType::get(type.getContext(), inputs, results); - }); - - auto materializeCast = [](OpBuilder &Builder, Type ResultType, - ValueRange Inputs, Location Loc) -> Value { - if (Inputs.size() != 1) return Value(); - return Builder.create(Loc, ResultType, Inputs[0]).getResult(0); - }; - - addSourceMaterialization(materializeCast); - addTargetMaterialization(materializeCast); - // Needed for region/block signature conversions (e.g. CFG block args). - addArgumentMaterialization(materializeCast); - } -}; - -static constexpr unsigned kPTOIndexBitWidth = - 32; // keep consistent with IndexType conversion - -// Forward declarations (definitions below). -static inline std::string pipeTokFromPipeAttr(mlir::pto::PipeAttr a); -static emitc::OpaqueType getSignedIntOpaqueType(MLIRContext *ctx, - unsigned bitWidth); -static emitc::OpaqueType getUnsignedIntOpaqueType(MLIRContext *ctx, - unsigned bitWidth); -static emitc::OpaqueType getWiderSignedIntOpaqueType(MLIRContext *ctx, - unsigned bitWidth); -static emitc::OpaqueType getWiderUnsignedIntOpaqueType(MLIRContext *ctx, - unsigned bitWidth); -static Value makeEmitCOpaqueConstant(ConversionPatternRewriter &rewriter, - Location loc, Type type, - llvm::StringRef literal); -static Value makeEmitCIntConstant(ConversionPatternRewriter &rewriter, - Location loc, Type type, int64_t value); -static Value emitCCast(ConversionPatternRewriter &rewriter, Location loc, - Type dstType, Value src); -static FailureOr buildEmitCOpaqueConstantLiteral(Type targetType, - Attribute valueAttr); -static Value castSignlessIntToUnsignedSameWidth(ConversionPatternRewriter &rewriter, - Location loc, Value v, - unsigned bitWidth); -static bool needsA5NoSplitVectorGuard(Operation *op); - -static FailureOr getTileSplitToken(int64_t split) { - switch (split) { - case 0: - return std::string("TileSplitAxis::TILE_NO_SPLIT"); - case 1: - return std::string("TileSplitAxis::TILE_UP_DOWN"); - case 2: - return std::string("TileSplitAxis::TILE_LEFT_RIGHT"); - default: - return failure(); - } -} - -static FailureOr -getTPipeDirectionToken(bool isL2G2L, int8_t dirMask, PTOArch targetArch) { - if (dirMask == 1) { - if (isL2G2L && targetArch == PTOArch::A5) - return std::string("Direction::DIR_C2V_GM"); - return std::string("Direction::DIR_C2V"); - } - if (dirMask == 2) { - if (isL2G2L && targetArch == PTOArch::A5) - return std::string("Direction::DIR_V2C_GM"); - return std::string("Direction::DIR_V2C"); - } - if (dirMask == 3) - return std::string("Direction::DIR_BOTH"); - return failure(); -} - -static std::string buildTPipeToken(int32_t flagBase, llvm::StringRef dirTok, - int32_t slotSize, int32_t slotNum, - int32_t localSlotNum, bool nosplit) { - std::string token = "TPipe<" + std::to_string(flagBase) + ", " + dirTok.str() + - ", " + std::to_string(slotSize) + ", " + - std::to_string(slotNum); - token += ", " + std::to_string(localSlotNum); - token += nosplit ? ", true" : ", false"; - token += ">"; - return token; -} - -static FailureOr buildTPipeTokenFromInitOp(Operation *op, - PTOArch targetArch) { - if (auto initOp = dyn_cast(op)) { - if (!initOp.getFlagBaseAttr()) - return failure(); - auto dirTok = - getTPipeDirectionToken(/*isL2G2L=*/true, initOp.getDirMask(), targetArch); - if (failed(dirTok)) - return failure(); - int32_t localSlotNum = initOp.getLocalSlotNumAttr() - ? initOp.getLocalSlotNumAttr().getInt() - : initOp.getSlotNum(); - return buildTPipeToken(initOp.getFlagBaseAttr().getInt(), *dirTok, - initOp.getSlotSize(), initOp.getSlotNum(), - localSlotNum, - initOp.getNosplitAttr() && - initOp.getNosplitAttr().getValue()); - } - - if (auto initOp = dyn_cast(op)) { - if (!initOp.getFlagBaseAttr()) - return failure(); - auto dirTok = - getTPipeDirectionToken(/*isL2G2L=*/false, initOp.getDirMask(), targetArch); - if (failed(dirTok)) - return failure(); - return buildTPipeToken(initOp.getFlagBaseAttr().getInt(), *dirTok, - initOp.getSlotSize(), initOp.getSlotNum(), 2, - initOp.getNosplitAttr() && - initOp.getNosplitAttr().getValue()); - } - - return failure(); -} - -static FailureOr getTPipeTokenFromValue(Value pipeHandle, - PTOArch targetArch) { - pipeHandle = peelUnrealized(pipeHandle); - Operation *def = pipeHandle.getDefiningOp(); - if (!def) - return failure(); - return buildTPipeTokenFromInitOp(def, targetArch); -} - -static bool isSetFFTsPointerLikeType(Type ty) { - return isEmitCPointerLikeType(ty); -} - -static bool tileDataReturnsIntegralAddress(pto::AddressSpace as) { - return as == pto::AddressSpace::BIAS; -} - -static Type getTileDataResultType(MLIRContext *ctx, pto::AddressSpace as, - StringRef elemTok) { - if (tileDataReturnsIntegralAddress(as)) - return emitc::OpaqueType::get(ctx, "uint64_t"); - return getEmitCPointerType(ctx, addrSpaceQualifier(as), elemTok); -} - -static Value materializeTileDataValue(ConversionPatternRewriter &rewriter, - Location loc, Value tile, - pto::AddressSpace as, - StringRef elemTok) { - auto rawTy = getTileDataResultType(rewriter.getContext(), as, elemTok); - return rewriter - .create(loc, rawTy, "PTOAS__TILE_DATA", - ArrayAttr{}, ArrayAttr{}, - ValueRange{tile}) - .getResult(0); -} - -static Value materializeAddressAsPointer(ConversionPatternRewriter &rewriter, - Location loc, Value addr, - pto::AddressSpace as, - StringRef elemTok) { - auto *ctx = rewriter.getContext(); - std::string ptrTyStr = - std::string(addrSpaceQualifier(as)) + " " + elemTok.str() + "*"; - auto ptrTy = getEmitCPointerType(ctx, addrSpaceQualifier(as), elemTok); - if (isSetFFTsPointerLikeType(addr.getType())) { - if (addr.getType() == ptrTy) - return addr; - return rewriter.create(loc, ptrTy, addr).getResult(); - } - auto castTyAttr = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, ptrTyStr)}); - return rewriter - .create(loc, ptrTy, "reinterpret_cast", - ArrayAttr{}, castTyAttr, - ValueRange{addr}) - .getResult(0); -} - -struct InterCoreSyncCallDesc { - const char *callee = nullptr; - ArrayAttr args; - SmallVector operands; -}; - -static Value castInterCoreEventIdToI32(ConversionPatternRewriter &rewriter, - Location loc, Value eventId) { - auto i32Ty = emitc::OpaqueType::get(rewriter.getContext(), "int32_t"); - if (eventId.getType() == i32Ty) - return eventId; - return emitCCast(rewriter, loc, i32Ty, eventId); -} - -static Attribute getFFTSModeCodegenArg(ConversionPatternRewriter &rewriter, - int64_t fftsMode) { - auto *ctx = rewriter.getContext(); - if (fftsMode == 2) - return emitc::OpaqueAttr::get(ctx, "FFTS_MODE_VAL"); - return emitc::OpaqueAttr::get(ctx, std::to_string(fftsMode)); -} - -static Value createFFTSMsg(ConversionPatternRewriter &rewriter, Location loc, - Value eventI32, int64_t fftsMode) { - auto *ctx = rewriter.getContext(); - auto msgTy = emitc::OpaqueType::get(ctx, "uint16_t"); - auto msgArgs = rewriter.getArrayAttr({ - getFFTSModeCodegenArg(rewriter, fftsMode), - IntegerAttr::get(IndexType::get(ctx), 0), - }); - return rewriter - .create(loc, msgTy, "getFFTSMsg", - /*args=*/msgArgs, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{eventI32}) - .getResult(0); -} - -static InterCoreSyncCallDesc buildInterCoreSyncSetCall( - ConversionPatternRewriter &rewriter, Location loc, PTOArch targetArch, - pto::PipeAttr pipeAttr, IntegerAttr eventIdAttr, int64_t fftsMode) { - auto *ctx = rewriter.getContext(); - std::string pipeTok = pipeTokFromPipeAttr(pipeAttr); - - if (targetArch == PTOArch::A3) { - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value eventVal = - makeEmitCIntConstant(rewriter, loc, i32Ty, eventIdAttr.getInt()); - Value msgVal = createFFTSMsg(rewriter, loc, eventVal, fftsMode); - - InterCoreSyncCallDesc desc; - desc.callee = "ffts_cross_core_sync"; - desc.args = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeTok), - IntegerAttr::get(IndexType::get(ctx), 0), - }); - desc.operands.push_back(msgVal); - return desc; - } - - InterCoreSyncCallDesc desc; - desc.callee = "set_intra_block"; - desc.args = rewriter.getArrayAttr( - {emitc::OpaqueAttr::get(ctx, pipeTok), eventIdAttr}); - return desc; -} - -static InterCoreSyncCallDesc buildInterCoreSyncSetCallDyn( - ConversionPatternRewriter &rewriter, Location loc, PTOArch targetArch, - pto::PipeAttr pipeAttr, Value eventIdVal, int64_t fftsMode) { - auto *ctx = rewriter.getContext(); - std::string pipeTok = pipeTokFromPipeAttr(pipeAttr); - Value eventI32 = castInterCoreEventIdToI32(rewriter, loc, eventIdVal); - - if (targetArch == PTOArch::A3) { - Value msgVal = createFFTSMsg(rewriter, loc, eventI32, fftsMode); - - InterCoreSyncCallDesc desc; - desc.callee = "ffts_cross_core_sync"; - desc.args = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeTok), - IntegerAttr::get(IndexType::get(ctx), 0), - }); - desc.operands.push_back(msgVal); - return desc; - } - - InterCoreSyncCallDesc desc; - desc.callee = "set_intra_block"; - desc.args = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeTok), - IntegerAttr::get(IndexType::get(ctx), 0), - }); - desc.operands.push_back(eventI32); - return desc; -} - -static InterCoreSyncCallDesc buildInterCoreSyncWaitCall( - ConversionPatternRewriter &rewriter, PTOArch targetArch, - pto::PipeAttr pipeAttr, IntegerAttr eventIdAttr) { - auto *ctx = rewriter.getContext(); - std::string pipeTok = pipeTokFromPipeAttr(pipeAttr); - - InterCoreSyncCallDesc desc; - if (targetArch == PTOArch::A3) { - desc.callee = "wait_flag_dev"; - desc.args = rewriter.getArrayAttr({eventIdAttr}); - return desc; - } - - desc.callee = "wait_intra_block"; - desc.args = rewriter.getArrayAttr( - {emitc::OpaqueAttr::get(ctx, pipeTok), eventIdAttr}); - return desc; -} - -static InterCoreSyncCallDesc buildInterCoreSyncWaitCallDyn( - ConversionPatternRewriter &rewriter, Location loc, PTOArch targetArch, - pto::PipeAttr pipeAttr, Value eventIdVal) { - auto *ctx = rewriter.getContext(); - std::string pipeTok = pipeTokFromPipeAttr(pipeAttr); - Value eventI32 = castInterCoreEventIdToI32(rewriter, loc, eventIdVal); - - InterCoreSyncCallDesc desc; - if (targetArch == PTOArch::A3) { - desc.callee = "wait_flag_dev"; - desc.args = rewriter.getArrayAttr({IntegerAttr::get(IndexType::get(ctx), 0)}); - desc.operands.push_back(eventI32); - return desc; - } - - desc.callee = "wait_intra_block"; - desc.args = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeTok), - IntegerAttr::get(IndexType::get(ctx), 0), - }); - desc.operands.push_back(eventI32); - return desc; -} - -static bool hasInterCoreSyncOp(func::FuncOp func) { - bool found = false; - func.walk([&](Operation *op) { - if (isa(op)) { - found = true; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - return found; -} - -static bool hasSetFFTsOp(func::FuncOp func) { - bool found = false; - func.walk([&](Operation *op) { - if (isa(op)) { - found = true; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - return found; -} - -//===----------------------------------------------------------------------===// -// Arith -> EmitC (full dialect coverage for scalar ops) -//===----------------------------------------------------------------------===// - -template -struct ArithSimpleBinaryToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type dstTy = this->getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getOperands()); - return success(); - } -}; - -// Integer bitwise ops (andi/ori/xori) on signless integers: perform in unsigned -// to avoid signedness pitfalls, then cast back. -template -struct ArithUnsignedBitwiseBinaryToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = this->getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - if (bitWidth == 1) { - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getLhs(), - adaptor.getRhs()); - return success(); - } - - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value resU = rewriter.create(loc, uTy, lhsU, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, resU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithDivUIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::DivUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value divU = rewriter.create(loc, uTy, lhsU, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, divU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithRemUIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::RemUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value remU = rewriter.create(loc, uTy, lhsU, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, remU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithCeilDivUIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::CeilDivUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value one = makeEmitCIntConstant(rewriter, loc, uTy, 1); - Value rhsMinusOne = rewriter.create(loc, uTy, rhsU, one); - Value num = rewriter.create(loc, uTy, lhsU, rhsMinusOne); - Value divU = rewriter.create(loc, uTy, num, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, divU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithCeilDivSIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::CeilDivSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - Value zero = makeEmitCIntConstant(rewriter, loc, dstTy, 0); - Value one = makeEmitCIntConstant(rewriter, loc, dstTy, 1); - - Value q0 = rewriter.create(loc, dstTy, adaptor.getLhs(), - adaptor.getRhs()); - Value r = rewriter.create(loc, dstTy, adaptor.getLhs(), - adaptor.getRhs()); - - Value rNeZero = rewriter.create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::ne, r, - zero); - Value lhsLt0 = - rewriter.create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, adaptor.getLhs(), - zero); - Value rhsLt0 = - rewriter.create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, adaptor.getRhs(), - zero); - Value signsSame = - rewriter.create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::eq, lhsLt0, rhsLt0); - Value adjust = - rewriter.create(loc, rewriter.getI1Type(), - rNeZero, signsSame); - - Value qPlusOne = rewriter.create(loc, dstTy, q0, one); - Value result = rewriter.create(loc, dstTy, adjust, - qPlusOne, q0); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithFloorDivSIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::FloorDivSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - Value zero = makeEmitCIntConstant(rewriter, loc, dstTy, 0); - Value one = makeEmitCIntConstant(rewriter, loc, dstTy, 1); - - Value q0 = rewriter.create(loc, dstTy, adaptor.getLhs(), - adaptor.getRhs()); - Value r = rewriter.create(loc, dstTy, adaptor.getLhs(), - adaptor.getRhs()); - - Value rNeZero = rewriter.create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::ne, r, - zero); - Value lhsLt0 = - rewriter.create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, adaptor.getLhs(), - zero); - Value rhsLt0 = - rewriter.create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, adaptor.getRhs(), - zero); - Value signsDifferent = - rewriter.create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::ne, lhsLt0, rhsLt0); - Value adjust = - rewriter.create(loc, rewriter.getI1Type(), - rNeZero, signsDifferent); - - Value qMinusOne = rewriter.create(loc, dstTy, q0, one); - Value result = rewriter.create(loc, dstTy, adjust, - qMinusOne, q0); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithShiftLeftToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::ShLIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - if (bitWidth == 1) { - // Compute on u8 and truncate to i1. - auto u8Ty = getUnsignedIntOpaqueType(rewriter.getContext(), 8); - Value lhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getLhs()); - Value rhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getRhs()); - Value sh = rewriter.create(loc, u8Ty, lhsU8, - rhsU8); - Value masked = - rewriter.create(loc, u8Ty, sh, - makeEmitCIntConstant(rewriter, loc, - u8Ty, 1)); - rewriter.replaceOp(op, emitCCast(rewriter, loc, dstTy, masked)); - return success(); - } - - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value shU = - rewriter.create(loc, uTy, lhsU, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, shU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithShiftRightUIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::ShRUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - if (bitWidth == 1) { - // (x >> y) on i1 is either x (y==0) or 0 (y!=0); approximate in u8. - auto u8Ty = getUnsignedIntOpaqueType(rewriter.getContext(), 8); - Value lhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getLhs()); - Value rhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getRhs()); - Value sh = rewriter.create(loc, u8Ty, lhsU8, - rhsU8); - Value masked = - rewriter.create(loc, u8Ty, sh, - makeEmitCIntConstant(rewriter, loc, - u8Ty, 1)); - rewriter.replaceOp(op, emitCCast(rewriter, loc, dstTy, masked)); - return success(); - } - - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value shU = - rewriter.create(loc, uTy, lhsU, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, shU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithShiftRightSIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::ShRSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - if (bitWidth == 1) { - // (x >> y) on i1 is either x (y==0) or 0 (y!=0); approximate in u8. - auto u8Ty = getUnsignedIntOpaqueType(rewriter.getContext(), 8); - Value lhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getLhs()); - Value rhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getRhs()); - Value sh = rewriter.create(loc, u8Ty, lhsU8, - rhsU8); - Value masked = - rewriter.create(loc, u8Ty, sh, - makeEmitCIntConstant(rewriter, loc, - u8Ty, 1)); - rewriter.replaceOp(op, emitCCast(rewriter, loc, dstTy, masked)); - return success(); - } - - // Signed arithmetic shift; cast RHS to unsigned to interpret shift amount. - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value sh = - rewriter.create(loc, dstTy, adaptor.getLhs(), - rhsU); - rewriter.replaceOp(op, sh); - return success(); - } -}; - -struct ArithNegFToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::NegFOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getOperand()); - return success(); - } -}; - -struct ArithRemFToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::RemFOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - - // Use builtin `fmod` when possible. For f16, compute in float and cast back. - Type callTy = dstTy; - Value lhs = adaptor.getLhs(); - Value rhs = adaptor.getRhs(); - - if (auto opFloatTy = dyn_cast(op.getType())) { - if (opFloatTy.isF16()) { - auto f32Ty = emitc::OpaqueType::get(rewriter.getContext(), "float"); - lhs = emitCCast(rewriter, loc, f32Ty, lhs); - rhs = emitCCast(rewriter, loc, f32Ty, rhs); - callTy = f32Ty; - } - } - - // Prefer `__builtin_fmod*` to avoid relying on extra headers. - llvm::StringRef callee = "__builtin_fmod"; - if (auto opFloatTy = dyn_cast(op.getType())) { - if (opFloatTy.isF32() || opFloatTy.isF16()) - callee = "__builtin_fmodf"; - else if (opFloatTy.isF64()) - callee = "__builtin_fmod"; - } - - auto call = rewriter.create( - loc, TypeRange{callTy}, callee, ValueRange{lhs, rhs}, - /*args=*/ArrayAttr{}, /*template_args=*/ArrayAttr{}); - Value result = call.getResult(0); - if (callTy != dstTy) - result = emitCCast(rewriter, loc, dstTy, result); - - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithSelectToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!op.getCondition().getType().isInteger(1)) - return rewriter.notifyMatchFailure( - op, "only scalar i1 conditions supported for arith.select"); - - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - - auto cond = - rewriter.create(op.getLoc(), dstTy, - adaptor.getCondition(), - adaptor.getTrueValue(), - adaptor.getFalseValue()); - rewriter.replaceOp(op, cond.getResult()); - return success(); - } -}; - -struct ArithExtUIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto dstIntTy = dyn_cast(op.getType()); - auto srcIntTy = dyn_cast(op.getIn().getType()); - if (!dstIntTy || !srcIntTy) - return rewriter.notifyMatchFailure(op, "expected scalar integer types"); - - Type dstTy = getTypeConverter()->convertType(dstIntTy); - if (!dstTy) - return failure(); - - // i1 -> iN: bool to integer already behaves as 0/1. - if (srcIntTy.getWidth() == 1) { - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); - return success(); - } - - auto uDstTy = - getUnsignedIntOpaqueType(rewriter.getContext(), dstIntTy.getWidth()); - Value srcU = - castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getIn(), - srcIntTy.getWidth()); - Value extU = emitCCast(rewriter, loc, uDstTy, srcU); - Value result = emitCCast(rewriter, loc, dstTy, extU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithExtSIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto dstIntTy = dyn_cast(op.getType()); - auto srcIntTy = dyn_cast(op.getIn().getType()); - if (!dstIntTy || !srcIntTy) - return rewriter.notifyMatchFailure(op, "expected scalar integer types"); - - Type dstTy = getTypeConverter()->convertType(dstIntTy); - if (!dstTy) - return failure(); - - // i1 sign-extension: 0 -> 0, 1 -> -1. - if (srcIntTy.getWidth() == 1) { - Value zero = makeEmitCIntConstant(rewriter, loc, dstTy, 0); - Value asInt = emitCCast(rewriter, loc, dstTy, adaptor.getIn()); - Value neg = rewriter.create(loc, dstTy, zero, asInt).getResult(); - rewriter.replaceOp(op, neg); - return success(); - } - - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); - return success(); - } -}; - -template -struct ArithCastToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type dstTy = this->getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); - return success(); - } -}; - -struct ArithIndexCastUIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::IndexCastUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - - // MemRef casts are handled elsewhere; for safety, fall back to emitc.cast. - if (isa(op.getIn().getType()) || isa(op.getType())) { - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); - return success(); - } - - auto getBW = [](Type t) -> std::optional { - if (auto i = dyn_cast(t)) - return i.getWidth(); - if (isa(t)) - return kPTOIndexBitWidth; - return std::nullopt; - }; - - auto srcBW = getBW(op.getIn().getType()); - auto dstBW = getBW(op.getType()); - if (!srcBW || !dstBW) - return rewriter.notifyMatchFailure(op, "unsupported index_castui types"); - - if (*dstBW <= *srcBW) { - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); - return success(); - } - - auto uSrcTy = getUnsignedIntOpaqueType(rewriter.getContext(), *srcBW); - auto uDstTy = getUnsignedIntOpaqueType(rewriter.getContext(), *dstBW); - Value srcU = emitCCast(rewriter, loc, uSrcTy, adaptor.getIn()); - Value extU = emitCCast(rewriter, loc, uDstTy, srcU); - Value result = emitCCast(rewriter, loc, dstTy, extU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithUIToFPToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto srcIntTy = dyn_cast(op.getIn().getType()); - if (!srcIntTy) - return rewriter.notifyMatchFailure(op, "expected scalar integer input"); - - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - - // Convert via an unsigned integer type of the same width. - if (srcIntTy.getWidth() == 1) { - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); - return success(); - } - Value srcU = - castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getIn(), - srcIntTy.getWidth()); - Value fp = rewriter.create(loc, dstTy, srcU).getResult(); - rewriter.replaceOp(op, fp); - return success(); - } -}; - -struct ArithFPToUIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::FPToUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto dstIntTy = dyn_cast(op.getType()); - if (!dstIntTy) - return rewriter.notifyMatchFailure(op, "expected scalar integer result"); - - Type dstTy = getTypeConverter()->convertType(dstIntTy); - if (!dstTy) - return failure(); - - auto uDstTy = - getUnsignedIntOpaqueType(rewriter.getContext(), dstIntTy.getWidth()); - Value asU = rewriter.create(loc, uDstTy, adaptor.getIn()).getResult(); - Value result = emitCCast(rewriter, loc, dstTy, asU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithBitcastToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - - // For pointer-like types, a regular cast is fine. - if (isa(dstTy)) { - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); - return success(); - } - - // Only support scalar int/float/index bitcasts here. - auto srcTy = op.getIn().getType(); - auto dstOrigTy = op.getType(); - - auto getBitWidth = [](Type t) -> std::optional { - if (auto it = dyn_cast(t)) - return it.getWidth(); - if (auto ft = dyn_cast(t)) - return ft.getWidth(); - if (isa(t)) - return kPTOIndexBitWidth; - return std::nullopt; - }; - auto srcBW = getBitWidth(srcTy); - auto dstBW = getBitWidth(dstOrigTy); - if (!srcBW || !dstBW || *srcBW != *dstBW) - return rewriter.notifyMatchFailure(op, "bitcast requires equal bitwidth"); - - // Determine the template argument from the destination type string. - auto dstOpaque = dyn_cast(dstTy); - if (!dstOpaque) - return rewriter.notifyMatchFailure(op, "expected emitc opaque dest type"); - - auto templateArgs = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), - dstOpaque.getValue())}); - auto call = rewriter.create( - loc, TypeRange{dstTy}, "ptoas_bitcast", /*operands=*/ValueRange{adaptor.getIn()}, - /*args=*/ArrayAttr{}, /*template_args=*/templateArgs); - rewriter.replaceOp(op, call.getResult(0)); - return success(); - } -}; - -// arith.cmpf lowering with ordered/unordered semantics. -struct ArithCmpFToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - struct CmpFConfig { - bool unordered = false; - emitc::CmpPredicate predicate = emitc::CmpPredicate::eq; - }; - - static Value isNaN(ConversionPatternRewriter &rewriter, Location loc, - Value v) { - return rewriter - .create(loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, - v, v) - .getResult(); - } - - static Value isNotNaN(ConversionPatternRewriter &rewriter, Location loc, - Value v) { - return rewriter - .create(loc, rewriter.getI1Type(), emitc::CmpPredicate::eq, - v, v) - .getResult(); - } - - static std::optional buildSpecialCmpFResult( - arith::CmpFPredicate predicate, ConversionPatternRewriter &rewriter, - Location loc, Type i1Ty, Value lhs, Value rhs) { - switch (predicate) { - case arith::CmpFPredicate::AlwaysFalse: - return makeEmitCOpaqueConstant(rewriter, loc, i1Ty, "false"); - case arith::CmpFPredicate::AlwaysTrue: - return makeEmitCOpaqueConstant(rewriter, loc, i1Ty, "true"); - case arith::CmpFPredicate::ORD: - return rewriter.create( - loc, i1Ty, isNotNaN(rewriter, loc, lhs), - isNotNaN(rewriter, loc, rhs)) - .getResult(); - case arith::CmpFPredicate::UNO: - return rewriter.create( - loc, i1Ty, isNaN(rewriter, loc, lhs), - isNaN(rewriter, loc, rhs)) - .getResult(); - default: - return std::nullopt; - } - } - - static std::optional - getCmpFConfig(arith::CmpFPredicate predicate) { - switch (predicate) { - case arith::CmpFPredicate::OEQ: - return CmpFConfig{false, emitc::CmpPredicate::eq}; - case arith::CmpFPredicate::OGT: - return CmpFConfig{false, emitc::CmpPredicate::gt}; - case arith::CmpFPredicate::OGE: - return CmpFConfig{false, emitc::CmpPredicate::ge}; - case arith::CmpFPredicate::OLT: - return CmpFConfig{false, emitc::CmpPredicate::lt}; - case arith::CmpFPredicate::OLE: - return CmpFConfig{false, emitc::CmpPredicate::le}; - case arith::CmpFPredicate::ONE: - return CmpFConfig{false, emitc::CmpPredicate::ne}; - case arith::CmpFPredicate::UEQ: - return CmpFConfig{true, emitc::CmpPredicate::eq}; - case arith::CmpFPredicate::UGT: - return CmpFConfig{true, emitc::CmpPredicate::gt}; - case arith::CmpFPredicate::UGE: - return CmpFConfig{true, emitc::CmpPredicate::ge}; - case arith::CmpFPredicate::ULT: - return CmpFConfig{true, emitc::CmpPredicate::lt}; - case arith::CmpFPredicate::ULE: - return CmpFConfig{true, emitc::CmpPredicate::le}; - case arith::CmpFPredicate::UNE: - return CmpFConfig{true, emitc::CmpPredicate::ne}; - default: - return std::nullopt; - } - } - - static Value buildCmpFResult(const CmpFConfig &config, - ConversionPatternRewriter &rewriter, - Location loc, Type i1Ty, Value lhs, Value rhs) { - Value cmp = rewriter - .create(loc, i1Ty, config.predicate, lhs, rhs) - .getResult(); - Value unord = rewriter.create( - loc, i1Ty, isNaN(rewriter, loc, lhs), isNaN(rewriter, loc, rhs)); - if (config.unordered) - return rewriter - .create(loc, i1Ty, unord, cmp) - .getResult(); - Value ord = rewriter.create( - loc, i1Ty, isNotNaN(rewriter, loc, lhs), isNotNaN(rewriter, loc, rhs)); - return rewriter - .create(loc, i1Ty, ord, cmp) - .getResult(); - } - - LogicalResult matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!isa(op.getLhs().getType())) - return rewriter.notifyMatchFailure(op, "cmpf only supported on scalar floats"); - - auto loc = op.getLoc(); - auto i1Ty = rewriter.getI1Type(); - if (auto special = buildSpecialCmpFResult(op.getPredicate(), rewriter, loc, - i1Ty, adaptor.getLhs(), - adaptor.getRhs())) { - rewriter.replaceOp(op, *special); - return success(); - } - - auto config = getCmpFConfig(op.getPredicate()); - if (!config) - return rewriter.notifyMatchFailure(op, "unsupported cmpf predicate"); - rewriter.replaceOp(op, buildCmpFResult(*config, rewriter, loc, i1Ty, - adaptor.getLhs(), adaptor.getRhs())); - return success(); - } -}; - -struct ArithAddUIExtendedToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getSum().getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, - "expected scalar integer or index operands"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - SmallVector newResultTypes; - if (failed(getTypeConverter()->convertTypes(op->getResultTypes(), - newResultTypes))) - return failure(); - if (newResultTypes.size() != 2) - return failure(); - - Type sumDstTy = newResultTypes[0]; - Type overflowDstTy = newResultTypes[1]; - - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - auto wideTy = getWiderUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value lhsWide = emitCCast(rewriter, loc, wideTy, lhsU); - Value rhsWide = emitCCast(rewriter, loc, wideTy, rhsU); - Value sumWide = - rewriter.create(loc, wideTy, lhsWide, rhsWide).getResult(); - - Value sumN = emitCCast(rewriter, loc, uTy, sumWide); - Value sum = emitCCast(rewriter, loc, sumDstTy, sumN); - - Value shiftAmt = makeEmitCIntConstant(rewriter, loc, wideTy, bitWidth); - Value high = rewriter - .create(loc, wideTy, sumWide, - shiftAmt) - .getResult(); - Value zeroWide = makeEmitCIntConstant(rewriter, loc, wideTy, 0); - Value overflow = - rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::ne, high, zeroWide) - .getResult(); - overflow = emitCCast(rewriter, loc, overflowDstTy, overflow); - - rewriter.replaceOp(op, {sum, overflow}); - return success(); - } -}; - -template -struct ArithMulExtendedToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getResult(0).getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, - "expected scalar integer or index operands"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - SmallVector newResultTypes; - if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), - newResultTypes))) - return failure(); - if (newResultTypes.size() != 2) - return failure(); - - Type lowDstTy = newResultTypes[0]; - Type highDstTy = newResultTypes[1]; - - Type wideTy = isUnsigned ? (Type)getWiderUnsignedIntOpaqueType(rewriter.getContext(), - bitWidth) - : (Type)getWiderSignedIntOpaqueType(rewriter.getContext(), - bitWidth); - - Value lhsWide; - Value rhsWide; - if constexpr (isUnsigned) { - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - lhsWide = emitCCast(rewriter, loc, wideTy, lhsU); - rhsWide = emitCCast(rewriter, loc, wideTy, rhsU); - } else { - lhsWide = emitCCast(rewriter, loc, wideTy, adaptor.getLhs()); - rhsWide = emitCCast(rewriter, loc, wideTy, adaptor.getRhs()); - } - - Value prodWide = - rewriter.create(loc, wideTy, lhsWide, rhsWide).getResult(); - Value low = emitCCast(rewriter, loc, lowDstTy, prodWide); - - Value shiftAmt = makeEmitCIntConstant(rewriter, loc, wideTy, bitWidth); - Value highWide = rewriter - .create(loc, wideTy, prodWide, - shiftAmt) - .getResult(); - Value high = emitCCast(rewriter, loc, highDstTy, highWide); - - rewriter.replaceOp(op, {low, high}); - return success(); - } -}; - -using ArithMulSIExtendedToEmitC = - ArithMulExtendedToEmitC; -using ArithMulUIExtendedToEmitC = - ArithMulExtendedToEmitC; - -struct ArithMinMaxIToEmitCBase { - static Value makeSelect(ConversionPatternRewriter &rewriter, Location loc, - Type dstTy, Value cond, Value trueV, Value falseV) { - return rewriter - .create(loc, dstTy, cond, trueV, falseV) - .getResult(); - } -}; - -struct ArithMaxSIToEmitC : public OpConversionPattern, - ArithMinMaxIToEmitCBase { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::MaxSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - Value cond = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, - adaptor.getLhs(), adaptor.getRhs()) - .getResult(); - Value res = makeSelect(rewriter, loc, dstTy, cond, adaptor.getRhs(), - adaptor.getLhs()); - rewriter.replaceOp(op, res); - return success(); - } -}; - -struct ArithMinSIToEmitC : public OpConversionPattern, - ArithMinMaxIToEmitCBase { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::MinSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - Value cond = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, - adaptor.getLhs(), adaptor.getRhs()) - .getResult(); - Value res = makeSelect(rewriter, loc, dstTy, cond, adaptor.getLhs(), - adaptor.getRhs()); - rewriter.replaceOp(op, res); - return success(); - } -}; - -struct ArithMaxUIToEmitC : public OpConversionPattern, - ArithMinMaxIToEmitCBase { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::MaxUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - Value lhsU = - castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = - castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value cond = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, lhsU, rhsU) - .getResult(); - Value res = makeSelect(rewriter, loc, dstTy, cond, adaptor.getRhs(), - adaptor.getLhs()); - rewriter.replaceOp(op, res); - return success(); - } -}; - -struct ArithMinUIToEmitC : public OpConversionPattern, - ArithMinMaxIToEmitCBase { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::MinUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - Value lhsU = - castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = - castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value cond = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, lhsU, rhsU) - .getResult(); - Value res = makeSelect(rewriter, loc, dstTy, cond, adaptor.getLhs(), - adaptor.getRhs()); - rewriter.replaceOp(op, res); - return success(); - } -}; - -// Floating-point max/min variants. -struct ArithFloatMinMaxToEmitCBase { - static Value isNaN(ConversionPatternRewriter &rewriter, Location loc, - Value v) { - return rewriter - .create(loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, - v, v) - .getResult(); - } - - static Value makeFZero(ConversionPatternRewriter &rewriter, Location loc, - Type ty) { - return makeEmitCOpaqueConstant(rewriter, loc, ty, "0.0f"); - } -}; - -struct ArithMaxNumFToEmitC : public OpConversionPattern, - ArithFloatMinMaxToEmitCBase { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::MaxNumFOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - - Value lhsNaN = isNaN(rewriter, loc, adaptor.getLhs()); - Value rhsNaN = isNaN(rewriter, loc, adaptor.getRhs()); - - Value cmpLt = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, - adaptor.getLhs(), adaptor.getRhs()) - .getResult(); - Value maxNoNaN = - rewriter - .create(loc, dstTy, cmpLt, adaptor.getRhs(), - adaptor.getLhs()) - .getResult(); - - Value rhsOrMax = - rewriter - .create(loc, dstTy, rhsNaN, adaptor.getLhs(), - maxNoNaN) - .getResult(); - Value res = - rewriter - .create(loc, dstTy, lhsNaN, adaptor.getRhs(), - rhsOrMax) - .getResult(); - rewriter.replaceOp(op, res); - return success(); - } -}; - -struct ArithMinNumFToEmitC : public OpConversionPattern, - ArithFloatMinMaxToEmitCBase { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::MinNumFOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - - Value lhsNaN = isNaN(rewriter, loc, adaptor.getLhs()); - Value rhsNaN = isNaN(rewriter, loc, adaptor.getRhs()); - - Value cmpLt = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, - adaptor.getLhs(), adaptor.getRhs()) - .getResult(); - Value minNoNaN = - rewriter - .create(loc, dstTy, cmpLt, adaptor.getLhs(), - adaptor.getRhs()) - .getResult(); - - Value rhsOrMin = - rewriter - .create(loc, dstTy, rhsNaN, adaptor.getLhs(), - minNoNaN) - .getResult(); - Value res = - rewriter - .create(loc, dstTy, lhsNaN, adaptor.getRhs(), - rhsOrMin) - .getResult(); - rewriter.replaceOp(op, res); - return success(); - } -}; - -template -struct ArithMinMaxFPropagateNaNToEmitC : public OpConversionPattern, - ArithFloatMinMaxToEmitCBase { - using OpConversionPattern::OpConversionPattern; - - static Value buildPrimaryCandidate(ConversionPatternRewriter &rewriter, - Location loc, Type dstTy, Value lhs, - Value rhs) { - Value cmpLt = - rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, lhs, rhs) - .getResult(); - return rewriter - .create( - loc, dstTy, cmpLt, isMaximum ? rhs : lhs, isMaximum ? lhs : rhs) - .getResult(); - } - - static Value buildSignBitValue(ConversionPatternRewriter &rewriter, - Location loc, Value lhs, FloatType floatTy) { - auto bitsTy = - getUnsignedIntOpaqueType(rewriter.getContext(), floatTy.getWidth()); - auto templateArgs = rewriter.getArrayAttr({emitc::OpaqueAttr::get( - rewriter.getContext(), cast(bitsTy).getValue())}); - Value lhsBits = - rewriter - .create(loc, TypeRange{bitsTy}, "ptoas_bitcast", - ValueRange{lhs}, ArrayAttr{}, - templateArgs) - .getResult(0); - Value oneBits = makeEmitCIntConstant(rewriter, loc, bitsTy, 1); - Value shiftAmount = - makeEmitCIntConstant(rewriter, loc, bitsTy, floatTy.getWidth() - 1); - Value signMask = rewriter - .create(loc, bitsTy, oneBits, - shiftAmount) - .getResult(); - return rewriter - .create(loc, bitsTy, lhsBits, signMask) - .getResult(); - } - - static Value buildSignedZeroCandidate(ConversionPatternRewriter &rewriter, - Location loc, Type dstTy, Value lhs, - Value rhs, FloatType floatTy) { - Value zero = makeFZero(rewriter, loc, dstTy); - Value equal = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::eq, lhs, rhs) - .getResult(); - Value lhsZero = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::eq, lhs, - zero) - .getResult(); - Value bothZero = rewriter - .create(loc, rewriter.getI1Type(), - equal, lhsZero) - .getResult(); - auto bitsTy = - getUnsignedIntOpaqueType(rewriter.getContext(), floatTy.getWidth()); - Value zeroBits = makeEmitCIntConstant(rewriter, loc, bitsTy, 0); - Value lhsIsNegZero = - rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::ne, - buildSignBitValue(rewriter, loc, lhs, floatTy), - zeroBits) - .getResult(); - Value tie = rewriter - .create( - loc, dstTy, lhsIsNegZero, isMaximum ? rhs : lhs, - isMaximum ? lhs : rhs) - .getResult(); - return rewriter - .create(loc, dstTy, bothZero, tie, - buildPrimaryCandidate(rewriter, loc, dstTy, - lhs, rhs)) - .getResult(); - } - - static Value buildNaNPropagatingResult(ConversionPatternRewriter &rewriter, - Location loc, Type dstTy, Value lhs, - Value rhs, FloatType floatTy) { - Value lhsNaN = isNaN(rewriter, loc, lhs); - Value rhsNaN = isNaN(rewriter, loc, rhs); - Value noNaN = - buildSignedZeroCandidate(rewriter, loc, dstTy, lhs, rhs, floatTy); - Value rhsOrNoNaN = rewriter - .create(loc, dstTy, rhsNaN, rhs, - noNaN) - .getResult(); - return rewriter - .create(loc, dstTy, lhsNaN, lhs, rhsOrNoNaN) - .getResult(); - } - - LogicalResult - matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!isa(op.getType())) - return rewriter.notifyMatchFailure(op, "expected scalar float type"); - - auto loc = op.getLoc(); - Type dstTy = this->getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - - auto floatTy = cast(op.getType()); - rewriter.replaceOp(op, buildNaNPropagatingResult( - rewriter, loc, dstTy, adaptor.getLhs(), - adaptor.getRhs(), floatTy)); - return success(); - } -}; - -using ArithMaximumFToEmitC = - ArithMinMaxFPropagateNaNToEmitC; -using ArithMinimumFToEmitC = - ArithMinMaxFPropagateNaNToEmitC; - -//===----------------------------------------------------------------------===// -// Arith -> EmitC helpers -//===----------------------------------------------------------------------===// - -static emitc::OpaqueType getSignedIntOpaqueType(MLIRContext *ctx, - unsigned bitWidth) { - switch (bitWidth) { - case 1: - return emitc::OpaqueType::get(ctx, "int8_t"); - case 8: - return emitc::OpaqueType::get(ctx, "int8_t"); - case 16: - return emitc::OpaqueType::get(ctx, "int16_t"); - case 32: - return emitc::OpaqueType::get(ctx, "int32_t"); - case 64: - return emitc::OpaqueType::get(ctx, "int64_t"); - case 128: - return emitc::OpaqueType::get(ctx, "__int128"); - default: - llvm::errs() << "[Debug] Unsupported signed integer bitwidth: " << bitWidth - << "\n"; - return emitc::OpaqueType::get(ctx, "int64_t"); - } -} - -static emitc::OpaqueType getUnsignedIntOpaqueType(MLIRContext *ctx, - unsigned bitWidth) { - switch (bitWidth) { - case 1: - return emitc::OpaqueType::get(ctx, "uint8_t"); - case 8: - return emitc::OpaqueType::get(ctx, "uint8_t"); - case 16: - return emitc::OpaqueType::get(ctx, "uint16_t"); - case 32: - return emitc::OpaqueType::get(ctx, "uint32_t"); - case 64: - return emitc::OpaqueType::get(ctx, "uint64_t"); - case 128: - return emitc::OpaqueType::get(ctx, "unsigned __int128"); - default: - llvm::errs() << "[Debug] Unsupported unsigned integer bitwidth: " - << bitWidth << "\n"; - return emitc::OpaqueType::get(ctx, "uint64_t"); - } -} - -static emitc::OpaqueType getWiderSignedIntOpaqueType(MLIRContext *ctx, - unsigned bitWidth) { - switch (bitWidth) { - case 1: - case 8: - return getSignedIntOpaqueType(ctx, 16); - case 16: - return getSignedIntOpaqueType(ctx, 32); - case 32: - return getSignedIntOpaqueType(ctx, 64); - case 64: - return getSignedIntOpaqueType(ctx, 128); - default: - return getSignedIntOpaqueType(ctx, 128); - } -} - -static emitc::OpaqueType getWiderUnsignedIntOpaqueType(MLIRContext *ctx, - unsigned bitWidth) { - switch (bitWidth) { - case 1: - case 8: - return getUnsignedIntOpaqueType(ctx, 16); - case 16: - return getUnsignedIntOpaqueType(ctx, 32); - case 32: - return getUnsignedIntOpaqueType(ctx, 64); - case 64: - return getUnsignedIntOpaqueType(ctx, 128); - default: - return getUnsignedIntOpaqueType(ctx, 128); - } -} - -static Value makeEmitCOpaqueConstant(ConversionPatternRewriter &rewriter, - Location loc, Type type, - llvm::StringRef literal) { - auto attr = emitc::OpaqueAttr::get(rewriter.getContext(), literal); - return rewriter.create(loc, type, attr); -} - -static Value makeEmitCIntConstant(ConversionPatternRewriter &rewriter, - Location loc, Type type, int64_t value) { - return makeEmitCOpaqueConstant(rewriter, loc, type, std::to_string(value)); -} - -static FailureOr buildEmitCOpaqueConstantLiteral(Type targetType, - Attribute valueAttr) { - auto opaqueTy = dyn_cast(targetType); - if (!opaqueTy) - return failure(); - - if (opaqueTy.getValue() == "pto::MrgSortExecutedNumList") { - auto dense = dyn_cast_or_null(valueAttr); - if (!dense) - return failure(); - - auto vecTy = dyn_cast(dense.getType()); - if (!vecTy || vecTy.getRank() != 1 || vecTy.getNumElements() != 4 || - !vecTy.getElementType().isInteger(16)) - return failure(); - - std::string literal; - llvm::raw_string_ostream os(literal); - os << "pto::MrgSortExecutedNumList{"; - bool first = true; - for (APInt elem : dense.getValues()) { - if (!first) - os << ", "; - first = false; - os << elem.getZExtValue(); - } - os << "}"; - os.flush(); - return literal; - } - - return failure(); -} - -static Value emitCCast(ConversionPatternRewriter &rewriter, Location loc, - Type dstType, Value src) { - if (src.getType() == dstType) - return src; - return rewriter.createOrFold(loc, dstType, src); -} - -// For signless iN integers lowered to signed C++ types, this creates a value -// representing the same N-bit pattern in an unsigned C++ type of the same -// width. This avoids incorrect sign-extension when later widening to a larger -// unsigned type. -static Value castSignlessIntToUnsignedSameWidth(ConversionPatternRewriter &rewriter, - Location loc, Value v, - unsigned bitWidth) { - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - return emitCCast(rewriter, loc, uTy, v); -} - -struct ArithMulIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(arith::MulIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - // i1 mul is equivalent to bitwise AND (mod 2 arithmetic). - if (bitWidth == 1) { - rewriter.replaceOpWithNewOp(op, opTy, adaptor.getLhs(), - adaptor.getRhs()); - return success(); - } - - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value mulU = rewriter.create(loc, uTy, lhsU, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, mulU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithAddIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - // i1 add is equivalent to XOR (mod 2 arithmetic). - if (bitWidth == 1) { - rewriter.replaceOpWithNewOp(op, opTy, adaptor.getLhs(), - adaptor.getRhs()); - return success(); - } - - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value addU = rewriter.create(loc, uTy, lhsU, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, addU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithCastOPToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type newTy = getTypeConverter()->convertType(op.getType()); - if (!newTy) - return failure(); - if (adaptor.getIn().getType() == newTy) { - rewriter.replaceOp(op, adaptor.getIn()); - return success(); - } - rewriter.replaceOpWithNewOp(op, newTy, adaptor.getIn()); - return success(); - } -}; - -struct ArithSubIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(arith::SubIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - // i1 sub is equivalent to XOR (mod 2 arithmetic). - if (bitWidth == 1) { - rewriter.replaceOpWithNewOp(op, opTy, adaptor.getLhs(), - adaptor.getRhs()); - return success(); - } - - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value subU = rewriter.create(loc, uTy, lhsU, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, subU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithDivSIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(arith::DivSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type newTy = getTypeConverter()->convertType(op.getType()); - if (!newTy) - return failure(); - rewriter.replaceOpWithNewOp(op, newTy, adaptor.getLhs(), - adaptor.getRhs()); - return success(); - } -}; - -struct ArithRemSIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type newTy = getTypeConverter()->convertType(op.getType()); - if (!newTy) - return failure(); - rewriter.replaceOpWithNewOp(op, newTy, adaptor.getLhs(), - adaptor.getRhs()); - return success(); - } -}; - -struct ArithTruncIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - auto dstIntTy = dyn_cast(op.getType()); - auto srcIntTy = dyn_cast(op.getIn().getType()); - if (!dstIntTy || !srcIntTy) - return rewriter.notifyMatchFailure(op, "expected scalar integer types"); - - Type dstTy = getTypeConverter()->convertType(dstIntTy); - if (!dstTy) - return failure(); - - // to-i1 conversions: Arith wants truncation to the low bit, while C/C++ - // casts to bool are equivalent to `v != 0`. Implement as `(bool)(v & 1)`. - if (dstIntTy.getWidth() == 1) { - if (srcIntTy.getWidth() == 1) { - rewriter.replaceOp(op, adaptor.getIn()); - return success(); - } - - auto uSrcTy = - getUnsignedIntOpaqueType(rewriter.getContext(), srcIntTy.getWidth()); - Value inU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getIn(), - srcIntTy.getWidth()); - Value one = makeEmitCIntConstant(rewriter, loc, uSrcTy, 1); - Value masked = - rewriter.create(loc, uSrcTy, inU, one); - Value asBool = emitCCast(rewriter, loc, dstTy, masked); - rewriter.replaceOp(op, asBool); - return success(); - } - - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); - return success(); - } -}; - -struct ArithConstantToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type newType = getTypeConverter()->convertType(op.getType()); - if (!newType) - return failure(); - - // `adaptor.getValue()` may be null if attribute conversion isn't defined. - // Use the original attribute as fallback and always cast null-safely. - Attribute valueAttr = adaptor.getValue(); - if (!valueAttr) - valueAttr = op.getValue(); - - if (auto opaqueLiteral = buildEmitCOpaqueConstantLiteral(newType, valueAttr); - succeeded(opaqueLiteral)) { - auto constAttr = emitc::OpaqueAttr::get(rewriter.getContext(), *opaqueLiteral); - rewriter.replaceOpWithNewOp(op, newType, constAttr); - return success(); - } - - if (auto floatAttr = dyn_cast_or_null(valueAttr)) { - SmallString<32> valStr; - floatAttr.getValue().toString(valStr); - llvm::StringRef s(valStr); - // Ensure the literal parses as a floating-point constant in C/C++. - // `APFloat::toString` may emit "1" for integral values; make it "1.0". - const bool hasFloatMarker = - s.contains('.') || s.contains('e') || s.contains('E') || - s.contains('p') || s.contains('P') || s.starts_with("0x") || - s.starts_with("0X") || s.starts_with("nan") || - s.starts_with("-nan") || s.starts_with("inf") || - s.starts_with("-inf"); - if (!hasFloatMarker) - valStr.append(".0"); - // Suffix: keep `f` for f16/f32; omit for f64. - if (!floatAttr.getType().isF64()) - valStr.append("f"); - auto constAttr = emitc::OpaqueAttr::get(rewriter.getContext(), valStr); - rewriter.replaceOpWithNewOp(op, newType, constAttr); - return success(); - } - - if (auto intAttr = dyn_cast_or_null(valueAttr)) { - std::string valStr = std::to_string(intAttr.getValue().getSExtValue()); - auto constAttr = emitc::OpaqueAttr::get(rewriter.getContext(), valStr); - rewriter.replaceOpWithNewOp(op, newType, constAttr); - return success(); - } - - return failure(); - } -}; -//===----------------------------------------------------------------------===// -// pto.mgather lowering -> MGATHER(dst, src, indexes) (pto-isa) -//===----------------------------------------------------------------------===// - -struct PTOMGatherToMGATHER : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::MGatherOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto *ctx = rewriter.getContext(); - Value mem = peelUnrealized(adaptor.getMem()); - Value idx = peelUnrealized(adaptor.getIdx()); - Value dst = peelUnrealized(adaptor.getDst()); - - Value memArg = maybeWrapGlobalMemrefAsGlobalTensor( - rewriter, op.getLoc(), mem, op.getMem().getType(), op.getOperation()); - - auto gatherOobTok = [&](pto::GatherOOB mode) -> StringRef { - switch (mode) { - case pto::GatherOOB::Undefined: - return "pto::GatherOOB::Undefined"; - case pto::GatherOOB::Clamp: - return "pto::GatherOOB::Clamp"; - case pto::GatherOOB::Wrap: - return "pto::GatherOOB::Wrap"; - case pto::GatherOOB::Zero: - return "pto::GatherOOB::Zero"; - } - llvm_unreachable("unknown GatherOOB"); - }; - - SmallVector templateArgVec; - const bool rowCoalesce = - isRowCoalescedMGatherIndexType(op.getDst().getType(), op.getIdx().getType()); - templateArgVec.push_back(emitc::OpaqueAttr::get( - ctx, rowCoalesce ? "pto::Coalesce::Row" : "pto::Coalesce::Elem")); - if (op.getGatherOob() != pto::GatherOOB::Undefined) { - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, gatherOobTok(op.getGatherOob()))); - } - ArrayAttr templateArgs = rewriter.getArrayAttr(templateArgVec); - - rewriter.create( - op.getLoc(), TypeRange{}, "MGATHER", - ArrayAttr{}, templateArgs, - ValueRange{dst, memArg, idx}); - - if (op->getNumResults() == 0) { - rewriter.eraseOp(op); - } else { - rewriter.replaceOp(op, dst); - } - return success(); - } -}; - -struct AffineApplyMulConstToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(affine::AffineApplyOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto map = op.getAffineMap(); - - if (map.getNumDims() != 0 || map.getNumSymbols() != 1) - return failure(); - - auto expr = map.getResult(0); - auto bin = dyn_cast(expr); - if (!bin || bin.getKind() != AffineExprKind::Mul) - return failure(); - - auto lhs = bin.getLHS(); - auto rhs = bin.getRHS(); - - auto symExpr = dyn_cast(lhs); - auto constExpr = dyn_cast(rhs); - if (!symExpr || !constExpr) - return failure(); - - Value inputVal = adaptor.getMapOperands()[0]; - - std::string valStr = std::to_string(constExpr.getValue()); - auto cstAttr = emitc::OpaqueAttr::get(rewriter.getContext(), valStr); - auto cstOp = rewriter.create( - op.getLoc(), inputVal.getType(), cstAttr); - - rewriter.replaceOpWithNewOp( - op, inputVal.getType(), inputVal, cstOp); - - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// Kernel inference helpers -//===----------------------------------------------------------------------===// - -enum class KernelKind { VecAdd, Matmul, Unknown }; - -[[maybe_unused]] static KernelKind inferKernelKind(func::FuncOp f) { - bool hasAdd = false; - bool hasMM = false; - f.walk([&](Operation *op) { - if (isa(op)) hasAdd = true; - if (isa(op)) hasMM = true; - if (isa(op)) hasMM = true; - }); - if (hasMM) return KernelKind::Matmul; - if (hasAdd) return KernelKind::VecAdd; - return KernelKind::Unknown; -} - -[[maybe_unused]] static void inferTileMNK(func::FuncOp f, int &M, int &N, int &K) { - M = 32; N = 32; K = 32; - SmallVector subs; - f.walk([&](memref::SubViewOp sv) { subs.push_back(sv); }); - - auto readShape2D = [&](memref::SubViewOp sv, int &d0, int &d1) { - auto resTy = mlir::cast(sv.getResult().getType()); - if (resTy.getRank() == 2 && resTy.hasStaticShape()) { - d0 = (int)resTy.getDimSize(0); - d1 = (int)resTy.getDimSize(1); - } - }; - - if (subs.empty()) return; - - int a0=32, a1=32; - readShape2D(subs[0], a0, a1); - M = a0; N = a1; - - if (subs.size() >= 2) { - int b0=32, b1=32; - readShape2D(subs[0], a0, a1); - readShape2D(subs[1], b0, b1); - M = a0; K = a1; N = b1; - } -} - -static std::optional getKernelKindMacro(func::FuncOp funcOp) { - auto kernelKindAttr = - funcOp->getAttrOfType(FunctionKernelKindAttr::name); - if (!kernelKindAttr) - return std::nullopt; - - switch (kernelKindAttr.getKernelKind()) { - case FunctionKernelKind::Cube: - return StringRef("__DAV_CUBE__"); - case FunctionKernelKind::Vector: - return StringRef("__DAV_VEC__"); - } - - llvm_unreachable("unexpected kernel kind"); -} - -struct FuncToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(func::FuncOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Convert the function signature with the type converter. - Type convertedTy = getTypeConverter()->convertType(op.getFunctionType()); - auto funcType = dyn_cast_or_null(convertedTy); - if (!funcType) - return rewriter.notifyMatchFailure(op, "failed to convert function type"); - if (funcType.getNumResults() > 1) - return rewriter.notifyMatchFailure( - op, "EmitC cannot return multiple values"); - - // Create the EmitC function with the converted signature. - auto emitcFunc = - rewriter.create(op.getLoc(), op.getName(), funcType); - - for (const auto &namedAttr : op->getAttrs()) { - StringRef name = namedAttr.getName().strref(); - if (name == op.getFunctionTypeAttrName() || - name == SymbolTable::getSymbolAttrName() || - name == pto::kPTOEntryAttrName || - name == pto::kLegacyHACCEntryAttrName || - name == "pto.internal.entry") - continue; - emitcFunc->setAttr(namedAttr.getName(), namedAttr.getValue()); - } - - if (op.isDeclaration()) { - emitcFunc.setSpecifiersAttr(rewriter.getStrArrayAttr({"extern"})); - rewriter.eraseOp(op); - return success(); - } - - if (pto::isPTOEntryFunction(op)) { - emitcFunc.setSpecifiersAttr( - rewriter.getStrArrayAttr({"__global__ AICORE"})); - } else if (op.isPrivate()) { - emitcFunc.setSpecifiersAttr( - rewriter.getStrArrayAttr({"static", "AICORE"})); - } else { - emitcFunc.setSpecifiersAttr(rewriter.getStrArrayAttr({"AICORE"})); - } - - std::optional kernelKindMacro = getKernelKindMacro(op); - bool needsNoSplitGuard = needsA5NoSplitVectorGuard(op.getOperation()); - - // Inline the original body, then convert region/block argument types to - // match the converted signature (also covers CFG blocks introduced by - // pre-lowering, e.g. scf.while -> cf.br/cf.cond_br). - rewriter.inlineRegionBefore(op.getBody(), emitcFunc.getBody(), - emitcFunc.end()); - - TypeConverter::SignatureConversion entryConv(op.getNumArguments()); - for (unsigned i = 0; i < op.getNumArguments(); ++i) - entryConv.addInputs(i, funcType.getInput(i)); - - if (failed(rewriter.convertRegionTypes(&emitcFunc.getBody(), - *getTypeConverter(), &entryConv))) - return failure(); - - // Preserve the existing function prologue shape. `kernel_kind` functions are - // emitted with the same macro guard/reset sequence that used to come from - // early pto.section wrapping, but only after SCF pre-lowering has finished. - { - Block &entryBlock = emitcFunc.getBody().front(); - rewriter.setInsertionPointToStart(&entryBlock); - rewriter.create(op.getLoc(), "using T = float;"); - if (kernelKindMacro) { - std::string startMacro = "\n#if defined(" + kernelKindMacro->str() + ")"; - rewriter.create(op.getLoc(), startMacro); - if (*kernelKindMacro == "__DAV_VEC__") { - rewriter.create(op.getLoc(), "set_mask_norm();"); - rewriter.create(op.getLoc(), - "set_vector_mask(-1, -1);"); - if (needsNoSplitGuard) - rewriter.create( - op.getLoc(), "if (get_subblockid() == 0) {"); - } - } - } - - if (kernelKindMacro) { - Block &lastBlock = emitcFunc.getBody().back(); - rewriter.setInsertionPoint(lastBlock.getTerminator()); - if (*kernelKindMacro == "__DAV_VEC__" && needsNoSplitGuard) - rewriter.create(op.getLoc(), "}"); - std::string endMacro = "#endif // " + kernelKindMacro->str() + "\n"; - rewriter.create(op.getLoc(), endMacro); - } - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// SubView lowering to GlobalTensor (keep your existing code) -//===----------------------------------------------------------------------=== - -enum class Role { A, B, C, Unknown }; - -template -static std::optional inferMatmulLikeSubviewRole(MatmulLikeOp op, - Value buffer) { - if (op.getLhs() == buffer) - return Role::A; - if (op.getRhs() == buffer) - return Role::B; - return std::nullopt; -} - -static std::optional inferSubviewRoleFromLoadUser(mlir::pto::TLoadOp load) { - Value buffer = load.getDst(); - if (!buffer) - return std::nullopt; - for (Operation *user : buffer.getUsers()) { - if (auto matmul = dyn_cast(user)) { - if (auto role = inferMatmulLikeSubviewRole(matmul, buffer)) - return role; - continue; - } - if (auto matmulAcc = dyn_cast(user)) { - if (auto role = inferMatmulLikeSubviewRole(matmulAcc, buffer)) - return role; - } - } - return std::nullopt; -} - -static std::optional inferSubviewRoleFromUser(Operation *user, Value result) { - if (auto load = dyn_cast(user)) - return inferSubviewRoleFromLoadUser(load); - if (auto store = dyn_cast(user)) { - if (store.getDst() == result) - return Role::C; - } - return std::nullopt; -} - -[[maybe_unused]] static Role inferSubviewRole(memref::SubViewOp sv) { - Value result = sv.getResult(); - for (Operation *user : result.getUsers()) { - if (auto role = inferSubviewRoleFromUser(user, result)) - return *role; - } - return Role::Unknown; -} - -// ============================================================================= -// 4. MemRef SubView -> Explicit Shape/Stride Construction (Full Implementation) -// ============================================================================= -struct SubviewToEmitCPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - // 辅助函数:尝试从 OpFoldResult 中提取静态整数值 - std::optional extractStaticInt(OpFoldResult ofr) const { - if (auto attr = ofr.dyn_cast()) { - if (auto intAttr = dyn_cast(attr)) - return intAttr.getInt(); - } else { - Value v = ofr.get(); - if (auto cOp = v.getDefiningOp()) { - if (auto iAttr = dyn_cast(cOp.getValue())) - return iAttr.getInt(); - } else if (auto idxOp = v.getDefiningOp()) { - return idxOp.value(); - } - } - return std::nullopt; - } - - LogicalResult matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - // 获取源 MemRef 类型信息 - auto srcType = mlir::cast(op.getSource().getType()); - int64_t rank = srcType.getRank(); - - auto elemTypeToString = [&](Type elemTy) -> std::string { - if (elemTy.isF16()) - return "half"; - if (elemTy.isBF16()) - return "bfloat16_t"; - if (elemTy.isF32()) - return "float"; - if (elemTy.isF64()) - return "double"; - if (elemTy.isInteger(8)) { - if (elemTy.isSignlessInteger(8) || elemTy.isSignedInteger(8)) - return "int8_t"; - return "uint8_t"; - } - if (elemTy.isInteger(16)) { - if (elemTy.isSignlessInteger(16) || elemTy.isSignedInteger(16)) - return "int16_t"; - return "uint16_t"; - } - if (elemTy.isInteger(32)) { - if (elemTy.isSignlessInteger(32) || elemTy.isSignedInteger(32)) - return "int32_t"; - return "uint32_t"; - } - if (elemTy.isInteger(64)) { - return cast(elemTy).isUnsigned() ? "uint64_t" : "int64_t"; - } - return "float"; - }; - - // ------------------------------------------------------------------------- - // Part 1: 指针偏移计算 (Runtime Pointer Arithmetic) - // ------------------------------------------------------------------------- - - // 准备类型: unsigned - Type u32Ty = emitc::OpaqueType::get(ctx, "unsigned"); - - // Helper: 创建 unsigned 常量 - auto mkU32 = [&](int64_t v) -> Value { - return rewriter.create( - loc, u32Ty, emitc::OpaqueAttr::get(ctx, std::to_string(v))); - }; - - // Helper: 将 OpFoldResult 转为 EmitC Value (用于计算) - auto ofrToEmitCValue = [&](OpFoldResult ofr) -> Value { - if (auto v = ofr.dyn_cast()) { - Value rv = rewriter.getRemappedValue(v); - // 如果类型不匹配,插入 Cast - if (rv.getType() != u32Ty) - return rewriter.create(loc, u32Ty, rv).getResult(); - return rv; - } - if (auto attr = ofr.dyn_cast()) { - if (auto ia = dyn_cast(attr)) - return mkU32(ia.getValue().getSExtValue()); - } - return mkU32(0); - }; - - // 1. 获取 Source 的 Strides (支持动态 Stride 收集) - SmallVector sourceStrides; - - if (auto rc = op.getSource().getDefiningOp()) { - sourceStrides = rc.getMixedStrides(); - } else { - SmallVector strideInts; - int64_t offset = ShapedType::kDynamic; - bool useTypeStrides = succeeded(getStridesAndOffset(srcType, strideInts, offset)); - (void)offset; - if (useTypeStrides) { - for (int64_t s : strideInts) { - if (s == ShapedType::kDynamic) - useTypeStrides = false; - } - } - if (useTypeStrides) { - for (int64_t s : strideInts) { - sourceStrides.push_back(rewriter.getIndexAttr(s)); - } - } else { - // Fallback: Compact Layout - auto shape = srcType.getShape(); - int64_t current = 1; - sourceStrides.resize(rank); - for (int i = rank - 1; i >= 0; --i) { - sourceStrides[i] = rewriter.getIndexAttr(current); - if (shape[i] != ShapedType::kDynamic) current *= shape[i]; - } - } - } - - // 2. 计算运行时 Offset - auto staticOffsets = op.getStaticOffsets(); - auto dynamicOffsets = adaptor.getOffsets(); - int dynOffIdx = 0; - Value totalOffset = mkU32(0); - - for (int i = 0; i < rank; ++i) { - // A. 获取 Offset - Value offVal; - if (staticOffsets[i] == ShapedType::kDynamic) { - Value rawDyn = dynamicOffsets[dynOffIdx++]; - offVal = rewriter.create(loc, u32Ty, rawDyn); - } else { - offVal = mkU32(staticOffsets[i]); - } - - // B. 获取 Stride (用于指针计算) - Value strideVal = mkU32(1); - if (i < (int)sourceStrides.size()) { - strideVal = ofrToEmitCValue(sourceStrides[i]); - } - - // C. 累加 - Value term = rewriter.create(loc, u32Ty, offVal, strideVal); - totalOffset = rewriter.create(loc, u32Ty, totalOffset, term); - } - - // 3. 生成新指针 - // - // NOTE: Some toolchains may materialize kernel pointer params as `void*` even - // when the underlying element type is i16. Pointer arithmetic on `void*` - // is ill-formed in C++, so we explicitly cast to a typed pointer for i16. - Value sourcePtr = adaptor.getSource(); - Value tileCandidate = sourcePtr; - if (auto castOp = sourcePtr.getDefiningOp()) { - tileCandidate = castOp.getOperand(); - } else if (auto uc = - sourcePtr.getDefiningOp()) { - tileCandidate = uc.getOperand(0); - } - if (auto ot = dyn_cast(tileCandidate.getType())) { - auto tyStr = ot.getValue(); - if (tyStr.find("Tile<") != std::string::npos || - tyStr.find("ConvTile<") != std::string::npos) { - std::string elemTok = elemTypeToString(srcType.getElementType()); - pto::AddressSpace as = pto::AddressSpace::GM; - if (auto asAttr = - dyn_cast_or_null(srcType.getMemorySpace())) - as = asAttr.getAddressSpace(); - sourcePtr = - materializeTileDataValue(rewriter, loc, tileCandidate, as, elemTok); - if (tileDataReturnsIntegralAddress(as)) - sourcePtr = - materializeAddressAsPointer(rewriter, loc, sourcePtr, as, elemTok); - } - } - Value newPtr; - { - auto resTy = mlir::cast(op.getResult().getType()); - Type elemTy = resTy.getElementType(); - if (elemTy.isInteger(16)) { - std::string castElemTypeStr = "int16_t"; - if (cast(elemTy).isUnsigned()) - castElemTypeStr = "uint16_t"; - - std::string qualifier = "__gm__"; - if (Attribute ms = srcType.getMemorySpace()) { - if (auto ptoAttr = dyn_cast(ms)) { - qualifier = addrSpaceQualifier(ptoAttr.getAddressSpace()); - } - } - - auto typedPtrTy = emitc::PointerType::get( - emitc::OpaqueType::get(ctx, qualifier + " " + castElemTypeStr)); - Value typedSourcePtr = rewriter.create(loc, typedPtrTy, sourcePtr); - newPtr = rewriter.create(loc, typedPtrTy, typedSourcePtr, totalOffset); - } else { - newPtr = rewriter.create(loc, sourcePtr.getType(), sourcePtr, totalOffset); - } - } - - - // ------------------------------------------------------------------------- - // Part 2: For non-GM memrefs, keep pointer (no GlobalTensor). - // ------------------------------------------------------------------------- - bool isGlobal = true; - if (auto asAttr = dyn_cast_or_null(srcType.getMemorySpace())) { - auto as = asAttr.getAddressSpace(); - isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); - } - if (!isGlobal) { - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - if (newPtr.getType() != dstTy) - newPtr = rewriter.create(loc, dstTy, newPtr); - rewriter.replaceOp(op, newPtr); - return success(); - } - - // ------------------------------------------------------------------------- - // Part 3: 生成 GlobalTensor 类型 (Shape/Stride Template Generation) - // ------------------------------------------------------------------------- - - // When emitting C++ with `declareVariablesAtTop`, value declarations are - // hoisted before body statements. Avoid introducing local `using` aliases - // for templated types (Shape/Stride/GlobalTensor) because those aliases - // would appear after the hoisted declarations and break compilation - // (`unknown type name`). - // - // Instead, use the fully spelled template types as EmitC opaque types. - - auto resTy = mlir::cast(op.getResult().getType()); - - // 1. 解析具体元素类型 - std::string elemTypeStr = getElemTypeStringForGT(resTy.getElementType()); - - // 2. 生成 Shape 模板参数,之后会右对齐有效维度并补齐到 5 维(高维填 1) - SmallVector shapeParamsVec; - SmallVector sizeValues; // 每个维度对应的运行时 size(统一为 unsigned) - auto resShape = resTy.getShape(); - auto mixedSizes = op.getMixedSizes(); - sizeValues.reserve(rank); - for (int i = 0; i < resTy.getRank(); ++i) { - if (resShape[i] == ShapedType::kDynamic) { - shapeParamsVec.push_back(-1); - } else { - shapeParamsVec.push_back(resShape[i]); - } - // size 值:优先从 op.getMixedSizes() 取(可动态/静态),否则退化为类型里的静态 shape。 - if (i < (int)mixedSizes.size()) - sizeValues.push_back(ofrToEmitCValue(mixedSizes[i])); - else - sizeValues.push_back( - mkU32(resShape[i] == ShapedType::kDynamic ? 1 : resShape[i])); - } - - // 3. 生成 Stride 模板参数 + 运行时 stride 值(考虑 subview step) - SmallVector strideTemplateVec; - SmallVector strideValues; // 每个维度对应的运行时 stride(统一为 unsigned) - strideTemplateVec.reserve(rank); - strideValues.reserve(rank); - auto subViewSteps = op.getMixedStrides(); - for (int i = 0; i < rank; ++i) { - OpFoldResult srcStrideOfr = - (i < (int)sourceStrides.size()) ? sourceStrides[i] - : rewriter.getIndexAttr(1); - OpFoldResult stepOfr = (i < (int)subViewSteps.size()) - ? subViewSteps[i] - : rewriter.getIndexAttr(1); - - auto srcStatic = extractStaticInt(srcStrideOfr); - auto stepStatic = extractStaticInt(stepOfr); - if (srcStatic && stepStatic) { - int64_t finalStride = (*srcStatic) * (*stepStatic); - strideTemplateVec.push_back(finalStride); - strideValues.push_back(mkU32(finalStride)); - continue; - } - - strideTemplateVec.push_back(-1); - Value srcV = ofrToEmitCValue(srcStrideOfr); - Value stepV = ofrToEmitCValue(stepOfr); - // 尽量避免乘以 1 生成冗余指令 - if (stepStatic && *stepStatic == 1) - strideValues.push_back(srcV); - else if (srcStatic && *srcStatic == 1) - strideValues.push_back(stepV); - else - strideValues.push_back( - rewriter.create(loc, u32Ty, srcV, stepV)); - } - - // 3.1 右对齐到 5 维:shape 补 1;已有维度继承原 stride; - // 被补出来的高维按“紧密升维”规则连续推导:stride[i] = shape[i+1] * stride[i+1] - SmallVector finalShape; - SmallVector finalStride; - buildGlobalTensorShapeAndStride(shapeParamsVec, strideTemplateVec, - finalShape, finalStride); - Value oneU32 = mkU32(1); - SmallVector finalShapeValues(5, oneU32); - SmallVector finalStrideValues(5, oneU32); - int shift = 5 - rank; - - // 先放入原始 shape/stride(保持用户提供的值) - for (int i = 0; i < rank && i < 5; ++i) { - finalShapeValues[shift + i] = sizeValues[i]; - finalStrideValues[shift + i] = strideValues[i]; - } - - // 从低维到高维倒推补齐 stride(仅对补出来的前置维度生效) - for (int i = 3; i >= 0; --i) { - // 如果该维已由原始 rank 覆盖,则保持原值 - if (i >= shift) - continue; - if (finalStride[i] != -1) { - finalStrideValues[i] = mkU32(finalStride[i]); - continue; - } - // 动态推导:stride[i] = shape[i+1] * stride[i+1] - if (finalShape[i + 1] == 1) { - finalStrideValues[i] = finalStrideValues[i + 1]; - } else { - finalStrideValues[i] = rewriter.create( - loc, u32Ty, finalShapeValues[i + 1], finalStrideValues[i + 1]); - } - } - - std::string shapeParams = joinIntTemplateParams(finalShape); - std::string strideParams = joinIntTemplateParams(finalStride); - - // Spelled-out C++ types. - std::string shapeCppType = "pto::Shape<" + shapeParams + ">"; - std::string strideCppType = "pto::Stride<" + strideParams + ">"; - - // 3.0 Layout: prefer the attribute from InferPTOLayout; only fall back to - // local inference when the pass is disabled. - std::string layoutEnum = "pto::Layout::ND"; - if (auto layout = resolveLayoutForGlobalTensor(op, op.getSource())) { - layoutEnum = layoutToEmitCString(*layout); - } else { - bool allStatic = - llvm::all_of(finalShape, [](int64_t value) { return value != -1; }) && - llvm::all_of(finalStride, [](int64_t value) { return value != -1; }); - - int layoutTag = 0; // ND - auto elemBytes = 4; // default float - if (elemTypeStr.find("half") != std::string::npos || - elemTypeStr.find("f16") != std::string::npos || - elemTypeStr.find("bf16") != std::string::npos) - elemBytes = 2; - else if (elemTypeStr.find("double") != std::string::npos || - elemTypeStr.find("f64") != std::string::npos) - elemBytes = 8; - - if (allStatic) { - if (finalShape[2] == 16 && - finalShape[2] * finalShape[3] * elemBytes == 512 && - finalStride[4] == 1 && finalStride[3] == finalShape[4]) { - layoutTag = 2; // NZ - } else { - bool isRow = finalStride[4] == 1; - for (int i = 3; i >= 0; --i) - isRow &= (finalStride[i] == - multiplyOrDynamic(finalStride[i + 1], finalShape[i + 1])); - bool isCol = finalStride[0] == 1; - for (int i = 0; i < 4; ++i) - isCol &= (finalStride[i + 1] == - multiplyOrDynamic(finalStride[i], finalShape[i])); - if (isCol) - layoutTag = 1; // DN - else - layoutTag = isRow ? 0 : 0; // fallback ND - } - } - - if (layoutTag == 1) - layoutEnum = "pto::Layout::DN"; - else if (layoutTag == 2) - layoutEnum = "pto::Layout::NZ"; - } - // GlobalTensor takes a Layout non-type template parameter; directly use the - // enum constant. - - - // ------------------------------------------------------------------------- - // Part 3: 显式对象实例化 (Explicit Object Instantiation) - // ------------------------------------------------------------------------- - - // A. Instantiate Shape object. - auto shapeTypeOpaque = emitc::OpaqueType::get(ctx, shapeCppType); - SmallVector shapeArgs; - // 从 adaptor.getSizes() 获取 subview 的所有 dynamic sizes - for (Value dynSize : adaptor.getSizes()) { - shapeArgs.push_back(dynSize); - } - - auto shapeInstOp = rewriter.create( - loc, - shapeTypeOpaque, // 返回类型 - shapeCppType, // 调用的“函数名”即类名构造函数 - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange(shapeArgs) - ); - - // B. Instantiate Stride object. - auto strideTypeOpaque = emitc::OpaqueType::get(ctx, strideCppType); - // 仅传入动态 stride 维度对应的值,匹配 pto::Stride 的 N-parameter ctor(并满足其 static_assert)。 - SmallVector strideCtorArgs; - strideCtorArgs.reserve(5); - for (int i = 0; i < 5; ++i) { - if (finalStride[i] == -1) - strideCtorArgs.push_back(finalStrideValues[i]); - } - auto strideInstOp = rewriter.create( - loc, strideTypeOpaque, strideCppType, - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange(strideCtorArgs)); - - // C. Instantiate GlobalTensor object (ptr + shape + stride). - std::string gtCppType = "GlobalTensor<" + elemTypeStr + ", " + shapeCppType + - ", " + strideCppType + ", " + layoutEnum + ">"; - auto gtType = emitc::OpaqueType::get(ctx, gtCppType); - - // 准备构造参数: [ptr, shape_instance, stride_instance] - SmallVector gtConstructorArgs; - gtConstructorArgs.push_back(newPtr); - gtConstructorArgs.push_back(shapeInstOp.getResult(0)); // 拿到 shape_inst 的 SSA Value - gtConstructorArgs.push_back(strideInstOp.getResult(0)); // 拿到 stride_inst 的 SSA Value - - rewriter.replaceOpWithNewOp( - op, - gtType, - gtCppType, - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange(gtConstructorArgs) - ); - - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// Helper: build GlobalTensor from a static MemRef (for TLOAD/TSTORE) -//===----------------------------------------------------------------------===// - -static std::string getElemTypeStringForGT(Type elemTy) { - return getEmitCScalarTypeToken(elemTy); -} - -static bool hasStaticShape(MemRefType mrTy) { - return llvm::none_of(mrTy.getShape(), [](int64_t dim) { - return dim == ShapedType::kDynamic; - }); -} - -static bool getStaticMemrefLayout(MemRefType mrTy, SmallVectorImpl &strides, - int64_t &offset) { - if (failed(getStridesAndOffset(mrTy, strides, offset))) { - strides.clear(); - int64_t stride = 1; - ArrayRef shape = mrTy.getShape(); - for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { - strides.push_back(stride); - stride *= shape[i]; - } - std::reverse(strides.begin(), strides.end()); - offset = 0; - } - return offset != ShapedType::kDynamic && - llvm::none_of(strides, [](int64_t strideValue) { - return strideValue == ShapedType::kDynamic; - }); -} - -static Value applyStaticMemrefOffset(ConversionPatternRewriter &rewriter, - Location loc, Value basePtr, - int64_t offset) { - if (offset == 0) - return basePtr; - auto *ctx = rewriter.getContext(); - Type u32Ty = emitc::OpaqueType::get(ctx, "unsigned"); - auto offVal = rewriter.create( - loc, u32Ty, emitc::OpaqueAttr::get(ctx, std::to_string(offset))); - return rewriter.create(loc, basePtr.getType(), basePtr, offVal); -} - -static int getGlobalTensorElementBytes(Type elemTy) { - return static_cast(getPTOStorageElemByteSize(elemTy)); -} - -static int64_t multiplyOrDynamic(int64_t lhs, int64_t rhs) { - if (lhs < 0 || rhs < 0) - return -1; - return lhs * rhs; -} - -static void buildGlobalTensorShapeAndStride(ArrayRef shape, - ArrayRef strides, - SmallVectorImpl &shape5D, - SmallVectorImpl &stride5D) { - shape5D.assign(5, 1); - stride5D.assign(5, 1); - int rank = static_cast(shape.size()); - int shift = 5 - rank; - for (int i = 0; i < rank && i < 5; ++i) { - shape5D[shift + i] = shape[i]; - stride5D[shift + i] = strides[i]; - } - for (int i = 3; i >= 0; --i) { - if (i >= shift) - continue; - stride5D[i] = multiplyOrDynamic(shape5D[i + 1], stride5D[i + 1]); - } -} - -static std::string joinIntTemplateParams(ArrayRef values) { - std::string result; - for (size_t i = 0; i < values.size(); ++i) { - if (i != 0) - result += ", "; - result += std::to_string(values[i]); - } - return result; -} - -static SmallVector buildRowMajorStrides(ArrayRef shape) { - SmallVector strides(shape.size(), 1); - int64_t running = 1; - for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { - strides[i] = running; - running = multiplyOrDynamic(running, shape[i]); - } - return strides; -} - -static std::string getGlobalTensorTypeStringFromShape(Type elemTy, - ArrayRef shape, - StringRef layoutEnum) { - SmallVector strides = buildRowMajorStrides(shape); - return getGlobalTensorTypeStringFromShapeAndStrides(elemTy, shape, strides, - layoutEnum); -} - -static std::string getGlobalTensorTypeStringFromShapeAndStrides( - Type elemTy, ArrayRef shape, ArrayRef strides, - StringRef layoutEnum) { - SmallVector shape5D; - SmallVector stride5D; - buildGlobalTensorShapeAndStride(shape, strides, shape5D, stride5D); - - std::string elemTypeStr = getElemTypeStringForGT(elemTy); - std::string shapeType = "pto::Shape<" + joinIntTemplateParams(shape5D) + ">"; - std::string strideType = - "pto::Stride<" + joinIntTemplateParams(stride5D) + ">"; - return "GlobalTensor<" + elemTypeStr + ", " + shapeType + ", " + - strideType + ", " + layoutEnum.str() + ">"; -} - -static emitc::OpaqueType getGlobalTensorOpaqueTypeFromShape( - MLIRContext *ctx, Type elemTy, ArrayRef shape, - StringRef layoutEnum) { - return emitc::OpaqueType::get( - ctx, getGlobalTensorTypeStringFromShape(elemTy, shape, layoutEnum)); -} - -static std::string inferFallbackGlobalTensorLayout(ArrayRef shape5D, - ArrayRef stride5D, - Type elemTy) { - int elemBytes = getGlobalTensorElementBytes(elemTy); - if (elemBytes == 0) - return "pto::Layout::ND"; - if (shape5D[2] == 16 && multiplyOrDynamic(shape5D[2], shape5D[3]) * elemBytes == 512 && - stride5D[4] == 1 && stride5D[3] == shape5D[4]) { - return "pto::Layout::NZ"; - } - - bool isRowMajor = stride5D[4] == 1; - for (int i = 3; i >= 0 && isRowMajor; --i) - isRowMajor = stride5D[i] == multiplyOrDynamic(stride5D[i + 1], shape5D[i + 1]); - - bool isColMajor = stride5D[0] == 1; - for (int i = 0; i < 4 && isColMajor; ++i) - isColMajor = stride5D[i + 1] == multiplyOrDynamic(stride5D[i], shape5D[i]); - - if (isColMajor) - return "pto::Layout::DN"; - return isRowMajor ? "pto::Layout::ND" : "pto::Layout::ND"; -} - -static std::string resolveGlobalTensorLayout(Operation *anchor, Value basePtr, - ArrayRef shape5D, - ArrayRef stride5D, - Type elemTy) { - if (auto layout = resolveLayoutForGlobalTensor(anchor, basePtr)) - return layoutToEmitCString(*layout); - return inferFallbackGlobalTensorLayout(shape5D, stride5D, elemTy); -} - -struct GlobalTensorTypeNames { - std::string shapeTypeName; - std::string strideTypeName; - std::string tensorTypeName; - std::string layoutConstName; -}; - -static GlobalTensorTypeNames getGlobalTensorTypeNames(Operation *anchor) { - std::string suffix = "_" + std::to_string(reinterpret_cast(anchor)); - return { - "GTShape" + suffix, - "GTStride" + suffix, - "GT" + suffix, - "GT" + suffix + "_layout", - }; -} -static Value buildGlobalTensorFromMemref(ConversionPatternRewriter &rewriter, - Location loc, Value basePtr, - MemRefType mrTy, - Operation *anchor) { - auto *ctx = rewriter.getContext(); - - ArrayRef shape = mrTy.getShape(); - if (!hasStaticShape(mrTy)) - return Value(); - - SmallVector strides; - int64_t offset = 0; - if (!getStaticMemrefLayout(mrTy, strides, offset)) - return Value(); - - Value ptr = applyStaticMemrefOffset(rewriter, loc, basePtr, offset); - GlobalTensorTypeNames names = getGlobalTensorTypeNames(anchor); - std::string elemTypeStr = getElemTypeStringForGT(mrTy.getElementType()); - SmallVector shape5D; - SmallVector stride5D; - buildGlobalTensorShapeAndStride(shape, strides, shape5D, stride5D); - - rewriter.create( - loc, "using " + names.shapeTypeName + " = pto::Shape<" + - joinIntTemplateParams(shape5D) + ">;"); - rewriter.create( - loc, "using " + names.strideTypeName + " = pto::Stride<" + - joinIntTemplateParams(stride5D) + ">;"); - - std::string layoutEnum = resolveGlobalTensorLayout( - anchor, basePtr, shape5D, stride5D, mrTy.getElementType()); - rewriter.create(loc, "constexpr pto::Layout " + - names.layoutConstName + " = " + - layoutEnum + ";"); - - auto shapeTypeOpaque = emitc::OpaqueType::get(ctx, names.shapeTypeName); - auto strideTypeOpaque = emitc::OpaqueType::get(ctx, names.strideTypeName); - auto shapeInstOp = rewriter.create( - loc, shapeTypeOpaque, names.shapeTypeName, ArrayAttr{}, ArrayAttr{}, - ValueRange{}); - auto strideInstOp = rewriter.create( - loc, strideTypeOpaque, names.strideTypeName, ArrayAttr{}, ArrayAttr{}, - ValueRange{}); - - rewriter.create( - loc, "using " + names.tensorTypeName + " = GlobalTensor<" + elemTypeStr + - ", " + names.shapeTypeName + ", " + names.strideTypeName + - ", " + names.layoutConstName + ">;"); - auto gtType = emitc::OpaqueType::get(ctx, names.tensorTypeName); - - SmallVector gtArgs; - gtArgs.push_back(ptr); - gtArgs.push_back(shapeInstOp.getResult(0)); - gtArgs.push_back(strideInstOp.getResult(0)); - - auto gtInst = rewriter.create( - loc, gtType, names.tensorTypeName, ArrayAttr{}, ArrayAttr{}, - ValueRange(gtArgs)); - - return gtInst.getResult(0); -} - -static Value maybeWrapGlobalMemrefAsGlobalTensor( - ConversionPatternRewriter &rewriter, Location loc, Value loweredValue, - Type originalType, Operation *anchor) { - auto mrTy = dyn_cast(originalType); - if (!mrTy) - return loweredValue; - - bool isGlobal = true; - if (auto asAttr = - dyn_cast_or_null(mrTy.getMemorySpace())) { - auto as = asAttr.getAddressSpace(); - isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); - } - if (!isGlobal) - return loweredValue; - - if (Value gt = - buildGlobalTensorFromMemref(rewriter, loc, loweredValue, mrTy, anchor)) - return gt; - return loweredValue; -} - -static Value castToGMBytePointer(ConversionPatternRewriter &rewriter, - Location loc, Value value) { - auto *ctx = rewriter.getContext(); - auto targetTy = - emitc::PointerType::get(emitc::OpaqueType::get(ctx, "__gm__ uint8_t")); - if (value.getType() == targetTy) - return value; - - auto castTyAttr = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "__gm__ uint8_t*")}); - if (isSetFFTsPointerLikeType(value.getType())) { - return rewriter - .create(loc, targetTy, "reinterpret_cast", - ArrayAttr{}, castTyAttr, - ValueRange{value}) - .getResult(0); - } - return rewriter.create(loc, targetTy, value).getResult(); -} - -static Value materializeTensorViewDataPointer( - ConversionPatternRewriter &rewriter, Location loc, Value value, - Type sourceType) { - auto tvTy = dyn_cast(sourceType); - if (!tvTy) - return value; - - auto *ctx = rewriter.getContext(); - std::string elemTypeStr = getElemTypeStringForGT(tvTy.getElementType()); - auto ptrTy = emitc::PointerType::get( - emitc::OpaqueType::get(ctx, "__gm__ " + elemTypeStr)); - return rewriter - .create(loc, ptrTy, "PTOAS__GLOBAL_TENSOR_DATA", - ArrayAttr{}, ArrayAttr{}, ValueRange{value}) - .getResult(0); -} - -static std::string tileBufBLayoutToken(pto::TileBufConfigAttr configAttr) { - std::string blTok = "BLayout::RowMajor"; - if (auto blAttr = dyn_cast(configAttr.getBLayout())) { - if (static_cast(blAttr.getValue()) == 1) - blTok = "BLayout::ColMajor"; - } - return blTok; -} - -static std::string tileBufSLayoutToken(pto::TileBufConfigAttr configAttr) { - std::string slTok = "SLayout::NoneBox"; - if (auto slAttr = dyn_cast(configAttr.getSLayout())) { - int32_t slVal = static_cast(slAttr.getValue()); - slTok = (slVal == 1) ? "SLayout::RowMajor" - : (slVal == 2) ? "SLayout::ColMajor" - : "SLayout::NoneBox"; - } - return slTok; -} - -static std::string tileBufPadToken(pto::TileBufConfigAttr configAttr) { - std::string padTok = "PadValue::Null"; - if (auto padAttr = dyn_cast(configAttr.getPad())) { - switch (static_cast(padAttr.getValue())) { - case 1: - padTok = "PadValue::Zero"; - break; - case 2: - padTok = "PadValue::Max"; - break; - case 3: - padTok = "PadValue::Min"; - break; - default: - padTok = "PadValue::Null"; - break; - } - } - return padTok; -} - -static pto::BLayout getTileBufBLayoutValue(pto::TileBufConfigAttr configAttr) { - if (auto blAttr = dyn_cast(configAttr.getBLayout())) - return blAttr.getValue(); - return pto::BLayout::RowMajor; -} - -static int64_t renderTileTemplateDim(int64_t rawDim, Type elemTy, - pto::BLayout blayout, int dimIdx) { - assert(dimIdx >= 0 && dimIdx < 2 && - "renderTileTemplateDim expects a rank-2 rows/cols dimension index"); - if (rawDim == ShapedType::kDynamic) - return rawDim; - if (!pto::isPTOFloat4PackedType(elemTy)) - return rawDim; - int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; - return dimIdx == packedDim ? rawDim * 2 : rawDim; -} - -static FailureOr buildAsyncScratchTileValue( - ConversionPatternRewriter &rewriter, Location loc, Value originalScratch, - Value emittedScratch) { - Value scratch = peelUnrealized(emittedScratch); - if (auto opaqueTy = dyn_cast(scratch.getType())) { - StringRef typeStr = opaqueTy.getValue(); - if (typeStr.contains("Tile<") || typeStr.contains("ConvTile<")) - return scratch; - } - - auto memTy = dyn_cast(originalScratch.getType()); - if (!memTy) - return failure(); - - ArrayRef shape = memTy.getShape(); - if (!memTy.hasStaticShape() || shape.empty() || shape.size() > 2) - return failure(); - - int64_t rows = shape.size() == 1 ? 1 : shape[0]; - int64_t cols = shape.size() == 1 ? shape[0] : shape[1]; - - auto *ctx = rewriter.getContext(); - pto::TileBufConfigAttr configAttr = pto::TileBufConfigAttr::getDefault(ctx); - if (auto bind = originalScratch.getDefiningOp()) { - configAttr = bind.getConfig(); - } else if (auto cast = originalScratch.getDefiningOp()) { - if (auto config = cast.getConfig()) - configAttr = *config; - } - - int32_t fractal = 512; - if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) - fractal = frAttr.getInt(); - - Type elemTy = memTy.getElementType(); - pto::BLayout blayout = getTileBufBLayoutValue(configAttr); - int64_t templateRows = renderTileTemplateDim(rows, elemTy, blayout, 0); - int64_t templateCols = renderTileTemplateDim(cols, elemTy, blayout, 1); - std::string elemTypeStr = getEmitCScalarTypeToken(elemTy); - std::string tileTypeStr = - "Tile"; - - Value tile = rewriter - .create( - loc, emitc::OpaqueType::get(ctx, tileTypeStr), - emitc::OpaqueAttr::get(ctx, "")) - .getResult(); - auto addr = rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - Value scratchAddr = - rewriter - .create(loc, emitc::OpaqueType::get(ctx, "uint64_t"), - "reinterpret_cast", ArrayAttr{}, addr, - ValueRange{scratch}) - .getResult(0); - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{tile, scratchAddr}); - return tile; -} - -static FailureOr buildSyncAllWorkspaceTileValue( - ConversionPatternRewriter &rewriter, Location loc, Value originalWorkspace, - Value emittedWorkspace) { - Value workspace = peelUnrealized(emittedWorkspace); - if (auto opaqueTy = dyn_cast(workspace.getType())) { - StringRef typeStr = opaqueTy.getValue(); - if (typeStr.contains("Tile<") || typeStr.contains("ConvTile<")) - return workspace; - } - - auto memTy = dyn_cast(originalWorkspace.getType()); - if (!memTy) - return failure(); - if (!memTy.hasStaticShape()) - return failure(); - - ArrayRef rawShape = memTy.getShape(); - if (rawShape.empty() || rawShape.size() > 2) - return failure(); - - int64_t rows = rawShape.size() == 1 ? 1 : rawShape[0]; - int64_t cols = rawShape.size() == 1 ? rawShape[0] : rawShape[1]; - SmallVector shape{rows, cols}; - SmallVector validShape{rows, cols}; - - auto *ctx = rewriter.getContext(); - pto::TileBufConfigAttr configAttr = pto::TileBufConfigAttr::getDefault(ctx); - if (auto bind = originalWorkspace.getDefiningOp()) { - configAttr = bind.getConfig(); - } else if (auto cast = originalWorkspace.getDefiningOp()) { - if (auto config = cast.getConfig()) - configAttr = *config; - } - - Attribute memorySpace = memTy.getMemorySpace(); - if (!memorySpace) - return failure(); - - auto tileTy = pto::TileBufType::get(ctx, shape, memTy.getElementType(), - memorySpace, validShape, configAttr); - auto tileTypeString = getEmitCTileTypeString(tileTy); - if (!tileTypeString) - return failure(); - - auto tileEmitTy = emitc::OpaqueType::get(ctx, *tileTypeString); - Value tile = rewriter - .create(loc, tileEmitTy, - emitc::OpaqueAttr::get(ctx, "")) - .getResult(); - - Value rawPtr = workspace; - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - if (isSetFFTsPointerLikeType(rawPtr.getType())) { - auto rcU64 = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - rawPtr = rewriter - .create(loc, u64Ty, "reinterpret_cast", - ArrayAttr{}, rcU64, - ValueRange{rawPtr}) - .getResult(0); - } else if (rawPtr.getType() != u64Ty) { - rawPtr = rewriter.create(loc, u64Ty, rawPtr).getResult(); - } - - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{tile, rawPtr}); - return tile; -} - -//===----------------------------------------------------------------------===// -// pto.pointer_cast lowering -//===----------------------------------------------------------------------=== -struct PointerCastConversion : public OpConversionPattern { - static bool getIndexConst(Value v, int64_t &out) { - if (auto cst = v.getDefiningOp()) { - if (auto ia = dyn_cast(cst.getValue())) { - out = ia.getValue().getSExtValue(); - return true; - } - } - return false; - } - - using OpConversionPattern::OpConversionPattern; - - enum class TileRole { Vec, Mat, Left, Right, Acc, Bias, Scaling }; - - static void collectUserOpsThroughCasts(Value v, SmallVectorImpl &out) { - for (Operation *u : v.getUsers()) { - if (auto castOp = dyn_cast(u)) { - for (Value r : castOp.getResults()) - collectUserOpsThroughCasts(r, out); - continue; - } - out.push_back(u); - } - } - - static Value peelUnrealized(Value v) { - while (auto castOp = v.getDefiningOp()) { - v = castOp.getOperand(0); - } - return v; - } - - static TileRole inferRole(pto::PointerCastOp op) { - // 1. 优先检查 AddressSpace - if (auto memRefTy = dyn_cast(op.getType())) { - Attribute memorySpace = memRefTy.getMemorySpace(); - if (auto ptoAttr = dyn_cast_or_null(memorySpace)) { - switch (ptoAttr.getAddressSpace()) { - case pto::AddressSpace::LEFT: return TileRole::Left; - case pto::AddressSpace::RIGHT: return TileRole::Right; - case pto::AddressSpace::ACC: return TileRole::Acc; - case pto::AddressSpace::BIAS: return TileRole::Bias; - case pto::AddressSpace::MAT: return TileRole::Mat; - case pto::AddressSpace::SCALING: return TileRole::Scaling; - default: break; - } - } - } - - // 2. 通过 Usage 推导 (Fallback) - SmallVector users; - collectUserOpsThroughCasts(op.getResult(), users); - - for (Operation *user : users) { - if (auto mm = dyn_cast(user)) { - if (mm.getDst() && peelUnrealized(mm.getDst()) == op.getResult()) return TileRole::Acc; - if (peelUnrealized(mm.getLhs()) == op.getResult()) return TileRole::Left; - if (peelUnrealized(mm.getRhs()) == op.getResult()) return TileRole::Right; - } - if (auto mmacc = dyn_cast(user)) { - if (mmacc.getDst() && peelUnrealized(mmacc.getDst()) == op.getResult()) return TileRole::Acc; - if (peelUnrealized(mmacc.getAccIn()) == op.getResult()) return TileRole::Acc; - if (peelUnrealized(mmacc.getLhs()) == op.getResult()) return TileRole::Left; - if (peelUnrealized(mmacc.getRhs()) == op.getResult()) return TileRole::Right; - } - } - - return TileRole::Vec; - } - - // [新增] 辅助函数:判断 Value 是否源自 arith.constant - static bool isConstant(Value v, int64_t &outVal) { - if (!v) return false; - if (auto cst = v.getDefiningOp()) { - if (auto attr = dyn_cast(cst.getValue())) { - outVal = attr.getInt(); - return true; - } - } - return false; - } - - LogicalResult matchAndRewrite(pto::PointerCastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - auto selfType = mlir::cast(op.getType()); - ArrayRef shape = selfType.getShape(); - Type elemType = selfType.getElementType(); - - // 1. 推导 Tile Role - TileRole role = inferRole(op); - - // 2. 类型字符串生成 (elemTypeStr, dimStr) - std::string elemTypeStr = getEmitCScalarTypeToken(elemType); - - std::string dimStr; - pto::BLayout blayout = pto::BLayout::RowMajor; - auto dimToString = [&](int64_t dim, const char *symbol, - int dimIdx) -> std::string { - if (dim == ShapedType::kDynamic) - return std::string(symbol); - return std::to_string(renderTileTemplateDim(dim, elemType, blayout, - dimIdx)); - }; - - // 3. Role Token - const char *roleTok = "TileType::Vec"; - switch (role) { - case TileRole::Left: roleTok = "TileType::Left"; break; - case TileRole::Right: roleTok = "TileType::Right"; break; - case TileRole::Acc: roleTok = "TileType::Acc"; break; - case TileRole::Bias: roleTok = "TileType::Bias"; break; - case TileRole::Mat: roleTok = "TileType::Mat"; break; - case TileRole::Vec: roleTok = "TileType::Vec"; break; - case TileRole::Scaling: roleTok = "TileType::Scaling"; break; - } - - // 4. Config & Layout (support BLayoutAttr/SLayoutAttr/PadValueAttr after namespace change) - std::string layoutParams = "BLayout::RowMajor"; - std::string extraParams = ""; - if (auto configOpt = op.getConfig()) { - auto config = *configOpt; - int32_t blVal = 0; - if (auto attr = dyn_cast(config.getBLayout())) - blVal = static_cast(attr.getValue()); - - if (blVal == 1) layoutParams = "BLayout::ColMajor"; - blayout = blVal == 1 ? pto::BLayout::ColMajor : pto::BLayout::RowMajor; - - int32_t slVal = 0; - if (auto attr = dyn_cast(config.getSLayout())) - slVal = static_cast(attr.getValue()); - - std::string slStr = (slVal == 1) ? "SLayout::RowMajor" : (slVal == 2) ? "SLayout::ColMajor" : "SLayout::NoneBox"; - - int32_t frVal = 0; - if (auto attr = dyn_cast(config.getSFractalSize())) frVal = attr.getInt(); - - int32_t padVal = 0; - if (auto attr = dyn_cast(config.getPad())) - padVal = static_cast(attr.getValue()); - - std::string padStr = "PadValue::Null"; - switch (padVal) { - case 1: padStr = "PadValue::Zero"; break; - case 2: padStr = "PadValue::Max"; break; - case 3: padStr = "PadValue::Min"; break; - } - - int32_t compactVal = 0; - if (auto attr = dyn_cast(config.getCompactMode())) - compactVal = static_cast(attr.getValue()); - - std::string compactStr = "CompactMode::Null"; - switch (compactVal) { - case 1: compactStr = "CompactMode::Normal"; break; - case 2: compactStr = "CompactMode::RowPlusOne"; break; - } - - if (!slStr.empty()) { - extraParams += ", " + slStr + ", " + std::to_string(frVal) + ", " + - padStr + ", " + compactStr; - } - } else { - extraParams = ", SLayout::NoneBox, 512, PadValue::Null, CompactMode::Null"; - } - - if (role == TileRole::Left) - dimStr = dimToString(shape[0], "M", 0) + ", " + - dimToString(shape[1], "K", 1); - else if (role == TileRole::Right) - dimStr = dimToString(shape[0], "K", 0) + ", " + - dimToString(shape[1], "N", 1); - else if (role == TileRole::Bias) - dimStr = "1, " + dimToString(shape[1], "N", 1); - else - dimStr = dimToString(shape[0], "M", 0) + ", " + - dimToString(shape[1], "N", 1); - - // [核心修改] Valid Dims 处理逻辑 (支持混合静态/动态) - std::string vrowTok, vcolTok; - bool useConstructor = false; - - bool rowIsDynamic = false; - bool colIsDynamic = false; - - SmallVector constructorArgs; - - Value vRow = op.getValidRow(); - Value vCol = op.getValidCol(); - Value vRowEmitC = adaptor.getValidRow(); - Value vColEmitC = adaptor.getValidCol(); - bool forceDynamicValid = op->hasAttr(kForceDynamicValidShapeAttrName); - - int64_t cRow = 0, cCol = 0; - bool rowIsConst = vRow && isConstant(vRow, cRow); - bool colIsConst = vCol && isConstant(vCol, cCol); - - auto makeCtorDimValue = [&](Value emitted, int64_t fallback) -> Value { - if (emitted) - return emitted; - return makeEmitCIntConstant( - rewriter, loc, emitc::OpaqueType::get(ctx, "int32_t"), fallback); - }; - auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { - if (!emitted || !pto::isPTOFloat4PackedType(elemType)) - return emitted; - int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; - if (dimIdx != packedDim) - return emitted; - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); - return rewriter.create(loc, i32Ty, emitted, two).getResult(); - }; - - if (forceDynamicValid) { - vrowTok = "-1"; - vcolTok = "-1"; - useConstructor = true; - constructorArgs.push_back( - makeCtorDimValue(maybeScaleDynamicValid(vRowEmitC, 0), - renderTileTemplateDim(rowIsConst ? cRow : shape[0], - elemType, blayout, 0))); - constructorArgs.push_back( - makeCtorDimValue(maybeScaleDynamicValid(vColEmitC, 1), - renderTileTemplateDim(colIsConst ? cCol : shape[1], - elemType, blayout, 1))); - } else { - if (rowIsConst) { - vrowTok = std::to_string( - renderTileTemplateDim(cRow, elemType, blayout, 0)); - } else if (vRow) { - vrowTok = "-1"; - rowIsDynamic = true; - useConstructor = true; - } else { - vrowTok = std::to_string( - renderTileTemplateDim(shape[0], elemType, blayout, 0)); - } - - if (colIsConst) { - vcolTok = std::to_string( - renderTileTemplateDim(cCol, elemType, blayout, 1)); - } else if (vCol) { - vcolTok = "-1"; - colIsDynamic = true; - useConstructor = true; - } else { - vcolTok = std::to_string( - renderTileTemplateDim(shape[1], elemType, blayout, 1)); - } - - if (useConstructor) { - if (rowIsDynamic && vRowEmitC) - constructorArgs.push_back(maybeScaleDynamicValid(vRowEmitC, 0)); - if (colIsDynamic && vColEmitC) - constructorArgs.push_back(maybeScaleDynamicValid(vColEmitC, 1)); - } - } - - // 5. 生成 Tile 类型字符串 - std::string tileTypeStr = - std::string("Tile<") + roleTok + ", " + elemTypeStr + ", " + dimStr + ", " + - layoutParams + ", " + vrowTok + ", " + vcolTok + extraParams + ">"; - - auto tileType = emitc::OpaqueType::get(ctx, tileTypeStr); - Value resultValue; - - if (useConstructor) { - // 使用 CallOpaqueOp 生成构造函数调用 (Tile v = Tile(...)) - auto ctorOp = rewriter.create( - loc, - tileType, // Result Type - tileTypeStr, // Callee Name (类名) - ArrayAttr{}, // args - ArrayAttr{}, // template_args - ValueRange(constructorArgs) // operands - ); - resultValue = ctorOp.getResult(0); - } else { - // 静态情况 (Tile v;) - auto varOp = rewriter.create( - loc, - tileType, - emitc::OpaqueAttr::get(ctx, "") - ); - resultValue = varOp.getResult(); - } - - // TASSIGN: pto-isa expects an integral address. - Value addr = adaptor.getAddrs()[0]; - if (isa(addr.getType()) || - (isa(addr.getType()) && - cast(addr.getType()).getValue().ends_with("*"))) { - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - auto rcU64 = rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - addr = rewriter.create( - loc, u64Ty, "reinterpret_cast", - /*args=*/ArrayAttr{}, /*templateArgs=*/rcU64, - /*operands=*/ValueRange{addr}) - .getResult(0); - } - - rewriter.create( - loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{resultValue, addr}); - - rewriter.replaceOp(op, resultValue); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.load_dps / pto.store_dps lowering (FIX: keep optional result) -//===----------------------------------------------------------------------=== - -struct PTOTLoadToTLOAD : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TLoadOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!op.getDst()) - return rewriter.notifyMatchFailure(op, "expected outs(dst) on pto.tload"); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value srcArg = src; - if (auto srcMrTy = dyn_cast(op.getSrc().getType())) { - bool isGlobal = true; - if (auto asAttr = dyn_cast_or_null(srcMrTy.getMemorySpace())) { - auto as = asAttr.getAddressSpace(); - isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); - } - if (isGlobal) { - if (Value gt = buildGlobalTensorFromMemref(rewriter, op.getLoc(), src, srcMrTy, - op.getOperation())) - srcArg = gt; - } - } - - rewriter.create( - op.getLoc(), TypeRange{}, "TLOAD", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, srcArg}); - - if (op->getNumResults() == 1) { - rewriter.replaceOp(op, dst); - } else { - rewriter.eraseOp(op); - } - return success(); - } -}; - -struct PTOTPrefetchToTPREFETCH : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPrefetchOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!op.getDst()) - return rewriter.notifyMatchFailure(op, "expected outs(dst) on pto.tprefetch"); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value srcArg = src; - if (auto srcMrTy = dyn_cast(op.getSrc().getType())) { - bool isGlobal = true; - if (auto asAttr = dyn_cast_or_null(srcMrTy.getMemorySpace())) { - auto as = asAttr.getAddressSpace(); - isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); - } - if (isGlobal) { - if (Value gt = buildGlobalTensorFromMemref(rewriter, op.getLoc(), src, srcMrTy, - op.getOperation())) - srcArg = gt; - } - } - - rewriter.create( - op.getLoc(), TypeRange{}, "TPREFETCH", - ArrayAttr{}, ArrayAttr{}, ValueRange{dst, srcArg}); - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOTPrefetchAsyncToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPrefetchAsyncOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src = peelUnrealized(adaptor.getSrc()); - Value srcArg = src; - if (!isEmitCGlobalTensorLikeType(srcArg.getType())) { - auto srcMrTy = dyn_cast(op.getSrc().getType()); - if (!srcMrTy) - return rewriter.notifyMatchFailure( - op, "expected src to lower to GlobalTensor or memref"); - srcArg = buildGlobalTensorFromMemref(rewriter, op.getLoc(), src, srcMrTy, - op.getSrc().getDefiningOp() - ? op.getSrc().getDefiningOp() - : op.getOperation()); - } - if (!srcArg) - return rewriter.notifyMatchFailure(op, - "failed to build GlobalTensor src"); - - Value prefetchCtx = peelUnrealized(adaptor.getCtx()); - - Type eventTy = getTypeConverter()->convertType(op.getEvent().getType()); - if (!eventTy) - return rewriter.notifyMatchFailure( - op, "failed to convert tprefetch_async result type"); - - Value event = rewriter - .create( - op.getLoc(), TypeRange{eventTy}, "TPREFETCH_ASYNC", - ArrayAttr{}, ArrayAttr{}, - ValueRange{srcArg, prefetchCtx}) - .getResult(0); - - rewriter.replaceOp(op, ValueRange{event}); - return success(); - } -}; - -struct PTOMakePrefetchAsyncContextToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::MakePrefetchAsyncContextOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type ctxTy = getTypeConverter()->convertType(op.getCtx().getType()); - if (!ctxTy) - return rewriter.notifyMatchFailure( - op, "failed to convert make_prefetch_async_context result type"); - - Value workspace = peelUnrealized(adaptor.getWorkspace()); - workspace = castToGMBytePointer(rewriter, op.getLoc(), workspace); - - Value ctx = rewriter - .create( - op.getLoc(), TypeRange{ctxTy}, "pto::PrefetchAsyncContext", - ArrayAttr{}, ArrayAttr{}, ValueRange{workspace}) - .getResult(0); - - rewriter.replaceOp(op, ValueRange{ctx}); - return success(); - } -}; - -struct PTOGetPrefetchAsyncSessionToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::GetPrefetchAsyncSessionOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type sessionTy = getTypeConverter()->convertType(op.getSession().getType()); - if (!sessionTy) - return rewriter.notifyMatchFailure( - op, "failed to convert get_prefetch_async_session result type"); - - Value ctx = peelUnrealized(adaptor.getCtx()); - Value session = rewriter - .create( - op.getLoc(), TypeRange{sessionTy}, - "PTOAS__PREFETCH_CTX_SESSION", ArrayAttr{}, - ArrayAttr{}, ValueRange{ctx}) - .getResult(0); - - rewriter.replaceOp(op, ValueRange{session}); - return success(); - } -}; - -struct PTOTStoreToTSTORE : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - static std::string stPhaseTok(pto::STPhase phase) { - switch (phase) { - case pto::STPhase::Unspecified: return "STPhase::Unspecified"; - case pto::STPhase::Partial: return "STPhase::Partial"; - case pto::STPhase::Final: return "STPhase::Final"; - } - return "STPhase::Unspecified"; - } - - static std::string atomicTypeTok(pto::AtomicType atomicType) { - switch (atomicType) { - case pto::AtomicType::AtomicNone: return "AtomicType::AtomicNone"; - case pto::AtomicType::AtomicAdd: return "AtomicType::AtomicAdd"; - } - return "AtomicType::AtomicNone"; - } - - static std::string reluPreModeTok(pto::ReluPreMode reluPreMode) { - switch (reluPreMode) { - case pto::ReluPreMode::NoRelu: return "ReluPreMode::NoRelu"; - case pto::ReluPreMode::NormalRelu: return "ReluPreMode::NormalRelu"; - } - return "ReluPreMode::NoRelu"; - } - - LogicalResult matchAndRewrite(pto::TStoreOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!op.getDst()) - return rewriter.notifyMatchFailure(op, "expected outs(dst) on pto.tstore"); - - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value preQuantScalar; - if (op.getPreQuantScalar()) - preQuantScalar = peelUnrealized(adaptor.getPreQuantScalar()); - Value dstArg = dst; - if (auto dstMrTy = dyn_cast(op.getDst().getType())) { - bool isGlobal = true; - if (auto asAttr = dyn_cast_or_null(dstMrTy.getMemorySpace())) { - auto as = asAttr.getAddressSpace(); - isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); - } - if (isGlobal) { - if (Value gt = buildGlobalTensorFromMemref(rewriter, op.getLoc(), dst, dstMrTy, - op.getOperation())) - dstArg = gt; - } - } - - const auto phase = op.getStPhase(); - const auto atomicType = op.getAtomicType(); - const auto reluPreMode = op.getReluPreMode(); - const bool hasPreQuantScalar = static_cast(preQuantScalar); - const bool phaseNonDefault = phase != pto::STPhase::Unspecified; - const bool atomicNonDefault = atomicType != pto::AtomicType::AtomicNone; - const bool reluNonDefault = reluPreMode != pto::ReluPreMode::NoRelu; - - auto getOpaqueTok = [&](Value v, StringRef name) -> FailureOr { - if (auto ot = mlir::dyn_cast(v.getType())) - return ot.getValue().str(); - return rewriter.notifyMatchFailure(op, (name + " must be emitc::OpaqueType").str()); - }; - - ArrayAttr targs; - // Map op attributes/operands to the exact TSTORE overload family: - // 1) TSTORE(dst, src) - // 2) TSTORE(dst, src) - // 3) TSTORE(dst, src) - // 4) TSTORE(dst, src) - // 5) TSTORE(dst, src) - // 6) TSTORE(dst, src) - // 7) TSTORE(dst, src, preQuant) - // 8) TSTORE(dst, src, preQuant) - if (!hasPreQuantScalar && !reluNonDefault && !atomicNonDefault) { - if (phaseNonDefault) { - targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, stPhaseTok(phase)), - }); - } else { - targs = ArrayAttr{}; - } - } else { - auto srcTokOr = getOpaqueTok(src, "src"); - auto dstTokOr = getOpaqueTok(dstArg, "dst"); - if (failed(srcTokOr) || failed(dstTokOr)) - return failure(); - - // If there is no preQuant and relu stays default, emit the atomic-only - // overloads (#3/#4) without ReluPreMode template argument. - if (!hasPreQuantScalar && !reluNonDefault) { - if (phaseNonDefault) { - targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, stPhaseTok(phase)), - emitc::OpaqueAttr::get(ctx, *srcTokOr), - emitc::OpaqueAttr::get(ctx, *dstTokOr), - emitc::OpaqueAttr::get(ctx, atomicTypeTok(atomicType)), - }); - } else { - targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, *srcTokOr), - emitc::OpaqueAttr::get(ctx, *dstTokOr), - emitc::OpaqueAttr::get(ctx, atomicTypeTok(atomicType)), - }); - } - } else { - // Relu/preQuant families (#5/#6/#7/#8): keep AtomicType + ReluPreMode. - if (phaseNonDefault) { - targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, stPhaseTok(phase)), - emitc::OpaqueAttr::get(ctx, *srcTokOr), - emitc::OpaqueAttr::get(ctx, *dstTokOr), - emitc::OpaqueAttr::get(ctx, atomicTypeTok(atomicType)), - emitc::OpaqueAttr::get(ctx, reluPreModeTok(reluPreMode)), - }); - } else { - targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, *srcTokOr), - emitc::OpaqueAttr::get(ctx, *dstTokOr), - emitc::OpaqueAttr::get(ctx, atomicTypeTok(atomicType)), - emitc::OpaqueAttr::get(ctx, reluPreModeTok(reluPreMode)), - }); - } - } - } - - SmallVector operands{dstArg, src}; - if (hasPreQuantScalar) - operands.push_back(preQuantScalar); - - rewriter.create( - loc, TypeRange{}, "TSTORE", - /*args=*/ArrayAttr{}, /*templateArgs=*/targs, - /*operands=*/operands); - - if (op->getNumResults() == 1) { - rewriter.replaceOp(op, dst); - } else { - rewriter.eraseOp(op); - } - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.matmul_dps lowering (Simplified: No internal copy/sync) -//===----------------------------------------------------------------------===// -// -// Render `pto.tmatmul` as one of three forms depending on the optional -// `acc_phase` attribute: -// * absent / Unspecified -> `TMATMUL(dst, lhs, rhs)` -// * Partial -> `TMATMUL(dst, lhs, rhs)` -// * Final -> `TMATMUL(dst, lhs, rhs)` -// The Unspecified default keeps backward compatibility with all upstream IR -// that does not yet emit an explicit phase attribute. -static ArrayAttr buildAccPhaseTemplateArgs(ConversionPatternRewriter &rewriter, - pto::AccPhase phase) { - StringRef tmpl; - switch (phase) { - case pto::AccPhase::Unspecified: - return ArrayAttr{}; - case pto::AccPhase::Partial: - tmpl = "AccPhase::Partial"; - break; - case pto::AccPhase::Final: - tmpl = "AccPhase::Final"; - break; - } - if (tmpl.empty()) - return ArrayAttr{}; - return rewriter.getArrayAttr( - {emitc::OpaqueAttr::get(rewriter.getContext(), tmpl)}); -} - -struct PTOTMatmulToTMATMUL : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMatmulOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // 1. 获取操作数 (剥离 Cast) - Value lhs = peelUnrealized(adaptor.getLhs()); // A (Left) - Value rhs = peelUnrealized(adaptor.getRhs()); // B (Right) - Value dst = peelUnrealized(adaptor.getDst()); // C (Acc) - - // 2. 根据 acc_phase 属性决定是否生成 TMATMUL(...) - ArrayAttr templateArgs = - buildAccPhaseTemplateArgs(rewriter, op.getAccPhase()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TMATMUL", - /*args=*/ArrayAttr{}, /*template_args=*/templateArgs, - ValueRange{dst, lhs, rhs}); - - // 3. 处理 Op 替换/删除 - if (op->getNumResults() == 1) { - rewriter.replaceOp(op, dst); - } else { - rewriter.eraseOp(op); - } - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.tgemv lowering -//===----------------------------------------------------------------------===// -struct PTOTGemvToTGEMV : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGemvOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // 1. 获取操作数 (剥离 Cast) - Value lhs = peelUnrealized(adaptor.getLhs()); // A (Matrix) - Value rhs = peelUnrealized(adaptor.getRhs()); // B (Vector) - Value dst = peelUnrealized(adaptor.getDst()); // C (Result) - - // 2. 直接生成函数调用 TGEMV(dst, lhs, rhs) - rewriter.create( - op.getLoc(), TypeRange{}, "TGEMV", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, lhs, rhs}); - - // 3. 处理 Op 替换/删除 - if (op->getNumResults() == 1) { - rewriter.replaceOp(op, dst); - } else { - rewriter.eraseOp(op); - } - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.tgemv.acc lowering -//===----------------------------------------------------------------------===// -struct PTOTGemvAccToTGEMVACC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGemvAccOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!op.getDst()) - return rewriter.notifyMatchFailure(op, "expected outs(dst) for pto.tgemv.acc"); - - // 1. 获取操作数 - Value accIn = peelUnrealized(adaptor.getAccIn()); // AccOld - Value lhs = peelUnrealized(adaptor.getLhs()); // A (Matrix) - Value rhs = peelUnrealized(adaptor.getRhs()); // B (Vector) - Value dst = peelUnrealized(adaptor.getDst()); // AccNew - - // 2. 直接生成函数调用 TGEMV_ACC(dst, accIn, lhs, rhs) - rewriter.create( - op.getLoc(), TypeRange{}, "TGEMV_ACC", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, accIn, lhs, rhs}); - - // 3. 处理 Op 替换/删除 - if (op->getNumResults() == 1) { - rewriter.replaceOp(op, dst); - } else { - rewriter.eraseOp(op); - } - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.matmul_acc_dps lowering (Simplified: No internal copy/sync) -//===----------------------------------------------------------------------===// -struct PTOTMatmulAccToTMATMULACC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMatmulAccOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!op.getDst()) - return rewriter.notifyMatchFailure(op, "expected outs(dst) for pto.tmatmul.acc"); - - // 1. 获取操作数 - Value accIn = peelUnrealized(adaptor.getAccIn()); // AccOld - Value lhs = peelUnrealized(adaptor.getLhs()); // A (Left) - Value rhs = peelUnrealized(adaptor.getRhs()); // B (Right) - Value dst = peelUnrealized(adaptor.getDst()); // AccNew - - // 2. 根据 acc_phase 属性决定是否生成 TMATMUL_ACC(...) - ArrayAttr templateArgs = - buildAccPhaseTemplateArgs(rewriter, op.getAccPhase()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TMATMUL_ACC", - /*args=*/ArrayAttr{}, /*template_args=*/templateArgs, - ValueRange{dst, accIn, lhs, rhs}); - - // 3. 处理 Op 替换/删除 - if (op->getNumResults() == 1) { - rewriter.replaceOp(op, dst); - } else { - rewriter.eraseOp(op); - } - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// Return lowering -//===----------------------------------------------------------------------=== - -static constexpr llvm::StringLiteral kAutoSyncTailPendingModeAttr = - "__pto.auto_sync_tail_mode"; - -struct ReturnToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (auto emitcFunc = op->getParentOfType()) { - if (auto modeAttr = - emitcFunc->getAttrOfType(kAutoSyncTailPendingModeAttr)) { - auto *ctx = rewriter.getContext(); - rewriter.setInsertionPoint(op); - auto args = rewriter.getArrayAttr( - {emitc::OpaqueAttr::get(ctx, modeAttr.getValue())}); - rewriter.create( - op.getLoc(), TypeRange{}, "ptoas_auto_sync_tail", - args, ArrayAttr{}, ValueRange{}); - } - } - - auto vals = adaptor.getOperands(); - if (vals.empty()) { - rewriter.replaceOpWithNewOp(op, Value{}); - return success(); - } - if (vals.size() == 1) { - rewriter.replaceOpWithNewOp(op, vals[0]); - return success(); - } - return rewriter.notifyMatchFailure(op, "EmitC cannot return multiple values"); - } -}; - -struct CallToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(func::CallOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (op.getNumResults() > 1) - return rewriter.notifyMatchFailure( - op, "EmitC cannot lower calls with multiple results"); - - SmallVector resultTypes; - if (failed( - getTypeConverter()->convertTypes(op.getResultTypes(), resultTypes))) - return rewriter.notifyMatchFailure(op, - "failed to convert call result types"); - - rewriter.replaceOpWithNewOp(op, op.getCalleeAttr(), - resultTypes, - adaptor.getOperands()); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// Sync lowering -//===----------------------------------------------------------------------=== - -static constexpr llvm::StringLiteral kAutoSyncTailBarrierAttr = - "pto.auto_sync_tail_barrier"; -static constexpr llvm::StringLiteral kAutoSyncTailHintAttr = - "pto.auto_sync_tail_hint"; -static constexpr llvm::StringLiteral kAutoSyncTailPolicyBarrierAll = - "barrier_all"; -static constexpr llvm::StringLiteral kAutoSyncTailPolicyMte3ToSEvent0 = - "setwait_mte3_to_s_event0"; -static constexpr llvm::StringLiteral kAutoSyncTailModeBarrierAllToken = - "PTOAutoSyncTailMode::kBarrierAll"; -static constexpr llvm::StringLiteral kAutoSyncTailModeMte3ToSEvent0Token = - "PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0"; - -static std::string getAutoSyncTailModeToken(Operation *op) { - if (op) { - if (auto hintAttr = op->getAttrOfType(kAutoSyncTailHintAttr)) { - if (hintAttr.getValue() == kAutoSyncTailPolicyBarrierAll) - return kAutoSyncTailModeBarrierAllToken.str(); - if (hintAttr.getValue() == kAutoSyncTailPolicyMte3ToSEvent0) - return kAutoSyncTailModeMte3ToSEvent0Token.str(); - } - } - - auto func = op ? op->getParentOfType() : func::FuncOp(); - if (!func) - return kAutoSyncTailModeBarrierAllToken.str(); - - auto hintAttr = func->getAttrOfType(kAutoSyncTailHintAttr); - if (!hintAttr) - return kAutoSyncTailModeBarrierAllToken.str(); - - if (hintAttr.getValue() == kAutoSyncTailPolicyBarrierAll) - return kAutoSyncTailModeBarrierAllToken.str(); - if (hintAttr.getValue() == kAutoSyncTailPolicyMte3ToSEvent0) - return kAutoSyncTailModeMte3ToSEvent0Token.str(); - - // Fallback to the conservative behavior when seeing unknown policies. - return kAutoSyncTailModeBarrierAllToken.str(); -} - -[[maybe_unused]] static std::string getPipeName(pto::PIPE pipe) { - switch (pipe) { - case pto::PIPE::PIPE_S: return "PIPE_S"; - case pto::PIPE::PIPE_V: return "PIPE_V"; - case pto::PIPE::PIPE_M: return "PIPE_M"; - case pto::PIPE::PIPE_MTE1: return "PIPE_MTE1"; - case pto::PIPE::PIPE_MTE2: return "PIPE_MTE2"; - case pto::PIPE::PIPE_MTE3: return "PIPE_MTE3"; - case pto::PIPE::PIPE_ALL: return "PIPE_ALL"; - case pto::PIPE::PIPE_MTE4: return "PIPE_MTE4"; - case pto::PIPE::PIPE_MTE5: return "PIPE_MTE5"; - case pto::PIPE::PIPE_V2: return "PIPE_V2"; - case pto::PIPE::PIPE_FIX: return "PIPE_FIX"; - case pto::PIPE::VIRTUAL_PIPE_MTE2_L1A: return "VIRTUAL_PIPE_MTE2_L1A"; - case pto::PIPE::VIRTUAL_PIPE_MTE2_L1B: return "VIRTUAL_PIPE_MTE2_L1B"; - // 默认回退 - default: return "PIPE_ALL"; - } -} - -//===----------------------------------------------------------------------===// -// pto.barrier lowering -> pipe_barrier(...) -//===----------------------------------------------------------------------===// -struct PTOBarrierToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::BarrierOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (op->hasAttr(kAutoSyncTailBarrierAttr)) { - auto modeAttr = rewriter.getStringAttr(getAutoSyncTailModeToken(op)); - if (auto emitcFunc = op->getParentOfType()) { - emitcFunc->setAttr(kAutoSyncTailPendingModeAttr, modeAttr); - } else if (auto funcOp = op->getParentOfType()) { - funcOp->setAttr(kAutoSyncTailPendingModeAttr, modeAttr); - } - rewriter.eraseOp(op); - return success(); - } - - // [FIX] op.getPipe() returns PipeAttr. - // We must call .getPipe() on the attribute to get the actual Enum value. - pto::PIPE pipeEnum = op.getPipe().getPipe(); - - // Convert Enum to String (e.g., PIPE_ALL -> "PIPE_ALL") - std::string pipeStr = pto::stringifyPIPE(pipeEnum).str(); - auto *ctx = rewriter.getContext(); - - auto args = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeStr) - }); - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, // void return - "pipe_barrier", // function name - args, // arguments - ArrayAttr{}, // template args - ValueRange{} // operands - ); - - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// Sync lowering (robust for bracket form pto.set_flag[...] / pto.wait_flag[...]) -// Replace your PTOSyncToRuntimeCall with the code below. -//===----------------------------------------------------------------------===// - -static bool tryConvertPipeAttrToToken(Attribute attr, std::string &token) { - if (!attr) - return false; - if (auto pipe = dyn_cast(attr)) { - token = mlir::pto::stringifyPIPE(pipe.getPipe()).str(); - return true; - } - if (auto stringAttr = dyn_cast(attr)) { - token = stringAttr.getValue().str(); - return true; - } - return false; -} - -static bool tryConvertEventAttrToToken(Attribute attr, std::string &token) { - if (!attr) - return false; - if (auto event = dyn_cast(attr)) { - token = mlir::pto::stringifyEVENT(event.getEvent()).str(); - return true; - } - if (auto stringAttr = dyn_cast(attr)) { - token = stringAttr.getValue().str(); - return true; - } - return false; -} - -static bool tryAssignSyncTokens(Attribute srcAttr, Attribute dstAttr, - Attribute evtAttr, std::string &srcTok, - std::string &dstTok, std::string &evtTok) { - std::string localSrc; - std::string localDst; - std::string localEvt; - if (!tryConvertPipeAttrToToken(srcAttr, localSrc) || - !tryConvertPipeAttrToToken(dstAttr, localDst) || - !tryConvertEventAttrToToken(evtAttr, localEvt)) { - return false; - } - srcTok = std::move(localSrc); - dstTok = std::move(localDst); - evtTok = std::move(localEvt); - return true; -} - -static bool tryExtractSyncTokensFromNamedAttrs(Operation *op, - StringRef srcName, - StringRef dstName, - StringRef evtName, - std::string &srcTok, - std::string &dstTok, - std::string &evtTok) { - return tryAssignSyncTokens(op->getAttr(srcName), op->getAttr(dstName), - op->getAttr(evtName), srcTok, dstTok, evtTok); -} - -static bool tryExtractSyncTokensFromArrayAttr(Operation *op, StringRef attrName, - std::string &srcTok, - std::string &dstTok, - std::string &evtTok) { - auto arrayAttr = op->getAttrOfType(attrName); - if (!arrayAttr || arrayAttr.size() < 3) - return false; - return tryAssignSyncTokens(arrayAttr[0], arrayAttr[1], arrayAttr[2], srcTok, - dstTok, evtTok); -} - -static bool tryExtractFallbackSyncTokens(Operation *op, std::string &srcTok, - std::string &dstTok, - std::string &evtTok) { - SmallVector pipes; - std::string event; - for (NamedAttribute namedAttr : op->getAttrs()) { - std::string token; - if (tryConvertPipeAttrToToken(namedAttr.getValue(), token)) { - pipes.push_back(std::move(token)); - continue; - } - if (event.empty() && - tryConvertEventAttrToToken(namedAttr.getValue(), token)) { - event = std::move(token); - } - } - if (pipes.size() < 2 || event.empty()) - return false; - srcTok = pipes[0]; - dstTok = pipes[1]; - evtTok = event; - return true; -} - -static LogicalResult extractSyncTripletTokens(Operation *op, - std::string &srcTok, - std::string &dstTok, - std::string &evtTok, - ConversionPatternRewriter &rewriter) { - if (tryExtractSyncTokensFromNamedAttrs(op, "src_pipe", "dst_pipe", "event_id", - srcTok, dstTok, evtTok) || - tryExtractSyncTokensFromNamedAttrs(op, "srcPipe", "dstPipe", "eventId", - srcTok, dstTok, evtTok) || - tryExtractSyncTokensFromNamedAttrs(op, "src", "dst", "event", srcTok, - dstTok, evtTok)) { - return success(); - } - - for (StringRef attrName : {"args", "pipes", "sync", "triplet", "attrs"}) { - if (tryExtractSyncTokensFromArrayAttr(op, attrName, srcTok, dstTok, - evtTok)) { - return success(); - } - } - - if (tryExtractFallbackSyncTokens(op, srcTok, dstTok, evtTok)) - return success(); - return rewriter.notifyMatchFailure( - op, "cannot extract PIPE/PIPE/EVENT tokens from pto.{set,wait}_flag"); -} -static inline std::string pipeTokFromPipeEnum(mlir::pto::PIPE p) { - return mlir::pto::stringifyPIPE(p).str(); -} -[[maybe_unused]] static inline std::string evtTokFromEventEnum(mlir::pto::EVENT e) { - return mlir::pto::stringifyEVENT(e).str(); -} -static inline std::string pipeTokFromPipeAttr(mlir::pto::PipeAttr a) { - return mlir::pto::stringifyPIPE(a.getPipe()).str(); -} -static inline std::string evtTokFromEventAttr(mlir::pto::EventAttr a) { - return mlir::pto::stringifyEVENT(a.getEvent()).str(); -} - -template -struct HasGetSrcPipe : std::false_type {}; -template -struct HasGetSrcPipe().getSrcPipe())>> : std::true_type {}; - -template -struct HasGetDstPipe : std::false_type {}; -template -struct HasGetDstPipe().getDstPipe())>> : std::true_type {}; - -template -struct HasGetEventId : std::false_type {}; -template -struct HasGetEventId().getEventId())>> : std::true_type {}; - -template -struct HasGetSrcPipeAttr : std::false_type {}; -template -struct HasGetSrcPipeAttr().getSrcPipeAttr())>> : std::true_type {}; - -template -struct HasGetDstPipeAttr : std::false_type {}; -template -struct HasGetDstPipeAttr().getDstPipeAttr())>> : std::true_type {}; - -template -struct HasGetEventIdAttr : std::false_type {}; -template -struct HasGetEventIdAttr().getEventIdAttr())>> : std::true_type {}; - -template -static LogicalResult extractSyncTokens(SyncOpT op, - std::string &srcTok, - std::string &dstTok, - std::string &evtTok, - ConversionPatternRewriter &rewriter) { - if constexpr (HasGetSrcPipe::value && - HasGetDstPipe::value && - HasGetEventId::value) { - auto s = op.getSrcPipe(); - auto d = op.getDstPipe(); - auto e = op.getEventId(); - - if constexpr (std::is_same::value) srcTok = pipeTokFromPipeEnum(s); - else srcTok = pipeTokFromPipeAttr(s); - - if constexpr (std::is_same::value) dstTok = pipeTokFromPipeEnum(d); - else dstTok = pipeTokFromPipeAttr(d); - - if constexpr (std::is_same::value) evtTok = evtTokFromEventEnum(e); - else evtTok = evtTokFromEventAttr(e); - - return success(); - } - - if constexpr (HasGetSrcPipeAttr::value && - HasGetDstPipeAttr::value && - HasGetEventIdAttr::value) { - auto s = op.getSrcPipeAttr(); - auto d = op.getDstPipeAttr(); - auto e = op.getEventIdAttr(); - srcTok = pipeTokFromPipeAttr(s); - dstTok = pipeTokFromPipeAttr(d); - evtTok = evtTokFromEventAttr(e); - return success(); - } - - return extractSyncTripletTokens(op.getOperation(), srcTok, dstTok, evtTok, rewriter); -} -struct PTOSetFlagToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::SetFlagOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - auto *ctx = rewriter.getContext(); - - std::string srcTok, dstTok, evtTok; - if (failed(extractSyncTokens(op, srcTok, dstTok, evtTok, rewriter))) - return failure(); - - auto argsAttr = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, srcTok), - emitc::OpaqueAttr::get(ctx, dstTok), - emitc::OpaqueAttr::get(ctx, evtTok), - }); - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, "set_flag", - /*args=*/argsAttr, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{}); - return success(); - } -}; - -struct PTOWaitFlagToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::WaitFlagOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - auto *ctx = rewriter.getContext(); - - std::string srcTok, dstTok, evtTok; - if (failed(extractSyncTokens(op, srcTok, dstTok, evtTok, rewriter))) - return failure(); - - auto argsAttr = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, srcTok), - emitc::OpaqueAttr::get(ctx, dstTok), - emitc::OpaqueAttr::get(ctx, evtTok), - }); - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, "wait_flag", - /*args=*/argsAttr, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{}); - return success(); - } -}; - -struct PTOSyncToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::TSyncOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - SmallVector operands; - operands.reserve(adaptor.getEvents().size()); - for (Value event : adaptor.getEvents()) - operands.push_back(peelUnrealized(event)); - - rewriter.create( - op.getLoc(), TypeRange{}, "TSYNC", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange(operands)); - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOSyncAllToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - static StringRef coreTypeTok(pto::SyncCoreType coreType) { - switch (coreType) { - case pto::SyncCoreType::AIVOnly: - return "SyncCoreType::AIVOnly"; - case pto::SyncCoreType::AICOnly: - return "SyncCoreType::AICOnly"; - case pto::SyncCoreType::Mix: - return "SyncCoreType::Mix"; - } - llvm_unreachable("unhandled SyncCoreType"); - } - - LogicalResult matchAndRewrite(mlir::pto::SyncAllOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto mode = op.getMode().getValue(); - auto coreType = op.getCoreType().getValue(); - - auto buildGmWorkspace = [&]() -> FailureOr { - Value gm = peelUnrealized(adaptor.getGmWorkspace()); - if (isEmitCGlobalTensorLikeType(gm.getType())) - return gm; - - auto memTy = dyn_cast(op.getGmWorkspace().getType()); - if (!memTy) - return failure(); - - Value gt = buildGlobalTensorFromMemref(rewriter, op.getLoc(), gm, memTy, - op.getGmWorkspace().getDefiningOp() - ? op.getGmWorkspace().getDefiningOp() - : op.getOperation()); - if (!gt) - return failure(); - return gt; - }; - - if (mode == pto::SyncAllMode::Hard) { - std::string callee = "SYNCALL<" + coreTypeTok(coreType).str() + ">"; - rewriter.create(op.getLoc(), TypeRange{}, callee, - ArrayAttr{}, ArrayAttr{}, - ValueRange{}); - rewriter.eraseOp(op); - return success(); - } - - FailureOr gmWorkspace = buildGmWorkspace(); - if (failed(gmWorkspace)) - return rewriter.notifyMatchFailure(op, - "failed to build gm_workspace GlobalTensor"); - - auto i32Ty = emitc::OpaqueType::get(rewriter.getContext(), "int32_t"); - Value usedCores = adaptor.getUsedCores() - ? peelUnrealized(adaptor.getUsedCores()) - : makeEmitCIntConstant(rewriter, op.getLoc(), i32Ty, 0); - if (usedCores.getType() != i32Ty) - usedCores = rewriter.create(op.getLoc(), i32Ty, usedCores) - .getResult(); - - std::string callee = - "SYNCALL"; - - SmallVector operands{*gmWorkspace}; - switch (coreType) { - case pto::SyncCoreType::AIVOnly: { - FailureOr ubWorkspace = - buildSyncAllWorkspaceTileValue(rewriter, op.getLoc(), - op.getUbWorkspace(), - adaptor.getUbWorkspace()); - if (failed(ubWorkspace)) - return rewriter.notifyMatchFailure( - op, "failed to materialize ub_workspace tile"); - operands.push_back(*ubWorkspace); - break; - } - case pto::SyncCoreType::AICOnly: { - FailureOr l1Workspace = - buildSyncAllWorkspaceTileValue(rewriter, op.getLoc(), - op.getL1Workspace(), - adaptor.getL1Workspace()); - if (failed(l1Workspace)) - return rewriter.notifyMatchFailure( - op, "failed to materialize l1_workspace tile"); - operands.push_back(*l1Workspace); - break; - } - case pto::SyncCoreType::Mix: { - FailureOr ubWorkspace = - buildSyncAllWorkspaceTileValue(rewriter, op.getLoc(), - op.getUbWorkspace(), - adaptor.getUbWorkspace()); - FailureOr l1Workspace = - buildSyncAllWorkspaceTileValue(rewriter, op.getLoc(), - op.getL1Workspace(), - adaptor.getL1Workspace()); - if (failed(ubWorkspace) || failed(l1Workspace)) - return rewriter.notifyMatchFailure( - op, "failed to materialize mixed syncall workspace tiles"); - operands.push_back(*ubWorkspace); - operands.push_back(*l1Workspace); - break; - } - } - - operands.push_back(usedCores); - rewriter.create(op.getLoc(), TypeRange{}, callee, - ArrayAttr{}, ArrayAttr{}, - ValueRange(operands)); - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOSyncFlagDynToEmitC : public ConversionPattern { - PTOSyncFlagDynToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - StringRef opName, StringRef callee) - : ConversionPattern(typeConverter, opName, /*benefit=*/1, ctx), - callee(callee.str()) {} - - LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - if (operands.size() != 1) - return rewriter.notifyMatchFailure(op, "expected exactly one dynamic event-id operand"); - - auto srcAttr = op->getAttrOfType("src_pipe"); - auto dstAttr = op->getAttrOfType("dst_pipe"); - if (!srcAttr || !dstAttr) - return rewriter.notifyMatchFailure(op, "missing PipeAttr src_pipe/dst_pipe attrs"); - - auto *ctx = rewriter.getContext(); - std::string srcTok = pipeTokFromPipeAttr(srcAttr); - std::string dstTok = pipeTokFromPipeAttr(dstAttr); - - Value eventVal = operands.front(); - eventVal = - emitCCast(rewriter, op->getLoc(), emitc::OpaqueType::get(ctx, "event_t"), eventVal); - - auto argsAttr = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, srcTok), - emitc::OpaqueAttr::get(ctx, dstTok), - IntegerAttr::get(IndexType::get(ctx), 0), - }); - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, callee, - /*args=*/argsAttr, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{eventVal}); - return success(); - } - -private: - std::string callee; -}; - -struct PTOGetBufToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::GetBufOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - auto *ctx = rewriter.getContext(); - - auto opTypeOr = parseSyncOpTypeLikeAttr(op.getOpTypeAttr()); - if (failed(opTypeOr)) - return rewriter.notifyMatchFailure(op, "get_buf expects pipe_event_type/sync_op_type attr"); - auto pipe = mapSyncOpTypeToPipe(*opTypeOr); - if (!isConcreteSyncPipe(pipe)) - return rewriter.notifyMatchFailure(op, "get_buf op_type cannot map to a concrete pipe"); - std::string pipeTok = pipeTokFromPipeEnum(pipe); - auto argsAttr = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeTok), - op.getBufIdAttr(), - op.getModeAttr(), - }); - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, "get_buf", - /*args=*/argsAttr, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{}); - return success(); - } -}; - -struct PTORlsBufToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::RlsBufOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - auto *ctx = rewriter.getContext(); - - auto opTypeOr = parseSyncOpTypeLikeAttr(op.getOpTypeAttr()); - if (failed(opTypeOr)) - return rewriter.notifyMatchFailure(op, "rls_buf expects pipe_event_type/sync_op_type attr"); - auto pipe = mapSyncOpTypeToPipe(*opTypeOr); - if (!isConcreteSyncPipe(pipe)) - return rewriter.notifyMatchFailure(op, "rls_buf op_type cannot map to a concrete pipe"); - std::string pipeTok = pipeTokFromPipeEnum(pipe); - auto argsAttr = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeTok), - op.getBufIdAttr(), - op.getModeAttr(), - }); - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, "rls_buf", - /*args=*/argsAttr, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{}); - return success(); - } -}; - -struct PTOSetFFTsToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::SetFFTsOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto *ctx = rewriter.getContext(); - auto loc = op.getLoc(); - - Value fftsAddr = peelUnrealized(adaptor.getFfts()); - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - - if (isSetFFTsPointerLikeType(fftsAddr.getType())) { - auto castTyAttr = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - fftsAddr = - rewriter - .create(loc, u64Ty, "reinterpret_cast", - /*args=*/ArrayAttr{}, - /*templateArgs=*/castTyAttr, - /*operands=*/ValueRange{fftsAddr}) - .getResult(0); - } else if (fftsAddr.getType() != u64Ty) { - fftsAddr = - rewriter.create(loc, u64Ty, fftsAddr).getResult(); - } - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, "set_ffts_base_addr", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{fftsAddr}); - return success(); - } -}; - -struct PTOSyncSetToEmitC : public OpConversionPattern { - PTOSyncSetToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult - matchAndRewrite(mlir::pto::SyncSetOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - auto *ctx = rewriter.getContext(); - IntegerAttr eventIdAttr = op.getEventIdAttr(); - Value eventIdDyn = adaptor.getEventIdDyn(); - int64_t fftsMode = 2; - if (IntegerAttr fftsModeAttr = op.getFftsModeAttr()) - fftsMode = fftsModeAttr.getInt(); - - if ((eventIdAttr != nullptr) == static_cast(eventIdDyn)) - return rewriter.notifyMatchFailure( - op, "expects exactly one of static event_id attr or dynamic event_id operand"); - - // A5 inter-core sync mirrors +16 only for cube-side producer (PIPE_FIX). - // Vec-side producer (PIPE_MTE3) emits a single set; hardware handles the - // subblock mapping in PTO-ISA custom flow. - if (targetArch == PTOArch::A5) { - pto::PIPE pipe = op.getPipe().getPipe(); - bool needsMirrorPlus16 = (pipe == pto::PIPE::PIPE_FIX); - std::string pipeTok = pipeTokFromPipeAttr(op.getPipe()); - auto emitSet = [&](Value eventOperand, IntegerAttr eventLiteral, - bool isDynamic) { - if (isDynamic) { - auto args = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeTok), - IntegerAttr::get(IndexType::get(ctx), 0), - }); - rewriter.create(loc, TypeRange{}, "set_intra_block", - /*args=*/args, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{eventOperand}); - return; - } - auto args = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeTok), - eventLiteral, - }); - rewriter.create(loc, TypeRange{}, "set_intra_block", - /*args=*/args, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{}); - }; - - if (eventIdAttr) { - emitSet(Value{}, eventIdAttr, /*isDynamic=*/false); - if (needsMirrorPlus16) { - auto plus16 = IntegerAttr::get(eventIdAttr.getType(), - eventIdAttr.getInt() + 16); - emitSet(Value{}, plus16, /*isDynamic=*/false); - } - } else { - Value eventI32 = castInterCoreEventIdToI32(rewriter, loc, eventIdDyn); - emitSet(eventI32, IntegerAttr{}, /*isDynamic=*/true); - if (needsMirrorPlus16) { - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value c16 = makeEmitCIntConstant(rewriter, loc, i32Ty, 16); - Value eventI32Plus16 = - rewriter.create(loc, i32Ty, eventI32, c16).getResult(); - emitSet(eventI32Plus16, IntegerAttr{}, /*isDynamic=*/true); - } - } - - rewriter.eraseOp(op); - return success(); - } - - InterCoreSyncCallDesc desc; - if (eventIdAttr) { - desc = buildInterCoreSyncSetCall(rewriter, loc, targetArch, op.getPipe(), - eventIdAttr, fftsMode); - } else { - desc = buildInterCoreSyncSetCallDyn(rewriter, loc, targetArch, op.getPipe(), - eventIdDyn, fftsMode); - } - rewriter.create(loc, TypeRange{}, desc.callee, - /*args=*/desc.args, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/desc.operands); - - rewriter.eraseOp(op); - return success(); - } - - PTOArch targetArch; -}; - -struct PTOSyncWaitToEmitC : public OpConversionPattern { - PTOSyncWaitToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult - matchAndRewrite(mlir::pto::SyncWaitOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - IntegerAttr eventIdAttr = op.getEventIdAttr(); - Value eventIdDyn = adaptor.getEventIdDyn(); - - if ((eventIdAttr != nullptr) == static_cast(eventIdDyn)) - return rewriter.notifyMatchFailure( - op, "expects exactly one of static event_id attr or dynamic event_id operand"); - - InterCoreSyncCallDesc desc; - if (eventIdAttr) { - desc = buildInterCoreSyncWaitCall(rewriter, targetArch, op.getPipe(), - eventIdAttr); - } else { - desc = buildInterCoreSyncWaitCallDyn(rewriter, loc, targetArch, op.getPipe(), - eventIdDyn); - } - rewriter.create(loc, TypeRange{}, desc.callee, - desc.args, ArrayAttr{}, desc.operands); - - rewriter.eraseOp(op); - return success(); - } - - PTOArch targetArch; -}; - -// GetBlockIdxOp Lowering (pto.get_block_idx -> get_block_idx()) -struct PTOGetBlockIdxToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::pto::GetBlockIdxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - rewriter.replaceOpWithNewOp( - op, op.getType(), "get_block_idx", ValueRange{}, ArrayAttr{}, - ArrayAttr{}); - - return success(); - } -}; - -// GetBlockNumOp Lowering (pto.get_block_num -> get_block_num()) -struct PTOGetBlockNumToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::pto::GetBlockNumOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - rewriter.replaceOpWithNewOp( - op, op.getType(), "get_block_num", ValueRange{}, ArrayAttr{}, - ArrayAttr{}); - - return success(); - } -}; - -// GetSubBlockIdxOp Lowering (pto.get_block_idx -> get_subblockid()) -struct PTOGetSubBlockIdxToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::pto::GetSubBlockIdxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - rewriter.replaceOpWithNewOp( - op, op.getType(), "get_subblockid", ValueRange{}, ArrayAttr{}, - ArrayAttr{}); - - return success(); - } -}; - -// GetSubBlockNumOp Lowering. -struct PTOGetSubBlockNumToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::pto::GetSubBlockNumOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - rewriter.replaceOpWithNewOp( - op, op.getType(), "get_subblockdim", ValueRange{}, ArrayAttr{}, - ArrayAttr{}); - - return success(); - } -}; - - -struct PTOMScatterToMSCATTER : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::MScatterOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto *ctx = rewriter.getContext(); - Value src = peelUnrealized(adaptor.getSrc()); - Value idx = peelUnrealized(adaptor.getIdx()); - Value mem = peelUnrealized(adaptor.getMem()); - - Value memArg = maybeWrapGlobalMemrefAsGlobalTensor( - rewriter, op.getLoc(), mem, op.getMem().getType(), op.getOperation()); - - auto scatterAtomicTok = [&](pto::ScatterAtomicOp atomic) -> StringRef { - switch (atomic) { - case pto::ScatterAtomicOp::None: - return "pto::ScatterAtomicOp::None"; - case pto::ScatterAtomicOp::Add: - return "pto::ScatterAtomicOp::Add"; - case pto::ScatterAtomicOp::Max: - return "pto::ScatterAtomicOp::Max"; - case pto::ScatterAtomicOp::Min: - return "pto::ScatterAtomicOp::Min"; - } - llvm_unreachable("unknown ScatterAtomicOp"); - }; - auto scatterOobTok = [&](pto::ScatterOOB mode) -> StringRef { - switch (mode) { - case pto::ScatterOOB::Undefined: - return "pto::ScatterOOB::Undefined"; - case pto::ScatterOOB::Skip: - return "pto::ScatterOOB::Skip"; - case pto::ScatterOOB::Clamp: - return "pto::ScatterOOB::Clamp"; - case pto::ScatterOOB::Wrap: - return "pto::ScatterOOB::Wrap"; - } - llvm_unreachable("unknown ScatterOOB"); - }; - - SmallVector templateArgVec; - const bool rowCoalesce = - isRowCoalescedMGatherIndexType(op.getSrc().getType(), op.getIdx().getType()); - templateArgVec.push_back(emitc::OpaqueAttr::get( - ctx, rowCoalesce ? "pto::Coalesce::Row" : "pto::Coalesce::Elem")); - if (op.getScatterAtomicOp() != pto::ScatterAtomicOp::None || - op.getScatterOob() != pto::ScatterOOB::Undefined) { - templateArgVec.push_back(emitc::OpaqueAttr::get( - ctx, scatterAtomicTok(op.getScatterAtomicOp()))); - if (op.getScatterOob() != pto::ScatterOOB::Undefined) - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, scatterOobTok(op.getScatterOob()))); - } - ArrayAttr templateArgs = rewriter.getArrayAttr(templateArgVec); - - rewriter.create( - op.getLoc(), TypeRange{}, "MSCATTER", - ArrayAttr{}, templateArgs, - ValueRange{memArg, src, idx}); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTOSetValToSETVAL : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSetValOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value dst = peelUnrealized(adaptor.getDst()); - Value val = peelUnrealized(adaptor.getVal()); - - // ---- offset: SSA index operand ---- - Value offset = peelUnrealized(adaptor.getOffset()); - - // Emit a marker call and let the ptoas post-processing step lower it to - // the corresponding tile setter. - rewriter.create( - op.getLoc(), TypeRange{}, "PTOAS__TILE_SET_VALUE", - ArrayAttr{}, ArrayAttr{}, ValueRange{dst, offset, val}); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTOGetValToGETVAL : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGetValOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src = peelUnrealized(adaptor.getSrc()); - - // ---- offset: SSA index operand ---- - Value offset = peelUnrealized(adaptor.getOffset()); - - // Emit a marker call and let the ptoas post-processing step lower it to - // the corresponding tile getter. - Type dstTy = getTypeConverter()->convertType(op.getDst().getType()); - if (!dstTy) - return failure(); - auto call = rewriter.create( - op.getLoc(), - TypeRange{dstTy}, - "PTOAS__TILE_GET_VALUE", - ArrayAttr{}, ArrayAttr{}, - ValueRange{src, offset}); - - rewriter.replaceOp(op, call.getResults()); - return success(); - } -}; - -struct PTOTAxpyToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAxpyOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value scalar = peelUnrealized(adaptor.getScalar()); - - rewriter.create( - loc, TypeRange{}, "TAXPY", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOHistogramToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::THistogramOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value idx = peelUnrealized(adaptor.getIdx()); - Value dst = peelUnrealized(adaptor.getDst()); - - auto templateArgs = rewriter.getArrayAttr({emitc::OpaqueAttr::get( - ctx, op.getIsMSB() ? "HistByte::BYTE_1" : "HistByte::BYTE_0")}); - rewriter.create( - loc, TypeRange{}, "THISTOGRAM", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, - /*operands=*/ValueRange{dst, src, idx}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOGetScaleAddrToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGetScaleAddrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TGET_SCALE_ADDR", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOSetValidShapeToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::SetValidShapeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto peelAllCasts = [](Value v) { - while (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(0); - if (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(); - return v; - }; - auto isTileLike = [](Value v) -> bool { - auto ot = dyn_cast(v.getType()); - if (!ot) - return false; - StringRef s = ot.getValue(); - return s.contains("Tile<") || s.contains("ConvTile<"); - }; - - Value src = peelAllCasts(peelUnrealized(adaptor.getSource())); - Value row = peelUnrealized(adaptor.getValidRow()); - Value col = peelUnrealized(adaptor.getValidCol()); - - if (!isTileLike(src)) - return rewriter.notifyMatchFailure( - op, "set_validshape source must lower to a tile-like value"); - - rewriter.create( - op.getLoc(), TypeRange{}, "PTOAS__TILE_SET_VALIDSHAPE", ArrayAttr{}, - ArrayAttr{}, ValueRange{src, row, col}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOGetValidShapeToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::GetValidShapeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto peelAllCasts = [](Value v) { - while (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(0); - if (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(); - return v; - }; - auto isTileLike = [](Value v) -> bool { - auto ot = dyn_cast(v.getType()); - if (!ot) - return false; - StringRef s = ot.getValue(); - return s.contains("Tile<") || s.contains("ConvTile<"); - }; - - Value src = peelAllCasts(peelUnrealized(adaptor.getSource())); - if (!isTileLike(src)) - return rewriter.notifyMatchFailure( - op, "get_validshape source must lower to a tile-like value"); - - auto resultTy = getTypeConverter()->convertType(rewriter.getIndexType()); - if (!resultTy) - return failure(); - - Value row = rewriter - .create( - op.getLoc(), resultTy, - "PTOAS__TILE_GET_VALID_ROW", ArrayAttr{}, - ArrayAttr{}, ValueRange{src}) - .getResult(0); - Value col = rewriter - .create( - op.getLoc(), resultTy, - "PTOAS__TILE_GET_VALID_COL", ArrayAttr{}, - ArrayAttr{}, ValueRange{src}) - .getResult(0); - rewriter.replaceOp(op, ValueRange{row, col}); - return success(); - } -}; - -struct PTOTAssignToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAssignOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto peelAllCasts = [](Value v) { - while (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(0); - if (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(); - return v; - }; - auto isTileLike = [](Value v) -> bool { - auto ot = dyn_cast(v.getType()); - if (!ot) - return false; - StringRef s = ot.getValue(); - return s.contains("Tile<") || s.contains("ConvTile<"); - }; - - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value tile = peelAllCasts(peelUnrealized(adaptor.getTile())); - if (!isTileLike(tile)) - return rewriter.notifyMatchFailure( - op, "tassign tile must lower to a tile-like value"); - - Value addr = peelUnrealized(adaptor.getAddr()); - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - if (isa(addr.getType()) || - (isa(addr.getType()) && - cast(addr.getType()).getValue().ends_with("*"))) { - auto rcU64 = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - addr = rewriter - .create(loc, u64Ty, "reinterpret_cast", - ArrayAttr{}, rcU64, - ValueRange{addr}) - .getResult(0); - } else if (addr.getType() != u64Ty) { - addr = rewriter.create(loc, u64Ty, addr).getResult(); - } - - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{tile, addr}); - rewriter.replaceOp(op, tile); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.load_scalar / pto.store_scalar lowering -> ptr[offset] -//===----------------------------------------------------------------------===// - -static Type getPointerLikeElementType(Type type) { - if (auto ptrTy = dyn_cast(type)) - return ptrTy.getElementType(); - if (auto memTy = dyn_cast(type)) - return memTy.getElementType(); - return Type(); -} - -struct PTOPtrToIntToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::PtrToIntOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value ptr = peelUnrealized(adaptor.getPtr()); - Type dstTy = getTypeConverter()->convertType(op.getResult().getType()); - if (!dstTy) - return failure(); - - auto dstOpaque = dyn_cast(dstTy); - if (!dstOpaque) - return failure(); - - auto templateArgs = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), - dstOpaque.getValue())}); - auto cast = rewriter.create( - op.getLoc(), dstTy, "reinterpret_cast", ArrayAttr{}, templateArgs, - ValueRange{ptr}); - rewriter.replaceOp(op, cast.getResult(0)); - return success(); - } -}; - -struct PTOIntToPtrToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::IntToPtrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value addr = peelUnrealized(adaptor.getAddr()); - Type dstTy = getTypeConverter()->convertType(op.getResult().getType()); - if (!dstTy) - return failure(); - - Type dstElemTy = getPointerLikeElementType(op.getResult().getType()); - if (!dstElemTy) - return failure(); - - std::string castType = - std::string("__gm__ ") + getEmitCScalarTypeToken(dstElemTy) + "*"; - auto templateArgs = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), - castType)}); - auto cast = rewriter.create( - op.getLoc(), dstTy, "reinterpret_cast", ArrayAttr{}, templateArgs, - ValueRange{addr}); - rewriter.replaceOp(op, cast.getResult(0)); - return success(); - } -}; - -struct PTOLoadScalarToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::LoadScalarOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value ptr = peelUnrealized(adaptor.getPtr()); - Value offset = peelUnrealized(adaptor.getOffset()); - - Type dstTy = getTypeConverter()->convertType(op.getValue().getType()); - if (!dstTy) - return failure(); - - auto call = rewriter.create( - op.getLoc(), TypeRange{dstTy}, "PTOAS__PTR_LOAD", - ArrayAttr{}, ArrayAttr{}, ValueRange{ptr, offset}); - - rewriter.replaceOp(op, call.getResults()); - return success(); - } -}; - -struct PTOStoreScalarToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::StoreScalarOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value ptr = peelUnrealized(adaptor.getPtr()); - Value offset = peelUnrealized(adaptor.getOffset()); - Value val = peelUnrealized(adaptor.getValue()); - - rewriter.create( - op.getLoc(), TypeRange{}, "PTOAS__PTR_STORE", - ArrayAttr{}, ArrayAttr{}, ValueRange{ptr, offset, val}); - rewriter.create( - op.getLoc(), TypeRange{}, "PTOAS__SCALAR_GM_STORE_FLUSH", - ArrayAttr{}, ArrayAttr{}, ValueRange{ptr}); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.tabs lowering -> TABS(dst, src) -//===----------------------------------------------------------------------===// - -struct PTOTAbsToTABS : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAbsOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - // intrinsic: TABS(dst, src) - rewriter.create( - op.getLoc(), TypeRange{}, "TABS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tadd lowering -> TADD(dst, src0, src1) -//===----------------------------------------------------------------------===// - -struct PTOTAddToTADD : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAddOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TADD", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOInitializeL2G2LPipeToEmitC - : public OpConversionPattern { - PTOInitializeL2G2LPipeToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult matchAndRewrite(mlir::pto::InitializeL2G2LPipeOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto tpipeTok = buildTPipeTokenFromInitOp(op.getOperation(), targetArch); - if (failed(tpipeTok)) - return rewriter.notifyMatchFailure(op, "failed to build TPipe token"); - - auto *ctx = rewriter.getContext(); - auto emitPipeTy = - cast(getTypeConverter()->convertType(op.getPipe().getType())); - - Value gmAddr = peelUnrealized(adaptor.getGmAddr()); - gmAddr = materializeTensorViewDataPointer( - rewriter, op.getLoc(), gmAddr, op.getGmAddr().getType()); - Value localAddr = - op.getLocalAddr() ? peelUnrealized(adaptor.getLocalAddr()) : Value(); - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value zero = makeEmitCIntConstant(rewriter, op.getLoc(), i32Ty, 0); - - Value c2vBuf = zero; - Value v2cBuf = zero; - if (op.getDirMask() == 1) - c2vBuf = localAddr ? localAddr : zero; - else if (op.getDirMask() == 2) - v2cBuf = localAddr ? localAddr : zero; - else if (op.getDirMask() == 3) { - if (localAddr) { - if (!op.getPeerLocalAddr()) - return rewriter.notifyMatchFailure( - op, "bidirectional l2g2l pipe requires peer local buffer"); - c2vBuf = localAddr; - v2cBuf = peelUnrealized(adaptor.getPeerLocalAddr()); - } - } else - return rewriter.notifyMatchFailure(op, "unsupported dir_mask"); - - rewriter.replaceOpWithNewOp( - op, TypeRange{emitPipeTy}, *tpipeTok, ArrayAttr{}, ArrayAttr{}, - ValueRange{gmAddr, c2vBuf, v2cBuf}); - return success(); - } - - PTOArch targetArch; -}; - -struct PTOInitializeL2LPipeToEmitC - : public OpConversionPattern { - PTOInitializeL2LPipeToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult matchAndRewrite(mlir::pto::InitializeL2LPipeOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto tpipeTok = buildTPipeTokenFromInitOp(op.getOperation(), targetArch); - if (failed(tpipeTok)) - return rewriter.notifyMatchFailure(op, "failed to build TPipe token"); - - auto *ctx = rewriter.getContext(); - auto emitPipeTy = - cast(getTypeConverter()->convertType(op.getPipe().getType())); - - auto gmPtrTy = - emitc::PointerType::get(emitc::OpaqueType::get(ctx, "__gm__ void")); - Value nullGm = - makeEmitCOpaqueConstant(rewriter, op.getLoc(), gmPtrTy, "nullptr"); - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value zero = makeEmitCIntConstant(rewriter, op.getLoc(), i32Ty, 0); - Value localAddr = peelUnrealized(adaptor.getLocalAddr()); - - Value c2vBuf = zero; - Value v2cBuf = zero; - if (op.getDirMask() == 1) - c2vBuf = localAddr; - else if (op.getDirMask() == 2) - v2cBuf = localAddr; - else if (op.getDirMask() == 3) { - c2vBuf = localAddr; - v2cBuf = peelUnrealized(adaptor.getPeerLocalAddr()); - } else - return rewriter.notifyMatchFailure(op, "unsupported dir_mask"); - - rewriter.replaceOpWithNewOp( - op, TypeRange{emitPipeTy}, *tpipeTok, ArrayAttr{}, ArrayAttr{}, - ValueRange{nullGm, c2vBuf, v2cBuf}); - return success(); - } - - PTOArch targetArch; -}; - -struct PTOBuildAsyncSessionToEmitC - : public OpConversionPattern { - PTOBuildAsyncSessionToEmitC(TypeConverter &typeConverter, MLIRContext *ctx) - : OpConversionPattern(typeConverter, ctx) {} - - LogicalResult matchAndRewrite(mlir::pto::BuildAsyncSessionOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto *ctx = rewriter.getContext(); - Location loc = op.getLoc(); - - auto sessionTy = - dyn_cast(getTypeConverter()->convertType(op.getSession().getType())); - if (!sessionTy) - return rewriter.notifyMatchFailure(op, "failed to convert async session type"); - - FailureOr scratchTile = - buildAsyncScratchTileValue(rewriter, loc, op.getScratch(), - adaptor.getScratch()); - if (failed(scratchTile)) - return rewriter.notifyMatchFailure(op, "failed to materialize async scratch tile"); - - Value workspace = - castToGMBytePointer(rewriter, loc, peelUnrealized(adaptor.getWorkspace())); - - Value session = rewriter - .create( - loc, sessionTy, emitc::OpaqueAttr::get(ctx, "")) - .getResult(); - - auto u32Ty = emitc::OpaqueType::get(ctx, "uint32_t"); - - auto makeU32Const = [&](uint64_t value) -> Value { - return makeEmitCOpaqueConstant(rewriter, loc, u32Ty, - std::to_string(value) + "u"); - }; - uint64_t syncId = op.getSyncIdAttr() ? op.getSyncIdAttr().getInt() : 0; - uint64_t blockBytes = - op.getBlockBytesAttr() ? op.getBlockBytesAttr().getInt() : 32 * 1024; - uint64_t commBlockOffset = - op.getCommBlockOffsetAttr() ? op.getCommBlockOffsetAttr().getInt() : 0; - uint64_t queueNum = op.getQueueNumAttr() ? op.getQueueNumAttr().getInt() : 1; - uint64_t channelGroupIdx = op.getChannelGroupIdxAttr() - ? op.getChannelGroupIdxAttr().getInt() - : UINT32_MAX; - - Value syncIdVal = makeU32Const(syncId); - Value channelGroupIdxVal = - channelGroupIdx == UINT32_MAX - ? makeEmitCOpaqueConstant(rewriter, loc, u32Ty, "UINT32_MAX") - : makeU32Const(channelGroupIdx); - - auto baseConfigTy = - emitc::OpaqueType::get(ctx, "pto::comm::sdma::SdmaBaseConfig"); - Value baseConfig = - rewriter - .create( - loc, baseConfigTy, - emitc::OpaqueAttr::get( - ctx, "{" + std::to_string(blockBytes) + "ULL, " + - std::to_string(commBlockOffset) + "ULL, " + - std::to_string(queueNum) + "u}")) - .getResult(); - - rewriter.create( - loc, TypeRange{}, "pto::comm::BuildAsyncSession", - ArrayAttr{}, ArrayAttr{}, - ValueRange{*scratchTile, workspace, session, syncIdVal, baseConfig, - channelGroupIdxVal}); - - rewriter.replaceOp(op, session); - return success(); - } -}; - -template -struct PTOAsyncTransferToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - explicit PTOAsyncTransferToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - StringRef callee) - : OpConversionPattern(typeConverter, ctx), callee(callee.str()) {} - - LogicalResult matchAndRewrite(AsyncOp op, typename AsyncOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value dstGT = dst; - Value srcGT = src; - if (!isEmitCGlobalTensorLikeType(dstGT.getType())) { - auto dstMrTy = dyn_cast(op.getDst().getType()); - if (!dstMrTy) - return rewriter.notifyMatchFailure(op, "expected dst to lower to GlobalTensor or memref"); - dstGT = buildGlobalTensorFromMemref(rewriter, op.getLoc(), dst, dstMrTy, - op.getDst().getDefiningOp() - ? op.getDst().getDefiningOp() - : op.getOperation()); - } - if (!isEmitCGlobalTensorLikeType(srcGT.getType())) { - auto srcMrTy = dyn_cast(op.getSrc().getType()); - if (!srcMrTy) - return rewriter.notifyMatchFailure(op, "expected src to lower to GlobalTensor or memref"); - srcGT = buildGlobalTensorFromMemref(rewriter, op.getLoc(), src, srcMrTy, - op.getSrc().getDefiningOp() - ? op.getSrc().getDefiningOp() - : op.getOperation()); - } - if (!dstGT || !srcGT) - return rewriter.notifyMatchFailure(op, "failed to build GlobalTensor operands"); - - Type eventTy = this->getTypeConverter()->convertType(op.getEvent().getType()); - if (!eventTy) - return rewriter.notifyMatchFailure(op, "failed to convert async event type"); - - rewriter.replaceOpWithNewOp( - op, TypeRange{eventTy}, callee, ArrayAttr{}, ArrayAttr{}, - ValueRange{dstGT, srcGT, peelUnrealized(adaptor.getSession())}); - return success(); - } - - std::string callee; -}; - -template -struct PTOAsyncEventToEmitC : public OpConversionPattern { - explicit PTOAsyncEventToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - StringRef callee) - : OpConversionPattern(typeConverter, ctx), - callee(callee.str()) {} - - LogicalResult matchAndRewrite(AsyncEventOp op, - typename AsyncEventOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type resultTy = - this->getTypeConverter()->convertType(op.getCompleted().getType()); - if (!resultTy) - return rewriter.notifyMatchFailure(op, "failed to convert async event result type"); - - rewriter.replaceOpWithNewOp( - op, TypeRange{resultTy}, callee, ArrayAttr{}, ArrayAttr{}, - ValueRange{peelUnrealized(adaptor.getEvent()), - peelUnrealized(adaptor.getSession())}); - return success(); - } - - std::string callee; -}; - -static FailureOr buildCommGlobalTensorValue( - ConversionPatternRewriter &rewriter, Location loc, Value originalValue, - Value emittedValue, Operation *anchor) { - Value value = peelUnrealized(emittedValue); - if (isEmitCGlobalTensorLikeType(value.getType())) - return value; - - auto memTy = dyn_cast(originalValue.getType()); - if (!memTy) - return failure(); - - Value gt = buildGlobalTensorFromMemref(rewriter, loc, value, memTy, anchor); - if (!gt) - return failure(); - return gt; -} - -static FailureOr buildCommTileValue(ConversionPatternRewriter &rewriter, - Location loc, Value originalValue, - Value emittedValue) { - Value value = peelUnrealized(emittedValue); - if (auto opaqueTy = dyn_cast(value.getType())) { - StringRef typeStr = opaqueTy.getValue(); - if (typeStr.contains("Tile<") || typeStr.contains("ConvTile<")) - return value; - } - return buildAsyncScratchTileValue(rewriter, loc, originalValue, emittedValue); -} - -static FailureOr buildCollectiveParallelGroup( - ConversionPatternRewriter &rewriter, Location loc, - ArrayRef groupGTs, int64_t root) { - if (groupGTs.empty()) - return failure(); - - auto firstTy = dyn_cast(groupGTs.front().getType()); - if (!firstTy) - return failure(); - - auto *ctx = rewriter.getContext(); - auto arrayTy = emitc::ArrayType::get({static_cast(groupGTs.size())}, - firstTy); - auto groupArray = cast>( - rewriter - .create(loc, arrayTy, - emitc::OpaqueAttr::get(ctx, "{}")) - .getResult()); - - auto indexTy = emitc::OpaqueType::get(ctx, "int"); - for (auto [idx, groupVal] : llvm::enumerate(groupGTs)) { - Value idxVal = - makeEmitCIntConstant(rewriter, loc, indexTy, static_cast(idx)); - Value slot = - rewriter.create(loc, groupArray, ValueRange{idxVal}) - .getResult(); - rewriter.create(loc, slot, groupVal); - } - - std::string pgTypeStr = - (Twine("pto::comm::ParallelGroup<") + firstTy.getValue() + ">").str(); - auto pgTy = emitc::OpaqueType::get(ctx, pgTypeStr); - Value sizeVal = makeEmitCIntConstant(rewriter, loc, indexTy, - static_cast(groupGTs.size())); - Value rootVal = makeEmitCIntConstant(rewriter, loc, indexTy, root); - return rewriter - .create( - loc, TypeRange{pgTy}, (Twine(pgTypeStr) + "::Create").str(), - ArrayAttr{}, ArrayAttr{}, ValueRange{groupArray, sizeVal, rootVal}) - .getResult(0); -} - -static std::string notifyOpTok(pto::NotifyOp op) { - switch (op) { - case pto::NotifyOp::AtomicAdd: - return "pto::comm::NotifyOp::AtomicAdd"; - case pto::NotifyOp::Set: - return "pto::comm::NotifyOp::Set"; - } - return "pto::comm::NotifyOp::Set"; -} - -static std::string waitCmpTok(pto::WaitCmp cmp) { - switch (cmp) { - case pto::WaitCmp::EQ: - return "pto::comm::WaitCmp::EQ"; - case pto::WaitCmp::NE: - return "pto::comm::WaitCmp::NE"; - case pto::WaitCmp::GT: - return "pto::comm::WaitCmp::GT"; - case pto::WaitCmp::GE: - return "pto::comm::WaitCmp::GE"; - case pto::WaitCmp::LT: - return "pto::comm::WaitCmp::LT"; - case pto::WaitCmp::LE: - return "pto::comm::WaitCmp::LE"; - } - return "pto::comm::WaitCmp::EQ"; -} - -static std::string reduceOpTok(pto::ReduceOp op) { - switch (op) { - case pto::ReduceOp::Sum: - return "pto::comm::ReduceOp::Sum"; - case pto::ReduceOp::Max: - return "pto::comm::ReduceOp::Max"; - case pto::ReduceOp::Min: - return "pto::comm::ReduceOp::Min"; - } - return "pto::comm::ReduceOp::Sum"; -} - -template -static FailureOr> buildCommGroupGlobalTensors( - ConversionPatternRewriter &rewriter, Location loc, OpTy op, - ValueRange originalGroup, ValueRange emittedGroup) { - SmallVector groupGTs; - groupGTs.reserve(originalGroup.size()); - for (auto [orig, emitted] : llvm::zip(originalGroup, emittedGroup)) { - FailureOr gt = - buildCommGlobalTensorValue(rewriter, loc, orig, emitted, op.getOperation()); - if (failed(gt)) - return failure(); - groupGTs.push_back(*gt); - } - return groupGTs; -} - -template -struct PTOCommCollectiveToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - explicit PTOCommCollectiveToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - StringRef apiName) - : OpConversionPattern(typeConverter, ctx), - apiName(apiName.str()) {} - - LogicalResult matchAndRewrite(CollectiveOp op, typename CollectiveOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - - auto buildPong = [&](Value original, Value emitted, StringRef name) -> FailureOr { - if (!original) - return failure(); - return buildCommTileValue(rewriter, loc, original, emitted); - }; - - if constexpr (std::is_same_v) { - FailureOr srcGT = - buildCommGlobalTensorValue(rewriter, loc, op.getSrc(), adaptor.getSrc(), - op.getOperation()); - FailureOr pingTile = - buildCommTileValue(rewriter, loc, op.getPing(), adaptor.getPing()); - auto groupGTs = - buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); - if (failed(srcGT) || failed(pingTile) || failed(groupGTs)) - return rewriter.notifyMatchFailure(op, "failed to materialize broadcast operands"); - FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); - if (failed(pg)) - return rewriter.notifyMatchFailure(op, "failed to materialize broadcast group"); - if (op.getPong()) { - FailureOr pongTile = - buildPong(op.getPong(), adaptor.getPong(), "__pong"); - if (failed(pongTile)) - return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); - rewriter.create( - loc, TypeRange{}, "pto::comm::TBROADCAST", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *srcGT, *pingTile, *pongTile}); - } else { - rewriter.create( - loc, TypeRange{}, "pto::comm::TBROADCAST", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *srcGT, *pingTile}); - } - } else if constexpr (std::is_same_v) { - FailureOr dstGT = - buildCommGlobalTensorValue(rewriter, loc, op.getDst(), adaptor.getDst(), - op.getOperation()); - FailureOr pingTile = - buildCommTileValue(rewriter, loc, op.getPing(), adaptor.getPing()); - auto groupGTs = - buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); - if (failed(dstGT) || failed(pingTile) || failed(groupGTs)) - return rewriter.notifyMatchFailure(op, "failed to materialize gather operands"); - FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); - if (failed(pg)) - return rewriter.notifyMatchFailure(op, "failed to materialize gather group"); - if (op.getPong()) { - FailureOr pongTile = - buildPong(op.getPong(), adaptor.getPong(), "__pong"); - if (failed(pongTile)) - return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); - rewriter.create( - loc, TypeRange{}, "pto::comm::TGATHER", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *dstGT, *pingTile, *pongTile}); - } else { - rewriter.create( - loc, TypeRange{}, "pto::comm::TGATHER", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *dstGT, *pingTile}); - } - } else if constexpr (std::is_same_v) { - FailureOr srcGT = - buildCommGlobalTensorValue(rewriter, loc, op.getSrc(), adaptor.getSrc(), - op.getOperation()); - FailureOr pingTile = - buildCommTileValue(rewriter, loc, op.getPing(), adaptor.getPing()); - auto groupGTs = - buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); - if (failed(srcGT) || failed(pingTile) || failed(groupGTs)) - return rewriter.notifyMatchFailure(op, "failed to materialize scatter operands"); - FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); - if (failed(pg)) - return rewriter.notifyMatchFailure(op, "failed to materialize scatter group"); - if (op.getPong()) { - FailureOr pongTile = - buildPong(op.getPong(), adaptor.getPong(), "__pong"); - if (failed(pongTile)) - return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); - rewriter.create( - loc, TypeRange{}, "pto::comm::TSCATTER", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *srcGT, *pingTile, *pongTile}); - } else { - rewriter.create( - loc, TypeRange{}, "pto::comm::TSCATTER", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *srcGT, *pingTile}); - } - } else { - FailureOr dstGT = - buildCommGlobalTensorValue(rewriter, loc, op.getDst(), adaptor.getDst(), - op.getOperation()); - FailureOr accTile = - buildCommTileValue(rewriter, loc, op.getAcc(), adaptor.getAcc()); - FailureOr recvPing = - buildCommTileValue(rewriter, loc, op.getRecvPing(), adaptor.getRecvPing()); - auto groupGTs = - buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); - if (failed(dstGT) || failed(accTile) || failed(recvPing) || failed(groupGTs)) - return rewriter.notifyMatchFailure(op, "failed to materialize reduce operands"); - FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); - if (failed(pg)) - return rewriter.notifyMatchFailure(op, "failed to materialize reduce group"); - if (op.getRecvPong()) { - FailureOr recvPong = - buildPong(op.getRecvPong(), adaptor.getRecvPong(), "__recv_pong"); - if (failed(recvPong)) - return rewriter.notifyMatchFailure(op, "failed to materialize recv_pong"); - auto reduceTy = - emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::ReduceOp"); - Value reduceOp = makeEmitCOpaqueConstant(rewriter, loc, reduceTy, - reduceOpTok(op.getReduceOp())); - rewriter.create( - loc, TypeRange{}, "pto::comm::TREDUCE", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *dstGT, *accTile, *recvPing, *recvPong, reduceOp}); - } else { - auto reduceTy = - emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::ReduceOp"); - Value reduceOp = makeEmitCOpaqueConstant(rewriter, loc, reduceTy, - reduceOpTok(op.getReduceOp())); - rewriter.create( - loc, TypeRange{}, "pto::comm::TREDUCE", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *dstGT, *accTile, *recvPing, reduceOp}); - } - } - rewriter.eraseOp(op); - return success(); - } - - std::string apiName; -}; - -template -struct PTOP2PCommToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - explicit PTOP2PCommToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - StringRef callee) - : OpConversionPattern(typeConverter, ctx), callee(callee.str()) {} - - LogicalResult matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - FailureOr dstGT = - buildCommGlobalTensorValue(rewriter, op.getLoc(), op.getDst(), adaptor.getDst(), - op.getOperation()); - FailureOr srcGT = - buildCommGlobalTensorValue(rewriter, op.getLoc(), op.getSrc(), adaptor.getSrc(), - op.getOperation()); - FailureOr pingTile = - buildCommTileValue(rewriter, op.getLoc(), op.getPing(), adaptor.getPing()); - if (failed(dstGT) || failed(srcGT) || failed(pingTile)) - return rewriter.notifyMatchFailure(op, "failed to materialize p2p operands"); - - SmallVector operands{*dstGT, *srcGT, *pingTile}; - std::string actualCallee = callee; - if constexpr (std::is_same_v) { - if (op.getAtomicType() == pto::AtomicType::AtomicAdd) - actualCallee = "pto::comm::TPUT"; - } - if (op.getPong()) { - FailureOr pongTile = - buildCommTileValue(rewriter, op.getLoc(), op.getPong(), adaptor.getPong()); - if (failed(pongTile)) - return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); - operands.push_back(*pongTile); - } - - rewriter.create(op.getLoc(), TypeRange{}, actualCallee, - ArrayAttr{}, ArrayAttr{}, operands); - rewriter.eraseOp(op); - return success(); - } - - std::string callee; -}; - -template -struct PTOSignalCommToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - explicit PTOSignalCommToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - StringRef callee) - : OpConversionPattern(typeConverter, ctx), - callee(callee.str()) {} - - LogicalResult matchAndRewrite(SignalOp op, typename SignalOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - FailureOr signalGT = buildCommGlobalTensorValue( - rewriter, op.getLoc(), op.getSignal(), adaptor.getSignal(), op.getOperation()); - if (failed(signalGT)) - return rewriter.notifyMatchFailure(op, "failed to materialize signal operand"); - - if constexpr (std::is_same_v) { - auto notifyTy = - emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::NotifyOp"); - Value notifyOp = makeEmitCOpaqueConstant( - rewriter, op.getLoc(), notifyTy, notifyOpTok(op.getNotifyOp())); - SmallVector operands{*signalGT, peelUnrealized(adaptor.getValue()), - notifyOp}; - rewriter.create(op.getLoc(), TypeRange{}, callee, - ArrayAttr{}, ArrayAttr{}, operands); - rewriter.eraseOp(op); - } else { - auto waitCmpTy = - emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::WaitCmp"); - Value waitCmp = makeEmitCOpaqueConstant( - rewriter, op.getLoc(), waitCmpTy, waitCmpTok(op.getCmp())); - SmallVector operands{*signalGT, peelUnrealized(adaptor.getCmpValue()), - waitCmp}; - if constexpr (std::is_same_v) { - Type resultTy = this->getTypeConverter()->convertType(op.getResult().getType()); - if (!resultTy) - return rewriter.notifyMatchFailure(op, "failed to convert ttest result type"); - rewriter.replaceOpWithNewOp( - op, TypeRange{resultTy}, callee, ArrayAttr{}, ArrayAttr{}, operands); - } else { - rewriter.create(op.getLoc(), TypeRange{}, callee, - ArrayAttr{}, ArrayAttr{}, operands); - rewriter.eraseOp(op); - } - } - return success(); - } - - std::string callee; -}; - -struct PTODeclareTileMemRefToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::DeclareTileMemRefOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::DeclareTileMemRefOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - Type convertedType = getTypeConverter()->convertType(op.getResult().getType()); - if (!convertedType) - return rewriter.notifyMatchFailure( - op, "failed to convert declare_tile_memref result type"); - rewriter.replaceOp(op, makeEmitCOpaqueConstant(rewriter, op.getLoc(), - convertedType, "nullptr")); - return success(); - } -}; - -struct PTODeclareGlobalToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::DeclareGlobalOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::DeclareGlobalOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - Type convertedType = getTypeConverter()->convertType(op.getEntry().getType()); - if (!convertedType) - return rewriter.notifyMatchFailure( - op, "failed to convert declare_global result type"); - if (auto tvTy = dyn_cast(op.getEntry().getType())) { - if (auto stridesAttr = - op->getAttrOfType(kGlobalTensorStridesAttrName)) { - auto strides = stridesAttr.asArrayRef(); - if (strides.size() == static_cast(tvTy.getRank())) { - convertedType = emitc::OpaqueType::get( - rewriter.getContext(), - getGlobalTensorTypeStringFromShapeAndStrides( - tvTy.getElementType(), tvTy.getShape(), strides)); - } - } - } - auto var = rewriter.create( - op.getLoc(), convertedType, - emitc::OpaqueAttr::get(rewriter.getContext(), "")); - rewriter.replaceOp(op, var.getResult()); - return success(); - } -}; - -struct PTODeclareEventIdArrayToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::DeclareEventIdArrayOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::DeclareEventIdArrayOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - Type arrayTy = getTypeConverter()->convertType(op.getArray().getType()); - if (!arrayTy) - return rewriter.notifyMatchFailure(op, - "failed to map declared eventid_array type"); - - auto array = rewriter - .create( - op.getLoc(), arrayTy, - emitc::OpaqueAttr::get(rewriter.getContext(), "")) - .getResult(); - rewriter.replaceOp(op, array); - return success(); - } -}; - -struct PTOEventIdArrayGetToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::EventIdArrayGetOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::EventIdArrayGetOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value array = peelUnrealized(adaptor.getArray()); - Value index = peelUnrealized(adaptor.getIndex()); - - Type resultTy = getTypeConverter()->convertType(op.getResult().getType()); - if (!resultTy) - return rewriter.notifyMatchFailure(op, - "failed to map eventid_array get result type"); - - auto load = - rewriter.create(op.getLoc(), resultTy, array, index); - rewriter.replaceOp(op, load.getResult()); - return success(); - } -}; - -struct PTOEventIdArraySetToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::EventIdArraySetOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::EventIdArraySetOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value array = peelUnrealized(adaptor.getArray()); - Value index = peelUnrealized(adaptor.getIndex()); - Value value = peelUnrealized(adaptor.getValue()); - - rewriter.create( - op.getLoc(), TypeRange{}, "PTOAS__EVENTID_ARRAY_STORE", - ArrayAttr{}, ArrayAttr{}, ValueRange{array, index, value}); - rewriter.eraseOp(op); - return success(); - } -}; - -// pto.declare_local_array -> emitc.variable of !emitc.array<...>. -// Renders as `T a[D1][D2]...;` in the emitted C++. -struct PTODeclareLocalArrayToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::DeclareLocalArrayOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::DeclareLocalArrayOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - Type arrayTy = getTypeConverter()->convertType(op.getArray().getType()); - if (!arrayTy) - return rewriter.notifyMatchFailure(op, - "failed to map !pto.local_array type"); - - auto var = rewriter - .create( - op.getLoc(), arrayTy, - emitc::OpaqueAttr::get(rewriter.getContext(), "")) - .getResult(); - rewriter.replaceOp(op, var); - return success(); - } -}; - -// pto.local_array_get %a[%i0, %i1, ...] -> rvalue. -// Lowers to a single emitc.subscript with the full index pack; the C++ emitter -// prints it as `a[i0][i1]...`. The adaptor already exposes target-typed values -// (the type converter has remapped !pto.local_array -> !emitc.array and -// index/integer indices), so they're forwarded directly to the builder. -struct PTOLocalArrayGetToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::LocalArrayGetOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::LocalArrayGetOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type resultTy = - getTypeConverter()->convertType(op.getResult().getType()); - if (!resultTy) - return rewriter.notifyMatchFailure( - op, "failed to map local_array element type"); - - auto sub = rewriter.create( - op.getLoc(), resultTy, adaptor.getArray(), adaptor.getIndices()); - rewriter.replaceOp(op, sub.getResult()); - return success(); - } -}; - -// pto.local_array_set %a[%i0, %i1, ...], %v -> emitc.assign to subscript slot. -// The C++ emitter prints this as `a[i0][i1]... = v;`. As above, adaptor values -// are already target-typed; pass them through directly. -struct PTOLocalArraySetToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::LocalArraySetOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::LocalArraySetOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value value = adaptor.getValue(); - Type elemTy = value.getType(); - - Value slot = rewriter - .create(op.getLoc(), elemTy, - adaptor.getArray(), - adaptor.getIndices()) - .getResult(); - rewriter.create(op.getLoc(), slot, value); - rewriter.eraseOp(op); - return success(); - } -}; - -static std::optional getStaticIndexLikeValue(Value value) { - if (!value) - return std::nullopt; - if (auto cst = value.getDefiningOp()) - return cst.value(); - if (auto cst = value.getDefiningOp()) - return cst.value(); - if (auto cst = value.getDefiningOp()) { - if (auto intAttr = dyn_cast(cst.getValue())) - return intAttr.getInt(); - } - return std::nullopt; -} - -static FailureOr buildGlobalTensorViewFromPointer( - ConversionPatternRewriter &rewriter, Location loc, Value ptr, Type elemTy, - ArrayRef shape, ArrayRef strides = {}, - StringRef layoutEnum = "pto::Layout::ND") { - if (llvm::any_of(shape, [](int64_t dim) { - return dim == ShapedType::kDynamic; - })) - return failure(); - - auto *ctx = rewriter.getContext(); - SmallVector rowMajorStrides; - ArrayRef effectiveStrides = strides; - if (effectiveStrides.empty()) { - rowMajorStrides = buildRowMajorStrides(shape); - effectiveStrides = rowMajorStrides; - } - SmallVector shape5D; - SmallVector stride5D; - buildGlobalTensorShapeAndStride(shape, effectiveStrides, shape5D, stride5D); - - std::string shapeType = "pto::Shape<" + joinIntTemplateParams(shape5D) + ">"; - std::string strideType = - "pto::Stride<" + joinIntTemplateParams(stride5D) + ">"; - auto shapeVal = rewriter - .create( - loc, emitc::OpaqueType::get(ctx, shapeType), - shapeType, ArrayAttr{}, ArrayAttr{}, ValueRange{}) - .getResult(0); - auto strideVal = rewriter - .create( - loc, emitc::OpaqueType::get(ctx, strideType), - strideType, ArrayAttr{}, ArrayAttr{}, ValueRange{}) - .getResult(0); - - std::string gtTypeStr = - getGlobalTensorTypeStringFromShapeAndStrides(elemTy, shape, - effectiveStrides, - layoutEnum); - auto gtType = emitc::OpaqueType::get(ctx, gtTypeStr); - auto gt = rewriter.create( - loc, gtType, gtTypeStr, ArrayAttr{}, ArrayAttr{}, - ValueRange{ptr, shapeVal, strideVal}); - return gt.getResult(0); -} - -static bool parseIntegerTemplateList(StringRef token, StringRef marker, - SmallVectorImpl &values) { - size_t pos = token.find(marker); - if (pos == StringRef::npos) - return false; - pos += marker.size(); - size_t end = token.find('>', pos); - if (end == StringRef::npos) - return false; - - SmallVector parts; - token.slice(pos, end).split(parts, ','); - values.clear(); - for (StringRef part : parts) { - int64_t value = 0; - if (part.trim().getAsInteger(10, value)) - return false; - values.push_back(value); - } - return true; -} - -static LogicalResult getStaticTensorViewStrides( - Value source, Value convertedSource, pto::TensorViewType sourceType, - SmallVectorImpl &strides) { - int64_t rank = sourceType.getRank(); - strides.clear(); - - if (auto makeView = source.getDefiningOp()) { - if ((int64_t)makeView.getStrides().size() != rank) - return failure(); - for (Value strideValue : makeView.getStrides()) { - auto cst = getStaticIndexLikeValue(strideValue); - if (!cst) - return failure(); - strides.push_back(*cst); - } - return success(); - } - - Value src = peelUnrealized(convertedSource); - if (auto opaqueTy = dyn_cast(src.getType())) { - SmallVector stride5D; - StringRef token = opaqueTy.getValue(); - if ((parseIntegerTemplateList(token, "pto::Stride<", stride5D) || - parseIntegerTemplateList(token, "Stride<", stride5D)) && - (int64_t)stride5D.size() >= rank) { - strides.append(stride5D.end() - rank, stride5D.end()); - return success(); - } - } - - auto fallback = buildRowMajorStrides(sourceType.getShape()); - strides.append(fallback.begin(), fallback.end()); - return success(); -} - -struct PTOPartitionViewToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::PartitionViewOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::PartitionViewOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto srcTy = dyn_cast(op.getSource().getType()); - auto resTy = dyn_cast(op.getResult().getType()); - if (!srcTy || !resTy) - return rewriter.notifyMatchFailure( - op, "expected tensor_view source and partition_tensor_view result"); - - if (op.getOffsets().size() != static_cast(srcTy.getRank()) || - op.getSizes().size() != static_cast(srcTy.getRank())) - return rewriter.notifyMatchFailure(op, "rank mismatch"); - - for (auto [idx, value] : llvm::enumerate(op.getSizes())) { - auto cst = getStaticIndexLikeValue(value); - if (!cst) - return rewriter.notifyMatchFailure( - op, "globaltensor partition_view requires static sizes"); - int64_t resultDim = resTy.getShape()[idx]; - if (resultDim != ShapedType::kDynamic && resultDim != *cst) - return rewriter.notifyMatchFailure( - op, "partition_view static size does not match result type"); - } - - SmallVector srcStrides; - if (failed(getStaticTensorViewStrides(op.getSource(), adaptor.getSource(), - srcTy, srcStrides))) - return rewriter.notifyMatchFailure( - op, "partition_view requires static source strides"); - int64_t staticLinearOffset = 0; - SmallVector> dynamicOffsetTerms; - for (auto [idx, values] : - llvm::enumerate(llvm::zip(op.getOffsets(), adaptor.getOffsets()))) { - Value originalOffset = std::get<0>(values); - Value convertedOffset = std::get<1>(values); - int64_t stride = srcStrides[idx]; - if (stride == ShapedType::kDynamic) - return rewriter.notifyMatchFailure( - op, "dynamic source stride is not supported"); - - if (auto cst = getStaticIndexLikeValue(originalOffset)) { - if (*cst != 0) - staticLinearOffset += (*cst) * stride; - continue; - } - dynamicOffsetTerms.push_back({convertedOffset, stride}); - } - - auto *ctx = rewriter.getContext(); - std::string elemTypeStr = getElemTypeStringForGT(srcTy.getElementType()); - auto ptrTy = emitc::PointerType::get( - emitc::OpaqueType::get(ctx, "__gm__ " + elemTypeStr)); - Value src = peelUnrealized(adaptor.getSource()); - auto data = rewriter - .create( - op.getLoc(), ptrTy, "PTOAS__GLOBAL_TENSOR_DATA", - ArrayAttr{}, ArrayAttr{}, ValueRange{src}) - .getResult(0); - Value ptr = data; - if (!dynamicOffsetTerms.empty()) { - Type u32Ty = emitc::OpaqueType::get(ctx, "unsigned"); - auto makeU32 = [&](int64_t value) { - return makeEmitCIntConstant(rewriter, op.getLoc(), u32Ty, value); - }; - auto asU32 = [&](Value value) -> Value { - if (value.getType() == u32Ty) - return value; - return rewriter.create(op.getLoc(), u32Ty, value) - .getResult(); - }; - - Value totalOffset = makeU32(staticLinearOffset); - for (auto [offsetValue, stride] : dynamicOffsetTerms) { - Value term = asU32(offsetValue); - if (stride != 1) { - Value strideValue = makeU32(stride); - term = rewriter - .create(op.getLoc(), u32Ty, term, - strideValue) - .getResult(); - } - totalOffset = rewriter - .create(op.getLoc(), u32Ty, - totalOffset, term) - .getResult(); - } - ptr = rewriter - .create(op.getLoc(), data.getType(), data, - totalOffset) - .getResult(); - } else { - ptr = applyStaticMemrefOffset(rewriter, op.getLoc(), data, - staticLinearOffset); - } - - auto resultOr = buildGlobalTensorViewFromPointer( - rewriter, op.getLoc(), ptr, resTy.getElementType(), resTy.getShape(), - srcStrides); - if (failed(resultOr)) - return rewriter.notifyMatchFailure( - op, "failed to materialize partition GlobalTensor"); - - rewriter.replaceOp(op, *resultOr); - return success(); - } -}; - -static FailureOr getPipeDataTypeToken(Value value) { - auto opaqueTy = dyn_cast(value.getType()); - if (!opaqueTy) - return failure(); - StringRef token = opaqueTy.getValue(); - if (!token.contains("Tile<") && !token.contains("GlobalTensor<")) - return failure(); - return token.str(); -} - -struct PTOTAllocToEmitC : public OpConversionPattern { - PTOTAllocToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult matchAndRewrite(mlir::pto::TAllocOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); - if (failed(pipeTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); - Value entry = peelUnrealized(adaptor.getEntry()); - auto entryTok = getPipeDataTypeToken(entry); - if (failed(entryTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve entry token"); - auto splitTok = getTileSplitToken(op.getSplit()); - if (failed(splitTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve split token"); - - std::string callee = - "TALLOC<" + *pipeTok + ", " + *entryTok + ", " + *splitTok + ">"; - rewriter.replaceOpWithNewOp( - op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, - ValueRange{peelUnrealized(adaptor.getPipeHandle()), entry}); - return success(); - } - - PTOArch targetArch; -}; - -struct PTOTPushToEmitC : public OpConversionPattern { - PTOTPushToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult matchAndRewrite(mlir::pto::TPushOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); - if (failed(pipeTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); - // Read the tile type token from the already-converted OpaqueType, which - // preserves the exact layout produced by BindTileOp / PointerCastOp EmitC. - Value convertedTile = peelUnrealized(adaptor.getTile()); - auto tileTok = getPipeDataTypeToken(convertedTile); - if (failed(tileTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve tile token"); - auto splitTok = getTileSplitToken(op.getSplit()); - if (failed(splitTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve split token"); - - std::string callee = - "TPUSH<" + *pipeTok + ", " + *tileTok + ", " + *splitTok + ">"; - rewriter.replaceOpWithNewOp( - op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, - ValueRange{peelUnrealized(adaptor.getPipeHandle()), convertedTile}); - return success(); - } - - PTOArch targetArch; -}; - -struct PTOTPopToEmitC : public OpConversionPattern { - PTOTPopToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult matchAndRewrite(mlir::pto::TPopOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); - if (failed(pipeTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); - Value convertedTile = peelUnrealized(adaptor.getTile()); - auto tileTok = getPipeDataTypeToken(convertedTile); - if (failed(tileTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve tile token"); - auto splitTok = getTileSplitToken(op.getSplit()); - if (failed(splitTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve split token"); - - std::string callee = - "TPOP<" + *pipeTok + ", " + *tileTok + ", " + *splitTok + ">"; - rewriter.replaceOpWithNewOp( - op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, - ValueRange{peelUnrealized(adaptor.getPipeHandle()), convertedTile}); - return success(); - } - - PTOArch targetArch; -}; - -struct PTOTFreeToEmitC : public OpConversionPattern { - PTOTFreeToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult matchAndRewrite(mlir::pto::TFreeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); - if (failed(pipeTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); - auto splitTok = getTileSplitToken(op.getSplit()); - if (failed(splitTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve split token"); - - SmallVector operands{peelUnrealized(adaptor.getPipeHandle())}; - std::string callee; - if (op.getEntry()) { - Value entry = peelUnrealized(adaptor.getEntry()); - auto entryTok = getPipeDataTypeToken(entry); - if (failed(entryTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve entry token"); - callee = "TFREE<" + *pipeTok + ", " + *entryTok + ", " + *splitTok + ">"; - operands.push_back(entry); - } else { - callee = "TFREE<" + *pipeTok + ", " + *splitTok + ">"; - } - rewriter.replaceOpWithNewOp( - op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, operands); - return success(); - } - - PTOArch targetArch; -}; - -//===----------------------------------------------------------------------===// -// populate patterns -//===----------------------------------------------------------------------=== -struct ReinterpretCastToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - auto resMrTy = dyn_cast(op.getType()); - if (!resMrTy) - return failure(); - - auto asAttr = dyn_cast_or_null(resMrTy.getMemorySpace()); - const bool isGm = (!asAttr || asAttr.getAddressSpace() == pto::AddressSpace::GM); - - bool emitAddPtrTrace = op->hasAttr("pto.addptr_trace"); - Value source = peelUnrealized(adaptor.getSource()); - auto offsets = adaptor.getOffsets(); - Value offsetVal = offsets.empty() ? Value() : offsets[0]; - - // GM: keep pointer arithmetic. - if (isGm) { - if (!offsetVal) { - rewriter.replaceOp(op, source); - return success(); - } - - Type resultType = getTypeConverter()->convertType(op.getType()); - if (!resultType) - return failure(); - - auto addOp = rewriter.create(loc, resultType, source, offsetVal); - if (emitAddPtrTrace) { - rewriter.setInsertionPointAfter(addOp); - rewriter.create( - loc, TypeRange{}, "PTOAS__ADDPTR_TRACE", - ArrayAttr{}, ArrayAttr{}, - ValueRange{addOp.getResult(), source, offsetVal}); - } - rewriter.replaceOp(op, addOp.getResult()); - return success(); - } - - // UB/L1/L0 tiles: materialize a new Tile view by assigning an adjusted - // underlying pointer (in elements). - pto::AddressSpace as = asAttr.getAddressSpace(); - - // Element type token. - Type elemTy = resMrTy.getElementType(); - std::string elemTok = getEmitCScalarTypeToken(elemTy); - int64_t elemBytes = getEmitCScalarByteWidth(elemTy); - - // Tile role. - const char *roleTok = "TileType::Vec"; - switch (as) { - case pto::AddressSpace::VEC: - roleTok = "TileType::Vec"; - break; - case pto::AddressSpace::MAT: - roleTok = "TileType::Mat"; - break; - case pto::AddressSpace::LEFT: - roleTok = "TileType::Left"; - break; - case pto::AddressSpace::RIGHT: - roleTok = "TileType::Right"; - break; - case pto::AddressSpace::ACC: - roleTok = "TileType::Acc"; - break; - case pto::AddressSpace::BIAS: - roleTok = "TileType::Bias"; - break; - case pto::AddressSpace::GM: - roleTok = "TileType::Vec"; - break; - case pto::AddressSpace::Zero: - roleTok = "TileType::Vec"; - break; - case pto::AddressSpace::SCALING: - roleTok = "TileType::Scaling"; - break; - } - - // Shape (fallback to 32x32). - int64_t rows = 32, cols = 32; - if (resMrTy.getRank() >= 2 && resMrTy.hasStaticShape()) { - rows = resMrTy.getDimSize(0); - cols = resMrTy.getDimSize(1); - } - int64_t templateRows = - renderTileTemplateDim(rows, elemTy, pto::BLayout::RowMajor, 0); - int64_t templateCols = - renderTileTemplateDim(cols, elemTy, pto::BLayout::RowMajor, 1); - - // Keep a conservative default config for now. - std::string tileTypeStr = - std::string("Tile<") + roleTok + ", " + elemTok + ", " + - std::to_string(templateRows) + ", " + std::to_string(templateCols) + - ", BLayout::RowMajor, " + std::to_string(templateRows) + ", " + - std::to_string(templateCols) + - ", SLayout::NoneBox, 512, PadValue::Null, CompactMode::Null>"; - - auto tileType = emitc::OpaqueType::get(ctx, tileTypeStr); - Value tile = rewriter - .create(loc, tileType, - emitc::OpaqueAttr::get(ctx, "")) - .getResult(); - - // Compute an integer address and assign it to the new tile. - // NOTE: pto-isa TASSIGN requires an integral address (not a pointer). - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - auto rcU64 = rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - - // Non-GM reinterpret_cast operands come from UB/L1/L0 tiles. - // We need the underlying address, but `__cce_get_tile_ptr()` is only valid - // inside `__tf__` functions. Use `tile.data()` (via a post-processed marker) - // and compute the adjusted address in bytes. - Value rawPtr = source; - if (auto ot = dyn_cast(source.getType())) { - // Only Tiles have a `.data()` member. For plain address-space pointers - // (e.g. `__ubuf__ float*`), use the pointer value directly. - if (ot.getValue().starts_with("Tile<")) { - rawPtr = materializeTileDataValue(rewriter, loc, source, as, elemTok); - } - } - - Value baseAddr = rawPtr; - if (isSetFFTsPointerLikeType(rawPtr.getType())) { - baseAddr = rewriter - .create(loc, u64Ty, "reinterpret_cast", - /*args=*/ArrayAttr{}, - /*templateArgs=*/rcU64, - /*operands=*/ValueRange{rawPtr}) - .getResult(0); - } else if (rawPtr.getType() != u64Ty) { - baseAddr = rewriter.create(loc, u64Ty, rawPtr).getResult(); - } - - Value addr = baseAddr; - if (offsetVal) { - Value offU64 = offsetVal; - if (offU64.getType() != u64Ty) - offU64 = rewriter.create(loc, u64Ty, offU64).getResult(); - - auto bytesAttr = emitc::OpaqueAttr::get(ctx, std::to_string(elemBytes)); - Value bytesVal = rewriter.create(loc, u64Ty, bytesAttr); - Value byteOff = rewriter.create(loc, u64Ty, offU64, bytesVal); - addr = rewriter.create(loc, u64Ty, baseAddr, byteOff); - } - - rewriter.create(loc, TypeRange{}, "TASSIGN", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{tile, addr}); - - rewriter.replaceOp(op, tile); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.taddc lowering -> TADDC(dst, src0, src1, src2) -//===----------------------------------------------------------------------===// - -struct PTOTAddCToTADDC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAddCOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value src2 = peelUnrealized(adaptor.getSrc2()); - Value dst = peelUnrealized(adaptor.getDst()); - - // pto-isa does not provide NPU implementation for TADDC yet. - // Decompose: dst = src0 + src1 + src2 - rewriter.create( - loc, TypeRange{}, "TADD", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TADD", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, dst, src2}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tadds lowering -> TADDS(dst, src, scalar) -//===----------------------------------------------------------------------===// - -struct PTOAddSToTADDS : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAddSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value scalar = peelUnrealized(adaptor.getScalar()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TADDS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src, scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.taddsc lowering -> TADDSC(dst, src0, scalar, src1) -//===----------------------------------------------------------------------===// - -struct PTOAddSCToTADDSC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAddSCOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - // pto-isa does not provide NPU implementation for TADDSC yet. - // Decompose: dst = src0 + scalar + src1 - rewriter.create( - loc, TypeRange{}, "TADDS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src0, scalar}); - rewriter.create( - loc, TypeRange{}, "TADD", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, dst, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTOTAndToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAndOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getSrc0()); - Value b = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TAND", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, a, b}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOConcatToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TConcatOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TCONCAT", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTOConcatidxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TConcatidxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value src0Idx = peelUnrealized(adaptor.getSrc0Idx()); - Value src1Idx = peelUnrealized(adaptor.getSrc1Idx()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TCONCAT", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src0, src1, src0Idx, src1Idx}); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTOAndSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAndSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TANDS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src, scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; - - -struct PTOTCIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TCIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value S = peelUnrealized(adaptor.getOperands()[0]); - - // The TCI scalar template parameter should follow the original PTO IR - // scalar type, not the converted EmitC value type. - std::string scalarTok = "int32_t"; - if (auto it = dyn_cast(op->getOperand(0).getType())) { - bool isUnsigned = it.isUnsigned(); - if (it.getWidth() == 16) - scalarTok = isUnsigned ? "uint16_t" : "int16_t"; - else - scalarTok = isUnsigned ? "uint32_t" : "int32_t"; - } - - // descending -> "0"/"1" - std::string descTok = op.getDescending() ? "1" : "0"; - - ArrayAttr targs; - if (auto ot = mlir::dyn_cast(dst.getType())) { - std::string tileTok = ot.getValue().str(); // "Tile<...>" - targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, tileTok), - emitc::OpaqueAttr::get(ctx, scalarTok), - emitc::OpaqueAttr::get(ctx, descTok), - }); - } else { - targs = rewriter.getArrayAttr({}); - } - - rewriter.create( - loc, TypeRange{}, "TCI", - /*args=*/ArrayAttr{}, - /*templateArgs=*/targs, - /*operands=*/ValueRange{dst, S}); - - rewriter.eraseOp(op); - return success(); - } -}; -static std::string cmpModeTok(pto::CmpModeAttr a) { - // 生成 "CmpMode::GT" 这种 token - auto m = a.getValue(); // 取 enum - switch (m) { - case pto::CmpMode::EQ: return "CmpMode::EQ"; - case pto::CmpMode::NE: return "CmpMode::NE"; - case pto::CmpMode::LT: return "CmpMode::LT"; - case pto::CmpMode::LE: return "CmpMode::LE"; - case pto::CmpMode::GT: return "CmpMode::GT"; - case pto::CmpMode::GE: return "CmpMode::GE"; - } - return "CmpMode::EQ"; -} -struct PTOColExpandToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPAND", - /*args=*/ArrayAttr(), - /*templateArgs=*/ArrayAttr(), - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandMulToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandMulOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDMUL", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandAddToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandAddOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDADD", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandDivToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandDivOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDDIV", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandExpdifToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandExpdifOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDEXPDIF", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandSubToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandSubOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDSUB", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDMAX", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDMIN", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOTTriToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TTriOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value diagonal = peelUnrealized(adaptor.getDiagonal()); - - ArrayAttr templateArgs; - if (auto dstOT = mlir::dyn_cast(dst.getType())) { - templateArgs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, std::to_string(op.getUpperOrLower())), - }); - } else { - templateArgs = ArrayAttr{}; - } - - SmallVector operands{dst, diagonal}; - rewriter.create( - loc, TypeRange{}, "TTRI", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOCmpToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TCmpOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - - std::string tok = "CmpMode::EQ"; - if (auto a = op.getCmpModeAttr()) - tok = cmpModeTok(a); - - auto modeTy = emitc::OpaqueType::get(ctx, "CmpMode"); - Value modeVal = rewriter.create( - loc, modeTy, emitc::OpaqueAttr::get(ctx, tok)); - - rewriter.create( - loc, - TypeRange{}, - "TCMP", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1, modeVal}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOCmpSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TCmpSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - - // cmpMode -> token - auto cmpAttr = op.getCmpModeAttr(); // PTO_CmpModeAttr - std::string tok = cmpModeTok(cmpAttr); - - auto modeTy = emitc::OpaqueType::get(ctx, "CmpMode"); - Value modeVal = rewriter.create( - loc, modeTy, emitc::OpaqueAttr::get(ctx, tok)); - - rewriter.create( - loc, - TypeRange{}, - "TCMPS", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, scalar, modeVal}); - - rewriter.eraseOp(op); - return success(); - } -}; - - -struct PTOColMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - // intrinsic: TCOLMAX(dst, src) - rewriter.create( - loc, TypeRange{}, "TCOLMAX", - /*args=*/ArrayAttr{}, // default: print all operands - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColArgMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColArgMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLARGMAX", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, tmp}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - // intrinsic: TCOLMIN(dst, src) - rewriter.create( - loc, TypeRange{}, "TCOLMIN", - /*args=*/ArrayAttr{}, // default: print all operands - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColArgMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColArgMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLARGMIN", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, tmp}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColSumToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColSumOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - // Check if tmp exists before accessing it - if (op.getTmp()) { - // Format 2: with tmp and isBinary - Value tmp = peelUnrealized(adaptor.getTmp()); - bool isBinary = false; - if (auto a = op.getIsBinaryAttr()) - isBinary = a.getValue(); - - auto boolTy = emitc::OpaqueType::get(ctx, "bool"); - auto tok = isBinary ? "true" : "false"; - Value isBinaryVal = rewriter.create( - loc, boolTy, emitc::OpaqueAttr::get(ctx, tok)); - - rewriter.create( - loc, TypeRange{}, "TCOLSUM", - /*args=*/ArrayAttr(), - /*templateArgs=*/ArrayAttr(), - /*operands=*/ValueRange{dst, src, tmp, isBinaryVal}); - } else { - // Format 1: without tmp and isBinary - rewriter.create( - loc, TypeRange{}, "TCOLSUM", - /*args=*/ArrayAttr(), - /*templateArgs=*/ArrayAttr(), - /*operands=*/ValueRange{dst, src}); - } - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColProdToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColProdOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLPROD", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; -static std::string roundModeTok(mlir::pto::RoundModeAttr attr) { - using RM = mlir::pto::RoundMode; - switch (attr.getValue()) { - case RM::NONE: return "RoundMode::CAST_NONE"; - case RM::RINT: return "RoundMode::CAST_RINT"; - case RM::ROUND: return "RoundMode::CAST_ROUND"; - case RM::FLOOR: return "RoundMode::CAST_FLOOR"; - case RM::CEIL: return "RoundMode::CAST_CEIL"; - case RM::TRUNC: return "RoundMode::CAST_TRUNC"; - case RM::ODD: return "RoundMode::CAST_ODD"; - case RM::CAST_RINT: return "RoundMode::CAST_RINT"; - } - return "RoundMode::CAST_RINT"; -} -static std::string saturationModeTok(mlir::pto::SaturationModeAttr attr) { - using SM = mlir::pto::SaturationMode; - switch (attr.getValue()) { - case SM::ON: return "SaturationMode::ON"; - case SM::OFF: return "SaturationMode::OFF"; - } - return "SaturationMode::OFF"; -} -struct PTOCvtToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TCvtOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - pto::RoundModeAttr rmAttr = op.getRmodeAttr(); - std::string rmTok = rmAttr ? roundModeTok(rmAttr) - : std::string("RoundMode::CAST_RINT"); - auto rmodeTy = emitc::OpaqueType::get(ctx, "RoundMode"); - Value rmodeVal = rewriter.create( - loc, rmodeTy, emitc::OpaqueAttr::get(ctx, rmTok)); - - auto satModeTy = emitc::OpaqueType::get(ctx, "SaturationMode"); - auto satAttr = op.getSatModeAttr(); - std::string satTok = satAttr ? saturationModeTok(satAttr) - : std::string("SaturationMode::OFF"); - Value satModeVal = rewriter.create( - loc, satModeTy, emitc::OpaqueAttr::get(ctx, satTok)); - - SmallVector operands{dst, src, rmodeVal, satModeVal}; - - rewriter.create( - loc, TypeRange{}, "TCVT", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTORandomToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRandomOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - SmallVector operands{ - dst, - peelUnrealized(adaptor.getKey0()), - peelUnrealized(adaptor.getKey1()), - peelUnrealized(adaptor.getCounter0()), - peelUnrealized(adaptor.getCounter1()), - peelUnrealized(adaptor.getCounter2()), - peelUnrealized(adaptor.getCounter3()), - }; - ArrayAttr templateArgs = rewriter.getArrayAttr( - {emitc::OpaqueAttr::get(ctx, std::to_string(op.getRounds()))}); - - rewriter.create( - loc, TypeRange{}, "PTOAS__TRANDOM", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, operands); - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tdiv lowering -> TDIV(dst, src0, src1) -//===----------------------------------------------------------------------===// - -struct PTODivToTDIV : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TDivOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TDIV", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tdivs lowering -> TDIVS(dst, src, scalar) or TDIVS(dst, scalar, src) -// Order is determined by operand types: if src is tile_buf, order is (tile, scalar) -// Otherwise, order is (scalar, tile) -//===----------------------------------------------------------------------===// - -struct PTODivSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TDivSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value dst = peelUnrealized(adaptor.getDst()); - // Preserve source order from textual parse: - // ins(tile, scalar) -> TDIVS(dst, tile, scalar) - // ins(scalar, tile) -> TDIVS(dst, scalar, tile) - rewriter.create( - loc, TypeRange{}, "TDIVS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src, scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.tdivs (TDivSOp) lowering -> TDIVS(dst, src, scalar) or TDIVS(dst, scalar, src) -// Order is determined by operand types: if src is tile_buf, order is (tile, scalar) -// Otherwise, order is (scalar, tile) -//===----------------------------------------------------------------------===// - -struct PTOTDivSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TDivSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value dst = peelUnrealized(adaptor.getDst()); - rewriter.create( - loc, TypeRange{}, "TDIVS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src, scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.texp lowering -> TEXP(dst, src) -//===----------------------------------------------------------------------===// - -struct PTOExpToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TExpOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TEXP", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.texpands lowering -> TEXPANDS(dst, scalar) -//===----------------------------------------------------------------------===// - -struct PTOExpandsToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TExpandsOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value scalar = peelUnrealized(adaptor.getScalar()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TEXPANDS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.textract lowering -> TEXTRACT(dst, src, indexRow, indexCol) -//===----------------------------------------------------------------------===// - -struct PTOExtractToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TExtractOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value r0 = peelUnrealized(adaptor.getIndexRow()); - Value c0 = peelUnrealized(adaptor.getIndexCol()); - - rewriter.create( - loc, TypeRange{}, "TEXTRACT", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, r0, c0}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.textract_fp lowering -> TEXTRACT_FP(dst, src, fp, indexRow, indexCol) -//===----------------------------------------------------------------------===// - -struct PTOExtractFPToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TExtractFPOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value fp = peelUnrealized(adaptor.getFp()); - Value dst = peelUnrealized(adaptor.getDst()); - Value r0 = peelUnrealized(adaptor.getIndexRow()); - Value c0 = peelUnrealized(adaptor.getIndexCol()); - - rewriter.create( - loc, TypeRange{}, "TEXTRACT_FP", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, fp, r0, c0}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tinsert lowering -> TINSERT(dst, src, indexRow, indexCol) -// Keep lowering arch-agnostic and let PTO-ISA infer proper A5 path. -//===----------------------------------------------------------------------===// - -struct PTOInsertToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TInsertOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value r0 = peelUnrealized(adaptor.getIndexRow()); - Value c0 = peelUnrealized(adaptor.getIndexCol()); - - rewriter.create( - loc, TypeRange{}, "TINSERT", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, r0, c0}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tinsert_fp lowering -> TINSERT_FP(dst, src, fp, indexRow, indexCol) -//===----------------------------------------------------------------------===// - -struct PTOInsertFPToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TInsertFPOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value fp = peelUnrealized(adaptor.getFp()); - Value dst = peelUnrealized(adaptor.getDst()); - Value r0 = peelUnrealized(adaptor.getIndexRow()); - Value c0 = peelUnrealized(adaptor.getIndexCol()); - - rewriter.create( - loc, TypeRange{}, "TINSERT_FP", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, fp, r0, c0}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tfillpad lowering -> TFILLPAD(dst, src) -//===----------------------------------------------------------------------===// - -struct PTOFillPadToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TFillPadOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TFILLPAD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tfillpad_inplace lowering -> TFILLPAD_INPLACE(dst, src) -//===----------------------------------------------------------------------===// - -struct PTOFillPadInplaceToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TFillPadInplaceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TFILLPAD_INPLACE", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tfillpad_expand lowering -> TFILLPAD_EXPAND(dst, src) -//===----------------------------------------------------------------------===// - -struct PTOFillPadExpandToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TFillPadExpandOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TFILLPAD_EXPAND", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tgather lowering -// - Index form : TGATHER(dst, src0, indices, tmp) -// - Compare form: TGATHER(dst, src0, kValue, cdst, tmp) -// - Mask form : TGATHER(dst, src0) -//===----------------------------------------------------------------------===// - -[[maybe_unused]] static std::string maskPatternTok(mlir::pto::MaskPatternAttr a) { - - auto v = a.getValue(); // enum - return (std::string("pto::MaskPattern::") + mlir::pto::stringifyMaskPattern(v).str()); -} - -struct PTOGatherToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGatherOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src0 = peelUnrealized(adaptor.getSrc()); - - auto getOpaqueTok = [&](Value v, StringRef name) -> FailureOr { - if (auto ot = mlir::dyn_cast(v.getType())) - return ot.getValue().str(); - return rewriter.notifyMatchFailure(op, (name + " must be emitc::OpaqueType (tile)").str()); - }; - - // Case 1: index-based TGATHER(dst, src0, indices, tmp) - if (Value idx = adaptor.getIndices()) { - idx = peelUnrealized(idx); - Value tmp = peelUnrealized(adaptor.getTmp()); - - rewriter.create( - loc, TypeRange{}, "TGATHER", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, idx, tmp}); - - rewriter.eraseOp(op); - return success(); - } - - // Case 2: compare-based TGATHER( - // dst, src0, kValue, tmp, cdst, offset) - if (Value cdst = adaptor.getCdst()) { - cdst = peelUnrealized(cdst); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value kValue = peelUnrealized(adaptor.getKValue()); - - auto dstTokOr = getOpaqueTok(dst, "dst"); - auto srcTokOr = getOpaqueTok(src0, "src0"); - auto cdstTokOr = getOpaqueTok(cdst, "cdst"); - auto tmpTokOr = getOpaqueTok(tmp, "tmp"); - if (failed(dstTokOr) || failed(srcTokOr) || failed(cdstTokOr) || failed(tmpTokOr)) - return failure(); - - auto cmpAttr = op.getCmpModeAttr(); - std::string cmpTok = cmpAttr ? cmpModeTok(cmpAttr) : "CmpMode::EQ"; - int64_t offset = 0; - if (auto offsetAttr = op.getOffsetAttr()) - offset = offsetAttr.getInt(); - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value offsetVal = makeEmitCIntConstant(rewriter, loc, i32Ty, offset); - - auto targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, *dstTokOr), - emitc::OpaqueAttr::get(ctx, *srcTokOr), - emitc::OpaqueAttr::get(ctx, *tmpTokOr), - emitc::OpaqueAttr::get(ctx, *cdstTokOr), - emitc::OpaqueAttr::get(ctx, cmpTok), - }); - - rewriter.create( - loc, TypeRange{}, "TGATHER", - /*args=*/ArrayAttr{}, /*templateArgs=*/targs, - /*operands=*/ValueRange{dst, src0, kValue, tmp, cdst, offsetVal}); - - rewriter.eraseOp(op); - return success(); - } - - // Case 3: mask-pattern TGATHER(dst, src0) - auto mp = op.getMaskPatternAttr(); - if (!mp) - return rewriter.notifyMatchFailure(op, "expected maskPattern, indices, or cdst on tgather"); - - auto dstTokOr = getOpaqueTok(dst, "dst"); - auto srcTokOr = getOpaqueTok(src0, "src0"); - if (failed(dstTokOr) || failed(srcTokOr)) - return failure(); - - // mp is an EnumAttr; stringify name is "P0101" etc. - // We emit MaskPattern::P0101 (because generated C++ has `using namespace pto;`) - std::string mpTok = std::string("MaskPattern::") + - mlir::pto::stringifyMaskPattern(mp.getValue()).str(); - - auto targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, *dstTokOr), - emitc::OpaqueAttr::get(ctx, *srcTokOr), - emitc::OpaqueAttr::get(ctx, mpTok), - }); - - rewriter.create( - loc, TypeRange{}, "TGATHER", - /*args=*/ArrayAttr{}, - /*templateArgs=*/targs, - /*operands=*/ValueRange{dst, src0}); - - rewriter.eraseOp(op); - return success(); - } -}; - - -struct PTOGatherbToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGatherBOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value offsets = peelUnrealized(adaptor.getOffsets()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TGATHERB", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, offsets}); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// TLOG lowering to EmitC (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - -struct PTOLogToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TLogOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TLOG", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - - - -//===----------------------------------------------------------------------===// -// TLRELU lowering to EmitC (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - - struct PTOLReluToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TLReluOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value slope = peelUnrealized(adaptor.getSlope()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, slope}; - - rewriter.create( - loc, TypeRange{}, "TLRELU", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// TMAX lowering to EmitC (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - -struct PTOMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TMAX", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// TMAXS lowering to EmitC (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - - struct PTOMaxSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMaxSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, scalar}; - rewriter.create( - loc, TypeRange{}, "TMAXS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - - -//===----------------------------------------------------------------------===// -// TMIN lowering to EmitC (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - -struct PTOMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TMIN", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// TMINS lowering to EmitC (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - -//===----------------------------------------------------------------------===// -// TMINS lowering to EmitC (fix APFloat -> FloatAttr) (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - -struct PTOMinsToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMinSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value scalar = peelUnrealized(adaptor.getScalar()); - - SmallVector operands{dst, src, scalar}; - rewriter.create( - loc, TypeRange{}, "TMINS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering for TMOV op -> EmitC) -//===----------------------------------------------------------------------===// - -struct PTOMovToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMovOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value fp; - if (op.getFp()) - fp = peelUnrealized(adaptor.getFp()); - Value preQuantScalar; - if (op.getPreQuantScalar()) - preQuantScalar = peelUnrealized(adaptor.getPreQuantScalar()); - - auto dstOT = mlir::dyn_cast(dst.getType()); - auto srcOT = mlir::dyn_cast(src.getType()); - if (!dstOT || !srcOT) - return rewriter.notifyMatchFailure( - op, "tmov lowering expects opaque dst/src types"); - - auto modeTok = [&](pto::AccToVecMode mode) -> StringRef { - switch (mode) { - case pto::AccToVecMode::SingleModeVec0: - return "pto::AccToVecMode::SingleModeVec0"; - case pto::AccToVecMode::SingleModeVec1: - return "pto::AccToVecMode::SingleModeVec1"; - case pto::AccToVecMode::DualModeSplitM: - return "pto::AccToVecMode::DualModeSplitM"; - case pto::AccToVecMode::DualModeSplitN: - return "pto::AccToVecMode::DualModeSplitN"; - } - llvm_unreachable("unknown AccToVecMode"); - }; - - auto modeAttr = op.getAccToVecModeAttr(); - auto reluTok = [&](pto::ReluPreMode mode) -> StringRef { - switch (mode) { - case pto::ReluPreMode::NoRelu: - return "ReluPreMode::NoRelu"; - case pto::ReluPreMode::NormalRelu: - return "ReluPreMode::NormalRelu"; - } - llvm_unreachable("unknown ReluPreMode"); - }; - - const bool hasFp = static_cast(fp); - const bool hasPreQuantScalar = static_cast(preQuantScalar); - const bool hasMode = static_cast(modeAttr); - const bool reluNonDefault = op.getReluPreMode() != pto::ReluPreMode::NoRelu; - - SmallVector operands{dst, src}; - SmallVector templateArgVec{ - emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), - }; - StringRef callee = "TMOV"; - - if (hasFp) { - auto fpOT = mlir::dyn_cast(fp.getType()); - if (!fpOT) - return rewriter.notifyMatchFailure( - op, "tmov fp lowering expects opaque fp type"); - operands.push_back(fp); - templateArgVec.push_back(emitc::OpaqueAttr::get(ctx, fpOT.getValue().str())); - if (hasMode) - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, modeTok(modeAttr.getValue()))); - if (hasMode || reluNonDefault) - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); - callee = hasMode ? "TMOV" : "TMOV_FP"; - } else if (hasPreQuantScalar) { - operands.push_back(preQuantScalar); - if (hasMode) - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, modeTok(modeAttr.getValue()))); - if (hasMode || reluNonDefault) - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); - } else if (hasMode) { - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, modeTok(modeAttr.getValue()))); - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); - } else if (reluNonDefault) { - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); - } - - ArrayAttr templateArgs = - templateArgVec.size() == 2 && !hasFp && !hasPreQuantScalar && - !hasMode && !reluNonDefault - ? ArrayAttr{} - : rewriter.getArrayAttr(templateArgVec); - - rewriter.create( - loc, TypeRange{}, callee, - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TMOV_FP DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOMovFPToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMovFPOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value fp = peelUnrealized(adaptor.getFp()); - - // TMOV_FP(dstTileData, cTile, fbTile) - ArrayAttr templateArgs; - auto dstOT = mlir::dyn_cast(dst.getType()); - auto srcOT = mlir::dyn_cast(src.getType()); - auto fpOT = mlir::dyn_cast(fp.getType()); - if (dstOT && srcOT && fpOT) { - templateArgs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, fpOT.getValue().str()), - }); - } else { - templateArgs = ArrayAttr{}; - } - - SmallVector operands{dst, src, fp}; - rewriter.create( - loc, TypeRange{}, "TMOV_FP", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOQuantToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TQuantOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value fp = peelUnrealized(adaptor.getFp()); - - // Optional offset (INT8_ASYM only): passed as pointer (&offset) - Value offsetPtr; - if (op.getOffset()) { - Value offset = peelUnrealized(adaptor.getOffset()); - auto offsetOT = mlir::dyn_cast(offset.getType()); - if (offsetOT) { - offsetPtr = rewriter - .create( - loc, emitc::PointerType::get(offsetOT), "&", offset) - .getResult(); - } - } - - // TQUANT(dst, src, fp[, &offset]) - std::string quantTypeStr = - op.getQuantType() == pto::QuantType::INT8_SYM - ? "pto::QuantType::INT8_SYM" - : "pto::QuantType::INT8_ASYM"; - ArrayAttr templateArgs; - auto dstOT = mlir::dyn_cast(dst.getType()); - auto srcOT = mlir::dyn_cast(src.getType()); - auto fpOT = mlir::dyn_cast(fp.getType()); - if (dstOT && srcOT && fpOT) { - templateArgs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, quantTypeStr), - emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, fpOT.getValue().str()), - }); - } else { - templateArgs = ArrayAttr{}; - } - - SmallVector operands{dst, src, fp}; - if (offsetPtr) - operands.push_back(offsetPtr); - - rewriter.create( - loc, TypeRange{}, "TQUANT", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTODequantToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TDequantOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value scale = peelUnrealized(adaptor.getScale()); - Value offset = peelUnrealized(adaptor.getOffset()); - - // TDEQUANT(dst, src, scale, offset) - ArrayAttr templateArgs; - auto dstOT = mlir::dyn_cast(dst.getType()); - auto srcOT = mlir::dyn_cast(src.getType()); - auto scaleOT = mlir::dyn_cast(scale.getType()); - if (dstOT && srcOT && scaleOT) { - templateArgs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, scaleOT.getValue().str()), - }); - } else { - templateArgs = ArrayAttr{}; - } - - rewriter.create( - loc, TypeRange{}, "TDEQUANT", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, - /*operands=*/SmallVector{dst, src, scale, offset}); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TMRGSORT DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOMrgSortToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMrgSortOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - if (op.isFormat1()) { - Value src = peelUnrealized(adaptor.getSrcs().front()); - Value dst = peelUnrealized(adaptor.getDsts().front()); - Value blockLen = peelUnrealized(adaptor.getBlockLen()); - - SmallVector operands{dst, src, blockLen}; - rewriter.create( - loc, TypeRange{}, "TMRGSORT", - ArrayAttr{}, ArrayAttr{}, operands); - } else if (op.isFormat2()) { - // pto-isa API: - // TMRGSORT( - // dst, executedNumList, tmp, src0, src1[, src2[, src3]]); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDsts()[0]); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value excuted = peelUnrealized(adaptor.getExcuted()); - - SmallVector srcs; - srcs.reserve(adaptor.getSrcs().size()); - for (Value v : adaptor.getSrcs()) - srcs.push_back(peelUnrealized(v)); - - auto dstOT = mlir::dyn_cast(dst.getType()); - auto tmpOT = mlir::dyn_cast(tmp.getType()); - if (!dstOT || !tmpOT || srcs.size() < 2 || srcs.size() > 4) - return op.emitOpError("format2 expects dst/tmp tilebufs and 2 to 4 srcs"); - - SmallVector targs; - targs.reserve(2 + srcs.size() + 1); - targs.push_back(emitc::OpaqueAttr::get(ctx, dstOT.getValue().str())); - targs.push_back(emitc::OpaqueAttr::get(ctx, tmpOT.getValue().str())); - for (Value v : srcs) { - auto ot = mlir::dyn_cast(v.getType()); - if (!ot) - return op.emitOpError("format2 expects tilebuf srcs"); - targs.push_back(emitc::OpaqueAttr::get(ctx, ot.getValue().str())); - } - targs.push_back(emitc::OpaqueAttr::get(ctx, op.getExhausted() ? "true" : "false")); - ArrayAttr templateArgs = rewriter.getArrayAttr(targs); - - SmallVector operands{dst, excuted, tmp}; - operands.append(srcs.begin(), srcs.end()); - - rewriter.create( - loc, TypeRange{}, "TMRGSORT", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, operands); - } else { - return op.emitOpError("unsupported mrgsort_dps format"); - } - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TMUL DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOMulToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMulOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TMUL", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TMULS DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOMulsToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMulSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc0()); - Value dst = peelUnrealized(adaptor.getDst()); - Value scalar = peelUnrealized(adaptor.getScalar()); - - SmallVector operands{dst, src, scalar}; - rewriter.create( - loc, TypeRange{}, "TMULS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TNEG DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTONegToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TNegOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TNEG", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TNOT DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTONotToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TNotOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TNOT", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TOR DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOOrToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TOrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TOR", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TORS DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOOrsToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TOrSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - // NOTE: The conversion type system may materialize integers as emitc.opaque - // (e.g. "int32_t"). For EmitC call emission we can pass the scalar through - // directly without arith casts here. - Value s = adaptor.getScalar(); - - SmallVector operands{dst, src0, s}; - rewriter.create( - loc, TypeRange{}, "TORS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TPARTADD DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOPartAddToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPartAddOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TPARTADD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TPARTMAX DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOPartMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPartMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TPARTMAX", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TPARTMIN DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOPartMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPartMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TPARTMIN", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOPartArgMaxToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPartArgMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value src0Idx = peelUnrealized(adaptor.getSrc0Idx()); - Value src1Idx = peelUnrealized(adaptor.getSrc1Idx()); - Value dst = peelUnrealized(adaptor.getDst()); - Value dstIdx = peelUnrealized(adaptor.getDstIdx()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TPARTARGMAX", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1, dstIdx, src0Idx, src1Idx}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOPartArgMinToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPartArgMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value src0Idx = peelUnrealized(adaptor.getSrc0Idx()); - Value src1Idx = peelUnrealized(adaptor.getSrc1Idx()); - Value dst = peelUnrealized(adaptor.getDst()); - Value dstIdx = peelUnrealized(adaptor.getDstIdx()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TPARTARGMIN", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1, dstIdx, src0Idx, src1Idx}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TPARTMUL DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOPartMulToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPartMulOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TPARTMUL", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TPRELU DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOPreluToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPReluOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - // C++ interface: TPRELU(dst, src0, src1, tmp) — last parameter is tmp. - SmallVector operands{dst, src0, src1, tmp}; - rewriter.create( - loc, TypeRange{}, "TPRELU", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TRECIP DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORecipToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRecipOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TRECIP", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TRELU DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOReluToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TReluOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TRELU", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TREM DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORemToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRemOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - SmallVector operands{dst, src0, src1, tmp}; - rewriter.create( - loc, TypeRange{}, "TREM", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOFModToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TFModOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TFMOD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TREMS DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORemSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRemSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - Value scalar = peelUnrealized(adaptor.getScalar()); - SmallVector operands{dst, src, scalar, tmp}; - rewriter.create( - loc, TypeRange{}, "TREMS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOFModSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TFModSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value scalar = peelUnrealized(adaptor.getScalar()); - - SmallVector operands{dst, src, scalar}; - rewriter.create( - loc, TypeRange{}, "TFMODS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWEXPAND DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORowExpandToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TROWEXPAND", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowExpandAddToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandAddOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TROWEXPANDADD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowExpandExpdifToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandExpdifOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src0, src1, tmp}); - else - operands.assign({dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TROWEXPANDEXPDIF", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWEXPANDDIV DPS/memref op) -//===----------------------------------------------------------------------===// -// Helper: replace or erase based on whether op has results. -static void replaceOrEraseWithOpaqueCall(Operation *op, - StringRef callee, - ArrayRef args, - ConversionPatternRewriter &rewriter) { - TypeRange resultTypes = op->getResultTypes(); - auto call = rewriter.create( - op->getLoc(), resultTypes, callee, ArrayAttr{}, ArrayAttr{}, ValueRange(args)); - if (resultTypes.empty()) - rewriter.eraseOp(op); - else - rewriter.replaceOp(op, call.getResults()); -} - -static void replaceOrEraseWithOpaqueCallAndReturnDst(Operation *op, Value dst, - StringRef callee, - ArrayRef args, - ConversionPatternRewriter &rewriter) { - rewriter.create( - op->getLoc(), TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, ValueRange(args)); - if (op->getNumResults() == 1) - rewriter.replaceOp(op, dst); - else - rewriter.eraseOp(op); -} - -// ---------- TOp ---------- -struct PTOTGemvBiasToTGEMV_BIAS - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGemvBiasOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getA()); - Value b = peelUnrealized(adaptor.getB()); - Value bias = peelUnrealized(adaptor.getBias()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCall(op.getOperation(), "TGEMV_BIAS", - {dst, a, b, bias}, rewriter); - return success(); - } -}; - -struct PTOTGemvMXToTGEMV_MX - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGemvMxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getA()); - Value aScale = peelUnrealized(adaptor.getAScale()); - Value b = peelUnrealized(adaptor.getB()); - Value bScale = peelUnrealized(adaptor.getBScale()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", - {dst, a, aScale, b, bScale}, rewriter); - return success(); - } -}; - -struct PTOTGemvMXAccToTGEMV_MX - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGemvMxAccOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value cIn = peelUnrealized(adaptor.getCIn()); - Value a = peelUnrealized(adaptor.getA()); - Value aScale = peelUnrealized(adaptor.getAScale()); - Value b = peelUnrealized(adaptor.getB()); - Value bScale = peelUnrealized(adaptor.getBScale()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", - {dst, cIn, a, aScale, b, bScale}, rewriter); - return success(); - } -}; - -struct PTOTGemvMXBiasToTGEMV_MX - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGemvMxBiasOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getA()); - Value aScale = peelUnrealized(adaptor.getAScale()); - Value b = peelUnrealized(adaptor.getB()); - Value bScale = peelUnrealized(adaptor.getBScale()); - Value bias = peelUnrealized(adaptor.getBias()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", - {dst, a, aScale, b, bScale, bias}, rewriter); - return success(); - } -}; - -struct PTOTMatmulBiasToTMATMUL_BIAS - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMatmulBiasOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getA()); - Value b = peelUnrealized(adaptor.getB()); - Value bias = peelUnrealized(adaptor.getBias()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_BIAS", - {dst, a, b, bias}, rewriter); - return success(); - } -}; - -struct PTOTMatmulMXToTMATMUL_MX - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMatmulMxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getA()); - Value aScale = peelUnrealized(adaptor.getAScale()); - Value b = peelUnrealized(adaptor.getB()); - Value bScale = peelUnrealized(adaptor.getBScale()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_MX", - {dst, a, aScale, b, bScale}, rewriter); - return success(); - } -}; - -struct PTOTMatmulMXAccToTMATMUL_MX_ACC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMatmulMxAccOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value cIn = peelUnrealized(adaptor.getCIn()); - Value a = peelUnrealized(adaptor.getA()); - Value aScale = peelUnrealized(adaptor.getAScale()); - Value b = peelUnrealized(adaptor.getB()); - Value bScale = peelUnrealized(adaptor.getBScale()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_MX", - {dst, cIn, a, aScale, b, bScale}, rewriter); - return success(); - } -}; - -struct PTOTMatmulMXBiasToTMATMUL_MX_BIAS - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMatmulMxBiasOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getA()); - Value aScale = peelUnrealized(adaptor.getAScale()); - Value b = peelUnrealized(adaptor.getB()); - Value bScale = peelUnrealized(adaptor.getBScale()); - Value bias = peelUnrealized(adaptor.getBias()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_MX", - {dst, a, aScale, b, bScale, bias}, rewriter); - return success(); - } -}; - -struct PTORowExpandDivToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandDivOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src0, src1, tmp}); - else - operands.assign({dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TROWEXPANDDIV", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWEXPANDMUL DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORowExpandMulToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandMulOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src0, src1, tmp}); - else - operands.assign({dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TROWEXPANDMUL", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWEXPANDSUB DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORowExpandSubToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandSubOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src0, src1, tmp}); - else - operands.assign({dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TROWEXPANDSUB", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowExpandMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src0, src1, tmp}); - else - operands.assign({dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TROWEXPANDMAX", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowExpandMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src0, src1, tmp}); - else - operands.assign({dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TROWEXPANDMIN", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWMAX DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORowMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, tmp}; - rewriter.create( - loc, TypeRange{}, "TROWMAX", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowArgMaxToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowArgMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TROWARGMAX", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, tmp}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWMIN DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORowMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, tmp}; - rewriter.create( - loc, TypeRange{}, "TROWMIN", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowArgMinToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowArgMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TROWARGMIN", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, tmp}); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWSUM DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORowSumToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowSumOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, tmp}; - rewriter.create( - loc, TypeRange{}, "TROWSUM", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowProdToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowProdOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, tmp}; - rewriter.create( - loc, TypeRange{}, "TROWPROD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TRSQRT DPS/memref op) -// - no-tmp form : TRSQRT(dst, src) -// - tmp form : TRSQRT(dst, src, tmp) -//===----------------------------------------------------------------------===// - -struct PTORsqrtToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRsqrtOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - SmallVector operands{dst, src}; - if (Value tmp = adaptor.getTmp()) - operands.push_back(peelUnrealized(tmp)); - rewriter.create( - loc, TypeRange{}, "TRSQRT", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSCATTER DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOScatterToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TScatterOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - const bool hasMaskPattern = static_cast(op.getMaskPatternAttr()); - const bool hasIndexes = static_cast(op.getIndexes()); - if (hasMaskPattern == hasIndexes) { - return rewriter.notifyMatchFailure( - op, "expected exactly one of indexes operand or maskPattern attribute"); - } - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - if (auto mp = op.getMaskPatternAttr()) { - auto *ctx = rewriter.getContext(); - auto targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, maskPatternTok(mp)), - }); - rewriter.create( - loc, TypeRange{}, "TSCATTER", - /*args=*/ArrayAttr{}, /*templateArgs=*/targs, - /*operands=*/ValueRange{dst, src}); - } else { - Value idx = peelUnrealized(adaptor.getIndexes()); - rewriter.create( - loc, TypeRange{}, "TSCATTER", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, idx}); - } - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSEL DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSelToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSelOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value mask = peelUnrealized(adaptor.getMask()); - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, mask, src0, src1, tmp}; - rewriter.create( - loc, TypeRange{}, "TSEL", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSELS DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSelSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSelSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value mask = peelUnrealized(adaptor.getMask()); - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, mask, src, tmp, scalar}; - rewriter.create( - loc, TypeRange{}, "TSELS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSHL DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOShlSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TShlOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TSHL", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSHR DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOShrSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TShrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TSHR", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering for TSHLS/TSHRS DPS: shift by scalar) -//===----------------------------------------------------------------------===// - -struct PTOShlSConstToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TShlSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - SmallVector operands{dst, src, scalar}; - rewriter.create( - loc, TypeRange{}, "TSHLS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOShrSConstToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TShrSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - SmallVector operands{dst, src, scalar}; - rewriter.create( - loc, TypeRange{}, "TSHRS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (TSORT32 DPS/memref op: ins(src, idx[, tmp]) outs(dst)) -//===----------------------------------------------------------------------===// - -struct PTOSORT32SToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSort32Op op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value idx = peelUnrealized(adaptor.getIdx()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src, idx, tmp}); - else - operands.assign({dst, src, idx}); - rewriter.create( - loc, TypeRange{}, "TSORT32", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSQRT DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSqrtSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSqrtOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TSQRT", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSTORE_FP DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOStoreFPSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TStoreFPOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value fp = peelUnrealized(adaptor.getFp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, fp}; - rewriter.create( - loc, TypeRange{}, "TSTORE_FP", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSUB DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSubSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSubOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TSUB", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSUBC DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSubCSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSubCOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value src2 = peelUnrealized(adaptor.getSrc2()); - Value dst = peelUnrealized(adaptor.getDst()); - - // pto-isa does not provide NPU implementation for TSUBC yet. - // Decompose: dst = src0 - src1 + src2 - rewriter.create( - loc, TypeRange{}, "TSUB", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TADD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, dst, src2}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSUBS DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSubSSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSubSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, scalar}; - rewriter.create( - loc, TypeRange{}, "TSUBS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSUBSC DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSubSCToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSubSCOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - // pto-isa does not provide NPU implementation for TSUBSC yet. - // Decompose: dst = src0 - scalar + src1 - rewriter.create( - loc, TypeRange{}, "TSUBS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, scalar}); - rewriter.create( - loc, TypeRange{}, "TADD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, dst, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TXOR DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOXORToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TXorOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = peelUnrealized(adaptor.getTmp()); - SmallVector operands{dst, src0, src1, tmp}; - rewriter.create( - loc, TypeRange{}, "TXOR", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTOTTransToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TTransOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, tmp}; - rewriter.create( - loc, TypeRange{}, "TTRANS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TXORS DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOXORSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TXorSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, scalar, tmp}; - rewriter.create( - loc, TypeRange{}, "TXORS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - struct PTOPrintToTPRINT : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPrintOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - - SmallVector operands{src}; - rewriter.create( - loc, TypeRange{}, "TPRINT", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -// pto.print "format", %scalar -> PRINTF("format", scalar) -struct PTOPrintOpToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::PrintOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - std::string fmt = op.getFormat().str(); - if (fmt.empty()) - fmt = "%f"; - std::string quoted = "\""; - for (char c : fmt) { - if (c == '"' || c == '\\') - quoted += '\\'; - else if (c == '\n') - quoted += "\\n"; - else if (c == '\t') - quoted += "\\t"; - else - quoted += c; - } - quoted += "\""; - - Value scalar = peelUnrealized(adaptor.getScalar()); - auto argsAttr = rewriter.getArrayAttr( - {emitc::OpaqueAttr::get(ctx, quoted), - IntegerAttr::get(IndexType::get(ctx), 0)}); - rewriter.create( - loc, TypeRange{}, "cce::printf", - /*args=*/argsAttr, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; - -// pto.trap -> TRAP() -struct PTOTrapOpToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TrapOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - rewriter.create( - loc, TypeRange{}, "trap", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{}); - - rewriter.eraseOp(op); - return success(); - } -}; - -// ============================================================================= -// 2. BindTileOp Lowering (FIX: Trace back to physical address) -// ============================================================================= -struct PTOBindTileToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - struct TileBuildSpec { - std::string tileTypeStr; - bool useConstructor = false; - SmallVector constructorArgs; - }; - - static bool getIndexConst(Value v, int64_t &out) { - if (!v) - return false; - if (auto cst = v.getDefiningOp()) { - if (auto ia = dyn_cast(cst.getValue())) { - out = ia.getValue().getSExtValue(); - return true; - } - } - return false; - } - - static bool getTilePointerStrides(pto::TileBufConfigAttr configAttr, - Type elemTy, int64_t rows, int64_t cols, - int64_t &rowStride, - int64_t &colStride) { - if (rows == ShapedType::kDynamic || cols == ShapedType::kDynamic) - return false; - - int32_t blVal = 0; - if (auto blAttr = dyn_cast(configAttr.getBLayout())) - blVal = static_cast(blAttr.getValue()); - else if (auto intAttr = dyn_cast(configAttr.getBLayout())) - blVal = static_cast(intAttr.getInt()); - - int32_t slVal = 0; - if (auto slAttr = dyn_cast(configAttr.getSLayout())) - slVal = static_cast(slAttr.getValue()); - else if (auto intAttr = dyn_cast(configAttr.getSLayout())) - slVal = static_cast(intAttr.getInt()); - - bool boxed = slVal != 0; - int64_t innerRows = 1; - int64_t innerCols = 1; - if (boxed) { - int32_t fractal = 512; - if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) - fractal = static_cast(frAttr.getInt()); - - unsigned elemBytes = pto::getPTOStorageElemByteSize(elemTy); - if (elemBytes == 0) - return false; - - switch (fractal) { - case 1024: - innerRows = 16; - innerCols = 16; - break; - case 32: - innerRows = 16; - innerCols = 2; - break; - case 512: - if (slVal == 1) { - innerRows = 16; - innerCols = 32 / elemBytes; - } else if (slVal == 2) { - innerRows = 32 / elemBytes; - innerCols = 16; - } else { - return false; - } - break; - default: - return false; - } - if (innerRows <= 0 || innerCols <= 0) - return false; - } - - if (!boxed) { - if (blVal == 1) { - rowStride = 1; - colStride = rows; - } else { - rowStride = cols; - colStride = 1; - } - return true; - } - - if (blVal == 1) { - if (slVal != 1) - return false; - rowStride = innerCols; - colStride = rows; - return true; - } - - rowStride = cols; - colStride = innerRows; - return true; - } - - LogicalResult matchAndRewrite(pto::BindTileOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - auto configAttr = op.getConfigAttr(); - auto viewSemantics = op->getAttrOfType("pto.view_semantics"); - bool isSubView = viewSemantics && viewSemantics.getValue() == "subview"; - - auto peelAllCasts = [](Value v) { - while (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(0); - if (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(); - return v; - }; - auto isTileLike = [](Value v) -> bool { - auto ot = dyn_cast(v.getType()); - if (!ot) - return false; - StringRef s = ot.getValue(); - return s.contains("Tile<") || s.contains("ConvTile<"); - }; - auto buildTileSpec = [&]() -> FailureOr { - auto resMrTy = dyn_cast(op.getType()); - if (!resMrTy) - return failure(); - - const char *roleTok = "TileType::Vec"; - if (auto asAttr = - dyn_cast_or_null(resMrTy.getMemorySpace())) { - switch (asAttr.getAddressSpace()) { - case pto::AddressSpace::VEC: - roleTok = "TileType::Vec"; - break; - case pto::AddressSpace::MAT: - roleTok = "TileType::Mat"; - break; - case pto::AddressSpace::LEFT: - roleTok = "TileType::Left"; - break; - case pto::AddressSpace::RIGHT: - roleTok = "TileType::Right"; - break; - case pto::AddressSpace::ACC: - roleTok = "TileType::Acc"; - break; - case pto::AddressSpace::BIAS: - roleTok = "TileType::Bias"; - break; - case pto::AddressSpace::SCALING: - roleTok = "TileType::Scaling"; - break; - case pto::AddressSpace::GM: - case pto::AddressSpace::Zero: - roleTok = "TileType::Vec"; - break; - } - } - - Type elemTy = resMrTy.getElementType(); - Type emitElemTy = getTypeConverter()->convertType(elemTy); - if (!emitElemTy) - return failure(); - auto emitElemOpaque = dyn_cast(emitElemTy); - if (!emitElemOpaque) - return failure(); - std::string elemTypeStr = emitElemOpaque.getValue().str(); - - if (resMrTy.getRank() < 2) - return failure(); - int64_t rows = resMrTy.getDimSize(0); - int64_t cols = resMrTy.getDimSize(1); - if (rows == ShapedType::kDynamic || cols == ShapedType::kDynamic) - return failure(); - - std::string blTok = "BLayout::RowMajor"; - if (auto blAttr = dyn_cast(configAttr.getBLayout())) { - if (static_cast(blAttr.getValue()) == 1) - blTok = "BLayout::ColMajor"; - } - pto::BLayout blayout = getTileBufBLayoutValue(configAttr); - - if (isSubView) { - auto subMrTy = dyn_cast(op.getSource().getType()); - auto subViewOp = op.getSource().getDefiningOp(); - if (subMrTy && subMrTy.getRank() >= 2 && subViewOp) { - int64_t subRows = subMrTy.getDimSize(0); - int64_t subCols = subMrTy.getDimSize(1); - SmallVector inheritedStrides; - int64_t inheritedOffset = ShapedType::kDynamic; - - if (!pto::isPTOFloat4PackedType(elemTy) && - subRows != ShapedType::kDynamic && - subCols != ShapedType::kDynamic && - succeeded(getStridesAndOffset(subMrTy, inheritedStrides, - inheritedOffset)) && - inheritedStrides.size() >= 2) { - int64_t childRowStride = 0; - int64_t childColStride = 0; - bool sameStrides = getTilePointerStrides( - configAttr, elemTy, subRows, subCols, childRowStride, - childColStride); - sameStrides = sameStrides && - inheritedStrides[0] == childRowStride && - inheritedStrides[1] == childColStride; - if (sameStrides) { - rows = subRows; - cols = subCols; - } - } - } - } - - std::string slTok = "SLayout::NoneBox"; - if (auto slAttr = dyn_cast(configAttr.getSLayout())) { - int32_t slVal = static_cast(slAttr.getValue()); - slTok = (slVal == 1) ? "SLayout::RowMajor" - : (slVal == 2) ? "SLayout::ColMajor" - : "SLayout::NoneBox"; - } - - int32_t fractal = 512; - if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) - fractal = frAttr.getInt(); - - std::string padTok = "PadValue::Null"; - if (auto padAttr = dyn_cast(configAttr.getPad())) { - switch (static_cast(padAttr.getValue())) { - case 1: - padTok = "PadValue::Zero"; - break; - case 2: - padTok = "PadValue::Max"; - break; - case 3: - padTok = "PadValue::Min"; - break; - default: - padTok = "PadValue::Null"; - break; - } - } - - std::string compactTok = "CompactMode::Null"; - if (auto compactAttr = dyn_cast(configAttr.getCompactMode())) { - switch (static_cast(compactAttr.getValue())) { - case 1: - compactTok = "CompactMode::Normal"; - break; - case 2: - compactTok = "CompactMode::RowPlusOne"; - break; - default: - compactTok = "CompactMode::Null"; - break; - } - } - - std::string vrowTok, vcolTok; - bool useConstructor = false; - bool rowIsDynamic = false; - bool colIsDynamic = false; - SmallVector constructorArgs; - - Value vRow = op.getValidRow(); - Value vCol = op.getValidCol(); - Value vRowEmitC = adaptor.getValidRow(); - Value vColEmitC = adaptor.getValidCol(); - bool forceDynamicValid = op->hasAttr(kForceDynamicValidShapeAttrName); - int64_t cRow = 0, cCol = 0; - bool rowIsConst = vRow && getIndexConst(vRow, cRow); - bool colIsConst = vCol && getIndexConst(vCol, cCol); - - auto makeCtorDimValue = [&](Value emitted, int64_t fallback) -> Value { - if (emitted) - return emitted; - return makeEmitCIntConstant( - rewriter, loc, emitc::OpaqueType::get(ctx, "int32_t"), fallback); - }; - auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { - if (!emitted || !pto::isPTOFloat4PackedType(elemTy)) - return emitted; - int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; - if (dimIdx != packedDim) - return emitted; - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); - return rewriter.create(loc, i32Ty, emitted, two).getResult(); - }; - - if (forceDynamicValid) { - vrowTok = "-1"; - vcolTok = "-1"; - useConstructor = true; - constructorArgs.push_back( - makeCtorDimValue(maybeScaleDynamicValid(vRowEmitC, 0), - renderTileTemplateDim(rowIsConst ? cRow : rows, - elemTy, blayout, 0))); - constructorArgs.push_back( - makeCtorDimValue(maybeScaleDynamicValid(vColEmitC, 1), - renderTileTemplateDim(colIsConst ? cCol : cols, - elemTy, blayout, 1))); - } else { - if (rowIsConst) { - vrowTok = std::to_string( - renderTileTemplateDim(cRow, elemTy, blayout, 0)); - } else if (vRow) { - vrowTok = "-1"; - rowIsDynamic = true; - useConstructor = true; - } else { - vrowTok = std::to_string( - renderTileTemplateDim(rows, elemTy, blayout, 0)); - } - - if (colIsConst) { - vcolTok = std::to_string( - renderTileTemplateDim(cCol, elemTy, blayout, 1)); - } else if (vCol) { - vcolTok = "-1"; - colIsDynamic = true; - useConstructor = true; - } else { - vcolTok = std::to_string( - renderTileTemplateDim(cols, elemTy, blayout, 1)); - } - - if (useConstructor) { - if (rowIsDynamic && vRowEmitC) - constructorArgs.push_back(maybeScaleDynamicValid(vRowEmitC, 0)); - if (colIsDynamic && vColEmitC) - constructorArgs.push_back(maybeScaleDynamicValid(vColEmitC, 1)); - } - } - - std::string tileTypeStr = std::string("Tile<") + roleTok + ", " + - elemTypeStr + ", " + - std::to_string(renderTileTemplateDim( - rows, elemTy, blayout, 0)) + - ", " + - std::to_string(renderTileTemplateDim( - cols, elemTy, blayout, 1)) + - ", " + blTok + - ", " + vrowTok + ", " + vcolTok + ", " + slTok + - ", " + std::to_string(fractal) + ", " + padTok + - ", " + compactTok + - ">"; - return TileBuildSpec{tileTypeStr, useConstructor, constructorArgs}; - }; - - auto buildTileValue = [&](const TileBuildSpec &spec, - bool forceDeclaration = false) -> Value { - auto tileType = emitc::OpaqueType::get(ctx, spec.tileTypeStr); - if (spec.useConstructor && !forceDeclaration) { - return rewriter - .create(loc, tileType, spec.tileTypeStr, - ArrayAttr{}, ArrayAttr{}, - ValueRange(spec.constructorArgs)) - .getResult(0); - } - - return rewriter - .create(loc, tileType, emitc::OpaqueAttr::get(ctx, "")) - .getResult(); - }; - - auto emitElemTypeToString = [&](Type elemTy) -> std::string { - return getEmitCScalarTypeToken(elemTy); - }; - - auto buildIntegralAddress = [&](Value sourceValue) -> FailureOr { - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - auto rcU64 = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - - Value rawPtr = sourceValue; - if (auto ot = dyn_cast(sourceValue.getType())) { - StringRef tyStr = ot.getValue(); - if (tyStr.contains("Tile<") || tyStr.contains("ConvTile<")) { - auto srcMrTy = dyn_cast(op.getSource().getType()); - if (!srcMrTy) - return failure(); - std::string elemTok = emitElemTypeToString(srcMrTy.getElementType()); - pto::AddressSpace as = pto::AddressSpace::GM; - if (auto asAttr = - dyn_cast_or_null(srcMrTy.getMemorySpace())) - as = asAttr.getAddressSpace(); - rawPtr = materializeTileDataValue(rewriter, loc, sourceValue, as, - elemTok); - } - } - - if (isSetFFTsPointerLikeType(rawPtr.getType())) { - return rewriter - .create(loc, u64Ty, "reinterpret_cast", - ArrayAttr{}, rcU64, ValueRange{rawPtr}) - .getResult(0); - } - - if (rawPtr.getType() == u64Ty) - return rawPtr; - return rewriter.create(loc, u64Ty, rawPtr).getResult(); - }; - - if (op.getSource().getDefiningOp()) { - FailureOr tileSpec = buildTileSpec(); - if (failed(tileSpec)) - return failure(); - rewriter.replaceOp(op, buildTileValue(*tileSpec)); - return success(); - } - - Value tileCandidate = peelAllCasts(adaptor.getSource()); - if (viewSemantics && viewSemantics.getValue() == "bitcast" && - isTileLike(tileCandidate)) { - FailureOr tileSpec = buildTileSpec(); - if (failed(tileSpec)) - return failure(); - Value dstTile = buildTileValue(*tileSpec); - FailureOr addr = buildIntegralAddress(tileCandidate); - if (failed(addr)) - return failure(); - - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dstTile, *addr}); - rewriter.replaceOp(op, dstTile); - return success(); - } - - if (viewSemantics && viewSemantics.getValue() == "treshape" && - isTileLike(tileCandidate)) { - FailureOr tileSpec = buildTileSpec(); - if (failed(tileSpec)) - return failure(); - Value dstTile = buildTileValue(*tileSpec, /*forceDeclaration=*/true); - - rewriter.create(loc, TypeRange{}, "TRESHAPE", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dstTile, tileCandidate}); - rewriter.replaceOp(op, dstTile); - return success(); - } - - // Subview origins are kept distinct from generic tile rebinding: - // even when source/destination C++ tile types match, subview may carry - // shifted base address semantics and should materialize a fresh handle. - if (isSubView) { - FailureOr tileSpec = buildTileSpec(); - if (failed(tileSpec)) - return failure(); - Value dstTile = buildTileValue(*tileSpec); - FailureOr addr = buildIntegralAddress(tileCandidate); - if (failed(addr)) - return failure(); - - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dstTile, *addr}); - rewriter.replaceOp(op, dstTile); - return success(); - } - - // Generic tile-to-tile rebind path: preserve the same backing storage and - // rebuild a sibling tile with updated metadata/valid dims. - if (isTileLike(tileCandidate)) { - FailureOr tileSpec = buildTileSpec(); - if (failed(tileSpec)) - return failure(); - - if (!tileSpec->useConstructor) { - if (auto srcTy = dyn_cast(tileCandidate.getType())) { - if (srcTy.getValue() == tileSpec->tileTypeStr) { - rewriter.replaceOp(op, tileCandidate); - return success(); - } - } - } - - Value dstTile = buildTileValue(*tileSpec); - FailureOr addr = buildIntegralAddress(tileCandidate); - if (failed(addr)) - return failure(); - - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dstTile, *addr}); - rewriter.replaceOp(op, dstTile); - return success(); - } - - SmallVector physAddrs; - Value source = op.getSource(); - - while (auto castOp = source.getDefiningOp()) - source = castOp.getOperand(0); - - if (auto upstreamCast = source.getDefiningOp()) { - auto upstreamOperands = upstreamCast.getAddrs(); - physAddrs.append(upstreamOperands.begin(), upstreamOperands.end()); - } else { - physAddrs.push_back(adaptor.getSource()); - } - - Value vRow = op.getValidRow(); - Value vCol = op.getValidCol(); - - auto newCast = rewriter.create( - loc, op.getType(), physAddrs, vRow ? vRow : Value(), - vCol ? vCol : Value(), configAttr); - if (viewSemantics) - newCast->setAttr("pto.view_semantics", viewSemantics); - if (op->hasAttr(kForceDynamicValidShapeAttrName)) - newCast->setAttr(kForceDynamicValidShapeAttrName, - op->getAttr(kForceDynamicValidShapeAttrName)); - rewriter.replaceOp(op, newCast.getResult()); - - return success(); - } -}; - -struct PTOAllocTileToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::AllocTileOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - MLIRContext *ctx = rewriter.getContext(); - auto tileTy = cast(op.getResult().getType()); - auto tileTypeString = getEmitCTileTypeString(tileTy); - if (!tileTypeString) - return rewriter.notifyMatchFailure( - op, "only rank-2 alloc_tile handles can be converted to EmitC"); - - Type convertedTy = getTypeConverter()->convertType(tileTy); - if (!convertedTy) - convertedTy = emitc::OpaqueType::get(ctx, *tileTypeString); - - auto validShape = tileTy.getValidShape(); - bool hasDynamicValidDim = - llvm::any_of(validShape, [](int64_t dim) { return dim < 0; }); - bool useConstructor = hasDynamicValidDim; - - SmallVector constructorArgs; - if (useConstructor) { - Type elemTy = tileTy.getElementType(); - pto::BLayout blayout = getTileBufBLayoutValue(tileTy.getConfigAttr()); - auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { - if (!emitted || !pto::isPTOFloat4PackedType(elemTy)) - return emitted; - int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; - if (dimIdx != packedDim) - return emitted; - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); - return rewriter.create(loc, i32Ty, emitted, two) - .getResult(); - }; - - if (validShape.size() > 0 && validShape[0] < 0) { - Value validRow = adaptor.getValidRow(); - if (!validRow) - return rewriter.notifyMatchFailure( - op, "dynamic alloc_tile valid row must have an operand"); - if (validRow) - validRow = peelUnrealized(validRow); - constructorArgs.push_back(maybeScaleDynamicValid(validRow, 0)); - } - if (validShape.size() > 1 && validShape[1] < 0) { - Value validCol = adaptor.getValidCol(); - if (!validCol) - return rewriter.notifyMatchFailure( - op, "dynamic alloc_tile valid col must have an operand"); - if (validCol) - validCol = peelUnrealized(validCol); - constructorArgs.push_back(maybeScaleDynamicValid(validCol, 1)); - } - } - - Value tile; - if (useConstructor) { - tile = rewriter - .create( - loc, convertedTy, *tileTypeString, ArrayAttr{}, - ArrayAttr{}, ValueRange(constructorArgs)) - .getResult(0); - } else { - tile = - rewriter - .create( - loc, convertedTy, emitc::OpaqueAttr::get(ctx, "")) - .getResult(); - } - - Value addr = adaptor.getAddr(); - if (addr) { - addr = peelUnrealized(addr); - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - if (isa(addr.getType()) || - (isa(addr.getType()) && - cast(addr.getType()).getValue().ends_with("*"))) { - auto rcU64 = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - addr = rewriter - .create(loc, u64Ty, "reinterpret_cast", - ArrayAttr{}, rcU64, - ValueRange{addr}) - .getResult(0); - } else if (addr.getType() != u64Ty) { - addr = rewriter.create(loc, u64Ty, addr).getResult(); - } - - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{tile, addr}); - } - - rewriter.replaceOp(op, tile); - return success(); - } -}; - -static FailureOr -createEmitCTileVariable(ConversionPatternRewriter &rewriter, Location loc, - const TypeConverter *typeConverter, - pto::TileBufType tileTy) { - auto tileTypeString = getEmitCTileTypeString(tileTy); - if (!tileTypeString) - return failure(); - - Type convertedTy = typeConverter->convertType(tileTy); - if (!convertedTy) - convertedTy = emitc::OpaqueType::get(rewriter.getContext(), *tileTypeString); - - return rewriter - .create( - loc, convertedTy, emitc::OpaqueAttr::get(rewriter.getContext(), "")) - .getResult(); -} - -struct PTOTReshapeToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TReshapeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto tileTy = dyn_cast(op.getResult().getType()); - if (!tileTy) - return failure(); - - FailureOr dst = - createEmitCTileVariable(rewriter, op.getLoc(), getTypeConverter(), tileTy); - if (failed(dst)) - return failure(); - - Value src = peelUnrealized(adaptor.getSrc()); - if (auto castOp = src.getDefiningOp()) - src = castOp.getOperand(); - - rewriter.create(op.getLoc(), TypeRange{}, "TRESHAPE", - ArrayAttr{}, ArrayAttr{}, - ValueRange{*dst, src}); - rewriter.replaceOp(op, *dst); - return success(); - } -}; - -struct PTOBitcastToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::BitcastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto dstTy = dyn_cast(op.getResult().getType()); - auto srcTy = dyn_cast(op.getSrc().getType()); - if (!dstTy || !srcTy) - return failure(); - - FailureOr dst = - createEmitCTileVariable(rewriter, op.getLoc(), getTypeConverter(), dstTy); - if (failed(dst)) - return failure(); - - Value src = peelUnrealized(adaptor.getSrc()); - if (auto castOp = src.getDefiningOp()) - src = castOp.getOperand(); - - pto::AddressSpace as = pto::AddressSpace::GM; - if (auto asAttr = - dyn_cast_or_null(srcTy.getMemorySpace())) - as = asAttr.getAddressSpace(); - std::string elemTok = getEmitCScalarTypeToken(srcTy.getElementType()); - - Value rawPtr = materializeTileDataValue(rewriter, op.getLoc(), src, as, elemTok); - auto u64Ty = emitc::OpaqueType::get(rewriter.getContext(), "uint64_t"); - Value addr = rawPtr; - if (isSetFFTsPointerLikeType(rawPtr.getType())) { - auto rcU64 = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), - "uint64_t")}); - addr = rewriter - .create(op.getLoc(), u64Ty, - "reinterpret_cast", ArrayAttr{}, - rcU64, ValueRange{rawPtr}) - .getResult(0); - } else if (addr.getType() != u64Ty) { - addr = rewriter.create(op.getLoc(), u64Ty, addr).getResult(); - } - - rewriter.create(op.getLoc(), TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{*dst, addr}); - rewriter.replaceOp(op, *dst); - return success(); - } -}; - -struct PTOMaterializeTileToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - static bool isTileLike(Value v) { - auto ot = dyn_cast(v.getType()); - if (!ot) - return false; - StringRef s = ot.getValue(); - return s.contains("Tile<") || s.contains("ConvTile<"); - } - - LogicalResult matchAndRewrite(pto::MaterializeTileOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - MLIRContext *ctx = rewriter.getContext(); - auto tileTy = cast(op.getResult().getType()); - auto tileTypeString = getEmitCTileTypeString(tileTy); - if (!tileTypeString) - return rewriter.notifyMatchFailure( - op, "only rank-2 tile_buf handles can be materialized to EmitC"); - - Type convertedTy = getTypeConverter()->convertType(tileTy); - if (!convertedTy) - convertedTy = emitc::OpaqueType::get(ctx, *tileTypeString); - - Value source = peelUnrealized(adaptor.getSource()); - if (auto castOp = source.getDefiningOp()) - source = castOp.getOperand(); - - auto viewSemantics = op->getAttrOfType("pto.view_semantics"); - bool forceDynamicValid = op->hasAttr(kForceDynamicValidShapeAttrName); - bool isReshape = viewSemantics && viewSemantics.getValue() == "treshape"; - bool isSubview = viewSemantics && viewSemantics.getValue() == "subview"; - bool sourceIsDeclaredTile = - op.getSource().getDefiningOp(); - - auto createTileValue = [&]() -> Value { - SmallVector constructorArgs; - bool useConstructor = false; - pto::BLayout blayout = getTileBufBLayoutValue(tileTy.getConfigAttr()); - Type elemTy = tileTy.getElementType(); - auto shape = tileTy.getShape(); - auto validShape = tileTy.getValidShape(); - - auto makeCtorDimValue = [&](Value emitted, int64_t fallback) -> Value { - if (emitted) - return emitted; - return makeEmitCIntConstant( - rewriter, loc, emitc::OpaqueType::get(ctx, "int32_t"), fallback); - }; - auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { - if (!emitted || !pto::isPTOFloat4PackedType(elemTy)) - return emitted; - int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; - if (dimIdx != packedDim) - return emitted; - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); - return rewriter.create(loc, i32Ty, emitted, two).getResult(); - }; - auto fallbackDim = [&](int dimIdx) { - return renderTileTemplateDim(shape[dimIdx], elemTy, blayout, dimIdx); - }; - - if (forceDynamicValid) { - useConstructor = true; - constructorArgs.push_back(makeCtorDimValue( - maybeScaleDynamicValid(adaptor.getValidRow(), 0), fallbackDim(0))); - constructorArgs.push_back(makeCtorDimValue( - maybeScaleDynamicValid(adaptor.getValidCol(), 1), fallbackDim(1))); - } else { - if (validShape[0] == ShapedType::kDynamic) { - useConstructor = true; - constructorArgs.push_back(makeCtorDimValue( - maybeScaleDynamicValid(adaptor.getValidRow(), 0), fallbackDim(0))); - } - if (validShape[1] == ShapedType::kDynamic) { - useConstructor = true; - constructorArgs.push_back(makeCtorDimValue( - maybeScaleDynamicValid(adaptor.getValidCol(), 1), fallbackDim(1))); - } - } - - if (useConstructor) { - return rewriter - .create(loc, convertedTy, *tileTypeString, - ArrayAttr{}, ArrayAttr{}, - ValueRange(constructorArgs)) - .getResult(0); - } - - return rewriter - .create(loc, convertedTy, - emitc::OpaqueAttr::get(ctx, "")) - .getResult(); - }; - - if (!isSubview && !forceDynamicValid && isTileLike(source)) { - if (auto srcTy = dyn_cast(source.getType())) { - if (srcTy.getValue() == *tileTypeString) { - rewriter.replaceOp(op, source); - return success(); - } - } - } - - Value tile = createTileValue(); - if (sourceIsDeclaredTile) { - rewriter.replaceOp(op, tile); - return success(); - } - - if (isReshape && isTileLike(source)) { - rewriter.create(loc, TypeRange{}, "TRESHAPE", - ArrayAttr{}, ArrayAttr{}, - ValueRange{tile, source}); - rewriter.replaceOp(op, tile); - return success(); - } - - pto::AddressSpace as = pto::AddressSpace::GM; - if (auto asAttr = - dyn_cast_or_null(tileTy.getMemorySpace())) - as = asAttr.getAddressSpace(); - std::string elemTok = getEmitCScalarTypeToken(tileTy.getElementType()); - - Value rawPtr = source; - if (isTileLike(rawPtr)) - rawPtr = materializeTileDataValue(rewriter, loc, rawPtr, as, elemTok); - - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - Value addr = rawPtr; - if (isSetFFTsPointerLikeType(rawPtr.getType())) { - auto rcU64 = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - addr = rewriter - .create(loc, u64Ty, "reinterpret_cast", - ArrayAttr{}, rcU64, - ValueRange{rawPtr}) - .getResult(0); - } else if (rawPtr.getType() != u64Ty) { - addr = rewriter.create(loc, u64Ty, rawPtr).getResult(); - } - - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{tile, addr}); - rewriter.replaceOp(op, tile); - return success(); - } -}; - -// ============================================================================= -// Arith CmpI -> EmitC Cmp -// ============================================================================= -class ArithCmpIToEmitC : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - // 将 arith.cmpi 转换为 emitc.cmp - // 映射 Predicate: eq -> equal, slt -> less, etc. - emitc::CmpPredicate emitcPred = emitc::CmpPredicate::eq; - const bool isUnsignedPred = - op.getPredicate() == arith::CmpIPredicate::ult || - op.getPredicate() == arith::CmpIPredicate::ule || - op.getPredicate() == arith::CmpIPredicate::ugt || - op.getPredicate() == arith::CmpIPredicate::uge; - switch (op.getPredicate()) { - case arith::CmpIPredicate::eq: emitcPred = emitc::CmpPredicate::eq; break; - case arith::CmpIPredicate::ne: emitcPred = emitc::CmpPredicate::ne; break; - case arith::CmpIPredicate::slt: emitcPred = emitc::CmpPredicate::lt; break; - case arith::CmpIPredicate::sle: emitcPred = emitc::CmpPredicate::le; break; - case arith::CmpIPredicate::sgt: emitcPred = emitc::CmpPredicate::gt; break; - case arith::CmpIPredicate::sge: emitcPred = emitc::CmpPredicate::ge; break; - // ... 处理无符号比较 (ult, ule 等) ... - case arith::CmpIPredicate::ult: emitcPred = emitc::CmpPredicate::lt; break; - case arith::CmpIPredicate::ule: emitcPred = emitc::CmpPredicate::le; break; - case arith::CmpIPredicate::ugt: emitcPred = emitc::CmpPredicate::gt; break; - case arith::CmpIPredicate::uge: emitcPred = emitc::CmpPredicate::ge; break; - } - - Type resTy = getTypeConverter()->convertType(op.getType()); - if (!resTy) - return failure(); - - Value lhs = adaptor.getLhs(); - Value rhs = adaptor.getRhs(); - if (isUnsignedPred) { - Type opTy = op.getLhs().getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure( - op, "expected scalar integer or index operands"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - if (bitWidth != 1) { - lhs = castSignlessIntToUnsignedSameWidth(rewriter, loc, lhs, bitWidth); - rhs = castSignlessIntToUnsignedSameWidth(rewriter, loc, rhs, bitWidth); - } - } - - rewriter.replaceOpWithNewOp( - op, - /*resultType=*/resTy, // i1 -> bool/i1 - emitcPred, - lhs, - rhs - ); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// Section Op Lowering -//===----------------------------------------------------------------------===// -static bool isA5NoSplitPipeOp(Operation *op) { - if (auto talloc = dyn_cast(op)) - return talloc.getSplit() == 0; - if (auto tpush = dyn_cast(op)) - return tpush.getSplit() == 0; - if (auto tpop = dyn_cast(op)) - return tpop.getSplit() == 0; - if (auto tfree = dyn_cast(op)) - return tfree.getSplit() == 0; - if (auto tpush = dyn_cast(op)) - return tpush.getSplit() == 0; - if (auto tpush = dyn_cast(op)) - return tpush.getSplit() == 0; - if (auto talloc = dyn_cast(op)) - return talloc.getSplit() == 0; - if (auto talloc = dyn_cast(op)) - return talloc.getSplit() == 0; - if (auto tpop = dyn_cast(op)) - return tpop.getSplit() == 0; - if (auto tpop = dyn_cast(op)) - return tpop.getSplit() == 0; - if (auto tfree = dyn_cast(op)) - return tfree.getSplit() == 0; - if (auto tfree = dyn_cast(op)) - return tfree.getSplit() == 0; - return false; -} - -static bool hasExplicitSubblockControl(Operation *op) { - bool hasControl = false; - op->walk([&](Operation *nested) { - if (isa(nested)) { - hasControl = true; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - return hasControl; -} - -static bool needsA5NoSplitVectorGuard(Operation *op) { - auto arch = getTargetArch(op); - if (arch != PTOArch::A5) - return false; - bool isVectorScope = isa(op); - if (auto func = dyn_cast(op)) { - if (auto kernelKindAttr = - func->getAttrOfType( - FunctionKernelKindAttr::name)) { - isVectorScope = - kernelKindAttr.getKernelKind() == FunctionKernelKind::Vector; - } - } - if (!isVectorScope) - return false; - if (hasExplicitSubblockControl(op)) - return false; - - bool hasNoSplitPipe = false; - op->walk([&](Operation *nested) { - if (!isA5NoSplitPipeOp(nested)) - return WalkResult::advance(); - hasNoSplitPipe = true; - return WalkResult::interrupt(); - }); - return hasNoSplitPipe; -} - -template -struct SectionToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - std::string getMacroName() const { - if (std::is_same::value) - return "__DAV_CUBE__"; - if (std::is_same::value) - return "__DAV_VEC__"; - return "UNKNOWN_MACRO"; - } - - LogicalResult - matchAndRewrite(SectionOpTy op, typename SectionOpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - bool needsNoSplitGuard = needsA5NoSplitVectorGuard(op.getOperation()); - - std::string startMacro = "\n#if defined(" + getMacroName() + ")"; - rewriter.create(loc, startMacro); - - if constexpr (std::is_same_v) { - // Vector mask is a global HW state and may be modified by previous kernels - // (or earlier sections). Reset it to a well-defined state for deterministic - // execution of VEC ops. - rewriter.create(loc, "set_mask_norm();"); - rewriter.create(loc, "set_vector_mask(-1, -1);"); - } - - if (needsNoSplitGuard) { - rewriter.create( - loc, "if (get_subblockid() == 0) {"); - } - - Block &innerBlock = op.getBody().front(); - if (!innerBlock.empty()) { - rewriter.inlineBlockBefore(&innerBlock, op.getOperation(), ValueRange{}); - } - - if (needsNoSplitGuard) - rewriter.create(loc, "}"); - - std::string endMacro = "#endif // " + getMacroName() + "\n"; - rewriter.create(loc, endMacro); - - rewriter.eraseOp(op); - - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// SCF Control-Flow Pre-Lowering -// -// EmitC translation supports `emitc.for`/`emitc.if` plus CFG-style -// `cf.br`/`cf.cond_br`. Upstream SCFToEmitC patterns only cover `scf.for` and -// `scf.if`, so we pre-lower some SCF ops into those supported forms. -//===----------------------------------------------------------------------===// - -namespace { - -static bool isTriviallyInlineableExecuteRegion(scf::ExecuteRegionOp op) { - Region &r = op.getRegion(); - if (!r.hasOneBlock()) - return false; - Block &b = r.front(); - return isa_and_nonnull(b.getTerminator()); -} - -static bool needsWholeFunctionSCFToCF(func::FuncOp func) { - bool needs = false; - func.walk([&](Operation *op) { - if (!isa(op)) - return WalkResult::advance(); - Operation *parentOp = op->getParentOp(); - - // `scf.execute_region` can legally appear in single-block parents. Only - // require whole-function SCFToCF if we need to lower it into CFG blocks - // (multi-block region / non-trivial terminators). - if (auto exec = dyn_cast(op)) { - if (parentOp && parentOp->hasTrait() && - !isTriviallyInlineableExecuteRegion(exec)) { - needs = true; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - } - - if (parentOp && parentOp->hasTrait()) { - needs = true; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - return needs; -} - -// scf.execute_region is semantically just an inlined region producing results -// via scf.yield. Inline it to the parent block to avoid extra lowering needs. -struct SCFExecuteRegionInline - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(scf::ExecuteRegionOp op, - PatternRewriter &rewriter) const override { - if (op.getRegion().empty()) - return rewriter.notifyMatchFailure(op, "expected non-empty region"); - - Block &innerBlock = op.getRegion().front(); - auto yield = dyn_cast(innerBlock.getTerminator()); - if (!yield) - return rewriter.notifyMatchFailure(op, "expected scf.yield terminator"); - - // Move the body operations before the execute_region op. - rewriter.inlineBlockBefore(&innerBlock, op.getOperation(), ValueRange{}); - - // Replace execute_region results with yielded values, then erase the yield. - rewriter.replaceOp(op, yield.getOperands()); - rewriter.eraseOp(yield); - return success(); - } -}; - -// Lower scf.execute_region into CFG blocks with cf.br/cf.cond_br by inlining the -// region blocks into the parent region and rewriting scf.yield to branch into a -// continuation block carrying results. -// -// Note: This requires the parent region to allow multiple blocks (e.g. the -// function body CFG region). For execute_region nested in single-block regions -// (scf.for/scf.if), run SCFToCF first to eliminate the single-block constraint. -struct SCFExecuteRegionToCF : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(scf::ExecuteRegionOp op, - PatternRewriter &rewriter) const override { - if (isTriviallyInlineableExecuteRegion(op)) - return rewriter.notifyMatchFailure(op, "trivially inlineable"); - - Operation *parentOp = op->getParentOp(); - if (parentOp && parentOp->hasTrait()) { - return rewriter.notifyMatchFailure( - op, "cannot lower scf.execute_region inside a single-block parent region"); - } - - if (op.getRegion().empty()) - return rewriter.notifyMatchFailure(op, "expected non-empty region"); - - Location loc = op.getLoc(); - Block *curBlock = op->getBlock(); - Region *parentRegion = curBlock->getParent(); - - // Split the parent block so we can branch to a continuation block with phi - // arguments for the execute_region results. - auto execIt = Block::iterator(op.getOperation()); - Block *continueBlock = rewriter.splitBlock(curBlock, std::next(execIt)); - - SmallVector contArgs; - contArgs.reserve(op.getNumResults()); - for (Type t : op.getResultTypes()) - contArgs.push_back(continueBlock->addArgument(t, loc)); - - for (auto it : llvm::enumerate(op.getResults())) - it.value().replaceAllUsesWith(contArgs[it.index()]); - - // Capture blocks before moving the region. - SmallVector movedBlocks; - movedBlocks.reserve(op.getRegion().getBlocks().size()); - for (Block &b : op.getRegion()) - movedBlocks.push_back(&b); - Block *entryBlock = &op.getRegion().front(); - - // Inline the execute_region blocks into the parent region right before the - // continuation block. - rewriter.inlineRegionBefore(op.getRegion(), *parentRegion, - continueBlock->getIterator()); - - // Replace all scf.yield terminators with a branch to the continuation. - for (Block *b : movedBlocks) { - auto yield = dyn_cast(b->getTerminator()); - if (!yield) - continue; - rewriter.setInsertionPoint(yield); - rewriter.create(loc, continueBlock, yield.getOperands()); - rewriter.eraseOp(yield); - } - - // Replace execute_region itself with a branch to the inlined entry block. - rewriter.setInsertionPoint(op); - rewriter.create(loc, entryBlock, ValueRange{}); - rewriter.eraseOp(op); - return success(); - } -}; - -// Lower scf.index_switch into CFG blocks with cf.cond_br/cf.br so that we can -// avoid `scf.if` result materialization quirks (and avoid relying on cf.switch, -// which is not supported by EmitC C++ translation). -struct SCFIndexSwitchToCF : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - static LogicalResult cloneYieldingBlockAndBranchTo( - PatternRewriter &rewriter, Location loc, Block &srcBlock, Block *destBlock, - Block *continueBlock) { - rewriter.setInsertionPointToEnd(destBlock); - - IRMapping mapping; - for (Operation &inner : srcBlock.without_terminator()) - rewriter.clone(inner, mapping); - - auto yield = dyn_cast(srcBlock.getTerminator()); - if (!yield) - return failure(); - - SmallVector yieldOperands; - yieldOperands.reserve(yield.getNumOperands()); - for (Value v : yield.getOperands()) - yieldOperands.push_back(mapping.lookupOrDefault(v)); - - rewriter.create(loc, continueBlock, yieldOperands); - return success(); - } - - static Block *splitBlockForContinuation(PatternRewriter &rewriter, - scf::IndexSwitchOp op) { - auto switchIt = Block::iterator(op.getOperation()); - return rewriter.splitBlock(op->getBlock(), std::next(switchIt)); - } - - static void addContinuationArguments(PatternRewriter &rewriter, - scf::IndexSwitchOp op, Location loc, - Block *continueBlock) { - SmallVector contArgs; - contArgs.reserve(op.getNumResults()); - for (Type type : op.getResultTypes()) - contArgs.push_back(continueBlock->addArgument(type, loc)); - for (auto result : llvm::enumerate(op.getResults())) - result.value().replaceAllUsesWith(contArgs[result.index()]); - } - - static void createIndexSwitchBlocks(PatternRewriter &rewriter, - Region *parentRegion, - Region::iterator insertPt, - unsigned numCases, - SmallVectorImpl &checkBlocks, - Block *&defaultBlock, - SmallVectorImpl &caseBlocks) { - checkBlocks.reserve(numCases); - caseBlocks.reserve(numCases); - for (unsigned i = 0; i < numCases; ++i) - checkBlocks.push_back(rewriter.createBlock(parentRegion, insertPt)); - defaultBlock = rewriter.createBlock(parentRegion, insertPt); - for (unsigned i = 0; i < numCases; ++i) - caseBlocks.push_back(rewriter.createBlock(parentRegion, insertPt)); - } - - static void populateIndexSwitchCheckBlocks( - PatternRewriter &rewriter, Location loc, Value selector, - ArrayRef cases, ArrayRef checkBlocks, - ArrayRef caseBlocks, Block *defaultBlock) { - for (unsigned i = 0; i < checkBlocks.size(); ++i) { - rewriter.setInsertionPointToEnd(checkBlocks[i]); - Value caseVal = rewriter.create(loc, cases[i]); - Value cond = rewriter.create( - loc, arith::CmpIPredicate::eq, selector, caseVal); - Block *falseDest = - (i + 1 < checkBlocks.size()) ? checkBlocks[i + 1] : defaultBlock; - rewriter.create(loc, cond, caseBlocks[i], ValueRange{}, - falseDest, ValueRange{}); - } - } - - LogicalResult matchAndRewrite(scf::IndexSwitchOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Operation *parentOp = op->getParentOp(); - if (parentOp && parentOp->hasTrait()) { - return rewriter.notifyMatchFailure( - op, "cannot lower scf.index_switch inside a single-block parent region"); - } - - Block *curBlock = op->getBlock(); - Region *parentRegion = curBlock->getParent(); - Block *continueBlock = splitBlockForContinuation(rewriter, op); - addContinuationArguments(rewriter, op, loc, continueBlock); - - unsigned numCases = op.getCases().size(); - auto insertPt = continueBlock->getIterator(); - - SmallVector checkBlocks; - SmallVector caseBlocks; - Block *defaultBlock = nullptr; - createIndexSwitchBlocks(rewriter, parentRegion, insertPt, numCases, - checkBlocks, defaultBlock, caseBlocks); - - Value selector = op.getArg(); - auto cases = op.getCases(); - populateIndexSwitchCheckBlocks(rewriter, loc, selector, cases, checkBlocks, - caseBlocks, defaultBlock); - - // Fill case blocks and default block with cloned bodies + branch to cont. - for (unsigned i = 0; i < numCases; ++i) { - if (failed(cloneYieldingBlockAndBranchTo( - rewriter, loc, op.getCaseBlock(i), caseBlocks[i], continueBlock))) - return rewriter.notifyMatchFailure(op, "expected scf.yield terminator"); - } - if (failed(cloneYieldingBlockAndBranchTo(rewriter, loc, op.getDefaultBlock(), - defaultBlock, continueBlock))) - return rewriter.notifyMatchFailure(op, "expected scf.yield terminator"); - - // Replace the original switch op with a branch into the check chain. - Block *entryDest = numCases ? checkBlocks[0] : defaultBlock; - rewriter.setInsertionPointAfter(op); - rewriter.create(loc, entryDest, ValueRange{}); - rewriter.eraseOp(op); - return success(); - } -}; - -// Lower scf.while into CFG blocks with cf.br/cf.cond_br. -// -// Note: This requires the parent region to allow multiple blocks. In -// particular, scf.if/scf.for regions are single-block and cannot contain this -// lowering. -struct SCFWhileToCF : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - static LogicalResult validateWhileResultUses(scf::WhileOp op) { - Block *parentBlock = op->getBlock(); - for (Value result : op.getResults()) { - for (OpOperand &use : result.getUses()) { - if (use.getOwner()->getBlock() != parentBlock) - return failure(); - } - } - return success(); - } - - static Block *splitAfterWhileBlock(PatternRewriter &rewriter, - scf::WhileOp op) { - auto whileIt = Block::iterator(op.getOperation()); - return rewriter.splitBlock(op->getBlock(), std::next(whileIt)); - } - - static void addWhileExitArguments(PatternRewriter &rewriter, scf::WhileOp op, - Location loc, Block *afterWhileBlock) { - SmallVector exitArgs; - exitArgs.reserve(op.getNumResults()); - for (Type type : op.getResultTypes()) - exitArgs.push_back(afterWhileBlock->addArgument(type, loc)); - for (auto result : llvm::enumerate(op.getResults())) - result.value().replaceAllUsesWith(exitArgs[result.index()]); - } - - static Block *createWhileHeaderBlock(PatternRewriter &rewriter, - scf::WhileOp op, Location loc, - Block *afterWhileBlock) { - SmallVector headerArgTypes; - for (Value init : op.getInits()) - headerArgTypes.push_back(init.getType()); - SmallVector headerArgLocs(headerArgTypes.size(), loc); - return rewriter.createBlock(afterWhileBlock->getParent(), - afterWhileBlock->getIterator(), headerArgTypes, - headerArgLocs); - } - - static Block *createWhileBodyBlock(PatternRewriter &rewriter, scf::WhileOp op, - Location loc, Block *afterWhileBlock) { - Block &afterRegionBlock = op.getAfter().front(); - SmallVector bodyArgTypes(afterRegionBlock.getArgumentTypes().begin(), - afterRegionBlock.getArgumentTypes().end()); - SmallVector bodyArgLocs(bodyArgTypes.size(), loc); - return rewriter.createBlock(afterWhileBlock->getParent(), - afterWhileBlock->getIterator(), bodyArgTypes, - bodyArgLocs); - } - - static void rewriteWhileTerminators(PatternRewriter &rewriter, Location loc, - Block *headerBlock, Block *bodyBlock, - Block *afterWhileBlock) { - auto condOp = cast(headerBlock->getTerminator()); - rewriter.setInsertionPoint(condOp); - rewriter.create(loc, condOp.getCondition(), - /*trueDest=*/bodyBlock, - /*trueOperands=*/condOp.getArgs(), - /*falseDest=*/afterWhileBlock, - /*falseOperands=*/condOp.getArgs()); - rewriter.eraseOp(condOp); - - auto yieldOp = cast(bodyBlock->getTerminator()); - rewriter.setInsertionPoint(yieldOp); - rewriter.create(loc, headerBlock, yieldOp.getOperands()); - rewriter.eraseOp(yieldOp); - } - - LogicalResult matchAndRewrite(scf::WhileOp op, - PatternRewriter &rewriter) const override { - Operation *parentOp = op->getParentOp(); - if (parentOp && parentOp->hasTrait()) { - return rewriter.notifyMatchFailure( - op, "cannot lower scf.while inside a single-block parent region"); - } - - if (failed(validateWhileResultUses(op))) - return rewriter.notifyMatchFailure( - op, "unsupported: while results used outside the parent block"); - - auto loc = op.getLoc(); - Block *afterWhileBlock = splitAfterWhileBlock(rewriter, op); - addWhileExitArguments(rewriter, op, loc, afterWhileBlock); - Block *headerBlock = createWhileHeaderBlock(rewriter, op, loc, - afterWhileBlock); - Block *bodyBlock = createWhileBodyBlock(rewriter, op, loc, afterWhileBlock); - - // Move the before/after region bodies into the new CFG blocks. - Block &afterRegionBlock = op.getAfter().front(); - rewriter.mergeBlocks(&op.getBefore().front(), headerBlock, - headerBlock->getArguments()); - rewriter.mergeBlocks(&afterRegionBlock, bodyBlock, bodyBlock->getArguments()); - rewriteWhileTerminators(rewriter, loc, headerBlock, bodyBlock, - afterWhileBlock); - - // Replace scf.while itself with a branch to the header. - rewriter.setInsertionPoint(op); - rewriter.create(loc, headerBlock, op.getInits()); - rewriter.eraseOp(op); - return success(); - } -}; - -// Lower cf.switch into chained comparisons and cf.cond_br/cf.br. -// -// EmitC C++ translation currently supports cf.br/cf.cond_br, but not cf.switch. -struct CFSwitchToCondBr : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - static SmallVector> - collectSwitchCaseOperands(cf::SwitchOp op) { - SmallVector> caseOperands; - caseOperands.reserve(op.getCaseDestinations().size()); - for (auto range : op.getCaseOperands()) - caseOperands.emplace_back(range.begin(), range.end()); - return caseOperands; - } - - static SmallVector getSwitchCaseValues(cf::SwitchOp op) { - SmallVector caseValues; - if (auto caseValuesAttr = op.getCaseValues()) { - for (APInt value : caseValuesAttr->getValues()) - caseValues.push_back(value); - } - return caseValues; - } - - static SmallVector createSwitchCheckBlocks(PatternRewriter &rewriter, - Region *parentRegion, - Block *curBlock, - size_t numCases) { - auto insertPt = std::next(curBlock->getIterator()); - SmallVector checkBlocks; - checkBlocks.reserve(numCases); - for (size_t i = 0; i < numCases; ++i) - checkBlocks.push_back(rewriter.createBlock(parentRegion, insertPt)); - return checkBlocks; - } - - static LogicalResult populateSwitchCheckBlocks( - PatternRewriter &rewriter, Location loc, Value flag, IntegerType flagTy, - ArrayRef caseValues, ArrayRef caseDests, - ArrayRef> caseOperands, Block *defaultDest, - ValueRange defaultOperands, ArrayRef checkBlocks, - cf::SwitchOp op) { - for (size_t i = 0; i < caseDests.size(); ++i) { - rewriter.setInsertionPointToEnd(checkBlocks[i]); - APInt caseVal = caseValues[i]; - if (caseVal.getBitWidth() != flagTy.getWidth()) { - return rewriter.notifyMatchFailure( - op, "case value bitwidth doesn't match flag type"); - } - - Value caseConst = rewriter.create( - loc, flagTy, rewriter.getIntegerAttr(flagTy, caseVal)); - Value cond = rewriter.create( - loc, arith::CmpIPredicate::eq, flag, caseConst); - Block *falseDest = - (i + 1 < checkBlocks.size()) ? checkBlocks[i + 1] : defaultDest; - ValueRange falseOperands = - (i + 1 < checkBlocks.size()) ? ValueRange{} : defaultOperands; - rewriter.create(loc, cond, caseDests[i], - caseOperands[i], falseDest, - falseOperands); - } - return success(); - } - - LogicalResult matchAndRewrite(cf::SwitchOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Operation *parentOp = op->getParentOp(); - if (parentOp && parentOp->hasTrait()) { - return rewriter.notifyMatchFailure( - op, "cannot lower cf.switch inside a single-block parent region"); - } - - Block *curBlock = op->getBlock(); - Region *parentRegion = curBlock->getParent(); - - Value flag = op.getFlag(); - auto flagTy = dyn_cast(flag.getType()); - if (!flagTy) - return rewriter.notifyMatchFailure(op, "expected integer switch flag"); - - SmallVector defaultOperands(op.getDefaultOperands().begin(), - op.getDefaultOperands().end()); - Block *defaultDest = op.getDefaultDestination(); - - SmallVector caseDests(op.getCaseDestinations().begin(), - op.getCaseDestinations().end()); - SmallVector> caseOperands = collectSwitchCaseOperands(op); - - if (caseDests.empty()) { - rewriter.replaceOpWithNewOp(op, defaultDest, defaultOperands); - return success(); - } - - if (!op.getCaseValues()) - return rewriter.notifyMatchFailure(op, "missing case_values"); - SmallVector caseValues = getSwitchCaseValues(op); - - if (caseValues.size() != caseDests.size()) - return rewriter.notifyMatchFailure(op, "case_values/destinations mismatch"); - if (caseOperands.size() != caseDests.size()) - return rewriter.notifyMatchFailure(op, "case_operands/destinations mismatch"); - - SmallVector checkBlocks = - createSwitchCheckBlocks(rewriter, parentRegion, curBlock, - caseDests.size()); - if (failed(populateSwitchCheckBlocks(rewriter, loc, flag, flagTy, - caseValues, caseDests, caseOperands, - defaultDest, defaultOperands, - checkBlocks, op))) { - return failure(); - } - - // Replace the switch terminator with a branch into the first check block. - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp(op, checkBlocks.front(), - ValueRange{}); - return success(); - } -}; - -} // namespace - -static void populatePTOToEmitCPatterns(RewritePatternSet &patterns, - TypeConverter &typeConverter, - MLIRContext *ctx, - DataFlowSolver &solver, - PTOArch targetArch) { - (void)solver; - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx, "pto.set_flag_dyn", - "set_flag"); - patterns.add(typeConverter, ctx, "pto.wait_flag_dyn", - "wait_flag"); - // Backward-compatible aliases used in some downstream branches. - patterns.add(typeConverter, ctx, "pto.set_flag_d", - "set_flag"); - patterns.add(typeConverter, ctx, "pto.wait_flag_d", - "wait_flag"); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add( - typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add( - typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add>( - typeConverter, ctx); - patterns.add>( - typeConverter, ctx); - patterns.add>( - typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add>(typeConverter, - ctx); - patterns.add>(typeConverter, - ctx); - patterns.add>(typeConverter, - ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add>(typeConverter, ctx); - patterns.add>(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add>(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add>(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add>( - typeConverter, ctx, - "pto::comm::TPUT_ASYNC"); - patterns.add>( - typeConverter, ctx, - "pto::comm::TGET_ASYNC"); - patterns.add>(typeConverter, ctx, - "pto::comm::TPUT"); - patterns.add>(typeConverter, ctx, - "pto::comm::TGET"); - patterns.add>(typeConverter, ctx, - "pto::comm::TNOTIFY"); - patterns.add>(typeConverter, ctx, - "pto::comm::TWAIT"); - patterns.add>(typeConverter, ctx, - "pto::comm::TTEST"); - patterns.add>(typeConverter, ctx, - "TBROADCAST"); - patterns.add>(typeConverter, ctx, - "TGATHER"); - patterns.add>(typeConverter, ctx, - "TSCATTER"); - patterns.add>(typeConverter, ctx, - "TREDUCE"); - patterns.add>( - typeConverter, ctx, "PTOAS__ASYNC_EVENT_WAIT"); - patterns.add>( - typeConverter, ctx, "PTOAS__ASYNC_EVENT_TEST"); - patterns.add(typeConverter, ctx, targetArch); - patterns.add(typeConverter, ctx, targetArch); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx, targetArch); - patterns.add(typeConverter, ctx, targetArch); - patterns.add(typeConverter, ctx, targetArch); - patterns.add(typeConverter, ctx, targetArch); - patterns.add(typeConverter, ctx, targetArch); - patterns.add(typeConverter, ctx, targetArch); - patterns.add>(typeConverter, ctx); - patterns.add>(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add< - PTOTMatmulBiasToTMATMUL_BIAS, - PTOTMatmulMXToTMATMUL_MX, - PTOTMatmulMXAccToTMATMUL_MX_ACC, - PTOTMatmulMXBiasToTMATMUL_MX_BIAS, - PTOTMatmulBiasToTMATMUL_BIAS, - PTOTMatmulMXToTMATMUL_MX, - PTOTMatmulMXAccToTMATMUL_MX_ACC, - PTOTMatmulMXBiasToTMATMUL_MX_BIAS, - PTOTGemvBiasToTGEMV_BIAS, - PTOTGemvMXToTGEMV_MX, - PTOTGemvMXAccToTGEMV_MX, - PTOTGemvMXBiasToTGEMV_MX, - PTOBarrierToEmitC - >(typeConverter, ctx); - - patterns.add(typeConverter, ctx); - - populateSCFToEmitCConversionPatterns(patterns); - // Keep CFG-style branches type-consistent when block argument types are - // converted (e.g. after lowering scf.while to cf.br/cf.cond_br). - populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); -} - -//===----------------------------------------------------------------------===// -// Pass -//===----------------------------------------------------------------------===// - -namespace { -struct EmitPTOManualPass - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(EmitPTOManualPass) - - PTOArch targetArch; - - EmitPTOManualPass() : targetArch(PTOArch::A3) {} - - explicit EmitPTOManualPass(PTOArch arch) : targetArch(arch) {} - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - LLVM_DEBUG(llvm::dbgs() << "DEBUG: Start PTOToEmitC Pass\n"); - MLIRContext *ctx = &getContext(); - ModuleOp mop = getOperation(); - - if (failed(pto::validatePTOEntryFunctions(mop))) - return signalPassFailure(); - pto::annotatePTOEntryFunctions(mop); - - // A3 requires explicit FFTS base setup for inter-core sync ops. - if (targetArch == PTOArch::A3) { - bool hasMissingSetFFTs = false; - for (auto func : mop.getOps()) { - if (!hasInterCoreSyncOp(func)) - continue; - if (hasSetFFTsOp(func)) - continue; - hasMissingSetFFTs = true; - func.emitError() - << "A3 inter-core sync requires explicit `pto.set_ffts` in the " - "same function when using `pto.sync.set`/`pto.sync.wait`"; - } - if (hasMissingSetFFTs) - return signalPassFailure(); - } - - bool needsEventIdArrayHelper = false; - bool needsTRandomHelper = false; - bool needsGlobalTensorDataHelper = false; - bool needsCommInclude = false; - mop.walk([&](Operation *op) { - if (isa(op)) - needsEventIdArrayHelper = true; - if (isa(op)) - needsTRandomHelper = true; - if (isa(op)) - needsGlobalTensorDataHelper = true; - if (isa(op)) - needsCommInclude = true; - }); - - // 1. 插入头文件 - auto loc = mop->getLoc(); - OpBuilder builder(ctx); - builder.setInsertionPointToStart(mop.getBody()); - builder.create( - loc, "pto/pto-inst.hpp", /*is_standard_include=*/false); - if (needsCommInclude) { - builder.create( - loc, builder.getStringAttr(R"cpp( -#ifndef PIPE_FIX -#define PIPE_FIX PIPE_M -#endif -)cpp")); - builder.create( - loc, "pto/comm/pto_comm_inst.hpp", /*is_standard_include=*/false); - } - builder.create( - loc, builder.getStringAttr("using namespace pto;")); - if (needsGlobalTensorDataHelper) { - builder.create( - loc, builder.getStringAttr(R"cpp( -template -static AICORE inline auto PTOAS__GLOBAL_TENSOR_DATA(Tensor &tensor) - -> decltype(tensor.data()) { - return tensor.data(); -} -)cpp")); - } - if (needsEventIdArrayHelper) { - builder.create( - loc, builder.getStringAttr(R"cpp( -template -struct PTOAS_EventIdArray { - static_assert(N > 0, "PTOAS_EventIdArray requires a positive static size"); - int32_t data[N] = {}; - - AICORE inline int32_t &operator[](int32_t idx) { return data[idx]; } - AICORE inline const int32_t &operator[](int32_t idx) const { return data[idx]; } -}; -)cpp")); - } - if (needsTRandomHelper) { - builder.create( - loc, builder.getStringAttr(R"cpp( -template -static AICORE inline void PTOAS__TRANDOM( - DstTile &dst, uint32_t key0, uint32_t key1, uint32_t counter0, - uint32_t counter1, uint32_t counter2, uint32_t counter3) { - TRandomKey key = {key0, key1}; - TRandomCounter counter = {counter0, counter1, counter2, counter3}; - TRANDOM(dst, key, counter); -} -)cpp")); - } - builder.create( - loc, builder.getStringAttr(R"cpp( -enum class PTOAutoSyncTailMode : int { - kBarrierAll = 0, - kSetWaitMte3ToSEvent0 = 1, -}; - -static AICORE inline void ptoas_auto_sync_tail( - PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { - switch (mode) { - case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: - set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); - wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); - break; - case PTOAutoSyncTailMode::kBarrierAll: - default: - pipe_barrier(PIPE_ALL); - break; - } -} -)cpp")); - // Only inject the bitcast helper when we actually lower ops that need it - // (e.g. arith.bitcast or arith.maximumf/minimumf tie-breaking on zeros). - bool needsBitcastHelper = false; - mop.walk([&](Operation *op) { - if (isa(op)) { - needsBitcastHelper = true; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - if (needsBitcastHelper) { - builder.create( - loc, builder.getStringAttr(R"cpp( - template - static inline To ptoas_bitcast(From from) { - static_assert(sizeof(To) == sizeof(From), "ptoas_bitcast: size mismatch"); - To to; - __builtin_memcpy(&to, &from, sizeof(To)); - return to; - } - )cpp")); - } - - // 1.5 Pre-lower SCF constructs not handled by SCFToEmitC. - { - // scf.while / scf.index_switch are lowered via CFG blocks. This is not - // possible inside ops that require single-block regions (e.g. scf.for / - // scf.if). If we see such nesting, lower the entire function to the - // ControlFlow dialect first. - bool needsAnySCFToCF = false; - for (auto func : mop.getOps()) { - if (needsWholeFunctionSCFToCF(func)) { - needsAnySCFToCF = true; - break; - } - } - if (needsAnySCFToCF) { - RewritePatternSet scfToCfPatterns(ctx); - populateSCFToControlFlowConversionPatterns(scfToCfPatterns); - FrozenRewritePatternSet frozenSCFToCF(std::move(scfToCfPatterns)); - - ConversionTarget scfToCfTarget(*ctx); - // Only eliminate the single-block SCF constructs; we'll pre-lower - // scf.while/index_switch/execute_region ourselves afterwards. - scfToCfTarget.addIllegalOp(); - scfToCfTarget.markUnknownOpDynamicallyLegal( - [](Operation *) { return true; }); - - for (auto func : mop.getOps()) { - if (!needsWholeFunctionSCFToCF(func)) - continue; - if (failed(applyPartialConversion(func, scfToCfTarget, - frozenSCFToCF))) { - func.emitError() - << "failed to lower nested SCF to ControlFlow (SCFToCF)"; - return signalPassFailure(); - } - } - } - - RewritePatternSet scfLoweringPatterns(ctx); - scfLoweringPatterns.add(ctx); - (void)applyPatternsAndFoldGreedily(mop, std::move(scfLoweringPatterns)); - - bool hasUnsupportedSCF = false; - mop.walk([&](Operation *op) { - if (isa(op)) { - hasUnsupportedSCF = true; - op->emitError() << "Unsupported SCF op remained after pre-lowering"; - return WalkResult::interrupt(); - } - if (isa(op)) { - hasUnsupportedSCF = true; - op->emitError() - << "Unsupported CF op remained after pre-lowering: cf.switch"; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - if (hasUnsupportedSCF) - return signalPassFailure(); - } - - PTOToEmitCTypeConverter typeConverter(ctx, targetArch); - - // 2. Pre-convert SCF structural op types (e.g. scf.if/scf.for results) - // using the same type converter. This avoids creating emitc.variable with - // unsupported types such as memref. - { - RewritePatternSet scfTypePatterns(ctx); - ConversionTarget scfTypeTarget(*ctx); - scf::populateSCFStructuralTypeConversionsAndLegality( - typeConverter, scfTypePatterns, scfTypeTarget); - scfTypeTarget.markUnknownOpDynamicallyLegal( - [](Operation *) { return true; }); - - if (failed(applyPartialConversion(mop, scfTypeTarget, - std::move(scfTypePatterns)))) { - mop.emitError("failed to reconcile SCF structural types"); - return signalPassFailure(); - } - } - - // 3. 配置转换目标 - ConversionTarget target(*ctx); - - target.addIllegalDialect(); - target.addIllegalDialect(); - target.addIllegalDialect(); - target.addIllegalDialect(); - - // If we introduced CFG branches (e.g. from scf.while), make sure they are - // updated to use legalized operand types. - target.addDynamicallyLegalOp( - [&](Operation *op) { - return isLegalForBranchOpInterfaceTypeConversionPattern(op, - typeConverter); - }); - - // [关键] 允许 Cast 存在,最后统一清理 - target.addLegalOp(); - - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - - target.addLegalDialect(); - target.addLegalOp(); - - auto solver = std::make_unique(); - solver->load(); - solver->load(); - if (failed(solver->initializeAndRun(getOperation()))) - return signalPassFailure(); - - RewritePatternSet patterns(ctx); - populatePTOToEmitCPatterns(patterns, typeConverter, ctx, *solver, targetArch); - - // 4. 执行转换 - if (failed(applyPartialConversion(mop, target, std::move(patterns)))) { - llvm::errs() << "Conversion FAILED! Rolling back executed.\n"; - return signalPassFailure(); - } - - // ========================================================================= - // 5. [终极清理] - // 顺序至关重要: - // Step A: 先移除所有 Cast,让 Loop 的 Operand 类型变成底层类型 (如 int32) - // Step B: 再根据新的 Operand 类型,修复 Loop IV 的类型 - // ========================================================================= - - // --- Step A: 清理 UnrealizedConversionCastOp --- - // Prefer dropping redundant/unused casts; otherwise lower to emitc.cast - // so the C++ emitter can print it. - auto isEmitCTileLikeType = [](Type ty) { - auto opaqueTy = dyn_cast(ty); - if (!opaqueTy) - return false; - StringRef value = opaqueTy.getValue(); - return value.contains("Tile<") || value.contains("ConvTile<"); - }; - - llvm::SmallVector castsToErase; - bool castCleanupFailed = false; - mop.walk([&](UnrealizedConversionCastOp cast) { - if (castCleanupFailed) - return; - - if (cast->getNumOperands() != 1 || cast->getNumResults() != 1) { - cast.emitError() << "unsupported unrealized_conversion_cast shape"; - castCleanupFailed = true; - return; - } - - Value input = cast.getOperand(0); - Value output = cast.getResult(0); - Type inTy = input.getType(); - Type outTy = output.getType(); - - if (output.use_empty()) { - castsToErase.push_back(cast); - return; - } - - if (inTy == outTy) { - output.replaceAllUsesWith(input); - castsToErase.push_back(cast); - return; - } - - // SCF/CFG type conversion can transiently materialize pointer->memref - // bridge casts. At this stage, the producing value is already in the - // lowered EmitC pointer form; keep it and drop the bridge cast. - if (isEmitCPointerLikeType(inTy) && isa(outTy)) { - output.replaceAllUsesWith(input); - castsToErase.push_back(cast); - return; - } - - // SCF structural type conversion may leave a bridge from the converted - // EmitC tile value back to the original pto.tile_buf type for PTO op - // users. After PTO ops are lowered, the EmitC tile value is the value we - // want to keep. - if (isEmitCTileLikeType(inTy) && isa(outTy)) { - output.replaceAllUsesWith(input); - castsToErase.push_back(cast); - return; - } - - if (emitc::isSupportedEmitCType(inTy) && emitc::isSupportedEmitCType(outTy)) { - OpBuilder builder(cast); - auto c = builder.create(cast.getLoc(), outTy, input); - output.replaceAllUsesWith(c.getResult()); - castsToErase.push_back(cast); - return; - } - - cast.emitError() << "cannot lower unrealized_conversion_cast(" << inTy - << " -> " << outTy << ") to emitc.cast"; - castCleanupFailed = true; - }); - - for (auto cast : castsToErase) - cast.erase(); - - if (castCleanupFailed) - return signalPassFailure(); - - // --- Step A2: Sink casts of emitc.variable "reads" to their use sites --- - // - // SCFToEmitC lowers scf.if/scf.for results via mutable `emitc.variable` and - // `emitc.assign`. During type conversion, casts from the variable handle to - // the converted type may be materialized right after the variable - // declaration, effectively snapshotting the value *before* assignments. That - // produces wrong C++ (use-before-init / stale reads). - // - // Fix by re-materializing the cast at each use site so it reads the variable - // at the point of use. - { - SmallVector castOpsToSink; - mop.walk([&](emitc::CastOp castOp) { - if (castOp.getSource().getDefiningOp()) - castOpsToSink.push_back(castOp); - }); - - for (emitc::CastOp castOp : castOpsToSink) { - Value src = castOp.getSource(); - Type dstTy = castOp.getResult().getType(); - Value oldRes = castOp.getResult(); - - // Replace each use with a freshly inserted cast right before the user. - for (OpOperand &use : llvm::make_early_inc_range(oldRes.getUses())) { - Operation *user = use.getOwner(); - OpBuilder b(user); - b.setInsertionPoint(user); - auto newCast = b.create(castOp.getLoc(), dstTy, src); - use.set(newCast.getResult()); - } - - castOp.erase(); - } - } - - // --- Step B: 修复 Loop 归纳变量 (IV) --- - // 此时 emitc.for 的 operand 已经是 int32 了,我们检查 IV 是否匹配,不匹配则修正 - mop.walk([&](emitc::ForOp forOp) { - Type boundTy = forOp.getLowerBound().getType(); - BlockArgument iv = forOp.getBody()->getArgument(0); - - if (iv.getType() != boundTy) { - iv.setType(boundTy); // 强制将 IV 类型 (index) 修改为与边界一致 (int32) - } - }); - - // --- Step C: 消除冗余 Tile 变量 (Dead Code Elimination) [新增] --- - // 逻辑:如果一个 emitc.variable 没有被读取(use_empty), - // 那么它自己,以及给它赋值的 TASSIGN 都可以删除。 - // 注意:TASSIGN(v15, v9) 会把 v15 作为 Operand 0 使用,所以 v15 不是严格的 use_empty。 - // 我们需要检查:v15 是否除了 TASSIGN 之外没有其他 User。 - - llvm::SmallVector deadVars; - mop.walk([&](emitc::VariableOp varOp) { - // 检查该变量的所有 User - bool isRead = false; - for (Operation* user : varOp.getResult().getUsers()) { - // 如果 User 是 TASSIGN 且变量是第0个参数(dst),不算"读取" - if (auto call = dyn_cast(user)) { - if (call.getCallee() == "TASSIGN" && call.getOperand(0) == varOp.getResult()) { - continue; // 这是一个赋值操作,不算有效使用 - } - } - // 如果还有其他用途(如 TLOAD, TMOV, TMATMUL),则该变量有用 - isRead = true; - break; - } - - if (!isRead) { - deadVars.push_back(varOp); - } - }); - - for (auto varOp : deadVars) { - // 1. 先删除所有使用该变量的 TASSIGN - llvm::SmallVector usersToErase; - for (Operation* user : varOp.getResult().getUsers()) { - // 我们上面已经确认过,剩下的 user 只能是 TASSIGN - usersToErase.push_back(user); - } - for (auto u : usersToErase) u->erase(); - - // 2. 删除变量定义本身 - varOp.erase(); - } - - llvm::SmallVector deadConsts; - mop.walk([&](emitc::ConstantOp constOp) { - if (constOp.getResult().use_empty()) - deadConsts.push_back(constOp); - }); - for (auto constOp : deadConsts) - constOp.erase(); - - // ========================================================================= - } - }; -} // namespace - -std::unique_ptr mlir::pto::createEmitPTOManualPass() { - return std::make_unique(); -} - -std::unique_ptr mlir::pto::createEmitPTOManualPass(PTOArch arch) { - return std::make_unique(arch); -} +#include "PTOToEmitC.def" diff --git a/lib/PTO/Transforms/PTOToEmitC.def b/lib/PTO/Transforms/PTOToEmitC.def new file mode 100644 index 000000000..ea9466da1 --- /dev/null +++ b/lib/PTO/Transforms/PTOToEmitC.def @@ -0,0 +1,12903 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- PTOToEmitC.cpp - PTO to EmitC conversion pass ----------------------===// +//===----------------------------------------------------------------------===// + +#pragma GCC diagnostic ignored "-Woverloaded-virtual" +// https://discourse.llvm.org/t/matchandrewrite-hiding-virtual-functions/84933/8 + +#include +#include + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" +#include "PTO/IR/PTOSyncUtils.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" +#include "mlir/Analysis/DataFlowFramework.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" + +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeRange.h" + +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Target/Cpp/CppEmitter.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" + +#include +#include +#include +#include + +#define DEBUG_TYPE "pto-emitc" + +namespace mlir { +#define GEN_PASS_DEF_EMITPTOMANUAL +#include "PTO/Transforms/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +static std::string getElemTypeStringForGT(Type elemTy); +static bool getStaticMemrefLayout(MemRefType mrTy, + SmallVectorImpl &strides, + int64_t &offset); +static int64_t multiplyOrDynamic(int64_t lhs, int64_t rhs); +static void buildGlobalTensorShapeAndStride(ArrayRef shape, + ArrayRef strides, + SmallVectorImpl &shape5D, + SmallVectorImpl &stride5D); +static std::string joinIntTemplateParams(ArrayRef values); +static SmallVector buildRowMajorStrides(ArrayRef shape); +static std::string getGlobalTensorTypeStringFromShape(Type elemTy, + ArrayRef shape, + StringRef layoutEnum = + "pto::Layout::ND"); +static std::string getGlobalTensorTypeStringFromShapeAndStrides( + Type elemTy, ArrayRef shape, ArrayRef strides, + StringRef layoutEnum = "pto::Layout::ND"); +static emitc::OpaqueType getGlobalTensorOpaqueTypeFromShape( + MLIRContext *ctx, Type elemTy, ArrayRef shape, + StringRef layoutEnum = "pto::Layout::ND"); + +static const char *addrSpaceQualifier(pto::AddressSpace as) { + switch (as) { + case pto::AddressSpace::Zero: + return "__gm__"; + case pto::AddressSpace::VEC: + return "__ubuf__"; + case pto::AddressSpace::GM: + return "__gm__"; + case pto::AddressSpace::MAT: + return "__cbuf__"; + case pto::AddressSpace::LEFT: + return "__ca__"; + case pto::AddressSpace::RIGHT: + return "__cb__"; + case pto::AddressSpace::ACC: + return "__cc__"; + case pto::AddressSpace::BIAS: + // Bias tiles are special in pto-isa; keep a safe fallback qualifier. + return "__gm__"; + case pto::AddressSpace::SCALING: + // pto-isa TileType::Scaling maps to __fbuf__ (see pto/common/memory.hpp). + return "__fbuf__"; + } + return "__gm__"; +} + +[[maybe_unused]] static constexpr llvm::StringLiteral kLoweredSetValidShapeAttrName = + "__pto.lowered_set_validshape"; +[[maybe_unused]] static constexpr llvm::StringLiteral kLoweredSetValidShapeConfigAttrName = + "__pto.lowered_set_validshape_config"; +static constexpr llvm::StringLiteral kForceDynamicValidShapeAttrName = + "__pto.force_dynamic_valid_shape"; +static constexpr llvm::StringLiteral kGlobalTensorStridesAttrName = + "__pto.globaltensor_strides"; + +static Value peelUnrealized(Value v) { + if (auto castOp = v.getDefiningOp()) + return castOp.getOperand(0); + return v; +} + +static Value buildGlobalTensorFromMemref(ConversionPatternRewriter &rewriter, + Location loc, Value basePtr, + MemRefType mrTy, Operation *anchor); + +static Value maybeWrapGlobalMemrefAsGlobalTensor( + ConversionPatternRewriter &rewriter, Location loc, Value loweredValue, + Type originalType, Operation *anchor); + +static bool hasCompatibleKnownExtentForMGather(int64_t lhs, int64_t rhs) { + return lhs == ShapedType::kDynamic || rhs == ShapedType::kDynamic || + lhs == rhs; +} + +static bool isKnownUnitExtentForMGather(int64_t value) { + return value == ShapedType::kDynamic || value == 1; +} + +struct GatherScatterShapeLayoutInfo { + SmallVector shape; + bool rowMajor = false; + bool colMajor = false; +}; + +static std::optional +getGatherScatterShapeLayoutInfo(Type ty) { + if (auto tileTy = dyn_cast(ty)) { + ArrayRef validShape = tileTy.getValidShape(); + if (validShape.size() != 2) + return std::nullopt; + + GatherScatterShapeLayoutInfo info; + info.shape.assign(validShape.begin(), validShape.end()); + int32_t blayout = tileTy.getBLayoutValueI32(); + info.rowMajor = blayout == static_cast(pto::BLayout::RowMajor); + info.colMajor = blayout == static_cast(pto::BLayout::ColMajor); + return info; + } + + auto memRefTy = dyn_cast(ty); + if (!memRefTy || memRefTy.getRank() != 2) + return std::nullopt; + + SmallVector strides; + int64_t offset = ShapedType::kDynamic; + if (failed(getStridesAndOffset(memRefTy, strides, offset)) || + strides.size() != 2) + return std::nullopt; + + GatherScatterShapeLayoutInfo info; + info.shape.assign(memRefTy.getShape().begin(), memRefTy.getShape().end()); + info.rowMajor = strides[1] == 1; + info.colMajor = strides[0] == 1; + return info; +} + +static bool isRowCoalescedMGatherIndexType(Type dataTy, Type idxTy) { + auto dataInfo = getGatherScatterShapeLayoutInfo(dataTy); + auto idxInfo = getGatherScatterShapeLayoutInfo(idxTy); + if (!dataInfo || !idxInfo) + return false; + + const bool rowCoalesce1xR = + idxInfo->rowMajor && isKnownUnitExtentForMGather(idxInfo->shape[0]) && + hasCompatibleKnownExtentForMGather(idxInfo->shape[1], dataInfo->shape[0]); + const bool rowCoalesceRx1 = + idxInfo->colMajor && + hasCompatibleKnownExtentForMGather(idxInfo->shape[0], dataInfo->shape[0]) && + isKnownUnitExtentForMGather(idxInfo->shape[1]); + return rowCoalesce1xR || rowCoalesceRx1; +} + +static std::optional getLayoutAttrFromOp(Operation *op) { + if (!op) + return std::nullopt; + if (auto attr = op->getAttrOfType("layout")) + return attr.getLayout(); + return std::nullopt; +} + +static std::optional resolveLayoutFromValueChain(Value v) { + v = peelUnrealized(v); + while (Operation *def = v.getDefiningOp()) { + if (auto layout = getLayoutAttrFromOp(def)) + return layout; + if (auto subview = dyn_cast(def)) { + v = peelUnrealized(subview.getSource()); + continue; + } + if (auto reinterpret = dyn_cast(def)) { + v = peelUnrealized(reinterpret.getSource()); + continue; + } + if (auto cast = dyn_cast(def)) { + v = peelUnrealized(cast.getSource()); + continue; + } + if (auto unrealized = dyn_cast(def)) { + if (unrealized->getNumOperands() == 0) + break; + v = peelUnrealized(unrealized.getOperand(0)); + continue; + } + break; + } + return std::nullopt; +} + +static std::optional +resolveLayoutForGlobalTensor(Operation *anchor, Value basePtr) { + if (auto layout = getLayoutAttrFromOp(anchor)) + return layout; + return resolveLayoutFromValueChain(basePtr); +} + +static std::string layoutToEmitCString(mlir::pto::Layout layout) { + switch (layout) { + case mlir::pto::Layout::ND: + return "pto::Layout::ND"; + case mlir::pto::Layout::DN: + return "pto::Layout::DN"; + case mlir::pto::Layout::NZ: + return "pto::Layout::NZ"; + } + return "pto::Layout::ND"; +} + +static bool isEmitCGlobalTensorLikeType(Type ty) { + auto opaqueTy = dyn_cast(ty); + return opaqueTy && opaqueTy.getValue().contains("GlobalTensor<"); +} + +static std::string getEmitCScalarTypeToken(Type elemTy) { + if (pto::isPTOFloat8Type(elemTy) && + (elemTy.isFloat8E4M3() || elemTy.isFloat8E4M3FN() || + elemTy.isFloat8E4M3FNUZ() || elemTy.isFloat8E4M3B11FNUZ())) + return "float8_e4m3_t"; + if (pto::isPTOFloat8Type(elemTy) && + (elemTy.isFloat8E5M2() || elemTy.isFloat8E5M2FNUZ())) + return "float8_e5m2_t"; + if (isa(elemTy)) + return "hifloat8_t"; + if (isa(elemTy)) + return "float4_e1m2x2_t"; + if (isa(elemTy)) + return "float4_e2m1x2_t"; + if (elemTy.isF16()) + return "half"; + if (elemTy.isBF16()) + return "bfloat16_t"; + if (elemTy.isF32()) + return "float"; + if (elemTy.isF64()) + return "double"; + if (elemTy.isInteger(8)) + return (elemTy.isSignlessInteger(8) || elemTy.isSignedInteger(8)) ? "int8_t" + : "uint8_t"; + if (elemTy.isInteger(16)) + return (elemTy.isSignlessInteger(16) || elemTy.isSignedInteger(16)) + ? "int16_t" + : "uint16_t"; + if (elemTy.isInteger(32)) + return (elemTy.isSignlessInteger(32) || elemTy.isSignedInteger(32)) + ? "int32_t" + : "uint32_t"; + if (elemTy.isInteger(64)) + return cast(elemTy).isUnsigned() ? "uint64_t" : "int64_t"; + return "float"; +} + +static emitc::PointerType getEmitCPointerType(MLIRContext *ctx, + StringRef pointeeTypeStr) { + return emitc::PointerType::get(emitc::OpaqueType::get(ctx, pointeeTypeStr)); +} + +static emitc::PointerType getEmitCPointerType(MLIRContext *ctx, + StringRef qualifier, + StringRef elemTypeStr) { + return getEmitCPointerType(ctx, (qualifier + " " + elemTypeStr).str()); +} + +static bool isEmitCPointerLikeType(Type ty) { + if (isa(ty)) + return true; + if (auto opaqueTy = dyn_cast(ty)) + return opaqueTy.getValue().ends_with("*"); + return false; +} + +static int64_t getEmitCScalarByteWidth(Type elemTy) { + if (pto::getPTOStorageElemByteSize(elemTy) == 1) + return 1; + if (elemTy.isF16() || elemTy.isBF16() || elemTy.isInteger(16)) + return 2; + if (elemTy.isF32() || elemTy.isInteger(32)) + return 4; + if (elemTy.isF64() || elemTy.isInteger(64)) + return 8; + return 4; +} + +static std::string tileBufBLayoutToken(pto::TileBufConfigAttr configAttr); +static std::string tileBufSLayoutToken(pto::TileBufConfigAttr configAttr); +static std::string tileBufPadToken(pto::TileBufConfigAttr configAttr); +static pto::BLayout getTileBufBLayoutValue(pto::TileBufConfigAttr configAttr); +static int64_t renderTileTemplateDim(int64_t rawDim, Type elemTy, + pto::BLayout blayout, int dimIdx); + +static const char *tileRoleToken(Attribute memorySpace) { + if (auto asAttr = dyn_cast_or_null(memorySpace)) { + switch (asAttr.getAddressSpace()) { + case pto::AddressSpace::VEC: + return "TileType::Vec"; + case pto::AddressSpace::MAT: + return "TileType::Mat"; + case pto::AddressSpace::LEFT: + return "TileType::Left"; + case pto::AddressSpace::RIGHT: + return "TileType::Right"; + case pto::AddressSpace::ACC: + return "TileType::Acc"; + case pto::AddressSpace::BIAS: + return "TileType::Bias"; + case pto::AddressSpace::SCALING: + return "TileType::Scaling"; + case pto::AddressSpace::GM: + case pto::AddressSpace::Zero: + return "TileType::Vec"; + } + } + return "TileType::Vec"; +} + +static std::string tileBufCompactToken(pto::TileBufConfigAttr configAttr) { + std::string compactTok = "CompactMode::Null"; + if (auto compactAttr = dyn_cast(configAttr.getCompactMode())) { + switch (static_cast(compactAttr.getValue())) { + case 1: + compactTok = "CompactMode::Normal"; + break; + case 2: + compactTok = "CompactMode::RowPlusOne"; + break; + default: + compactTok = "CompactMode::Null"; + break; + } + } + return compactTok; +} + +static std::optional getEmitCTileTypeString(pto::TileBufType type) { + if (type.getRank() != 2) + return std::nullopt; + auto validShape = type.getValidShape(); + if (validShape.size() != 2) + return std::nullopt; + + Type elemTy = type.getElementType(); + auto configAttr = type.getConfigAttr(); + pto::BLayout blayout = getTileBufBLayoutValue(configAttr); + ArrayRef shape = type.getShape(); + int64_t rows = shape[0]; + int64_t cols = shape[1]; + + auto render = [&](int64_t dim, int dimIdx) { + return renderTileTemplateDim(dim, elemTy, blayout, dimIdx); + }; + + std::string vrowTok = + validShape[0] == ShapedType::kDynamic + ? "-1" + : std::to_string(render(validShape[0], 0)); + std::string vcolTok = + validShape[1] == ShapedType::kDynamic + ? "-1" + : std::to_string(render(validShape[1], 1)); + + int32_t fractal = 512; + if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) + fractal = frAttr.getInt(); + + return std::string("Tile<") + tileRoleToken(type.getMemorySpace()) + ", " + + getEmitCScalarTypeToken(elemTy) + ", " + + std::to_string(render(rows, 0)) + ", " + + std::to_string(render(cols, 1)) + ", " + + tileBufBLayoutToken(configAttr) + ", " + vrowTok + ", " + vcolTok + + ", " + tileBufSLayoutToken(configAttr) + ", " + + std::to_string(fractal) + ", " + tileBufPadToken(configAttr) + ", " + + tileBufCompactToken(configAttr) + ">"; +} + +//===----------------------------------------------------------------------===// +// Type Converter +//===----------------------------------------------------------------------===// + +class PTOToEmitCTypeConverter : public TypeConverter { +public: + PTOToEmitCTypeConverter(MLIRContext *Ctx, PTOArch targetArch) { + // --------------------------------------------------------- + // 1. 基本类型 (f32, i32, index) + // --------------------------------------------------------- + addConversion([Ctx](FloatType type) -> Type { + if (type.isFloat8E4M3() || type.isFloat8E4M3FN() || + type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ()) + return emitc::OpaqueType::get(Ctx, "float8_e4m3_t"); + if (type.isFloat8E5M2() || type.isFloat8E5M2FNUZ()) + return emitc::OpaqueType::get(Ctx, "float8_e5m2_t"); + if (type.isF32()) return emitc::OpaqueType::get(Ctx, "float"); + if (type.isF16()) return emitc::OpaqueType::get(Ctx, "half"); + if (type.isBF16()) return emitc::OpaqueType::get(Ctx, "bfloat16_t"); + if (type.isF64()) return emitc::OpaqueType::get(Ctx, "double"); + llvm::errs() << "[Debug] Unsupported FloatType: " << type << "\n"; + return Type{}; + }); + + addConversion([Ctx](pto::HiF8Type) -> Type { + return emitc::OpaqueType::get(Ctx, "hifloat8_t"); + }); + addConversion([Ctx](pto::F4E1M2x2Type) -> Type { + return emitc::OpaqueType::get(Ctx, "float4_e1m2x2_t"); + }); + addConversion([Ctx](pto::F4E2M1x2Type) -> Type { + return emitc::OpaqueType::get(Ctx, "float4_e2m1x2_t"); + }); + + addConversion([Ctx](IntegerType type) -> Type { + if (type.getWidth() == 1) + return type; + + // Prefer fixed-width C types. Preserve signedness if the MLIR integer is + // explicitly signed/unsigned; treat signless as signed by default. + const bool isUnsigned = type.isUnsignedInteger(); + switch (type.getWidth()) { + case 8: + return emitc::OpaqueType::get(Ctx, isUnsigned ? "uint8_t" : "int8_t"); + case 16: + return emitc::OpaqueType::get(Ctx, + isUnsigned ? "uint16_t" : "int16_t"); + case 32: + return emitc::OpaqueType::get(Ctx, + isUnsigned ? "uint32_t" : "int32_t"); + case 64: + return emitc::OpaqueType::get(Ctx, + isUnsigned ? "uint64_t" : "int64_t"); + default: + llvm::errs() << "[Debug] Unsupported IntegerType width: " + << type.getWidth() << "\n"; + return emitc::OpaqueType::get(Ctx, "int32_t"); // Fallback + } + }); + + addConversion([Ctx](IndexType type) -> Type { + return emitc::OpaqueType::get(Ctx, "int32_t"); + }); + + // vector<4xi16> (e.g. TMRGSORT executedNumList) -> pto::MrgSortExecutedNumList + addConversion([Ctx](VectorType type) -> Type { + if (type.getRank() == 1 && type.getNumElements() == 4 && + type.getElementType().isInteger(16)) + return emitc::OpaqueType::get(Ctx, "pto::MrgSortExecutedNumList"); + return Type{}; + }); + + // --------------------------------------------------------- + // 2. PTO 特殊类型 (透传或转换) + // --------------------------------------------------------- + addConversion([](emitc::OpaqueType type) { return type; }); + addConversion([](emitc::PointerType type) { return type; }); + + // --------------------------------------------------------- + // 2.5 PtrType 转换 (指针类型) + // --------------------------------------------------------- + addConversion([this, Ctx](pto::PtrType type) -> std::optional { + Type elemType = type.getElementType(); + Type newElemType = convertType(elemType); + if (!newElemType) + return std::nullopt; + + std::string elemTypeStr; + if (auto opq = dyn_cast(newElemType)) { + elemTypeStr = opq.getValue().str(); + } else { + llvm::errs() << " [Error] PtrType elem type is not OpaqueType: " + << newElemType << "\n"; + return std::nullopt; + } + + std::string qualifier = "__gm__"; + + std::string finalTypeStr = qualifier + " " + elemTypeStr; + return getEmitCPointerType(Ctx, finalTypeStr); + }); + + addConversion([Ctx](pto::PipeType type) -> Type { + (void)type; + return emitc::OpaqueType::get(Ctx, "auto"); + }); + + addConversion([Ctx](pto::EventIdArrayType type) -> Type { + std::string tok = "PTOAS_EventIdArray<" + std::to_string(type.getSize()) + ">"; + return emitc::OpaqueType::get(Ctx, tok); + }); + + // !pto.local_array -> !emitc.array. + // Variables of this type render as `T a[D1][D2]...;` in the emitted C++. + addConversion([this](pto::LocalArrayType type) -> std::optional { + Type convertedElem = convertType(type.getElementType()); + if (!convertedElem) + return std::nullopt; + return emitc::ArrayType::get(type.getShape(), convertedElem); + }); + + addConversion([Ctx](pto::AsyncSessionType type) -> Type { + (void)type; + return emitc::OpaqueType::get(Ctx, "pto::comm::AsyncSession"); + }); + + addConversion([Ctx](pto::AsyncEventType type) -> Type { + (void)type; + return emitc::OpaqueType::get(Ctx, "pto::comm::AsyncEvent"); + }); + + addConversion([Ctx](pto::PrefetchAsyncContextType type) -> Type { + (void)type; + return emitc::OpaqueType::get(Ctx, "pto::PrefetchAsyncContext"); + }); + + addConversion([Ctx](pto::TensorViewType type) -> Type { + return getGlobalTensorOpaqueTypeFromShape( + Ctx, type.getElementType(), type.getShape()); + }); + + addConversion([Ctx](pto::PartitionTensorViewType type) -> Type { + return getGlobalTensorOpaqueTypeFromShape( + Ctx, type.getElementType(), type.getShape()); + }); + + addConversion([Ctx](pto::TileBufType type) -> std::optional { + auto typeString = getEmitCTileTypeString(type); + if (!typeString) + return std::nullopt; + return emitc::OpaqueType::get(Ctx, *typeString); + }); + + // --------------------------------------------------------- + // 3. MemRef 转换 (Debug 重点) + // --------------------------------------------------------- + addConversion([this, Ctx](MemRefType type) -> std::optional { + LLVM_DEBUG(llvm::dbgs() << "Converting MemRef: " << type << "\n"); + + // A. 转换元素类型 + Type elemType = type.getElementType(); + Type newElemType = convertType(elemType); + if (!newElemType) { + llvm::errs() << " [Error] Failed to convert element type: " << elemType << "\n"; + return std::nullopt; + } + + // 获取元素类型的字符串 + std::string elemTypeStr; + if (auto opq = dyn_cast(newElemType)) { + elemTypeStr = opq.getValue().str(); + } else { + llvm::errs() << " [Error] Converted element type is not OpaqueType: " << newElemType << "\n"; + return std::nullopt; + } + + // B. 处理 Memory Space + std::string qualifier = ""; + Attribute memorySpace = type.getMemorySpace(); + + if (!memorySpace) { + qualifier = "__gm__"; + } else if (auto ptoAttr = dyn_cast(memorySpace)) { + qualifier = addrSpaceQualifier(ptoAttr.getAddressSpace()); + } else { + llvm::errs() << " [Warning] Unknown MemorySpace Attribute type: " << memorySpace << "\n"; + qualifier = "__gm__"; // Fallback + } + + std::string finalTypeStr = qualifier + " " + elemTypeStr; + LLVM_DEBUG(llvm::dbgs() << " [Success] -> " << finalTypeStr << "*\n"); + + return getEmitCPointerType(Ctx, finalTypeStr); + }); + + // --------------------------------------------------------- + // 4. Function & Materialization + // --------------------------------------------------------- + addConversion([this](FunctionType type) -> Type { + SmallVector inputs; + if (failed(convertTypes(type.getInputs(), inputs))) return Type{}; + SmallVector results; + if (failed(convertTypes(type.getResults(), results))) return Type{}; + return FunctionType::get(type.getContext(), inputs, results); + }); + + auto materializeCast = [](OpBuilder &Builder, Type ResultType, + ValueRange Inputs, Location Loc) -> Value { + if (Inputs.size() != 1) return Value(); + return Builder.create(Loc, ResultType, Inputs[0]).getResult(0); + }; + + addSourceMaterialization(materializeCast); + addTargetMaterialization(materializeCast); + // Needed for region/block signature conversions (e.g. CFG block args). + addArgumentMaterialization(materializeCast); + } +}; + +static constexpr unsigned kPTOIndexBitWidth = + 32; // keep consistent with IndexType conversion + +// Forward declarations (definitions below). +static inline std::string pipeTokFromPipeAttr(mlir::pto::PipeAttr a); +static emitc::OpaqueType getSignedIntOpaqueType(MLIRContext *ctx, + unsigned bitWidth); +static emitc::OpaqueType getUnsignedIntOpaqueType(MLIRContext *ctx, + unsigned bitWidth); +static emitc::OpaqueType getWiderSignedIntOpaqueType(MLIRContext *ctx, + unsigned bitWidth); +static emitc::OpaqueType getWiderUnsignedIntOpaqueType(MLIRContext *ctx, + unsigned bitWidth); +static Value makeEmitCOpaqueConstant(ConversionPatternRewriter &rewriter, + Location loc, Type type, + llvm::StringRef literal); +static Value makeEmitCIntConstant(ConversionPatternRewriter &rewriter, + Location loc, Type type, int64_t value); +static Value emitCCast(ConversionPatternRewriter &rewriter, Location loc, + Type dstType, Value src); +static FailureOr buildEmitCOpaqueConstantLiteral(Type targetType, + Attribute valueAttr); +static Value castSignlessIntToUnsignedSameWidth(ConversionPatternRewriter &rewriter, + Location loc, Value v, + unsigned bitWidth); +static bool needsA5NoSplitVectorGuard(Operation *op); + +static FailureOr getTileSplitToken(int64_t split) { + switch (split) { + case 0: + return std::string("TileSplitAxis::TILE_NO_SPLIT"); + case 1: + return std::string("TileSplitAxis::TILE_UP_DOWN"); + case 2: + return std::string("TileSplitAxis::TILE_LEFT_RIGHT"); + default: + return failure(); + } +} + +static FailureOr +getTPipeDirectionToken(bool isL2G2L, int8_t dirMask, PTOArch targetArch) { + if (dirMask == 1) { + if (isL2G2L && targetArch == PTOArch::A5) + return std::string("Direction::DIR_C2V_GM"); + return std::string("Direction::DIR_C2V"); + } + if (dirMask == 2) { + if (isL2G2L && targetArch == PTOArch::A5) + return std::string("Direction::DIR_V2C_GM"); + return std::string("Direction::DIR_V2C"); + } + if (dirMask == 3) + return std::string("Direction::DIR_BOTH"); + return failure(); +} + +static std::string buildTPipeToken(int32_t flagBase, llvm::StringRef dirTok, + int32_t slotSize, int32_t slotNum, + int32_t localSlotNum, bool nosplit) { + std::string token = "TPipe<" + std::to_string(flagBase) + ", " + dirTok.str() + + ", " + std::to_string(slotSize) + ", " + + std::to_string(slotNum); + token += ", " + std::to_string(localSlotNum); + token += nosplit ? ", true" : ", false"; + token += ">"; + return token; +} + +static FailureOr buildTPipeTokenFromInitOp(Operation *op, + PTOArch targetArch) { + if (auto initOp = dyn_cast(op)) { + if (!initOp.getFlagBaseAttr()) + return failure(); + auto dirTok = + getTPipeDirectionToken(/*isL2G2L=*/true, initOp.getDirMask(), targetArch); + if (failed(dirTok)) + return failure(); + int32_t localSlotNum = initOp.getLocalSlotNumAttr() + ? initOp.getLocalSlotNumAttr().getInt() + : initOp.getSlotNum(); + return buildTPipeToken(initOp.getFlagBaseAttr().getInt(), *dirTok, + initOp.getSlotSize(), initOp.getSlotNum(), + localSlotNum, + initOp.getNosplitAttr() && + initOp.getNosplitAttr().getValue()); + } + + if (auto initOp = dyn_cast(op)) { + if (!initOp.getFlagBaseAttr()) + return failure(); + auto dirTok = + getTPipeDirectionToken(/*isL2G2L=*/false, initOp.getDirMask(), targetArch); + if (failed(dirTok)) + return failure(); + return buildTPipeToken(initOp.getFlagBaseAttr().getInt(), *dirTok, + initOp.getSlotSize(), initOp.getSlotNum(), 2, + initOp.getNosplitAttr() && + initOp.getNosplitAttr().getValue()); + } + + return failure(); +} + +static FailureOr getTPipeTokenFromValue(Value pipeHandle, + PTOArch targetArch) { + pipeHandle = peelUnrealized(pipeHandle); + Operation *def = pipeHandle.getDefiningOp(); + if (!def) + return failure(); + return buildTPipeTokenFromInitOp(def, targetArch); +} + +static bool isSetFFTsPointerLikeType(Type ty) { + return isEmitCPointerLikeType(ty); +} + +static bool tileDataReturnsIntegralAddress(pto::AddressSpace as) { + return as == pto::AddressSpace::BIAS; +} + +static Type getTileDataResultType(MLIRContext *ctx, pto::AddressSpace as, + StringRef elemTok) { + if (tileDataReturnsIntegralAddress(as)) + return emitc::OpaqueType::get(ctx, "uint64_t"); + return getEmitCPointerType(ctx, addrSpaceQualifier(as), elemTok); +} + +static Value materializeTileDataValue(ConversionPatternRewriter &rewriter, + Location loc, Value tile, + pto::AddressSpace as, + StringRef elemTok) { + auto rawTy = getTileDataResultType(rewriter.getContext(), as, elemTok); + return rewriter + .create(loc, rawTy, "PTOAS__TILE_DATA", + ArrayAttr{}, ArrayAttr{}, + ValueRange{tile}) + .getResult(0); +} + +static Value materializeAddressAsPointer(ConversionPatternRewriter &rewriter, + Location loc, Value addr, + pto::AddressSpace as, + StringRef elemTok) { + auto *ctx = rewriter.getContext(); + std::string ptrTyStr = + std::string(addrSpaceQualifier(as)) + " " + elemTok.str() + "*"; + auto ptrTy = getEmitCPointerType(ctx, addrSpaceQualifier(as), elemTok); + if (isSetFFTsPointerLikeType(addr.getType())) { + if (addr.getType() == ptrTy) + return addr; + return rewriter.create(loc, ptrTy, addr).getResult(); + } + auto castTyAttr = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, ptrTyStr)}); + return rewriter + .create(loc, ptrTy, "reinterpret_cast", + ArrayAttr{}, castTyAttr, + ValueRange{addr}) + .getResult(0); +} + +struct InterCoreSyncCallDesc { + const char *callee = nullptr; + ArrayAttr args; + SmallVector operands; +}; + +static Value castInterCoreEventIdToI32(ConversionPatternRewriter &rewriter, + Location loc, Value eventId) { + auto i32Ty = emitc::OpaqueType::get(rewriter.getContext(), "int32_t"); + if (eventId.getType() == i32Ty) + return eventId; + return emitCCast(rewriter, loc, i32Ty, eventId); +} + +static Attribute getFFTSModeCodegenArg(ConversionPatternRewriter &rewriter, + int64_t fftsMode) { + auto *ctx = rewriter.getContext(); + if (fftsMode == 2) + return emitc::OpaqueAttr::get(ctx, "FFTS_MODE_VAL"); + return emitc::OpaqueAttr::get(ctx, std::to_string(fftsMode)); +} + +static Value createFFTSMsg(ConversionPatternRewriter &rewriter, Location loc, + Value eventI32, int64_t fftsMode) { + auto *ctx = rewriter.getContext(); + auto msgTy = emitc::OpaqueType::get(ctx, "uint16_t"); + auto msgArgs = rewriter.getArrayAttr({ + getFFTSModeCodegenArg(rewriter, fftsMode), + IntegerAttr::get(IndexType::get(ctx), 0), + }); + return rewriter + .create(loc, msgTy, "getFFTSMsg", + /*args=*/msgArgs, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{eventI32}) + .getResult(0); +} + +static InterCoreSyncCallDesc buildInterCoreSyncSetCall( + ConversionPatternRewriter &rewriter, Location loc, PTOArch targetArch, + pto::PipeAttr pipeAttr, IntegerAttr eventIdAttr, int64_t fftsMode) { + auto *ctx = rewriter.getContext(); + std::string pipeTok = pipeTokFromPipeAttr(pipeAttr); + + if (targetArch == PTOArch::A3) { + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value eventVal = + makeEmitCIntConstant(rewriter, loc, i32Ty, eventIdAttr.getInt()); + Value msgVal = createFFTSMsg(rewriter, loc, eventVal, fftsMode); + + InterCoreSyncCallDesc desc; + desc.callee = "ffts_cross_core_sync"; + desc.args = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeTok), + IntegerAttr::get(IndexType::get(ctx), 0), + }); + desc.operands.push_back(msgVal); + return desc; + } + + InterCoreSyncCallDesc desc; + desc.callee = "set_intra_block"; + desc.args = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, pipeTok), eventIdAttr}); + return desc; +} + +static InterCoreSyncCallDesc buildInterCoreSyncSetCallDyn( + ConversionPatternRewriter &rewriter, Location loc, PTOArch targetArch, + pto::PipeAttr pipeAttr, Value eventIdVal, int64_t fftsMode) { + auto *ctx = rewriter.getContext(); + std::string pipeTok = pipeTokFromPipeAttr(pipeAttr); + Value eventI32 = castInterCoreEventIdToI32(rewriter, loc, eventIdVal); + + if (targetArch == PTOArch::A3) { + Value msgVal = createFFTSMsg(rewriter, loc, eventI32, fftsMode); + + InterCoreSyncCallDesc desc; + desc.callee = "ffts_cross_core_sync"; + desc.args = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeTok), + IntegerAttr::get(IndexType::get(ctx), 0), + }); + desc.operands.push_back(msgVal); + return desc; + } + + InterCoreSyncCallDesc desc; + desc.callee = "set_intra_block"; + desc.args = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeTok), + IntegerAttr::get(IndexType::get(ctx), 0), + }); + desc.operands.push_back(eventI32); + return desc; +} + +static InterCoreSyncCallDesc buildInterCoreSyncWaitCall( + ConversionPatternRewriter &rewriter, PTOArch targetArch, + pto::PipeAttr pipeAttr, IntegerAttr eventIdAttr) { + auto *ctx = rewriter.getContext(); + std::string pipeTok = pipeTokFromPipeAttr(pipeAttr); + + InterCoreSyncCallDesc desc; + if (targetArch == PTOArch::A3) { + desc.callee = "wait_flag_dev"; + desc.args = rewriter.getArrayAttr({eventIdAttr}); + return desc; + } + + desc.callee = "wait_intra_block"; + desc.args = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, pipeTok), eventIdAttr}); + return desc; +} + +static InterCoreSyncCallDesc buildInterCoreSyncWaitCallDyn( + ConversionPatternRewriter &rewriter, Location loc, PTOArch targetArch, + pto::PipeAttr pipeAttr, Value eventIdVal) { + auto *ctx = rewriter.getContext(); + std::string pipeTok = pipeTokFromPipeAttr(pipeAttr); + Value eventI32 = castInterCoreEventIdToI32(rewriter, loc, eventIdVal); + + InterCoreSyncCallDesc desc; + if (targetArch == PTOArch::A3) { + desc.callee = "wait_flag_dev"; + desc.args = rewriter.getArrayAttr({IntegerAttr::get(IndexType::get(ctx), 0)}); + desc.operands.push_back(eventI32); + return desc; + } + + desc.callee = "wait_intra_block"; + desc.args = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeTok), + IntegerAttr::get(IndexType::get(ctx), 0), + }); + desc.operands.push_back(eventI32); + return desc; +} + +static bool hasInterCoreSyncOp(func::FuncOp func) { + bool found = false; + func.walk([&](Operation *op) { + if (isa(op)) { + found = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return found; +} + +static bool hasSetFFTsOp(func::FuncOp func) { + bool found = false; + func.walk([&](Operation *op) { + if (isa(op)) { + found = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return found; +} + +//===----------------------------------------------------------------------===// +// Arith -> EmitC (full dialect coverage for scalar ops) +//===----------------------------------------------------------------------===// + +template +struct ArithSimpleBinaryToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type dstTy = this->getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getOperands()); + return success(); + } +}; + +// Integer bitwise ops (andi/ori/xori) on signless integers: perform in unsigned +// to avoid signedness pitfalls, then cast back. +template +struct ArithUnsignedBitwiseBinaryToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = this->getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + if (bitWidth == 1) { + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value resU = rewriter.create(loc, uTy, lhsU, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, resU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithDivUIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::DivUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value divU = rewriter.create(loc, uTy, lhsU, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, divU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithRemUIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::RemUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value remU = rewriter.create(loc, uTy, lhsU, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, remU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithCeilDivUIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::CeilDivUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value one = makeEmitCIntConstant(rewriter, loc, uTy, 1); + Value rhsMinusOne = rewriter.create(loc, uTy, rhsU, one); + Value num = rewriter.create(loc, uTy, lhsU, rhsMinusOne); + Value divU = rewriter.create(loc, uTy, num, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, divU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithCeilDivSIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::CeilDivSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + Value zero = makeEmitCIntConstant(rewriter, loc, dstTy, 0); + Value one = makeEmitCIntConstant(rewriter, loc, dstTy, 1); + + Value q0 = rewriter.create(loc, dstTy, adaptor.getLhs(), + adaptor.getRhs()); + Value r = rewriter.create(loc, dstTy, adaptor.getLhs(), + adaptor.getRhs()); + + Value rNeZero = rewriter.create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::ne, r, + zero); + Value lhsLt0 = + rewriter.create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, adaptor.getLhs(), + zero); + Value rhsLt0 = + rewriter.create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, adaptor.getRhs(), + zero); + Value signsSame = + rewriter.create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::eq, lhsLt0, rhsLt0); + Value adjust = + rewriter.create(loc, rewriter.getI1Type(), + rNeZero, signsSame); + + Value qPlusOne = rewriter.create(loc, dstTy, q0, one); + Value result = rewriter.create(loc, dstTy, adjust, + qPlusOne, q0); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithFloorDivSIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::FloorDivSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + Value zero = makeEmitCIntConstant(rewriter, loc, dstTy, 0); + Value one = makeEmitCIntConstant(rewriter, loc, dstTy, 1); + + Value q0 = rewriter.create(loc, dstTy, adaptor.getLhs(), + adaptor.getRhs()); + Value r = rewriter.create(loc, dstTy, adaptor.getLhs(), + adaptor.getRhs()); + + Value rNeZero = rewriter.create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::ne, r, + zero); + Value lhsLt0 = + rewriter.create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, adaptor.getLhs(), + zero); + Value rhsLt0 = + rewriter.create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, adaptor.getRhs(), + zero); + Value signsDifferent = + rewriter.create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::ne, lhsLt0, rhsLt0); + Value adjust = + rewriter.create(loc, rewriter.getI1Type(), + rNeZero, signsDifferent); + + Value qMinusOne = rewriter.create(loc, dstTy, q0, one); + Value result = rewriter.create(loc, dstTy, adjust, + qMinusOne, q0); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithShiftLeftToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::ShLIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + if (bitWidth == 1) { + // Compute on u8 and truncate to i1. + auto u8Ty = getUnsignedIntOpaqueType(rewriter.getContext(), 8); + Value lhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getLhs()); + Value rhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getRhs()); + Value sh = rewriter.create(loc, u8Ty, lhsU8, + rhsU8); + Value masked = + rewriter.create(loc, u8Ty, sh, + makeEmitCIntConstant(rewriter, loc, + u8Ty, 1)); + rewriter.replaceOp(op, emitCCast(rewriter, loc, dstTy, masked)); + return success(); + } + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value shU = + rewriter.create(loc, uTy, lhsU, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, shU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithShiftRightUIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::ShRUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + if (bitWidth == 1) { + // (x >> y) on i1 is either x (y==0) or 0 (y!=0); approximate in u8. + auto u8Ty = getUnsignedIntOpaqueType(rewriter.getContext(), 8); + Value lhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getLhs()); + Value rhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getRhs()); + Value sh = rewriter.create(loc, u8Ty, lhsU8, + rhsU8); + Value masked = + rewriter.create(loc, u8Ty, sh, + makeEmitCIntConstant(rewriter, loc, + u8Ty, 1)); + rewriter.replaceOp(op, emitCCast(rewriter, loc, dstTy, masked)); + return success(); + } + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value shU = + rewriter.create(loc, uTy, lhsU, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, shU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithShiftRightSIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::ShRSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + if (bitWidth == 1) { + // (x >> y) on i1 is either x (y==0) or 0 (y!=0); approximate in u8. + auto u8Ty = getUnsignedIntOpaqueType(rewriter.getContext(), 8); + Value lhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getLhs()); + Value rhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getRhs()); + Value sh = rewriter.create(loc, u8Ty, lhsU8, + rhsU8); + Value masked = + rewriter.create(loc, u8Ty, sh, + makeEmitCIntConstant(rewriter, loc, + u8Ty, 1)); + rewriter.replaceOp(op, emitCCast(rewriter, loc, dstTy, masked)); + return success(); + } + + // Signed arithmetic shift; cast RHS to unsigned to interpret shift amount. + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value sh = + rewriter.create(loc, dstTy, adaptor.getLhs(), + rhsU); + rewriter.replaceOp(op, sh); + return success(); + } +}; + +struct ArithNegFToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::NegFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getOperand()); + return success(); + } +}; + +struct ArithRemFToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::RemFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + + // Use builtin `fmod` when possible. For f16, compute in float and cast back. + Type callTy = dstTy; + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + + if (auto opFloatTy = dyn_cast(op.getType())) { + if (opFloatTy.isF16()) { + auto f32Ty = emitc::OpaqueType::get(rewriter.getContext(), "float"); + lhs = emitCCast(rewriter, loc, f32Ty, lhs); + rhs = emitCCast(rewriter, loc, f32Ty, rhs); + callTy = f32Ty; + } + } + + // Prefer `__builtin_fmod*` to avoid relying on extra headers. + llvm::StringRef callee = "__builtin_fmod"; + if (auto opFloatTy = dyn_cast(op.getType())) { + if (opFloatTy.isF32() || opFloatTy.isF16()) + callee = "__builtin_fmodf"; + else if (opFloatTy.isF64()) + callee = "__builtin_fmod"; + } + + auto call = rewriter.create( + loc, TypeRange{callTy}, callee, ValueRange{lhs, rhs}, + /*args=*/ArrayAttr{}, /*template_args=*/ArrayAttr{}); + Value result = call.getResult(0); + if (callTy != dstTy) + result = emitCCast(rewriter, loc, dstTy, result); + + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithSelectToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getCondition().getType().isInteger(1)) + return rewriter.notifyMatchFailure( + op, "only scalar i1 conditions supported for arith.select"); + + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + + auto cond = + rewriter.create(op.getLoc(), dstTy, + adaptor.getCondition(), + adaptor.getTrueValue(), + adaptor.getFalseValue()); + rewriter.replaceOp(op, cond.getResult()); + return success(); + } +}; + +struct ArithExtUIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto dstIntTy = dyn_cast(op.getType()); + auto srcIntTy = dyn_cast(op.getIn().getType()); + if (!dstIntTy || !srcIntTy) + return rewriter.notifyMatchFailure(op, "expected scalar integer types"); + + Type dstTy = getTypeConverter()->convertType(dstIntTy); + if (!dstTy) + return failure(); + + // i1 -> iN: bool to integer already behaves as 0/1. + if (srcIntTy.getWidth() == 1) { + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); + return success(); + } + + auto uDstTy = + getUnsignedIntOpaqueType(rewriter.getContext(), dstIntTy.getWidth()); + Value srcU = + castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getIn(), + srcIntTy.getWidth()); + Value extU = emitCCast(rewriter, loc, uDstTy, srcU); + Value result = emitCCast(rewriter, loc, dstTy, extU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithExtSIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto dstIntTy = dyn_cast(op.getType()); + auto srcIntTy = dyn_cast(op.getIn().getType()); + if (!dstIntTy || !srcIntTy) + return rewriter.notifyMatchFailure(op, "expected scalar integer types"); + + Type dstTy = getTypeConverter()->convertType(dstIntTy); + if (!dstTy) + return failure(); + + // i1 sign-extension: 0 -> 0, 1 -> -1. + if (srcIntTy.getWidth() == 1) { + Value zero = makeEmitCIntConstant(rewriter, loc, dstTy, 0); + Value asInt = emitCCast(rewriter, loc, dstTy, adaptor.getIn()); + Value neg = rewriter.create(loc, dstTy, zero, asInt).getResult(); + rewriter.replaceOp(op, neg); + return success(); + } + + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); + return success(); + } +}; + +template +struct ArithCastToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type dstTy = this->getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); + return success(); + } +}; + +struct ArithIndexCastUIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::IndexCastUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + + // MemRef casts are handled elsewhere; for safety, fall back to emitc.cast. + if (isa(op.getIn().getType()) || isa(op.getType())) { + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); + return success(); + } + + auto getBW = [](Type t) -> std::optional { + if (auto i = dyn_cast(t)) + return i.getWidth(); + if (isa(t)) + return kPTOIndexBitWidth; + return std::nullopt; + }; + + auto srcBW = getBW(op.getIn().getType()); + auto dstBW = getBW(op.getType()); + if (!srcBW || !dstBW) + return rewriter.notifyMatchFailure(op, "unsupported index_castui types"); + + if (*dstBW <= *srcBW) { + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); + return success(); + } + + auto uSrcTy = getUnsignedIntOpaqueType(rewriter.getContext(), *srcBW); + auto uDstTy = getUnsignedIntOpaqueType(rewriter.getContext(), *dstBW); + Value srcU = emitCCast(rewriter, loc, uSrcTy, adaptor.getIn()); + Value extU = emitCCast(rewriter, loc, uDstTy, srcU); + Value result = emitCCast(rewriter, loc, dstTy, extU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithUIToFPToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto srcIntTy = dyn_cast(op.getIn().getType()); + if (!srcIntTy) + return rewriter.notifyMatchFailure(op, "expected scalar integer input"); + + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + + // Convert via an unsigned integer type of the same width. + if (srcIntTy.getWidth() == 1) { + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); + return success(); + } + Value srcU = + castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getIn(), + srcIntTy.getWidth()); + Value fp = rewriter.create(loc, dstTy, srcU).getResult(); + rewriter.replaceOp(op, fp); + return success(); + } +}; + +struct ArithFPToUIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::FPToUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto dstIntTy = dyn_cast(op.getType()); + if (!dstIntTy) + return rewriter.notifyMatchFailure(op, "expected scalar integer result"); + + Type dstTy = getTypeConverter()->convertType(dstIntTy); + if (!dstTy) + return failure(); + + auto uDstTy = + getUnsignedIntOpaqueType(rewriter.getContext(), dstIntTy.getWidth()); + Value asU = rewriter.create(loc, uDstTy, adaptor.getIn()).getResult(); + Value result = emitCCast(rewriter, loc, dstTy, asU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithBitcastToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + + // For pointer-like types, a regular cast is fine. + if (isa(dstTy)) { + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); + return success(); + } + + // Only support scalar int/float/index bitcasts here. + auto srcTy = op.getIn().getType(); + auto dstOrigTy = op.getType(); + + auto getBitWidth = [](Type t) -> std::optional { + if (auto it = dyn_cast(t)) + return it.getWidth(); + if (auto ft = dyn_cast(t)) + return ft.getWidth(); + if (isa(t)) + return kPTOIndexBitWidth; + return std::nullopt; + }; + auto srcBW = getBitWidth(srcTy); + auto dstBW = getBitWidth(dstOrigTy); + if (!srcBW || !dstBW || *srcBW != *dstBW) + return rewriter.notifyMatchFailure(op, "bitcast requires equal bitwidth"); + + // Determine the template argument from the destination type string. + auto dstOpaque = dyn_cast(dstTy); + if (!dstOpaque) + return rewriter.notifyMatchFailure(op, "expected emitc opaque dest type"); + + auto templateArgs = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), + dstOpaque.getValue())}); + auto call = rewriter.create( + loc, TypeRange{dstTy}, "ptoas_bitcast", /*operands=*/ValueRange{adaptor.getIn()}, + /*args=*/ArrayAttr{}, /*template_args=*/templateArgs); + rewriter.replaceOp(op, call.getResult(0)); + return success(); + } +}; + +// arith.cmpf lowering with ordered/unordered semantics. +struct ArithCmpFToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + struct CmpFConfig { + bool unordered = false; + emitc::CmpPredicate predicate = emitc::CmpPredicate::eq; + }; + + static Value isNaN(ConversionPatternRewriter &rewriter, Location loc, + Value v) { + return rewriter + .create(loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, + v, v) + .getResult(); + } + + static Value isNotNaN(ConversionPatternRewriter &rewriter, Location loc, + Value v) { + return rewriter + .create(loc, rewriter.getI1Type(), emitc::CmpPredicate::eq, + v, v) + .getResult(); + } + + static std::optional buildSpecialCmpFResult( + arith::CmpFPredicate predicate, ConversionPatternRewriter &rewriter, + Location loc, Type i1Ty, Value lhs, Value rhs) { + switch (predicate) { + case arith::CmpFPredicate::AlwaysFalse: + return makeEmitCOpaqueConstant(rewriter, loc, i1Ty, "false"); + case arith::CmpFPredicate::AlwaysTrue: + return makeEmitCOpaqueConstant(rewriter, loc, i1Ty, "true"); + case arith::CmpFPredicate::ORD: + return rewriter.create( + loc, i1Ty, isNotNaN(rewriter, loc, lhs), + isNotNaN(rewriter, loc, rhs)) + .getResult(); + case arith::CmpFPredicate::UNO: + return rewriter.create( + loc, i1Ty, isNaN(rewriter, loc, lhs), + isNaN(rewriter, loc, rhs)) + .getResult(); + default: + return std::nullopt; + } + } + + static std::optional + getCmpFConfig(arith::CmpFPredicate predicate) { + switch (predicate) { + case arith::CmpFPredicate::OEQ: + return CmpFConfig{false, emitc::CmpPredicate::eq}; + case arith::CmpFPredicate::OGT: + return CmpFConfig{false, emitc::CmpPredicate::gt}; + case arith::CmpFPredicate::OGE: + return CmpFConfig{false, emitc::CmpPredicate::ge}; + case arith::CmpFPredicate::OLT: + return CmpFConfig{false, emitc::CmpPredicate::lt}; + case arith::CmpFPredicate::OLE: + return CmpFConfig{false, emitc::CmpPredicate::le}; + case arith::CmpFPredicate::ONE: + return CmpFConfig{false, emitc::CmpPredicate::ne}; + case arith::CmpFPredicate::UEQ: + return CmpFConfig{true, emitc::CmpPredicate::eq}; + case arith::CmpFPredicate::UGT: + return CmpFConfig{true, emitc::CmpPredicate::gt}; + case arith::CmpFPredicate::UGE: + return CmpFConfig{true, emitc::CmpPredicate::ge}; + case arith::CmpFPredicate::ULT: + return CmpFConfig{true, emitc::CmpPredicate::lt}; + case arith::CmpFPredicate::ULE: + return CmpFConfig{true, emitc::CmpPredicate::le}; + case arith::CmpFPredicate::UNE: + return CmpFConfig{true, emitc::CmpPredicate::ne}; + default: + return std::nullopt; + } + } + + static Value buildCmpFResult(const CmpFConfig &config, + ConversionPatternRewriter &rewriter, + Location loc, Type i1Ty, Value lhs, Value rhs) { + Value cmp = rewriter + .create(loc, i1Ty, config.predicate, lhs, rhs) + .getResult(); + Value unord = rewriter.create( + loc, i1Ty, isNaN(rewriter, loc, lhs), isNaN(rewriter, loc, rhs)); + if (config.unordered) + return rewriter + .create(loc, i1Ty, unord, cmp) + .getResult(); + Value ord = rewriter.create( + loc, i1Ty, isNotNaN(rewriter, loc, lhs), isNotNaN(rewriter, loc, rhs)); + return rewriter + .create(loc, i1Ty, ord, cmp) + .getResult(); + } + + LogicalResult matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isa(op.getLhs().getType())) + return rewriter.notifyMatchFailure(op, "cmpf only supported on scalar floats"); + + auto loc = op.getLoc(); + auto i1Ty = rewriter.getI1Type(); + if (auto special = buildSpecialCmpFResult(op.getPredicate(), rewriter, loc, + i1Ty, adaptor.getLhs(), + adaptor.getRhs())) { + rewriter.replaceOp(op, *special); + return success(); + } + + auto config = getCmpFConfig(op.getPredicate()); + if (!config) + return rewriter.notifyMatchFailure(op, "unsupported cmpf predicate"); + rewriter.replaceOp(op, buildCmpFResult(*config, rewriter, loc, i1Ty, + adaptor.getLhs(), adaptor.getRhs())); + return success(); + } +}; + +struct ArithAddUIExtendedToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getSum().getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, + "expected scalar integer or index operands"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + SmallVector newResultTypes; + if (failed(getTypeConverter()->convertTypes(op->getResultTypes(), + newResultTypes))) + return failure(); + if (newResultTypes.size() != 2) + return failure(); + + Type sumDstTy = newResultTypes[0]; + Type overflowDstTy = newResultTypes[1]; + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + auto wideTy = getWiderUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value lhsWide = emitCCast(rewriter, loc, wideTy, lhsU); + Value rhsWide = emitCCast(rewriter, loc, wideTy, rhsU); + Value sumWide = + rewriter.create(loc, wideTy, lhsWide, rhsWide).getResult(); + + Value sumN = emitCCast(rewriter, loc, uTy, sumWide); + Value sum = emitCCast(rewriter, loc, sumDstTy, sumN); + + Value shiftAmt = makeEmitCIntConstant(rewriter, loc, wideTy, bitWidth); + Value high = rewriter + .create(loc, wideTy, sumWide, + shiftAmt) + .getResult(); + Value zeroWide = makeEmitCIntConstant(rewriter, loc, wideTy, 0); + Value overflow = + rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::ne, high, zeroWide) + .getResult(); + overflow = emitCCast(rewriter, loc, overflowDstTy, overflow); + + rewriter.replaceOp(op, {sum, overflow}); + return success(); + } +}; + +template +struct ArithMulExtendedToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getResult(0).getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, + "expected scalar integer or index operands"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + SmallVector newResultTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), + newResultTypes))) + return failure(); + if (newResultTypes.size() != 2) + return failure(); + + Type lowDstTy = newResultTypes[0]; + Type highDstTy = newResultTypes[1]; + + Type wideTy = isUnsigned ? (Type)getWiderUnsignedIntOpaqueType(rewriter.getContext(), + bitWidth) + : (Type)getWiderSignedIntOpaqueType(rewriter.getContext(), + bitWidth); + + Value lhsWide; + Value rhsWide; + if constexpr (isUnsigned) { + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + lhsWide = emitCCast(rewriter, loc, wideTy, lhsU); + rhsWide = emitCCast(rewriter, loc, wideTy, rhsU); + } else { + lhsWide = emitCCast(rewriter, loc, wideTy, adaptor.getLhs()); + rhsWide = emitCCast(rewriter, loc, wideTy, adaptor.getRhs()); + } + + Value prodWide = + rewriter.create(loc, wideTy, lhsWide, rhsWide).getResult(); + Value low = emitCCast(rewriter, loc, lowDstTy, prodWide); + + Value shiftAmt = makeEmitCIntConstant(rewriter, loc, wideTy, bitWidth); + Value highWide = rewriter + .create(loc, wideTy, prodWide, + shiftAmt) + .getResult(); + Value high = emitCCast(rewriter, loc, highDstTy, highWide); + + rewriter.replaceOp(op, {low, high}); + return success(); + } +}; + +using ArithMulSIExtendedToEmitC = + ArithMulExtendedToEmitC; +using ArithMulUIExtendedToEmitC = + ArithMulExtendedToEmitC; + +struct ArithMinMaxIToEmitCBase { + static Value makeSelect(ConversionPatternRewriter &rewriter, Location loc, + Type dstTy, Value cond, Value trueV, Value falseV) { + return rewriter + .create(loc, dstTy, cond, trueV, falseV) + .getResult(); + } +}; + +struct ArithMaxSIToEmitC : public OpConversionPattern, + ArithMinMaxIToEmitCBase { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::MaxSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + Value cond = rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, + adaptor.getLhs(), adaptor.getRhs()) + .getResult(); + Value res = makeSelect(rewriter, loc, dstTy, cond, adaptor.getRhs(), + adaptor.getLhs()); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ArithMinSIToEmitC : public OpConversionPattern, + ArithMinMaxIToEmitCBase { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::MinSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + Value cond = rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, + adaptor.getLhs(), adaptor.getRhs()) + .getResult(); + Value res = makeSelect(rewriter, loc, dstTy, cond, adaptor.getLhs(), + adaptor.getRhs()); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ArithMaxUIToEmitC : public OpConversionPattern, + ArithMinMaxIToEmitCBase { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::MaxUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + Value lhsU = + castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = + castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value cond = rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, lhsU, rhsU) + .getResult(); + Value res = makeSelect(rewriter, loc, dstTy, cond, adaptor.getRhs(), + adaptor.getLhs()); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ArithMinUIToEmitC : public OpConversionPattern, + ArithMinMaxIToEmitCBase { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::MinUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + Value lhsU = + castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = + castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value cond = rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, lhsU, rhsU) + .getResult(); + Value res = makeSelect(rewriter, loc, dstTy, cond, adaptor.getLhs(), + adaptor.getRhs()); + rewriter.replaceOp(op, res); + return success(); + } +}; + +// Floating-point max/min variants. +struct ArithFloatMinMaxToEmitCBase { + static Value isNaN(ConversionPatternRewriter &rewriter, Location loc, + Value v) { + return rewriter + .create(loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, + v, v) + .getResult(); + } + + static Value makeFZero(ConversionPatternRewriter &rewriter, Location loc, + Type ty) { + return makeEmitCOpaqueConstant(rewriter, loc, ty, "0.0f"); + } +}; + +struct ArithMaxNumFToEmitC : public OpConversionPattern, + ArithFloatMinMaxToEmitCBase { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::MaxNumFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + + Value lhsNaN = isNaN(rewriter, loc, adaptor.getLhs()); + Value rhsNaN = isNaN(rewriter, loc, adaptor.getRhs()); + + Value cmpLt = rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, + adaptor.getLhs(), adaptor.getRhs()) + .getResult(); + Value maxNoNaN = + rewriter + .create(loc, dstTy, cmpLt, adaptor.getRhs(), + adaptor.getLhs()) + .getResult(); + + Value rhsOrMax = + rewriter + .create(loc, dstTy, rhsNaN, adaptor.getLhs(), + maxNoNaN) + .getResult(); + Value res = + rewriter + .create(loc, dstTy, lhsNaN, adaptor.getRhs(), + rhsOrMax) + .getResult(); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ArithMinNumFToEmitC : public OpConversionPattern, + ArithFloatMinMaxToEmitCBase { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::MinNumFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + + Value lhsNaN = isNaN(rewriter, loc, adaptor.getLhs()); + Value rhsNaN = isNaN(rewriter, loc, adaptor.getRhs()); + + Value cmpLt = rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, + adaptor.getLhs(), adaptor.getRhs()) + .getResult(); + Value minNoNaN = + rewriter + .create(loc, dstTy, cmpLt, adaptor.getLhs(), + adaptor.getRhs()) + .getResult(); + + Value rhsOrMin = + rewriter + .create(loc, dstTy, rhsNaN, adaptor.getLhs(), + minNoNaN) + .getResult(); + Value res = + rewriter + .create(loc, dstTy, lhsNaN, adaptor.getRhs(), + rhsOrMin) + .getResult(); + rewriter.replaceOp(op, res); + return success(); + } +}; + +template +struct ArithMinMaxFPropagateNaNToEmitC : public OpConversionPattern, + ArithFloatMinMaxToEmitCBase { + using OpConversionPattern::OpConversionPattern; + + static Value buildPrimaryCandidate(ConversionPatternRewriter &rewriter, + Location loc, Type dstTy, Value lhs, + Value rhs) { + Value cmpLt = + rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, lhs, rhs) + .getResult(); + return rewriter + .create( + loc, dstTy, cmpLt, isMaximum ? rhs : lhs, isMaximum ? lhs : rhs) + .getResult(); + } + + static Value buildSignBitValue(ConversionPatternRewriter &rewriter, + Location loc, Value lhs, FloatType floatTy) { + auto bitsTy = + getUnsignedIntOpaqueType(rewriter.getContext(), floatTy.getWidth()); + auto templateArgs = rewriter.getArrayAttr({emitc::OpaqueAttr::get( + rewriter.getContext(), cast(bitsTy).getValue())}); + Value lhsBits = + rewriter + .create(loc, TypeRange{bitsTy}, "ptoas_bitcast", + ValueRange{lhs}, ArrayAttr{}, + templateArgs) + .getResult(0); + Value oneBits = makeEmitCIntConstant(rewriter, loc, bitsTy, 1); + Value shiftAmount = + makeEmitCIntConstant(rewriter, loc, bitsTy, floatTy.getWidth() - 1); + Value signMask = rewriter + .create(loc, bitsTy, oneBits, + shiftAmount) + .getResult(); + return rewriter + .create(loc, bitsTy, lhsBits, signMask) + .getResult(); + } + + static Value buildSignedZeroCandidate(ConversionPatternRewriter &rewriter, + Location loc, Type dstTy, Value lhs, + Value rhs, FloatType floatTy) { + Value zero = makeFZero(rewriter, loc, dstTy); + Value equal = rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::eq, lhs, rhs) + .getResult(); + Value lhsZero = rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::eq, lhs, + zero) + .getResult(); + Value bothZero = rewriter + .create(loc, rewriter.getI1Type(), + equal, lhsZero) + .getResult(); + auto bitsTy = + getUnsignedIntOpaqueType(rewriter.getContext(), floatTy.getWidth()); + Value zeroBits = makeEmitCIntConstant(rewriter, loc, bitsTy, 0); + Value lhsIsNegZero = + rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::ne, + buildSignBitValue(rewriter, loc, lhs, floatTy), + zeroBits) + .getResult(); + Value tie = rewriter + .create( + loc, dstTy, lhsIsNegZero, isMaximum ? rhs : lhs, + isMaximum ? lhs : rhs) + .getResult(); + return rewriter + .create(loc, dstTy, bothZero, tie, + buildPrimaryCandidate(rewriter, loc, dstTy, + lhs, rhs)) + .getResult(); + } + + static Value buildNaNPropagatingResult(ConversionPatternRewriter &rewriter, + Location loc, Type dstTy, Value lhs, + Value rhs, FloatType floatTy) { + Value lhsNaN = isNaN(rewriter, loc, lhs); + Value rhsNaN = isNaN(rewriter, loc, rhs); + Value noNaN = + buildSignedZeroCandidate(rewriter, loc, dstTy, lhs, rhs, floatTy); + Value rhsOrNoNaN = rewriter + .create(loc, dstTy, rhsNaN, rhs, + noNaN) + .getResult(); + return rewriter + .create(loc, dstTy, lhsNaN, lhs, rhsOrNoNaN) + .getResult(); + } + + LogicalResult + matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isa(op.getType())) + return rewriter.notifyMatchFailure(op, "expected scalar float type"); + + auto loc = op.getLoc(); + Type dstTy = this->getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + + auto floatTy = cast(op.getType()); + rewriter.replaceOp(op, buildNaNPropagatingResult( + rewriter, loc, dstTy, adaptor.getLhs(), + adaptor.getRhs(), floatTy)); + return success(); + } +}; + +using ArithMaximumFToEmitC = + ArithMinMaxFPropagateNaNToEmitC; +using ArithMinimumFToEmitC = + ArithMinMaxFPropagateNaNToEmitC; + +//===----------------------------------------------------------------------===// +// Arith -> EmitC helpers +//===----------------------------------------------------------------------===// + +static emitc::OpaqueType getSignedIntOpaqueType(MLIRContext *ctx, + unsigned bitWidth) { + switch (bitWidth) { + case 1: + return emitc::OpaqueType::get(ctx, "int8_t"); + case 8: + return emitc::OpaqueType::get(ctx, "int8_t"); + case 16: + return emitc::OpaqueType::get(ctx, "int16_t"); + case 32: + return emitc::OpaqueType::get(ctx, "int32_t"); + case 64: + return emitc::OpaqueType::get(ctx, "int64_t"); + case 128: + return emitc::OpaqueType::get(ctx, "__int128"); + default: + llvm::errs() << "[Debug] Unsupported signed integer bitwidth: " << bitWidth + << "\n"; + return emitc::OpaqueType::get(ctx, "int64_t"); + } +} + +static emitc::OpaqueType getUnsignedIntOpaqueType(MLIRContext *ctx, + unsigned bitWidth) { + switch (bitWidth) { + case 1: + return emitc::OpaqueType::get(ctx, "uint8_t"); + case 8: + return emitc::OpaqueType::get(ctx, "uint8_t"); + case 16: + return emitc::OpaqueType::get(ctx, "uint16_t"); + case 32: + return emitc::OpaqueType::get(ctx, "uint32_t"); + case 64: + return emitc::OpaqueType::get(ctx, "uint64_t"); + case 128: + return emitc::OpaqueType::get(ctx, "unsigned __int128"); + default: + llvm::errs() << "[Debug] Unsupported unsigned integer bitwidth: " + << bitWidth << "\n"; + return emitc::OpaqueType::get(ctx, "uint64_t"); + } +} + +static emitc::OpaqueType getWiderSignedIntOpaqueType(MLIRContext *ctx, + unsigned bitWidth) { + switch (bitWidth) { + case 1: + case 8: + return getSignedIntOpaqueType(ctx, 16); + case 16: + return getSignedIntOpaqueType(ctx, 32); + case 32: + return getSignedIntOpaqueType(ctx, 64); + case 64: + return getSignedIntOpaqueType(ctx, 128); + default: + return getSignedIntOpaqueType(ctx, 128); + } +} + +static emitc::OpaqueType getWiderUnsignedIntOpaqueType(MLIRContext *ctx, + unsigned bitWidth) { + switch (bitWidth) { + case 1: + case 8: + return getUnsignedIntOpaqueType(ctx, 16); + case 16: + return getUnsignedIntOpaqueType(ctx, 32); + case 32: + return getUnsignedIntOpaqueType(ctx, 64); + case 64: + return getUnsignedIntOpaqueType(ctx, 128); + default: + return getUnsignedIntOpaqueType(ctx, 128); + } +} + +static Value makeEmitCOpaqueConstant(ConversionPatternRewriter &rewriter, + Location loc, Type type, + llvm::StringRef literal) { + auto attr = emitc::OpaqueAttr::get(rewriter.getContext(), literal); + return rewriter.create(loc, type, attr); +} + +static Value makeEmitCIntConstant(ConversionPatternRewriter &rewriter, + Location loc, Type type, int64_t value) { + return makeEmitCOpaqueConstant(rewriter, loc, type, std::to_string(value)); +} + +static FailureOr buildEmitCOpaqueConstantLiteral(Type targetType, + Attribute valueAttr) { + auto opaqueTy = dyn_cast(targetType); + if (!opaqueTy) + return failure(); + + if (opaqueTy.getValue() == "pto::MrgSortExecutedNumList") { + auto dense = dyn_cast_or_null(valueAttr); + if (!dense) + return failure(); + + auto vecTy = dyn_cast(dense.getType()); + if (!vecTy || vecTy.getRank() != 1 || vecTy.getNumElements() != 4 || + !vecTy.getElementType().isInteger(16)) + return failure(); + + std::string literal; + llvm::raw_string_ostream os(literal); + os << "pto::MrgSortExecutedNumList{"; + bool first = true; + for (APInt elem : dense.getValues()) { + if (!first) + os << ", "; + first = false; + os << elem.getZExtValue(); + } + os << "}"; + os.flush(); + return literal; + } + + return failure(); +} + +static Value emitCCast(ConversionPatternRewriter &rewriter, Location loc, + Type dstType, Value src) { + if (src.getType() == dstType) + return src; + return rewriter.createOrFold(loc, dstType, src); +} + +// For signless iN integers lowered to signed C++ types, this creates a value +// representing the same N-bit pattern in an unsigned C++ type of the same +// width. This avoids incorrect sign-extension when later widening to a larger +// unsigned type. +static Value castSignlessIntToUnsignedSameWidth(ConversionPatternRewriter &rewriter, + Location loc, Value v, + unsigned bitWidth) { + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + return emitCCast(rewriter, loc, uTy, v); +} + +struct ArithMulIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(arith::MulIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + // i1 mul is equivalent to bitwise AND (mod 2 arithmetic). + if (bitWidth == 1) { + rewriter.replaceOpWithNewOp(op, opTy, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value mulU = rewriter.create(loc, uTy, lhsU, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, mulU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithAddIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + // i1 add is equivalent to XOR (mod 2 arithmetic). + if (bitWidth == 1) { + rewriter.replaceOpWithNewOp(op, opTy, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value addU = rewriter.create(loc, uTy, lhsU, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, addU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithCastOPToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type newTy = getTypeConverter()->convertType(op.getType()); + if (!newTy) + return failure(); + if (adaptor.getIn().getType() == newTy) { + rewriter.replaceOp(op, adaptor.getIn()); + return success(); + } + rewriter.replaceOpWithNewOp(op, newTy, adaptor.getIn()); + return success(); + } +}; + +struct ArithSubIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(arith::SubIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + // i1 sub is equivalent to XOR (mod 2 arithmetic). + if (bitWidth == 1) { + rewriter.replaceOpWithNewOp(op, opTy, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value subU = rewriter.create(loc, uTy, lhsU, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, subU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithDivSIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(arith::DivSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type newTy = getTypeConverter()->convertType(op.getType()); + if (!newTy) + return failure(); + rewriter.replaceOpWithNewOp(op, newTy, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } +}; + +struct ArithRemSIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type newTy = getTypeConverter()->convertType(op.getType()); + if (!newTy) + return failure(); + rewriter.replaceOpWithNewOp(op, newTy, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } +}; + +struct ArithTruncIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + auto dstIntTy = dyn_cast(op.getType()); + auto srcIntTy = dyn_cast(op.getIn().getType()); + if (!dstIntTy || !srcIntTy) + return rewriter.notifyMatchFailure(op, "expected scalar integer types"); + + Type dstTy = getTypeConverter()->convertType(dstIntTy); + if (!dstTy) + return failure(); + + // to-i1 conversions: Arith wants truncation to the low bit, while C/C++ + // casts to bool are equivalent to `v != 0`. Implement as `(bool)(v & 1)`. + if (dstIntTy.getWidth() == 1) { + if (srcIntTy.getWidth() == 1) { + rewriter.replaceOp(op, adaptor.getIn()); + return success(); + } + + auto uSrcTy = + getUnsignedIntOpaqueType(rewriter.getContext(), srcIntTy.getWidth()); + Value inU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getIn(), + srcIntTy.getWidth()); + Value one = makeEmitCIntConstant(rewriter, loc, uSrcTy, 1); + Value masked = + rewriter.create(loc, uSrcTy, inU, one); + Value asBool = emitCCast(rewriter, loc, dstTy, masked); + rewriter.replaceOp(op, asBool); + return success(); + } + + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); + return success(); + } +}; + +struct ArithConstantToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type newType = getTypeConverter()->convertType(op.getType()); + if (!newType) + return failure(); + + // `adaptor.getValue()` may be null if attribute conversion isn't defined. + // Use the original attribute as fallback and always cast null-safely. + Attribute valueAttr = adaptor.getValue(); + if (!valueAttr) + valueAttr = op.getValue(); + + if (auto opaqueLiteral = buildEmitCOpaqueConstantLiteral(newType, valueAttr); + succeeded(opaqueLiteral)) { + auto constAttr = emitc::OpaqueAttr::get(rewriter.getContext(), *opaqueLiteral); + rewriter.replaceOpWithNewOp(op, newType, constAttr); + return success(); + } + + if (auto floatAttr = dyn_cast_or_null(valueAttr)) { + SmallString<32> valStr; + floatAttr.getValue().toString(valStr); + llvm::StringRef s(valStr); + // Ensure the literal parses as a floating-point constant in C/C++. + // `APFloat::toString` may emit "1" for integral values; make it "1.0". + const bool hasFloatMarker = + s.contains('.') || s.contains('e') || s.contains('E') || + s.contains('p') || s.contains('P') || s.starts_with("0x") || + s.starts_with("0X") || s.starts_with("nan") || + s.starts_with("-nan") || s.starts_with("inf") || + s.starts_with("-inf"); + if (!hasFloatMarker) + valStr.append(".0"); + // Suffix: keep `f` for f16/f32; omit for f64. + if (!floatAttr.getType().isF64()) + valStr.append("f"); + auto constAttr = emitc::OpaqueAttr::get(rewriter.getContext(), valStr); + rewriter.replaceOpWithNewOp(op, newType, constAttr); + return success(); + } + + if (auto intAttr = dyn_cast_or_null(valueAttr)) { + std::string valStr = std::to_string(intAttr.getValue().getSExtValue()); + auto constAttr = emitc::OpaqueAttr::get(rewriter.getContext(), valStr); + rewriter.replaceOpWithNewOp(op, newType, constAttr); + return success(); + } + + return failure(); + } +}; +//===----------------------------------------------------------------------===// +// pto.mgather lowering -> MGATHER(dst, src, indexes) (pto-isa) +//===----------------------------------------------------------------------===// + +struct PTOMGatherToMGATHER : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::MGatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *ctx = rewriter.getContext(); + Value mem = peelUnrealized(adaptor.getMem()); + Value idx = peelUnrealized(adaptor.getIdx()); + Value dst = peelUnrealized(adaptor.getDst()); + + Value memArg = maybeWrapGlobalMemrefAsGlobalTensor( + rewriter, op.getLoc(), mem, op.getMem().getType(), op.getOperation()); + + auto gatherOobTok = [&](pto::GatherOOB mode) -> StringRef { + switch (mode) { + case pto::GatherOOB::Undefined: + return "pto::GatherOOB::Undefined"; + case pto::GatherOOB::Clamp: + return "pto::GatherOOB::Clamp"; + case pto::GatherOOB::Wrap: + return "pto::GatherOOB::Wrap"; + case pto::GatherOOB::Zero: + return "pto::GatherOOB::Zero"; + } + llvm_unreachable("unknown GatherOOB"); + }; + + SmallVector templateArgVec; + const bool rowCoalesce = + isRowCoalescedMGatherIndexType(op.getDst().getType(), op.getIdx().getType()); + templateArgVec.push_back(emitc::OpaqueAttr::get( + ctx, rowCoalesce ? "pto::Coalesce::Row" : "pto::Coalesce::Elem")); + if (op.getGatherOob() != pto::GatherOOB::Undefined) { + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, gatherOobTok(op.getGatherOob()))); + } + ArrayAttr templateArgs = rewriter.getArrayAttr(templateArgVec); + + rewriter.create( + op.getLoc(), TypeRange{}, "MGATHER", + ArrayAttr{}, templateArgs, + ValueRange{dst, memArg, idx}); + + if (op->getNumResults() == 0) { + rewriter.eraseOp(op); + } else { + rewriter.replaceOp(op, dst); + } + return success(); + } +}; + +struct AffineApplyMulConstToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(affine::AffineApplyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto map = op.getAffineMap(); + + if (map.getNumDims() != 0 || map.getNumSymbols() != 1) + return failure(); + + auto expr = map.getResult(0); + auto bin = dyn_cast(expr); + if (!bin || bin.getKind() != AffineExprKind::Mul) + return failure(); + + auto lhs = bin.getLHS(); + auto rhs = bin.getRHS(); + + auto symExpr = dyn_cast(lhs); + auto constExpr = dyn_cast(rhs); + if (!symExpr || !constExpr) + return failure(); + + Value inputVal = adaptor.getMapOperands()[0]; + + std::string valStr = std::to_string(constExpr.getValue()); + auto cstAttr = emitc::OpaqueAttr::get(rewriter.getContext(), valStr); + auto cstOp = rewriter.create( + op.getLoc(), inputVal.getType(), cstAttr); + + rewriter.replaceOpWithNewOp( + op, inputVal.getType(), inputVal, cstOp); + + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Kernel inference helpers +//===----------------------------------------------------------------------===// + +enum class KernelKind { VecAdd, Matmul, Unknown }; + +[[maybe_unused]] static KernelKind inferKernelKind(func::FuncOp f) { + bool hasAdd = false; + bool hasMM = false; + f.walk([&](Operation *op) { + if (isa(op)) hasAdd = true; + if (isa(op)) hasMM = true; + if (isa(op)) hasMM = true; + }); + if (hasMM) return KernelKind::Matmul; + if (hasAdd) return KernelKind::VecAdd; + return KernelKind::Unknown; +} + +[[maybe_unused]] static void inferTileMNK(func::FuncOp f, int &M, int &N, int &K) { + M = 32; N = 32; K = 32; + SmallVector subs; + f.walk([&](memref::SubViewOp sv) { subs.push_back(sv); }); + + auto readShape2D = [&](memref::SubViewOp sv, int &d0, int &d1) { + auto resTy = mlir::cast(sv.getResult().getType()); + if (resTy.getRank() == 2 && resTy.hasStaticShape()) { + d0 = (int)resTy.getDimSize(0); + d1 = (int)resTy.getDimSize(1); + } + }; + + if (subs.empty()) return; + + int a0=32, a1=32; + readShape2D(subs[0], a0, a1); + M = a0; N = a1; + + if (subs.size() >= 2) { + int b0=32, b1=32; + readShape2D(subs[0], a0, a1); + readShape2D(subs[1], b0, b1); + M = a0; K = a1; N = b1; + } +} + +static std::optional getKernelKindMacro(func::FuncOp funcOp) { + auto kernelKindAttr = + funcOp->getAttrOfType(FunctionKernelKindAttr::name); + if (!kernelKindAttr) + return std::nullopt; + + switch (kernelKindAttr.getKernelKind()) { + case FunctionKernelKind::Cube: + return StringRef("__DAV_CUBE__"); + case FunctionKernelKind::Vector: + return StringRef("__DAV_VEC__"); + } + + llvm_unreachable("unexpected kernel kind"); +} + +struct FuncToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(func::FuncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Convert the function signature with the type converter. + Type convertedTy = getTypeConverter()->convertType(op.getFunctionType()); + auto funcType = dyn_cast_or_null(convertedTy); + if (!funcType) + return rewriter.notifyMatchFailure(op, "failed to convert function type"); + if (funcType.getNumResults() > 1) + return rewriter.notifyMatchFailure( + op, "EmitC cannot return multiple values"); + + // Create the EmitC function with the converted signature. + auto emitcFunc = + rewriter.create(op.getLoc(), op.getName(), funcType); + + for (const auto &namedAttr : op->getAttrs()) { + StringRef name = namedAttr.getName().strref(); + if (name == op.getFunctionTypeAttrName() || + name == SymbolTable::getSymbolAttrName() || + name == pto::kPTOEntryAttrName || + name == pto::kLegacyHACCEntryAttrName || + name == "pto.internal.entry") + continue; + emitcFunc->setAttr(namedAttr.getName(), namedAttr.getValue()); + } + + if (op.isDeclaration()) { + emitcFunc.setSpecifiersAttr(rewriter.getStrArrayAttr({"extern"})); + rewriter.eraseOp(op); + return success(); + } + + if (pto::isPTOEntryFunction(op)) { + emitcFunc.setSpecifiersAttr( + rewriter.getStrArrayAttr({"__global__ AICORE"})); + } else if (op.isPrivate()) { + emitcFunc.setSpecifiersAttr( + rewriter.getStrArrayAttr({"static", "AICORE"})); + } else { + emitcFunc.setSpecifiersAttr(rewriter.getStrArrayAttr({"AICORE"})); + } + + std::optional kernelKindMacro = getKernelKindMacro(op); + bool needsNoSplitGuard = needsA5NoSplitVectorGuard(op.getOperation()); + + // Inline the original body, then convert region/block argument types to + // match the converted signature (also covers CFG blocks introduced by + // pre-lowering, e.g. scf.while -> cf.br/cf.cond_br). + rewriter.inlineRegionBefore(op.getBody(), emitcFunc.getBody(), + emitcFunc.end()); + + TypeConverter::SignatureConversion entryConv(op.getNumArguments()); + for (unsigned i = 0; i < op.getNumArguments(); ++i) + entryConv.addInputs(i, funcType.getInput(i)); + + if (failed(rewriter.convertRegionTypes(&emitcFunc.getBody(), + *getTypeConverter(), &entryConv))) + return failure(); + + // Preserve the existing function prologue shape. `kernel_kind` functions are + // emitted with the same macro guard/reset sequence that used to come from + // early pto.section wrapping, but only after SCF pre-lowering has finished. + { + Block &entryBlock = emitcFunc.getBody().front(); + rewriter.setInsertionPointToStart(&entryBlock); + rewriter.create(op.getLoc(), "using T = float;"); + if (kernelKindMacro) { + std::string startMacro = "\n#if defined(" + kernelKindMacro->str() + ")"; + rewriter.create(op.getLoc(), startMacro); + if (*kernelKindMacro == "__DAV_VEC__") { + rewriter.create(op.getLoc(), "set_mask_norm();"); + rewriter.create(op.getLoc(), + "set_vector_mask(-1, -1);"); + if (needsNoSplitGuard) + rewriter.create( + op.getLoc(), "if (get_subblockid() == 0) {"); + } + } + } + + if (kernelKindMacro) { + Block &lastBlock = emitcFunc.getBody().back(); + rewriter.setInsertionPoint(lastBlock.getTerminator()); + if (*kernelKindMacro == "__DAV_VEC__" && needsNoSplitGuard) + rewriter.create(op.getLoc(), "}"); + std::string endMacro = "#endif // " + kernelKindMacro->str() + "\n"; + rewriter.create(op.getLoc(), endMacro); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// SubView lowering to GlobalTensor (keep your existing code) +//===----------------------------------------------------------------------=== + +enum class Role { A, B, C, Unknown }; + +template +static std::optional inferMatmulLikeSubviewRole(MatmulLikeOp op, + Value buffer) { + if (op.getLhs() == buffer) + return Role::A; + if (op.getRhs() == buffer) + return Role::B; + return std::nullopt; +} + +static std::optional inferSubviewRoleFromLoadUser(mlir::pto::TLoadOp load) { + Value buffer = load.getDst(); + if (!buffer) + return std::nullopt; + for (Operation *user : buffer.getUsers()) { + if (auto matmul = dyn_cast(user)) { + if (auto role = inferMatmulLikeSubviewRole(matmul, buffer)) + return role; + continue; + } + if (auto matmulAcc = dyn_cast(user)) { + if (auto role = inferMatmulLikeSubviewRole(matmulAcc, buffer)) + return role; + } + } + return std::nullopt; +} + +static std::optional inferSubviewRoleFromUser(Operation *user, Value result) { + if (auto load = dyn_cast(user)) + return inferSubviewRoleFromLoadUser(load); + if (auto store = dyn_cast(user)) { + if (store.getDst() == result) + return Role::C; + } + return std::nullopt; +} + +[[maybe_unused]] static Role inferSubviewRole(memref::SubViewOp sv) { + Value result = sv.getResult(); + for (Operation *user : result.getUsers()) { + if (auto role = inferSubviewRoleFromUser(user, result)) + return *role; + } + return Role::Unknown; +} + +// ============================================================================= +// 4. MemRef SubView -> Explicit Shape/Stride Construction (Full Implementation) +// ============================================================================= +struct SubviewToEmitCPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + // 辅助函数:尝试从 OpFoldResult 中提取静态整数值 + std::optional extractStaticInt(OpFoldResult ofr) const { + if (auto attr = ofr.dyn_cast()) { + if (auto intAttr = dyn_cast(attr)) + return intAttr.getInt(); + } else { + Value v = ofr.get(); + if (auto cOp = v.getDefiningOp()) { + if (auto iAttr = dyn_cast(cOp.getValue())) + return iAttr.getInt(); + } else if (auto idxOp = v.getDefiningOp()) { + return idxOp.value(); + } + } + return std::nullopt; + } + + LogicalResult matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + // 获取源 MemRef 类型信息 + auto srcType = mlir::cast(op.getSource().getType()); + int64_t rank = srcType.getRank(); + + auto elemTypeToString = [&](Type elemTy) -> std::string { + if (elemTy.isF16()) + return "half"; + if (elemTy.isBF16()) + return "bfloat16_t"; + if (elemTy.isF32()) + return "float"; + if (elemTy.isF64()) + return "double"; + if (elemTy.isInteger(8)) { + if (elemTy.isSignlessInteger(8) || elemTy.isSignedInteger(8)) + return "int8_t"; + return "uint8_t"; + } + if (elemTy.isInteger(16)) { + if (elemTy.isSignlessInteger(16) || elemTy.isSignedInteger(16)) + return "int16_t"; + return "uint16_t"; + } + if (elemTy.isInteger(32)) { + if (elemTy.isSignlessInteger(32) || elemTy.isSignedInteger(32)) + return "int32_t"; + return "uint32_t"; + } + if (elemTy.isInteger(64)) { + return cast(elemTy).isUnsigned() ? "uint64_t" : "int64_t"; + } + return "float"; + }; + + // ------------------------------------------------------------------------- + // Part 1: 指针偏移计算 (Runtime Pointer Arithmetic) + // ------------------------------------------------------------------------- + + // 准备类型: unsigned + Type u32Ty = emitc::OpaqueType::get(ctx, "unsigned"); + + // Helper: 创建 unsigned 常量 + auto mkU32 = [&](int64_t v) -> Value { + return rewriter.create( + loc, u32Ty, emitc::OpaqueAttr::get(ctx, std::to_string(v))); + }; + + // Helper: 将 OpFoldResult 转为 EmitC Value (用于计算) + auto ofrToEmitCValue = [&](OpFoldResult ofr) -> Value { + if (auto v = ofr.dyn_cast()) { + Value rv = rewriter.getRemappedValue(v); + // 如果类型不匹配,插入 Cast + if (rv.getType() != u32Ty) + return rewriter.create(loc, u32Ty, rv).getResult(); + return rv; + } + if (auto attr = ofr.dyn_cast()) { + if (auto ia = dyn_cast(attr)) + return mkU32(ia.getValue().getSExtValue()); + } + return mkU32(0); + }; + + // 1. 获取 Source 的 Strides (支持动态 Stride 收集) + SmallVector sourceStrides; + + if (auto rc = op.getSource().getDefiningOp()) { + sourceStrides = rc.getMixedStrides(); + } else { + SmallVector strideInts; + int64_t offset = ShapedType::kDynamic; + bool useTypeStrides = succeeded(getStridesAndOffset(srcType, strideInts, offset)); + (void)offset; + if (useTypeStrides) { + for (int64_t s : strideInts) { + if (s == ShapedType::kDynamic) + useTypeStrides = false; + } + } + if (useTypeStrides) { + for (int64_t s : strideInts) { + sourceStrides.push_back(rewriter.getIndexAttr(s)); + } + } else { + // Fallback: Compact Layout + auto shape = srcType.getShape(); + int64_t current = 1; + sourceStrides.resize(rank); + for (int i = rank - 1; i >= 0; --i) { + sourceStrides[i] = rewriter.getIndexAttr(current); + if (shape[i] != ShapedType::kDynamic) current *= shape[i]; + } + } + } + + // 2. 计算运行时 Offset + auto staticOffsets = op.getStaticOffsets(); + auto dynamicOffsets = adaptor.getOffsets(); + int dynOffIdx = 0; + Value totalOffset = mkU32(0); + + for (int i = 0; i < rank; ++i) { + // A. 获取 Offset + Value offVal; + if (staticOffsets[i] == ShapedType::kDynamic) { + Value rawDyn = dynamicOffsets[dynOffIdx++]; + offVal = rewriter.create(loc, u32Ty, rawDyn); + } else { + offVal = mkU32(staticOffsets[i]); + } + + // B. 获取 Stride (用于指针计算) + Value strideVal = mkU32(1); + if (i < (int)sourceStrides.size()) { + strideVal = ofrToEmitCValue(sourceStrides[i]); + } + + // C. 累加 + Value term = rewriter.create(loc, u32Ty, offVal, strideVal); + totalOffset = rewriter.create(loc, u32Ty, totalOffset, term); + } + + // 3. 生成新指针 + // + // NOTE: Some toolchains may materialize kernel pointer params as `void*` even + // when the underlying element type is i16. Pointer arithmetic on `void*` + // is ill-formed in C++, so we explicitly cast to a typed pointer for i16. + Value sourcePtr = adaptor.getSource(); + Value tileCandidate = sourcePtr; + if (auto castOp = sourcePtr.getDefiningOp()) { + tileCandidate = castOp.getOperand(); + } else if (auto uc = + sourcePtr.getDefiningOp()) { + tileCandidate = uc.getOperand(0); + } + if (auto ot = dyn_cast(tileCandidate.getType())) { + auto tyStr = ot.getValue(); + if (tyStr.find("Tile<") != std::string::npos || + tyStr.find("ConvTile<") != std::string::npos) { + std::string elemTok = elemTypeToString(srcType.getElementType()); + pto::AddressSpace as = pto::AddressSpace::GM; + if (auto asAttr = + dyn_cast_or_null(srcType.getMemorySpace())) + as = asAttr.getAddressSpace(); + sourcePtr = + materializeTileDataValue(rewriter, loc, tileCandidate, as, elemTok); + if (tileDataReturnsIntegralAddress(as)) + sourcePtr = + materializeAddressAsPointer(rewriter, loc, sourcePtr, as, elemTok); + } + } + Value newPtr; + { + auto resTy = mlir::cast(op.getResult().getType()); + Type elemTy = resTy.getElementType(); + if (elemTy.isInteger(16)) { + std::string castElemTypeStr = "int16_t"; + if (cast(elemTy).isUnsigned()) + castElemTypeStr = "uint16_t"; + + std::string qualifier = "__gm__"; + if (Attribute ms = srcType.getMemorySpace()) { + if (auto ptoAttr = dyn_cast(ms)) { + qualifier = addrSpaceQualifier(ptoAttr.getAddressSpace()); + } + } + + auto typedPtrTy = emitc::PointerType::get( + emitc::OpaqueType::get(ctx, qualifier + " " + castElemTypeStr)); + Value typedSourcePtr = rewriter.create(loc, typedPtrTy, sourcePtr); + newPtr = rewriter.create(loc, typedPtrTy, typedSourcePtr, totalOffset); + } else { + newPtr = rewriter.create(loc, sourcePtr.getType(), sourcePtr, totalOffset); + } + } + + + // ------------------------------------------------------------------------- + // Part 2: For non-GM memrefs, keep pointer (no GlobalTensor). + // ------------------------------------------------------------------------- + bool isGlobal = true; + if (auto asAttr = dyn_cast_or_null(srcType.getMemorySpace())) { + auto as = asAttr.getAddressSpace(); + isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); + } + if (!isGlobal) { + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + if (newPtr.getType() != dstTy) + newPtr = rewriter.create(loc, dstTy, newPtr); + rewriter.replaceOp(op, newPtr); + return success(); + } + + // ------------------------------------------------------------------------- + // Part 3: 生成 GlobalTensor 类型 (Shape/Stride Template Generation) + // ------------------------------------------------------------------------- + + // When emitting C++ with `declareVariablesAtTop`, value declarations are + // hoisted before body statements. Avoid introducing local `using` aliases + // for templated types (Shape/Stride/GlobalTensor) because those aliases + // would appear after the hoisted declarations and break compilation + // (`unknown type name`). + // + // Instead, use the fully spelled template types as EmitC opaque types. + + auto resTy = mlir::cast(op.getResult().getType()); + + // 1. 解析具体元素类型 + std::string elemTypeStr = getElemTypeStringForGT(resTy.getElementType()); + + // 2. 生成 Shape 模板参数,之后会右对齐有效维度并补齐到 5 维(高维填 1) + SmallVector shapeParamsVec; + SmallVector sizeValues; // 每个维度对应的运行时 size(统一为 unsigned) + auto resShape = resTy.getShape(); + auto mixedSizes = op.getMixedSizes(); + sizeValues.reserve(rank); + for (int i = 0; i < resTy.getRank(); ++i) { + if (resShape[i] == ShapedType::kDynamic) { + shapeParamsVec.push_back(-1); + } else { + shapeParamsVec.push_back(resShape[i]); + } + // size 值:优先从 op.getMixedSizes() 取(可动态/静态),否则退化为类型里的静态 shape。 + if (i < (int)mixedSizes.size()) + sizeValues.push_back(ofrToEmitCValue(mixedSizes[i])); + else + sizeValues.push_back( + mkU32(resShape[i] == ShapedType::kDynamic ? 1 : resShape[i])); + } + + // 3. 生成 Stride 模板参数 + 运行时 stride 值(考虑 subview step) + SmallVector strideTemplateVec; + SmallVector strideValues; // 每个维度对应的运行时 stride(统一为 unsigned) + strideTemplateVec.reserve(rank); + strideValues.reserve(rank); + auto subViewSteps = op.getMixedStrides(); + for (int i = 0; i < rank; ++i) { + OpFoldResult srcStrideOfr = + (i < (int)sourceStrides.size()) ? sourceStrides[i] + : rewriter.getIndexAttr(1); + OpFoldResult stepOfr = (i < (int)subViewSteps.size()) + ? subViewSteps[i] + : rewriter.getIndexAttr(1); + + auto srcStatic = extractStaticInt(srcStrideOfr); + auto stepStatic = extractStaticInt(stepOfr); + if (srcStatic && stepStatic) { + int64_t finalStride = (*srcStatic) * (*stepStatic); + strideTemplateVec.push_back(finalStride); + strideValues.push_back(mkU32(finalStride)); + continue; + } + + strideTemplateVec.push_back(-1); + Value srcV = ofrToEmitCValue(srcStrideOfr); + Value stepV = ofrToEmitCValue(stepOfr); + // 尽量避免乘以 1 生成冗余指令 + if (stepStatic && *stepStatic == 1) + strideValues.push_back(srcV); + else if (srcStatic && *srcStatic == 1) + strideValues.push_back(stepV); + else + strideValues.push_back( + rewriter.create(loc, u32Ty, srcV, stepV)); + } + + // 3.1 右对齐到 5 维:shape 补 1;已有维度继承原 stride; + // 被补出来的高维按“紧密升维”规则连续推导:stride[i] = shape[i+1] * stride[i+1] + SmallVector finalShape; + SmallVector finalStride; + buildGlobalTensorShapeAndStride(shapeParamsVec, strideTemplateVec, + finalShape, finalStride); + Value oneU32 = mkU32(1); + SmallVector finalShapeValues(5, oneU32); + SmallVector finalStrideValues(5, oneU32); + int shift = 5 - rank; + + // 先放入原始 shape/stride(保持用户提供的值) + for (int i = 0; i < rank && i < 5; ++i) { + finalShapeValues[shift + i] = sizeValues[i]; + finalStrideValues[shift + i] = strideValues[i]; + } + + // 从低维到高维倒推补齐 stride(仅对补出来的前置维度生效) + for (int i = 3; i >= 0; --i) { + // 如果该维已由原始 rank 覆盖,则保持原值 + if (i >= shift) + continue; + if (finalStride[i] != -1) { + finalStrideValues[i] = mkU32(finalStride[i]); + continue; + } + // 动态推导:stride[i] = shape[i+1] * stride[i+1] + if (finalShape[i + 1] == 1) { + finalStrideValues[i] = finalStrideValues[i + 1]; + } else { + finalStrideValues[i] = rewriter.create( + loc, u32Ty, finalShapeValues[i + 1], finalStrideValues[i + 1]); + } + } + + std::string shapeParams = joinIntTemplateParams(finalShape); + std::string strideParams = joinIntTemplateParams(finalStride); + + // Spelled-out C++ types. + std::string shapeCppType = "pto::Shape<" + shapeParams + ">"; + std::string strideCppType = "pto::Stride<" + strideParams + ">"; + + // 3.0 Layout: prefer the attribute from InferPTOLayout; only fall back to + // local inference when the pass is disabled. + std::string layoutEnum = "pto::Layout::ND"; + if (auto layout = resolveLayoutForGlobalTensor(op, op.getSource())) { + layoutEnum = layoutToEmitCString(*layout); + } else { + bool allStatic = + llvm::all_of(finalShape, [](int64_t value) { return value != -1; }) && + llvm::all_of(finalStride, [](int64_t value) { return value != -1; }); + + int layoutTag = 0; // ND + auto elemBytes = 4; // default float + if (elemTypeStr.find("half") != std::string::npos || + elemTypeStr.find("f16") != std::string::npos || + elemTypeStr.find("bf16") != std::string::npos) + elemBytes = 2; + else if (elemTypeStr.find("double") != std::string::npos || + elemTypeStr.find("f64") != std::string::npos) + elemBytes = 8; + + if (allStatic) { + if (finalShape[2] == 16 && + finalShape[2] * finalShape[3] * elemBytes == 512 && + finalStride[4] == 1 && finalStride[3] == finalShape[4]) { + layoutTag = 2; // NZ + } else { + bool isRow = finalStride[4] == 1; + for (int i = 3; i >= 0; --i) + isRow &= (finalStride[i] == + multiplyOrDynamic(finalStride[i + 1], finalShape[i + 1])); + bool isCol = finalStride[0] == 1; + for (int i = 0; i < 4; ++i) + isCol &= (finalStride[i + 1] == + multiplyOrDynamic(finalStride[i], finalShape[i])); + if (isCol) + layoutTag = 1; // DN + else + layoutTag = isRow ? 0 : 0; // fallback ND + } + } + + if (layoutTag == 1) + layoutEnum = "pto::Layout::DN"; + else if (layoutTag == 2) + layoutEnum = "pto::Layout::NZ"; + } + // GlobalTensor takes a Layout non-type template parameter; directly use the + // enum constant. + + + // ------------------------------------------------------------------------- + // Part 3: 显式对象实例化 (Explicit Object Instantiation) + // ------------------------------------------------------------------------- + + // A. Instantiate Shape object. + auto shapeTypeOpaque = emitc::OpaqueType::get(ctx, shapeCppType); + SmallVector shapeArgs; + // 从 adaptor.getSizes() 获取 subview 的所有 dynamic sizes + for (Value dynSize : adaptor.getSizes()) { + shapeArgs.push_back(dynSize); + } + + auto shapeInstOp = rewriter.create( + loc, + shapeTypeOpaque, // 返回类型 + shapeCppType, // 调用的“函数名”即类名构造函数 + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange(shapeArgs) + ); + + // B. Instantiate Stride object. + auto strideTypeOpaque = emitc::OpaqueType::get(ctx, strideCppType); + // 仅传入动态 stride 维度对应的值,匹配 pto::Stride 的 N-parameter ctor(并满足其 static_assert)。 + SmallVector strideCtorArgs; + strideCtorArgs.reserve(5); + for (int i = 0; i < 5; ++i) { + if (finalStride[i] == -1) + strideCtorArgs.push_back(finalStrideValues[i]); + } + auto strideInstOp = rewriter.create( + loc, strideTypeOpaque, strideCppType, + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange(strideCtorArgs)); + + // C. Instantiate GlobalTensor object (ptr + shape + stride). + std::string gtCppType = "GlobalTensor<" + elemTypeStr + ", " + shapeCppType + + ", " + strideCppType + ", " + layoutEnum + ">"; + auto gtType = emitc::OpaqueType::get(ctx, gtCppType); + + // 准备构造参数: [ptr, shape_instance, stride_instance] + SmallVector gtConstructorArgs; + gtConstructorArgs.push_back(newPtr); + gtConstructorArgs.push_back(shapeInstOp.getResult(0)); // 拿到 shape_inst 的 SSA Value + gtConstructorArgs.push_back(strideInstOp.getResult(0)); // 拿到 stride_inst 的 SSA Value + + rewriter.replaceOpWithNewOp( + op, + gtType, + gtCppType, + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange(gtConstructorArgs) + ); + + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Helper: build GlobalTensor from a static MemRef (for TLOAD/TSTORE) +//===----------------------------------------------------------------------===// + +static std::string getElemTypeStringForGT(Type elemTy) { + return getEmitCScalarTypeToken(elemTy); +} + +static bool hasStaticShape(MemRefType mrTy) { + return llvm::none_of(mrTy.getShape(), [](int64_t dim) { + return dim == ShapedType::kDynamic; + }); +} + +static bool getStaticMemrefLayout(MemRefType mrTy, SmallVectorImpl &strides, + int64_t &offset) { + if (failed(getStridesAndOffset(mrTy, strides, offset))) { + strides.clear(); + int64_t stride = 1; + ArrayRef shape = mrTy.getShape(); + for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { + strides.push_back(stride); + stride *= shape[i]; + } + std::reverse(strides.begin(), strides.end()); + offset = 0; + } + return offset != ShapedType::kDynamic && + llvm::none_of(strides, [](int64_t strideValue) { + return strideValue == ShapedType::kDynamic; + }); +} + +static Value applyStaticMemrefOffset(ConversionPatternRewriter &rewriter, + Location loc, Value basePtr, + int64_t offset) { + if (offset == 0) + return basePtr; + auto *ctx = rewriter.getContext(); + Type u32Ty = emitc::OpaqueType::get(ctx, "unsigned"); + auto offVal = rewriter.create( + loc, u32Ty, emitc::OpaqueAttr::get(ctx, std::to_string(offset))); + return rewriter.create(loc, basePtr.getType(), basePtr, offVal); +} + +static int getGlobalTensorElementBytes(Type elemTy) { + return static_cast(getPTOStorageElemByteSize(elemTy)); +} + +static int64_t multiplyOrDynamic(int64_t lhs, int64_t rhs) { + if (lhs < 0 || rhs < 0) + return -1; + return lhs * rhs; +} + +static void buildGlobalTensorShapeAndStride(ArrayRef shape, + ArrayRef strides, + SmallVectorImpl &shape5D, + SmallVectorImpl &stride5D) { + shape5D.assign(5, 1); + stride5D.assign(5, 1); + int rank = static_cast(shape.size()); + int shift = 5 - rank; + for (int i = 0; i < rank && i < 5; ++i) { + shape5D[shift + i] = shape[i]; + stride5D[shift + i] = strides[i]; + } + for (int i = 3; i >= 0; --i) { + if (i >= shift) + continue; + stride5D[i] = multiplyOrDynamic(shape5D[i + 1], stride5D[i + 1]); + } +} + +static std::string joinIntTemplateParams(ArrayRef values) { + std::string result; + for (size_t i = 0; i < values.size(); ++i) { + if (i != 0) + result += ", "; + result += std::to_string(values[i]); + } + return result; +} + +static SmallVector buildRowMajorStrides(ArrayRef shape) { + SmallVector strides(shape.size(), 1); + int64_t running = 1; + for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { + strides[i] = running; + running = multiplyOrDynamic(running, shape[i]); + } + return strides; +} + +static std::string getGlobalTensorTypeStringFromShape(Type elemTy, + ArrayRef shape, + StringRef layoutEnum) { + SmallVector strides = buildRowMajorStrides(shape); + return getGlobalTensorTypeStringFromShapeAndStrides(elemTy, shape, strides, + layoutEnum); +} + +static std::string getGlobalTensorTypeStringFromShapeAndStrides( + Type elemTy, ArrayRef shape, ArrayRef strides, + StringRef layoutEnum) { + SmallVector shape5D; + SmallVector stride5D; + buildGlobalTensorShapeAndStride(shape, strides, shape5D, stride5D); + + std::string elemTypeStr = getElemTypeStringForGT(elemTy); + std::string shapeType = "pto::Shape<" + joinIntTemplateParams(shape5D) + ">"; + std::string strideType = + "pto::Stride<" + joinIntTemplateParams(stride5D) + ">"; + return "GlobalTensor<" + elemTypeStr + ", " + shapeType + ", " + + strideType + ", " + layoutEnum.str() + ">"; +} + +static emitc::OpaqueType getGlobalTensorOpaqueTypeFromShape( + MLIRContext *ctx, Type elemTy, ArrayRef shape, + StringRef layoutEnum) { + return emitc::OpaqueType::get( + ctx, getGlobalTensorTypeStringFromShape(elemTy, shape, layoutEnum)); +} + +static std::string inferFallbackGlobalTensorLayout(ArrayRef shape5D, + ArrayRef stride5D, + Type elemTy) { + int elemBytes = getGlobalTensorElementBytes(elemTy); + if (elemBytes == 0) + return "pto::Layout::ND"; + if (shape5D[2] == 16 && multiplyOrDynamic(shape5D[2], shape5D[3]) * elemBytes == 512 && + stride5D[4] == 1 && stride5D[3] == shape5D[4]) { + return "pto::Layout::NZ"; + } + + bool isRowMajor = stride5D[4] == 1; + for (int i = 3; i >= 0 && isRowMajor; --i) + isRowMajor = stride5D[i] == multiplyOrDynamic(stride5D[i + 1], shape5D[i + 1]); + + bool isColMajor = stride5D[0] == 1; + for (int i = 0; i < 4 && isColMajor; ++i) + isColMajor = stride5D[i + 1] == multiplyOrDynamic(stride5D[i], shape5D[i]); + + if (isColMajor) + return "pto::Layout::DN"; + return isRowMajor ? "pto::Layout::ND" : "pto::Layout::ND"; +} + +static std::string resolveGlobalTensorLayout(Operation *anchor, Value basePtr, + ArrayRef shape5D, + ArrayRef stride5D, + Type elemTy) { + if (auto layout = resolveLayoutForGlobalTensor(anchor, basePtr)) + return layoutToEmitCString(*layout); + return inferFallbackGlobalTensorLayout(shape5D, stride5D, elemTy); +} + +struct GlobalTensorTypeNames { + std::string shapeTypeName; + std::string strideTypeName; + std::string tensorTypeName; + std::string layoutConstName; +}; + +static GlobalTensorTypeNames getGlobalTensorTypeNames(Operation *anchor) { + std::string suffix = "_" + std::to_string(reinterpret_cast(anchor)); + return { + "GTShape" + suffix, + "GTStride" + suffix, + "GT" + suffix, + "GT" + suffix + "_layout", + }; +} +static Value buildGlobalTensorFromMemref(ConversionPatternRewriter &rewriter, + Location loc, Value basePtr, + MemRefType mrTy, + Operation *anchor) { + auto *ctx = rewriter.getContext(); + + ArrayRef shape = mrTy.getShape(); + if (!hasStaticShape(mrTy)) + return Value(); + + SmallVector strides; + int64_t offset = 0; + if (!getStaticMemrefLayout(mrTy, strides, offset)) + return Value(); + + Value ptr = applyStaticMemrefOffset(rewriter, loc, basePtr, offset); + GlobalTensorTypeNames names = getGlobalTensorTypeNames(anchor); + std::string elemTypeStr = getElemTypeStringForGT(mrTy.getElementType()); + SmallVector shape5D; + SmallVector stride5D; + buildGlobalTensorShapeAndStride(shape, strides, shape5D, stride5D); + + rewriter.create( + loc, "using " + names.shapeTypeName + " = pto::Shape<" + + joinIntTemplateParams(shape5D) + ">;"); + rewriter.create( + loc, "using " + names.strideTypeName + " = pto::Stride<" + + joinIntTemplateParams(stride5D) + ">;"); + + std::string layoutEnum = resolveGlobalTensorLayout( + anchor, basePtr, shape5D, stride5D, mrTy.getElementType()); + rewriter.create(loc, "constexpr pto::Layout " + + names.layoutConstName + " = " + + layoutEnum + ";"); + + auto shapeTypeOpaque = emitc::OpaqueType::get(ctx, names.shapeTypeName); + auto strideTypeOpaque = emitc::OpaqueType::get(ctx, names.strideTypeName); + auto shapeInstOp = rewriter.create( + loc, shapeTypeOpaque, names.shapeTypeName, ArrayAttr{}, ArrayAttr{}, + ValueRange{}); + auto strideInstOp = rewriter.create( + loc, strideTypeOpaque, names.strideTypeName, ArrayAttr{}, ArrayAttr{}, + ValueRange{}); + + rewriter.create( + loc, "using " + names.tensorTypeName + " = GlobalTensor<" + elemTypeStr + + ", " + names.shapeTypeName + ", " + names.strideTypeName + + ", " + names.layoutConstName + ">;"); + auto gtType = emitc::OpaqueType::get(ctx, names.tensorTypeName); + + SmallVector gtArgs; + gtArgs.push_back(ptr); + gtArgs.push_back(shapeInstOp.getResult(0)); + gtArgs.push_back(strideInstOp.getResult(0)); + + auto gtInst = rewriter.create( + loc, gtType, names.tensorTypeName, ArrayAttr{}, ArrayAttr{}, + ValueRange(gtArgs)); + + return gtInst.getResult(0); +} + +static Value maybeWrapGlobalMemrefAsGlobalTensor( + ConversionPatternRewriter &rewriter, Location loc, Value loweredValue, + Type originalType, Operation *anchor) { + auto mrTy = dyn_cast(originalType); + if (!mrTy) + return loweredValue; + + bool isGlobal = true; + if (auto asAttr = + dyn_cast_or_null(mrTy.getMemorySpace())) { + auto as = asAttr.getAddressSpace(); + isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); + } + if (!isGlobal) + return loweredValue; + + if (Value gt = + buildGlobalTensorFromMemref(rewriter, loc, loweredValue, mrTy, anchor)) + return gt; + return loweredValue; +} + +static Value castToGMBytePointer(ConversionPatternRewriter &rewriter, + Location loc, Value value) { + auto *ctx = rewriter.getContext(); + auto targetTy = + emitc::PointerType::get(emitc::OpaqueType::get(ctx, "__gm__ uint8_t")); + if (value.getType() == targetTy) + return value; + + auto castTyAttr = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "__gm__ uint8_t*")}); + if (isSetFFTsPointerLikeType(value.getType())) { + return rewriter + .create(loc, targetTy, "reinterpret_cast", + ArrayAttr{}, castTyAttr, + ValueRange{value}) + .getResult(0); + } + return rewriter.create(loc, targetTy, value).getResult(); +} + +static Value materializeTensorViewDataPointer( + ConversionPatternRewriter &rewriter, Location loc, Value value, + Type sourceType) { + auto tvTy = dyn_cast(sourceType); + if (!tvTy) + return value; + + auto *ctx = rewriter.getContext(); + std::string elemTypeStr = getElemTypeStringForGT(tvTy.getElementType()); + auto ptrTy = emitc::PointerType::get( + emitc::OpaqueType::get(ctx, "__gm__ " + elemTypeStr)); + return rewriter + .create(loc, ptrTy, "PTOAS__GLOBAL_TENSOR_DATA", + ArrayAttr{}, ArrayAttr{}, ValueRange{value}) + .getResult(0); +} + +static std::string tileBufBLayoutToken(pto::TileBufConfigAttr configAttr) { + std::string blTok = "BLayout::RowMajor"; + if (auto blAttr = dyn_cast(configAttr.getBLayout())) { + if (static_cast(blAttr.getValue()) == 1) + blTok = "BLayout::ColMajor"; + } + return blTok; +} + +static std::string tileBufSLayoutToken(pto::TileBufConfigAttr configAttr) { + std::string slTok = "SLayout::NoneBox"; + if (auto slAttr = dyn_cast(configAttr.getSLayout())) { + int32_t slVal = static_cast(slAttr.getValue()); + slTok = (slVal == 1) ? "SLayout::RowMajor" + : (slVal == 2) ? "SLayout::ColMajor" + : "SLayout::NoneBox"; + } + return slTok; +} + +static std::string tileBufPadToken(pto::TileBufConfigAttr configAttr) { + std::string padTok = "PadValue::Null"; + if (auto padAttr = dyn_cast(configAttr.getPad())) { + switch (static_cast(padAttr.getValue())) { + case 1: + padTok = "PadValue::Zero"; + break; + case 2: + padTok = "PadValue::Max"; + break; + case 3: + padTok = "PadValue::Min"; + break; + default: + padTok = "PadValue::Null"; + break; + } + } + return padTok; +} + +static pto::BLayout getTileBufBLayoutValue(pto::TileBufConfigAttr configAttr) { + if (auto blAttr = dyn_cast(configAttr.getBLayout())) + return blAttr.getValue(); + return pto::BLayout::RowMajor; +} + +static int64_t renderTileTemplateDim(int64_t rawDim, Type elemTy, + pto::BLayout blayout, int dimIdx) { + assert(dimIdx >= 0 && dimIdx < 2 && + "renderTileTemplateDim expects a rank-2 rows/cols dimension index"); + if (rawDim == ShapedType::kDynamic) + return rawDim; + if (!pto::isPTOFloat4PackedType(elemTy)) + return rawDim; + int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; + return dimIdx == packedDim ? rawDim * 2 : rawDim; +} + +static FailureOr buildAsyncScratchTileValue( + ConversionPatternRewriter &rewriter, Location loc, Value originalScratch, + Value emittedScratch) { + Value scratch = peelUnrealized(emittedScratch); + if (auto opaqueTy = dyn_cast(scratch.getType())) { + StringRef typeStr = opaqueTy.getValue(); + if (typeStr.contains("Tile<") || typeStr.contains("ConvTile<")) + return scratch; + } + + auto memTy = dyn_cast(originalScratch.getType()); + if (!memTy) + return failure(); + + ArrayRef shape = memTy.getShape(); + if (!memTy.hasStaticShape() || shape.empty() || shape.size() > 2) + return failure(); + + int64_t rows = shape.size() == 1 ? 1 : shape[0]; + int64_t cols = shape.size() == 1 ? shape[0] : shape[1]; + + auto *ctx = rewriter.getContext(); + pto::TileBufConfigAttr configAttr = pto::TileBufConfigAttr::getDefault(ctx); + if (auto bind = originalScratch.getDefiningOp()) { + configAttr = bind.getConfig(); + } else if (auto cast = originalScratch.getDefiningOp()) { + if (auto config = cast.getConfig()) + configAttr = *config; + } + + int32_t fractal = 512; + if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) + fractal = frAttr.getInt(); + + Type elemTy = memTy.getElementType(); + pto::BLayout blayout = getTileBufBLayoutValue(configAttr); + int64_t templateRows = renderTileTemplateDim(rows, elemTy, blayout, 0); + int64_t templateCols = renderTileTemplateDim(cols, elemTy, blayout, 1); + std::string elemTypeStr = getEmitCScalarTypeToken(elemTy); + std::string tileTypeStr = + "Tile"; + + Value tile = rewriter + .create( + loc, emitc::OpaqueType::get(ctx, tileTypeStr), + emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + auto addr = rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + Value scratchAddr = + rewriter + .create(loc, emitc::OpaqueType::get(ctx, "uint64_t"), + "reinterpret_cast", ArrayAttr{}, addr, + ValueRange{scratch}) + .getResult(0); + rewriter.create(loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{tile, scratchAddr}); + return tile; +} + +static FailureOr buildSyncAllWorkspaceTileValue( + ConversionPatternRewriter &rewriter, Location loc, Value originalWorkspace, + Value emittedWorkspace) { + Value workspace = peelUnrealized(emittedWorkspace); + if (auto opaqueTy = dyn_cast(workspace.getType())) { + StringRef typeStr = opaqueTy.getValue(); + if (typeStr.contains("Tile<") || typeStr.contains("ConvTile<")) + return workspace; + } + + auto memTy = dyn_cast(originalWorkspace.getType()); + if (!memTy) + return failure(); + if (!memTy.hasStaticShape()) + return failure(); + + ArrayRef rawShape = memTy.getShape(); + if (rawShape.empty() || rawShape.size() > 2) + return failure(); + + int64_t rows = rawShape.size() == 1 ? 1 : rawShape[0]; + int64_t cols = rawShape.size() == 1 ? rawShape[0] : rawShape[1]; + SmallVector shape{rows, cols}; + SmallVector validShape{rows, cols}; + + auto *ctx = rewriter.getContext(); + pto::TileBufConfigAttr configAttr = pto::TileBufConfigAttr::getDefault(ctx); + if (auto bind = originalWorkspace.getDefiningOp()) { + configAttr = bind.getConfig(); + } else if (auto cast = originalWorkspace.getDefiningOp()) { + if (auto config = cast.getConfig()) + configAttr = *config; + } + + Attribute memorySpace = memTy.getMemorySpace(); + if (!memorySpace) + return failure(); + + auto tileTy = pto::TileBufType::get(ctx, shape, memTy.getElementType(), + memorySpace, validShape, configAttr); + auto tileTypeString = getEmitCTileTypeString(tileTy); + if (!tileTypeString) + return failure(); + + auto tileEmitTy = emitc::OpaqueType::get(ctx, *tileTypeString); + Value tile = rewriter + .create(loc, tileEmitTy, + emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + + Value rawPtr = workspace; + auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); + if (isSetFFTsPointerLikeType(rawPtr.getType())) { + auto rcU64 = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + rawPtr = rewriter + .create(loc, u64Ty, "reinterpret_cast", + ArrayAttr{}, rcU64, + ValueRange{rawPtr}) + .getResult(0); + } else if (rawPtr.getType() != u64Ty) { + rawPtr = rewriter.create(loc, u64Ty, rawPtr).getResult(); + } + + rewriter.create(loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{tile, rawPtr}); + return tile; +} + +//===----------------------------------------------------------------------===// +// pto.pointer_cast lowering +//===----------------------------------------------------------------------=== +struct PointerCastConversion : public OpConversionPattern { + static bool getIndexConst(Value v, int64_t &out) { + if (auto cst = v.getDefiningOp()) { + if (auto ia = dyn_cast(cst.getValue())) { + out = ia.getValue().getSExtValue(); + return true; + } + } + return false; + } + + using OpConversionPattern::OpConversionPattern; + + enum class TileRole { Vec, Mat, Left, Right, Acc, Bias, Scaling }; + + static void collectUserOpsThroughCasts(Value v, SmallVectorImpl &out) { + for (Operation *u : v.getUsers()) { + if (auto castOp = dyn_cast(u)) { + for (Value r : castOp.getResults()) + collectUserOpsThroughCasts(r, out); + continue; + } + out.push_back(u); + } + } + + static Value peelUnrealized(Value v) { + while (auto castOp = v.getDefiningOp()) { + v = castOp.getOperand(0); + } + return v; + } + + static TileRole inferRole(pto::PointerCastOp op) { + // 1. 优先检查 AddressSpace + if (auto memRefTy = dyn_cast(op.getType())) { + Attribute memorySpace = memRefTy.getMemorySpace(); + if (auto ptoAttr = dyn_cast_or_null(memorySpace)) { + switch (ptoAttr.getAddressSpace()) { + case pto::AddressSpace::LEFT: return TileRole::Left; + case pto::AddressSpace::RIGHT: return TileRole::Right; + case pto::AddressSpace::ACC: return TileRole::Acc; + case pto::AddressSpace::BIAS: return TileRole::Bias; + case pto::AddressSpace::MAT: return TileRole::Mat; + case pto::AddressSpace::SCALING: return TileRole::Scaling; + default: break; + } + } + } + + // 2. 通过 Usage 推导 (Fallback) + SmallVector users; + collectUserOpsThroughCasts(op.getResult(), users); + + for (Operation *user : users) { + if (auto mm = dyn_cast(user)) { + if (mm.getDst() && peelUnrealized(mm.getDst()) == op.getResult()) return TileRole::Acc; + if (peelUnrealized(mm.getLhs()) == op.getResult()) return TileRole::Left; + if (peelUnrealized(mm.getRhs()) == op.getResult()) return TileRole::Right; + } + if (auto mmacc = dyn_cast(user)) { + if (mmacc.getDst() && peelUnrealized(mmacc.getDst()) == op.getResult()) return TileRole::Acc; + if (peelUnrealized(mmacc.getAccIn()) == op.getResult()) return TileRole::Acc; + if (peelUnrealized(mmacc.getLhs()) == op.getResult()) return TileRole::Left; + if (peelUnrealized(mmacc.getRhs()) == op.getResult()) return TileRole::Right; + } + } + + return TileRole::Vec; + } + + // [新增] 辅助函数:判断 Value 是否源自 arith.constant + static bool isConstant(Value v, int64_t &outVal) { + if (!v) return false; + if (auto cst = v.getDefiningOp()) { + if (auto attr = dyn_cast(cst.getValue())) { + outVal = attr.getInt(); + return true; + } + } + return false; + } + + LogicalResult matchAndRewrite(pto::PointerCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto selfType = mlir::cast(op.getType()); + ArrayRef shape = selfType.getShape(); + Type elemType = selfType.getElementType(); + + // 1. 推导 Tile Role + TileRole role = inferRole(op); + + // 2. 类型字符串生成 (elemTypeStr, dimStr) + std::string elemTypeStr = getEmitCScalarTypeToken(elemType); + + std::string dimStr; + pto::BLayout blayout = pto::BLayout::RowMajor; + auto dimToString = [&](int64_t dim, const char *symbol, + int dimIdx) -> std::string { + if (dim == ShapedType::kDynamic) + return std::string(symbol); + return std::to_string(renderTileTemplateDim(dim, elemType, blayout, + dimIdx)); + }; + + // 3. Role Token + const char *roleTok = "TileType::Vec"; + switch (role) { + case TileRole::Left: roleTok = "TileType::Left"; break; + case TileRole::Right: roleTok = "TileType::Right"; break; + case TileRole::Acc: roleTok = "TileType::Acc"; break; + case TileRole::Bias: roleTok = "TileType::Bias"; break; + case TileRole::Mat: roleTok = "TileType::Mat"; break; + case TileRole::Vec: roleTok = "TileType::Vec"; break; + case TileRole::Scaling: roleTok = "TileType::Scaling"; break; + } + + // 4. Config & Layout (support BLayoutAttr/SLayoutAttr/PadValueAttr after namespace change) + std::string layoutParams = "BLayout::RowMajor"; + std::string extraParams = ""; + if (auto configOpt = op.getConfig()) { + auto config = *configOpt; + int32_t blVal = 0; + if (auto attr = dyn_cast(config.getBLayout())) + blVal = static_cast(attr.getValue()); + + if (blVal == 1) layoutParams = "BLayout::ColMajor"; + blayout = blVal == 1 ? pto::BLayout::ColMajor : pto::BLayout::RowMajor; + + int32_t slVal = 0; + if (auto attr = dyn_cast(config.getSLayout())) + slVal = static_cast(attr.getValue()); + + std::string slStr = (slVal == 1) ? "SLayout::RowMajor" : (slVal == 2) ? "SLayout::ColMajor" : "SLayout::NoneBox"; + + int32_t frVal = 0; + if (auto attr = dyn_cast(config.getSFractalSize())) frVal = attr.getInt(); + + int32_t padVal = 0; + if (auto attr = dyn_cast(config.getPad())) + padVal = static_cast(attr.getValue()); + + std::string padStr = "PadValue::Null"; + switch (padVal) { + case 1: padStr = "PadValue::Zero"; break; + case 2: padStr = "PadValue::Max"; break; + case 3: padStr = "PadValue::Min"; break; + } + + int32_t compactVal = 0; + if (auto attr = dyn_cast(config.getCompactMode())) + compactVal = static_cast(attr.getValue()); + + std::string compactStr = "CompactMode::Null"; + switch (compactVal) { + case 1: compactStr = "CompactMode::Normal"; break; + case 2: compactStr = "CompactMode::RowPlusOne"; break; + } + + if (!slStr.empty()) { + extraParams += ", " + slStr + ", " + std::to_string(frVal) + ", " + + padStr + ", " + compactStr; + } + } else { + extraParams = ", SLayout::NoneBox, 512, PadValue::Null, CompactMode::Null"; + } + + if (role == TileRole::Left) + dimStr = dimToString(shape[0], "M", 0) + ", " + + dimToString(shape[1], "K", 1); + else if (role == TileRole::Right) + dimStr = dimToString(shape[0], "K", 0) + ", " + + dimToString(shape[1], "N", 1); + else if (role == TileRole::Bias) + dimStr = "1, " + dimToString(shape[1], "N", 1); + else + dimStr = dimToString(shape[0], "M", 0) + ", " + + dimToString(shape[1], "N", 1); + + // [核心修改] Valid Dims 处理逻辑 (支持混合静态/动态) + std::string vrowTok, vcolTok; + bool useConstructor = false; + + bool rowIsDynamic = false; + bool colIsDynamic = false; + + SmallVector constructorArgs; + + Value vRow = op.getValidRow(); + Value vCol = op.getValidCol(); + Value vRowEmitC = adaptor.getValidRow(); + Value vColEmitC = adaptor.getValidCol(); + bool forceDynamicValid = op->hasAttr(kForceDynamicValidShapeAttrName); + + int64_t cRow = 0, cCol = 0; + bool rowIsConst = vRow && isConstant(vRow, cRow); + bool colIsConst = vCol && isConstant(vCol, cCol); + + auto makeCtorDimValue = [&](Value emitted, int64_t fallback) -> Value { + if (emitted) + return emitted; + return makeEmitCIntConstant( + rewriter, loc, emitc::OpaqueType::get(ctx, "int32_t"), fallback); + }; + auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { + if (!emitted || !pto::isPTOFloat4PackedType(elemType)) + return emitted; + int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; + if (dimIdx != packedDim) + return emitted; + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); + return rewriter.create(loc, i32Ty, emitted, two).getResult(); + }; + + if (forceDynamicValid) { + vrowTok = "-1"; + vcolTok = "-1"; + useConstructor = true; + constructorArgs.push_back( + makeCtorDimValue(maybeScaleDynamicValid(vRowEmitC, 0), + renderTileTemplateDim(rowIsConst ? cRow : shape[0], + elemType, blayout, 0))); + constructorArgs.push_back( + makeCtorDimValue(maybeScaleDynamicValid(vColEmitC, 1), + renderTileTemplateDim(colIsConst ? cCol : shape[1], + elemType, blayout, 1))); + } else { + if (rowIsConst) { + vrowTok = std::to_string( + renderTileTemplateDim(cRow, elemType, blayout, 0)); + } else if (vRow) { + vrowTok = "-1"; + rowIsDynamic = true; + useConstructor = true; + } else { + vrowTok = std::to_string( + renderTileTemplateDim(shape[0], elemType, blayout, 0)); + } + + if (colIsConst) { + vcolTok = std::to_string( + renderTileTemplateDim(cCol, elemType, blayout, 1)); + } else if (vCol) { + vcolTok = "-1"; + colIsDynamic = true; + useConstructor = true; + } else { + vcolTok = std::to_string( + renderTileTemplateDim(shape[1], elemType, blayout, 1)); + } + + if (useConstructor) { + if (rowIsDynamic && vRowEmitC) + constructorArgs.push_back(maybeScaleDynamicValid(vRowEmitC, 0)); + if (colIsDynamic && vColEmitC) + constructorArgs.push_back(maybeScaleDynamicValid(vColEmitC, 1)); + } + } + + // 5. 生成 Tile 类型字符串 + std::string tileTypeStr = + std::string("Tile<") + roleTok + ", " + elemTypeStr + ", " + dimStr + ", " + + layoutParams + ", " + vrowTok + ", " + vcolTok + extraParams + ">"; + + auto tileType = emitc::OpaqueType::get(ctx, tileTypeStr); + Value resultValue; + + if (useConstructor) { + // 使用 CallOpaqueOp 生成构造函数调用 (Tile v = Tile(...)) + auto ctorOp = rewriter.create( + loc, + tileType, // Result Type + tileTypeStr, // Callee Name (类名) + ArrayAttr{}, // args + ArrayAttr{}, // template_args + ValueRange(constructorArgs) // operands + ); + resultValue = ctorOp.getResult(0); + } else { + // 静态情况 (Tile v;) + auto varOp = rewriter.create( + loc, + tileType, + emitc::OpaqueAttr::get(ctx, "") + ); + resultValue = varOp.getResult(); + } + + // TASSIGN: pto-isa expects an integral address. + Value addr = adaptor.getAddrs()[0]; + if (isa(addr.getType()) || + (isa(addr.getType()) && + cast(addr.getType()).getValue().ends_with("*"))) { + auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); + auto rcU64 = rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + addr = rewriter.create( + loc, u64Ty, "reinterpret_cast", + /*args=*/ArrayAttr{}, /*templateArgs=*/rcU64, + /*operands=*/ValueRange{addr}) + .getResult(0); + } + + rewriter.create( + loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{resultValue, addr}); + + rewriter.replaceOp(op, resultValue); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// pto.load_dps / pto.store_dps lowering (FIX: keep optional result) +//===----------------------------------------------------------------------=== + +struct PTOTLoadToTLOAD : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TLoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getDst()) + return rewriter.notifyMatchFailure(op, "expected outs(dst) on pto.tload"); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value srcArg = src; + if (auto srcMrTy = dyn_cast(op.getSrc().getType())) { + bool isGlobal = true; + if (auto asAttr = dyn_cast_or_null(srcMrTy.getMemorySpace())) { + auto as = asAttr.getAddressSpace(); + isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); + } + if (isGlobal) { + if (Value gt = buildGlobalTensorFromMemref(rewriter, op.getLoc(), src, srcMrTy, + op.getOperation())) + srcArg = gt; + } + } + + rewriter.create( + op.getLoc(), TypeRange{}, "TLOAD", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, srcArg}); + + if (op->getNumResults() == 1) { + rewriter.replaceOp(op, dst); + } else { + rewriter.eraseOp(op); + } + return success(); + } +}; + +struct PTOTPrefetchToTPREFETCH : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPrefetchOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getDst()) + return rewriter.notifyMatchFailure(op, "expected outs(dst) on pto.tprefetch"); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value srcArg = src; + if (auto srcMrTy = dyn_cast(op.getSrc().getType())) { + bool isGlobal = true; + if (auto asAttr = dyn_cast_or_null(srcMrTy.getMemorySpace())) { + auto as = asAttr.getAddressSpace(); + isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); + } + if (isGlobal) { + if (Value gt = buildGlobalTensorFromMemref(rewriter, op.getLoc(), src, srcMrTy, + op.getOperation())) + srcArg = gt; + } + } + + rewriter.create( + op.getLoc(), TypeRange{}, "TPREFETCH", + ArrayAttr{}, ArrayAttr{}, ValueRange{dst, srcArg}); + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOTPrefetchAsyncToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPrefetchAsyncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = peelUnrealized(adaptor.getSrc()); + Value srcArg = src; + if (!isEmitCGlobalTensorLikeType(srcArg.getType())) { + auto srcMrTy = dyn_cast(op.getSrc().getType()); + if (!srcMrTy) + return rewriter.notifyMatchFailure( + op, "expected src to lower to GlobalTensor or memref"); + srcArg = buildGlobalTensorFromMemref(rewriter, op.getLoc(), src, srcMrTy, + op.getSrc().getDefiningOp() + ? op.getSrc().getDefiningOp() + : op.getOperation()); + } + if (!srcArg) + return rewriter.notifyMatchFailure(op, + "failed to build GlobalTensor src"); + + Value prefetchCtx = peelUnrealized(adaptor.getCtx()); + + Type eventTy = getTypeConverter()->convertType(op.getEvent().getType()); + if (!eventTy) + return rewriter.notifyMatchFailure( + op, "failed to convert tprefetch_async result type"); + + Value event = rewriter + .create( + op.getLoc(), TypeRange{eventTy}, "TPREFETCH_ASYNC", + ArrayAttr{}, ArrayAttr{}, + ValueRange{srcArg, prefetchCtx}) + .getResult(0); + + rewriter.replaceOp(op, ValueRange{event}); + return success(); + } +}; + +struct PTOMakePrefetchAsyncContextToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::MakePrefetchAsyncContextOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type ctxTy = getTypeConverter()->convertType(op.getCtx().getType()); + if (!ctxTy) + return rewriter.notifyMatchFailure( + op, "failed to convert make_prefetch_async_context result type"); + + Value workspace = peelUnrealized(adaptor.getWorkspace()); + workspace = castToGMBytePointer(rewriter, op.getLoc(), workspace); + + Value ctx = rewriter + .create( + op.getLoc(), TypeRange{ctxTy}, "pto::PrefetchAsyncContext", + ArrayAttr{}, ArrayAttr{}, ValueRange{workspace}) + .getResult(0); + + rewriter.replaceOp(op, ValueRange{ctx}); + return success(); + } +}; + +struct PTOGetPrefetchAsyncSessionToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::GetPrefetchAsyncSessionOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type sessionTy = getTypeConverter()->convertType(op.getSession().getType()); + if (!sessionTy) + return rewriter.notifyMatchFailure( + op, "failed to convert get_prefetch_async_session result type"); + + Value ctx = peelUnrealized(adaptor.getCtx()); + Value session = rewriter + .create( + op.getLoc(), TypeRange{sessionTy}, + "PTOAS__PREFETCH_CTX_SESSION", ArrayAttr{}, + ArrayAttr{}, ValueRange{ctx}) + .getResult(0); + + rewriter.replaceOp(op, ValueRange{session}); + return success(); + } +}; + +struct PTOTStoreToTSTORE : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + static std::string stPhaseTok(pto::STPhase phase) { + switch (phase) { + case pto::STPhase::Unspecified: return "STPhase::Unspecified"; + case pto::STPhase::Partial: return "STPhase::Partial"; + case pto::STPhase::Final: return "STPhase::Final"; + } + return "STPhase::Unspecified"; + } + + static std::string atomicTypeTok(pto::AtomicType atomicType) { + switch (atomicType) { + case pto::AtomicType::AtomicNone: return "AtomicType::AtomicNone"; + case pto::AtomicType::AtomicAdd: return "AtomicType::AtomicAdd"; + } + return "AtomicType::AtomicNone"; + } + + static std::string reluPreModeTok(pto::ReluPreMode reluPreMode) { + switch (reluPreMode) { + case pto::ReluPreMode::NoRelu: return "ReluPreMode::NoRelu"; + case pto::ReluPreMode::NormalRelu: return "ReluPreMode::NormalRelu"; + } + return "ReluPreMode::NoRelu"; + } + + LogicalResult matchAndRewrite(pto::TStoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getDst()) + return rewriter.notifyMatchFailure(op, "expected outs(dst) on pto.tstore"); + + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value preQuantScalar; + if (op.getPreQuantScalar()) + preQuantScalar = peelUnrealized(adaptor.getPreQuantScalar()); + Value dstArg = dst; + if (auto dstMrTy = dyn_cast(op.getDst().getType())) { + bool isGlobal = true; + if (auto asAttr = dyn_cast_or_null(dstMrTy.getMemorySpace())) { + auto as = asAttr.getAddressSpace(); + isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); + } + if (isGlobal) { + if (Value gt = buildGlobalTensorFromMemref(rewriter, op.getLoc(), dst, dstMrTy, + op.getOperation())) + dstArg = gt; + } + } + + const auto phase = op.getStPhase(); + const auto atomicType = op.getAtomicType(); + const auto reluPreMode = op.getReluPreMode(); + const bool hasPreQuantScalar = static_cast(preQuantScalar); + const bool phaseNonDefault = phase != pto::STPhase::Unspecified; + const bool atomicNonDefault = atomicType != pto::AtomicType::AtomicNone; + const bool reluNonDefault = reluPreMode != pto::ReluPreMode::NoRelu; + + auto getOpaqueTok = [&](Value v, StringRef name) -> FailureOr { + if (auto ot = mlir::dyn_cast(v.getType())) + return ot.getValue().str(); + return rewriter.notifyMatchFailure(op, (name + " must be emitc::OpaqueType").str()); + }; + + ArrayAttr targs; + // Map op attributes/operands to the exact TSTORE overload family: + // 1) TSTORE(dst, src) + // 2) TSTORE(dst, src) + // 3) TSTORE(dst, src) + // 4) TSTORE(dst, src) + // 5) TSTORE(dst, src) + // 6) TSTORE(dst, src) + // 7) TSTORE(dst, src, preQuant) + // 8) TSTORE(dst, src, preQuant) + if (!hasPreQuantScalar && !reluNonDefault && !atomicNonDefault) { + if (phaseNonDefault) { + targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, stPhaseTok(phase)), + }); + } else { + targs = ArrayAttr{}; + } + } else { + auto srcTokOr = getOpaqueTok(src, "src"); + auto dstTokOr = getOpaqueTok(dstArg, "dst"); + if (failed(srcTokOr) || failed(dstTokOr)) + return failure(); + + // If there is no preQuant and relu stays default, emit the atomic-only + // overloads (#3/#4) without ReluPreMode template argument. + if (!hasPreQuantScalar && !reluNonDefault) { + if (phaseNonDefault) { + targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, stPhaseTok(phase)), + emitc::OpaqueAttr::get(ctx, *srcTokOr), + emitc::OpaqueAttr::get(ctx, *dstTokOr), + emitc::OpaqueAttr::get(ctx, atomicTypeTok(atomicType)), + }); + } else { + targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, *srcTokOr), + emitc::OpaqueAttr::get(ctx, *dstTokOr), + emitc::OpaqueAttr::get(ctx, atomicTypeTok(atomicType)), + }); + } + } else { + // Relu/preQuant families (#5/#6/#7/#8): keep AtomicType + ReluPreMode. + if (phaseNonDefault) { + targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, stPhaseTok(phase)), + emitc::OpaqueAttr::get(ctx, *srcTokOr), + emitc::OpaqueAttr::get(ctx, *dstTokOr), + emitc::OpaqueAttr::get(ctx, atomicTypeTok(atomicType)), + emitc::OpaqueAttr::get(ctx, reluPreModeTok(reluPreMode)), + }); + } else { + targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, *srcTokOr), + emitc::OpaqueAttr::get(ctx, *dstTokOr), + emitc::OpaqueAttr::get(ctx, atomicTypeTok(atomicType)), + emitc::OpaqueAttr::get(ctx, reluPreModeTok(reluPreMode)), + }); + } + } + } + + SmallVector operands{dstArg, src}; + if (hasPreQuantScalar) + operands.push_back(preQuantScalar); + + rewriter.create( + loc, TypeRange{}, "TSTORE", + /*args=*/ArrayAttr{}, /*templateArgs=*/targs, + /*operands=*/operands); + + if (op->getNumResults() == 1) { + rewriter.replaceOp(op, dst); + } else { + rewriter.eraseOp(op); + } + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// pto.matmul_dps lowering (Simplified: No internal copy/sync) +//===----------------------------------------------------------------------===// +// +// Render `pto.tmatmul` as one of three forms depending on the optional +// `acc_phase` attribute: +// * absent / Unspecified -> `TMATMUL(dst, lhs, rhs)` +// * Partial -> `TMATMUL(dst, lhs, rhs)` +// * Final -> `TMATMUL(dst, lhs, rhs)` +// The Unspecified default keeps backward compatibility with all upstream IR +// that does not yet emit an explicit phase attribute. +static ArrayAttr buildAccPhaseTemplateArgs(ConversionPatternRewriter &rewriter, + pto::AccPhase phase) { + StringRef tmpl; + switch (phase) { + case pto::AccPhase::Unspecified: + return ArrayAttr{}; + case pto::AccPhase::Partial: + tmpl = "AccPhase::Partial"; + break; + case pto::AccPhase::Final: + tmpl = "AccPhase::Final"; + break; + } + if (tmpl.empty()) + return ArrayAttr{}; + return rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(rewriter.getContext(), tmpl)}); +} + +struct PTOTMatmulToTMATMUL : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMatmulOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // 1. 获取操作数 (剥离 Cast) + Value lhs = peelUnrealized(adaptor.getLhs()); // A (Left) + Value rhs = peelUnrealized(adaptor.getRhs()); // B (Right) + Value dst = peelUnrealized(adaptor.getDst()); // C (Acc) + + // 2. 根据 acc_phase 属性决定是否生成 TMATMUL(...) + ArrayAttr templateArgs = + buildAccPhaseTemplateArgs(rewriter, op.getAccPhase()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TMATMUL", + /*args=*/ArrayAttr{}, /*template_args=*/templateArgs, + ValueRange{dst, lhs, rhs}); + + // 3. 处理 Op 替换/删除 + if (op->getNumResults() == 1) { + rewriter.replaceOp(op, dst); + } else { + rewriter.eraseOp(op); + } + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// pto.tgemv lowering +//===----------------------------------------------------------------------===// +struct PTOTGemvToTGEMV : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGemvOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // 1. 获取操作数 (剥离 Cast) + Value lhs = peelUnrealized(adaptor.getLhs()); // A (Matrix) + Value rhs = peelUnrealized(adaptor.getRhs()); // B (Vector) + Value dst = peelUnrealized(adaptor.getDst()); // C (Result) + + // 2. 直接生成函数调用 TGEMV(dst, lhs, rhs) + rewriter.create( + op.getLoc(), TypeRange{}, "TGEMV", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, lhs, rhs}); + + // 3. 处理 Op 替换/删除 + if (op->getNumResults() == 1) { + rewriter.replaceOp(op, dst); + } else { + rewriter.eraseOp(op); + } + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// pto.tgemv.acc lowering +//===----------------------------------------------------------------------===// +struct PTOTGemvAccToTGEMVACC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGemvAccOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getDst()) + return rewriter.notifyMatchFailure(op, "expected outs(dst) for pto.tgemv.acc"); + + // 1. 获取操作数 + Value accIn = peelUnrealized(adaptor.getAccIn()); // AccOld + Value lhs = peelUnrealized(adaptor.getLhs()); // A (Matrix) + Value rhs = peelUnrealized(adaptor.getRhs()); // B (Vector) + Value dst = peelUnrealized(adaptor.getDst()); // AccNew + + // 2. 直接生成函数调用 TGEMV_ACC(dst, accIn, lhs, rhs) + rewriter.create( + op.getLoc(), TypeRange{}, "TGEMV_ACC", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, accIn, lhs, rhs}); + + // 3. 处理 Op 替换/删除 + if (op->getNumResults() == 1) { + rewriter.replaceOp(op, dst); + } else { + rewriter.eraseOp(op); + } + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// pto.matmul_acc_dps lowering (Simplified: No internal copy/sync) +//===----------------------------------------------------------------------===// +struct PTOTMatmulAccToTMATMULACC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMatmulAccOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getDst()) + return rewriter.notifyMatchFailure(op, "expected outs(dst) for pto.tmatmul.acc"); + + // 1. 获取操作数 + Value accIn = peelUnrealized(adaptor.getAccIn()); // AccOld + Value lhs = peelUnrealized(adaptor.getLhs()); // A (Left) + Value rhs = peelUnrealized(adaptor.getRhs()); // B (Right) + Value dst = peelUnrealized(adaptor.getDst()); // AccNew + + // 2. 根据 acc_phase 属性决定是否生成 TMATMUL_ACC(...) + ArrayAttr templateArgs = + buildAccPhaseTemplateArgs(rewriter, op.getAccPhase()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TMATMUL_ACC", + /*args=*/ArrayAttr{}, /*template_args=*/templateArgs, + ValueRange{dst, accIn, lhs, rhs}); + + // 3. 处理 Op 替换/删除 + if (op->getNumResults() == 1) { + rewriter.replaceOp(op, dst); + } else { + rewriter.eraseOp(op); + } + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Return lowering +//===----------------------------------------------------------------------=== + +static constexpr llvm::StringLiteral kAutoSyncTailPendingModeAttr = + "__pto.auto_sync_tail_mode"; + +struct ReturnToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (auto emitcFunc = op->getParentOfType()) { + if (auto modeAttr = + emitcFunc->getAttrOfType(kAutoSyncTailPendingModeAttr)) { + auto *ctx = rewriter.getContext(); + rewriter.setInsertionPoint(op); + auto args = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, modeAttr.getValue())}); + rewriter.create( + op.getLoc(), TypeRange{}, "ptoas_auto_sync_tail", + args, ArrayAttr{}, ValueRange{}); + } + } + + auto vals = adaptor.getOperands(); + if (vals.empty()) { + rewriter.replaceOpWithNewOp(op, Value{}); + return success(); + } + if (vals.size() == 1) { + rewriter.replaceOpWithNewOp(op, vals[0]); + return success(); + } + return rewriter.notifyMatchFailure(op, "EmitC cannot return multiple values"); + } +}; + +struct CallToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(func::CallOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op.getNumResults() > 1) + return rewriter.notifyMatchFailure( + op, "EmitC cannot lower calls with multiple results"); + + SmallVector resultTypes; + if (failed( + getTypeConverter()->convertTypes(op.getResultTypes(), resultTypes))) + return rewriter.notifyMatchFailure(op, + "failed to convert call result types"); + + rewriter.replaceOpWithNewOp(op, op.getCalleeAttr(), + resultTypes, + adaptor.getOperands()); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Sync lowering +//===----------------------------------------------------------------------=== + +static constexpr llvm::StringLiteral kAutoSyncTailBarrierAttr = + "pto.auto_sync_tail_barrier"; +static constexpr llvm::StringLiteral kAutoSyncTailHintAttr = + "pto.auto_sync_tail_hint"; +static constexpr llvm::StringLiteral kAutoSyncTailPolicyBarrierAll = + "barrier_all"; +static constexpr llvm::StringLiteral kAutoSyncTailPolicyMte3ToSEvent0 = + "setwait_mte3_to_s_event0"; +static constexpr llvm::StringLiteral kAutoSyncTailModeBarrierAllToken = + "PTOAutoSyncTailMode::kBarrierAll"; +static constexpr llvm::StringLiteral kAutoSyncTailModeMte3ToSEvent0Token = + "PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0"; + +static std::string getAutoSyncTailModeToken(Operation *op) { + if (op) { + if (auto hintAttr = op->getAttrOfType(kAutoSyncTailHintAttr)) { + if (hintAttr.getValue() == kAutoSyncTailPolicyBarrierAll) + return kAutoSyncTailModeBarrierAllToken.str(); + if (hintAttr.getValue() == kAutoSyncTailPolicyMte3ToSEvent0) + return kAutoSyncTailModeMte3ToSEvent0Token.str(); + } + } + + auto func = op ? op->getParentOfType() : func::FuncOp(); + if (!func) + return kAutoSyncTailModeBarrierAllToken.str(); + + auto hintAttr = func->getAttrOfType(kAutoSyncTailHintAttr); + if (!hintAttr) + return kAutoSyncTailModeBarrierAllToken.str(); + + if (hintAttr.getValue() == kAutoSyncTailPolicyBarrierAll) + return kAutoSyncTailModeBarrierAllToken.str(); + if (hintAttr.getValue() == kAutoSyncTailPolicyMte3ToSEvent0) + return kAutoSyncTailModeMte3ToSEvent0Token.str(); + + // Fallback to the conservative behavior when seeing unknown policies. + return kAutoSyncTailModeBarrierAllToken.str(); +} + +[[maybe_unused]] static std::string getPipeName(pto::PIPE pipe) { + switch (pipe) { + case pto::PIPE::PIPE_S: return "PIPE_S"; + case pto::PIPE::PIPE_V: return "PIPE_V"; + case pto::PIPE::PIPE_M: return "PIPE_M"; + case pto::PIPE::PIPE_MTE1: return "PIPE_MTE1"; + case pto::PIPE::PIPE_MTE2: return "PIPE_MTE2"; + case pto::PIPE::PIPE_MTE3: return "PIPE_MTE3"; + case pto::PIPE::PIPE_ALL: return "PIPE_ALL"; + case pto::PIPE::PIPE_MTE4: return "PIPE_MTE4"; + case pto::PIPE::PIPE_MTE5: return "PIPE_MTE5"; + case pto::PIPE::PIPE_V2: return "PIPE_V2"; + case pto::PIPE::PIPE_FIX: return "PIPE_FIX"; + case pto::PIPE::VIRTUAL_PIPE_MTE2_L1A: return "VIRTUAL_PIPE_MTE2_L1A"; + case pto::PIPE::VIRTUAL_PIPE_MTE2_L1B: return "VIRTUAL_PIPE_MTE2_L1B"; + // 默认回退 + default: return "PIPE_ALL"; + } +} + +//===----------------------------------------------------------------------===// +// pto.barrier lowering -> pipe_barrier(...) +//===----------------------------------------------------------------------===// +struct PTOBarrierToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::BarrierOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op->hasAttr(kAutoSyncTailBarrierAttr)) { + auto modeAttr = rewriter.getStringAttr(getAutoSyncTailModeToken(op)); + if (auto emitcFunc = op->getParentOfType()) { + emitcFunc->setAttr(kAutoSyncTailPendingModeAttr, modeAttr); + } else if (auto funcOp = op->getParentOfType()) { + funcOp->setAttr(kAutoSyncTailPendingModeAttr, modeAttr); + } + rewriter.eraseOp(op); + return success(); + } + + // [FIX] op.getPipe() returns PipeAttr. + // We must call .getPipe() on the attribute to get the actual Enum value. + pto::PIPE pipeEnum = op.getPipe().getPipe(); + + // Convert Enum to String (e.g., PIPE_ALL -> "PIPE_ALL") + std::string pipeStr = pto::stringifyPIPE(pipeEnum).str(); + auto *ctx = rewriter.getContext(); + + auto args = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeStr) + }); + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, // void return + "pipe_barrier", // function name + args, // arguments + ArrayAttr{}, // template args + ValueRange{} // operands + ); + + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Sync lowering (robust for bracket form pto.set_flag[...] / pto.wait_flag[...]) +// Replace your PTOSyncToRuntimeCall with the code below. +//===----------------------------------------------------------------------===// + +static bool tryConvertPipeAttrToToken(Attribute attr, std::string &token) { + if (!attr) + return false; + if (auto pipe = dyn_cast(attr)) { + token = mlir::pto::stringifyPIPE(pipe.getPipe()).str(); + return true; + } + if (auto stringAttr = dyn_cast(attr)) { + token = stringAttr.getValue().str(); + return true; + } + return false; +} + +static bool tryConvertEventAttrToToken(Attribute attr, std::string &token) { + if (!attr) + return false; + if (auto event = dyn_cast(attr)) { + token = mlir::pto::stringifyEVENT(event.getEvent()).str(); + return true; + } + if (auto stringAttr = dyn_cast(attr)) { + token = stringAttr.getValue().str(); + return true; + } + return false; +} + +static bool tryAssignSyncTokens(Attribute srcAttr, Attribute dstAttr, + Attribute evtAttr, std::string &srcTok, + std::string &dstTok, std::string &evtTok) { + std::string localSrc; + std::string localDst; + std::string localEvt; + if (!tryConvertPipeAttrToToken(srcAttr, localSrc) || + !tryConvertPipeAttrToToken(dstAttr, localDst) || + !tryConvertEventAttrToToken(evtAttr, localEvt)) { + return false; + } + srcTok = std::move(localSrc); + dstTok = std::move(localDst); + evtTok = std::move(localEvt); + return true; +} + +static bool tryExtractSyncTokensFromNamedAttrs(Operation *op, + StringRef srcName, + StringRef dstName, + StringRef evtName, + std::string &srcTok, + std::string &dstTok, + std::string &evtTok) { + return tryAssignSyncTokens(op->getAttr(srcName), op->getAttr(dstName), + op->getAttr(evtName), srcTok, dstTok, evtTok); +} + +static bool tryExtractSyncTokensFromArrayAttr(Operation *op, StringRef attrName, + std::string &srcTok, + std::string &dstTok, + std::string &evtTok) { + auto arrayAttr = op->getAttrOfType(attrName); + if (!arrayAttr || arrayAttr.size() < 3) + return false; + return tryAssignSyncTokens(arrayAttr[0], arrayAttr[1], arrayAttr[2], srcTok, + dstTok, evtTok); +} + +static bool tryExtractFallbackSyncTokens(Operation *op, std::string &srcTok, + std::string &dstTok, + std::string &evtTok) { + SmallVector pipes; + std::string event; + for (NamedAttribute namedAttr : op->getAttrs()) { + std::string token; + if (tryConvertPipeAttrToToken(namedAttr.getValue(), token)) { + pipes.push_back(std::move(token)); + continue; + } + if (event.empty() && + tryConvertEventAttrToToken(namedAttr.getValue(), token)) { + event = std::move(token); + } + } + if (pipes.size() < 2 || event.empty()) + return false; + srcTok = pipes[0]; + dstTok = pipes[1]; + evtTok = event; + return true; +} + +static LogicalResult extractSyncTripletTokens(Operation *op, + std::string &srcTok, + std::string &dstTok, + std::string &evtTok, + ConversionPatternRewriter &rewriter) { + if (tryExtractSyncTokensFromNamedAttrs(op, "src_pipe", "dst_pipe", "event_id", + srcTok, dstTok, evtTok) || + tryExtractSyncTokensFromNamedAttrs(op, "srcPipe", "dstPipe", "eventId", + srcTok, dstTok, evtTok) || + tryExtractSyncTokensFromNamedAttrs(op, "src", "dst", "event", srcTok, + dstTok, evtTok)) { + return success(); + } + + for (StringRef attrName : {"args", "pipes", "sync", "triplet", "attrs"}) { + if (tryExtractSyncTokensFromArrayAttr(op, attrName, srcTok, dstTok, + evtTok)) { + return success(); + } + } + + if (tryExtractFallbackSyncTokens(op, srcTok, dstTok, evtTok)) + return success(); + return rewriter.notifyMatchFailure( + op, "cannot extract PIPE/PIPE/EVENT tokens from pto.{set,wait}_flag"); +} +static inline std::string pipeTokFromPipeEnum(mlir::pto::PIPE p) { + return mlir::pto::stringifyPIPE(p).str(); +} +[[maybe_unused]] static inline std::string evtTokFromEventEnum(mlir::pto::EVENT e) { + return mlir::pto::stringifyEVENT(e).str(); +} +static inline std::string pipeTokFromPipeAttr(mlir::pto::PipeAttr a) { + return mlir::pto::stringifyPIPE(a.getPipe()).str(); +} +static inline std::string evtTokFromEventAttr(mlir::pto::EventAttr a) { + return mlir::pto::stringifyEVENT(a.getEvent()).str(); +} + +template +struct HasGetSrcPipe : std::false_type {}; +template +struct HasGetSrcPipe().getSrcPipe())>> : std::true_type {}; + +template +struct HasGetDstPipe : std::false_type {}; +template +struct HasGetDstPipe().getDstPipe())>> : std::true_type {}; + +template +struct HasGetEventId : std::false_type {}; +template +struct HasGetEventId().getEventId())>> : std::true_type {}; + +template +struct HasGetSrcPipeAttr : std::false_type {}; +template +struct HasGetSrcPipeAttr().getSrcPipeAttr())>> : std::true_type {}; + +template +struct HasGetDstPipeAttr : std::false_type {}; +template +struct HasGetDstPipeAttr().getDstPipeAttr())>> : std::true_type {}; + +template +struct HasGetEventIdAttr : std::false_type {}; +template +struct HasGetEventIdAttr().getEventIdAttr())>> : std::true_type {}; + +template +static LogicalResult extractSyncTokens(SyncOpT op, + std::string &srcTok, + std::string &dstTok, + std::string &evtTok, + ConversionPatternRewriter &rewriter) { + if constexpr (HasGetSrcPipe::value && + HasGetDstPipe::value && + HasGetEventId::value) { + auto s = op.getSrcPipe(); + auto d = op.getDstPipe(); + auto e = op.getEventId(); + + if constexpr (std::is_same::value) srcTok = pipeTokFromPipeEnum(s); + else srcTok = pipeTokFromPipeAttr(s); + + if constexpr (std::is_same::value) dstTok = pipeTokFromPipeEnum(d); + else dstTok = pipeTokFromPipeAttr(d); + + if constexpr (std::is_same::value) evtTok = evtTokFromEventEnum(e); + else evtTok = evtTokFromEventAttr(e); + + return success(); + } + + if constexpr (HasGetSrcPipeAttr::value && + HasGetDstPipeAttr::value && + HasGetEventIdAttr::value) { + auto s = op.getSrcPipeAttr(); + auto d = op.getDstPipeAttr(); + auto e = op.getEventIdAttr(); + srcTok = pipeTokFromPipeAttr(s); + dstTok = pipeTokFromPipeAttr(d); + evtTok = evtTokFromEventAttr(e); + return success(); + } + + return extractSyncTripletTokens(op.getOperation(), srcTok, dstTok, evtTok, rewriter); +} +struct PTOSetFlagToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::SetFlagOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + auto *ctx = rewriter.getContext(); + + std::string srcTok, dstTok, evtTok; + if (failed(extractSyncTokens(op, srcTok, dstTok, evtTok, rewriter))) + return failure(); + + auto argsAttr = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, srcTok), + emitc::OpaqueAttr::get(ctx, dstTok), + emitc::OpaqueAttr::get(ctx, evtTok), + }); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, "set_flag", + /*args=*/argsAttr, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{}); + return success(); + } +}; + +struct PTOWaitFlagToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::WaitFlagOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + auto *ctx = rewriter.getContext(); + + std::string srcTok, dstTok, evtTok; + if (failed(extractSyncTokens(op, srcTok, dstTok, evtTok, rewriter))) + return failure(); + + auto argsAttr = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, srcTok), + emitc::OpaqueAttr::get(ctx, dstTok), + emitc::OpaqueAttr::get(ctx, evtTok), + }); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, "wait_flag", + /*args=*/argsAttr, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{}); + return success(); + } +}; + +struct PTOSyncToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::TSyncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector operands; + operands.reserve(adaptor.getEvents().size()); + for (Value event : adaptor.getEvents()) + operands.push_back(peelUnrealized(event)); + + rewriter.create( + op.getLoc(), TypeRange{}, "TSYNC", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange(operands)); + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOSyncAllToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + static StringRef coreTypeTok(pto::SyncCoreType coreType) { + switch (coreType) { + case pto::SyncCoreType::AIVOnly: + return "SyncCoreType::AIVOnly"; + case pto::SyncCoreType::AICOnly: + return "SyncCoreType::AICOnly"; + case pto::SyncCoreType::Mix: + return "SyncCoreType::Mix"; + } + llvm_unreachable("unhandled SyncCoreType"); + } + + LogicalResult matchAndRewrite(mlir::pto::SyncAllOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto mode = op.getMode().getValue(); + auto coreType = op.getCoreType().getValue(); + + auto buildGmWorkspace = [&]() -> FailureOr { + Value gm = peelUnrealized(adaptor.getGmWorkspace()); + if (isEmitCGlobalTensorLikeType(gm.getType())) + return gm; + + auto memTy = dyn_cast(op.getGmWorkspace().getType()); + if (!memTy) + return failure(); + + Value gt = buildGlobalTensorFromMemref(rewriter, op.getLoc(), gm, memTy, + op.getGmWorkspace().getDefiningOp() + ? op.getGmWorkspace().getDefiningOp() + : op.getOperation()); + if (!gt) + return failure(); + return gt; + }; + + if (mode == pto::SyncAllMode::Hard) { + std::string callee = "SYNCALL<" + coreTypeTok(coreType).str() + ">"; + rewriter.create(op.getLoc(), TypeRange{}, callee, + ArrayAttr{}, ArrayAttr{}, + ValueRange{}); + rewriter.eraseOp(op); + return success(); + } + + FailureOr gmWorkspace = buildGmWorkspace(); + if (failed(gmWorkspace)) + return rewriter.notifyMatchFailure(op, + "failed to build gm_workspace GlobalTensor"); + + auto i32Ty = emitc::OpaqueType::get(rewriter.getContext(), "int32_t"); + Value usedCores = adaptor.getUsedCores() + ? peelUnrealized(adaptor.getUsedCores()) + : makeEmitCIntConstant(rewriter, op.getLoc(), i32Ty, 0); + if (usedCores.getType() != i32Ty) + usedCores = rewriter.create(op.getLoc(), i32Ty, usedCores) + .getResult(); + + std::string callee = + "SYNCALL"; + + SmallVector operands{*gmWorkspace}; + switch (coreType) { + case pto::SyncCoreType::AIVOnly: { + FailureOr ubWorkspace = + buildSyncAllWorkspaceTileValue(rewriter, op.getLoc(), + op.getUbWorkspace(), + adaptor.getUbWorkspace()); + if (failed(ubWorkspace)) + return rewriter.notifyMatchFailure( + op, "failed to materialize ub_workspace tile"); + operands.push_back(*ubWorkspace); + break; + } + case pto::SyncCoreType::AICOnly: { + FailureOr l1Workspace = + buildSyncAllWorkspaceTileValue(rewriter, op.getLoc(), + op.getL1Workspace(), + adaptor.getL1Workspace()); + if (failed(l1Workspace)) + return rewriter.notifyMatchFailure( + op, "failed to materialize l1_workspace tile"); + operands.push_back(*l1Workspace); + break; + } + case pto::SyncCoreType::Mix: { + FailureOr ubWorkspace = + buildSyncAllWorkspaceTileValue(rewriter, op.getLoc(), + op.getUbWorkspace(), + adaptor.getUbWorkspace()); + FailureOr l1Workspace = + buildSyncAllWorkspaceTileValue(rewriter, op.getLoc(), + op.getL1Workspace(), + adaptor.getL1Workspace()); + if (failed(ubWorkspace) || failed(l1Workspace)) + return rewriter.notifyMatchFailure( + op, "failed to materialize mixed syncall workspace tiles"); + operands.push_back(*ubWorkspace); + operands.push_back(*l1Workspace); + break; + } + } + + operands.push_back(usedCores); + rewriter.create(op.getLoc(), TypeRange{}, callee, + ArrayAttr{}, ArrayAttr{}, + ValueRange(operands)); + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOSyncFlagDynToEmitC : public ConversionPattern { + PTOSyncFlagDynToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + StringRef opName, StringRef callee) + : ConversionPattern(typeConverter, opName, /*benefit=*/1, ctx), + callee(callee.str()) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (operands.size() != 1) + return rewriter.notifyMatchFailure(op, "expected exactly one dynamic event-id operand"); + + auto srcAttr = op->getAttrOfType("src_pipe"); + auto dstAttr = op->getAttrOfType("dst_pipe"); + if (!srcAttr || !dstAttr) + return rewriter.notifyMatchFailure(op, "missing PipeAttr src_pipe/dst_pipe attrs"); + + auto *ctx = rewriter.getContext(); + std::string srcTok = pipeTokFromPipeAttr(srcAttr); + std::string dstTok = pipeTokFromPipeAttr(dstAttr); + + Value eventVal = operands.front(); + eventVal = + emitCCast(rewriter, op->getLoc(), emitc::OpaqueType::get(ctx, "event_t"), eventVal); + + auto argsAttr = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, srcTok), + emitc::OpaqueAttr::get(ctx, dstTok), + IntegerAttr::get(IndexType::get(ctx), 0), + }); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, callee, + /*args=*/argsAttr, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{eventVal}); + return success(); + } + +private: + std::string callee; +}; + +struct PTOGetBufToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::GetBufOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + auto *ctx = rewriter.getContext(); + + auto opTypeOr = parseSyncOpTypeLikeAttr(op.getOpTypeAttr()); + if (failed(opTypeOr)) + return rewriter.notifyMatchFailure(op, "get_buf expects pipe_event_type/sync_op_type attr"); + auto pipe = mapSyncOpTypeToPipe(*opTypeOr); + if (!isConcreteSyncPipe(pipe)) + return rewriter.notifyMatchFailure(op, "get_buf op_type cannot map to a concrete pipe"); + std::string pipeTok = pipeTokFromPipeEnum(pipe); + auto argsAttr = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeTok), + op.getBufIdAttr(), + op.getModeAttr(), + }); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, "get_buf", + /*args=*/argsAttr, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{}); + return success(); + } +}; + +struct PTORlsBufToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::RlsBufOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + auto *ctx = rewriter.getContext(); + + auto opTypeOr = parseSyncOpTypeLikeAttr(op.getOpTypeAttr()); + if (failed(opTypeOr)) + return rewriter.notifyMatchFailure(op, "rls_buf expects pipe_event_type/sync_op_type attr"); + auto pipe = mapSyncOpTypeToPipe(*opTypeOr); + if (!isConcreteSyncPipe(pipe)) + return rewriter.notifyMatchFailure(op, "rls_buf op_type cannot map to a concrete pipe"); + std::string pipeTok = pipeTokFromPipeEnum(pipe); + auto argsAttr = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeTok), + op.getBufIdAttr(), + op.getModeAttr(), + }); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, "rls_buf", + /*args=*/argsAttr, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{}); + return success(); + } +}; + +struct PTOSetFFTsToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::SetFFTsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *ctx = rewriter.getContext(); + auto loc = op.getLoc(); + + Value fftsAddr = peelUnrealized(adaptor.getFfts()); + auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); + + if (isSetFFTsPointerLikeType(fftsAddr.getType())) { + auto castTyAttr = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + fftsAddr = + rewriter + .create(loc, u64Ty, "reinterpret_cast", + /*args=*/ArrayAttr{}, + /*templateArgs=*/castTyAttr, + /*operands=*/ValueRange{fftsAddr}) + .getResult(0); + } else if (fftsAddr.getType() != u64Ty) { + fftsAddr = + rewriter.create(loc, u64Ty, fftsAddr).getResult(); + } + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, "set_ffts_base_addr", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{fftsAddr}); + return success(); + } +}; + +struct PTOSyncSetToEmitC : public OpConversionPattern { + PTOSyncSetToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + PTOArch targetArch) + : OpConversionPattern(typeConverter, ctx), + targetArch(targetArch) {} + + LogicalResult + matchAndRewrite(mlir::pto::SyncSetOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto *ctx = rewriter.getContext(); + IntegerAttr eventIdAttr = op.getEventIdAttr(); + Value eventIdDyn = adaptor.getEventIdDyn(); + int64_t fftsMode = 2; + if (IntegerAttr fftsModeAttr = op.getFftsModeAttr()) + fftsMode = fftsModeAttr.getInt(); + + if ((eventIdAttr != nullptr) == static_cast(eventIdDyn)) + return rewriter.notifyMatchFailure( + op, "expects exactly one of static event_id attr or dynamic event_id operand"); + + // A5 inter-core sync mirrors +16 only for cube-side producer (PIPE_FIX). + // Vec-side producer (PIPE_MTE3) emits a single set; hardware handles the + // subblock mapping in PTO-ISA custom flow. + if (targetArch == PTOArch::A5) { + pto::PIPE pipe = op.getPipe().getPipe(); + bool needsMirrorPlus16 = (pipe == pto::PIPE::PIPE_FIX); + std::string pipeTok = pipeTokFromPipeAttr(op.getPipe()); + auto emitSet = [&](Value eventOperand, IntegerAttr eventLiteral, + bool isDynamic) { + if (isDynamic) { + auto args = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeTok), + IntegerAttr::get(IndexType::get(ctx), 0), + }); + rewriter.create(loc, TypeRange{}, "set_intra_block", + /*args=*/args, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{eventOperand}); + return; + } + auto args = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeTok), + eventLiteral, + }); + rewriter.create(loc, TypeRange{}, "set_intra_block", + /*args=*/args, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{}); + }; + + if (eventIdAttr) { + emitSet(Value{}, eventIdAttr, /*isDynamic=*/false); + if (needsMirrorPlus16) { + auto plus16 = IntegerAttr::get(eventIdAttr.getType(), + eventIdAttr.getInt() + 16); + emitSet(Value{}, plus16, /*isDynamic=*/false); + } + } else { + Value eventI32 = castInterCoreEventIdToI32(rewriter, loc, eventIdDyn); + emitSet(eventI32, IntegerAttr{}, /*isDynamic=*/true); + if (needsMirrorPlus16) { + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value c16 = makeEmitCIntConstant(rewriter, loc, i32Ty, 16); + Value eventI32Plus16 = + rewriter.create(loc, i32Ty, eventI32, c16).getResult(); + emitSet(eventI32Plus16, IntegerAttr{}, /*isDynamic=*/true); + } + } + + rewriter.eraseOp(op); + return success(); + } + + InterCoreSyncCallDesc desc; + if (eventIdAttr) { + desc = buildInterCoreSyncSetCall(rewriter, loc, targetArch, op.getPipe(), + eventIdAttr, fftsMode); + } else { + desc = buildInterCoreSyncSetCallDyn(rewriter, loc, targetArch, op.getPipe(), + eventIdDyn, fftsMode); + } + rewriter.create(loc, TypeRange{}, desc.callee, + /*args=*/desc.args, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/desc.operands); + + rewriter.eraseOp(op); + return success(); + } + + PTOArch targetArch; +}; + +struct PTOSyncWaitToEmitC : public OpConversionPattern { + PTOSyncWaitToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + PTOArch targetArch) + : OpConversionPattern(typeConverter, ctx), + targetArch(targetArch) {} + + LogicalResult + matchAndRewrite(mlir::pto::SyncWaitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + IntegerAttr eventIdAttr = op.getEventIdAttr(); + Value eventIdDyn = adaptor.getEventIdDyn(); + + if ((eventIdAttr != nullptr) == static_cast(eventIdDyn)) + return rewriter.notifyMatchFailure( + op, "expects exactly one of static event_id attr or dynamic event_id operand"); + + InterCoreSyncCallDesc desc; + if (eventIdAttr) { + desc = buildInterCoreSyncWaitCall(rewriter, targetArch, op.getPipe(), + eventIdAttr); + } else { + desc = buildInterCoreSyncWaitCallDyn(rewriter, loc, targetArch, op.getPipe(), + eventIdDyn); + } + rewriter.create(loc, TypeRange{}, desc.callee, + desc.args, ArrayAttr{}, desc.operands); + + rewriter.eraseOp(op); + return success(); + } + + PTOArch targetArch; +}; + +// GetBlockIdxOp Lowering (pto.get_block_idx -> get_block_idx()) +struct PTOGetBlockIdxToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mlir::pto::GetBlockIdxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + rewriter.replaceOpWithNewOp( + op, op.getType(), "get_block_idx", ValueRange{}, ArrayAttr{}, + ArrayAttr{}); + + return success(); + } +}; + +// GetBlockNumOp Lowering (pto.get_block_num -> get_block_num()) +struct PTOGetBlockNumToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mlir::pto::GetBlockNumOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + rewriter.replaceOpWithNewOp( + op, op.getType(), "get_block_num", ValueRange{}, ArrayAttr{}, + ArrayAttr{}); + + return success(); + } +}; + +// GetSubBlockIdxOp Lowering (pto.get_block_idx -> get_subblockid()) +struct PTOGetSubBlockIdxToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mlir::pto::GetSubBlockIdxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + rewriter.replaceOpWithNewOp( + op, op.getType(), "get_subblockid", ValueRange{}, ArrayAttr{}, + ArrayAttr{}); + + return success(); + } +}; + +// GetSubBlockNumOp Lowering. +struct PTOGetSubBlockNumToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mlir::pto::GetSubBlockNumOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + rewriter.replaceOpWithNewOp( + op, op.getType(), "get_subblockdim", ValueRange{}, ArrayAttr{}, + ArrayAttr{}); + + return success(); + } +}; + + +struct PTOMScatterToMSCATTER : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::MScatterOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *ctx = rewriter.getContext(); + Value src = peelUnrealized(adaptor.getSrc()); + Value idx = peelUnrealized(adaptor.getIdx()); + Value mem = peelUnrealized(adaptor.getMem()); + + Value memArg = maybeWrapGlobalMemrefAsGlobalTensor( + rewriter, op.getLoc(), mem, op.getMem().getType(), op.getOperation()); + + auto scatterAtomicTok = [&](pto::ScatterAtomicOp atomic) -> StringRef { + switch (atomic) { + case pto::ScatterAtomicOp::None: + return "pto::ScatterAtomicOp::None"; + case pto::ScatterAtomicOp::Add: + return "pto::ScatterAtomicOp::Add"; + case pto::ScatterAtomicOp::Max: + return "pto::ScatterAtomicOp::Max"; + case pto::ScatterAtomicOp::Min: + return "pto::ScatterAtomicOp::Min"; + } + llvm_unreachable("unknown ScatterAtomicOp"); + }; + auto scatterOobTok = [&](pto::ScatterOOB mode) -> StringRef { + switch (mode) { + case pto::ScatterOOB::Undefined: + return "pto::ScatterOOB::Undefined"; + case pto::ScatterOOB::Skip: + return "pto::ScatterOOB::Skip"; + case pto::ScatterOOB::Clamp: + return "pto::ScatterOOB::Clamp"; + case pto::ScatterOOB::Wrap: + return "pto::ScatterOOB::Wrap"; + } + llvm_unreachable("unknown ScatterOOB"); + }; + + SmallVector templateArgVec; + const bool rowCoalesce = + isRowCoalescedMGatherIndexType(op.getSrc().getType(), op.getIdx().getType()); + templateArgVec.push_back(emitc::OpaqueAttr::get( + ctx, rowCoalesce ? "pto::Coalesce::Row" : "pto::Coalesce::Elem")); + if (op.getScatterAtomicOp() != pto::ScatterAtomicOp::None || + op.getScatterOob() != pto::ScatterOOB::Undefined) { + templateArgVec.push_back(emitc::OpaqueAttr::get( + ctx, scatterAtomicTok(op.getScatterAtomicOp()))); + if (op.getScatterOob() != pto::ScatterOOB::Undefined) + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, scatterOobTok(op.getScatterOob()))); + } + ArrayAttr templateArgs = rewriter.getArrayAttr(templateArgVec); + + rewriter.create( + op.getLoc(), TypeRange{}, "MSCATTER", + ArrayAttr{}, templateArgs, + ValueRange{memArg, src, idx}); + + rewriter.eraseOp(op); + return success(); + } +}; +struct PTOSetValToSETVAL : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSetValOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value dst = peelUnrealized(adaptor.getDst()); + Value val = peelUnrealized(adaptor.getVal()); + + // ---- offset: SSA index operand ---- + Value offset = peelUnrealized(adaptor.getOffset()); + + // Emit a marker call and let the ptoas post-processing step lower it to + // the corresponding tile setter. + rewriter.create( + op.getLoc(), TypeRange{}, "PTOAS__TILE_SET_VALUE", + ArrayAttr{}, ArrayAttr{}, ValueRange{dst, offset, val}); + + rewriter.eraseOp(op); + return success(); + } +}; +struct PTOGetValToGETVAL : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGetValOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = peelUnrealized(adaptor.getSrc()); + + // ---- offset: SSA index operand ---- + Value offset = peelUnrealized(adaptor.getOffset()); + + // Emit a marker call and let the ptoas post-processing step lower it to + // the corresponding tile getter. + Type dstTy = getTypeConverter()->convertType(op.getDst().getType()); + if (!dstTy) + return failure(); + auto call = rewriter.create( + op.getLoc(), + TypeRange{dstTy}, + "PTOAS__TILE_GET_VALUE", + ArrayAttr{}, ArrayAttr{}, + ValueRange{src, offset}); + + rewriter.replaceOp(op, call.getResults()); + return success(); + } +}; + +struct PTOTAxpyToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAxpyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value scalar = peelUnrealized(adaptor.getScalar()); + + rewriter.create( + loc, TypeRange{}, "TAXPY", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, scalar}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOHistogramToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::THistogramOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value idx = peelUnrealized(adaptor.getIdx()); + Value dst = peelUnrealized(adaptor.getDst()); + + auto templateArgs = rewriter.getArrayAttr({emitc::OpaqueAttr::get( + ctx, op.getIsMSB() ? "HistByte::BYTE_1" : "HistByte::BYTE_0")}); + rewriter.create( + loc, TypeRange{}, "THISTOGRAM", + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, + /*operands=*/ValueRange{dst, src, idx}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOGetScaleAddrToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGetScaleAddrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TGET_SCALE_ADDR", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOSetValidShapeToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::SetValidShapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto peelAllCasts = [](Value v) { + while (auto castOp = v.getDefiningOp()) + v = castOp.getOperand(0); + if (auto castOp = v.getDefiningOp()) + v = castOp.getOperand(); + return v; + }; + auto isTileLike = [](Value v) -> bool { + auto ot = dyn_cast(v.getType()); + if (!ot) + return false; + StringRef s = ot.getValue(); + return s.contains("Tile<") || s.contains("ConvTile<"); + }; + + Value src = peelAllCasts(peelUnrealized(adaptor.getSource())); + Value row = peelUnrealized(adaptor.getValidRow()); + Value col = peelUnrealized(adaptor.getValidCol()); + + if (!isTileLike(src)) + return rewriter.notifyMatchFailure( + op, "set_validshape source must lower to a tile-like value"); + + rewriter.create( + op.getLoc(), TypeRange{}, "PTOAS__TILE_SET_VALIDSHAPE", ArrayAttr{}, + ArrayAttr{}, ValueRange{src, row, col}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOGetValidShapeToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::GetValidShapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto peelAllCasts = [](Value v) { + while (auto castOp = v.getDefiningOp()) + v = castOp.getOperand(0); + if (auto castOp = v.getDefiningOp()) + v = castOp.getOperand(); + return v; + }; + auto isTileLike = [](Value v) -> bool { + auto ot = dyn_cast(v.getType()); + if (!ot) + return false; + StringRef s = ot.getValue(); + return s.contains("Tile<") || s.contains("ConvTile<"); + }; + + Value src = peelAllCasts(peelUnrealized(adaptor.getSource())); + if (!isTileLike(src)) + return rewriter.notifyMatchFailure( + op, "get_validshape source must lower to a tile-like value"); + + auto resultTy = getTypeConverter()->convertType(rewriter.getIndexType()); + if (!resultTy) + return failure(); + + Value row = rewriter + .create( + op.getLoc(), resultTy, + "PTOAS__TILE_GET_VALID_ROW", ArrayAttr{}, + ArrayAttr{}, ValueRange{src}) + .getResult(0); + Value col = rewriter + .create( + op.getLoc(), resultTy, + "PTOAS__TILE_GET_VALID_COL", ArrayAttr{}, + ArrayAttr{}, ValueRange{src}) + .getResult(0); + rewriter.replaceOp(op, ValueRange{row, col}); + return success(); + } +}; + +struct PTOTAssignToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAssignOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto peelAllCasts = [](Value v) { + while (auto castOp = v.getDefiningOp()) + v = castOp.getOperand(0); + if (auto castOp = v.getDefiningOp()) + v = castOp.getOperand(); + return v; + }; + auto isTileLike = [](Value v) -> bool { + auto ot = dyn_cast(v.getType()); + if (!ot) + return false; + StringRef s = ot.getValue(); + return s.contains("Tile<") || s.contains("ConvTile<"); + }; + + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value tile = peelAllCasts(peelUnrealized(adaptor.getTile())); + if (!isTileLike(tile)) + return rewriter.notifyMatchFailure( + op, "tassign tile must lower to a tile-like value"); + + Value addr = peelUnrealized(adaptor.getAddr()); + auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); + if (isa(addr.getType()) || + (isa(addr.getType()) && + cast(addr.getType()).getValue().ends_with("*"))) { + auto rcU64 = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + addr = rewriter + .create(loc, u64Ty, "reinterpret_cast", + ArrayAttr{}, rcU64, + ValueRange{addr}) + .getResult(0); + } else if (addr.getType() != u64Ty) { + addr = rewriter.create(loc, u64Ty, addr).getResult(); + } + + rewriter.create(loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{tile, addr}); + rewriter.replaceOp(op, tile); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// pto.load_scalar / pto.store_scalar lowering -> ptr[offset] +//===----------------------------------------------------------------------===// + +static Type getPointerLikeElementType(Type type) { + if (auto ptrTy = dyn_cast(type)) + return ptrTy.getElementType(); + if (auto memTy = dyn_cast(type)) + return memTy.getElementType(); + return Type(); +} + +struct PTOPtrToIntToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::PtrToIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value ptr = peelUnrealized(adaptor.getPtr()); + Type dstTy = getTypeConverter()->convertType(op.getResult().getType()); + if (!dstTy) + return failure(); + + auto dstOpaque = dyn_cast(dstTy); + if (!dstOpaque) + return failure(); + + auto templateArgs = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), + dstOpaque.getValue())}); + auto cast = rewriter.create( + op.getLoc(), dstTy, "reinterpret_cast", ArrayAttr{}, templateArgs, + ValueRange{ptr}); + rewriter.replaceOp(op, cast.getResult(0)); + return success(); + } +}; + +struct PTOIntToPtrToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::IntToPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value addr = peelUnrealized(adaptor.getAddr()); + Type dstTy = getTypeConverter()->convertType(op.getResult().getType()); + if (!dstTy) + return failure(); + + Type dstElemTy = getPointerLikeElementType(op.getResult().getType()); + if (!dstElemTy) + return failure(); + + std::string castType = + std::string("__gm__ ") + getEmitCScalarTypeToken(dstElemTy) + "*"; + auto templateArgs = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), + castType)}); + auto cast = rewriter.create( + op.getLoc(), dstTy, "reinterpret_cast", ArrayAttr{}, templateArgs, + ValueRange{addr}); + rewriter.replaceOp(op, cast.getResult(0)); + return success(); + } +}; + +struct PTOLoadScalarToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::LoadScalarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value ptr = peelUnrealized(adaptor.getPtr()); + Value offset = peelUnrealized(adaptor.getOffset()); + + Type dstTy = getTypeConverter()->convertType(op.getValue().getType()); + if (!dstTy) + return failure(); + + auto call = rewriter.create( + op.getLoc(), TypeRange{dstTy}, "PTOAS__PTR_LOAD", + ArrayAttr{}, ArrayAttr{}, ValueRange{ptr, offset}); + + rewriter.replaceOp(op, call.getResults()); + return success(); + } +}; + +struct PTOStoreScalarToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::StoreScalarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value ptr = peelUnrealized(adaptor.getPtr()); + Value offset = peelUnrealized(adaptor.getOffset()); + Value val = peelUnrealized(adaptor.getValue()); + + rewriter.create( + op.getLoc(), TypeRange{}, "PTOAS__PTR_STORE", + ArrayAttr{}, ArrayAttr{}, ValueRange{ptr, offset, val}); + rewriter.create( + op.getLoc(), TypeRange{}, "PTOAS__SCALAR_GM_STORE_FLUSH", + ArrayAttr{}, ArrayAttr{}, ValueRange{ptr}); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// pto.tabs lowering -> TABS(dst, src) +//===----------------------------------------------------------------------===// + +struct PTOTAbsToTABS : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAbsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + // intrinsic: TABS(dst, src) + rewriter.create( + op.getLoc(), TypeRange{}, "TABS", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tadd lowering -> TADD(dst, src0, src1) +//===----------------------------------------------------------------------===// + +struct PTOTAddToTADD : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAddOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TADD", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOInitializeL2G2LPipeToEmitC + : public OpConversionPattern { + PTOInitializeL2G2LPipeToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + PTOArch targetArch) + : OpConversionPattern(typeConverter, ctx), + targetArch(targetArch) {} + + LogicalResult matchAndRewrite(mlir::pto::InitializeL2G2LPipeOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto tpipeTok = buildTPipeTokenFromInitOp(op.getOperation(), targetArch); + if (failed(tpipeTok)) + return rewriter.notifyMatchFailure(op, "failed to build TPipe token"); + + auto *ctx = rewriter.getContext(); + auto emitPipeTy = + cast(getTypeConverter()->convertType(op.getPipe().getType())); + + Value gmAddr = peelUnrealized(adaptor.getGmAddr()); + gmAddr = materializeTensorViewDataPointer( + rewriter, op.getLoc(), gmAddr, op.getGmAddr().getType()); + Value localAddr = + op.getLocalAddr() ? peelUnrealized(adaptor.getLocalAddr()) : Value(); + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value zero = makeEmitCIntConstant(rewriter, op.getLoc(), i32Ty, 0); + + Value c2vBuf = zero; + Value v2cBuf = zero; + if (op.getDirMask() == 1) + c2vBuf = localAddr ? localAddr : zero; + else if (op.getDirMask() == 2) + v2cBuf = localAddr ? localAddr : zero; + else if (op.getDirMask() == 3) { + if (localAddr) { + if (!op.getPeerLocalAddr()) + return rewriter.notifyMatchFailure( + op, "bidirectional l2g2l pipe requires peer local buffer"); + c2vBuf = localAddr; + v2cBuf = peelUnrealized(adaptor.getPeerLocalAddr()); + } + } else + return rewriter.notifyMatchFailure(op, "unsupported dir_mask"); + + rewriter.replaceOpWithNewOp( + op, TypeRange{emitPipeTy}, *tpipeTok, ArrayAttr{}, ArrayAttr{}, + ValueRange{gmAddr, c2vBuf, v2cBuf}); + return success(); + } + + PTOArch targetArch; +}; + +struct PTOInitializeL2LPipeToEmitC + : public OpConversionPattern { + PTOInitializeL2LPipeToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + PTOArch targetArch) + : OpConversionPattern(typeConverter, ctx), + targetArch(targetArch) {} + + LogicalResult matchAndRewrite(mlir::pto::InitializeL2LPipeOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto tpipeTok = buildTPipeTokenFromInitOp(op.getOperation(), targetArch); + if (failed(tpipeTok)) + return rewriter.notifyMatchFailure(op, "failed to build TPipe token"); + + auto *ctx = rewriter.getContext(); + auto emitPipeTy = + cast(getTypeConverter()->convertType(op.getPipe().getType())); + + auto gmPtrTy = + emitc::PointerType::get(emitc::OpaqueType::get(ctx, "__gm__ void")); + Value nullGm = + makeEmitCOpaqueConstant(rewriter, op.getLoc(), gmPtrTy, "nullptr"); + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value zero = makeEmitCIntConstant(rewriter, op.getLoc(), i32Ty, 0); + Value localAddr = peelUnrealized(adaptor.getLocalAddr()); + + Value c2vBuf = zero; + Value v2cBuf = zero; + if (op.getDirMask() == 1) + c2vBuf = localAddr; + else if (op.getDirMask() == 2) + v2cBuf = localAddr; + else if (op.getDirMask() == 3) { + c2vBuf = localAddr; + v2cBuf = peelUnrealized(adaptor.getPeerLocalAddr()); + } else + return rewriter.notifyMatchFailure(op, "unsupported dir_mask"); + + rewriter.replaceOpWithNewOp( + op, TypeRange{emitPipeTy}, *tpipeTok, ArrayAttr{}, ArrayAttr{}, + ValueRange{nullGm, c2vBuf, v2cBuf}); + return success(); + } + + PTOArch targetArch; +}; + +struct PTOBuildAsyncSessionToEmitC + : public OpConversionPattern { + PTOBuildAsyncSessionToEmitC(TypeConverter &typeConverter, MLIRContext *ctx) + : OpConversionPattern(typeConverter, ctx) {} + + LogicalResult matchAndRewrite(mlir::pto::BuildAsyncSessionOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *ctx = rewriter.getContext(); + Location loc = op.getLoc(); + + auto sessionTy = + dyn_cast(getTypeConverter()->convertType(op.getSession().getType())); + if (!sessionTy) + return rewriter.notifyMatchFailure(op, "failed to convert async session type"); + + FailureOr scratchTile = + buildAsyncScratchTileValue(rewriter, loc, op.getScratch(), + adaptor.getScratch()); + if (failed(scratchTile)) + return rewriter.notifyMatchFailure(op, "failed to materialize async scratch tile"); + + Value workspace = + castToGMBytePointer(rewriter, loc, peelUnrealized(adaptor.getWorkspace())); + + Value session = rewriter + .create( + loc, sessionTy, emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + + auto u32Ty = emitc::OpaqueType::get(ctx, "uint32_t"); + + auto makeU32Const = [&](uint64_t value) -> Value { + return makeEmitCOpaqueConstant(rewriter, loc, u32Ty, + std::to_string(value) + "u"); + }; + uint64_t syncId = op.getSyncIdAttr() ? op.getSyncIdAttr().getInt() : 0; + uint64_t blockBytes = + op.getBlockBytesAttr() ? op.getBlockBytesAttr().getInt() : 32 * 1024; + uint64_t commBlockOffset = + op.getCommBlockOffsetAttr() ? op.getCommBlockOffsetAttr().getInt() : 0; + uint64_t queueNum = op.getQueueNumAttr() ? op.getQueueNumAttr().getInt() : 1; + uint64_t channelGroupIdx = op.getChannelGroupIdxAttr() + ? op.getChannelGroupIdxAttr().getInt() + : UINT32_MAX; + + Value syncIdVal = makeU32Const(syncId); + Value channelGroupIdxVal = + channelGroupIdx == UINT32_MAX + ? makeEmitCOpaqueConstant(rewriter, loc, u32Ty, "UINT32_MAX") + : makeU32Const(channelGroupIdx); + + auto baseConfigTy = + emitc::OpaqueType::get(ctx, "pto::comm::sdma::SdmaBaseConfig"); + Value baseConfig = + rewriter + .create( + loc, baseConfigTy, + emitc::OpaqueAttr::get( + ctx, "{" + std::to_string(blockBytes) + "ULL, " + + std::to_string(commBlockOffset) + "ULL, " + + std::to_string(queueNum) + "u}")) + .getResult(); + + rewriter.create( + loc, TypeRange{}, "pto::comm::BuildAsyncSession", + ArrayAttr{}, ArrayAttr{}, + ValueRange{*scratchTile, workspace, session, syncIdVal, baseConfig, + channelGroupIdxVal}); + + rewriter.replaceOp(op, session); + return success(); + } +}; + +template +struct PTOAsyncTransferToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + explicit PTOAsyncTransferToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + StringRef callee) + : OpConversionPattern(typeConverter, ctx), callee(callee.str()) {} + + LogicalResult matchAndRewrite(AsyncOp op, typename AsyncOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value dst = peelUnrealized(adaptor.getDst()); + Value src = peelUnrealized(adaptor.getSrc()); + Value dstGT = dst; + Value srcGT = src; + if (!isEmitCGlobalTensorLikeType(dstGT.getType())) { + auto dstMrTy = dyn_cast(op.getDst().getType()); + if (!dstMrTy) + return rewriter.notifyMatchFailure(op, "expected dst to lower to GlobalTensor or memref"); + dstGT = buildGlobalTensorFromMemref(rewriter, op.getLoc(), dst, dstMrTy, + op.getDst().getDefiningOp() + ? op.getDst().getDefiningOp() + : op.getOperation()); + } + if (!isEmitCGlobalTensorLikeType(srcGT.getType())) { + auto srcMrTy = dyn_cast(op.getSrc().getType()); + if (!srcMrTy) + return rewriter.notifyMatchFailure(op, "expected src to lower to GlobalTensor or memref"); + srcGT = buildGlobalTensorFromMemref(rewriter, op.getLoc(), src, srcMrTy, + op.getSrc().getDefiningOp() + ? op.getSrc().getDefiningOp() + : op.getOperation()); + } + if (!dstGT || !srcGT) + return rewriter.notifyMatchFailure(op, "failed to build GlobalTensor operands"); + + Type eventTy = this->getTypeConverter()->convertType(op.getEvent().getType()); + if (!eventTy) + return rewriter.notifyMatchFailure(op, "failed to convert async event type"); + + rewriter.replaceOpWithNewOp( + op, TypeRange{eventTy}, callee, ArrayAttr{}, ArrayAttr{}, + ValueRange{dstGT, srcGT, peelUnrealized(adaptor.getSession())}); + return success(); + } + + std::string callee; +}; + +template +struct PTOAsyncEventToEmitC : public OpConversionPattern { + explicit PTOAsyncEventToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + StringRef callee) + : OpConversionPattern(typeConverter, ctx), + callee(callee.str()) {} + + LogicalResult matchAndRewrite(AsyncEventOp op, + typename AsyncEventOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultTy = + this->getTypeConverter()->convertType(op.getCompleted().getType()); + if (!resultTy) + return rewriter.notifyMatchFailure(op, "failed to convert async event result type"); + + rewriter.replaceOpWithNewOp( + op, TypeRange{resultTy}, callee, ArrayAttr{}, ArrayAttr{}, + ValueRange{peelUnrealized(adaptor.getEvent()), + peelUnrealized(adaptor.getSession())}); + return success(); + } + + std::string callee; +}; + +static FailureOr buildCommGlobalTensorValue( + ConversionPatternRewriter &rewriter, Location loc, Value originalValue, + Value emittedValue, Operation *anchor) { + Value value = peelUnrealized(emittedValue); + if (isEmitCGlobalTensorLikeType(value.getType())) + return value; + + auto memTy = dyn_cast(originalValue.getType()); + if (!memTy) + return failure(); + + Value gt = buildGlobalTensorFromMemref(rewriter, loc, value, memTy, anchor); + if (!gt) + return failure(); + return gt; +} + +static FailureOr buildCommTileValue(ConversionPatternRewriter &rewriter, + Location loc, Value originalValue, + Value emittedValue) { + Value value = peelUnrealized(emittedValue); + if (auto opaqueTy = dyn_cast(value.getType())) { + StringRef typeStr = opaqueTy.getValue(); + if (typeStr.contains("Tile<") || typeStr.contains("ConvTile<")) + return value; + } + return buildAsyncScratchTileValue(rewriter, loc, originalValue, emittedValue); +} + +static FailureOr buildCollectiveParallelGroup( + ConversionPatternRewriter &rewriter, Location loc, + ArrayRef groupGTs, int64_t root) { + if (groupGTs.empty()) + return failure(); + + auto firstTy = dyn_cast(groupGTs.front().getType()); + if (!firstTy) + return failure(); + + auto *ctx = rewriter.getContext(); + auto arrayTy = emitc::ArrayType::get({static_cast(groupGTs.size())}, + firstTy); + auto groupArray = cast>( + rewriter + .create(loc, arrayTy, + emitc::OpaqueAttr::get(ctx, "{}")) + .getResult()); + + auto indexTy = emitc::OpaqueType::get(ctx, "int"); + for (auto [idx, groupVal] : llvm::enumerate(groupGTs)) { + Value idxVal = + makeEmitCIntConstant(rewriter, loc, indexTy, static_cast(idx)); + Value slot = + rewriter.create(loc, groupArray, ValueRange{idxVal}) + .getResult(); + rewriter.create(loc, slot, groupVal); + } + + std::string pgTypeStr = + (Twine("pto::comm::ParallelGroup<") + firstTy.getValue() + ">").str(); + auto pgTy = emitc::OpaqueType::get(ctx, pgTypeStr); + Value sizeVal = makeEmitCIntConstant(rewriter, loc, indexTy, + static_cast(groupGTs.size())); + Value rootVal = makeEmitCIntConstant(rewriter, loc, indexTy, root); + return rewriter + .create( + loc, TypeRange{pgTy}, (Twine(pgTypeStr) + "::Create").str(), + ArrayAttr{}, ArrayAttr{}, ValueRange{groupArray, sizeVal, rootVal}) + .getResult(0); +} + +static std::string notifyOpTok(pto::NotifyOp op) { + switch (op) { + case pto::NotifyOp::AtomicAdd: + return "pto::comm::NotifyOp::AtomicAdd"; + case pto::NotifyOp::Set: + return "pto::comm::NotifyOp::Set"; + } + return "pto::comm::NotifyOp::Set"; +} + +static std::string waitCmpTok(pto::WaitCmp cmp) { + switch (cmp) { + case pto::WaitCmp::EQ: + return "pto::comm::WaitCmp::EQ"; + case pto::WaitCmp::NE: + return "pto::comm::WaitCmp::NE"; + case pto::WaitCmp::GT: + return "pto::comm::WaitCmp::GT"; + case pto::WaitCmp::GE: + return "pto::comm::WaitCmp::GE"; + case pto::WaitCmp::LT: + return "pto::comm::WaitCmp::LT"; + case pto::WaitCmp::LE: + return "pto::comm::WaitCmp::LE"; + } + return "pto::comm::WaitCmp::EQ"; +} + +static std::string reduceOpTok(pto::ReduceOp op) { + switch (op) { + case pto::ReduceOp::Sum: + return "pto::comm::ReduceOp::Sum"; + case pto::ReduceOp::Max: + return "pto::comm::ReduceOp::Max"; + case pto::ReduceOp::Min: + return "pto::comm::ReduceOp::Min"; + } + return "pto::comm::ReduceOp::Sum"; +} + +template +static FailureOr> buildCommGroupGlobalTensors( + ConversionPatternRewriter &rewriter, Location loc, OpTy op, + ValueRange originalGroup, ValueRange emittedGroup) { + SmallVector groupGTs; + groupGTs.reserve(originalGroup.size()); + for (auto [orig, emitted] : llvm::zip(originalGroup, emittedGroup)) { + FailureOr gt = + buildCommGlobalTensorValue(rewriter, loc, orig, emitted, op.getOperation()); + if (failed(gt)) + return failure(); + groupGTs.push_back(*gt); + } + return groupGTs; +} + +template +struct PTOCommCollectiveToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + explicit PTOCommCollectiveToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + StringRef apiName) + : OpConversionPattern(typeConverter, ctx), + apiName(apiName.str()) {} + + LogicalResult matchAndRewrite(CollectiveOp op, typename CollectiveOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + auto buildPong = [&](Value original, Value emitted, StringRef name) -> FailureOr { + if (!original) + return failure(); + return buildCommTileValue(rewriter, loc, original, emitted); + }; + + if constexpr (std::is_same_v) { + FailureOr srcGT = + buildCommGlobalTensorValue(rewriter, loc, op.getSrc(), adaptor.getSrc(), + op.getOperation()); + FailureOr pingTile = + buildCommTileValue(rewriter, loc, op.getPing(), adaptor.getPing()); + auto groupGTs = + buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); + if (failed(srcGT) || failed(pingTile) || failed(groupGTs)) + return rewriter.notifyMatchFailure(op, "failed to materialize broadcast operands"); + FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); + if (failed(pg)) + return rewriter.notifyMatchFailure(op, "failed to materialize broadcast group"); + if (op.getPong()) { + FailureOr pongTile = + buildPong(op.getPong(), adaptor.getPong(), "__pong"); + if (failed(pongTile)) + return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); + rewriter.create( + loc, TypeRange{}, "pto::comm::TBROADCAST", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *srcGT, *pingTile, *pongTile}); + } else { + rewriter.create( + loc, TypeRange{}, "pto::comm::TBROADCAST", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *srcGT, *pingTile}); + } + } else if constexpr (std::is_same_v) { + FailureOr dstGT = + buildCommGlobalTensorValue(rewriter, loc, op.getDst(), adaptor.getDst(), + op.getOperation()); + FailureOr pingTile = + buildCommTileValue(rewriter, loc, op.getPing(), adaptor.getPing()); + auto groupGTs = + buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); + if (failed(dstGT) || failed(pingTile) || failed(groupGTs)) + return rewriter.notifyMatchFailure(op, "failed to materialize gather operands"); + FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); + if (failed(pg)) + return rewriter.notifyMatchFailure(op, "failed to materialize gather group"); + if (op.getPong()) { + FailureOr pongTile = + buildPong(op.getPong(), adaptor.getPong(), "__pong"); + if (failed(pongTile)) + return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); + rewriter.create( + loc, TypeRange{}, "pto::comm::TGATHER", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *dstGT, *pingTile, *pongTile}); + } else { + rewriter.create( + loc, TypeRange{}, "pto::comm::TGATHER", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *dstGT, *pingTile}); + } + } else if constexpr (std::is_same_v) { + FailureOr srcGT = + buildCommGlobalTensorValue(rewriter, loc, op.getSrc(), adaptor.getSrc(), + op.getOperation()); + FailureOr pingTile = + buildCommTileValue(rewriter, loc, op.getPing(), adaptor.getPing()); + auto groupGTs = + buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); + if (failed(srcGT) || failed(pingTile) || failed(groupGTs)) + return rewriter.notifyMatchFailure(op, "failed to materialize scatter operands"); + FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); + if (failed(pg)) + return rewriter.notifyMatchFailure(op, "failed to materialize scatter group"); + if (op.getPong()) { + FailureOr pongTile = + buildPong(op.getPong(), adaptor.getPong(), "__pong"); + if (failed(pongTile)) + return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); + rewriter.create( + loc, TypeRange{}, "pto::comm::TSCATTER", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *srcGT, *pingTile, *pongTile}); + } else { + rewriter.create( + loc, TypeRange{}, "pto::comm::TSCATTER", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *srcGT, *pingTile}); + } + } else { + FailureOr dstGT = + buildCommGlobalTensorValue(rewriter, loc, op.getDst(), adaptor.getDst(), + op.getOperation()); + FailureOr accTile = + buildCommTileValue(rewriter, loc, op.getAcc(), adaptor.getAcc()); + FailureOr recvPing = + buildCommTileValue(rewriter, loc, op.getRecvPing(), adaptor.getRecvPing()); + auto groupGTs = + buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); + if (failed(dstGT) || failed(accTile) || failed(recvPing) || failed(groupGTs)) + return rewriter.notifyMatchFailure(op, "failed to materialize reduce operands"); + FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); + if (failed(pg)) + return rewriter.notifyMatchFailure(op, "failed to materialize reduce group"); + if (op.getRecvPong()) { + FailureOr recvPong = + buildPong(op.getRecvPong(), adaptor.getRecvPong(), "__recv_pong"); + if (failed(recvPong)) + return rewriter.notifyMatchFailure(op, "failed to materialize recv_pong"); + auto reduceTy = + emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::ReduceOp"); + Value reduceOp = makeEmitCOpaqueConstant(rewriter, loc, reduceTy, + reduceOpTok(op.getReduceOp())); + rewriter.create( + loc, TypeRange{}, "pto::comm::TREDUCE", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *dstGT, *accTile, *recvPing, *recvPong, reduceOp}); + } else { + auto reduceTy = + emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::ReduceOp"); + Value reduceOp = makeEmitCOpaqueConstant(rewriter, loc, reduceTy, + reduceOpTok(op.getReduceOp())); + rewriter.create( + loc, TypeRange{}, "pto::comm::TREDUCE", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *dstGT, *accTile, *recvPing, reduceOp}); + } + } + rewriter.eraseOp(op); + return success(); + } + + std::string apiName; +}; + +template +struct PTOP2PCommToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + explicit PTOP2PCommToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + StringRef callee) + : OpConversionPattern(typeConverter, ctx), callee(callee.str()) {} + + LogicalResult matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr dstGT = + buildCommGlobalTensorValue(rewriter, op.getLoc(), op.getDst(), adaptor.getDst(), + op.getOperation()); + FailureOr srcGT = + buildCommGlobalTensorValue(rewriter, op.getLoc(), op.getSrc(), adaptor.getSrc(), + op.getOperation()); + FailureOr pingTile = + buildCommTileValue(rewriter, op.getLoc(), op.getPing(), adaptor.getPing()); + if (failed(dstGT) || failed(srcGT) || failed(pingTile)) + return rewriter.notifyMatchFailure(op, "failed to materialize p2p operands"); + + SmallVector operands{*dstGT, *srcGT, *pingTile}; + std::string actualCallee = callee; + if constexpr (std::is_same_v) { + if (op.getAtomicType() == pto::AtomicType::AtomicAdd) + actualCallee = "pto::comm::TPUT"; + } + if (op.getPong()) { + FailureOr pongTile = + buildCommTileValue(rewriter, op.getLoc(), op.getPong(), adaptor.getPong()); + if (failed(pongTile)) + return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); + operands.push_back(*pongTile); + } + + rewriter.create(op.getLoc(), TypeRange{}, actualCallee, + ArrayAttr{}, ArrayAttr{}, operands); + rewriter.eraseOp(op); + return success(); + } + + std::string callee; +}; + +template +struct PTOSignalCommToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + explicit PTOSignalCommToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + StringRef callee) + : OpConversionPattern(typeConverter, ctx), + callee(callee.str()) {} + + LogicalResult matchAndRewrite(SignalOp op, typename SignalOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr signalGT = buildCommGlobalTensorValue( + rewriter, op.getLoc(), op.getSignal(), adaptor.getSignal(), op.getOperation()); + if (failed(signalGT)) + return rewriter.notifyMatchFailure(op, "failed to materialize signal operand"); + + if constexpr (std::is_same_v) { + auto notifyTy = + emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::NotifyOp"); + Value notifyOp = makeEmitCOpaqueConstant( + rewriter, op.getLoc(), notifyTy, notifyOpTok(op.getNotifyOp())); + SmallVector operands{*signalGT, peelUnrealized(adaptor.getValue()), + notifyOp}; + rewriter.create(op.getLoc(), TypeRange{}, callee, + ArrayAttr{}, ArrayAttr{}, operands); + rewriter.eraseOp(op); + } else { + auto waitCmpTy = + emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::WaitCmp"); + Value waitCmp = makeEmitCOpaqueConstant( + rewriter, op.getLoc(), waitCmpTy, waitCmpTok(op.getCmp())); + SmallVector operands{*signalGT, peelUnrealized(adaptor.getCmpValue()), + waitCmp}; + if constexpr (std::is_same_v) { + Type resultTy = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultTy) + return rewriter.notifyMatchFailure(op, "failed to convert ttest result type"); + rewriter.replaceOpWithNewOp( + op, TypeRange{resultTy}, callee, ArrayAttr{}, ArrayAttr{}, operands); + } else { + rewriter.create(op.getLoc(), TypeRange{}, callee, + ArrayAttr{}, ArrayAttr{}, operands); + rewriter.eraseOp(op); + } + } + return success(); + } + + std::string callee; +}; + +struct PTODeclareTileMemRefToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::DeclareTileMemRefOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::DeclareTileMemRefOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + Type convertedType = getTypeConverter()->convertType(op.getResult().getType()); + if (!convertedType) + return rewriter.notifyMatchFailure( + op, "failed to convert declare_tile_memref result type"); + rewriter.replaceOp(op, makeEmitCOpaqueConstant(rewriter, op.getLoc(), + convertedType, "nullptr")); + return success(); + } +}; + +struct PTODeclareGlobalToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::DeclareGlobalOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::DeclareGlobalOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + Type convertedType = getTypeConverter()->convertType(op.getEntry().getType()); + if (!convertedType) + return rewriter.notifyMatchFailure( + op, "failed to convert declare_global result type"); + if (auto tvTy = dyn_cast(op.getEntry().getType())) { + if (auto stridesAttr = + op->getAttrOfType(kGlobalTensorStridesAttrName)) { + auto strides = stridesAttr.asArrayRef(); + if (strides.size() == static_cast(tvTy.getRank())) { + convertedType = emitc::OpaqueType::get( + rewriter.getContext(), + getGlobalTensorTypeStringFromShapeAndStrides( + tvTy.getElementType(), tvTy.getShape(), strides)); + } + } + } + auto var = rewriter.create( + op.getLoc(), convertedType, + emitc::OpaqueAttr::get(rewriter.getContext(), "")); + rewriter.replaceOp(op, var.getResult()); + return success(); + } +}; + +struct PTODeclareEventIdArrayToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::DeclareEventIdArrayOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::DeclareEventIdArrayOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + Type arrayTy = getTypeConverter()->convertType(op.getArray().getType()); + if (!arrayTy) + return rewriter.notifyMatchFailure(op, + "failed to map declared eventid_array type"); + + auto array = rewriter + .create( + op.getLoc(), arrayTy, + emitc::OpaqueAttr::get(rewriter.getContext(), "")) + .getResult(); + rewriter.replaceOp(op, array); + return success(); + } +}; + +struct PTOEventIdArrayGetToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::EventIdArrayGetOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::EventIdArrayGetOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value array = peelUnrealized(adaptor.getArray()); + Value index = peelUnrealized(adaptor.getIndex()); + + Type resultTy = getTypeConverter()->convertType(op.getResult().getType()); + if (!resultTy) + return rewriter.notifyMatchFailure(op, + "failed to map eventid_array get result type"); + + auto load = + rewriter.create(op.getLoc(), resultTy, array, index); + rewriter.replaceOp(op, load.getResult()); + return success(); + } +}; + +struct PTOEventIdArraySetToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::EventIdArraySetOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::EventIdArraySetOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value array = peelUnrealized(adaptor.getArray()); + Value index = peelUnrealized(adaptor.getIndex()); + Value value = peelUnrealized(adaptor.getValue()); + + rewriter.create( + op.getLoc(), TypeRange{}, "PTOAS__EVENTID_ARRAY_STORE", + ArrayAttr{}, ArrayAttr{}, ValueRange{array, index, value}); + rewriter.eraseOp(op); + return success(); + } +}; + +// pto.declare_local_array -> emitc.variable of !emitc.array<...>. +// Renders as `T a[D1][D2]...;` in the emitted C++. +struct PTODeclareLocalArrayToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::DeclareLocalArrayOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::DeclareLocalArrayOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + Type arrayTy = getTypeConverter()->convertType(op.getArray().getType()); + if (!arrayTy) + return rewriter.notifyMatchFailure(op, + "failed to map !pto.local_array type"); + + auto var = rewriter + .create( + op.getLoc(), arrayTy, + emitc::OpaqueAttr::get(rewriter.getContext(), "")) + .getResult(); + rewriter.replaceOp(op, var); + return success(); + } +}; + +// pto.local_array_get %a[%i0, %i1, ...] -> rvalue. +// Lowers to a single emitc.subscript with the full index pack; the C++ emitter +// prints it as `a[i0][i1]...`. The adaptor already exposes target-typed values +// (the type converter has remapped !pto.local_array -> !emitc.array and +// index/integer indices), so they're forwarded directly to the builder. +struct PTOLocalArrayGetToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::LocalArrayGetOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::LocalArrayGetOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultTy = + getTypeConverter()->convertType(op.getResult().getType()); + if (!resultTy) + return rewriter.notifyMatchFailure( + op, "failed to map local_array element type"); + + auto sub = rewriter.create( + op.getLoc(), resultTy, adaptor.getArray(), adaptor.getIndices()); + rewriter.replaceOp(op, sub.getResult()); + return success(); + } +}; + +// pto.local_array_set %a[%i0, %i1, ...], %v -> emitc.assign to subscript slot. +// The C++ emitter prints this as `a[i0][i1]... = v;`. As above, adaptor values +// are already target-typed; pass them through directly. +struct PTOLocalArraySetToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::LocalArraySetOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::LocalArraySetOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value value = adaptor.getValue(); + Type elemTy = value.getType(); + + Value slot = rewriter + .create(op.getLoc(), elemTy, + adaptor.getArray(), + adaptor.getIndices()) + .getResult(); + rewriter.create(op.getLoc(), slot, value); + rewriter.eraseOp(op); + return success(); + } +}; + +static std::optional getStaticIndexLikeValue(Value value) { + if (!value) + return std::nullopt; + if (auto cst = value.getDefiningOp()) + return cst.value(); + if (auto cst = value.getDefiningOp()) + return cst.value(); + if (auto cst = value.getDefiningOp()) { + if (auto intAttr = dyn_cast(cst.getValue())) + return intAttr.getInt(); + } + return std::nullopt; +} + +static FailureOr buildGlobalTensorViewFromPointer( + ConversionPatternRewriter &rewriter, Location loc, Value ptr, Type elemTy, + ArrayRef shape, ArrayRef strides = {}, + StringRef layoutEnum = "pto::Layout::ND") { + if (llvm::any_of(shape, [](int64_t dim) { + return dim == ShapedType::kDynamic; + })) + return failure(); + + auto *ctx = rewriter.getContext(); + SmallVector rowMajorStrides; + ArrayRef effectiveStrides = strides; + if (effectiveStrides.empty()) { + rowMajorStrides = buildRowMajorStrides(shape); + effectiveStrides = rowMajorStrides; + } + SmallVector shape5D; + SmallVector stride5D; + buildGlobalTensorShapeAndStride(shape, effectiveStrides, shape5D, stride5D); + + std::string shapeType = "pto::Shape<" + joinIntTemplateParams(shape5D) + ">"; + std::string strideType = + "pto::Stride<" + joinIntTemplateParams(stride5D) + ">"; + auto shapeVal = rewriter + .create( + loc, emitc::OpaqueType::get(ctx, shapeType), + shapeType, ArrayAttr{}, ArrayAttr{}, ValueRange{}) + .getResult(0); + auto strideVal = rewriter + .create( + loc, emitc::OpaqueType::get(ctx, strideType), + strideType, ArrayAttr{}, ArrayAttr{}, ValueRange{}) + .getResult(0); + + std::string gtTypeStr = + getGlobalTensorTypeStringFromShapeAndStrides(elemTy, shape, + effectiveStrides, + layoutEnum); + auto gtType = emitc::OpaqueType::get(ctx, gtTypeStr); + auto gt = rewriter.create( + loc, gtType, gtTypeStr, ArrayAttr{}, ArrayAttr{}, + ValueRange{ptr, shapeVal, strideVal}); + return gt.getResult(0); +} + +static bool parseIntegerTemplateList(StringRef token, StringRef marker, + SmallVectorImpl &values) { + size_t pos = token.find(marker); + if (pos == StringRef::npos) + return false; + pos += marker.size(); + size_t end = token.find('>', pos); + if (end == StringRef::npos) + return false; + + SmallVector parts; + token.slice(pos, end).split(parts, ','); + values.clear(); + for (StringRef part : parts) { + int64_t value = 0; + if (part.trim().getAsInteger(10, value)) + return false; + values.push_back(value); + } + return true; +} + +static LogicalResult getStaticTensorViewStrides( + Value source, Value convertedSource, pto::TensorViewType sourceType, + SmallVectorImpl &strides) { + int64_t rank = sourceType.getRank(); + strides.clear(); + + if (auto makeView = source.getDefiningOp()) { + if ((int64_t)makeView.getStrides().size() != rank) + return failure(); + for (Value strideValue : makeView.getStrides()) { + auto cst = getStaticIndexLikeValue(strideValue); + if (!cst) + return failure(); + strides.push_back(*cst); + } + return success(); + } + + Value src = peelUnrealized(convertedSource); + if (auto opaqueTy = dyn_cast(src.getType())) { + SmallVector stride5D; + StringRef token = opaqueTy.getValue(); + if ((parseIntegerTemplateList(token, "pto::Stride<", stride5D) || + parseIntegerTemplateList(token, "Stride<", stride5D)) && + (int64_t)stride5D.size() >= rank) { + strides.append(stride5D.end() - rank, stride5D.end()); + return success(); + } + } + + auto fallback = buildRowMajorStrides(sourceType.getShape()); + strides.append(fallback.begin(), fallback.end()); + return success(); +} + +struct PTOPartitionViewToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::PartitionViewOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::PartitionViewOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcTy = dyn_cast(op.getSource().getType()); + auto resTy = dyn_cast(op.getResult().getType()); + if (!srcTy || !resTy) + return rewriter.notifyMatchFailure( + op, "expected tensor_view source and partition_tensor_view result"); + + if (op.getOffsets().size() != static_cast(srcTy.getRank()) || + op.getSizes().size() != static_cast(srcTy.getRank())) + return rewriter.notifyMatchFailure(op, "rank mismatch"); + + for (auto [idx, value] : llvm::enumerate(op.getSizes())) { + auto cst = getStaticIndexLikeValue(value); + if (!cst) + return rewriter.notifyMatchFailure( + op, "globaltensor partition_view requires static sizes"); + int64_t resultDim = resTy.getShape()[idx]; + if (resultDim != ShapedType::kDynamic && resultDim != *cst) + return rewriter.notifyMatchFailure( + op, "partition_view static size does not match result type"); + } + + SmallVector srcStrides; + if (failed(getStaticTensorViewStrides(op.getSource(), adaptor.getSource(), + srcTy, srcStrides))) + return rewriter.notifyMatchFailure( + op, "partition_view requires static source strides"); + int64_t staticLinearOffset = 0; + SmallVector> dynamicOffsetTerms; + for (auto [idx, values] : + llvm::enumerate(llvm::zip(op.getOffsets(), adaptor.getOffsets()))) { + Value originalOffset = std::get<0>(values); + Value convertedOffset = std::get<1>(values); + int64_t stride = srcStrides[idx]; + if (stride == ShapedType::kDynamic) + return rewriter.notifyMatchFailure( + op, "dynamic source stride is not supported"); + + if (auto cst = getStaticIndexLikeValue(originalOffset)) { + if (*cst != 0) + staticLinearOffset += (*cst) * stride; + continue; + } + dynamicOffsetTerms.push_back({convertedOffset, stride}); + } + + auto *ctx = rewriter.getContext(); + std::string elemTypeStr = getElemTypeStringForGT(srcTy.getElementType()); + auto ptrTy = emitc::PointerType::get( + emitc::OpaqueType::get(ctx, "__gm__ " + elemTypeStr)); + Value src = peelUnrealized(adaptor.getSource()); + auto data = rewriter + .create( + op.getLoc(), ptrTy, "PTOAS__GLOBAL_TENSOR_DATA", + ArrayAttr{}, ArrayAttr{}, ValueRange{src}) + .getResult(0); + Value ptr = data; + if (!dynamicOffsetTerms.empty()) { + Type u32Ty = emitc::OpaqueType::get(ctx, "unsigned"); + auto makeU32 = [&](int64_t value) { + return makeEmitCIntConstant(rewriter, op.getLoc(), u32Ty, value); + }; + auto asU32 = [&](Value value) -> Value { + if (value.getType() == u32Ty) + return value; + return rewriter.create(op.getLoc(), u32Ty, value) + .getResult(); + }; + + Value totalOffset = makeU32(staticLinearOffset); + for (auto [offsetValue, stride] : dynamicOffsetTerms) { + Value term = asU32(offsetValue); + if (stride != 1) { + Value strideValue = makeU32(stride); + term = rewriter + .create(op.getLoc(), u32Ty, term, + strideValue) + .getResult(); + } + totalOffset = rewriter + .create(op.getLoc(), u32Ty, + totalOffset, term) + .getResult(); + } + ptr = rewriter + .create(op.getLoc(), data.getType(), data, + totalOffset) + .getResult(); + } else { + ptr = applyStaticMemrefOffset(rewriter, op.getLoc(), data, + staticLinearOffset); + } + + auto resultOr = buildGlobalTensorViewFromPointer( + rewriter, op.getLoc(), ptr, resTy.getElementType(), resTy.getShape(), + srcStrides); + if (failed(resultOr)) + return rewriter.notifyMatchFailure( + op, "failed to materialize partition GlobalTensor"); + + rewriter.replaceOp(op, *resultOr); + return success(); + } +}; + +static FailureOr getPipeDataTypeToken(Value value) { + auto opaqueTy = dyn_cast(value.getType()); + if (!opaqueTy) + return failure(); + StringRef token = opaqueTy.getValue(); + if (!token.contains("Tile<") && !token.contains("GlobalTensor<")) + return failure(); + return token.str(); +} + +struct PTOTAllocToEmitC : public OpConversionPattern { + PTOTAllocToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + PTOArch targetArch) + : OpConversionPattern(typeConverter, ctx), + targetArch(targetArch) {} + + LogicalResult matchAndRewrite(mlir::pto::TAllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); + if (failed(pipeTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); + Value entry = peelUnrealized(adaptor.getEntry()); + auto entryTok = getPipeDataTypeToken(entry); + if (failed(entryTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve entry token"); + auto splitTok = getTileSplitToken(op.getSplit()); + if (failed(splitTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve split token"); + + std::string callee = + "TALLOC<" + *pipeTok + ", " + *entryTok + ", " + *splitTok + ">"; + rewriter.replaceOpWithNewOp( + op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, + ValueRange{peelUnrealized(adaptor.getPipeHandle()), entry}); + return success(); + } + + PTOArch targetArch; +}; + +struct PTOTPushToEmitC : public OpConversionPattern { + PTOTPushToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + PTOArch targetArch) + : OpConversionPattern(typeConverter, ctx), + targetArch(targetArch) {} + + LogicalResult matchAndRewrite(mlir::pto::TPushOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); + if (failed(pipeTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); + // Read the tile type token from the already-converted OpaqueType, which + // preserves the exact layout produced by BindTileOp / PointerCastOp EmitC. + Value convertedTile = peelUnrealized(adaptor.getTile()); + auto tileTok = getPipeDataTypeToken(convertedTile); + if (failed(tileTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve tile token"); + auto splitTok = getTileSplitToken(op.getSplit()); + if (failed(splitTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve split token"); + + std::string callee = + "TPUSH<" + *pipeTok + ", " + *tileTok + ", " + *splitTok + ">"; + rewriter.replaceOpWithNewOp( + op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, + ValueRange{peelUnrealized(adaptor.getPipeHandle()), convertedTile}); + return success(); + } + + PTOArch targetArch; +}; + +struct PTOTPopToEmitC : public OpConversionPattern { + PTOTPopToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + PTOArch targetArch) + : OpConversionPattern(typeConverter, ctx), + targetArch(targetArch) {} + + LogicalResult matchAndRewrite(mlir::pto::TPopOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); + if (failed(pipeTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); + Value convertedTile = peelUnrealized(adaptor.getTile()); + auto tileTok = getPipeDataTypeToken(convertedTile); + if (failed(tileTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve tile token"); + auto splitTok = getTileSplitToken(op.getSplit()); + if (failed(splitTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve split token"); + + std::string callee = + "TPOP<" + *pipeTok + ", " + *tileTok + ", " + *splitTok + ">"; + rewriter.replaceOpWithNewOp( + op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, + ValueRange{peelUnrealized(adaptor.getPipeHandle()), convertedTile}); + return success(); + } + + PTOArch targetArch; +}; + +struct PTOTFreeToEmitC : public OpConversionPattern { + PTOTFreeToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + PTOArch targetArch) + : OpConversionPattern(typeConverter, ctx), + targetArch(targetArch) {} + + LogicalResult matchAndRewrite(mlir::pto::TFreeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); + if (failed(pipeTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); + auto splitTok = getTileSplitToken(op.getSplit()); + if (failed(splitTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve split token"); + + SmallVector operands{peelUnrealized(adaptor.getPipeHandle())}; + std::string callee; + if (op.getEntry()) { + Value entry = peelUnrealized(adaptor.getEntry()); + auto entryTok = getPipeDataTypeToken(entry); + if (failed(entryTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve entry token"); + callee = "TFREE<" + *pipeTok + ", " + *entryTok + ", " + *splitTok + ">"; + operands.push_back(entry); + } else { + callee = "TFREE<" + *pipeTok + ", " + *splitTok + ">"; + } + rewriter.replaceOpWithNewOp( + op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, operands); + return success(); + } + + PTOArch targetArch; +}; + +//===----------------------------------------------------------------------===// +// populate patterns +//===----------------------------------------------------------------------=== +struct ReinterpretCastToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + auto resMrTy = dyn_cast(op.getType()); + if (!resMrTy) + return failure(); + + auto asAttr = dyn_cast_or_null(resMrTy.getMemorySpace()); + const bool isGm = (!asAttr || asAttr.getAddressSpace() == pto::AddressSpace::GM); + + bool emitAddPtrTrace = op->hasAttr("pto.addptr_trace"); + Value source = peelUnrealized(adaptor.getSource()); + auto offsets = adaptor.getOffsets(); + Value offsetVal = offsets.empty() ? Value() : offsets[0]; + + // GM: keep pointer arithmetic. + if (isGm) { + if (!offsetVal) { + rewriter.replaceOp(op, source); + return success(); + } + + Type resultType = getTypeConverter()->convertType(op.getType()); + if (!resultType) + return failure(); + + auto addOp = rewriter.create(loc, resultType, source, offsetVal); + if (emitAddPtrTrace) { + rewriter.setInsertionPointAfter(addOp); + rewriter.create( + loc, TypeRange{}, "PTOAS__ADDPTR_TRACE", + ArrayAttr{}, ArrayAttr{}, + ValueRange{addOp.getResult(), source, offsetVal}); + } + rewriter.replaceOp(op, addOp.getResult()); + return success(); + } + + // UB/L1/L0 tiles: materialize a new Tile view by assigning an adjusted + // underlying pointer (in elements). + pto::AddressSpace as = asAttr.getAddressSpace(); + + // Element type token. + Type elemTy = resMrTy.getElementType(); + std::string elemTok = getEmitCScalarTypeToken(elemTy); + int64_t elemBytes = getEmitCScalarByteWidth(elemTy); + + // Tile role. + const char *roleTok = "TileType::Vec"; + switch (as) { + case pto::AddressSpace::VEC: + roleTok = "TileType::Vec"; + break; + case pto::AddressSpace::MAT: + roleTok = "TileType::Mat"; + break; + case pto::AddressSpace::LEFT: + roleTok = "TileType::Left"; + break; + case pto::AddressSpace::RIGHT: + roleTok = "TileType::Right"; + break; + case pto::AddressSpace::ACC: + roleTok = "TileType::Acc"; + break; + case pto::AddressSpace::BIAS: + roleTok = "TileType::Bias"; + break; + case pto::AddressSpace::GM: + roleTok = "TileType::Vec"; + break; + case pto::AddressSpace::Zero: + roleTok = "TileType::Vec"; + break; + case pto::AddressSpace::SCALING: + roleTok = "TileType::Scaling"; + break; + } + + // Shape (fallback to 32x32). + int64_t rows = 32, cols = 32; + if (resMrTy.getRank() >= 2 && resMrTy.hasStaticShape()) { + rows = resMrTy.getDimSize(0); + cols = resMrTy.getDimSize(1); + } + int64_t templateRows = + renderTileTemplateDim(rows, elemTy, pto::BLayout::RowMajor, 0); + int64_t templateCols = + renderTileTemplateDim(cols, elemTy, pto::BLayout::RowMajor, 1); + + // Keep a conservative default config for now. + std::string tileTypeStr = + std::string("Tile<") + roleTok + ", " + elemTok + ", " + + std::to_string(templateRows) + ", " + std::to_string(templateCols) + + ", BLayout::RowMajor, " + std::to_string(templateRows) + ", " + + std::to_string(templateCols) + + ", SLayout::NoneBox, 512, PadValue::Null, CompactMode::Null>"; + + auto tileType = emitc::OpaqueType::get(ctx, tileTypeStr); + Value tile = rewriter + .create(loc, tileType, + emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + + // Compute an integer address and assign it to the new tile. + // NOTE: pto-isa TASSIGN requires an integral address (not a pointer). + auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); + auto rcU64 = rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + + // Non-GM reinterpret_cast operands come from UB/L1/L0 tiles. + // We need the underlying address, but `__cce_get_tile_ptr()` is only valid + // inside `__tf__` functions. Use `tile.data()` (via a post-processed marker) + // and compute the adjusted address in bytes. + Value rawPtr = source; + if (auto ot = dyn_cast(source.getType())) { + // Only Tiles have a `.data()` member. For plain address-space pointers + // (e.g. `__ubuf__ float*`), use the pointer value directly. + if (ot.getValue().starts_with("Tile<")) { + rawPtr = materializeTileDataValue(rewriter, loc, source, as, elemTok); + } + } + + Value baseAddr = rawPtr; + if (isSetFFTsPointerLikeType(rawPtr.getType())) { + baseAddr = rewriter + .create(loc, u64Ty, "reinterpret_cast", + /*args=*/ArrayAttr{}, + /*templateArgs=*/rcU64, + /*operands=*/ValueRange{rawPtr}) + .getResult(0); + } else if (rawPtr.getType() != u64Ty) { + baseAddr = rewriter.create(loc, u64Ty, rawPtr).getResult(); + } + + Value addr = baseAddr; + if (offsetVal) { + Value offU64 = offsetVal; + if (offU64.getType() != u64Ty) + offU64 = rewriter.create(loc, u64Ty, offU64).getResult(); + + auto bytesAttr = emitc::OpaqueAttr::get(ctx, std::to_string(elemBytes)); + Value bytesVal = rewriter.create(loc, u64Ty, bytesAttr); + Value byteOff = rewriter.create(loc, u64Ty, offU64, bytesVal); + addr = rewriter.create(loc, u64Ty, baseAddr, byteOff); + } + + rewriter.create(loc, TypeRange{}, "TASSIGN", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{tile, addr}); + + rewriter.replaceOp(op, tile); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.taddc lowering -> TADDC(dst, src0, src1, src2) +//===----------------------------------------------------------------------===// + +struct PTOTAddCToTADDC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAddCOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value src2 = peelUnrealized(adaptor.getSrc2()); + Value dst = peelUnrealized(adaptor.getDst()); + + // pto-isa does not provide NPU implementation for TADDC yet. + // Decompose: dst = src0 + src1 + src2 + rewriter.create( + loc, TypeRange{}, "TADD", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TADD", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, dst, src2}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tadds lowering -> TADDS(dst, src, scalar) +//===----------------------------------------------------------------------===// + +struct PTOAddSToTADDS : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAddSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value scalar = peelUnrealized(adaptor.getScalar()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TADDS", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src, scalar}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.taddsc lowering -> TADDSC(dst, src0, scalar, src1) +//===----------------------------------------------------------------------===// + +struct PTOAddSCToTADDSC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAddSCOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + // pto-isa does not provide NPU implementation for TADDSC yet. + // Decompose: dst = src0 + scalar + src1 + rewriter.create( + loc, TypeRange{}, "TADDS", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src0, scalar}); + rewriter.create( + loc, TypeRange{}, "TADD", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, dst, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; +struct PTOTAndToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAndOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value a = peelUnrealized(adaptor.getSrc0()); + Value b = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TAND", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, a, b}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOConcatToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TConcatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TCONCAT", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; +struct PTOConcatidxToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TConcatidxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value src0Idx = peelUnrealized(adaptor.getSrc0Idx()); + Value src1Idx = peelUnrealized(adaptor.getSrc1Idx()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TCONCAT", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src0, src1, src0Idx, src1Idx}); + + rewriter.eraseOp(op); + return success(); + } +}; +struct PTOAndSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAndSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TANDS", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src, scalar}); + + rewriter.eraseOp(op); + return success(); + } +}; + + +struct PTOTCIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TCIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value S = peelUnrealized(adaptor.getOperands()[0]); + + // The TCI scalar template parameter should follow the original PTO IR + // scalar type, not the converted EmitC value type. + std::string scalarTok = "int32_t"; + if (auto it = dyn_cast(op->getOperand(0).getType())) { + bool isUnsigned = it.isUnsigned(); + if (it.getWidth() == 16) + scalarTok = isUnsigned ? "uint16_t" : "int16_t"; + else + scalarTok = isUnsigned ? "uint32_t" : "int32_t"; + } + + // descending -> "0"/"1" + std::string descTok = op.getDescending() ? "1" : "0"; + + ArrayAttr targs; + if (auto ot = mlir::dyn_cast(dst.getType())) { + std::string tileTok = ot.getValue().str(); // "Tile<...>" + targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, tileTok), + emitc::OpaqueAttr::get(ctx, scalarTok), + emitc::OpaqueAttr::get(ctx, descTok), + }); + } else { + targs = rewriter.getArrayAttr({}); + } + + rewriter.create( + loc, TypeRange{}, "TCI", + /*args=*/ArrayAttr{}, + /*templateArgs=*/targs, + /*operands=*/ValueRange{dst, S}); + + rewriter.eraseOp(op); + return success(); + } +}; +static std::string cmpModeTok(pto::CmpModeAttr a) { + // 生成 "CmpMode::GT" 这种 token + auto m = a.getValue(); // 取 enum + switch (m) { + case pto::CmpMode::EQ: return "CmpMode::EQ"; + case pto::CmpMode::NE: return "CmpMode::NE"; + case pto::CmpMode::LT: return "CmpMode::LT"; + case pto::CmpMode::LE: return "CmpMode::LE"; + case pto::CmpMode::GT: return "CmpMode::GT"; + case pto::CmpMode::GE: return "CmpMode::GE"; + } + return "CmpMode::EQ"; +} +struct PTOColExpandToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value src = peelUnrealized(adaptor.getSrc()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPAND", + /*args=*/ArrayAttr(), + /*templateArgs=*/ArrayAttr(), + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColExpandMulToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandMulOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPANDMUL", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColExpandAddToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandAddOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPANDADD", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColExpandDivToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandDivOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPANDDIV", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColExpandExpdifToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandExpdifOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPANDEXPDIF", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColExpandSubToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandSubOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPANDSUB", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColExpandMaxToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPANDMAX", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColExpandMinToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPANDMIN", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOTTriToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TTriOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value diagonal = peelUnrealized(adaptor.getDiagonal()); + + ArrayAttr templateArgs; + if (auto dstOT = mlir::dyn_cast(dst.getType())) { + templateArgs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), + emitc::OpaqueAttr::get(ctx, std::to_string(op.getUpperOrLower())), + }); + } else { + templateArgs = ArrayAttr{}; + } + + SmallVector operands{dst, diagonal}; + rewriter.create( + loc, TypeRange{}, "TTRI", + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOCmpToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TCmpOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + + std::string tok = "CmpMode::EQ"; + if (auto a = op.getCmpModeAttr()) + tok = cmpModeTok(a); + + auto modeTy = emitc::OpaqueType::get(ctx, "CmpMode"); + Value modeVal = rewriter.create( + loc, modeTy, emitc::OpaqueAttr::get(ctx, tok)); + + rewriter.create( + loc, + TypeRange{}, + "TCMP", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1, modeVal}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOCmpSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TCmpSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + + // cmpMode -> token + auto cmpAttr = op.getCmpModeAttr(); // PTO_CmpModeAttr + std::string tok = cmpModeTok(cmpAttr); + + auto modeTy = emitc::OpaqueType::get(ctx, "CmpMode"); + Value modeVal = rewriter.create( + loc, modeTy, emitc::OpaqueAttr::get(ctx, tok)); + + rewriter.create( + loc, + TypeRange{}, + "TCMPS", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, scalar, modeVal}); + + rewriter.eraseOp(op); + return success(); + } +}; + + +struct PTOColMaxToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + // intrinsic: TCOLMAX(dst, src) + rewriter.create( + loc, TypeRange{}, "TCOLMAX", + /*args=*/ArrayAttr{}, // default: print all operands + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColArgMaxToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColArgMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLARGMAX", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, tmp}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColMinToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + // intrinsic: TCOLMIN(dst, src) + rewriter.create( + loc, TypeRange{}, "TCOLMIN", + /*args=*/ArrayAttr{}, // default: print all operands + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColArgMinToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColArgMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLARGMIN", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, tmp}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColSumToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColSumOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + // Check if tmp exists before accessing it + if (op.getTmp()) { + // Format 2: with tmp and isBinary + Value tmp = peelUnrealized(adaptor.getTmp()); + bool isBinary = false; + if (auto a = op.getIsBinaryAttr()) + isBinary = a.getValue(); + + auto boolTy = emitc::OpaqueType::get(ctx, "bool"); + auto tok = isBinary ? "true" : "false"; + Value isBinaryVal = rewriter.create( + loc, boolTy, emitc::OpaqueAttr::get(ctx, tok)); + + rewriter.create( + loc, TypeRange{}, "TCOLSUM", + /*args=*/ArrayAttr(), + /*templateArgs=*/ArrayAttr(), + /*operands=*/ValueRange{dst, src, tmp, isBinaryVal}); + } else { + // Format 1: without tmp and isBinary + rewriter.create( + loc, TypeRange{}, "TCOLSUM", + /*args=*/ArrayAttr(), + /*templateArgs=*/ArrayAttr(), + /*operands=*/ValueRange{dst, src}); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColProdToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColProdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLPROD", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; +static std::string roundModeTok(mlir::pto::RoundModeAttr attr) { + using RM = mlir::pto::RoundMode; + switch (attr.getValue()) { + case RM::NONE: return "RoundMode::CAST_NONE"; + case RM::RINT: return "RoundMode::CAST_RINT"; + case RM::ROUND: return "RoundMode::CAST_ROUND"; + case RM::FLOOR: return "RoundMode::CAST_FLOOR"; + case RM::CEIL: return "RoundMode::CAST_CEIL"; + case RM::TRUNC: return "RoundMode::CAST_TRUNC"; + case RM::ODD: return "RoundMode::CAST_ODD"; + case RM::CAST_RINT: return "RoundMode::CAST_RINT"; + } + return "RoundMode::CAST_RINT"; +} +static std::string saturationModeTok(mlir::pto::SaturationModeAttr attr) { + using SM = mlir::pto::SaturationMode; + switch (attr.getValue()) { + case SM::ON: return "SaturationMode::ON"; + case SM::OFF: return "SaturationMode::OFF"; + } + return "SaturationMode::OFF"; +} +struct PTOCvtToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TCvtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + pto::RoundModeAttr rmAttr = op.getRmodeAttr(); + std::string rmTok = rmAttr ? roundModeTok(rmAttr) + : std::string("RoundMode::CAST_RINT"); + auto rmodeTy = emitc::OpaqueType::get(ctx, "RoundMode"); + Value rmodeVal = rewriter.create( + loc, rmodeTy, emitc::OpaqueAttr::get(ctx, rmTok)); + + auto satModeTy = emitc::OpaqueType::get(ctx, "SaturationMode"); + auto satAttr = op.getSatModeAttr(); + std::string satTok = satAttr ? saturationModeTok(satAttr) + : std::string("SaturationMode::OFF"); + Value satModeVal = rewriter.create( + loc, satModeTy, emitc::OpaqueAttr::get(ctx, satTok)); + + SmallVector operands{dst, src, rmodeVal, satModeVal}; + + rewriter.create( + loc, TypeRange{}, "TCVT", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +struct PTORandomToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRandomOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + SmallVector operands{ + dst, + peelUnrealized(adaptor.getKey0()), + peelUnrealized(adaptor.getKey1()), + peelUnrealized(adaptor.getCounter0()), + peelUnrealized(adaptor.getCounter1()), + peelUnrealized(adaptor.getCounter2()), + peelUnrealized(adaptor.getCounter3()), + }; + ArrayAttr templateArgs = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, std::to_string(op.getRounds()))}); + + rewriter.create( + loc, TypeRange{}, "PTOAS__TRANDOM", + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, operands); + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tdiv lowering -> TDIV(dst, src0, src1) +//===----------------------------------------------------------------------===// + +struct PTODivToTDIV : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TDivOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TDIV", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tdivs lowering -> TDIVS(dst, src, scalar) or TDIVS(dst, scalar, src) +// Order is determined by operand types: if src is tile_buf, order is (tile, scalar) +// Otherwise, order is (scalar, tile) +//===----------------------------------------------------------------------===// + +struct PTODivSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TDivSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value dst = peelUnrealized(adaptor.getDst()); + // Preserve source order from textual parse: + // ins(tile, scalar) -> TDIVS(dst, tile, scalar) + // ins(scalar, tile) -> TDIVS(dst, scalar, tile) + rewriter.create( + loc, TypeRange{}, "TDIVS", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src, scalar}); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// pto.tdivs (TDivSOp) lowering -> TDIVS(dst, src, scalar) or TDIVS(dst, scalar, src) +// Order is determined by operand types: if src is tile_buf, order is (tile, scalar) +// Otherwise, order is (scalar, tile) +//===----------------------------------------------------------------------===// + +struct PTOTDivSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TDivSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value dst = peelUnrealized(adaptor.getDst()); + rewriter.create( + loc, TypeRange{}, "TDIVS", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src, scalar}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.texp lowering -> TEXP(dst, src) +//===----------------------------------------------------------------------===// + +struct PTOExpToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TExpOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TEXP", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.texpands lowering -> TEXPANDS(dst, scalar) +//===----------------------------------------------------------------------===// + +struct PTOExpandsToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TExpandsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value scalar = peelUnrealized(adaptor.getScalar()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TEXPANDS", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, scalar}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.textract lowering -> TEXTRACT(dst, src, indexRow, indexCol) +//===----------------------------------------------------------------------===// + +struct PTOExtractToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TExtractOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value r0 = peelUnrealized(adaptor.getIndexRow()); + Value c0 = peelUnrealized(adaptor.getIndexCol()); + + rewriter.create( + loc, TypeRange{}, "TEXTRACT", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, r0, c0}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.textract_fp lowering -> TEXTRACT_FP(dst, src, fp, indexRow, indexCol) +//===----------------------------------------------------------------------===// + +struct PTOExtractFPToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TExtractFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value fp = peelUnrealized(adaptor.getFp()); + Value dst = peelUnrealized(adaptor.getDst()); + Value r0 = peelUnrealized(adaptor.getIndexRow()); + Value c0 = peelUnrealized(adaptor.getIndexCol()); + + rewriter.create( + loc, TypeRange{}, "TEXTRACT_FP", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, fp, r0, c0}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tinsert lowering -> TINSERT(dst, src, indexRow, indexCol) +// Keep lowering arch-agnostic and let PTO-ISA infer proper A5 path. +//===----------------------------------------------------------------------===// + +struct PTOInsertToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TInsertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value r0 = peelUnrealized(adaptor.getIndexRow()); + Value c0 = peelUnrealized(adaptor.getIndexCol()); + + rewriter.create( + loc, TypeRange{}, "TINSERT", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, r0, c0}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tinsert_fp lowering -> TINSERT_FP(dst, src, fp, indexRow, indexCol) +//===----------------------------------------------------------------------===// + +struct PTOInsertFPToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TInsertFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value fp = peelUnrealized(adaptor.getFp()); + Value dst = peelUnrealized(adaptor.getDst()); + Value r0 = peelUnrealized(adaptor.getIndexRow()); + Value c0 = peelUnrealized(adaptor.getIndexCol()); + + rewriter.create( + loc, TypeRange{}, "TINSERT_FP", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, fp, r0, c0}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tfillpad lowering -> TFILLPAD(dst, src) +//===----------------------------------------------------------------------===// + +struct PTOFillPadToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TFillPadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TFILLPAD", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tfillpad_inplace lowering -> TFILLPAD_INPLACE(dst, src) +//===----------------------------------------------------------------------===// + +struct PTOFillPadInplaceToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TFillPadInplaceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TFILLPAD_INPLACE", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tfillpad_expand lowering -> TFILLPAD_EXPAND(dst, src) +//===----------------------------------------------------------------------===// + +struct PTOFillPadExpandToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TFillPadExpandOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TFILLPAD_EXPAND", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tgather lowering +// - Index form : TGATHER(dst, src0, indices, tmp) +// - Compare form: TGATHER(dst, src0, kValue, cdst, tmp) +// - Mask form : TGATHER(dst, src0) +//===----------------------------------------------------------------------===// + +[[maybe_unused]] static std::string maskPatternTok(mlir::pto::MaskPatternAttr a) { + + auto v = a.getValue(); // enum + return (std::string("pto::MaskPattern::") + mlir::pto::stringifyMaskPattern(v).str()); +} + +struct PTOGatherToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value src0 = peelUnrealized(adaptor.getSrc()); + + auto getOpaqueTok = [&](Value v, StringRef name) -> FailureOr { + if (auto ot = mlir::dyn_cast(v.getType())) + return ot.getValue().str(); + return rewriter.notifyMatchFailure(op, (name + " must be emitc::OpaqueType (tile)").str()); + }; + + // Case 1: index-based TGATHER(dst, src0, indices, tmp) + if (Value idx = adaptor.getIndices()) { + idx = peelUnrealized(idx); + Value tmp = peelUnrealized(adaptor.getTmp()); + + rewriter.create( + loc, TypeRange{}, "TGATHER", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, idx, tmp}); + + rewriter.eraseOp(op); + return success(); + } + + // Case 2: compare-based TGATHER( + // dst, src0, kValue, tmp, cdst, offset) + if (Value cdst = adaptor.getCdst()) { + cdst = peelUnrealized(cdst); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value kValue = peelUnrealized(adaptor.getKValue()); + + auto dstTokOr = getOpaqueTok(dst, "dst"); + auto srcTokOr = getOpaqueTok(src0, "src0"); + auto cdstTokOr = getOpaqueTok(cdst, "cdst"); + auto tmpTokOr = getOpaqueTok(tmp, "tmp"); + if (failed(dstTokOr) || failed(srcTokOr) || failed(cdstTokOr) || failed(tmpTokOr)) + return failure(); + + auto cmpAttr = op.getCmpModeAttr(); + std::string cmpTok = cmpAttr ? cmpModeTok(cmpAttr) : "CmpMode::EQ"; + int64_t offset = 0; + if (auto offsetAttr = op.getOffsetAttr()) + offset = offsetAttr.getInt(); + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value offsetVal = makeEmitCIntConstant(rewriter, loc, i32Ty, offset); + + auto targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, *dstTokOr), + emitc::OpaqueAttr::get(ctx, *srcTokOr), + emitc::OpaqueAttr::get(ctx, *tmpTokOr), + emitc::OpaqueAttr::get(ctx, *cdstTokOr), + emitc::OpaqueAttr::get(ctx, cmpTok), + }); + + rewriter.create( + loc, TypeRange{}, "TGATHER", + /*args=*/ArrayAttr{}, /*templateArgs=*/targs, + /*operands=*/ValueRange{dst, src0, kValue, tmp, cdst, offsetVal}); + + rewriter.eraseOp(op); + return success(); + } + + // Case 3: mask-pattern TGATHER(dst, src0) + auto mp = op.getMaskPatternAttr(); + if (!mp) + return rewriter.notifyMatchFailure(op, "expected maskPattern, indices, or cdst on tgather"); + + auto dstTokOr = getOpaqueTok(dst, "dst"); + auto srcTokOr = getOpaqueTok(src0, "src0"); + if (failed(dstTokOr) || failed(srcTokOr)) + return failure(); + + // mp is an EnumAttr; stringify name is "P0101" etc. + // We emit MaskPattern::P0101 (because generated C++ has `using namespace pto;`) + std::string mpTok = std::string("MaskPattern::") + + mlir::pto::stringifyMaskPattern(mp.getValue()).str(); + + auto targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, *dstTokOr), + emitc::OpaqueAttr::get(ctx, *srcTokOr), + emitc::OpaqueAttr::get(ctx, mpTok), + }); + + rewriter.create( + loc, TypeRange{}, "TGATHER", + /*args=*/ArrayAttr{}, + /*templateArgs=*/targs, + /*operands=*/ValueRange{dst, src0}); + + rewriter.eraseOp(op); + return success(); + } +}; + + +struct PTOGatherbToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGatherBOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value offsets = peelUnrealized(adaptor.getOffsets()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TGATHERB", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, offsets}); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// TLOG lowering to EmitC (PTOConvert.cpp) +//===----------------------------------------------------------------------===// + +struct PTOLogToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TLogOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src}; + rewriter.create( + loc, TypeRange{}, "TLOG", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + + + +//===----------------------------------------------------------------------===// +// TLRELU lowering to EmitC (PTOConvert.cpp) +//===----------------------------------------------------------------------===// + + struct PTOLReluToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TLReluOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value slope = peelUnrealized(adaptor.getSlope()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, slope}; + + rewriter.create( + loc, TypeRange{}, "TLRELU", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// TMAX lowering to EmitC (PTOConvert.cpp) +//===----------------------------------------------------------------------===// + +struct PTOMaxToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TMAX", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// TMAXS lowering to EmitC (PTOConvert.cpp) +//===----------------------------------------------------------------------===// + + struct PTOMaxSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMaxSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, scalar}; + rewriter.create( + loc, TypeRange{}, "TMAXS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + + +//===----------------------------------------------------------------------===// +// TMIN lowering to EmitC (PTOConvert.cpp) +//===----------------------------------------------------------------------===// + +struct PTOMinToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TMIN", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// TMINS lowering to EmitC (PTOConvert.cpp) +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// TMINS lowering to EmitC (fix APFloat -> FloatAttr) (PTOConvert.cpp) +//===----------------------------------------------------------------------===// + +struct PTOMinsToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMinSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value scalar = peelUnrealized(adaptor.getScalar()); + + SmallVector operands{dst, src, scalar}; + rewriter.create( + loc, TypeRange{}, "TMINS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering for TMOV op -> EmitC) +//===----------------------------------------------------------------------===// + +struct PTOMovToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMovOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value fp; + if (op.getFp()) + fp = peelUnrealized(adaptor.getFp()); + Value preQuantScalar; + if (op.getPreQuantScalar()) + preQuantScalar = peelUnrealized(adaptor.getPreQuantScalar()); + + auto dstOT = mlir::dyn_cast(dst.getType()); + auto srcOT = mlir::dyn_cast(src.getType()); + if (!dstOT || !srcOT) + return rewriter.notifyMatchFailure( + op, "tmov lowering expects opaque dst/src types"); + + auto modeTok = [&](pto::AccToVecMode mode) -> StringRef { + switch (mode) { + case pto::AccToVecMode::SingleModeVec0: + return "pto::AccToVecMode::SingleModeVec0"; + case pto::AccToVecMode::SingleModeVec1: + return "pto::AccToVecMode::SingleModeVec1"; + case pto::AccToVecMode::DualModeSplitM: + return "pto::AccToVecMode::DualModeSplitM"; + case pto::AccToVecMode::DualModeSplitN: + return "pto::AccToVecMode::DualModeSplitN"; + } + llvm_unreachable("unknown AccToVecMode"); + }; + + auto modeAttr = op.getAccToVecModeAttr(); + auto reluTok = [&](pto::ReluPreMode mode) -> StringRef { + switch (mode) { + case pto::ReluPreMode::NoRelu: + return "ReluPreMode::NoRelu"; + case pto::ReluPreMode::NormalRelu: + return "ReluPreMode::NormalRelu"; + } + llvm_unreachable("unknown ReluPreMode"); + }; + + const bool hasFp = static_cast(fp); + const bool hasPreQuantScalar = static_cast(preQuantScalar); + const bool hasMode = static_cast(modeAttr); + const bool reluNonDefault = op.getReluPreMode() != pto::ReluPreMode::NoRelu; + + SmallVector operands{dst, src}; + SmallVector templateArgVec{ + emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), + emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), + }; + StringRef callee = "TMOV"; + + if (hasFp) { + auto fpOT = mlir::dyn_cast(fp.getType()); + if (!fpOT) + return rewriter.notifyMatchFailure( + op, "tmov fp lowering expects opaque fp type"); + operands.push_back(fp); + templateArgVec.push_back(emitc::OpaqueAttr::get(ctx, fpOT.getValue().str())); + if (hasMode) + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, modeTok(modeAttr.getValue()))); + if (hasMode || reluNonDefault) + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); + callee = hasMode ? "TMOV" : "TMOV_FP"; + } else if (hasPreQuantScalar) { + operands.push_back(preQuantScalar); + if (hasMode) + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, modeTok(modeAttr.getValue()))); + if (hasMode || reluNonDefault) + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); + } else if (hasMode) { + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, modeTok(modeAttr.getValue()))); + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); + } else if (reluNonDefault) { + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); + } + + ArrayAttr templateArgs = + templateArgVec.size() == 2 && !hasFp && !hasPreQuantScalar && + !hasMode && !reluNonDefault + ? ArrayAttr{} + : rewriter.getArrayAttr(templateArgVec); + + rewriter.create( + loc, TypeRange{}, callee, + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TMOV_FP DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOMovFPToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMovFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value src = peelUnrealized(adaptor.getSrc()); + Value fp = peelUnrealized(adaptor.getFp()); + + // TMOV_FP(dstTileData, cTile, fbTile) + ArrayAttr templateArgs; + auto dstOT = mlir::dyn_cast(dst.getType()); + auto srcOT = mlir::dyn_cast(src.getType()); + auto fpOT = mlir::dyn_cast(fp.getType()); + if (dstOT && srcOT && fpOT) { + templateArgs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), + emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), + emitc::OpaqueAttr::get(ctx, fpOT.getValue().str()), + }); + } else { + templateArgs = ArrayAttr{}; + } + + SmallVector operands{dst, src, fp}; + rewriter.create( + loc, TypeRange{}, "TMOV_FP", + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOQuantToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TQuantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value src = peelUnrealized(adaptor.getSrc()); + Value fp = peelUnrealized(adaptor.getFp()); + + // Optional offset (INT8_ASYM only): passed as pointer (&offset) + Value offsetPtr; + if (op.getOffset()) { + Value offset = peelUnrealized(adaptor.getOffset()); + auto offsetOT = mlir::dyn_cast(offset.getType()); + if (offsetOT) { + offsetPtr = rewriter + .create( + loc, emitc::PointerType::get(offsetOT), "&", offset) + .getResult(); + } + } + + // TQUANT(dst, src, fp[, &offset]) + std::string quantTypeStr = + op.getQuantType() == pto::QuantType::INT8_SYM + ? "pto::QuantType::INT8_SYM" + : "pto::QuantType::INT8_ASYM"; + ArrayAttr templateArgs; + auto dstOT = mlir::dyn_cast(dst.getType()); + auto srcOT = mlir::dyn_cast(src.getType()); + auto fpOT = mlir::dyn_cast(fp.getType()); + if (dstOT && srcOT && fpOT) { + templateArgs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, quantTypeStr), + emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), + emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), + emitc::OpaqueAttr::get(ctx, fpOT.getValue().str()), + }); + } else { + templateArgs = ArrayAttr{}; + } + + SmallVector operands{dst, src, fp}; + if (offsetPtr) + operands.push_back(offsetPtr); + + rewriter.create( + loc, TypeRange{}, "TQUANT", + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +struct PTODequantToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TDequantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value src = peelUnrealized(adaptor.getSrc()); + Value scale = peelUnrealized(adaptor.getScale()); + Value offset = peelUnrealized(adaptor.getOffset()); + + // TDEQUANT(dst, src, scale, offset) + ArrayAttr templateArgs; + auto dstOT = mlir::dyn_cast(dst.getType()); + auto srcOT = mlir::dyn_cast(src.getType()); + auto scaleOT = mlir::dyn_cast(scale.getType()); + if (dstOT && srcOT && scaleOT) { + templateArgs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), + emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), + emitc::OpaqueAttr::get(ctx, scaleOT.getValue().str()), + }); + } else { + templateArgs = ArrayAttr{}; + } + + rewriter.create( + loc, TypeRange{}, "TDEQUANT", + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, + /*operands=*/SmallVector{dst, src, scale, offset}); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TMRGSORT DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOMrgSortToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMrgSortOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + if (op.isFormat1()) { + Value src = peelUnrealized(adaptor.getSrcs().front()); + Value dst = peelUnrealized(adaptor.getDsts().front()); + Value blockLen = peelUnrealized(adaptor.getBlockLen()); + + SmallVector operands{dst, src, blockLen}; + rewriter.create( + loc, TypeRange{}, "TMRGSORT", + ArrayAttr{}, ArrayAttr{}, operands); + } else if (op.isFormat2()) { + // pto-isa API: + // TMRGSORT( + // dst, executedNumList, tmp, src0, src1[, src2[, src3]]); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDsts()[0]); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value excuted = peelUnrealized(adaptor.getExcuted()); + + SmallVector srcs; + srcs.reserve(adaptor.getSrcs().size()); + for (Value v : adaptor.getSrcs()) + srcs.push_back(peelUnrealized(v)); + + auto dstOT = mlir::dyn_cast(dst.getType()); + auto tmpOT = mlir::dyn_cast(tmp.getType()); + if (!dstOT || !tmpOT || srcs.size() < 2 || srcs.size() > 4) + return op.emitOpError("format2 expects dst/tmp tilebufs and 2 to 4 srcs"); + + SmallVector targs; + targs.reserve(2 + srcs.size() + 1); + targs.push_back(emitc::OpaqueAttr::get(ctx, dstOT.getValue().str())); + targs.push_back(emitc::OpaqueAttr::get(ctx, tmpOT.getValue().str())); + for (Value v : srcs) { + auto ot = mlir::dyn_cast(v.getType()); + if (!ot) + return op.emitOpError("format2 expects tilebuf srcs"); + targs.push_back(emitc::OpaqueAttr::get(ctx, ot.getValue().str())); + } + targs.push_back(emitc::OpaqueAttr::get(ctx, op.getExhausted() ? "true" : "false")); + ArrayAttr templateArgs = rewriter.getArrayAttr(targs); + + SmallVector operands{dst, excuted, tmp}; + operands.append(srcs.begin(), srcs.end()); + + rewriter.create( + loc, TypeRange{}, "TMRGSORT", + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, operands); + } else { + return op.emitOpError("unsupported mrgsort_dps format"); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TMUL DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOMulToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMulOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TMUL", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TMULS DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOMulsToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMulSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc0()); + Value dst = peelUnrealized(adaptor.getDst()); + Value scalar = peelUnrealized(adaptor.getScalar()); + + SmallVector operands{dst, src, scalar}; + rewriter.create( + loc, TypeRange{}, "TMULS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TNEG DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTONegToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TNegOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src}; + rewriter.create( + loc, TypeRange{}, "TNEG", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TNOT DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTONotToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TNotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src}; + rewriter.create( + loc, TypeRange{}, "TNOT", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TOR DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOOrToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TOrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TOR", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TORS DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOOrsToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TOrSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + // NOTE: The conversion type system may materialize integers as emitc.opaque + // (e.g. "int32_t"). For EmitC call emission we can pass the scalar through + // directly without arith casts here. + Value s = adaptor.getScalar(); + + SmallVector operands{dst, src0, s}; + rewriter.create( + loc, TypeRange{}, "TORS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TPARTADD DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOPartAddToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPartAddOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TPARTADD", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TPARTMAX DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOPartMaxToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPartMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TPARTMAX", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TPARTMIN DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOPartMinToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPartMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TPARTMIN", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOPartArgMaxToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPartArgMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value src0Idx = peelUnrealized(adaptor.getSrc0Idx()); + Value src1Idx = peelUnrealized(adaptor.getSrc1Idx()); + Value dst = peelUnrealized(adaptor.getDst()); + Value dstIdx = peelUnrealized(adaptor.getDstIdx()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TPARTARGMAX", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1, dstIdx, src0Idx, src1Idx}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOPartArgMinToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPartArgMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value src0Idx = peelUnrealized(adaptor.getSrc0Idx()); + Value src1Idx = peelUnrealized(adaptor.getSrc1Idx()); + Value dst = peelUnrealized(adaptor.getDst()); + Value dstIdx = peelUnrealized(adaptor.getDstIdx()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TPARTARGMIN", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1, dstIdx, src0Idx, src1Idx}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TPARTMUL DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOPartMulToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPartMulOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TPARTMUL", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TPRELU DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOPreluToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPReluOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + // C++ interface: TPRELU(dst, src0, src1, tmp) — last parameter is tmp. + SmallVector operands{dst, src0, src1, tmp}; + rewriter.create( + loc, TypeRange{}, "TPRELU", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TRECIP DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORecipToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRecipOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src}; + rewriter.create( + loc, TypeRange{}, "TRECIP", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TRELU DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOReluToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TReluOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src}; + rewriter.create( + loc, TypeRange{}, "TRELU", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TREM DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORemToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRemOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + SmallVector operands{dst, src0, src1, tmp}; + rewriter.create( + loc, TypeRange{}, "TREM", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOFModToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TFModOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TFMOD", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TREMS DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORemSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRemSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + Value scalar = peelUnrealized(adaptor.getScalar()); + SmallVector operands{dst, src, scalar, tmp}; + rewriter.create( + loc, TypeRange{}, "TREMS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOFModSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TFModSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value scalar = peelUnrealized(adaptor.getScalar()); + + SmallVector operands{dst, src, scalar}; + rewriter.create( + loc, TypeRange{}, "TFMODS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TROWEXPAND DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORowExpandToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src}; + rewriter.create( + loc, TypeRange{}, "TROWEXPAND", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTORowExpandAddToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandAddOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TROWEXPANDADD", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTORowExpandExpdifToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandExpdifOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); + + SmallVector operands; + if (tmp) + operands.assign({dst, src0, src1, tmp}); + else + operands.assign({dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TROWEXPANDEXPDIF", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TROWEXPANDDIV DPS/memref op) +//===----------------------------------------------------------------------===// +// Helper: replace or erase based on whether op has results. +static void replaceOrEraseWithOpaqueCall(Operation *op, + StringRef callee, + ArrayRef args, + ConversionPatternRewriter &rewriter) { + TypeRange resultTypes = op->getResultTypes(); + auto call = rewriter.create( + op->getLoc(), resultTypes, callee, ArrayAttr{}, ArrayAttr{}, ValueRange(args)); + if (resultTypes.empty()) + rewriter.eraseOp(op); + else + rewriter.replaceOp(op, call.getResults()); +} + +static void replaceOrEraseWithOpaqueCallAndReturnDst(Operation *op, Value dst, + StringRef callee, + ArrayRef args, + ConversionPatternRewriter &rewriter) { + rewriter.create( + op->getLoc(), TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, ValueRange(args)); + if (op->getNumResults() == 1) + rewriter.replaceOp(op, dst); + else + rewriter.eraseOp(op); +} + +// ---------- TOp ---------- +struct PTOTGemvBiasToTGEMV_BIAS + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGemvBiasOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value a = peelUnrealized(adaptor.getA()); + Value b = peelUnrealized(adaptor.getB()); + Value bias = peelUnrealized(adaptor.getBias()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCall(op.getOperation(), "TGEMV_BIAS", + {dst, a, b, bias}, rewriter); + return success(); + } +}; + +struct PTOTGemvMXToTGEMV_MX + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGemvMxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value a = peelUnrealized(adaptor.getA()); + Value aScale = peelUnrealized(adaptor.getAScale()); + Value b = peelUnrealized(adaptor.getB()); + Value bScale = peelUnrealized(adaptor.getBScale()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", + {dst, a, aScale, b, bScale}, rewriter); + return success(); + } +}; + +struct PTOTGemvMXAccToTGEMV_MX + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGemvMxAccOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value cIn = peelUnrealized(adaptor.getCIn()); + Value a = peelUnrealized(adaptor.getA()); + Value aScale = peelUnrealized(adaptor.getAScale()); + Value b = peelUnrealized(adaptor.getB()); + Value bScale = peelUnrealized(adaptor.getBScale()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", + {dst, cIn, a, aScale, b, bScale}, rewriter); + return success(); + } +}; + +struct PTOTGemvMXBiasToTGEMV_MX + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGemvMxBiasOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value a = peelUnrealized(adaptor.getA()); + Value aScale = peelUnrealized(adaptor.getAScale()); + Value b = peelUnrealized(adaptor.getB()); + Value bScale = peelUnrealized(adaptor.getBScale()); + Value bias = peelUnrealized(adaptor.getBias()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", + {dst, a, aScale, b, bScale, bias}, rewriter); + return success(); + } +}; + +struct PTOTMatmulBiasToTMATMUL_BIAS + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMatmulBiasOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value a = peelUnrealized(adaptor.getA()); + Value b = peelUnrealized(adaptor.getB()); + Value bias = peelUnrealized(adaptor.getBias()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_BIAS", + {dst, a, b, bias}, rewriter); + return success(); + } +}; + +struct PTOTMatmulMXToTMATMUL_MX + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMatmulMxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value a = peelUnrealized(adaptor.getA()); + Value aScale = peelUnrealized(adaptor.getAScale()); + Value b = peelUnrealized(adaptor.getB()); + Value bScale = peelUnrealized(adaptor.getBScale()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_MX", + {dst, a, aScale, b, bScale}, rewriter); + return success(); + } +}; + +struct PTOTMatmulMXAccToTMATMUL_MX_ACC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMatmulMxAccOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value cIn = peelUnrealized(adaptor.getCIn()); + Value a = peelUnrealized(adaptor.getA()); + Value aScale = peelUnrealized(adaptor.getAScale()); + Value b = peelUnrealized(adaptor.getB()); + Value bScale = peelUnrealized(adaptor.getBScale()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_MX", + {dst, cIn, a, aScale, b, bScale}, rewriter); + return success(); + } +}; + +struct PTOTMatmulMXBiasToTMATMUL_MX_BIAS + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMatmulMxBiasOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value a = peelUnrealized(adaptor.getA()); + Value aScale = peelUnrealized(adaptor.getAScale()); + Value b = peelUnrealized(adaptor.getB()); + Value bScale = peelUnrealized(adaptor.getBScale()); + Value bias = peelUnrealized(adaptor.getBias()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_MX", + {dst, a, aScale, b, bScale, bias}, rewriter); + return success(); + } +}; + +struct PTORowExpandDivToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandDivOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); + + SmallVector operands; + if (tmp) + operands.assign({dst, src0, src1, tmp}); + else + operands.assign({dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TROWEXPANDDIV", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TROWEXPANDMUL DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORowExpandMulToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandMulOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); + + SmallVector operands; + if (tmp) + operands.assign({dst, src0, src1, tmp}); + else + operands.assign({dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TROWEXPANDMUL", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TROWEXPANDSUB DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORowExpandSubToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandSubOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); + + SmallVector operands; + if (tmp) + operands.assign({dst, src0, src1, tmp}); + else + operands.assign({dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TROWEXPANDSUB", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTORowExpandMaxToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); + + SmallVector operands; + if (tmp) + operands.assign({dst, src0, src1, tmp}); + else + operands.assign({dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TROWEXPANDMAX", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTORowExpandMinToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); + + SmallVector operands; + if (tmp) + operands.assign({dst, src0, src1, tmp}); + else + operands.assign({dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TROWEXPANDMIN", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TROWMAX DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORowMaxToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, tmp}; + rewriter.create( + loc, TypeRange{}, "TROWMAX", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTORowArgMaxToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowArgMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TROWARGMAX", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, tmp}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TROWMIN DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORowMinToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, tmp}; + rewriter.create( + loc, TypeRange{}, "TROWMIN", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTORowArgMinToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowArgMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TROWARGMIN", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, tmp}); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TROWSUM DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORowSumToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowSumOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, tmp}; + rewriter.create( + loc, TypeRange{}, "TROWSUM", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTORowProdToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowProdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, tmp}; + rewriter.create( + loc, TypeRange{}, "TROWPROD", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TRSQRT DPS/memref op) +// - no-tmp form : TRSQRT(dst, src) +// - tmp form : TRSQRT(dst, src, tmp) +//===----------------------------------------------------------------------===// + +struct PTORsqrtToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRsqrtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + SmallVector operands{dst, src}; + if (Value tmp = adaptor.getTmp()) + operands.push_back(peelUnrealized(tmp)); + rewriter.create( + loc, TypeRange{}, "TRSQRT", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSCATTER DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOScatterToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TScatterOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + const bool hasMaskPattern = static_cast(op.getMaskPatternAttr()); + const bool hasIndexes = static_cast(op.getIndexes()); + if (hasMaskPattern == hasIndexes) { + return rewriter.notifyMatchFailure( + op, "expected exactly one of indexes operand or maskPattern attribute"); + } + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + if (auto mp = op.getMaskPatternAttr()) { + auto *ctx = rewriter.getContext(); + auto targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, maskPatternTok(mp)), + }); + rewriter.create( + loc, TypeRange{}, "TSCATTER", + /*args=*/ArrayAttr{}, /*templateArgs=*/targs, + /*operands=*/ValueRange{dst, src}); + } else { + Value idx = peelUnrealized(adaptor.getIndexes()); + rewriter.create( + loc, TypeRange{}, "TSCATTER", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, idx}); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSEL DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOSelToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSelOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value mask = peelUnrealized(adaptor.getMask()); + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, mask, src0, src1, tmp}; + rewriter.create( + loc, TypeRange{}, "TSEL", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSELS DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOSelSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSelSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value mask = peelUnrealized(adaptor.getMask()); + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, mask, src, tmp, scalar}; + rewriter.create( + loc, TypeRange{}, "TSELS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSHL DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOShlSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TShlOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TSHL", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSHR DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOShrSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TShrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TSHR", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering for TSHLS/TSHRS DPS: shift by scalar) +//===----------------------------------------------------------------------===// + +struct PTOShlSConstToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TShlSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value dst = peelUnrealized(adaptor.getDst()); + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + SmallVector operands{dst, src, scalar}; + rewriter.create( + loc, TypeRange{}, "TSHLS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOShrSConstToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TShrSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value dst = peelUnrealized(adaptor.getDst()); + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + SmallVector operands{dst, src, scalar}; + rewriter.create( + loc, TypeRange{}, "TSHRS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (TSORT32 DPS/memref op: ins(src, idx[, tmp]) outs(dst)) +//===----------------------------------------------------------------------===// + +struct PTOSORT32SToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSort32Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value idx = peelUnrealized(adaptor.getIdx()); + Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); + + SmallVector operands; + if (tmp) + operands.assign({dst, src, idx, tmp}); + else + operands.assign({dst, src, idx}); + rewriter.create( + loc, TypeRange{}, "TSORT32", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSQRT DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOSqrtSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSqrtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src}; + rewriter.create( + loc, TypeRange{}, "TSQRT", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSTORE_FP DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOStoreFPSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TStoreFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value fp = peelUnrealized(adaptor.getFp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, fp}; + rewriter.create( + loc, TypeRange{}, "TSTORE_FP", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSUB DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOSubSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSubOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TSUB", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSUBC DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOSubCSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSubCOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value src2 = peelUnrealized(adaptor.getSrc2()); + Value dst = peelUnrealized(adaptor.getDst()); + + // pto-isa does not provide NPU implementation for TSUBC yet. + // Decompose: dst = src0 - src1 + src2 + rewriter.create( + loc, TypeRange{}, "TSUB", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TADD", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, dst, src2}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSUBS DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOSubSSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSubSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, scalar}; + rewriter.create( + loc, TypeRange{}, "TSUBS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSUBSC DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOSubSCToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSubSCOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + // pto-isa does not provide NPU implementation for TSUBSC yet. + // Decompose: dst = src0 - scalar + src1 + rewriter.create( + loc, TypeRange{}, "TSUBS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, scalar}); + rewriter.create( + loc, TypeRange{}, "TADD", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, dst, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TXOR DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOXORToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TXorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + Value tmp = peelUnrealized(adaptor.getTmp()); + SmallVector operands{dst, src0, src1, tmp}; + rewriter.create( + loc, TypeRange{}, "TXOR", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +struct PTOTTransToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TTransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, tmp}; + rewriter.create( + loc, TypeRange{}, "TTRANS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TXORS DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOXORSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TXorSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, scalar, tmp}; + rewriter.create( + loc, TypeRange{}, "TXORS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + struct PTOPrintToTPRINT : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + + SmallVector operands{src}; + rewriter.create( + loc, TypeRange{}, "TPRINT", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +// pto.print "format", %scalar -> PRINTF("format", scalar) +struct PTOPrintOpToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::PrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + std::string fmt = op.getFormat().str(); + if (fmt.empty()) + fmt = "%f"; + std::string quoted = "\""; + for (char c : fmt) { + if (c == '"' || c == '\\') + quoted += '\\'; + else if (c == '\n') + quoted += "\\n"; + else if (c == '\t') + quoted += "\\t"; + else + quoted += c; + } + quoted += "\""; + + Value scalar = peelUnrealized(adaptor.getScalar()); + auto argsAttr = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, quoted), + IntegerAttr::get(IndexType::get(ctx), 0)}); + rewriter.create( + loc, TypeRange{}, "cce::printf", + /*args=*/argsAttr, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{scalar}); + + rewriter.eraseOp(op); + return success(); + } +}; + +// pto.trap -> TRAP() +struct PTOTrapOpToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TrapOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + rewriter.create( + loc, TypeRange{}, "trap", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{}); + + rewriter.eraseOp(op); + return success(); + } +}; + +// ============================================================================= +// 2. BindTileOp Lowering (FIX: Trace back to physical address) +// ============================================================================= +struct PTOBindTileToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + struct TileBuildSpec { + std::string tileTypeStr; + bool useConstructor = false; + SmallVector constructorArgs; + }; + + static bool getIndexConst(Value v, int64_t &out) { + if (!v) + return false; + if (auto cst = v.getDefiningOp()) { + if (auto ia = dyn_cast(cst.getValue())) { + out = ia.getValue().getSExtValue(); + return true; + } + } + return false; + } + + static bool getTilePointerStrides(pto::TileBufConfigAttr configAttr, + Type elemTy, int64_t rows, int64_t cols, + int64_t &rowStride, + int64_t &colStride) { + if (rows == ShapedType::kDynamic || cols == ShapedType::kDynamic) + return false; + + int32_t blVal = 0; + if (auto blAttr = dyn_cast(configAttr.getBLayout())) + blVal = static_cast(blAttr.getValue()); + else if (auto intAttr = dyn_cast(configAttr.getBLayout())) + blVal = static_cast(intAttr.getInt()); + + int32_t slVal = 0; + if (auto slAttr = dyn_cast(configAttr.getSLayout())) + slVal = static_cast(slAttr.getValue()); + else if (auto intAttr = dyn_cast(configAttr.getSLayout())) + slVal = static_cast(intAttr.getInt()); + + bool boxed = slVal != 0; + int64_t innerRows = 1; + int64_t innerCols = 1; + if (boxed) { + int32_t fractal = 512; + if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) + fractal = static_cast(frAttr.getInt()); + + unsigned elemBytes = pto::getPTOStorageElemByteSize(elemTy); + if (elemBytes == 0) + return false; + + switch (fractal) { + case 1024: + innerRows = 16; + innerCols = 16; + break; + case 32: + innerRows = 16; + innerCols = 2; + break; + case 512: + if (slVal == 1) { + innerRows = 16; + innerCols = 32 / elemBytes; + } else if (slVal == 2) { + innerRows = 32 / elemBytes; + innerCols = 16; + } else { + return false; + } + break; + default: + return false; + } + if (innerRows <= 0 || innerCols <= 0) + return false; + } + + if (!boxed) { + if (blVal == 1) { + rowStride = 1; + colStride = rows; + } else { + rowStride = cols; + colStride = 1; + } + return true; + } + + if (blVal == 1) { + if (slVal != 1) + return false; + rowStride = innerCols; + colStride = rows; + return true; + } + + rowStride = cols; + colStride = innerRows; + return true; + } + + LogicalResult matchAndRewrite(pto::BindTileOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto configAttr = op.getConfigAttr(); + auto viewSemantics = op->getAttrOfType("pto.view_semantics"); + bool isSubView = viewSemantics && viewSemantics.getValue() == "subview"; + + auto peelAllCasts = [](Value v) { + while (auto castOp = v.getDefiningOp()) + v = castOp.getOperand(0); + if (auto castOp = v.getDefiningOp()) + v = castOp.getOperand(); + return v; + }; + auto isTileLike = [](Value v) -> bool { + auto ot = dyn_cast(v.getType()); + if (!ot) + return false; + StringRef s = ot.getValue(); + return s.contains("Tile<") || s.contains("ConvTile<"); + }; + auto buildTileSpec = [&]() -> FailureOr { + auto resMrTy = dyn_cast(op.getType()); + if (!resMrTy) + return failure(); + + const char *roleTok = "TileType::Vec"; + if (auto asAttr = + dyn_cast_or_null(resMrTy.getMemorySpace())) { + switch (asAttr.getAddressSpace()) { + case pto::AddressSpace::VEC: + roleTok = "TileType::Vec"; + break; + case pto::AddressSpace::MAT: + roleTok = "TileType::Mat"; + break; + case pto::AddressSpace::LEFT: + roleTok = "TileType::Left"; + break; + case pto::AddressSpace::RIGHT: + roleTok = "TileType::Right"; + break; + case pto::AddressSpace::ACC: + roleTok = "TileType::Acc"; + break; + case pto::AddressSpace::BIAS: + roleTok = "TileType::Bias"; + break; + case pto::AddressSpace::SCALING: + roleTok = "TileType::Scaling"; + break; + case pto::AddressSpace::GM: + case pto::AddressSpace::Zero: + roleTok = "TileType::Vec"; + break; + } + } + + Type elemTy = resMrTy.getElementType(); + Type emitElemTy = getTypeConverter()->convertType(elemTy); + if (!emitElemTy) + return failure(); + auto emitElemOpaque = dyn_cast(emitElemTy); + if (!emitElemOpaque) + return failure(); + std::string elemTypeStr = emitElemOpaque.getValue().str(); + + if (resMrTy.getRank() < 2) + return failure(); + int64_t rows = resMrTy.getDimSize(0); + int64_t cols = resMrTy.getDimSize(1); + if (rows == ShapedType::kDynamic || cols == ShapedType::kDynamic) + return failure(); + + std::string blTok = "BLayout::RowMajor"; + if (auto blAttr = dyn_cast(configAttr.getBLayout())) { + if (static_cast(blAttr.getValue()) == 1) + blTok = "BLayout::ColMajor"; + } + pto::BLayout blayout = getTileBufBLayoutValue(configAttr); + + if (isSubView) { + auto subMrTy = dyn_cast(op.getSource().getType()); + auto subViewOp = op.getSource().getDefiningOp(); + if (subMrTy && subMrTy.getRank() >= 2 && subViewOp) { + int64_t subRows = subMrTy.getDimSize(0); + int64_t subCols = subMrTy.getDimSize(1); + SmallVector inheritedStrides; + int64_t inheritedOffset = ShapedType::kDynamic; + + if (!pto::isPTOFloat4PackedType(elemTy) && + subRows != ShapedType::kDynamic && + subCols != ShapedType::kDynamic && + succeeded(getStridesAndOffset(subMrTy, inheritedStrides, + inheritedOffset)) && + inheritedStrides.size() >= 2) { + int64_t childRowStride = 0; + int64_t childColStride = 0; + bool sameStrides = getTilePointerStrides( + configAttr, elemTy, subRows, subCols, childRowStride, + childColStride); + sameStrides = sameStrides && + inheritedStrides[0] == childRowStride && + inheritedStrides[1] == childColStride; + if (sameStrides) { + rows = subRows; + cols = subCols; + } + } + } + } + + std::string slTok = "SLayout::NoneBox"; + if (auto slAttr = dyn_cast(configAttr.getSLayout())) { + int32_t slVal = static_cast(slAttr.getValue()); + slTok = (slVal == 1) ? "SLayout::RowMajor" + : (slVal == 2) ? "SLayout::ColMajor" + : "SLayout::NoneBox"; + } + + int32_t fractal = 512; + if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) + fractal = frAttr.getInt(); + + std::string padTok = "PadValue::Null"; + if (auto padAttr = dyn_cast(configAttr.getPad())) { + switch (static_cast(padAttr.getValue())) { + case 1: + padTok = "PadValue::Zero"; + break; + case 2: + padTok = "PadValue::Max"; + break; + case 3: + padTok = "PadValue::Min"; + break; + default: + padTok = "PadValue::Null"; + break; + } + } + + std::string compactTok = "CompactMode::Null"; + if (auto compactAttr = dyn_cast(configAttr.getCompactMode())) { + switch (static_cast(compactAttr.getValue())) { + case 1: + compactTok = "CompactMode::Normal"; + break; + case 2: + compactTok = "CompactMode::RowPlusOne"; + break; + default: + compactTok = "CompactMode::Null"; + break; + } + } + + std::string vrowTok, vcolTok; + bool useConstructor = false; + bool rowIsDynamic = false; + bool colIsDynamic = false; + SmallVector constructorArgs; + + Value vRow = op.getValidRow(); + Value vCol = op.getValidCol(); + Value vRowEmitC = adaptor.getValidRow(); + Value vColEmitC = adaptor.getValidCol(); + bool forceDynamicValid = op->hasAttr(kForceDynamicValidShapeAttrName); + int64_t cRow = 0, cCol = 0; + bool rowIsConst = vRow && getIndexConst(vRow, cRow); + bool colIsConst = vCol && getIndexConst(vCol, cCol); + + auto makeCtorDimValue = [&](Value emitted, int64_t fallback) -> Value { + if (emitted) + return emitted; + return makeEmitCIntConstant( + rewriter, loc, emitc::OpaqueType::get(ctx, "int32_t"), fallback); + }; + auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { + if (!emitted || !pto::isPTOFloat4PackedType(elemTy)) + return emitted; + int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; + if (dimIdx != packedDim) + return emitted; + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); + return rewriter.create(loc, i32Ty, emitted, two).getResult(); + }; + + if (forceDynamicValid) { + vrowTok = "-1"; + vcolTok = "-1"; + useConstructor = true; + constructorArgs.push_back( + makeCtorDimValue(maybeScaleDynamicValid(vRowEmitC, 0), + renderTileTemplateDim(rowIsConst ? cRow : rows, + elemTy, blayout, 0))); + constructorArgs.push_back( + makeCtorDimValue(maybeScaleDynamicValid(vColEmitC, 1), + renderTileTemplateDim(colIsConst ? cCol : cols, + elemTy, blayout, 1))); + } else { + if (rowIsConst) { + vrowTok = std::to_string( + renderTileTemplateDim(cRow, elemTy, blayout, 0)); + } else if (vRow) { + vrowTok = "-1"; + rowIsDynamic = true; + useConstructor = true; + } else { + vrowTok = std::to_string( + renderTileTemplateDim(rows, elemTy, blayout, 0)); + } + + if (colIsConst) { + vcolTok = std::to_string( + renderTileTemplateDim(cCol, elemTy, blayout, 1)); + } else if (vCol) { + vcolTok = "-1"; + colIsDynamic = true; + useConstructor = true; + } else { + vcolTok = std::to_string( + renderTileTemplateDim(cols, elemTy, blayout, 1)); + } + + if (useConstructor) { + if (rowIsDynamic && vRowEmitC) + constructorArgs.push_back(maybeScaleDynamicValid(vRowEmitC, 0)); + if (colIsDynamic && vColEmitC) + constructorArgs.push_back(maybeScaleDynamicValid(vColEmitC, 1)); + } + } + + std::string tileTypeStr = std::string("Tile<") + roleTok + ", " + + elemTypeStr + ", " + + std::to_string(renderTileTemplateDim( + rows, elemTy, blayout, 0)) + + ", " + + std::to_string(renderTileTemplateDim( + cols, elemTy, blayout, 1)) + + ", " + blTok + + ", " + vrowTok + ", " + vcolTok + ", " + slTok + + ", " + std::to_string(fractal) + ", " + padTok + + ", " + compactTok + + ">"; + return TileBuildSpec{tileTypeStr, useConstructor, constructorArgs}; + }; + + auto buildTileValue = [&](const TileBuildSpec &spec, + bool forceDeclaration = false) -> Value { + auto tileType = emitc::OpaqueType::get(ctx, spec.tileTypeStr); + if (spec.useConstructor && !forceDeclaration) { + return rewriter + .create(loc, tileType, spec.tileTypeStr, + ArrayAttr{}, ArrayAttr{}, + ValueRange(spec.constructorArgs)) + .getResult(0); + } + + return rewriter + .create(loc, tileType, emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + }; + + auto emitElemTypeToString = [&](Type elemTy) -> std::string { + return getEmitCScalarTypeToken(elemTy); + }; + + auto buildIntegralAddress = [&](Value sourceValue) -> FailureOr { + auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); + auto rcU64 = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + + Value rawPtr = sourceValue; + if (auto ot = dyn_cast(sourceValue.getType())) { + StringRef tyStr = ot.getValue(); + if (tyStr.contains("Tile<") || tyStr.contains("ConvTile<")) { + auto srcMrTy = dyn_cast(op.getSource().getType()); + if (!srcMrTy) + return failure(); + std::string elemTok = emitElemTypeToString(srcMrTy.getElementType()); + pto::AddressSpace as = pto::AddressSpace::GM; + if (auto asAttr = + dyn_cast_or_null(srcMrTy.getMemorySpace())) + as = asAttr.getAddressSpace(); + rawPtr = materializeTileDataValue(rewriter, loc, sourceValue, as, + elemTok); + } + } + + if (isSetFFTsPointerLikeType(rawPtr.getType())) { + return rewriter + .create(loc, u64Ty, "reinterpret_cast", + ArrayAttr{}, rcU64, ValueRange{rawPtr}) + .getResult(0); + } + + if (rawPtr.getType() == u64Ty) + return rawPtr; + return rewriter.create(loc, u64Ty, rawPtr).getResult(); + }; + + if (op.getSource().getDefiningOp()) { + FailureOr tileSpec = buildTileSpec(); + if (failed(tileSpec)) + return failure(); + rewriter.replaceOp(op, buildTileValue(*tileSpec)); + return success(); + } + + Value tileCandidate = peelAllCasts(adaptor.getSource()); + if (viewSemantics && viewSemantics.getValue() == "bitcast" && + isTileLike(tileCandidate)) { + FailureOr tileSpec = buildTileSpec(); + if (failed(tileSpec)) + return failure(); + Value dstTile = buildTileValue(*tileSpec); + FailureOr addr = buildIntegralAddress(tileCandidate); + if (failed(addr)) + return failure(); + + rewriter.create(loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dstTile, *addr}); + rewriter.replaceOp(op, dstTile); + return success(); + } + + if (viewSemantics && viewSemantics.getValue() == "treshape" && + isTileLike(tileCandidate)) { + FailureOr tileSpec = buildTileSpec(); + if (failed(tileSpec)) + return failure(); + Value dstTile = buildTileValue(*tileSpec, /*forceDeclaration=*/true); + + rewriter.create(loc, TypeRange{}, "TRESHAPE", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dstTile, tileCandidate}); + rewriter.replaceOp(op, dstTile); + return success(); + } + + // Subview origins are kept distinct from generic tile rebinding: + // even when source/destination C++ tile types match, subview may carry + // shifted base address semantics and should materialize a fresh handle. + if (isSubView) { + FailureOr tileSpec = buildTileSpec(); + if (failed(tileSpec)) + return failure(); + Value dstTile = buildTileValue(*tileSpec); + FailureOr addr = buildIntegralAddress(tileCandidate); + if (failed(addr)) + return failure(); + + rewriter.create(loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dstTile, *addr}); + rewriter.replaceOp(op, dstTile); + return success(); + } + + // Generic tile-to-tile rebind path: preserve the same backing storage and + // rebuild a sibling tile with updated metadata/valid dims. + if (isTileLike(tileCandidate)) { + FailureOr tileSpec = buildTileSpec(); + if (failed(tileSpec)) + return failure(); + + if (!tileSpec->useConstructor) { + if (auto srcTy = dyn_cast(tileCandidate.getType())) { + if (srcTy.getValue() == tileSpec->tileTypeStr) { + rewriter.replaceOp(op, tileCandidate); + return success(); + } + } + } + + Value dstTile = buildTileValue(*tileSpec); + FailureOr addr = buildIntegralAddress(tileCandidate); + if (failed(addr)) + return failure(); + + rewriter.create(loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dstTile, *addr}); + rewriter.replaceOp(op, dstTile); + return success(); + } + + SmallVector physAddrs; + Value source = op.getSource(); + + while (auto castOp = source.getDefiningOp()) + source = castOp.getOperand(0); + + if (auto upstreamCast = source.getDefiningOp()) { + auto upstreamOperands = upstreamCast.getAddrs(); + physAddrs.append(upstreamOperands.begin(), upstreamOperands.end()); + } else { + physAddrs.push_back(adaptor.getSource()); + } + + Value vRow = op.getValidRow(); + Value vCol = op.getValidCol(); + + auto newCast = rewriter.create( + loc, op.getType(), physAddrs, vRow ? vRow : Value(), + vCol ? vCol : Value(), configAttr); + if (viewSemantics) + newCast->setAttr("pto.view_semantics", viewSemantics); + if (op->hasAttr(kForceDynamicValidShapeAttrName)) + newCast->setAttr(kForceDynamicValidShapeAttrName, + op->getAttr(kForceDynamicValidShapeAttrName)); + rewriter.replaceOp(op, newCast.getResult()); + + return success(); + } +}; + +struct PTOAllocTileToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::AllocTileOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + auto tileTy = cast(op.getResult().getType()); + auto tileTypeString = getEmitCTileTypeString(tileTy); + if (!tileTypeString) + return rewriter.notifyMatchFailure( + op, "only rank-2 alloc_tile handles can be converted to EmitC"); + + Type convertedTy = getTypeConverter()->convertType(tileTy); + if (!convertedTy) + convertedTy = emitc::OpaqueType::get(ctx, *tileTypeString); + + auto validShape = tileTy.getValidShape(); + bool hasDynamicValidDim = + llvm::any_of(validShape, [](int64_t dim) { return dim < 0; }); + bool useConstructor = hasDynamicValidDim; + + SmallVector constructorArgs; + if (useConstructor) { + Type elemTy = tileTy.getElementType(); + pto::BLayout blayout = getTileBufBLayoutValue(tileTy.getConfigAttr()); + auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { + if (!emitted || !pto::isPTOFloat4PackedType(elemTy)) + return emitted; + int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; + if (dimIdx != packedDim) + return emitted; + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); + return rewriter.create(loc, i32Ty, emitted, two) + .getResult(); + }; + + if (validShape.size() > 0 && validShape[0] < 0) { + Value validRow = adaptor.getValidRow(); + if (!validRow) + return rewriter.notifyMatchFailure( + op, "dynamic alloc_tile valid row must have an operand"); + if (validRow) + validRow = peelUnrealized(validRow); + constructorArgs.push_back(maybeScaleDynamicValid(validRow, 0)); + } + if (validShape.size() > 1 && validShape[1] < 0) { + Value validCol = adaptor.getValidCol(); + if (!validCol) + return rewriter.notifyMatchFailure( + op, "dynamic alloc_tile valid col must have an operand"); + if (validCol) + validCol = peelUnrealized(validCol); + constructorArgs.push_back(maybeScaleDynamicValid(validCol, 1)); + } + } + + Value tile; + if (useConstructor) { + tile = rewriter + .create( + loc, convertedTy, *tileTypeString, ArrayAttr{}, + ArrayAttr{}, ValueRange(constructorArgs)) + .getResult(0); + } else { + tile = + rewriter + .create( + loc, convertedTy, emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + } + + Value addr = adaptor.getAddr(); + if (addr) { + addr = peelUnrealized(addr); + auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); + if (isa(addr.getType()) || + (isa(addr.getType()) && + cast(addr.getType()).getValue().ends_with("*"))) { + auto rcU64 = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + addr = rewriter + .create(loc, u64Ty, "reinterpret_cast", + ArrayAttr{}, rcU64, + ValueRange{addr}) + .getResult(0); + } else if (addr.getType() != u64Ty) { + addr = rewriter.create(loc, u64Ty, addr).getResult(); + } + + rewriter.create(loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{tile, addr}); + } + + rewriter.replaceOp(op, tile); + return success(); + } +}; + +static FailureOr +createEmitCTileVariable(ConversionPatternRewriter &rewriter, Location loc, + const TypeConverter *typeConverter, + pto::TileBufType tileTy) { + auto tileTypeString = getEmitCTileTypeString(tileTy); + if (!tileTypeString) + return failure(); + + Type convertedTy = typeConverter->convertType(tileTy); + if (!convertedTy) + convertedTy = emitc::OpaqueType::get(rewriter.getContext(), *tileTypeString); + + return rewriter + .create( + loc, convertedTy, emitc::OpaqueAttr::get(rewriter.getContext(), "")) + .getResult(); +} + +struct PTOTReshapeToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TReshapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto tileTy = dyn_cast(op.getResult().getType()); + if (!tileTy) + return failure(); + + FailureOr dst = + createEmitCTileVariable(rewriter, op.getLoc(), getTypeConverter(), tileTy); + if (failed(dst)) + return failure(); + + Value src = peelUnrealized(adaptor.getSrc()); + if (auto castOp = src.getDefiningOp()) + src = castOp.getOperand(); + + rewriter.create(op.getLoc(), TypeRange{}, "TRESHAPE", + ArrayAttr{}, ArrayAttr{}, + ValueRange{*dst, src}); + rewriter.replaceOp(op, *dst); + return success(); + } +}; + +struct PTOBitcastToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::BitcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto dstTy = dyn_cast(op.getResult().getType()); + auto srcTy = dyn_cast(op.getSrc().getType()); + if (!dstTy || !srcTy) + return failure(); + + FailureOr dst = + createEmitCTileVariable(rewriter, op.getLoc(), getTypeConverter(), dstTy); + if (failed(dst)) + return failure(); + + Value src = peelUnrealized(adaptor.getSrc()); + if (auto castOp = src.getDefiningOp()) + src = castOp.getOperand(); + + pto::AddressSpace as = pto::AddressSpace::GM; + if (auto asAttr = + dyn_cast_or_null(srcTy.getMemorySpace())) + as = asAttr.getAddressSpace(); + std::string elemTok = getEmitCScalarTypeToken(srcTy.getElementType()); + + Value rawPtr = materializeTileDataValue(rewriter, op.getLoc(), src, as, elemTok); + auto u64Ty = emitc::OpaqueType::get(rewriter.getContext(), "uint64_t"); + Value addr = rawPtr; + if (isSetFFTsPointerLikeType(rawPtr.getType())) { + auto rcU64 = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), + "uint64_t")}); + addr = rewriter + .create(op.getLoc(), u64Ty, + "reinterpret_cast", ArrayAttr{}, + rcU64, ValueRange{rawPtr}) + .getResult(0); + } else if (addr.getType() != u64Ty) { + addr = rewriter.create(op.getLoc(), u64Ty, addr).getResult(); + } + + rewriter.create(op.getLoc(), TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{*dst, addr}); + rewriter.replaceOp(op, *dst); + return success(); + } +}; + +struct PTOMaterializeTileToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + static bool isTileLike(Value v) { + auto ot = dyn_cast(v.getType()); + if (!ot) + return false; + StringRef s = ot.getValue(); + return s.contains("Tile<") || s.contains("ConvTile<"); + } + + LogicalResult matchAndRewrite(pto::MaterializeTileOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + auto tileTy = cast(op.getResult().getType()); + auto tileTypeString = getEmitCTileTypeString(tileTy); + if (!tileTypeString) + return rewriter.notifyMatchFailure( + op, "only rank-2 tile_buf handles can be materialized to EmitC"); + + Type convertedTy = getTypeConverter()->convertType(tileTy); + if (!convertedTy) + convertedTy = emitc::OpaqueType::get(ctx, *tileTypeString); + + Value source = peelUnrealized(adaptor.getSource()); + if (auto castOp = source.getDefiningOp()) + source = castOp.getOperand(); + + auto viewSemantics = op->getAttrOfType("pto.view_semantics"); + bool forceDynamicValid = op->hasAttr(kForceDynamicValidShapeAttrName); + bool isReshape = viewSemantics && viewSemantics.getValue() == "treshape"; + bool isSubview = viewSemantics && viewSemantics.getValue() == "subview"; + bool sourceIsDeclaredTile = + op.getSource().getDefiningOp(); + + auto createTileValue = [&]() -> Value { + SmallVector constructorArgs; + bool useConstructor = false; + pto::BLayout blayout = getTileBufBLayoutValue(tileTy.getConfigAttr()); + Type elemTy = tileTy.getElementType(); + auto shape = tileTy.getShape(); + auto validShape = tileTy.getValidShape(); + + auto makeCtorDimValue = [&](Value emitted, int64_t fallback) -> Value { + if (emitted) + return emitted; + return makeEmitCIntConstant( + rewriter, loc, emitc::OpaqueType::get(ctx, "int32_t"), fallback); + }; + auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { + if (!emitted || !pto::isPTOFloat4PackedType(elemTy)) + return emitted; + int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; + if (dimIdx != packedDim) + return emitted; + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); + return rewriter.create(loc, i32Ty, emitted, two).getResult(); + }; + auto fallbackDim = [&](int dimIdx) { + return renderTileTemplateDim(shape[dimIdx], elemTy, blayout, dimIdx); + }; + + if (forceDynamicValid) { + useConstructor = true; + constructorArgs.push_back(makeCtorDimValue( + maybeScaleDynamicValid(adaptor.getValidRow(), 0), fallbackDim(0))); + constructorArgs.push_back(makeCtorDimValue( + maybeScaleDynamicValid(adaptor.getValidCol(), 1), fallbackDim(1))); + } else { + if (validShape[0] == ShapedType::kDynamic) { + useConstructor = true; + constructorArgs.push_back(makeCtorDimValue( + maybeScaleDynamicValid(adaptor.getValidRow(), 0), fallbackDim(0))); + } + if (validShape[1] == ShapedType::kDynamic) { + useConstructor = true; + constructorArgs.push_back(makeCtorDimValue( + maybeScaleDynamicValid(adaptor.getValidCol(), 1), fallbackDim(1))); + } + } + + if (useConstructor) { + return rewriter + .create(loc, convertedTy, *tileTypeString, + ArrayAttr{}, ArrayAttr{}, + ValueRange(constructorArgs)) + .getResult(0); + } + + return rewriter + .create(loc, convertedTy, + emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + }; + + if (!isSubview && !forceDynamicValid && isTileLike(source)) { + if (auto srcTy = dyn_cast(source.getType())) { + if (srcTy.getValue() == *tileTypeString) { + rewriter.replaceOp(op, source); + return success(); + } + } + } + + Value tile = createTileValue(); + if (sourceIsDeclaredTile) { + rewriter.replaceOp(op, tile); + return success(); + } + + if (isReshape && isTileLike(source)) { + rewriter.create(loc, TypeRange{}, "TRESHAPE", + ArrayAttr{}, ArrayAttr{}, + ValueRange{tile, source}); + rewriter.replaceOp(op, tile); + return success(); + } + + pto::AddressSpace as = pto::AddressSpace::GM; + if (auto asAttr = + dyn_cast_or_null(tileTy.getMemorySpace())) + as = asAttr.getAddressSpace(); + std::string elemTok = getEmitCScalarTypeToken(tileTy.getElementType()); + + Value rawPtr = source; + if (isTileLike(rawPtr)) + rawPtr = materializeTileDataValue(rewriter, loc, rawPtr, as, elemTok); + + auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); + Value addr = rawPtr; + if (isSetFFTsPointerLikeType(rawPtr.getType())) { + auto rcU64 = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + addr = rewriter + .create(loc, u64Ty, "reinterpret_cast", + ArrayAttr{}, rcU64, + ValueRange{rawPtr}) + .getResult(0); + } else if (rawPtr.getType() != u64Ty) { + addr = rewriter.create(loc, u64Ty, rawPtr).getResult(); + } + + rewriter.create(loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{tile, addr}); + rewriter.replaceOp(op, tile); + return success(); + } +}; + +// ============================================================================= +// Arith CmpI -> EmitC Cmp +// ============================================================================= +class ArithCmpIToEmitC : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + // 将 arith.cmpi 转换为 emitc.cmp + // 映射 Predicate: eq -> equal, slt -> less, etc. + emitc::CmpPredicate emitcPred = emitc::CmpPredicate::eq; + const bool isUnsignedPred = + op.getPredicate() == arith::CmpIPredicate::ult || + op.getPredicate() == arith::CmpIPredicate::ule || + op.getPredicate() == arith::CmpIPredicate::ugt || + op.getPredicate() == arith::CmpIPredicate::uge; + switch (op.getPredicate()) { + case arith::CmpIPredicate::eq: emitcPred = emitc::CmpPredicate::eq; break; + case arith::CmpIPredicate::ne: emitcPred = emitc::CmpPredicate::ne; break; + case arith::CmpIPredicate::slt: emitcPred = emitc::CmpPredicate::lt; break; + case arith::CmpIPredicate::sle: emitcPred = emitc::CmpPredicate::le; break; + case arith::CmpIPredicate::sgt: emitcPred = emitc::CmpPredicate::gt; break; + case arith::CmpIPredicate::sge: emitcPred = emitc::CmpPredicate::ge; break; + // ... 处理无符号比较 (ult, ule 等) ... + case arith::CmpIPredicate::ult: emitcPred = emitc::CmpPredicate::lt; break; + case arith::CmpIPredicate::ule: emitcPred = emitc::CmpPredicate::le; break; + case arith::CmpIPredicate::ugt: emitcPred = emitc::CmpPredicate::gt; break; + case arith::CmpIPredicate::uge: emitcPred = emitc::CmpPredicate::ge; break; + } + + Type resTy = getTypeConverter()->convertType(op.getType()); + if (!resTy) + return failure(); + + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + if (isUnsignedPred) { + Type opTy = op.getLhs().getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure( + op, "expected scalar integer or index operands"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + if (bitWidth != 1) { + lhs = castSignlessIntToUnsignedSameWidth(rewriter, loc, lhs, bitWidth); + rhs = castSignlessIntToUnsignedSameWidth(rewriter, loc, rhs, bitWidth); + } + } + + rewriter.replaceOpWithNewOp( + op, + /*resultType=*/resTy, // i1 -> bool/i1 + emitcPred, + lhs, + rhs + ); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Section Op Lowering +//===----------------------------------------------------------------------===// +static bool isA5NoSplitPipeOp(Operation *op) { + if (auto talloc = dyn_cast(op)) + return talloc.getSplit() == 0; + if (auto tpush = dyn_cast(op)) + return tpush.getSplit() == 0; + if (auto tpop = dyn_cast(op)) + return tpop.getSplit() == 0; + if (auto tfree = dyn_cast(op)) + return tfree.getSplit() == 0; + if (auto tpush = dyn_cast(op)) + return tpush.getSplit() == 0; + if (auto tpush = dyn_cast(op)) + return tpush.getSplit() == 0; + if (auto talloc = dyn_cast(op)) + return talloc.getSplit() == 0; + if (auto talloc = dyn_cast(op)) + return talloc.getSplit() == 0; + if (auto tpop = dyn_cast(op)) + return tpop.getSplit() == 0; + if (auto tpop = dyn_cast(op)) + return tpop.getSplit() == 0; + if (auto tfree = dyn_cast(op)) + return tfree.getSplit() == 0; + if (auto tfree = dyn_cast(op)) + return tfree.getSplit() == 0; + return false; +} + +static bool hasExplicitSubblockControl(Operation *op) { + bool hasControl = false; + op->walk([&](Operation *nested) { + if (isa(nested)) { + hasControl = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return hasControl; +} + +static bool needsA5NoSplitVectorGuard(Operation *op) { + auto arch = getTargetArch(op); + if (arch != PTOArch::A5) + return false; + bool isVectorScope = isa(op); + if (auto func = dyn_cast(op)) { + if (auto kernelKindAttr = + func->getAttrOfType( + FunctionKernelKindAttr::name)) { + isVectorScope = + kernelKindAttr.getKernelKind() == FunctionKernelKind::Vector; + } + } + if (!isVectorScope) + return false; + if (hasExplicitSubblockControl(op)) + return false; + + bool hasNoSplitPipe = false; + op->walk([&](Operation *nested) { + if (!isA5NoSplitPipeOp(nested)) + return WalkResult::advance(); + hasNoSplitPipe = true; + return WalkResult::interrupt(); + }); + return hasNoSplitPipe; +} + +template +struct SectionToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + std::string getMacroName() const { + if (std::is_same::value) + return "__DAV_CUBE__"; + if (std::is_same::value) + return "__DAV_VEC__"; + return "UNKNOWN_MACRO"; + } + + LogicalResult + matchAndRewrite(SectionOpTy op, typename SectionOpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + bool needsNoSplitGuard = needsA5NoSplitVectorGuard(op.getOperation()); + + std::string startMacro = "\n#if defined(" + getMacroName() + ")"; + rewriter.create(loc, startMacro); + + if constexpr (std::is_same_v) { + // Vector mask is a global HW state and may be modified by previous kernels + // (or earlier sections). Reset it to a well-defined state for deterministic + // execution of VEC ops. + rewriter.create(loc, "set_mask_norm();"); + rewriter.create(loc, "set_vector_mask(-1, -1);"); + } + + if (needsNoSplitGuard) { + rewriter.create( + loc, "if (get_subblockid() == 0) {"); + } + + Block &innerBlock = op.getBody().front(); + if (!innerBlock.empty()) { + rewriter.inlineBlockBefore(&innerBlock, op.getOperation(), ValueRange{}); + } + + if (needsNoSplitGuard) + rewriter.create(loc, "}"); + + std::string endMacro = "#endif // " + getMacroName() + "\n"; + rewriter.create(loc, endMacro); + + rewriter.eraseOp(op); + + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// SCF Control-Flow Pre-Lowering +// +// EmitC translation supports `emitc.for`/`emitc.if` plus CFG-style +// `cf.br`/`cf.cond_br`. Upstream SCFToEmitC patterns only cover `scf.for` and +// `scf.if`, so we pre-lower some SCF ops into those supported forms. +//===----------------------------------------------------------------------===// + +namespace { + +static bool isTriviallyInlineableExecuteRegion(scf::ExecuteRegionOp op) { + Region &r = op.getRegion(); + if (!r.hasOneBlock()) + return false; + Block &b = r.front(); + return isa_and_nonnull(b.getTerminator()); +} + +static bool needsWholeFunctionSCFToCF(func::FuncOp func) { + bool needs = false; + func.walk([&](Operation *op) { + if (!isa(op)) + return WalkResult::advance(); + Operation *parentOp = op->getParentOp(); + + // `scf.execute_region` can legally appear in single-block parents. Only + // require whole-function SCFToCF if we need to lower it into CFG blocks + // (multi-block region / non-trivial terminators). + if (auto exec = dyn_cast(op)) { + if (parentOp && parentOp->hasTrait() && + !isTriviallyInlineableExecuteRegion(exec)) { + needs = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + } + + if (parentOp && parentOp->hasTrait()) { + needs = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return needs; +} + +// scf.execute_region is semantically just an inlined region producing results +// via scf.yield. Inline it to the parent block to avoid extra lowering needs. +struct SCFExecuteRegionInline + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ExecuteRegionOp op, + PatternRewriter &rewriter) const override { + if (op.getRegion().empty()) + return rewriter.notifyMatchFailure(op, "expected non-empty region"); + + Block &innerBlock = op.getRegion().front(); + auto yield = dyn_cast(innerBlock.getTerminator()); + if (!yield) + return rewriter.notifyMatchFailure(op, "expected scf.yield terminator"); + + // Move the body operations before the execute_region op. + rewriter.inlineBlockBefore(&innerBlock, op.getOperation(), ValueRange{}); + + // Replace execute_region results with yielded values, then erase the yield. + rewriter.replaceOp(op, yield.getOperands()); + rewriter.eraseOp(yield); + return success(); + } +}; + +// Lower scf.execute_region into CFG blocks with cf.br/cf.cond_br by inlining the +// region blocks into the parent region and rewriting scf.yield to branch into a +// continuation block carrying results. +// +// Note: This requires the parent region to allow multiple blocks (e.g. the +// function body CFG region). For execute_region nested in single-block regions +// (scf.for/scf.if), run SCFToCF first to eliminate the single-block constraint. +struct SCFExecuteRegionToCF : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ExecuteRegionOp op, + PatternRewriter &rewriter) const override { + if (isTriviallyInlineableExecuteRegion(op)) + return rewriter.notifyMatchFailure(op, "trivially inlineable"); + + Operation *parentOp = op->getParentOp(); + if (parentOp && parentOp->hasTrait()) { + return rewriter.notifyMatchFailure( + op, "cannot lower scf.execute_region inside a single-block parent region"); + } + + if (op.getRegion().empty()) + return rewriter.notifyMatchFailure(op, "expected non-empty region"); + + Location loc = op.getLoc(); + Block *curBlock = op->getBlock(); + Region *parentRegion = curBlock->getParent(); + + // Split the parent block so we can branch to a continuation block with phi + // arguments for the execute_region results. + auto execIt = Block::iterator(op.getOperation()); + Block *continueBlock = rewriter.splitBlock(curBlock, std::next(execIt)); + + SmallVector contArgs; + contArgs.reserve(op.getNumResults()); + for (Type t : op.getResultTypes()) + contArgs.push_back(continueBlock->addArgument(t, loc)); + + for (auto it : llvm::enumerate(op.getResults())) + it.value().replaceAllUsesWith(contArgs[it.index()]); + + // Capture blocks before moving the region. + SmallVector movedBlocks; + movedBlocks.reserve(op.getRegion().getBlocks().size()); + for (Block &b : op.getRegion()) + movedBlocks.push_back(&b); + Block *entryBlock = &op.getRegion().front(); + + // Inline the execute_region blocks into the parent region right before the + // continuation block. + rewriter.inlineRegionBefore(op.getRegion(), *parentRegion, + continueBlock->getIterator()); + + // Replace all scf.yield terminators with a branch to the continuation. + for (Block *b : movedBlocks) { + auto yield = dyn_cast(b->getTerminator()); + if (!yield) + continue; + rewriter.setInsertionPoint(yield); + rewriter.create(loc, continueBlock, yield.getOperands()); + rewriter.eraseOp(yield); + } + + // Replace execute_region itself with a branch to the inlined entry block. + rewriter.setInsertionPoint(op); + rewriter.create(loc, entryBlock, ValueRange{}); + rewriter.eraseOp(op); + return success(); + } +}; + +// Lower scf.index_switch into CFG blocks with cf.cond_br/cf.br so that we can +// avoid `scf.if` result materialization quirks (and avoid relying on cf.switch, +// which is not supported by EmitC C++ translation). +struct SCFIndexSwitchToCF : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + static LogicalResult cloneYieldingBlockAndBranchTo( + PatternRewriter &rewriter, Location loc, Block &srcBlock, Block *destBlock, + Block *continueBlock) { + rewriter.setInsertionPointToEnd(destBlock); + + IRMapping mapping; + for (Operation &inner : srcBlock.without_terminator()) + rewriter.clone(inner, mapping); + + auto yield = dyn_cast(srcBlock.getTerminator()); + if (!yield) + return failure(); + + SmallVector yieldOperands; + yieldOperands.reserve(yield.getNumOperands()); + for (Value v : yield.getOperands()) + yieldOperands.push_back(mapping.lookupOrDefault(v)); + + rewriter.create(loc, continueBlock, yieldOperands); + return success(); + } + + static Block *splitBlockForContinuation(PatternRewriter &rewriter, + scf::IndexSwitchOp op) { + auto switchIt = Block::iterator(op.getOperation()); + return rewriter.splitBlock(op->getBlock(), std::next(switchIt)); + } + + static void addContinuationArguments(PatternRewriter &rewriter, + scf::IndexSwitchOp op, Location loc, + Block *continueBlock) { + SmallVector contArgs; + contArgs.reserve(op.getNumResults()); + for (Type type : op.getResultTypes()) + contArgs.push_back(continueBlock->addArgument(type, loc)); + for (auto result : llvm::enumerate(op.getResults())) + result.value().replaceAllUsesWith(contArgs[result.index()]); + } + + static void createIndexSwitchBlocks(PatternRewriter &rewriter, + Region *parentRegion, + Region::iterator insertPt, + unsigned numCases, + SmallVectorImpl &checkBlocks, + Block *&defaultBlock, + SmallVectorImpl &caseBlocks) { + checkBlocks.reserve(numCases); + caseBlocks.reserve(numCases); + for (unsigned i = 0; i < numCases; ++i) + checkBlocks.push_back(rewriter.createBlock(parentRegion, insertPt)); + defaultBlock = rewriter.createBlock(parentRegion, insertPt); + for (unsigned i = 0; i < numCases; ++i) + caseBlocks.push_back(rewriter.createBlock(parentRegion, insertPt)); + } + + static void populateIndexSwitchCheckBlocks( + PatternRewriter &rewriter, Location loc, Value selector, + ArrayRef cases, ArrayRef checkBlocks, + ArrayRef caseBlocks, Block *defaultBlock) { + for (unsigned i = 0; i < checkBlocks.size(); ++i) { + rewriter.setInsertionPointToEnd(checkBlocks[i]); + Value caseVal = rewriter.create(loc, cases[i]); + Value cond = rewriter.create( + loc, arith::CmpIPredicate::eq, selector, caseVal); + Block *falseDest = + (i + 1 < checkBlocks.size()) ? checkBlocks[i + 1] : defaultBlock; + rewriter.create(loc, cond, caseBlocks[i], ValueRange{}, + falseDest, ValueRange{}); + } + } + + LogicalResult matchAndRewrite(scf::IndexSwitchOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Operation *parentOp = op->getParentOp(); + if (parentOp && parentOp->hasTrait()) { + return rewriter.notifyMatchFailure( + op, "cannot lower scf.index_switch inside a single-block parent region"); + } + + Block *curBlock = op->getBlock(); + Region *parentRegion = curBlock->getParent(); + Block *continueBlock = splitBlockForContinuation(rewriter, op); + addContinuationArguments(rewriter, op, loc, continueBlock); + + unsigned numCases = op.getCases().size(); + auto insertPt = continueBlock->getIterator(); + + SmallVector checkBlocks; + SmallVector caseBlocks; + Block *defaultBlock = nullptr; + createIndexSwitchBlocks(rewriter, parentRegion, insertPt, numCases, + checkBlocks, defaultBlock, caseBlocks); + + Value selector = op.getArg(); + auto cases = op.getCases(); + populateIndexSwitchCheckBlocks(rewriter, loc, selector, cases, checkBlocks, + caseBlocks, defaultBlock); + + // Fill case blocks and default block with cloned bodies + branch to cont. + for (unsigned i = 0; i < numCases; ++i) { + if (failed(cloneYieldingBlockAndBranchTo( + rewriter, loc, op.getCaseBlock(i), caseBlocks[i], continueBlock))) + return rewriter.notifyMatchFailure(op, "expected scf.yield terminator"); + } + if (failed(cloneYieldingBlockAndBranchTo(rewriter, loc, op.getDefaultBlock(), + defaultBlock, continueBlock))) + return rewriter.notifyMatchFailure(op, "expected scf.yield terminator"); + + // Replace the original switch op with a branch into the check chain. + Block *entryDest = numCases ? checkBlocks[0] : defaultBlock; + rewriter.setInsertionPointAfter(op); + rewriter.create(loc, entryDest, ValueRange{}); + rewriter.eraseOp(op); + return success(); + } +}; + +// Lower scf.while into CFG blocks with cf.br/cf.cond_br. +// +// Note: This requires the parent region to allow multiple blocks. In +// particular, scf.if/scf.for regions are single-block and cannot contain this +// lowering. +struct SCFWhileToCF : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + static LogicalResult validateWhileResultUses(scf::WhileOp op) { + Block *parentBlock = op->getBlock(); + for (Value result : op.getResults()) { + for (OpOperand &use : result.getUses()) { + if (use.getOwner()->getBlock() != parentBlock) + return failure(); + } + } + return success(); + } + + static Block *splitAfterWhileBlock(PatternRewriter &rewriter, + scf::WhileOp op) { + auto whileIt = Block::iterator(op.getOperation()); + return rewriter.splitBlock(op->getBlock(), std::next(whileIt)); + } + + static void addWhileExitArguments(PatternRewriter &rewriter, scf::WhileOp op, + Location loc, Block *afterWhileBlock) { + SmallVector exitArgs; + exitArgs.reserve(op.getNumResults()); + for (Type type : op.getResultTypes()) + exitArgs.push_back(afterWhileBlock->addArgument(type, loc)); + for (auto result : llvm::enumerate(op.getResults())) + result.value().replaceAllUsesWith(exitArgs[result.index()]); + } + + static Block *createWhileHeaderBlock(PatternRewriter &rewriter, + scf::WhileOp op, Location loc, + Block *afterWhileBlock) { + SmallVector headerArgTypes; + for (Value init : op.getInits()) + headerArgTypes.push_back(init.getType()); + SmallVector headerArgLocs(headerArgTypes.size(), loc); + return rewriter.createBlock(afterWhileBlock->getParent(), + afterWhileBlock->getIterator(), headerArgTypes, + headerArgLocs); + } + + static Block *createWhileBodyBlock(PatternRewriter &rewriter, scf::WhileOp op, + Location loc, Block *afterWhileBlock) { + Block &afterRegionBlock = op.getAfter().front(); + SmallVector bodyArgTypes(afterRegionBlock.getArgumentTypes().begin(), + afterRegionBlock.getArgumentTypes().end()); + SmallVector bodyArgLocs(bodyArgTypes.size(), loc); + return rewriter.createBlock(afterWhileBlock->getParent(), + afterWhileBlock->getIterator(), bodyArgTypes, + bodyArgLocs); + } + + static void rewriteWhileTerminators(PatternRewriter &rewriter, Location loc, + Block *headerBlock, Block *bodyBlock, + Block *afterWhileBlock) { + auto condOp = cast(headerBlock->getTerminator()); + rewriter.setInsertionPoint(condOp); + rewriter.create(loc, condOp.getCondition(), + /*trueDest=*/bodyBlock, + /*trueOperands=*/condOp.getArgs(), + /*falseDest=*/afterWhileBlock, + /*falseOperands=*/condOp.getArgs()); + rewriter.eraseOp(condOp); + + auto yieldOp = cast(bodyBlock->getTerminator()); + rewriter.setInsertionPoint(yieldOp); + rewriter.create(loc, headerBlock, yieldOp.getOperands()); + rewriter.eraseOp(yieldOp); + } + + LogicalResult matchAndRewrite(scf::WhileOp op, + PatternRewriter &rewriter) const override { + Operation *parentOp = op->getParentOp(); + if (parentOp && parentOp->hasTrait()) { + return rewriter.notifyMatchFailure( + op, "cannot lower scf.while inside a single-block parent region"); + } + + if (failed(validateWhileResultUses(op))) + return rewriter.notifyMatchFailure( + op, "unsupported: while results used outside the parent block"); + + auto loc = op.getLoc(); + Block *afterWhileBlock = splitAfterWhileBlock(rewriter, op); + addWhileExitArguments(rewriter, op, loc, afterWhileBlock); + Block *headerBlock = createWhileHeaderBlock(rewriter, op, loc, + afterWhileBlock); + Block *bodyBlock = createWhileBodyBlock(rewriter, op, loc, afterWhileBlock); + + // Move the before/after region bodies into the new CFG blocks. + Block &afterRegionBlock = op.getAfter().front(); + rewriter.mergeBlocks(&op.getBefore().front(), headerBlock, + headerBlock->getArguments()); + rewriter.mergeBlocks(&afterRegionBlock, bodyBlock, bodyBlock->getArguments()); + rewriteWhileTerminators(rewriter, loc, headerBlock, bodyBlock, + afterWhileBlock); + + // Replace scf.while itself with a branch to the header. + rewriter.setInsertionPoint(op); + rewriter.create(loc, headerBlock, op.getInits()); + rewriter.eraseOp(op); + return success(); + } +}; + +// Lower cf.switch into chained comparisons and cf.cond_br/cf.br. +// +// EmitC C++ translation currently supports cf.br/cf.cond_br, but not cf.switch. +struct CFSwitchToCondBr : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + static SmallVector> + collectSwitchCaseOperands(cf::SwitchOp op) { + SmallVector> caseOperands; + caseOperands.reserve(op.getCaseDestinations().size()); + for (auto range : op.getCaseOperands()) + caseOperands.emplace_back(range.begin(), range.end()); + return caseOperands; + } + + static SmallVector getSwitchCaseValues(cf::SwitchOp op) { + SmallVector caseValues; + if (auto caseValuesAttr = op.getCaseValues()) { + for (APInt value : caseValuesAttr->getValues()) + caseValues.push_back(value); + } + return caseValues; + } + + static SmallVector createSwitchCheckBlocks(PatternRewriter &rewriter, + Region *parentRegion, + Block *curBlock, + size_t numCases) { + auto insertPt = std::next(curBlock->getIterator()); + SmallVector checkBlocks; + checkBlocks.reserve(numCases); + for (size_t i = 0; i < numCases; ++i) + checkBlocks.push_back(rewriter.createBlock(parentRegion, insertPt)); + return checkBlocks; + } + + static LogicalResult populateSwitchCheckBlocks( + PatternRewriter &rewriter, Location loc, Value flag, IntegerType flagTy, + ArrayRef caseValues, ArrayRef caseDests, + ArrayRef> caseOperands, Block *defaultDest, + ValueRange defaultOperands, ArrayRef checkBlocks, + cf::SwitchOp op) { + for (size_t i = 0; i < caseDests.size(); ++i) { + rewriter.setInsertionPointToEnd(checkBlocks[i]); + APInt caseVal = caseValues[i]; + if (caseVal.getBitWidth() != flagTy.getWidth()) { + return rewriter.notifyMatchFailure( + op, "case value bitwidth doesn't match flag type"); + } + + Value caseConst = rewriter.create( + loc, flagTy, rewriter.getIntegerAttr(flagTy, caseVal)); + Value cond = rewriter.create( + loc, arith::CmpIPredicate::eq, flag, caseConst); + Block *falseDest = + (i + 1 < checkBlocks.size()) ? checkBlocks[i + 1] : defaultDest; + ValueRange falseOperands = + (i + 1 < checkBlocks.size()) ? ValueRange{} : defaultOperands; + rewriter.create(loc, cond, caseDests[i], + caseOperands[i], falseDest, + falseOperands); + } + return success(); + } + + LogicalResult matchAndRewrite(cf::SwitchOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Operation *parentOp = op->getParentOp(); + if (parentOp && parentOp->hasTrait()) { + return rewriter.notifyMatchFailure( + op, "cannot lower cf.switch inside a single-block parent region"); + } + + Block *curBlock = op->getBlock(); + Region *parentRegion = curBlock->getParent(); + + Value flag = op.getFlag(); + auto flagTy = dyn_cast(flag.getType()); + if (!flagTy) + return rewriter.notifyMatchFailure(op, "expected integer switch flag"); + + SmallVector defaultOperands(op.getDefaultOperands().begin(), + op.getDefaultOperands().end()); + Block *defaultDest = op.getDefaultDestination(); + + SmallVector caseDests(op.getCaseDestinations().begin(), + op.getCaseDestinations().end()); + SmallVector> caseOperands = collectSwitchCaseOperands(op); + + if (caseDests.empty()) { + rewriter.replaceOpWithNewOp(op, defaultDest, defaultOperands); + return success(); + } + + if (!op.getCaseValues()) + return rewriter.notifyMatchFailure(op, "missing case_values"); + SmallVector caseValues = getSwitchCaseValues(op); + + if (caseValues.size() != caseDests.size()) + return rewriter.notifyMatchFailure(op, "case_values/destinations mismatch"); + if (caseOperands.size() != caseDests.size()) + return rewriter.notifyMatchFailure(op, "case_operands/destinations mismatch"); + + SmallVector checkBlocks = + createSwitchCheckBlocks(rewriter, parentRegion, curBlock, + caseDests.size()); + if (failed(populateSwitchCheckBlocks(rewriter, loc, flag, flagTy, + caseValues, caseDests, caseOperands, + defaultDest, defaultOperands, + checkBlocks, op))) { + return failure(); + } + + // Replace the switch terminator with a branch into the first check block. + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp(op, checkBlocks.front(), + ValueRange{}); + return success(); + } +}; + +} // namespace + +static void populatePTOToEmitCPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx, + DataFlowSolver &solver, + PTOArch targetArch) { + (void)solver; + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx, "pto.set_flag_dyn", + "set_flag"); + patterns.add(typeConverter, ctx, "pto.wait_flag_dyn", + "wait_flag"); + // Backward-compatible aliases used in some downstream branches. + patterns.add(typeConverter, ctx, "pto.set_flag_d", + "set_flag"); + patterns.add(typeConverter, ctx, "pto.wait_flag_d", + "wait_flag"); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add( + typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add( + typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add>( + typeConverter, ctx); + patterns.add>( + typeConverter, ctx); + patterns.add>( + typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add>(typeConverter, + ctx); + patterns.add>(typeConverter, + ctx); + patterns.add>(typeConverter, + ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add>( + typeConverter, ctx, + "pto::comm::TPUT_ASYNC"); + patterns.add>( + typeConverter, ctx, + "pto::comm::TGET_ASYNC"); + patterns.add>(typeConverter, ctx, + "pto::comm::TPUT"); + patterns.add>(typeConverter, ctx, + "pto::comm::TGET"); + patterns.add>(typeConverter, ctx, + "pto::comm::TNOTIFY"); + patterns.add>(typeConverter, ctx, + "pto::comm::TWAIT"); + patterns.add>(typeConverter, ctx, + "pto::comm::TTEST"); + patterns.add>(typeConverter, ctx, + "TBROADCAST"); + patterns.add>(typeConverter, ctx, + "TGATHER"); + patterns.add>(typeConverter, ctx, + "TSCATTER"); + patterns.add>(typeConverter, ctx, + "TREDUCE"); + patterns.add>( + typeConverter, ctx, "PTOAS__ASYNC_EVENT_WAIT"); + patterns.add>( + typeConverter, ctx, "PTOAS__ASYNC_EVENT_TEST"); + patterns.add(typeConverter, ctx, targetArch); + patterns.add(typeConverter, ctx, targetArch); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx, targetArch); + patterns.add(typeConverter, ctx, targetArch); + patterns.add(typeConverter, ctx, targetArch); + patterns.add(typeConverter, ctx, targetArch); + patterns.add(typeConverter, ctx, targetArch); + patterns.add(typeConverter, ctx, targetArch); + patterns.add>(typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add< + PTOTMatmulBiasToTMATMUL_BIAS, + PTOTMatmulMXToTMATMUL_MX, + PTOTMatmulMXAccToTMATMUL_MX_ACC, + PTOTMatmulMXBiasToTMATMUL_MX_BIAS, + PTOTMatmulBiasToTMATMUL_BIAS, + PTOTMatmulMXToTMATMUL_MX, + PTOTMatmulMXAccToTMATMUL_MX_ACC, + PTOTMatmulMXBiasToTMATMUL_MX_BIAS, + PTOTGemvBiasToTGEMV_BIAS, + PTOTGemvMXToTGEMV_MX, + PTOTGemvMXAccToTGEMV_MX, + PTOTGemvMXBiasToTGEMV_MX, + PTOBarrierToEmitC + >(typeConverter, ctx); + + patterns.add(typeConverter, ctx); + + populateSCFToEmitCConversionPatterns(patterns); + // Keep CFG-style branches type-consistent when block argument types are + // converted (e.g. after lowering scf.while to cf.br/cf.cond_br). + populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); +} + +//===----------------------------------------------------------------------===// +// Pass +//===----------------------------------------------------------------------===// + +namespace { +struct EmitPTOManualPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(EmitPTOManualPass) + + PTOArch targetArch; + + EmitPTOManualPass() : targetArch(PTOArch::A3) {} + + explicit EmitPTOManualPass(PTOArch arch) : targetArch(arch) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + LLVM_DEBUG(llvm::dbgs() << "DEBUG: Start PTOToEmitC Pass\n"); + MLIRContext *ctx = &getContext(); + ModuleOp mop = getOperation(); + + if (failed(pto::validatePTOEntryFunctions(mop))) + return signalPassFailure(); + pto::annotatePTOEntryFunctions(mop); + + // A3 requires explicit FFTS base setup for inter-core sync ops. + if (targetArch == PTOArch::A3) { + bool hasMissingSetFFTs = false; + for (auto func : mop.getOps()) { + if (!hasInterCoreSyncOp(func)) + continue; + if (hasSetFFTsOp(func)) + continue; + hasMissingSetFFTs = true; + func.emitError() + << "A3 inter-core sync requires explicit `pto.set_ffts` in the " + "same function when using `pto.sync.set`/`pto.sync.wait`"; + } + if (hasMissingSetFFTs) + return signalPassFailure(); + } + + bool needsEventIdArrayHelper = false; + bool needsTRandomHelper = false; + bool needsGlobalTensorDataHelper = false; + bool needsCommInclude = false; + mop.walk([&](Operation *op) { + if (isa(op)) + needsEventIdArrayHelper = true; + if (isa(op)) + needsTRandomHelper = true; + if (isa(op)) + needsGlobalTensorDataHelper = true; + if (isa(op)) + needsCommInclude = true; + }); + + // 1. 插入头文件 + auto loc = mop->getLoc(); + OpBuilder builder(ctx); + builder.setInsertionPointToStart(mop.getBody()); + builder.create( + loc, "pto/pto-inst.hpp", /*is_standard_include=*/false); + if (needsCommInclude) { + builder.create( + loc, builder.getStringAttr(R"cpp( +#ifndef PIPE_FIX +#define PIPE_FIX PIPE_M +#endif +)cpp")); + builder.create( + loc, "pto/comm/pto_comm_inst.hpp", /*is_standard_include=*/false); + } + builder.create( + loc, builder.getStringAttr("using namespace pto;")); + if (needsGlobalTensorDataHelper) { + builder.create( + loc, builder.getStringAttr(R"cpp( +template +static AICORE inline auto PTOAS__GLOBAL_TENSOR_DATA(Tensor &tensor) + -> decltype(tensor.data()) { + return tensor.data(); +} +)cpp")); + } + if (needsEventIdArrayHelper) { + builder.create( + loc, builder.getStringAttr(R"cpp( +template +struct PTOAS_EventIdArray { + static_assert(N > 0, "PTOAS_EventIdArray requires a positive static size"); + int32_t data[N] = {}; + + AICORE inline int32_t &operator[](int32_t idx) { return data[idx]; } + AICORE inline const int32_t &operator[](int32_t idx) const { return data[idx]; } +}; +)cpp")); + } + if (needsTRandomHelper) { + builder.create( + loc, builder.getStringAttr(R"cpp( +template +static AICORE inline void PTOAS__TRANDOM( + DstTile &dst, uint32_t key0, uint32_t key1, uint32_t counter0, + uint32_t counter1, uint32_t counter2, uint32_t counter3) { + TRandomKey key = {key0, key1}; + TRandomCounter counter = {counter0, counter1, counter2, counter3}; + TRANDOM(dst, key, counter); +} +)cpp")); + } + builder.create( + loc, builder.getStringAttr(R"cpp( +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static AICORE inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} +)cpp")); + // Only inject the bitcast helper when we actually lower ops that need it + // (e.g. arith.bitcast or arith.maximumf/minimumf tie-breaking on zeros). + bool needsBitcastHelper = false; + mop.walk([&](Operation *op) { + if (isa(op)) { + needsBitcastHelper = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (needsBitcastHelper) { + builder.create( + loc, builder.getStringAttr(R"cpp( + template + static inline To ptoas_bitcast(From from) { + static_assert(sizeof(To) == sizeof(From), "ptoas_bitcast: size mismatch"); + To to; + __builtin_memcpy(&to, &from, sizeof(To)); + return to; + } + )cpp")); + } + + // 1.5 Pre-lower SCF constructs not handled by SCFToEmitC. + { + // scf.while / scf.index_switch are lowered via CFG blocks. This is not + // possible inside ops that require single-block regions (e.g. scf.for / + // scf.if). If we see such nesting, lower the entire function to the + // ControlFlow dialect first. + bool needsAnySCFToCF = false; + for (auto func : mop.getOps()) { + if (needsWholeFunctionSCFToCF(func)) { + needsAnySCFToCF = true; + break; + } + } + if (needsAnySCFToCF) { + RewritePatternSet scfToCfPatterns(ctx); + populateSCFToControlFlowConversionPatterns(scfToCfPatterns); + FrozenRewritePatternSet frozenSCFToCF(std::move(scfToCfPatterns)); + + ConversionTarget scfToCfTarget(*ctx); + // Only eliminate the single-block SCF constructs; we'll pre-lower + // scf.while/index_switch/execute_region ourselves afterwards. + scfToCfTarget.addIllegalOp(); + scfToCfTarget.markUnknownOpDynamicallyLegal( + [](Operation *) { return true; }); + + for (auto func : mop.getOps()) { + if (!needsWholeFunctionSCFToCF(func)) + continue; + if (failed(applyPartialConversion(func, scfToCfTarget, + frozenSCFToCF))) { + func.emitError() + << "failed to lower nested SCF to ControlFlow (SCFToCF)"; + return signalPassFailure(); + } + } + } + + RewritePatternSet scfLoweringPatterns(ctx); + scfLoweringPatterns.add(ctx); + (void)applyPatternsAndFoldGreedily(mop, std::move(scfLoweringPatterns)); + + bool hasUnsupportedSCF = false; + mop.walk([&](Operation *op) { + if (isa(op)) { + hasUnsupportedSCF = true; + op->emitError() << "Unsupported SCF op remained after pre-lowering"; + return WalkResult::interrupt(); + } + if (isa(op)) { + hasUnsupportedSCF = true; + op->emitError() + << "Unsupported CF op remained after pre-lowering: cf.switch"; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (hasUnsupportedSCF) + return signalPassFailure(); + } + + PTOToEmitCTypeConverter typeConverter(ctx, targetArch); + + // 2. Pre-convert SCF structural op types (e.g. scf.if/scf.for results) + // using the same type converter. This avoids creating emitc.variable with + // unsupported types such as memref. + { + RewritePatternSet scfTypePatterns(ctx); + ConversionTarget scfTypeTarget(*ctx); + scf::populateSCFStructuralTypeConversionsAndLegality( + typeConverter, scfTypePatterns, scfTypeTarget); + scfTypeTarget.markUnknownOpDynamicallyLegal( + [](Operation *) { return true; }); + + if (failed(applyPartialConversion(mop, scfTypeTarget, + std::move(scfTypePatterns)))) { + mop.emitError("failed to reconcile SCF structural types"); + return signalPassFailure(); + } + } + + // 3. 配置转换目标 + ConversionTarget target(*ctx); + + target.addIllegalDialect(); + target.addIllegalDialect(); + target.addIllegalDialect(); + target.addIllegalDialect(); + + // If we introduced CFG branches (e.g. from scf.while), make sure they are + // updated to use legalized operand types. + target.addDynamicallyLegalOp( + [&](Operation *op) { + return isLegalForBranchOpInterfaceTypeConversionPattern(op, + typeConverter); + }); + + // [关键] 允许 Cast 存在,最后统一清理 + target.addLegalOp(); + + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + + target.addLegalDialect(); + target.addLegalOp(); + + auto solver = std::make_unique(); + solver->load(); + solver->load(); + if (failed(solver->initializeAndRun(getOperation()))) + return signalPassFailure(); + + RewritePatternSet patterns(ctx); + populatePTOToEmitCPatterns(patterns, typeConverter, ctx, *solver, targetArch); + + // 4. 执行转换 + if (failed(applyPartialConversion(mop, target, std::move(patterns)))) { + llvm::errs() << "Conversion FAILED! Rolling back executed.\n"; + return signalPassFailure(); + } + + // ========================================================================= + // 5. [终极清理] + // 顺序至关重要: + // Step A: 先移除所有 Cast,让 Loop 的 Operand 类型变成底层类型 (如 int32) + // Step B: 再根据新的 Operand 类型,修复 Loop IV 的类型 + // ========================================================================= + + // --- Step A: 清理 UnrealizedConversionCastOp --- + // Prefer dropping redundant/unused casts; otherwise lower to emitc.cast + // so the C++ emitter can print it. + auto isEmitCTileLikeType = [](Type ty) { + auto opaqueTy = dyn_cast(ty); + if (!opaqueTy) + return false; + StringRef value = opaqueTy.getValue(); + return value.contains("Tile<") || value.contains("ConvTile<"); + }; + + llvm::SmallVector castsToErase; + bool castCleanupFailed = false; + mop.walk([&](UnrealizedConversionCastOp cast) { + if (castCleanupFailed) + return; + + if (cast->getNumOperands() != 1 || cast->getNumResults() != 1) { + cast.emitError() << "unsupported unrealized_conversion_cast shape"; + castCleanupFailed = true; + return; + } + + Value input = cast.getOperand(0); + Value output = cast.getResult(0); + Type inTy = input.getType(); + Type outTy = output.getType(); + + if (output.use_empty()) { + castsToErase.push_back(cast); + return; + } + + if (inTy == outTy) { + output.replaceAllUsesWith(input); + castsToErase.push_back(cast); + return; + } + + // SCF/CFG type conversion can transiently materialize pointer->memref + // bridge casts. At this stage, the producing value is already in the + // lowered EmitC pointer form; keep it and drop the bridge cast. + if (isEmitCPointerLikeType(inTy) && isa(outTy)) { + output.replaceAllUsesWith(input); + castsToErase.push_back(cast); + return; + } + + // SCF structural type conversion may leave a bridge from the converted + // EmitC tile value back to the original pto.tile_buf type for PTO op + // users. After PTO ops are lowered, the EmitC tile value is the value we + // want to keep. + if (isEmitCTileLikeType(inTy) && isa(outTy)) { + output.replaceAllUsesWith(input); + castsToErase.push_back(cast); + return; + } + + if (emitc::isSupportedEmitCType(inTy) && emitc::isSupportedEmitCType(outTy)) { + OpBuilder builder(cast); + auto c = builder.create(cast.getLoc(), outTy, input); + output.replaceAllUsesWith(c.getResult()); + castsToErase.push_back(cast); + return; + } + + cast.emitError() << "cannot lower unrealized_conversion_cast(" << inTy + << " -> " << outTy << ") to emitc.cast"; + castCleanupFailed = true; + }); + + for (auto cast : castsToErase) + cast.erase(); + + if (castCleanupFailed) + return signalPassFailure(); + + // --- Step A2: Sink casts of emitc.variable "reads" to their use sites --- + // + // SCFToEmitC lowers scf.if/scf.for results via mutable `emitc.variable` and + // `emitc.assign`. During type conversion, casts from the variable handle to + // the converted type may be materialized right after the variable + // declaration, effectively snapshotting the value *before* assignments. That + // produces wrong C++ (use-before-init / stale reads). + // + // Fix by re-materializing the cast at each use site so it reads the variable + // at the point of use. + { + SmallVector castOpsToSink; + mop.walk([&](emitc::CastOp castOp) { + if (castOp.getSource().getDefiningOp()) + castOpsToSink.push_back(castOp); + }); + + for (emitc::CastOp castOp : castOpsToSink) { + Value src = castOp.getSource(); + Type dstTy = castOp.getResult().getType(); + Value oldRes = castOp.getResult(); + + // Replace each use with a freshly inserted cast right before the user. + for (OpOperand &use : llvm::make_early_inc_range(oldRes.getUses())) { + Operation *user = use.getOwner(); + OpBuilder b(user); + b.setInsertionPoint(user); + auto newCast = b.create(castOp.getLoc(), dstTy, src); + use.set(newCast.getResult()); + } + + castOp.erase(); + } + } + + // --- Step B: 修复 Loop 归纳变量 (IV) --- + // 此时 emitc.for 的 operand 已经是 int32 了,我们检查 IV 是否匹配,不匹配则修正 + mop.walk([&](emitc::ForOp forOp) { + Type boundTy = forOp.getLowerBound().getType(); + BlockArgument iv = forOp.getBody()->getArgument(0); + + if (iv.getType() != boundTy) { + iv.setType(boundTy); // 强制将 IV 类型 (index) 修改为与边界一致 (int32) + } + }); + + // --- Step C: 消除冗余 Tile 变量 (Dead Code Elimination) [新增] --- + // 逻辑:如果一个 emitc.variable 没有被读取(use_empty), + // 那么它自己,以及给它赋值的 TASSIGN 都可以删除。 + // 注意:TASSIGN(v15, v9) 会把 v15 作为 Operand 0 使用,所以 v15 不是严格的 use_empty。 + // 我们需要检查:v15 是否除了 TASSIGN 之外没有其他 User。 + + llvm::SmallVector deadVars; + mop.walk([&](emitc::VariableOp varOp) { + // 检查该变量的所有 User + bool isRead = false; + for (Operation* user : varOp.getResult().getUsers()) { + // 如果 User 是 TASSIGN 且变量是第0个参数(dst),不算"读取" + if (auto call = dyn_cast(user)) { + if (call.getCallee() == "TASSIGN" && call.getOperand(0) == varOp.getResult()) { + continue; // 这是一个赋值操作,不算有效使用 + } + } + // 如果还有其他用途(如 TLOAD, TMOV, TMATMUL),则该变量有用 + isRead = true; + break; + } + + if (!isRead) { + deadVars.push_back(varOp); + } + }); + + for (auto varOp : deadVars) { + // 1. 先删除所有使用该变量的 TASSIGN + llvm::SmallVector usersToErase; + for (Operation* user : varOp.getResult().getUsers()) { + // 我们上面已经确认过,剩下的 user 只能是 TASSIGN + usersToErase.push_back(user); + } + for (auto u : usersToErase) u->erase(); + + // 2. 删除变量定义本身 + varOp.erase(); + } + + llvm::SmallVector deadConsts; + mop.walk([&](emitc::ConstantOp constOp) { + if (constOp.getResult().use_empty()) + deadConsts.push_back(constOp); + }); + for (auto constOp : deadConsts) + constOp.erase(); + + // ========================================================================= + } + }; +} // namespace + +std::unique_ptr mlir::pto::createEmitPTOManualPass() { + return std::make_unique(); +} + +std::unique_ptr mlir::pto::createEmitPTOManualPass(PTOArch arch) { + return std::make_unique(arch); +} diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index c21669b81..521032476 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -6,3610 +6,5 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -//===- PTOViewToMemref.cpp ------------------------------------------------===// -//===----------------------------------------------------------------------===// -// -// Lower PTO tile/view operations to memref-based IR while preserving tile -// metadata through binding ops and SSA backtracking. -#include "PTO/IR/PTO.h" -#include "PTO/IR/PTOTypeUtils.h" -#include "PTO/Transforms/Passes.h" - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" - -#include "mlir/IR/AsmState.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/SymbolTable.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace pto { -#define GEN_PASS_DEF_PTOVIEWTOMEMREF -#include "PTO/Transforms/Passes.h.inc" -} // namespace pto -} // namespace mlir - -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/raw_ostream.h" -#include "Utils.h" // 假设包含一些通用的工具函数 - -#include -#include -#include - -#define DEBUG_TYPE "pto-view-to-memref" - -using namespace mlir; - -namespace mlir { -namespace pto { - -static constexpr llvm::StringLiteral kLoweredSetValidShapeAttrName = - "__pto.lowered_set_validshape"; -static constexpr llvm::StringLiteral kForceDynamicValidShapeAttrName = - "__pto.force_dynamic_valid_shape"; - -namespace { - -static void markForceDynamicValidShape(Operation *op, bool force, - MLIRContext *ctx); - -static Type convertPTOTypeToMemRef(Type t); - -constexpr size_t kTileRank2D = 2; -constexpr size_t kRowDimensionIndex = 0; -constexpr size_t kColumnDimensionIndex = 1; -constexpr unsigned kShapeVectorInlineCapacity = 4; -constexpr unsigned kOperationVectorInlineCapacity = 8; - -constexpr int64_t kElementBytes1 = 1; -constexpr int64_t kElementBytes2 = 2; -constexpr int64_t kElementBytes4 = 4; -constexpr int64_t kElementBytes8 = 8; -constexpr int64_t kElementBytes16 = 16; -constexpr int64_t kElementBytes32 = 32; - -constexpr int64_t kInnerExtent1 = 1; -constexpr int64_t kInnerExtent2 = 2; -constexpr int64_t kInnerExtent4 = 4; -constexpr int64_t kInnerExtent8 = 8; -constexpr int64_t kInnerExtent16 = 16; -constexpr int64_t kInnerExtent32 = 32; - -constexpr int32_t kFractalSize32 = 32; -constexpr int32_t kFractalSize512 = 512; -constexpr int32_t kFractalSize1024 = 1024; - -constexpr int32_t kBLayoutColMajor = - static_cast(BLayout::ColMajor); -constexpr int32_t kSLayoutNoneBox = - static_cast(SLayout::NoneBox); -constexpr int32_t kSLayoutRowMajor = - static_cast(SLayout::RowMajor); -constexpr int32_t kSLayoutColMajor = - static_cast(SLayout::ColMajor); -constexpr int32_t kCompactModeRowPlusOne = - static_cast(CompactMode::RowPlusOne); - -constexpr unsigned kThirdOperandIndex = 2; -constexpr unsigned kFourthOperandIndex = 3; -constexpr unsigned kFifthOperandIndex = 4; -constexpr unsigned kSixthOperandIndex = 5; - -template -using SmallInlineVector = SmallVector; - -template -using DefaultInlineVector = SmallVector; - -// ============================================================================= -// Helper: Metadata Backtracking (核心机制) -// ============================================================================= -// 从一个 MemRef Value 向上回溯,找到它绑定的 TileBufConfig。 -// 这解决了 "Type Erasure" 问题:memref 类型本身不包含 config,但 SSA 定义链包含。 -static mlir::pto::TileBufConfigAttr lookupConfig(Value v) { - // 1. 最直接的情况:它就是 bind_tile 的结果 - if (auto bind = v.getDefiningOp()) { - return bind.getConfig(); - } - // PointerCastOp can also carry tile metadata (used when alloc_tile specifies - // an explicit address). - if (auto pc = v.getDefiningOp()) { - if (auto cfg = pc.getConfig()) - return *cfg; - return {}; - } - - // 2. 穿透 View 操作 (SubView, Cast 等) 向上查找 - if (auto subview = v.getDefiningOp()) { - return lookupConfig(subview.getSource()); - } - if (auto cast = v.getDefiningOp()) { - return lookupConfig(cast.getSource()); - } - if (auto cast = v.getDefiningOp()) { - return lookupConfig(cast.getSource()); - } - - // 如果追溯到 BlockArgument (函数参数) 或其他无法穿透的 Op,则返回空 - return {}; -} - -// ============================================================================= -// Helper: Valid dims backtracking (v_row / v_col) -// ============================================================================= -static void lookupValidDims(Value v, Value &vRow, Value &vCol) { - if (auto bind = v.getDefiningOp()) { - vRow = bind.getValidRow(); - vCol = bind.getValidCol(); - return; - } - if (auto pc = v.getDefiningOp()) { - vRow = pc.getValidRow(); - vCol = pc.getValidCol(); - return; - } - if (auto subview = v.getDefiningOp()) { - lookupValidDims(subview.getSource(), vRow, vCol); - return; - } - if (auto cast = v.getDefiningOp()) { - lookupValidDims(cast.getSource(), vRow, vCol); - return; - } - if (auto cast = v.getDefiningOp()) { - lookupValidDims(cast.getSource(), vRow, vCol); - return; - } - vRow = Value(); - vCol = Value(); -} - -// ============================================================================= -// Helper Functions for Layout Normalization -// ============================================================================= - -struct TileLayoutInfo { - int64_t rowStride = 1; - int64_t colStride = 1; - int64_t innerRows = 1; - int64_t innerCols = 1; - bool boxed = false; // slayout != NoneBox -}; - -struct TileLayoutConfig { - int32_t bLayout = 0; - int32_t sLayout = 0; - int32_t fractalSize = kFractalSize512; - int32_t compactMode = 0; -}; - -static int64_t getElemBytes(Type elemTy) { - unsigned bytes = getPTOStorageElemByteSize(elemTy); - return bytes == 0 ? -1 : static_cast(bytes); -} - -template -static bool readEnumAttrOrIntegerI32(Attribute attr, int32_t &out) { - if (auto enumAttr = dyn_cast(attr)) { - out = static_cast(enumAttr.getValue()); - return true; - } - if (auto intAttr = dyn_cast(attr)) { - out = static_cast(intAttr.getInt()); - return true; - } - return false; -} - -static bool readBLayoutI32(Attribute attr, int32_t &out) { - return readEnumAttrOrIntegerI32(attr, out); -} - -static bool readSLayoutI32(Attribute attr, int32_t &out) { - return readEnumAttrOrIntegerI32(attr, out); -} - -static bool readCompactModeI32(Attribute attr, int32_t &out) { - return readEnumAttrOrIntegerI32(attr, out); -} - -static Value peelIndexLikeCast(Value value) { - while (true) { - if (auto castOp = value.getDefiningOp()) { - value = castOp.getIn(); - continue; - } - if (auto extOp = value.getDefiningOp()) { - value = extOp.getIn(); - continue; - } - if (auto extOp = value.getDefiningOp()) { - value = extOp.getIn(); - continue; - } - if (auto truncOp = value.getDefiningOp()) { - value = truncOp.getIn(); - continue; - } - return value; - } -} - -static bool getConstIndexValue(Value value, int64_t &out) { - value = peelIndexLikeCast(value); - if (auto constIndex = value.getDefiningOp()) { - out = constIndex.value(); - return true; - } - if (auto constInt = value.getDefiningOp()) { - out = constInt.value(); - return true; - } - auto constOp = value.getDefiningOp(); - auto intAttr = - constOp ? dyn_cast(constOp.getValue()) : IntegerAttr(); - if (!intAttr) - return false; - out = intAttr.getInt(); - return true; -} - -static TileLayoutConfig getTileLayoutConfig(mlir::pto::TileBufConfigAttr cfg) { - TileLayoutConfig config; - (void)readBLayoutI32(cfg.getBLayout(), config.bLayout); - (void)readSLayoutI32(cfg.getSLayout(), config.sLayout); - if (auto attr = dyn_cast(cfg.getSFractalSize())) - config.fractalSize = static_cast(attr.getInt()); - (void)readCompactModeI32(cfg.getCompactMode(), config.compactMode); - return config; -} - -static bool getFractal512InnerExtent(int64_t elemBytes, int64_t &extent) { - switch (elemBytes) { - case kElementBytes1: - extent = kInnerExtent32; - return true; - case kElementBytes2: - extent = kInnerExtent16; - return true; - case kElementBytes4: - extent = kInnerExtent8; - return true; - case kElementBytes8: - extent = kInnerExtent4; - return true; - case kElementBytes16: - extent = kInnerExtent2; - return true; - case kElementBytes32: - extent = kInnerExtent1; - return true; - default: - return false; - } -} - -static bool computeBoxInnerShape(const TileLayoutConfig &config, Type elemTy, - TileLayoutInfo &info) { - info.boxed = config.sLayout != kSLayoutNoneBox; - if (!info.boxed) { - info.innerRows = kInnerExtent1; - info.innerCols = kInnerExtent1; - return true; - } - - int64_t elemBytes = getElemBytes(elemTy); - if (elemBytes <= 0) - return false; - - switch (config.fractalSize) { - case kFractalSize1024: - info.innerRows = kInnerExtent16; - info.innerCols = kInnerExtent16; - return true; - case kFractalSize32: - info.innerRows = kInnerExtent16; - info.innerCols = kInnerExtent2; - return true; - case kFractalSize512: - if (config.sLayout == kSLayoutRowMajor) { - info.innerRows = kInnerExtent16; - return getFractal512InnerExtent(elemBytes, info.innerCols); - } - if (config.sLayout == kSLayoutColMajor) { - if (!getFractal512InnerExtent(elemBytes, info.innerRows)) - return false; - info.innerCols = kInnerExtent16; - return true; - } - return false; - default: - return false; - } -} - -static bool computeTilePointerStrides(const TileLayoutConfig &config, - ArrayRef shape, - TileLayoutInfo &info) { - int64_t rows = shape[0]; - int64_t cols = shape[1]; - auto applyCompactToMajorStride = [&](int64_t majorStride) -> int64_t { - if (config.compactMode == kCompactModeRowPlusOne) - return majorStride + kInnerExtent1; - return majorStride; - }; - if (!info.boxed) { - if (config.bLayout == kBLayoutColMajor) { - info.rowStride = kInnerExtent1; - info.colStride = applyCompactToMajorStride(rows); - return true; - } - info.rowStride = applyCompactToMajorStride(cols); - info.colStride = kInnerExtent1; - return true; - } - - if (config.bLayout == kBLayoutColMajor) { - if (config.sLayout != kSLayoutRowMajor) - return false; - info.rowStride = info.innerCols; - info.colStride = applyCompactToMajorStride(rows); - return true; - } - - info.rowStride = applyCompactToMajorStride(cols); - info.colStride = info.innerRows; - return true; -} - -static bool computeTileLayoutInfo(mlir::pto::TileBufConfigAttr cfg, Type elemTy, - ArrayRef shape, - TileLayoutInfo &info) { - if (shape.size() != kTileRank2D || - llvm::is_contained(shape, ShapedType::kDynamic)) - return false; - - TileLayoutConfig config = getTileLayoutConfig(cfg); - return computeBoxInnerShape(config, elemTy, info) && - computeTilePointerStrides(config, shape, info); -} - -static void collectAffineAddTerms(AffineExpr root, - SmallVectorImpl &terms) { - SmallInlineVector pending{root}; - while (!pending.empty()) { - AffineExpr current = pending.pop_back_val(); - auto addExpr = llvm::dyn_cast(current); - if (!addExpr || addExpr.getKind() != AffineExprKind::Add) { - terms.push_back(current); - continue; - } - pending.push_back(addExpr.getRHS()); - pending.push_back(addExpr.getLHS()); - } -} - -static bool tryAssignAffineStride(AffineExpr expr, - MutableArrayRef strides) { - if (auto dim = llvm::dyn_cast(expr)) { - strides[dim.getPosition()] = 1; - return true; - } - - auto mulExpr = llvm::dyn_cast(expr); - if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul) - return false; - - auto assignStride = [&](AffineExpr dimExpr, - AffineExpr constantExpr) -> bool { - auto dim = llvm::dyn_cast(dimExpr); - auto constant = llvm::dyn_cast(constantExpr); - if (!dim || !constant) - return false; - strides[dim.getPosition()] = constant.getValue(); - return true; - }; - return assignStride(mulExpr.getLHS(), mulExpr.getRHS()) || - assignStride(mulExpr.getRHS(), mulExpr.getLHS()); -} - -[[maybe_unused]] static void decomposeStridedLayout(AffineMap map, - SmallVectorImpl &strides) { - strides.assign(map.getNumDims(), 0); - if (map.getNumResults() != 1) - return; - - SmallInlineVector terms; - collectAffineAddTerms(map.getResult(0), terms); - for (AffineExpr term : terms) - (void)tryAssignAffineStride(term, strides); -} - -static Value makeIndexConstant(IRRewriter &rewriter, Location loc, - int64_t value) { - return rewriter.create(loc, rewriter.getIndexType(), - rewriter.getIndexAttr(value)); -} - -static SmallVector computeCompactStrides(ArrayRef shape) { - SmallVector strides(shape.size(), 1); - int64_t stride = 1; - for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { - strides[i] = stride; - if (shape[i] != ShapedType::kDynamic) - stride *= shape[i]; - } - return strides; -} - -static void materializeStaticValidDims(IRRewriter &rewriter, Location loc, - mlir::pto::TileBufType tbTy, Value &vRow, - Value &vCol) { - ArrayRef validShape = tbTy.getValidShape(); - if (tbTy.hasDynamicValid()) - return; - if (!validShape.empty() && validShape[kRowDimensionIndex] >= 0) - vRow = makeIndexConstant(rewriter, loc, validShape[kRowDimensionIndex]); - if (validShape.size() >= kTileRank2D && - validShape[kColumnDimensionIndex] >= 0) - vCol = makeIndexConstant(rewriter, loc, validShape[kColumnDimensionIndex]); -} - -static bool checkMultipleOf(Operation *op, int64_t value, int64_t divisor, - StringRef label) { - if (divisor <= 0) { - op->emitError("boxed layout requires positive divisor for ") << label; - return false; - } - if (value % divisor == 0) - return true; - op->emitError("boxed layout requires ") - << label << " multiple of " << divisor << ", got " << value; - return false; -} - -// 确保 Value 是 Index 类型 -static Value ensureIndex(IRRewriter &rewriter, Location loc, Value v, - Operation *anchorOp) { - if (v.getType().isIndex()) - return v; - if (isa(v.getType())) - return rewriter.create(loc, rewriter.getIndexType(), v); - if (anchorOp) - anchorOp->emitError() << "expected index or integer, but got " << v.getType(); - return Value(); -} - -static bool tryGetIndexAttrFromValue(IRRewriter &rewriter, Value v, - IntegerAttr &constAttr) { - if (auto cOp = v.getDefiningOp()) { - constAttr = rewriter.getIndexAttr(cOp.value()); - return true; - } - if (auto cInt = v.getDefiningOp()) { - constAttr = rewriter.getIndexAttr(cInt.value()); - return true; - } - return false; -} - -static void appendMixedIndex(IRRewriter &rewriter, Location loc, Value v, - Operation *anchorOp, - SmallVectorImpl &mixedVals) { - IntegerAttr constAttr; - if (tryGetIndexAttrFromValue(rewriter, v, constAttr)) { - mixedVals.push_back(constAttr); - return; - } - mixedVals.push_back(ensureIndex(rewriter, loc, v, anchorOp)); -} - -static bool foldAddPtrChainIntoOffset(IRRewriter &rewriter, Location loc, - Value &base, Value &totalOffset) { - bool folded = false; - while (auto add = base.getDefiningOp()) { - folded = true; - Value off = ensureIndex(rewriter, loc, add.getOperand(1), add); - totalOffset = - totalOffset ? rewriter.create(loc, totalOffset, off) : off; - base = add.getOperand(0); - } - return folded; -} - -static Value clampSubViewValidDim(IRRewriter &rewriter, Location loc, - Value explicitValid, int64_t size, - Operation *anchorOp) { - Value sizeVal = rewriter.create(loc, size); - if (!explicitValid) - return sizeVal; - - int64_t cst = 0; - if (getConstIndexValue(explicitValid, cst)) - return rewriter.create(loc, std::min(cst, size)); - - Value v = ensureIndex(rewriter, loc, explicitValid, anchorOp); - Value lt = rewriter.create(loc, arith::CmpIPredicate::slt, v, - sizeVal); - return rewriter.create(loc, lt, v, sizeVal); -} - -[[maybe_unused]] static void dumpPretty(Operation *op, llvm::raw_ostream &os) { - OpPrintingFlags flags; - flags.useLocalScope(); - AsmState state(op, flags); - op->print(os, state); - os << "\n"; - os.flush(); -} - -// ============================================================================= -// Type Converter Logic -// ============================================================================= - -static SmallVector buildTileMemRefStrides(mlir::pto::TileBufType tbTy) { - TileLayoutInfo info; - if (computeTileLayoutInfo(tbTy.getConfigAttr(), tbTy.getElementType(), - tbTy.getShape(), info)) { - return {info.rowStride, info.colStride}; - } - return computeCompactStrides(tbTy.getShape()); -} - -static Type convertTileBufTypeToMemRef(mlir::pto::TileBufType tbTy) { - auto layoutAttr = StridedLayoutAttr::get(tbTy.getContext(), - ShapedType::kDynamic, - buildTileMemRefStrides(tbTy)); - return MemRefType::get(tbTy.getShape(), tbTy.getElementType(), layoutAttr, - tbTy.getMemorySpace()); -} - -static Type convertPTOTypeToMemRef(Type t) { - // 1. 处理 !pto.ptr - if (auto pty = dyn_cast(t)) { - return MemRefType::get({ShapedType::kDynamic}, pty.getElementType(), - MemRefLayoutAttrInterface(), Attribute()); - } - - // 2. 处理 !pto.tile_buf<...> - if (auto tbTy = dyn_cast(t)) - return convertTileBufTypeToMemRef(tbTy); - if (auto tvTy = dyn_cast(t)) - return MemRefType::get(tvTy.getShape(), tvTy.getElementType(), - MemRefLayoutAttrInterface(), Attribute()); - if (auto partTy = dyn_cast(t)) - return MemRefType::get(partTy.getShape(), partTy.getElementType(), - MemRefLayoutAttrInterface(), Attribute()); - // 其他类型透传 - return t; -} - -// Ensure scf.if result types follow the rewritten yield operand types. -// PTOViewToMemref rewrites tile values to memref in branch bodies, but scf.if -// result types are not auto-updated by those op-local rewrites. -static LogicalResult reconcileSCFIfResultTypes(func::FuncOp func) { - DefaultInlineVector ifOps; - func.walk([&](scf::IfOp ifOp) { ifOps.push_back(ifOp); }); - - for (scf::IfOp ifOp : ifOps) { - if (ifOp.getNumResults() == 0) - continue; - - auto thenYield = dyn_cast(ifOp.thenBlock()->getTerminator()); - auto elseYield = dyn_cast(ifOp.elseBlock()->getTerminator()); - if (!thenYield || !elseYield) { - ifOp.emitError("result-bearing scf.if must end with scf.yield in both " - "then/else regions"); - return failure(); - } - - if (thenYield.getNumOperands() != ifOp.getNumResults() || - elseYield.getNumOperands() != ifOp.getNumResults()) { - ifOp.emitError("scf.if result count does not match yielded values"); - return failure(); - } - - for (unsigned i = 0; i < ifOp.getNumResults(); ++i) { - Type thenTy = thenYield.getOperand(i).getType(); - Type elseTy = elseYield.getOperand(i).getType(); - if (thenTy != elseTy) { - ifOp.emitError() << "scf.if branch yield type mismatch at result #" << i - << ": then=" << thenTy << ", else=" << elseTy; - return failure(); - } - - if (ifOp.getResult(i).getType() != thenTy) - ifOp.getResult(i).setType(thenTy); - } - } - - return success(); -} - -static LogicalResult reconcileSCFForResultTypes(func::FuncOp func) { - DefaultInlineVector forOps; - func.walk([&](scf::ForOp forOp) { forOps.push_back(forOp); }); - - for (scf::ForOp forOp : forOps) { - if (forOp.getNumResults() == 0) - continue; - - auto yield = dyn_cast(forOp.getBody()->getTerminator()); - if (!yield) { - forOp.emitError("result-bearing scf.for must end with scf.yield"); - return failure(); - } - - if (yield.getNumOperands() != forOp.getNumResults() || - forOp.getInitArgs().size() != forOp.getNumResults()) { - forOp.emitError("scf.for result count does not match iter/yield values"); - return failure(); - } - - for (unsigned i = 0; i < forOp.getNumResults(); ++i) { - Type initTy = forOp.getInitArgs()[i].getType(); - Type yieldTy = yield.getOperand(i).getType(); - if (initTy != yieldTy) { - forOp.emitError() << "scf.for init/yield type mismatch at result #" << i - << ": init=" << initTy << ", yield=" << yieldTy; - return failure(); - } - - BlockArgument iterArg = forOp.getRegionIterArg(i); - if (iterArg.getType() != initTy) - iterArg.setType(initTy); - if (forOp.getResult(i).getType() != initTy) - forOp.getResult(i).setType(initTy); - } - } - - return success(); -} - -static LogicalResult markLoweredSetValidShapeOps(func::FuncOp func, - MLIRContext *ctx) { - WalkResult result = func.walk([&](mlir::pto::SetValidShapeOp op) { - if (isa(op.getSource().getType())) { - if (!lookupConfig(op.getSource())) { - op.emitError( - "set_validshape requires a locally bound tile source; function " - "arguments/results are unsupported"); - return WalkResult::interrupt(); - } - op->setAttr(kLoweredSetValidShapeAttrName, UnitAttr::get(ctx)); - return WalkResult::advance(); - } - op->removeAttr(kLoweredSetValidShapeAttrName); - return WalkResult::advance(); - }); - return result.wasInterrupted() ? failure() : success(); -} - -static void markForceDynamicValidShape(Operation *op, bool force, - MLIRContext *ctx) { - if (force) { - op->setAttr(kForceDynamicValidShapeAttrName, UnitAttr::get(ctx)); - return; - } - op->removeAttr(kForceDynamicValidShapeAttrName); -} - -[[maybe_unused]] static void rewriteFunctionSignature(func::FuncOp func, MLIRContext *ctx) { - Block &entry = func.front(); - auto fnTy = func.getFunctionType(); - - SmallVector newInputs; - for (Type type : fnTy.getInputs()) - newInputs.push_back(convertPTOTypeToMemRef(type)); - - SmallVector newResults; - for (Type type : fnTy.getResults()) - newResults.push_back(convertPTOTypeToMemRef(type)); - - for (unsigned i = 0; i < entry.getNumArguments(); ++i) { - if (entry.getArgument(i).getType() != newInputs[i]) - entry.getArgument(i).setType(newInputs[i]); - } - func.setFunctionType(FunctionType::get(ctx, newInputs, newResults)); -} - -[[maybe_unused]] static LogicalResult lowerAllocTileOps(func::FuncOp func, MLIRContext *ctx) { - DefaultInlineVector allocTiles; - func.walk([&](mlir::pto::AllocTileOp op) { allocTiles.push_back(op); }); - - for (auto op : allocTiles) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); - - auto tbTy = dyn_cast(op.getResult().getType()); - if (!tbTy) - continue; - - SmallInlineVector shape(tbTy.getShape().begin(), - tbTy.getShape().end()); - Type elemTy = tbTy.getElementType(); - SmallVector strides = buildTileMemRefStrides(tbTy); - - auto targetLayout = - StridedLayoutAttr::get(ctx, ShapedType::kDynamic, strides); - auto targetType = - MemRefType::get(shape, elemTy, targetLayout, tbTy.getMemorySpace()); - - Value vRow = op.getValidRow(); - Value vCol = op.getValidCol(); - materializeStaticValidDims(rewriter, loc, tbTy, vRow, vCol); - - auto configAttr = tbTy.getConfigAttr(); - if (!configAttr) - configAttr = pto::TileBufConfigAttr::getDefault(ctx); - - if (Value addr = op.getAddr()) { - auto pc = rewriter.create( - loc, targetType, ValueRange{addr}, vRow ? vRow : Value(), - vCol ? vCol : Value(), configAttr); - markForceDynamicValidShape(pc, tbTy.hasDynamicValid(), ctx); - auto bindOp = rewriter.create( - loc, targetType, pc.getResult(), vRow ? vRow : Value(), - vCol ? vCol : Value(), configAttr); - markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); - rewriter.replaceOp(op, bindOp.getResult()); - continue; - } - - auto allocLayout = StridedLayoutAttr::get(ctx, 0, strides); - auto allocType = - MemRefType::get(shape, elemTy, allocLayout, tbTy.getMemorySpace()); - Value alloc = rewriter.create(loc, allocType); - auto bindOp = rewriter.create( - loc, targetType, alloc, vRow ? vRow : Value(), vCol ? vCol : Value(), - configAttr); - markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); - rewriter.replaceOp(op, bindOp.getResult()); - } - return success(); -} - -[[maybe_unused]] static LogicalResult lowerDeclareTileOps(func::FuncOp func, MLIRContext *ctx) { - DefaultInlineVector declaredTiles; - func.walk([&](mlir::pto::DeclareTileOp op) { declaredTiles.push_back(op); }); - - for (auto op : declaredTiles) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); - - auto tbTy = dyn_cast(op.getTile().getType()); - if (!tbTy) { - op.emitError("declare_tile result must be tile_buf type"); - return failure(); - } - - auto targetType = dyn_cast(convertPTOTypeToMemRef(tbTy)); - if (!targetType) { - op.emitError("failed to convert declare_tile result to memref type"); - return failure(); - } - - auto configAttr = tbTy.getConfigAttr(); - if (!configAttr) - configAttr = pto::TileBufConfigAttr::getDefault(ctx); - - Value vRow; - Value vCol; - materializeStaticValidDims(rewriter, loc, tbTy, vRow, vCol); - - auto declaredMemRef = - rewriter.create(loc, targetType); - auto bindOp = rewriter.create( - loc, targetType, declaredMemRef.getResult(), vRow ? vRow : Value(), - vCol ? vCol : Value(), configAttr); - markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); - rewriter.replaceOp(op, bindOp.getResult()); - } - return success(); -} - -static Value castIndexToI64(IRRewriter &rewriter, Location loc, Value value) { - Type i64Ty = rewriter.getI64Type(); - if (value.getType() == i64Ty) - return value; - return rewriter.create(loc, i64Ty, value).getResult(); -} - -static FailureOr -materializePtrToIntAddPtrAddress(IRRewriter &rewriter, Location loc, - mlir::pto::PtrToIntOp anchor, Value source) { - SmallVector addPtrChain; - Value base = source; - while (auto add = base.getDefiningOp()) { - addPtrChain.push_back(add); - base = add.getOperand(0); - } - - if (addPtrChain.empty()) - return failure(); - - auto baseMemTy = dyn_cast(base.getType()); - if (!baseMemTy) { - anchor.emitOpError( - "pto.addptr source base could not be lowered to a GM memref"); - return failure(); - } - - Value byteAddress = rewriter.create( - loc, rewriter.getI64Type(), base); - for (auto add : addPtrChain) { - auto addPtrTy = dyn_cast(add.getResult().getType()); - if (!addPtrTy) { - anchor.emitOpError("requires pto.addptr source to have !pto.ptr result " - "type before byte-address lowering"); - return failure(); - } - - unsigned elemBytes = - mlir::pto::getPTOStorageElemByteSize(addPtrTy.getElementType()); - if (elemBytes == 0) { - anchor.emitOpError("cannot lower pto.addptr source with unknown element " - "byte size to a byte address"); - return failure(); - } - - Value byteOffset = castIndexToI64(rewriter, loc, add.getOffset()); - if (elemBytes != 1) { - Value elemBytesValue = - rewriter.create(loc, elemBytes, 64); - byteOffset = - rewriter.create(loc, byteOffset, elemBytesValue) - .getResult(); - } - byteAddress = - rewriter.create(loc, byteAddress, byteOffset).getResult(); - } - - return byteAddress; -} - -static LogicalResult lowerIntToPtrOps(func::FuncOp func, MLIRContext *ctx) { - DefaultInlineVector intToPtrs; - func.walk([&](mlir::pto::IntToPtrOp op) { intToPtrs.push_back(op); }); - - for (auto op : intToPtrs) { - if (!isa(op.getResult().getType())) - continue; - - auto targetTy = - dyn_cast(convertPTOTypeToMemRef(op.getResult().getType())); - if (!targetTy) { - op.emitError("failed to convert inttoptr result to memref type"); - return failure(); - } - - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - auto lowered = - rewriter.create(op.getLoc(), targetTy, - op.getAddr()); - lowered->setAttrs(op->getAttrs()); - rewriter.replaceOp(op, lowered.getResult()); - } - - return success(); -} - -static LogicalResult lowerPtrToIntOps(func::FuncOp func, MLIRContext *ctx) { - DefaultInlineVector ptrToInts; - func.walk([&](mlir::pto::PtrToIntOp op) { ptrToInts.push_back(op); }); - - for (auto op : ptrToInts) { - Value source = op.getPtr(); - if (source.getDefiningOp()) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - FailureOr byteAddress = - materializePtrToIntAddPtrAddress(rewriter, op.getLoc(), op, source); - if (failed(byteAddress)) - return failure(); - rewriter.replaceOp(op, *byteAddress); - continue; - } - - if (isa(source.getType())) - continue; - } - - DefaultInlineVector remaining; - func.walk([&](mlir::pto::PtrToIntOp op) { - if (isa(op.getPtr().getType())) - remaining.push_back(op); - }); - for (auto op : remaining) { - op.emitError("ptrtoint source could not be lowered to a GM memref"); - return failure(); - } - - return success(); -} - -[[maybe_unused]] static LogicalResult lowerMakeTensorViewOps(func::FuncOp func, MLIRContext *ctx) { - DefaultInlineVector makeViews; - func.walk([&](mlir::pto::MakeTensorViewOp op) { makeViews.push_back(op); }); - - for (auto op : makeViews) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); - - Value baseBuf = op.getOperand(0); - OpFoldResult off0 = rewriter.getIndexAttr(0); - bool foldedAddPtr = false; - { - Value cur = baseBuf; - Value totalOffset; - while (auto add = cur.getDefiningOp()) { - foldedAddPtr = true; - Value off = ensureIndex(rewriter, loc, add.getOperand(1), add); - totalOffset = totalOffset ? rewriter.create(loc, totalOffset, off) - : off; - cur = add.getOperand(0); - } - if (cur != baseBuf) { - baseBuf = cur; - off0 = totalOffset ? OpFoldResult(totalOffset) : off0; - } - } - - auto baseMr = dyn_cast(baseBuf.getType()); - if (!baseMr) { - op.emitError("make_tensor_view base must be memref"); - return failure(); - } - - size_t rank = op.getShape().size(); - int64_t dyn = ShapedType::kDynamic; - SmallVector dynStrides(rank, dyn); - auto layout = - StridedLayoutAttr::get(ctx, /*offset=*/dyn, /*strides=*/dynStrides); - SmallVector dynShape(rank, dyn); - auto mrTy = MemRefType::get(dynShape, baseMr.getElementType(), layout, - baseMr.getMemorySpace()); - - SmallInlineVector sizes; - for (Value value : op.getShape()) - sizes.push_back(ensureIndex(rewriter, loc, value, op)); - SmallInlineVector strides; - for (Value value : op.getStrides()) - strides.push_back(ensureIndex(rewriter, loc, value, op)); - - auto rc = rewriter.create(loc, mrTy, baseBuf, off0, - sizes, strides); - if (foldedAddPtr) - rc->setAttr("pto.addptr_trace", rewriter.getUnitAttr()); - if (auto layoutAttr = op.getLayoutAttr()) - rc->setAttr("layout", layoutAttr); - rewriter.replaceOp(op, rc.getResult()); - } - return success(); -} - -[[maybe_unused]] static LogicalResult lowerTensorViewDimOps(func::FuncOp func, MLIRContext *ctx) { - DefaultInlineVector tvDims; - func.walk([&](mlir::pto::GetTensorViewDimOp op) { tvDims.push_back(op); }); - - for (auto op : tvDims) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Value view = op.getTensorView(); - auto mrTy = dyn_cast(view.getType()); - if (!mrTy) - continue; - Value dim = rewriter.create(op.getLoc(), view, op.getDimIndex()); - rewriter.replaceOp(op, dim); - } - return success(); -} - -[[maybe_unused]] static LogicalResult foldAddPtrIntoScalarOps(func::FuncOp func, MLIRContext *ctx) { - DefaultInlineVector loadScalars; - func.walk([&](mlir::pto::LoadScalarOp op) { loadScalars.push_back(op); }); - for (auto op : loadScalars) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); - - Value base = op.getPtr(); - Value totalOffset = ensureIndex(rewriter, loc, op.getOffset(), op); - bool foldedAddPtr = foldAddPtrChainIntoOffset(rewriter, loc, base, totalOffset); - if (foldedAddPtr) { - auto newOp = - rewriter.create(loc, op.getValue().getType(), base, - totalOffset); - rewriter.replaceOp(op, newOp.getValue()); - } - } - - DefaultInlineVector storeScalars; - func.walk([&](mlir::pto::StoreScalarOp op) { storeScalars.push_back(op); }); - for (auto op : storeScalars) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); - - Value base = op.getPtr(); - Value totalOffset = ensureIndex(rewriter, loc, op.getOffset(), op); - bool foldedAddPtr = foldAddPtrChainIntoOffset(rewriter, loc, base, totalOffset); - if (foldedAddPtr) { - rewriter.create(loc, base, totalOffset, op.getValue()); - rewriter.eraseOp(op); - } - } - - DefaultInlineVector addPtrs; - func.walk([&](mlir::pto::AddPtrOp op) { addPtrs.push_back(op.getOperation()); }); - bool changed = true; - while (changed) { - changed = false; - for (auto &op : addPtrs) { - if (!op) - continue; - if (op->use_empty()) { - op->erase(); - op = nullptr; - changed = true; - } - } - } - for (Operation *op : addPtrs) { - if (!op) - continue; - op->emitError( - "addptr must feed make_tensor_view or load/store_scalar for lowering"); - return failure(); - } - return success(); -} - -static LogicalResult lowerPartitionViewOps(func::FuncOp func, MLIRContext *ctx) { - DefaultInlineVector partitionViews; - func.walk([&](mlir::pto::PartitionViewOp op) { partitionViews.push_back(op); }); - - for (auto op : partitionViews) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); - Value src = op.getOperand(0); - auto srcMrTy = dyn_cast(src.getType()); - if (!srcMrTy) - continue; - int64_t rank = srcMrTy.getRank(); - - SmallVector staticSizes; - SmallVector mixedSizes; - for (Value size : op.getSizes()) { - IntegerAttr constAttr; - bool isStatic = false; - if (auto cOp = size.getDefiningOp()) { - constAttr = rewriter.getIndexAttr(cOp.value()); - isStatic = true; - } else if (auto cInt = size.getDefiningOp()) { - constAttr = rewriter.getIndexAttr(cInt.value()); - isStatic = true; - } - - if (isStatic) { - mixedSizes.push_back(constAttr); - staticSizes.push_back(constAttr.getInt()); - } else { - mixedSizes.push_back(ensureIndex(rewriter, loc, size, op)); - staticSizes.push_back(ShapedType::kDynamic); - } - } - - SmallVector mixedOffsets; - for (Value offset : op.getOffsets()) { - appendMixedIndex(rewriter, loc, offset, op, mixedOffsets); - } - - int64_t dyn = ShapedType::kDynamic; - SmallVector dynStrides(rank, dyn); - auto layout = StridedLayoutAttr::get(ctx, dyn, dynStrides); - auto resTy = MemRefType::get(staticSizes, srcMrTy.getElementType(), layout, - srcMrTy.getMemorySpace()); - - SmallVector mixedStrides(rank, rewriter.getIndexAttr(1)); - auto sv = rewriter.create(loc, resTy, src, mixedOffsets, - mixedSizes, mixedStrides); - if (Operation *srcDef = src.getDefiningOp()) { - if (auto layoutAttr = srcDef->getAttrOfType("layout")) - sv->setAttr("layout", layoutAttr); - } - rewriter.replaceOp(op, sv.getResult()); - } - return success(); -} - -static LogicalResult lowerSubViewOps(func::FuncOp func, MLIRContext *ctx) { - DefaultInlineVector subViews; - func.walk([&](mlir::pto::SubViewOp op) { subViews.push_back(op); }); - - for (auto op : subViews) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); - auto resultTileTy = - dyn_cast(op.getResult().getType()); - Value src = op->getOperand(0); - auto srcMrTy = dyn_cast(src.getType()); - if (!srcMrTy) { - op.emitError("pto.subview source must be lowered to memref first"); - return failure(); - } - - ArrayAttr sizeAttr = op.getSizes(); - SmallVector staticSizes; - SmallVector mixedSizes; - for (Attribute attr : sizeAttr) { - int64_t size = cast(attr).getInt(); - staticSizes.push_back(size); - mixedSizes.push_back(rewriter.getIndexAttr(size)); - } - - SmallVector mixedOffsets; - for (Value offset : op.getOffsets()) { - appendMixedIndex(rewriter, loc, offset, op, mixedOffsets); - } - - auto configAttr = lookupConfig(src); - if (!configAttr) - configAttr = pto::TileBufConfigAttr::getDefault(ctx); - - TileLayoutInfo layoutInfo; - if (!computeTileLayoutInfo(configAttr, srcMrTy.getElementType(), - srcMrTy.getShape(), layoutInfo)) { - op.emitError("unsupported tile layout for pto.subview"); - return failure(); - } - - if (layoutInfo.boxed) { - if (staticSizes.size() != kTileRank2D || - op.getOffsets().size() != kTileRank2D) { - op.emitError("boxed layout subview expects 2D sizes/offsets"); - return failure(); - } - if (!checkMultipleOf(op, staticSizes[0], layoutInfo.innerRows, "row size") || - !checkMultipleOf(op, staticSizes[1], layoutInfo.innerCols, "col size")) { - return failure(); - } - - int64_t off0 = 0; - int64_t off1 = 0; - bool off0Const = getConstIndexValue(op.getOffsets()[0], off0); - bool off1Const = getConstIndexValue(op.getOffsets()[1], off1); - if (off0Const && - !checkMultipleOf(op, off0, layoutInfo.innerRows, "row offset")) { - return failure(); - } - if (off1Const && - !checkMultipleOf(op, off1, layoutInfo.innerCols, "col offset")) { - return failure(); - } - - } - - SmallVector srcStrides; - int64_t srcOffset = ShapedType::kDynamic; - if (failed(getStridesAndOffset(srcMrTy, srcStrides, srcOffset))) - srcStrides = computeCompactStrides(srcMrTy.getShape()); - - // Keep parent physical shape + strides for bound tile semantics. - auto resultLayout = - StridedLayoutAttr::get(ctx, ShapedType::kDynamic, srcStrides); - auto parentShape = srcMrTy.getShape(); - auto resultMemRefType = - MemRefType::get(parentShape, srcMrTy.getElementType(), resultLayout, - srcMrTy.getMemorySpace()); - - // Intermediate memref.subview keeps logical subview size. - auto subViewMemRefType = - MemRefType::get(staticSizes, srcMrTy.getElementType(), resultLayout, - srcMrTy.getMemorySpace()); - - SmallVector mixedStrides(staticSizes.size(), - rewriter.getIndexAttr(1)); - auto sv = rewriter.create(loc, subViewMemRefType, src, - mixedOffsets, mixedSizes, - mixedStrides); - - Value vRow; - Value vCol; - if (!staticSizes.empty()) - vRow = clampSubViewValidDim(rewriter, loc, op.getValidRow(), - staticSizes[0], op); - if (staticSizes.size() > 1) - vCol = clampSubViewValidDim(rewriter, loc, op.getValidCol(), - staticSizes[1], op); - - auto bindOp = rewriter.create( - loc, resultMemRefType, sv.getResult(), vRow ? vRow : Value(), - vCol ? vCol : Value(), configAttr); - markForceDynamicValidShape(bindOp, - resultTileTy && resultTileTy.hasDynamicValid(), - ctx); - bindOp->setAttr("pto.view_semantics", rewriter.getStringAttr("subview")); - rewriter.replaceOp(op, bindOp.getResult()); - } - return success(); -} - -static Value buildTileBufViewLikeValue(Operation *anchorOp, Value src, - mlir::pto::TileBufType tbTy, - StringRef viewSemantics, - MLIRContext *ctx) { - Location loc = anchorOp->getLoc(); - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(anchorOp); - - auto srcMrTy = dyn_cast(src.getType()); - if (!srcMrTy) { - anchorOp->emitError("tile_buf view op src must be lowered to memref first"); - return Value(); - } - - auto targetType = dyn_cast(convertPTOTypeToMemRef(tbTy)); - if (!targetType) { - anchorOp->emitError("failed to convert tile_buf type to memref type"); - return Value(); - } - for (int64_t dim : targetType.getShape()) { - if (dim == ShapedType::kDynamic) { - anchorOp->emitError("dynamic shapes are not supported for tile_buf view ops"); - return Value(); - } - } - - Value parentVRow; - Value parentVCol; - lookupValidDims(src, parentVRow, parentVCol); - Value vRow = parentVRow; - Value vCol = parentVCol; - materializeStaticValidDims(rewriter, loc, tbTy, vRow, vCol); - - auto configAttr = tbTy.getConfigAttr(); - if (!configAttr) - configAttr = pto::TileBufConfigAttr::getDefault(ctx); - - auto bindOp = rewriter.create( - loc, targetType, src, vRow ? vRow : Value(), vCol ? vCol : Value(), - configAttr); - markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); - if (!viewSemantics.empty()) - bindOp->setAttr("pto.view_semantics", rewriter.getStringAttr(viewSemantics)); - return bindOp.getResult(); -} - -static LogicalResult lowerTileBufViewLikeOps(func::FuncOp func, MLIRContext *ctx) { - DefaultInlineVector reshapes; - func.walk([&](mlir::pto::TReshapeOp op) { reshapes.push_back(op); }); - for (auto op : reshapes) { - auto tbTy = dyn_cast(op.getResult().getType()); - if (!tbTy) { - op.emitError("treshape result must be tile_buf type"); - return failure(); - } - Value lowered = buildTileBufViewLikeValue(op, op->getOperand(0), tbTy, - "treshape", ctx); - if (!lowered) - return failure(); - IRRewriter rewriter(ctx); - rewriter.replaceOp(op, lowered); - } - - DefaultInlineVector bitcasts; - func.walk([&](mlir::pto::BitcastOp op) { bitcasts.push_back(op); }); - for (auto op : bitcasts) { - auto tbTy = dyn_cast(op.getResult().getType()); - if (!tbTy) { - op.emitError("bitcast result must be tile_buf type"); - return failure(); - } - Value lowered = buildTileBufViewLikeValue(op, op->getOperand(0), tbTy, - "bitcast", ctx); - if (!lowered) - return failure(); - IRRewriter rewriter(ctx); - rewriter.replaceOp(op, lowered); - } - return success(); -} - -// ============================================================================= -// The Pass Implementation -// ============================================================================= - -struct PTOViewToMemrefPass - : public mlir::pto::impl::PTOViewToMemrefBase { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PTOViewToMemrefPass) - - void runOnOperation() override { - ModuleOp mod = getOperation(); - MLIRContext *ctx = &getContext(); - - for (auto func : mod.getOps()) { - if (func.isExternal()) continue; - - // ------------------------------------------------------------------ - // Stage 0: ensure inttoptr values remain scalar-load/store only. - // ------------------------------------------------------------------ - if (failed(validateIntToPtrUses(func))) { - signalPassFailure(); - return; - } - - Block &entry = func.front(); - auto fnTy = func.getFunctionType(); - - // ------------------------------------------------------------------ - // Stage 0.10: Rewrite Function Signature - // ------------------------------------------------------------------ - SmallVector newInputs; - for (Type t : fnTy.getInputs()) newInputs.push_back(convertPTOTypeToMemRef(t)); - - SmallVector newResults; - for (Type t : fnTy.getResults()) newResults.push_back(convertPTOTypeToMemRef(t)); - - // Update entry block arguments - for (unsigned i = 0; i < entry.getNumArguments(); ++i) { - if (entry.getArgument(i).getType() != newInputs[i]) { - entry.getArgument(i).setType(newInputs[i]); - } - } - - // Update function type - func.setFunctionType(FunctionType::get(ctx, newInputs, newResults)); - - // ------------------------------------------------------------------ - // Stage 0.20: lower pto.inttoptr result types to GM memrefs. - // ------------------------------------------------------------------ - if (failed(lowerIntToPtrOps(func, ctx))) { - signalPassFailure(); - return; - } - - // ------------------------------------------------------------------ - // Stage 0.30: materialize pto.ptrtoint(addptr ...) byte offsets. - // ------------------------------------------------------------------ - if (failed(lowerPtrToIntOps(func, ctx))) { - signalPassFailure(); - return; - } - - // ------------------------------------------------------------------ - // Stage 0.5: lower pto.alloc_tile -> memref.alloc + pto.bind_tile - // ------------------------------------------------------------------ - DefaultInlineVector allocTiles; - func.walk([&](mlir::pto::AllocTileOp op) { allocTiles.push_back(op); }); - - for (auto op : allocTiles) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); - - auto tbTy = dyn_cast(op.getResult().getType()); - if (!tbTy) continue; - - // 1. 获取 Shape 和 ElementType - SmallInlineVector shape(tbTy.getShape().begin(), tbTy.getShape().end()); - Type elemTy = tbTy.getElementType(); - - // 2. 计算 Strides (layout-aware when possible) - SmallVector strides; - TileLayoutInfo info; - if (computeTileLayoutInfo(tbTy.getConfigAttr(), elemTy, shape, info)) { - strides = {info.rowStride, info.colStride}; - } else { - strides.resize(shape.size()); - int64_t s = 1; - for (int i = (int)shape.size() - 1; i >= 0; --i) { - strides[i] = s; - if (shape[i] != ShapedType::kDynamic) s *= shape[i]; - } - } - - // 3. 构造 [BindTile 输出] 的动态类型 (Offset: ?) - // 这必须与 convertPTOTypeToMemRef 返回的类型一致,以便与 Subview 兼容 - auto targetLayout = - StridedLayoutAttr::get(ctx, ShapedType::kDynamic, strides); // offset = ? - auto targetType = - MemRefType::get(shape, elemTy, targetLayout, tbTy.getMemorySpace()); - - // 4. Preserve tile valid dims (v_row / v_col). - // - // `pto.alloc_tile` encodes the valid shape in the result TileBufType - // (e.g. acc tile may be rows=16 but v_row=1). The alloc op itself does - // not necessarily carry explicit operands for static valid dims, so we - // must materialize them from the type to keep them through - // tile_buf -> memref lowering. - // - // For dynamically valid tiles (validShape == [-1, -1]), preserve the - // runtime operands if present. - Value vRow = op.getValidRow(); - Value vCol = op.getValidCol(); - // TileBuf valid dims use a negative sentinel (e.g. '?' / -1), which is - // distinct from MLIR's ShapedType::kDynamic (INT64_MIN). Treat any - // negative value as dynamic here. - materializeStaticValidDims(rewriter, loc, tbTy, vRow, vCol); - - // 5. 获取 Config (保持不变) - auto configAttr = tbTy.getConfigAttr(); - if (!configAttr) configAttr = pto::TileBufConfigAttr::getDefault(ctx); - - // 6. If alloc_tile provides an explicit address, keep the original - // pointer_cast lowering intact and additionally rebind through - // pto.bind_tile. PointerCastOp continues to carry the tile metadata - // used by existing lowering paths, while BindTileOp provides the - // unified anchor EmitC uses to recover tile_buf information. - if (Value addr = op.getAddr()) { - auto pc = rewriter.create( - loc, targetType, ValueRange{addr}, vRow ? vRow : Value(), - vCol ? vCol : Value(), configAttr); - markForceDynamicValidShape(pc, tbTy.hasDynamicValid(), ctx); - auto bindOp = rewriter.create( - loc, targetType, pc.getResult(), vRow ? vRow : Value(), - vCol ? vCol : Value(), configAttr); - markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); - rewriter.replaceOp(op, bindOp.getResult()); - continue; - } - - // 7. Otherwise allocate a concrete memref buffer and bind tile. - // memref.alloc 要求明确的 layout,不能是动态 offset。 - auto allocLayout = StridedLayoutAttr::get(ctx, 0, strides); // offset = 0 - auto allocType = MemRefType::get(shape, elemTy, allocLayout, tbTy.getMemorySpace()); - Value alloc = rewriter.create(loc, allocType); - - // BindTileOp 的 Builder 会自动处理空的 Value,将其视为静态维度 - auto bindOp = rewriter.create( - loc, targetType, alloc, vRow ? vRow : Value(), vCol ? vCol : Value(), - configAttr); - markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); - - rewriter.replaceOp(op, bindOp.getResult()); - } - - // ------------------------------------------------------------------ - // Stage 0.75: lower pto.declare_tile -> pto.declare_tile_memref + - // pto.bind_tile - // ------------------------------------------------------------------ - DefaultInlineVector declaredTiles; - func.walk([&](mlir::pto::DeclareTileOp op) { declaredTiles.push_back(op); }); - - for (auto op : declaredTiles) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); - - auto tbTy = dyn_cast(op.getTile().getType()); - if (!tbTy) { - op.emitError("declare_tile result must be tile_buf type"); - signalPassFailure(); - return; - } - - auto targetType = dyn_cast(convertPTOTypeToMemRef(tbTy)); - if (!targetType) { - op.emitError("failed to convert declare_tile result to memref type"); - signalPassFailure(); - return; - } - - auto configAttr = tbTy.getConfigAttr(); - if (!configAttr) - configAttr = pto::TileBufConfigAttr::getDefault(ctx); - - Value vRow; - Value vCol; - materializeStaticValidDims(rewriter, loc, tbTy, vRow, vCol); - - auto declaredMemRef = - rewriter.create(loc, targetType); - auto bindOp = rewriter.create( - loc, targetType, declaredMemRef.getResult(), - vRow ? vRow : Value(), vCol ? vCol : Value(), configAttr); - markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); - - rewriter.replaceOp(op, bindOp.getResult()); - } - - // ------------------------------------------------------------------ - // Stage 0.8: normalize pto.tassign result type to match tile operand - // after tile_buf -> memref lowering (required for verifier consistency). - // ------------------------------------------------------------------ - DefaultInlineVector tassignOps; - func.walk([&](mlir::pto::TAssignOp op) { tassignOps.push_back(op); }); - for (auto op : tassignOps) { - Type targetTy = op.getTile().getType(); - if (op.getResult().getType() == targetTy) - continue; - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - auto normalized = - rewriter.create(op.getLoc(), targetTy, op.getTile(), - op.getAddr()); - rewriter.replaceOp(op, normalized.getResult()); - } - - // ------------------------------------------------------------------ - // Stage 1: Lower pto.make_tensor_view -> memref.reinterpret_cast - // ------------------------------------------------------------------ - DefaultInlineVector makeViews; - func.walk([&](mlir::pto::MakeTensorViewOp op) { makeViews.push_back(op); }); - - for (auto op : makeViews) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); - - Value baseBuf = op.getOperand(0); - OpFoldResult off0 = rewriter.getIndexAttr(0); - - // Fold pto.addptr chains into the view base to avoid nested reinterpret_cast. - bool foldedAddPtr = false; - { - Value cur = baseBuf; - Value totalOffset; - while (auto add = cur.getDefiningOp()) { - foldedAddPtr = true; - Value off = ensureIndex(rewriter, loc, add.getOperand(1), add); - if (totalOffset) - totalOffset = rewriter.create(loc, totalOffset, off); - else - totalOffset = off; - cur = add.getOperand(0); - } - if (cur != baseBuf) { - baseBuf = cur; - off0 = totalOffset ? OpFoldResult(totalOffset) : off0; - } - } - - auto baseMr = dyn_cast(baseBuf.getType()); - if (!baseMr) { - op.emitError("make_tensor_view base must be memref"); signalPassFailure(); return; - } - - // [修复] 获取动态 Rank (根据 shape 输入的数量) - size_t rank = op.getShape().size(); - - // Construct target type with dynamic offset/strides - Type elemTy = baseMr.getElementType(); - int64_t dyn = ShapedType::kDynamic; - - // [修复] 构建 N 维 Strided Layout - // strides 数组长度必须等于 rank - SmallVector dynStrides(rank, dyn); - auto layout = StridedLayoutAttr::get(ctx, /*offset=*/dyn, /*strides=*/dynStrides); - - // [修复] 构建 N 维 Shape - SmallVector dynShape(rank, dyn); - auto mrTy = MemRefType::get(dynShape, elemTy, layout, baseMr.getMemorySpace()); - - SmallInlineVector sizes; - for (Value v : op.getShape()) sizes.push_back(ensureIndex(rewriter, loc, v, op)); - - SmallInlineVector strides; - for (Value v : op.getStrides()) strides.push_back(ensureIndex(rewriter, loc, v, op)); - - auto rc = rewriter.create( - loc, mrTy, baseBuf, off0, sizes, strides); - if (foldedAddPtr) { - rc->setAttr("pto.addptr_trace", rewriter.getUnitAttr()); - } - if (auto layoutAttr = op.getLayoutAttr()) { - rc->setAttr("layout", layoutAttr); - } - - rewriter.replaceOp(op, rc.getResult()); - } - - // ------------------------------------------------------------------ - // Stage 1.25: Lower pto.get_tensor_view_dim -> memref.dim - // ------------------------------------------------------------------ - DefaultInlineVector tvDims; - func.walk([&](mlir::pto::GetTensorViewDimOp op) { tvDims.push_back(op); }); - - for (auto op : tvDims) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); - - Value view = op.getTensorView(); - auto mrTy = dyn_cast(view.getType()); - if (!mrTy) - continue; // leave it to later passes if it hasn't been lowered yet - - Value dimIdx = op.getDimIndex(); - Value dim = rewriter.create(loc, view, dimIdx); - rewriter.replaceOp(op, dim); - } - - // ------------------------------------------------------------------ - // Stage 1.3: Lower pto.partition_view -> memref.subview - // ------------------------------------------------------------------ - if (failed(lowerPartitionViewOps(func, ctx))) { - signalPassFailure(); - return; - } - - // ------------------------------------------------------------------ - // Stage 1.35: Lower pto.subview -> memref.subview + pto.bind_tile - // ------------------------------------------------------------------ - if (failed(lowerSubViewOps(func, ctx))) { - signalPassFailure(); - return; - } - - // ------------------------------------------------------------------ - // Stage 1.4: Lower tile_buf view-like ops (treshape/bitcast) - // ------------------------------------------------------------------ - if (failed(lowerTileBufViewLikeOps(func, ctx))) { - signalPassFailure(); - return; - } - - // ------------------------------------------------------------------ - // Stage 1.5: Fold pto.addptr chains into load/store_scalar. - // ------------------------------------------------------------------ - DefaultInlineVector loadScalars; - func.walk([&](mlir::pto::LoadScalarOp op) { loadScalars.push_back(op); }); - - for (auto op : loadScalars) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); - - Value base = op.getPtr(); - Value totalOffset = ensureIndex(rewriter, loc, op.getOffset(), op); - - bool foldedAddPtr = false; - while (auto add = base.getDefiningOp()) { - foldedAddPtr = true; - Value off = ensureIndex(rewriter, loc, add.getOperand(1), add); - if (totalOffset) - totalOffset = rewriter.create(loc, totalOffset, off); - else - totalOffset = off; - base = add.getOperand(0); - } - - if (foldedAddPtr) { - auto newOp = rewriter.create( - loc, op.getValue().getType(), base, totalOffset); - rewriter.replaceOp(op, newOp.getValue()); - } - } - - DefaultInlineVector storeScalars; - func.walk([&](mlir::pto::StoreScalarOp op) { storeScalars.push_back(op); }); - - for (auto op : storeScalars) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); - - Value base = op.getPtr(); - Value totalOffset = ensureIndex(rewriter, loc, op.getOffset(), op); - - bool foldedAddPtr = false; - while (auto add = base.getDefiningOp()) { - foldedAddPtr = true; - Value off = ensureIndex(rewriter, loc, add.getOperand(1), add); - if (totalOffset) - totalOffset = rewriter.create(loc, totalOffset, off); - else - totalOffset = off; - base = add.getOperand(0); - } - - if (foldedAddPtr) { - rewriter.create( - loc, base, totalOffset, op.getValue()); - rewriter.eraseOp(op); - } - } - - // ------------------------------------------------------------------ - // Stage 1.75: Fold addptr used by initialize_l2g2l_pipe(gm_addr). - // This keeps IR well-typed after function arguments are rewritten from - // !pto.ptr to memref. - // ------------------------------------------------------------------ - bool foldedPipeInitAddPtr = true; - while (foldedPipeInitAddPtr) { - foldedPipeInitAddPtr = false; - DefaultInlineVector addPtrsForPipeInit; - func.walk([&](mlir::pto::AddPtrOp op) { - bool eligible = !op->use_empty(); - for (Operation *user : op->getUsers()) { - auto init = dyn_cast(user); - if (!init || init.getGmAddr() != op->getResult(0)) { - eligible = false; - break; - } - } - if (eligible) - addPtrsForPipeInit.push_back(op); - }); - - for (auto op : addPtrsForPipeInit) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); - - Value base = op->getOperand(0); - Value totalOffset = ensureIndex(rewriter, loc, op->getOperand(1), op); - while (auto add = base.getDefiningOp()) { - Value off = ensureIndex(rewriter, loc, add->getOperand(1), add); - totalOffset = rewriter.create(loc, totalOffset, off); - base = add->getOperand(0); - } - - auto baseMrTy = dyn_cast(base.getType()); - if (!baseMrTy || baseMrTy.getRank() != 1) - continue; - - int64_t dyn = ShapedType::kDynamic; - auto layout = StridedLayoutAttr::get(ctx, dyn, {dyn}); - auto targetTy = MemRefType::get({dyn}, baseMrTy.getElementType(), layout, - baseMrTy.getMemorySpace()); - SmallVector sizes{rewriter.getIndexAttr(1)}; - SmallVector strides{rewriter.getIndexAttr(1)}; - auto rc = rewriter.create( - loc, targetTy, base, OpFoldResult(totalOffset), sizes, strides); - rc->setAttr("pto.addptr_trace", rewriter.getUnitAttr()); - rewriter.replaceOp(op, rc.getResult()); - foldedPipeInitAddPtr = true; - } - } - - // Clean up: addptr should be folded into make_tensor_view. - DefaultInlineVector addPtrs; - func.walk([&](mlir::pto::AddPtrOp op) { addPtrs.push_back(op.getOperation()); }); - bool changed = true; - while (changed) { - changed = false; - for (auto &op : addPtrs) { - if (!op) - continue; - if (op->use_empty()) { - op->erase(); - op = nullptr; - changed = true; - } - } - } - for (auto *op : addPtrs) { - if (!op) - continue; - op->emitError("addptr must feed make_tensor_view, initialize_l2g2l_pipe(gm_addr) or load/store_scalar for lowering"); - signalPassFailure(); - return; - } - - // ------------------------------------------------------------------ - // Stage 3: Rewrite Compute Ops - // [关键] 全面使用 op->getOperand(i) 避免 Typed Accessor Crash - // ------------------------------------------------------------------ - - // --- TLoadOp [Src, Dst] --- - DefaultInlineVector loads; - func.walk([&](mlir::pto::TLoadOp op) { loads.push_back(op); }); - for (auto op : loads) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op->getOperand(0); - Value dst = op->getOperand(1); - - auto newOp = - rewriter.create(op.getLoc(), TypeRange{}, src, dst); - newOp->setAttrs(op->getAttrs()); - rewriter.replaceOp(op, newOp->getResults()); - } - - // --- TStoreOp [Src, Dst] --- - DefaultInlineVector storeops; - func.walk([&](mlir::pto::TStoreOp op) { storeops.push_back(op); }); - for (auto op : storeops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op->getOperand(0); - Value dst = op->getOperand(1); - Value preQuant = op.getPreQuantScalar(); - - pto::TStoreOp newOp; - if (preQuant) { - newOp = rewriter.create(op.getLoc(), TypeRange{}, - src, dst, preQuant); - } else { - newOp = rewriter.create(op.getLoc(), TypeRange{}, - src, dst, Value{}); - } - newOp->setAttrs(op->getAttrs()); - rewriter.replaceOp(op, newOp->getResults()); - } - - // --- TTransOp [Src, Tmp, Dst] --- - DefaultInlineVector trans; - func.walk([&](mlir::pto::TTransOp op) { trans.push_back(op); }); - for (auto op : trans) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex)); - } - - // --- TExpOp [Src, Dst] --- - DefaultInlineVector exp; - func.walk([&](mlir::pto::TExpOp op) { exp.push_back(op); }); - for (auto op : exp) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, op->getOperand(0), op->getOperand(1)); - } - - // --- TMulOp [Src, Scalar, Dst] --- - DefaultInlineVector mul; - func.walk([&](mlir::pto::TMulOp op) { mul.push_back(op); }); - for (auto op : mul) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, op->getOperand(0), op.getOperand(1), - op->getOperand(kThirdOperandIndex)); - } - - // --- TMulSOp [Src, Scalar, Dst] --- - DefaultInlineVector muls; - func.walk([&](mlir::pto::TMulSOp op) { muls.push_back(op); }); - for (auto op : muls) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, op->getOperand(0), op.getScalar(), - op->getOperand(kThirdOperandIndex)); - } - - // --- TAddOp [Src0, Src1, Dst] --- - DefaultInlineVector addops; - func.walk([&](mlir::pto::TAddOp op) { addops.push_back(op); }); - for (auto op : addops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex)); - } - - // --- TMatmulOp [Lhs, Rhs, Dst] (no optional bias in ODS) --- - DefaultInlineVector matmuls; - func.walk([&](mlir::pto::TMatmulOp op) { matmuls.push_back(op); }); - for (auto op : matmuls) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Value lhs = op->getOperand(0); - Value rhs = op->getOperand(1); - Value dst = op->getOperand(kThirdOperandIndex); - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, lhs, rhs, dst, op.getAccPhaseAttr()); - } - - // --- TMatmulAccOp [Acc, Lhs, Rhs, Dst] --- - DefaultInlineVector matmulAccs; - func.walk([&](mlir::pto::TMatmulAccOp op) { matmulAccs.push_back(op); }); - for (auto op : matmulAccs) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex), op.getAccPhaseAttr()); - } - - // --- TMatmulBiasOp [Acc, Lhs, Rhs, Bias, Dst] --- - DefaultInlineVector matmulBiass; - func.walk([&](mlir::pto::TMatmulBiasOp op) { matmulBiass.push_back(op); }); - for (auto op : matmulBiass) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex)); - } - - // --- TMatmulMxOp--- - DefaultInlineVector matmulMxs; - func.walk([&](mlir::pto::TMatmulMxOp op) { matmulMxs.push_back(op); }); - for (auto op : matmulMxs) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex), - op->getOperand(kFifthOperandIndex)); - } - - // --- TMatmulMxAccOp --- - DefaultInlineVector matmulMxAccs; - func.walk([&](mlir::pto::TMatmulMxAccOp op) { matmulMxAccs.push_back(op); }); - for (auto op : matmulMxAccs) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex), - op->getOperand(kFifthOperandIndex), - op->getOperand(kSixthOperandIndex)); - } - - // --- TMatmulMxBiasOp --- - DefaultInlineVector matmulMxBiass; - func.walk([&](mlir::pto::TMatmulMxBiasOp op) { matmulMxBiass.push_back(op); }); - for (auto op : matmulMxBiass) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex), - op->getOperand(kFifthOperandIndex), - op->getOperand(kSixthOperandIndex)); - } - - // --- TGemvOp [Lhs, Rhs, Dst] --- - DefaultInlineVector gemvs; - func.walk([&](mlir::pto::TGemvOp op) { gemvs.push_back(op); }); - for (auto op : gemvs) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value lhs = op->getOperand(0); - Value rhs = op->getOperand(1); - Value dst = op->getOperand(kThirdOperandIndex); - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, lhs, rhs, dst); - } - - // --- TGemvAccOp [Acc, Lhs, Rhs, Dst] --- - DefaultInlineVector gemvAccs; - func.walk([&](mlir::pto::TGemvAccOp op) { gemvAccs.push_back(op); }); - for (auto op : gemvAccs) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex)); - } - - // --- TGemvBiasOp [Acc, Lhs, Rhs, Bias, Dst] --- - DefaultInlineVector gemvBiass; - func.walk([&](mlir::pto::TGemvBiasOp op) { gemvBiass.push_back(op); }); - for (auto op : gemvBiass) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex)); - } - - // --- TGemvMxOp [A, AScale, B, BScale, Dst] --- - DefaultInlineVector gemvMxs; - func.walk([&](mlir::pto::TGemvMxOp op) { gemvMxs.push_back(op); }); - for (auto op : gemvMxs) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex), - op->getOperand(kFifthOperandIndex)); - } - - // --- TGemvMxAccOp [CIn, A, AScale, B, BScale, Dst] --- - DefaultInlineVector gemvMxAccs; - func.walk([&](mlir::pto::TGemvMxAccOp op) { gemvMxAccs.push_back(op); }); - for (auto op : gemvMxAccs) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex), - op->getOperand(kFifthOperandIndex), - op->getOperand(kSixthOperandIndex)); - } - - // --- TGemvMxBiasOp [A, AScale, B, BScale, Bias, Dst] --- - DefaultInlineVector gemvMxBiass; - func.walk([&](mlir::pto::TGemvMxBiasOp op) { gemvMxBiass.push_back(op); }); - for (auto op : gemvMxBiass) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex), - op->getOperand(kFifthOperandIndex), - op->getOperand(kSixthOperandIndex)); - } - - // --- TMovOp [Src, Dst] --- - DefaultInlineVector movs; - func.walk([&](mlir::pto::TMovOp op) { movs.push_back(op); }); - for (auto op : movs) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, op.getSrc(), op.getDst(), op.getFp(), - op.getPreQuantScalar(), op.getAccToVecModeAttr(), - op.getReluPreModeAttr()); - } - - DefaultInlineVector abseops; - func.walk([&](mlir::pto::TAbsOp op) { abseops.push_back(op); }); - - for (auto op : abseops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - DefaultInlineVector addcops; - func.walk([&](mlir::pto::TAddCOp op) { addcops.push_back(op); }); - - for (auto op : addcops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value src2 = op.getSrc2(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto src2Ty = dyn_cast(src2.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !src2Ty ||!dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - src2, - dst); - } - - DefaultInlineVector addsops; - func.walk([&](mlir::pto::TAddSOp op) { addsops.push_back(op); }); - - for (auto op : addsops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value scalar = op.getScalar(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - scalar, - dst); - } - - DefaultInlineVector addscops; - func.walk([&](mlir::pto::TAddSCOp op) { addscops.push_back(op); }); - - for (auto op : addscops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value scalar = op.getScalar(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - scalar, - src1, - dst); - } - - DefaultInlineVector andops; - func.walk([&](mlir::pto::TAndOp op) { andops.push_back(op); }); - - for (auto op : andops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - dst); - } - - DefaultInlineVector concats; - func.walk([&](mlir::pto::TConcatOp op) { concats.push_back(op); }); - - for (auto op : concats) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - dst); - } - - DefaultInlineVector concatIdxs; - func.walk([&](mlir::pto::TConcatidxOp op) { concatIdxs.push_back(op); }); - - IRRewriter rewriter(ctx); - for (auto op : concatIdxs) { - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value src0Idx = op.getSrc0Idx(); - Value src1Idx = op.getSrc1Idx(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto src0IdxTy = dyn_cast(src0Idx.getType()); - auto src1IdxTy = dyn_cast(src1Idx.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !src0IdxTy || !src1IdxTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - src0Idx, - src1Idx, - dst); - } - - DefaultInlineVector andsops; - func.walk([&](mlir::pto::TAndSOp op) { andsops.push_back(op); }); - - for (auto op : andsops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value scalar = op.getScalar(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - scalar, - dst); - } - - DefaultInlineVector ciops; - func.walk([&](mlir::pto::TCIOp op) { ciops.push_back(op); }); - - for (auto op : ciops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value s = op->getOperand(0); - Value dst = op.getDst(); - bool descending = op.getDescending(); - - auto sTy = dyn_cast(s.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!sTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - s, - dst, - descending); - } - - DefaultInlineVector cmpops; - func.walk([&](mlir::pto::TCmpOp op) { cmpops.push_back(op); }); - - for (auto op : cmpops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - auto newOp = rewriter.create( - op.getLoc(), - TypeRange{}, - src0, - src1, - dst); - - if (auto a = op.getCmpModeAttr()) - newOp->setAttr("cmpMode", a); - - rewriter.replaceOp(op, newOp->getResults()); // 0 results -> OK - } - - DefaultInlineVector cmpsops; - func.walk([&](mlir::pto::TCmpSOp op) { cmpsops.push_back(op); }); - - for (auto op : cmpsops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value scalar = op.getScalar(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - auto scalarTy = scalar.getType(); - bool scalarOk = - isa(scalarTy); // ScalarType in ODS: int/float - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - if (!scalarOk) { - op.emitError("expects scalar to be an integer or float type"); - signalPassFailure(); - return; - } - - auto cmpMode = op.getCmpModeAttr(); - auto newOp = rewriter.create( - op.getLoc(), - TypeRange{}, - src, - scalar, - cmpMode, - dst); - - rewriter.replaceOp(op, newOp->getResults()); // 0 results -> OK - } - - DefaultInlineVector colexpand; - func.walk([&](mlir::pto::TColExpandOp op) { colexpand.push_back(op); }); - - for (auto op : colexpand) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if ( !srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - DefaultInlineVector colmaxops; - func.walk([&](mlir::pto::TColMaxOp op) { colmaxops.push_back(op); }); - - for (auto op : colmaxops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if ( !srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - DefaultInlineVector colminops; - func.walk([&](mlir::pto::TColMinOp op) { colminops.push_back(op); }); - - for (auto op : colminops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if ( !srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - DefaultInlineVector colexpandmulops; - func.walk([&](mlir::pto::TColExpandMulOp op) { - colexpandmulops.push_back(op); - }); - - for (auto op : colexpandmulops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - dst); - } - - DefaultInlineVector colexpandmaxops; - func.walk([&](mlir::pto::TColExpandMaxOp op) { - colexpandmaxops.push_back(op); - }); - - for (auto op : colexpandmaxops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - dst); - } - - DefaultInlineVector colexpandminops; - func.walk([&](mlir::pto::TColExpandMinOp op) { - colexpandminops.push_back(op); - }); - - for (auto op : colexpandminops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - dst); - } - - DefaultInlineVector colsumops; - func.walk([&](mlir::pto::TColSumOp op) { colsumops.push_back(op); }); - - for (auto op : colsumops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - Value tmp = op.getTmp(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("src/dst are not memref yet"); - signalPassFailure(); - return; - } - - // If tmp exists, it must have isBinary attribute - if (tmp) { - auto tmpTy = dyn_cast(tmp.getType()); - if (!tmpTy) { - op.emitError("tmp is not memref yet"); - signalPassFailure(); - return; - } - - // Get isBinary attribute (should exist if tmp exists) - BoolAttr isBinaryAttr = op.getIsBinaryAttr(); - if (!isBinaryAttr) { - isBinaryAttr = BoolAttr::get(ctx, false); - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - tmp, - dst, - isBinaryAttr); - } else { - // Format 1: no tmp, no isBinary - // Use generic builder to avoid adding default isBinary attribute - SmallVector operands = {src, dst}; - SmallVector attrs; - // Copy all attributes except isBinary - for (auto attr : op->getAttrs()) { - if (attr.getName() != "isBinary") { - attrs.push_back(attr); - } - } - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - operands, - attrs); - } - } - - DefaultInlineVector cvtops; - func.walk([&](mlir::pto::TCvtOp op) { cvtops.push_back(op); }); - - for (auto op : cvtops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - auto rmodeAttr = op.getRmodeAttr(); // PTO_RoundModeAttr - auto satModeAttr = op.getSatModeAttr(); - - auto newOp = rewriter.create( - op.getLoc(), - TypeRange{}, - src, - dst, - rmodeAttr, - satModeAttr); - - rewriter.replaceOp(op, newOp->getResults()); - } - - DefaultInlineVector divops; - func.walk([&](mlir::pto::TDivOp op) { divops.push_back(op); }); - - for (auto op : divops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - dst); - } - - DefaultInlineVector divsops; - func.walk([&](mlir::pto::TDivSOp op) { divsops.push_back(op); }); - - for (auto op : divsops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value scale = op.getScalar(); - Value dst = op.getDst(); - - // Check types - they might still be TileBufType or already converted to MemRefType - auto srcTy = dyn_cast(src.getType()); - auto srcTileTy = dyn_cast(src.getType()); - auto scaleTileTy = dyn_cast(scale.getType()); - auto dstTy = dyn_cast(dst.getType()); - auto dstTileTy = dyn_cast(dst.getType()); - - // Determine which operand is tile-like and which is scalar-like. - // Keep the original operand order (set by parser textual form). - // Check if src is memref/tensor/tile (not scalar) - bool srcIsMemref = (srcTy != nullptr || srcTileTy != nullptr || - isa(src.getType()) || - isa(src.getType())); - // Check if scale is memref/tensor/tile (not scalar) - bool scaleIsMemref = (isa(scale.getType()) || - scaleTileTy != nullptr || - isa(scale.getType()) || - isa(scale.getType())); - - // Type validation - ensure we have the right types - if (!srcIsMemref && !scaleIsMemref) { - op.emitError("at least one operand (src or scale) must be tile_buf or memref"); - signalPassFailure(); - return; - } - if (srcIsMemref && scaleIsMemref) { - op.emitError("exactly one operand (src or scale) must be tile_buf or memref, the other must be scalar"); - signalPassFailure(); - return; - } - - if (!dstTy && !dstTileTy) { - op.emitError("dst operand must be tile_buf or memref"); - signalPassFailure(); - return; - } - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - scale, - dst); - } - - DefaultInlineVector expandsops; - func.walk([&](mlir::pto::TExpandsOp op) { expandsops.push_back(op); }); - - for (auto op : expandsops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value scalar = op.getScalar(); - Value dst = op.getDst(); - - auto dstTy = dyn_cast(dst.getType()); - if (!dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - scalar, - dst); - } - - DefaultInlineVector extractops; - func.walk([&](mlir::pto::TExtractOp op) { extractops.push_back(op); }); - - for (auto op : extractops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value indexRow = op.getIndexRow(); - Value indexCol = op.getIndexCol(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto indexRowTy = dyn_cast(indexRow.getType()); - auto indexColTy = dyn_cast(indexCol.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !indexRowTy || !indexColTy || !dstTy) { - op.emitError("ins/outs are not correct yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - indexRow, - indexCol, - dst); - } - - DefaultInlineVector fillpadops; - func.walk([&](mlir::pto::TFillPadOp op) { fillpadops.push_back(op); }); - - for (auto op : fillpadops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - DefaultInlineVector fillpadInplaceOps; - func.walk( - [&](mlir::pto::TFillPadInplaceOp op) { fillpadInplaceOps.push_back(op); }); - - for (auto op : fillpadInplaceOps) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - // --- TSetValOp [Dst, Offset, Val] --- - // Lower tile-world scalar write to memref-world SETVAL DPS op. - DefaultInlineVector tsetvalops; - func.walk([&](mlir::pto::TSetValOp op) { tsetvalops.push_back(op); }); - - for (auto op : tsetvalops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value dst = op.getDst(); - Value offset = op.getOffset(); - Value val = op.getVal(); - - auto dstTy = dyn_cast(dst.getType()); - if (!dstTy) { - op.emitError("dst is not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - dst, - offset, - val); - } - - // --- TGetValOp [Src, Offset] -> Scalar --- - // Lower tile-world scalar read to memref-world GETVAL DPS op. - DefaultInlineVector tgetvalops; - func.walk([&](mlir::pto::TGetValOp op) { tgetvalops.push_back(op); }); - - for (auto op : tgetvalops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value offset = op.getOffset(); - Type dstType = op.getDst().getType(); - - auto srcTy = dyn_cast(src.getType()); - if (!srcTy) { - op.emitError("src is not memref yet"); - signalPassFailure(); - return; - } - - auto newOp = rewriter.create( - op.getLoc(), - dstType, - src, - offset); - rewriter.replaceOp(op, newOp.getDst()); - } - - DefaultInlineVector gatherops; - func.walk([&](mlir::pto::TGatherOp op) { gatherops.push_back(op); }); - - for (auto op : gatherops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - Value cdst = op.getCdst(); - Value indices = op.getIndices(); - Value tmp = op.getTmp(); - Value kValue = op.getKValue(); - auto maskPattern = op.getMaskPatternAttr(); - auto cmpMode = op.getCmpModeAttr(); - auto offset = op.getOffsetAttr(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - if (maskPattern) { - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst, - /*cdst=*/Value(), - /*indices=*/Value(), - /*tmp=*/Value(), - /*kValue=*/Value(), - /*maskPattern=*/maskPattern, - /*cmpMode=*/pto::CmpModeAttr(), - /*offset=*/IntegerAttr()); - continue; - } - - if (cdst || kValue) { - auto cdstTy = dyn_cast(cdst.getType()); - auto tmpTy = dyn_cast(tmp.getType()); - if (!cdstTy || !tmpTy) { - op.emitError("compare-form tgather expects cdst/tmp to be memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst, - cdst, - /*indices=*/Value(), - tmp, - kValue, - /*maskPattern=*/pto::MaskPatternAttr(), - cmpMode, - offset); - continue; - } - - if (indices || tmp) { - auto indicesTy = dyn_cast(indices.getType()); - auto tmpTy = dyn_cast(tmp.getType()); - if (!indicesTy || !tmpTy) { - op.emitError("index-form tgather expects indices/tmp to be memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst, - /*cdst=*/Value(), - indices, - tmp, - /*kValue=*/Value(), - /*maskPattern=*/pto::MaskPatternAttr(), - /*cmpMode=*/pto::CmpModeAttr(), - /*offset=*/IntegerAttr()); - continue; - } - - op.emitError("expects tgather to be in mask, index+tmp, or compare+tmp form"); - signalPassFailure(); - return; - } - - DefaultInlineVector gatherbops; - func.walk([&](mlir::pto::TGatherBOp op) { gatherbops.push_back(op); }); - - for (auto op : gatherbops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value offsets = op.getOffsets(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto offsetsTy = dyn_cast(offsets.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !offsetsTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - offsets, - dst); - } - - DefaultInlineVector logops; - func.walk([&](mlir::pto::TLogOp op) { logops.push_back(op); }); - - for (auto op : logops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - DefaultInlineVector lreluops; - func.walk([&](mlir::pto::TLReluOp op) { lreluops.push_back(op); }); - - for (auto op : lreluops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value slope = op.getSlope(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto slopeTy = dyn_cast(slope.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !slopeTy || !dstTy) { - op.emitError("ins/outs are not correct type yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - slope, - dst); - } - - DefaultInlineVector maxops; - func.walk([&](mlir::pto::TMaxOp op) { maxops.push_back(op); }); - - for (auto op : maxops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - dst); - } - - DefaultInlineVector maxsops; - func.walk([&](mlir::pto::TMaxSOp op) { maxsops.push_back(op); }); - - for (auto op : maxsops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value scalar = op.getScalar(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - bool scalarIsScalar = isa(scalar.getType()); - if (!srcTy || !scalarIsScalar || !dstTy) { - op.emitError("expects src/dst to be memref and scalar to be integer/float"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - scalar, - dst); - } - - DefaultInlineVector minops; - func.walk([&](mlir::pto::TMinOp op) { minops.push_back(op); }); - - for (auto op : minops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - dst); - } - - DefaultInlineVector minsops; - func.walk([&](mlir::pto::TMinSOp op) { minsops.push_back(op); }); - - for (auto op : minsops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value scalar = op.getScalar(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - bool scalarIsScalar = isa(scalar.getType()); - if (!srcTy || !scalarIsScalar || !dstTy) { - op.emitError("expects src/dst to be memref and scalar to be integer/float"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - scalar, - dst); - } - - DefaultInlineVector movfpops; - func.walk([&](mlir::pto::TMovFPOp op) { movfpops.push_back(op); }); - - for (auto op : movfpops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value fp = op.getFp(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto fpTy = dyn_cast(fp.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !fpTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - fp, - dst); - } - - DefaultInlineVector quantops; - func.walk([&](mlir::pto::TQuantOp op) { quantops.push_back(op); }); - - for (auto op : quantops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value fp = op.getFp(); - Value offset = op.getOffset(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto fpTy = dyn_cast(fp.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !fpTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - if (offset && !dyn_cast(offset.getType())) { - op.emitError("offset is not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - fp, - offset, - dst, - op.getQuantTypeAttr()); - } - - DefaultInlineVector mrgsortops; - func.walk([&](mlir::pto::TMrgSortOp op) { mrgsortops.push_back(op); }); - - for (auto op : mrgsortops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - if (op.isFormat1()) { - Value src = op.getSrc(); - Value dst = op.getDst(); - Value blockLenVal = op.getBlockLen(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - ValueRange{src}, - blockLenVal, - ValueRange{dst}, - Value() /*tmp*/, - Value() /*excuted*/, - op.getExhaustedAttr()); - } else if (op.isFormat2()) { - bool allMemRef = true; - for (Value v : op.getSrcs()) - if (!dyn_cast(v.getType())) { allMemRef = false; break; } - if (!allMemRef) { - op.emitError("format2 ins/outs are not memref yet"); - signalPassFailure(); - return; - } - if (op.getDsts().size() != 1u || !op.getTmp()) { - op.emitError("format2 expects outs(dst) and ins(tmp)"); - signalPassFailure(); - return; - } - - Value dst = op.getDst(); - Value tmp = op.getTmp(); - Value excuted = op.getExcuted(); - if (!dyn_cast(dst.getType()) || !dyn_cast(tmp.getType())) { - op.emitError("format2 dst/tmp must be memref"); - signalPassFailure(); - return; - } - if (!dyn_cast(excuted.getType())) { - op.emitError("format2 outs(excuted) must be vector"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - op.getSrcs(), - Value() /*blockLen*/, - ValueRange{dst}, - tmp, - excuted, - op.getExhaustedAttr()); - } else { - op.emitError("tmrgsort must be format1 or format2"); - signalPassFailure(); - return; - } - } - - DefaultInlineVector negops; - func.walk([&](mlir::pto::TNegOp op) { negops.push_back(op); }); - - for (auto op : negops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - DefaultInlineVector notops; - func.walk([&](mlir::pto::TNotOp op) { notops.push_back(op); }); - - for (auto op : notops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - DefaultInlineVector orops; - func.walk([&](mlir::pto::TOrOp op) { orops.push_back(op); }); - - for (auto op : orops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - src0, - src1, - dst); - } - - DefaultInlineVector orsops; - func.walk([&](mlir::pto::TOrSOp op) { orsops.push_back(op); }); - - for (auto op : orsops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value scalar = op.getScalar(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto scalarTy = dyn_cast(scalar.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !scalarTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - scalar, - dst); - } - - DefaultInlineVector partaddops; - func.walk([&](mlir::pto::TPartAddOp op) { partaddops.push_back(op); }); - - for (auto op : partaddops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - src0, - src1, - dst); - } - - DefaultInlineVector partmulops; - func.walk([&](mlir::pto::TPartMulOp op) { partmulops.push_back(op); }); - - for (auto op : partmulops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - src0, - src1, - dst); - } - - DefaultInlineVector mgatherops; - func.walk([&](mlir::pto::MGatherOp op) { mgatherops.push_back(op); }); - - for (auto op : mgatherops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value dst = op.getDst(); - Value idx = op.getIdx(); - Value mem = op.getMem(); - - auto dstTy = dyn_cast(dst.getType()); - auto idxTy = dyn_cast(idx.getType()); - auto memTy = dyn_cast(mem.getType()); - if (!dstTy || !idxTy || !memTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - mem, - idx, - dst, - op.getGatherOobAttr()); - } - - DefaultInlineVector mascatterops; - func.walk([&](mlir::pto::MScatterOp op) { mascatterops.push_back(op); }); - - for (auto op : mascatterops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value idx = op.getIdx(); - Value mem = op.getMem(); - - auto srcTy = dyn_cast(src.getType()); - auto idxTy = dyn_cast(idx.getType()); - auto memTy = dyn_cast(mem.getType()); - if (!srcTy || !idxTy || !memTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - idx, - mem, - op.getScatterAtomicOpAttr(), - op.getScatterOobAttr()); - } - DefaultInlineVector printops; - func.walk([&](mlir::pto::TPrintOp op) { printops.push_back(op); }); - - for (auto op : printops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - - auto srcTy = dyn_cast(src.getType()); - if (!srcTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src); - } - - // ------------------------------------------------------------------ - // Stage 4: Reconcile control-flow result types - // ------------------------------------------------------------------ - if (failed(reconcileSCFIfResultTypes(func))) { - signalPassFailure(); - return; - } - if (failed(reconcileSCFForResultTypes(func))) { - signalPassFailure(); - return; - } - - // Mark memref-form set_validshape only after control-flow result-type - // reconciliation. Values such as scf.if results can stay tile_buf until - // this late stage. - if (failed(markLoweredSetValidShapeOps(func, ctx))) { - signalPassFailure(); - return; - } - } - - // Debug Output - LLVM_DEBUG(llvm::dbgs() << mod.getOperation()); - } -}; - -} // namespace - -std::unique_ptr createPTOViewToMemrefPass() { - return std::make_unique(); -} - -} // namespace pto -} // namespace mlir +#include "PTOViewToMemref.def" diff --git a/lib/PTO/Transforms/PTOViewToMemref.def b/lib/PTO/Transforms/PTOViewToMemref.def new file mode 100644 index 000000000..c21669b81 --- /dev/null +++ b/lib/PTO/Transforms/PTOViewToMemref.def @@ -0,0 +1,3615 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- PTOViewToMemref.cpp ------------------------------------------------===// +//===----------------------------------------------------------------------===// +// +// Lower PTO tile/view operations to memref-based IR while preserving tile +// metadata through binding ops and SSA backtracking. + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" + +#include "mlir/IR/AsmState.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_PTOVIEWTOMEMREF +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" +#include "Utils.h" // 假设包含一些通用的工具函数 + +#include +#include +#include + +#define DEBUG_TYPE "pto-view-to-memref" + +using namespace mlir; + +namespace mlir { +namespace pto { + +static constexpr llvm::StringLiteral kLoweredSetValidShapeAttrName = + "__pto.lowered_set_validshape"; +static constexpr llvm::StringLiteral kForceDynamicValidShapeAttrName = + "__pto.force_dynamic_valid_shape"; + +namespace { + +static void markForceDynamicValidShape(Operation *op, bool force, + MLIRContext *ctx); + +static Type convertPTOTypeToMemRef(Type t); + +constexpr size_t kTileRank2D = 2; +constexpr size_t kRowDimensionIndex = 0; +constexpr size_t kColumnDimensionIndex = 1; +constexpr unsigned kShapeVectorInlineCapacity = 4; +constexpr unsigned kOperationVectorInlineCapacity = 8; + +constexpr int64_t kElementBytes1 = 1; +constexpr int64_t kElementBytes2 = 2; +constexpr int64_t kElementBytes4 = 4; +constexpr int64_t kElementBytes8 = 8; +constexpr int64_t kElementBytes16 = 16; +constexpr int64_t kElementBytes32 = 32; + +constexpr int64_t kInnerExtent1 = 1; +constexpr int64_t kInnerExtent2 = 2; +constexpr int64_t kInnerExtent4 = 4; +constexpr int64_t kInnerExtent8 = 8; +constexpr int64_t kInnerExtent16 = 16; +constexpr int64_t kInnerExtent32 = 32; + +constexpr int32_t kFractalSize32 = 32; +constexpr int32_t kFractalSize512 = 512; +constexpr int32_t kFractalSize1024 = 1024; + +constexpr int32_t kBLayoutColMajor = + static_cast(BLayout::ColMajor); +constexpr int32_t kSLayoutNoneBox = + static_cast(SLayout::NoneBox); +constexpr int32_t kSLayoutRowMajor = + static_cast(SLayout::RowMajor); +constexpr int32_t kSLayoutColMajor = + static_cast(SLayout::ColMajor); +constexpr int32_t kCompactModeRowPlusOne = + static_cast(CompactMode::RowPlusOne); + +constexpr unsigned kThirdOperandIndex = 2; +constexpr unsigned kFourthOperandIndex = 3; +constexpr unsigned kFifthOperandIndex = 4; +constexpr unsigned kSixthOperandIndex = 5; + +template +using SmallInlineVector = SmallVector; + +template +using DefaultInlineVector = SmallVector; + +// ============================================================================= +// Helper: Metadata Backtracking (核心机制) +// ============================================================================= +// 从一个 MemRef Value 向上回溯,找到它绑定的 TileBufConfig。 +// 这解决了 "Type Erasure" 问题:memref 类型本身不包含 config,但 SSA 定义链包含。 +static mlir::pto::TileBufConfigAttr lookupConfig(Value v) { + // 1. 最直接的情况:它就是 bind_tile 的结果 + if (auto bind = v.getDefiningOp()) { + return bind.getConfig(); + } + // PointerCastOp can also carry tile metadata (used when alloc_tile specifies + // an explicit address). + if (auto pc = v.getDefiningOp()) { + if (auto cfg = pc.getConfig()) + return *cfg; + return {}; + } + + // 2. 穿透 View 操作 (SubView, Cast 等) 向上查找 + if (auto subview = v.getDefiningOp()) { + return lookupConfig(subview.getSource()); + } + if (auto cast = v.getDefiningOp()) { + return lookupConfig(cast.getSource()); + } + if (auto cast = v.getDefiningOp()) { + return lookupConfig(cast.getSource()); + } + + // 如果追溯到 BlockArgument (函数参数) 或其他无法穿透的 Op,则返回空 + return {}; +} + +// ============================================================================= +// Helper: Valid dims backtracking (v_row / v_col) +// ============================================================================= +static void lookupValidDims(Value v, Value &vRow, Value &vCol) { + if (auto bind = v.getDefiningOp()) { + vRow = bind.getValidRow(); + vCol = bind.getValidCol(); + return; + } + if (auto pc = v.getDefiningOp()) { + vRow = pc.getValidRow(); + vCol = pc.getValidCol(); + return; + } + if (auto subview = v.getDefiningOp()) { + lookupValidDims(subview.getSource(), vRow, vCol); + return; + } + if (auto cast = v.getDefiningOp()) { + lookupValidDims(cast.getSource(), vRow, vCol); + return; + } + if (auto cast = v.getDefiningOp()) { + lookupValidDims(cast.getSource(), vRow, vCol); + return; + } + vRow = Value(); + vCol = Value(); +} + +// ============================================================================= +// Helper Functions for Layout Normalization +// ============================================================================= + +struct TileLayoutInfo { + int64_t rowStride = 1; + int64_t colStride = 1; + int64_t innerRows = 1; + int64_t innerCols = 1; + bool boxed = false; // slayout != NoneBox +}; + +struct TileLayoutConfig { + int32_t bLayout = 0; + int32_t sLayout = 0; + int32_t fractalSize = kFractalSize512; + int32_t compactMode = 0; +}; + +static int64_t getElemBytes(Type elemTy) { + unsigned bytes = getPTOStorageElemByteSize(elemTy); + return bytes == 0 ? -1 : static_cast(bytes); +} + +template +static bool readEnumAttrOrIntegerI32(Attribute attr, int32_t &out) { + if (auto enumAttr = dyn_cast(attr)) { + out = static_cast(enumAttr.getValue()); + return true; + } + if (auto intAttr = dyn_cast(attr)) { + out = static_cast(intAttr.getInt()); + return true; + } + return false; +} + +static bool readBLayoutI32(Attribute attr, int32_t &out) { + return readEnumAttrOrIntegerI32(attr, out); +} + +static bool readSLayoutI32(Attribute attr, int32_t &out) { + return readEnumAttrOrIntegerI32(attr, out); +} + +static bool readCompactModeI32(Attribute attr, int32_t &out) { + return readEnumAttrOrIntegerI32(attr, out); +} + +static Value peelIndexLikeCast(Value value) { + while (true) { + if (auto castOp = value.getDefiningOp()) { + value = castOp.getIn(); + continue; + } + if (auto extOp = value.getDefiningOp()) { + value = extOp.getIn(); + continue; + } + if (auto extOp = value.getDefiningOp()) { + value = extOp.getIn(); + continue; + } + if (auto truncOp = value.getDefiningOp()) { + value = truncOp.getIn(); + continue; + } + return value; + } +} + +static bool getConstIndexValue(Value value, int64_t &out) { + value = peelIndexLikeCast(value); + if (auto constIndex = value.getDefiningOp()) { + out = constIndex.value(); + return true; + } + if (auto constInt = value.getDefiningOp()) { + out = constInt.value(); + return true; + } + auto constOp = value.getDefiningOp(); + auto intAttr = + constOp ? dyn_cast(constOp.getValue()) : IntegerAttr(); + if (!intAttr) + return false; + out = intAttr.getInt(); + return true; +} + +static TileLayoutConfig getTileLayoutConfig(mlir::pto::TileBufConfigAttr cfg) { + TileLayoutConfig config; + (void)readBLayoutI32(cfg.getBLayout(), config.bLayout); + (void)readSLayoutI32(cfg.getSLayout(), config.sLayout); + if (auto attr = dyn_cast(cfg.getSFractalSize())) + config.fractalSize = static_cast(attr.getInt()); + (void)readCompactModeI32(cfg.getCompactMode(), config.compactMode); + return config; +} + +static bool getFractal512InnerExtent(int64_t elemBytes, int64_t &extent) { + switch (elemBytes) { + case kElementBytes1: + extent = kInnerExtent32; + return true; + case kElementBytes2: + extent = kInnerExtent16; + return true; + case kElementBytes4: + extent = kInnerExtent8; + return true; + case kElementBytes8: + extent = kInnerExtent4; + return true; + case kElementBytes16: + extent = kInnerExtent2; + return true; + case kElementBytes32: + extent = kInnerExtent1; + return true; + default: + return false; + } +} + +static bool computeBoxInnerShape(const TileLayoutConfig &config, Type elemTy, + TileLayoutInfo &info) { + info.boxed = config.sLayout != kSLayoutNoneBox; + if (!info.boxed) { + info.innerRows = kInnerExtent1; + info.innerCols = kInnerExtent1; + return true; + } + + int64_t elemBytes = getElemBytes(elemTy); + if (elemBytes <= 0) + return false; + + switch (config.fractalSize) { + case kFractalSize1024: + info.innerRows = kInnerExtent16; + info.innerCols = kInnerExtent16; + return true; + case kFractalSize32: + info.innerRows = kInnerExtent16; + info.innerCols = kInnerExtent2; + return true; + case kFractalSize512: + if (config.sLayout == kSLayoutRowMajor) { + info.innerRows = kInnerExtent16; + return getFractal512InnerExtent(elemBytes, info.innerCols); + } + if (config.sLayout == kSLayoutColMajor) { + if (!getFractal512InnerExtent(elemBytes, info.innerRows)) + return false; + info.innerCols = kInnerExtent16; + return true; + } + return false; + default: + return false; + } +} + +static bool computeTilePointerStrides(const TileLayoutConfig &config, + ArrayRef shape, + TileLayoutInfo &info) { + int64_t rows = shape[0]; + int64_t cols = shape[1]; + auto applyCompactToMajorStride = [&](int64_t majorStride) -> int64_t { + if (config.compactMode == kCompactModeRowPlusOne) + return majorStride + kInnerExtent1; + return majorStride; + }; + if (!info.boxed) { + if (config.bLayout == kBLayoutColMajor) { + info.rowStride = kInnerExtent1; + info.colStride = applyCompactToMajorStride(rows); + return true; + } + info.rowStride = applyCompactToMajorStride(cols); + info.colStride = kInnerExtent1; + return true; + } + + if (config.bLayout == kBLayoutColMajor) { + if (config.sLayout != kSLayoutRowMajor) + return false; + info.rowStride = info.innerCols; + info.colStride = applyCompactToMajorStride(rows); + return true; + } + + info.rowStride = applyCompactToMajorStride(cols); + info.colStride = info.innerRows; + return true; +} + +static bool computeTileLayoutInfo(mlir::pto::TileBufConfigAttr cfg, Type elemTy, + ArrayRef shape, + TileLayoutInfo &info) { + if (shape.size() != kTileRank2D || + llvm::is_contained(shape, ShapedType::kDynamic)) + return false; + + TileLayoutConfig config = getTileLayoutConfig(cfg); + return computeBoxInnerShape(config, elemTy, info) && + computeTilePointerStrides(config, shape, info); +} + +static void collectAffineAddTerms(AffineExpr root, + SmallVectorImpl &terms) { + SmallInlineVector pending{root}; + while (!pending.empty()) { + AffineExpr current = pending.pop_back_val(); + auto addExpr = llvm::dyn_cast(current); + if (!addExpr || addExpr.getKind() != AffineExprKind::Add) { + terms.push_back(current); + continue; + } + pending.push_back(addExpr.getRHS()); + pending.push_back(addExpr.getLHS()); + } +} + +static bool tryAssignAffineStride(AffineExpr expr, + MutableArrayRef strides) { + if (auto dim = llvm::dyn_cast(expr)) { + strides[dim.getPosition()] = 1; + return true; + } + + auto mulExpr = llvm::dyn_cast(expr); + if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul) + return false; + + auto assignStride = [&](AffineExpr dimExpr, + AffineExpr constantExpr) -> bool { + auto dim = llvm::dyn_cast(dimExpr); + auto constant = llvm::dyn_cast(constantExpr); + if (!dim || !constant) + return false; + strides[dim.getPosition()] = constant.getValue(); + return true; + }; + return assignStride(mulExpr.getLHS(), mulExpr.getRHS()) || + assignStride(mulExpr.getRHS(), mulExpr.getLHS()); +} + +[[maybe_unused]] static void decomposeStridedLayout(AffineMap map, + SmallVectorImpl &strides) { + strides.assign(map.getNumDims(), 0); + if (map.getNumResults() != 1) + return; + + SmallInlineVector terms; + collectAffineAddTerms(map.getResult(0), terms); + for (AffineExpr term : terms) + (void)tryAssignAffineStride(term, strides); +} + +static Value makeIndexConstant(IRRewriter &rewriter, Location loc, + int64_t value) { + return rewriter.create(loc, rewriter.getIndexType(), + rewriter.getIndexAttr(value)); +} + +static SmallVector computeCompactStrides(ArrayRef shape) { + SmallVector strides(shape.size(), 1); + int64_t stride = 1; + for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { + strides[i] = stride; + if (shape[i] != ShapedType::kDynamic) + stride *= shape[i]; + } + return strides; +} + +static void materializeStaticValidDims(IRRewriter &rewriter, Location loc, + mlir::pto::TileBufType tbTy, Value &vRow, + Value &vCol) { + ArrayRef validShape = tbTy.getValidShape(); + if (tbTy.hasDynamicValid()) + return; + if (!validShape.empty() && validShape[kRowDimensionIndex] >= 0) + vRow = makeIndexConstant(rewriter, loc, validShape[kRowDimensionIndex]); + if (validShape.size() >= kTileRank2D && + validShape[kColumnDimensionIndex] >= 0) + vCol = makeIndexConstant(rewriter, loc, validShape[kColumnDimensionIndex]); +} + +static bool checkMultipleOf(Operation *op, int64_t value, int64_t divisor, + StringRef label) { + if (divisor <= 0) { + op->emitError("boxed layout requires positive divisor for ") << label; + return false; + } + if (value % divisor == 0) + return true; + op->emitError("boxed layout requires ") + << label << " multiple of " << divisor << ", got " << value; + return false; +} + +// 确保 Value 是 Index 类型 +static Value ensureIndex(IRRewriter &rewriter, Location loc, Value v, + Operation *anchorOp) { + if (v.getType().isIndex()) + return v; + if (isa(v.getType())) + return rewriter.create(loc, rewriter.getIndexType(), v); + if (anchorOp) + anchorOp->emitError() << "expected index or integer, but got " << v.getType(); + return Value(); +} + +static bool tryGetIndexAttrFromValue(IRRewriter &rewriter, Value v, + IntegerAttr &constAttr) { + if (auto cOp = v.getDefiningOp()) { + constAttr = rewriter.getIndexAttr(cOp.value()); + return true; + } + if (auto cInt = v.getDefiningOp()) { + constAttr = rewriter.getIndexAttr(cInt.value()); + return true; + } + return false; +} + +static void appendMixedIndex(IRRewriter &rewriter, Location loc, Value v, + Operation *anchorOp, + SmallVectorImpl &mixedVals) { + IntegerAttr constAttr; + if (tryGetIndexAttrFromValue(rewriter, v, constAttr)) { + mixedVals.push_back(constAttr); + return; + } + mixedVals.push_back(ensureIndex(rewriter, loc, v, anchorOp)); +} + +static bool foldAddPtrChainIntoOffset(IRRewriter &rewriter, Location loc, + Value &base, Value &totalOffset) { + bool folded = false; + while (auto add = base.getDefiningOp()) { + folded = true; + Value off = ensureIndex(rewriter, loc, add.getOperand(1), add); + totalOffset = + totalOffset ? rewriter.create(loc, totalOffset, off) : off; + base = add.getOperand(0); + } + return folded; +} + +static Value clampSubViewValidDim(IRRewriter &rewriter, Location loc, + Value explicitValid, int64_t size, + Operation *anchorOp) { + Value sizeVal = rewriter.create(loc, size); + if (!explicitValid) + return sizeVal; + + int64_t cst = 0; + if (getConstIndexValue(explicitValid, cst)) + return rewriter.create(loc, std::min(cst, size)); + + Value v = ensureIndex(rewriter, loc, explicitValid, anchorOp); + Value lt = rewriter.create(loc, arith::CmpIPredicate::slt, v, + sizeVal); + return rewriter.create(loc, lt, v, sizeVal); +} + +[[maybe_unused]] static void dumpPretty(Operation *op, llvm::raw_ostream &os) { + OpPrintingFlags flags; + flags.useLocalScope(); + AsmState state(op, flags); + op->print(os, state); + os << "\n"; + os.flush(); +} + +// ============================================================================= +// Type Converter Logic +// ============================================================================= + +static SmallVector buildTileMemRefStrides(mlir::pto::TileBufType tbTy) { + TileLayoutInfo info; + if (computeTileLayoutInfo(tbTy.getConfigAttr(), tbTy.getElementType(), + tbTy.getShape(), info)) { + return {info.rowStride, info.colStride}; + } + return computeCompactStrides(tbTy.getShape()); +} + +static Type convertTileBufTypeToMemRef(mlir::pto::TileBufType tbTy) { + auto layoutAttr = StridedLayoutAttr::get(tbTy.getContext(), + ShapedType::kDynamic, + buildTileMemRefStrides(tbTy)); + return MemRefType::get(tbTy.getShape(), tbTy.getElementType(), layoutAttr, + tbTy.getMemorySpace()); +} + +static Type convertPTOTypeToMemRef(Type t) { + // 1. 处理 !pto.ptr + if (auto pty = dyn_cast(t)) { + return MemRefType::get({ShapedType::kDynamic}, pty.getElementType(), + MemRefLayoutAttrInterface(), Attribute()); + } + + // 2. 处理 !pto.tile_buf<...> + if (auto tbTy = dyn_cast(t)) + return convertTileBufTypeToMemRef(tbTy); + if (auto tvTy = dyn_cast(t)) + return MemRefType::get(tvTy.getShape(), tvTy.getElementType(), + MemRefLayoutAttrInterface(), Attribute()); + if (auto partTy = dyn_cast(t)) + return MemRefType::get(partTy.getShape(), partTy.getElementType(), + MemRefLayoutAttrInterface(), Attribute()); + // 其他类型透传 + return t; +} + +// Ensure scf.if result types follow the rewritten yield operand types. +// PTOViewToMemref rewrites tile values to memref in branch bodies, but scf.if +// result types are not auto-updated by those op-local rewrites. +static LogicalResult reconcileSCFIfResultTypes(func::FuncOp func) { + DefaultInlineVector ifOps; + func.walk([&](scf::IfOp ifOp) { ifOps.push_back(ifOp); }); + + for (scf::IfOp ifOp : ifOps) { + if (ifOp.getNumResults() == 0) + continue; + + auto thenYield = dyn_cast(ifOp.thenBlock()->getTerminator()); + auto elseYield = dyn_cast(ifOp.elseBlock()->getTerminator()); + if (!thenYield || !elseYield) { + ifOp.emitError("result-bearing scf.if must end with scf.yield in both " + "then/else regions"); + return failure(); + } + + if (thenYield.getNumOperands() != ifOp.getNumResults() || + elseYield.getNumOperands() != ifOp.getNumResults()) { + ifOp.emitError("scf.if result count does not match yielded values"); + return failure(); + } + + for (unsigned i = 0; i < ifOp.getNumResults(); ++i) { + Type thenTy = thenYield.getOperand(i).getType(); + Type elseTy = elseYield.getOperand(i).getType(); + if (thenTy != elseTy) { + ifOp.emitError() << "scf.if branch yield type mismatch at result #" << i + << ": then=" << thenTy << ", else=" << elseTy; + return failure(); + } + + if (ifOp.getResult(i).getType() != thenTy) + ifOp.getResult(i).setType(thenTy); + } + } + + return success(); +} + +static LogicalResult reconcileSCFForResultTypes(func::FuncOp func) { + DefaultInlineVector forOps; + func.walk([&](scf::ForOp forOp) { forOps.push_back(forOp); }); + + for (scf::ForOp forOp : forOps) { + if (forOp.getNumResults() == 0) + continue; + + auto yield = dyn_cast(forOp.getBody()->getTerminator()); + if (!yield) { + forOp.emitError("result-bearing scf.for must end with scf.yield"); + return failure(); + } + + if (yield.getNumOperands() != forOp.getNumResults() || + forOp.getInitArgs().size() != forOp.getNumResults()) { + forOp.emitError("scf.for result count does not match iter/yield values"); + return failure(); + } + + for (unsigned i = 0; i < forOp.getNumResults(); ++i) { + Type initTy = forOp.getInitArgs()[i].getType(); + Type yieldTy = yield.getOperand(i).getType(); + if (initTy != yieldTy) { + forOp.emitError() << "scf.for init/yield type mismatch at result #" << i + << ": init=" << initTy << ", yield=" << yieldTy; + return failure(); + } + + BlockArgument iterArg = forOp.getRegionIterArg(i); + if (iterArg.getType() != initTy) + iterArg.setType(initTy); + if (forOp.getResult(i).getType() != initTy) + forOp.getResult(i).setType(initTy); + } + } + + return success(); +} + +static LogicalResult markLoweredSetValidShapeOps(func::FuncOp func, + MLIRContext *ctx) { + WalkResult result = func.walk([&](mlir::pto::SetValidShapeOp op) { + if (isa(op.getSource().getType())) { + if (!lookupConfig(op.getSource())) { + op.emitError( + "set_validshape requires a locally bound tile source; function " + "arguments/results are unsupported"); + return WalkResult::interrupt(); + } + op->setAttr(kLoweredSetValidShapeAttrName, UnitAttr::get(ctx)); + return WalkResult::advance(); + } + op->removeAttr(kLoweredSetValidShapeAttrName); + return WalkResult::advance(); + }); + return result.wasInterrupted() ? failure() : success(); +} + +static void markForceDynamicValidShape(Operation *op, bool force, + MLIRContext *ctx) { + if (force) { + op->setAttr(kForceDynamicValidShapeAttrName, UnitAttr::get(ctx)); + return; + } + op->removeAttr(kForceDynamicValidShapeAttrName); +} + +[[maybe_unused]] static void rewriteFunctionSignature(func::FuncOp func, MLIRContext *ctx) { + Block &entry = func.front(); + auto fnTy = func.getFunctionType(); + + SmallVector newInputs; + for (Type type : fnTy.getInputs()) + newInputs.push_back(convertPTOTypeToMemRef(type)); + + SmallVector newResults; + for (Type type : fnTy.getResults()) + newResults.push_back(convertPTOTypeToMemRef(type)); + + for (unsigned i = 0; i < entry.getNumArguments(); ++i) { + if (entry.getArgument(i).getType() != newInputs[i]) + entry.getArgument(i).setType(newInputs[i]); + } + func.setFunctionType(FunctionType::get(ctx, newInputs, newResults)); +} + +[[maybe_unused]] static LogicalResult lowerAllocTileOps(func::FuncOp func, MLIRContext *ctx) { + DefaultInlineVector allocTiles; + func.walk([&](mlir::pto::AllocTileOp op) { allocTiles.push_back(op); }); + + for (auto op : allocTiles) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + auto tbTy = dyn_cast(op.getResult().getType()); + if (!tbTy) + continue; + + SmallInlineVector shape(tbTy.getShape().begin(), + tbTy.getShape().end()); + Type elemTy = tbTy.getElementType(); + SmallVector strides = buildTileMemRefStrides(tbTy); + + auto targetLayout = + StridedLayoutAttr::get(ctx, ShapedType::kDynamic, strides); + auto targetType = + MemRefType::get(shape, elemTy, targetLayout, tbTy.getMemorySpace()); + + Value vRow = op.getValidRow(); + Value vCol = op.getValidCol(); + materializeStaticValidDims(rewriter, loc, tbTy, vRow, vCol); + + auto configAttr = tbTy.getConfigAttr(); + if (!configAttr) + configAttr = pto::TileBufConfigAttr::getDefault(ctx); + + if (Value addr = op.getAddr()) { + auto pc = rewriter.create( + loc, targetType, ValueRange{addr}, vRow ? vRow : Value(), + vCol ? vCol : Value(), configAttr); + markForceDynamicValidShape(pc, tbTy.hasDynamicValid(), ctx); + auto bindOp = rewriter.create( + loc, targetType, pc.getResult(), vRow ? vRow : Value(), + vCol ? vCol : Value(), configAttr); + markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); + rewriter.replaceOp(op, bindOp.getResult()); + continue; + } + + auto allocLayout = StridedLayoutAttr::get(ctx, 0, strides); + auto allocType = + MemRefType::get(shape, elemTy, allocLayout, tbTy.getMemorySpace()); + Value alloc = rewriter.create(loc, allocType); + auto bindOp = rewriter.create( + loc, targetType, alloc, vRow ? vRow : Value(), vCol ? vCol : Value(), + configAttr); + markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); + rewriter.replaceOp(op, bindOp.getResult()); + } + return success(); +} + +[[maybe_unused]] static LogicalResult lowerDeclareTileOps(func::FuncOp func, MLIRContext *ctx) { + DefaultInlineVector declaredTiles; + func.walk([&](mlir::pto::DeclareTileOp op) { declaredTiles.push_back(op); }); + + for (auto op : declaredTiles) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + auto tbTy = dyn_cast(op.getTile().getType()); + if (!tbTy) { + op.emitError("declare_tile result must be tile_buf type"); + return failure(); + } + + auto targetType = dyn_cast(convertPTOTypeToMemRef(tbTy)); + if (!targetType) { + op.emitError("failed to convert declare_tile result to memref type"); + return failure(); + } + + auto configAttr = tbTy.getConfigAttr(); + if (!configAttr) + configAttr = pto::TileBufConfigAttr::getDefault(ctx); + + Value vRow; + Value vCol; + materializeStaticValidDims(rewriter, loc, tbTy, vRow, vCol); + + auto declaredMemRef = + rewriter.create(loc, targetType); + auto bindOp = rewriter.create( + loc, targetType, declaredMemRef.getResult(), vRow ? vRow : Value(), + vCol ? vCol : Value(), configAttr); + markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); + rewriter.replaceOp(op, bindOp.getResult()); + } + return success(); +} + +static Value castIndexToI64(IRRewriter &rewriter, Location loc, Value value) { + Type i64Ty = rewriter.getI64Type(); + if (value.getType() == i64Ty) + return value; + return rewriter.create(loc, i64Ty, value).getResult(); +} + +static FailureOr +materializePtrToIntAddPtrAddress(IRRewriter &rewriter, Location loc, + mlir::pto::PtrToIntOp anchor, Value source) { + SmallVector addPtrChain; + Value base = source; + while (auto add = base.getDefiningOp()) { + addPtrChain.push_back(add); + base = add.getOperand(0); + } + + if (addPtrChain.empty()) + return failure(); + + auto baseMemTy = dyn_cast(base.getType()); + if (!baseMemTy) { + anchor.emitOpError( + "pto.addptr source base could not be lowered to a GM memref"); + return failure(); + } + + Value byteAddress = rewriter.create( + loc, rewriter.getI64Type(), base); + for (auto add : addPtrChain) { + auto addPtrTy = dyn_cast(add.getResult().getType()); + if (!addPtrTy) { + anchor.emitOpError("requires pto.addptr source to have !pto.ptr result " + "type before byte-address lowering"); + return failure(); + } + + unsigned elemBytes = + mlir::pto::getPTOStorageElemByteSize(addPtrTy.getElementType()); + if (elemBytes == 0) { + anchor.emitOpError("cannot lower pto.addptr source with unknown element " + "byte size to a byte address"); + return failure(); + } + + Value byteOffset = castIndexToI64(rewriter, loc, add.getOffset()); + if (elemBytes != 1) { + Value elemBytesValue = + rewriter.create(loc, elemBytes, 64); + byteOffset = + rewriter.create(loc, byteOffset, elemBytesValue) + .getResult(); + } + byteAddress = + rewriter.create(loc, byteAddress, byteOffset).getResult(); + } + + return byteAddress; +} + +static LogicalResult lowerIntToPtrOps(func::FuncOp func, MLIRContext *ctx) { + DefaultInlineVector intToPtrs; + func.walk([&](mlir::pto::IntToPtrOp op) { intToPtrs.push_back(op); }); + + for (auto op : intToPtrs) { + if (!isa(op.getResult().getType())) + continue; + + auto targetTy = + dyn_cast(convertPTOTypeToMemRef(op.getResult().getType())); + if (!targetTy) { + op.emitError("failed to convert inttoptr result to memref type"); + return failure(); + } + + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + auto lowered = + rewriter.create(op.getLoc(), targetTy, + op.getAddr()); + lowered->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, lowered.getResult()); + } + + return success(); +} + +static LogicalResult lowerPtrToIntOps(func::FuncOp func, MLIRContext *ctx) { + DefaultInlineVector ptrToInts; + func.walk([&](mlir::pto::PtrToIntOp op) { ptrToInts.push_back(op); }); + + for (auto op : ptrToInts) { + Value source = op.getPtr(); + if (source.getDefiningOp()) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + FailureOr byteAddress = + materializePtrToIntAddPtrAddress(rewriter, op.getLoc(), op, source); + if (failed(byteAddress)) + return failure(); + rewriter.replaceOp(op, *byteAddress); + continue; + } + + if (isa(source.getType())) + continue; + } + + DefaultInlineVector remaining; + func.walk([&](mlir::pto::PtrToIntOp op) { + if (isa(op.getPtr().getType())) + remaining.push_back(op); + }); + for (auto op : remaining) { + op.emitError("ptrtoint source could not be lowered to a GM memref"); + return failure(); + } + + return success(); +} + +[[maybe_unused]] static LogicalResult lowerMakeTensorViewOps(func::FuncOp func, MLIRContext *ctx) { + DefaultInlineVector makeViews; + func.walk([&](mlir::pto::MakeTensorViewOp op) { makeViews.push_back(op); }); + + for (auto op : makeViews) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + Value baseBuf = op.getOperand(0); + OpFoldResult off0 = rewriter.getIndexAttr(0); + bool foldedAddPtr = false; + { + Value cur = baseBuf; + Value totalOffset; + while (auto add = cur.getDefiningOp()) { + foldedAddPtr = true; + Value off = ensureIndex(rewriter, loc, add.getOperand(1), add); + totalOffset = totalOffset ? rewriter.create(loc, totalOffset, off) + : off; + cur = add.getOperand(0); + } + if (cur != baseBuf) { + baseBuf = cur; + off0 = totalOffset ? OpFoldResult(totalOffset) : off0; + } + } + + auto baseMr = dyn_cast(baseBuf.getType()); + if (!baseMr) { + op.emitError("make_tensor_view base must be memref"); + return failure(); + } + + size_t rank = op.getShape().size(); + int64_t dyn = ShapedType::kDynamic; + SmallVector dynStrides(rank, dyn); + auto layout = + StridedLayoutAttr::get(ctx, /*offset=*/dyn, /*strides=*/dynStrides); + SmallVector dynShape(rank, dyn); + auto mrTy = MemRefType::get(dynShape, baseMr.getElementType(), layout, + baseMr.getMemorySpace()); + + SmallInlineVector sizes; + for (Value value : op.getShape()) + sizes.push_back(ensureIndex(rewriter, loc, value, op)); + SmallInlineVector strides; + for (Value value : op.getStrides()) + strides.push_back(ensureIndex(rewriter, loc, value, op)); + + auto rc = rewriter.create(loc, mrTy, baseBuf, off0, + sizes, strides); + if (foldedAddPtr) + rc->setAttr("pto.addptr_trace", rewriter.getUnitAttr()); + if (auto layoutAttr = op.getLayoutAttr()) + rc->setAttr("layout", layoutAttr); + rewriter.replaceOp(op, rc.getResult()); + } + return success(); +} + +[[maybe_unused]] static LogicalResult lowerTensorViewDimOps(func::FuncOp func, MLIRContext *ctx) { + DefaultInlineVector tvDims; + func.walk([&](mlir::pto::GetTensorViewDimOp op) { tvDims.push_back(op); }); + + for (auto op : tvDims) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Value view = op.getTensorView(); + auto mrTy = dyn_cast(view.getType()); + if (!mrTy) + continue; + Value dim = rewriter.create(op.getLoc(), view, op.getDimIndex()); + rewriter.replaceOp(op, dim); + } + return success(); +} + +[[maybe_unused]] static LogicalResult foldAddPtrIntoScalarOps(func::FuncOp func, MLIRContext *ctx) { + DefaultInlineVector loadScalars; + func.walk([&](mlir::pto::LoadScalarOp op) { loadScalars.push_back(op); }); + for (auto op : loadScalars) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + Value base = op.getPtr(); + Value totalOffset = ensureIndex(rewriter, loc, op.getOffset(), op); + bool foldedAddPtr = foldAddPtrChainIntoOffset(rewriter, loc, base, totalOffset); + if (foldedAddPtr) { + auto newOp = + rewriter.create(loc, op.getValue().getType(), base, + totalOffset); + rewriter.replaceOp(op, newOp.getValue()); + } + } + + DefaultInlineVector storeScalars; + func.walk([&](mlir::pto::StoreScalarOp op) { storeScalars.push_back(op); }); + for (auto op : storeScalars) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + Value base = op.getPtr(); + Value totalOffset = ensureIndex(rewriter, loc, op.getOffset(), op); + bool foldedAddPtr = foldAddPtrChainIntoOffset(rewriter, loc, base, totalOffset); + if (foldedAddPtr) { + rewriter.create(loc, base, totalOffset, op.getValue()); + rewriter.eraseOp(op); + } + } + + DefaultInlineVector addPtrs; + func.walk([&](mlir::pto::AddPtrOp op) { addPtrs.push_back(op.getOperation()); }); + bool changed = true; + while (changed) { + changed = false; + for (auto &op : addPtrs) { + if (!op) + continue; + if (op->use_empty()) { + op->erase(); + op = nullptr; + changed = true; + } + } + } + for (Operation *op : addPtrs) { + if (!op) + continue; + op->emitError( + "addptr must feed make_tensor_view or load/store_scalar for lowering"); + return failure(); + } + return success(); +} + +static LogicalResult lowerPartitionViewOps(func::FuncOp func, MLIRContext *ctx) { + DefaultInlineVector partitionViews; + func.walk([&](mlir::pto::PartitionViewOp op) { partitionViews.push_back(op); }); + + for (auto op : partitionViews) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + Value src = op.getOperand(0); + auto srcMrTy = dyn_cast(src.getType()); + if (!srcMrTy) + continue; + int64_t rank = srcMrTy.getRank(); + + SmallVector staticSizes; + SmallVector mixedSizes; + for (Value size : op.getSizes()) { + IntegerAttr constAttr; + bool isStatic = false; + if (auto cOp = size.getDefiningOp()) { + constAttr = rewriter.getIndexAttr(cOp.value()); + isStatic = true; + } else if (auto cInt = size.getDefiningOp()) { + constAttr = rewriter.getIndexAttr(cInt.value()); + isStatic = true; + } + + if (isStatic) { + mixedSizes.push_back(constAttr); + staticSizes.push_back(constAttr.getInt()); + } else { + mixedSizes.push_back(ensureIndex(rewriter, loc, size, op)); + staticSizes.push_back(ShapedType::kDynamic); + } + } + + SmallVector mixedOffsets; + for (Value offset : op.getOffsets()) { + appendMixedIndex(rewriter, loc, offset, op, mixedOffsets); + } + + int64_t dyn = ShapedType::kDynamic; + SmallVector dynStrides(rank, dyn); + auto layout = StridedLayoutAttr::get(ctx, dyn, dynStrides); + auto resTy = MemRefType::get(staticSizes, srcMrTy.getElementType(), layout, + srcMrTy.getMemorySpace()); + + SmallVector mixedStrides(rank, rewriter.getIndexAttr(1)); + auto sv = rewriter.create(loc, resTy, src, mixedOffsets, + mixedSizes, mixedStrides); + if (Operation *srcDef = src.getDefiningOp()) { + if (auto layoutAttr = srcDef->getAttrOfType("layout")) + sv->setAttr("layout", layoutAttr); + } + rewriter.replaceOp(op, sv.getResult()); + } + return success(); +} + +static LogicalResult lowerSubViewOps(func::FuncOp func, MLIRContext *ctx) { + DefaultInlineVector subViews; + func.walk([&](mlir::pto::SubViewOp op) { subViews.push_back(op); }); + + for (auto op : subViews) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + auto resultTileTy = + dyn_cast(op.getResult().getType()); + Value src = op->getOperand(0); + auto srcMrTy = dyn_cast(src.getType()); + if (!srcMrTy) { + op.emitError("pto.subview source must be lowered to memref first"); + return failure(); + } + + ArrayAttr sizeAttr = op.getSizes(); + SmallVector staticSizes; + SmallVector mixedSizes; + for (Attribute attr : sizeAttr) { + int64_t size = cast(attr).getInt(); + staticSizes.push_back(size); + mixedSizes.push_back(rewriter.getIndexAttr(size)); + } + + SmallVector mixedOffsets; + for (Value offset : op.getOffsets()) { + appendMixedIndex(rewriter, loc, offset, op, mixedOffsets); + } + + auto configAttr = lookupConfig(src); + if (!configAttr) + configAttr = pto::TileBufConfigAttr::getDefault(ctx); + + TileLayoutInfo layoutInfo; + if (!computeTileLayoutInfo(configAttr, srcMrTy.getElementType(), + srcMrTy.getShape(), layoutInfo)) { + op.emitError("unsupported tile layout for pto.subview"); + return failure(); + } + + if (layoutInfo.boxed) { + if (staticSizes.size() != kTileRank2D || + op.getOffsets().size() != kTileRank2D) { + op.emitError("boxed layout subview expects 2D sizes/offsets"); + return failure(); + } + if (!checkMultipleOf(op, staticSizes[0], layoutInfo.innerRows, "row size") || + !checkMultipleOf(op, staticSizes[1], layoutInfo.innerCols, "col size")) { + return failure(); + } + + int64_t off0 = 0; + int64_t off1 = 0; + bool off0Const = getConstIndexValue(op.getOffsets()[0], off0); + bool off1Const = getConstIndexValue(op.getOffsets()[1], off1); + if (off0Const && + !checkMultipleOf(op, off0, layoutInfo.innerRows, "row offset")) { + return failure(); + } + if (off1Const && + !checkMultipleOf(op, off1, layoutInfo.innerCols, "col offset")) { + return failure(); + } + + } + + SmallVector srcStrides; + int64_t srcOffset = ShapedType::kDynamic; + if (failed(getStridesAndOffset(srcMrTy, srcStrides, srcOffset))) + srcStrides = computeCompactStrides(srcMrTy.getShape()); + + // Keep parent physical shape + strides for bound tile semantics. + auto resultLayout = + StridedLayoutAttr::get(ctx, ShapedType::kDynamic, srcStrides); + auto parentShape = srcMrTy.getShape(); + auto resultMemRefType = + MemRefType::get(parentShape, srcMrTy.getElementType(), resultLayout, + srcMrTy.getMemorySpace()); + + // Intermediate memref.subview keeps logical subview size. + auto subViewMemRefType = + MemRefType::get(staticSizes, srcMrTy.getElementType(), resultLayout, + srcMrTy.getMemorySpace()); + + SmallVector mixedStrides(staticSizes.size(), + rewriter.getIndexAttr(1)); + auto sv = rewriter.create(loc, subViewMemRefType, src, + mixedOffsets, mixedSizes, + mixedStrides); + + Value vRow; + Value vCol; + if (!staticSizes.empty()) + vRow = clampSubViewValidDim(rewriter, loc, op.getValidRow(), + staticSizes[0], op); + if (staticSizes.size() > 1) + vCol = clampSubViewValidDim(rewriter, loc, op.getValidCol(), + staticSizes[1], op); + + auto bindOp = rewriter.create( + loc, resultMemRefType, sv.getResult(), vRow ? vRow : Value(), + vCol ? vCol : Value(), configAttr); + markForceDynamicValidShape(bindOp, + resultTileTy && resultTileTy.hasDynamicValid(), + ctx); + bindOp->setAttr("pto.view_semantics", rewriter.getStringAttr("subview")); + rewriter.replaceOp(op, bindOp.getResult()); + } + return success(); +} + +static Value buildTileBufViewLikeValue(Operation *anchorOp, Value src, + mlir::pto::TileBufType tbTy, + StringRef viewSemantics, + MLIRContext *ctx) { + Location loc = anchorOp->getLoc(); + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(anchorOp); + + auto srcMrTy = dyn_cast(src.getType()); + if (!srcMrTy) { + anchorOp->emitError("tile_buf view op src must be lowered to memref first"); + return Value(); + } + + auto targetType = dyn_cast(convertPTOTypeToMemRef(tbTy)); + if (!targetType) { + anchorOp->emitError("failed to convert tile_buf type to memref type"); + return Value(); + } + for (int64_t dim : targetType.getShape()) { + if (dim == ShapedType::kDynamic) { + anchorOp->emitError("dynamic shapes are not supported for tile_buf view ops"); + return Value(); + } + } + + Value parentVRow; + Value parentVCol; + lookupValidDims(src, parentVRow, parentVCol); + Value vRow = parentVRow; + Value vCol = parentVCol; + materializeStaticValidDims(rewriter, loc, tbTy, vRow, vCol); + + auto configAttr = tbTy.getConfigAttr(); + if (!configAttr) + configAttr = pto::TileBufConfigAttr::getDefault(ctx); + + auto bindOp = rewriter.create( + loc, targetType, src, vRow ? vRow : Value(), vCol ? vCol : Value(), + configAttr); + markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); + if (!viewSemantics.empty()) + bindOp->setAttr("pto.view_semantics", rewriter.getStringAttr(viewSemantics)); + return bindOp.getResult(); +} + +static LogicalResult lowerTileBufViewLikeOps(func::FuncOp func, MLIRContext *ctx) { + DefaultInlineVector reshapes; + func.walk([&](mlir::pto::TReshapeOp op) { reshapes.push_back(op); }); + for (auto op : reshapes) { + auto tbTy = dyn_cast(op.getResult().getType()); + if (!tbTy) { + op.emitError("treshape result must be tile_buf type"); + return failure(); + } + Value lowered = buildTileBufViewLikeValue(op, op->getOperand(0), tbTy, + "treshape", ctx); + if (!lowered) + return failure(); + IRRewriter rewriter(ctx); + rewriter.replaceOp(op, lowered); + } + + DefaultInlineVector bitcasts; + func.walk([&](mlir::pto::BitcastOp op) { bitcasts.push_back(op); }); + for (auto op : bitcasts) { + auto tbTy = dyn_cast(op.getResult().getType()); + if (!tbTy) { + op.emitError("bitcast result must be tile_buf type"); + return failure(); + } + Value lowered = buildTileBufViewLikeValue(op, op->getOperand(0), tbTy, + "bitcast", ctx); + if (!lowered) + return failure(); + IRRewriter rewriter(ctx); + rewriter.replaceOp(op, lowered); + } + return success(); +} + +// ============================================================================= +// The Pass Implementation +// ============================================================================= + +struct PTOViewToMemrefPass + : public mlir::pto::impl::PTOViewToMemrefBase { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PTOViewToMemrefPass) + + void runOnOperation() override { + ModuleOp mod = getOperation(); + MLIRContext *ctx = &getContext(); + + for (auto func : mod.getOps()) { + if (func.isExternal()) continue; + + // ------------------------------------------------------------------ + // Stage 0: ensure inttoptr values remain scalar-load/store only. + // ------------------------------------------------------------------ + if (failed(validateIntToPtrUses(func))) { + signalPassFailure(); + return; + } + + Block &entry = func.front(); + auto fnTy = func.getFunctionType(); + + // ------------------------------------------------------------------ + // Stage 0.10: Rewrite Function Signature + // ------------------------------------------------------------------ + SmallVector newInputs; + for (Type t : fnTy.getInputs()) newInputs.push_back(convertPTOTypeToMemRef(t)); + + SmallVector newResults; + for (Type t : fnTy.getResults()) newResults.push_back(convertPTOTypeToMemRef(t)); + + // Update entry block arguments + for (unsigned i = 0; i < entry.getNumArguments(); ++i) { + if (entry.getArgument(i).getType() != newInputs[i]) { + entry.getArgument(i).setType(newInputs[i]); + } + } + + // Update function type + func.setFunctionType(FunctionType::get(ctx, newInputs, newResults)); + + // ------------------------------------------------------------------ + // Stage 0.20: lower pto.inttoptr result types to GM memrefs. + // ------------------------------------------------------------------ + if (failed(lowerIntToPtrOps(func, ctx))) { + signalPassFailure(); + return; + } + + // ------------------------------------------------------------------ + // Stage 0.30: materialize pto.ptrtoint(addptr ...) byte offsets. + // ------------------------------------------------------------------ + if (failed(lowerPtrToIntOps(func, ctx))) { + signalPassFailure(); + return; + } + + // ------------------------------------------------------------------ + // Stage 0.5: lower pto.alloc_tile -> memref.alloc + pto.bind_tile + // ------------------------------------------------------------------ + DefaultInlineVector allocTiles; + func.walk([&](mlir::pto::AllocTileOp op) { allocTiles.push_back(op); }); + + for (auto op : allocTiles) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + auto tbTy = dyn_cast(op.getResult().getType()); + if (!tbTy) continue; + + // 1. 获取 Shape 和 ElementType + SmallInlineVector shape(tbTy.getShape().begin(), tbTy.getShape().end()); + Type elemTy = tbTy.getElementType(); + + // 2. 计算 Strides (layout-aware when possible) + SmallVector strides; + TileLayoutInfo info; + if (computeTileLayoutInfo(tbTy.getConfigAttr(), elemTy, shape, info)) { + strides = {info.rowStride, info.colStride}; + } else { + strides.resize(shape.size()); + int64_t s = 1; + for (int i = (int)shape.size() - 1; i >= 0; --i) { + strides[i] = s; + if (shape[i] != ShapedType::kDynamic) s *= shape[i]; + } + } + + // 3. 构造 [BindTile 输出] 的动态类型 (Offset: ?) + // 这必须与 convertPTOTypeToMemRef 返回的类型一致,以便与 Subview 兼容 + auto targetLayout = + StridedLayoutAttr::get(ctx, ShapedType::kDynamic, strides); // offset = ? + auto targetType = + MemRefType::get(shape, elemTy, targetLayout, tbTy.getMemorySpace()); + + // 4. Preserve tile valid dims (v_row / v_col). + // + // `pto.alloc_tile` encodes the valid shape in the result TileBufType + // (e.g. acc tile may be rows=16 but v_row=1). The alloc op itself does + // not necessarily carry explicit operands for static valid dims, so we + // must materialize them from the type to keep them through + // tile_buf -> memref lowering. + // + // For dynamically valid tiles (validShape == [-1, -1]), preserve the + // runtime operands if present. + Value vRow = op.getValidRow(); + Value vCol = op.getValidCol(); + // TileBuf valid dims use a negative sentinel (e.g. '?' / -1), which is + // distinct from MLIR's ShapedType::kDynamic (INT64_MIN). Treat any + // negative value as dynamic here. + materializeStaticValidDims(rewriter, loc, tbTy, vRow, vCol); + + // 5. 获取 Config (保持不变) + auto configAttr = tbTy.getConfigAttr(); + if (!configAttr) configAttr = pto::TileBufConfigAttr::getDefault(ctx); + + // 6. If alloc_tile provides an explicit address, keep the original + // pointer_cast lowering intact and additionally rebind through + // pto.bind_tile. PointerCastOp continues to carry the tile metadata + // used by existing lowering paths, while BindTileOp provides the + // unified anchor EmitC uses to recover tile_buf information. + if (Value addr = op.getAddr()) { + auto pc = rewriter.create( + loc, targetType, ValueRange{addr}, vRow ? vRow : Value(), + vCol ? vCol : Value(), configAttr); + markForceDynamicValidShape(pc, tbTy.hasDynamicValid(), ctx); + auto bindOp = rewriter.create( + loc, targetType, pc.getResult(), vRow ? vRow : Value(), + vCol ? vCol : Value(), configAttr); + markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); + rewriter.replaceOp(op, bindOp.getResult()); + continue; + } + + // 7. Otherwise allocate a concrete memref buffer and bind tile. + // memref.alloc 要求明确的 layout,不能是动态 offset。 + auto allocLayout = StridedLayoutAttr::get(ctx, 0, strides); // offset = 0 + auto allocType = MemRefType::get(shape, elemTy, allocLayout, tbTy.getMemorySpace()); + Value alloc = rewriter.create(loc, allocType); + + // BindTileOp 的 Builder 会自动处理空的 Value,将其视为静态维度 + auto bindOp = rewriter.create( + loc, targetType, alloc, vRow ? vRow : Value(), vCol ? vCol : Value(), + configAttr); + markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); + + rewriter.replaceOp(op, bindOp.getResult()); + } + + // ------------------------------------------------------------------ + // Stage 0.75: lower pto.declare_tile -> pto.declare_tile_memref + + // pto.bind_tile + // ------------------------------------------------------------------ + DefaultInlineVector declaredTiles; + func.walk([&](mlir::pto::DeclareTileOp op) { declaredTiles.push_back(op); }); + + for (auto op : declaredTiles) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + auto tbTy = dyn_cast(op.getTile().getType()); + if (!tbTy) { + op.emitError("declare_tile result must be tile_buf type"); + signalPassFailure(); + return; + } + + auto targetType = dyn_cast(convertPTOTypeToMemRef(tbTy)); + if (!targetType) { + op.emitError("failed to convert declare_tile result to memref type"); + signalPassFailure(); + return; + } + + auto configAttr = tbTy.getConfigAttr(); + if (!configAttr) + configAttr = pto::TileBufConfigAttr::getDefault(ctx); + + Value vRow; + Value vCol; + materializeStaticValidDims(rewriter, loc, tbTy, vRow, vCol); + + auto declaredMemRef = + rewriter.create(loc, targetType); + auto bindOp = rewriter.create( + loc, targetType, declaredMemRef.getResult(), + vRow ? vRow : Value(), vCol ? vCol : Value(), configAttr); + markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); + + rewriter.replaceOp(op, bindOp.getResult()); + } + + // ------------------------------------------------------------------ + // Stage 0.8: normalize pto.tassign result type to match tile operand + // after tile_buf -> memref lowering (required for verifier consistency). + // ------------------------------------------------------------------ + DefaultInlineVector tassignOps; + func.walk([&](mlir::pto::TAssignOp op) { tassignOps.push_back(op); }); + for (auto op : tassignOps) { + Type targetTy = op.getTile().getType(); + if (op.getResult().getType() == targetTy) + continue; + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + auto normalized = + rewriter.create(op.getLoc(), targetTy, op.getTile(), + op.getAddr()); + rewriter.replaceOp(op, normalized.getResult()); + } + + // ------------------------------------------------------------------ + // Stage 1: Lower pto.make_tensor_view -> memref.reinterpret_cast + // ------------------------------------------------------------------ + DefaultInlineVector makeViews; + func.walk([&](mlir::pto::MakeTensorViewOp op) { makeViews.push_back(op); }); + + for (auto op : makeViews) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + Value baseBuf = op.getOperand(0); + OpFoldResult off0 = rewriter.getIndexAttr(0); + + // Fold pto.addptr chains into the view base to avoid nested reinterpret_cast. + bool foldedAddPtr = false; + { + Value cur = baseBuf; + Value totalOffset; + while (auto add = cur.getDefiningOp()) { + foldedAddPtr = true; + Value off = ensureIndex(rewriter, loc, add.getOperand(1), add); + if (totalOffset) + totalOffset = rewriter.create(loc, totalOffset, off); + else + totalOffset = off; + cur = add.getOperand(0); + } + if (cur != baseBuf) { + baseBuf = cur; + off0 = totalOffset ? OpFoldResult(totalOffset) : off0; + } + } + + auto baseMr = dyn_cast(baseBuf.getType()); + if (!baseMr) { + op.emitError("make_tensor_view base must be memref"); signalPassFailure(); return; + } + + // [修复] 获取动态 Rank (根据 shape 输入的数量) + size_t rank = op.getShape().size(); + + // Construct target type with dynamic offset/strides + Type elemTy = baseMr.getElementType(); + int64_t dyn = ShapedType::kDynamic; + + // [修复] 构建 N 维 Strided Layout + // strides 数组长度必须等于 rank + SmallVector dynStrides(rank, dyn); + auto layout = StridedLayoutAttr::get(ctx, /*offset=*/dyn, /*strides=*/dynStrides); + + // [修复] 构建 N 维 Shape + SmallVector dynShape(rank, dyn); + auto mrTy = MemRefType::get(dynShape, elemTy, layout, baseMr.getMemorySpace()); + + SmallInlineVector sizes; + for (Value v : op.getShape()) sizes.push_back(ensureIndex(rewriter, loc, v, op)); + + SmallInlineVector strides; + for (Value v : op.getStrides()) strides.push_back(ensureIndex(rewriter, loc, v, op)); + + auto rc = rewriter.create( + loc, mrTy, baseBuf, off0, sizes, strides); + if (foldedAddPtr) { + rc->setAttr("pto.addptr_trace", rewriter.getUnitAttr()); + } + if (auto layoutAttr = op.getLayoutAttr()) { + rc->setAttr("layout", layoutAttr); + } + + rewriter.replaceOp(op, rc.getResult()); + } + + // ------------------------------------------------------------------ + // Stage 1.25: Lower pto.get_tensor_view_dim -> memref.dim + // ------------------------------------------------------------------ + DefaultInlineVector tvDims; + func.walk([&](mlir::pto::GetTensorViewDimOp op) { tvDims.push_back(op); }); + + for (auto op : tvDims) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + Value view = op.getTensorView(); + auto mrTy = dyn_cast(view.getType()); + if (!mrTy) + continue; // leave it to later passes if it hasn't been lowered yet + + Value dimIdx = op.getDimIndex(); + Value dim = rewriter.create(loc, view, dimIdx); + rewriter.replaceOp(op, dim); + } + + // ------------------------------------------------------------------ + // Stage 1.3: Lower pto.partition_view -> memref.subview + // ------------------------------------------------------------------ + if (failed(lowerPartitionViewOps(func, ctx))) { + signalPassFailure(); + return; + } + + // ------------------------------------------------------------------ + // Stage 1.35: Lower pto.subview -> memref.subview + pto.bind_tile + // ------------------------------------------------------------------ + if (failed(lowerSubViewOps(func, ctx))) { + signalPassFailure(); + return; + } + + // ------------------------------------------------------------------ + // Stage 1.4: Lower tile_buf view-like ops (treshape/bitcast) + // ------------------------------------------------------------------ + if (failed(lowerTileBufViewLikeOps(func, ctx))) { + signalPassFailure(); + return; + } + + // ------------------------------------------------------------------ + // Stage 1.5: Fold pto.addptr chains into load/store_scalar. + // ------------------------------------------------------------------ + DefaultInlineVector loadScalars; + func.walk([&](mlir::pto::LoadScalarOp op) { loadScalars.push_back(op); }); + + for (auto op : loadScalars) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + Value base = op.getPtr(); + Value totalOffset = ensureIndex(rewriter, loc, op.getOffset(), op); + + bool foldedAddPtr = false; + while (auto add = base.getDefiningOp()) { + foldedAddPtr = true; + Value off = ensureIndex(rewriter, loc, add.getOperand(1), add); + if (totalOffset) + totalOffset = rewriter.create(loc, totalOffset, off); + else + totalOffset = off; + base = add.getOperand(0); + } + + if (foldedAddPtr) { + auto newOp = rewriter.create( + loc, op.getValue().getType(), base, totalOffset); + rewriter.replaceOp(op, newOp.getValue()); + } + } + + DefaultInlineVector storeScalars; + func.walk([&](mlir::pto::StoreScalarOp op) { storeScalars.push_back(op); }); + + for (auto op : storeScalars) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + Value base = op.getPtr(); + Value totalOffset = ensureIndex(rewriter, loc, op.getOffset(), op); + + bool foldedAddPtr = false; + while (auto add = base.getDefiningOp()) { + foldedAddPtr = true; + Value off = ensureIndex(rewriter, loc, add.getOperand(1), add); + if (totalOffset) + totalOffset = rewriter.create(loc, totalOffset, off); + else + totalOffset = off; + base = add.getOperand(0); + } + + if (foldedAddPtr) { + rewriter.create( + loc, base, totalOffset, op.getValue()); + rewriter.eraseOp(op); + } + } + + // ------------------------------------------------------------------ + // Stage 1.75: Fold addptr used by initialize_l2g2l_pipe(gm_addr). + // This keeps IR well-typed after function arguments are rewritten from + // !pto.ptr to memref. + // ------------------------------------------------------------------ + bool foldedPipeInitAddPtr = true; + while (foldedPipeInitAddPtr) { + foldedPipeInitAddPtr = false; + DefaultInlineVector addPtrsForPipeInit; + func.walk([&](mlir::pto::AddPtrOp op) { + bool eligible = !op->use_empty(); + for (Operation *user : op->getUsers()) { + auto init = dyn_cast(user); + if (!init || init.getGmAddr() != op->getResult(0)) { + eligible = false; + break; + } + } + if (eligible) + addPtrsForPipeInit.push_back(op); + }); + + for (auto op : addPtrsForPipeInit) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + Value base = op->getOperand(0); + Value totalOffset = ensureIndex(rewriter, loc, op->getOperand(1), op); + while (auto add = base.getDefiningOp()) { + Value off = ensureIndex(rewriter, loc, add->getOperand(1), add); + totalOffset = rewriter.create(loc, totalOffset, off); + base = add->getOperand(0); + } + + auto baseMrTy = dyn_cast(base.getType()); + if (!baseMrTy || baseMrTy.getRank() != 1) + continue; + + int64_t dyn = ShapedType::kDynamic; + auto layout = StridedLayoutAttr::get(ctx, dyn, {dyn}); + auto targetTy = MemRefType::get({dyn}, baseMrTy.getElementType(), layout, + baseMrTy.getMemorySpace()); + SmallVector sizes{rewriter.getIndexAttr(1)}; + SmallVector strides{rewriter.getIndexAttr(1)}; + auto rc = rewriter.create( + loc, targetTy, base, OpFoldResult(totalOffset), sizes, strides); + rc->setAttr("pto.addptr_trace", rewriter.getUnitAttr()); + rewriter.replaceOp(op, rc.getResult()); + foldedPipeInitAddPtr = true; + } + } + + // Clean up: addptr should be folded into make_tensor_view. + DefaultInlineVector addPtrs; + func.walk([&](mlir::pto::AddPtrOp op) { addPtrs.push_back(op.getOperation()); }); + bool changed = true; + while (changed) { + changed = false; + for (auto &op : addPtrs) { + if (!op) + continue; + if (op->use_empty()) { + op->erase(); + op = nullptr; + changed = true; + } + } + } + for (auto *op : addPtrs) { + if (!op) + continue; + op->emitError("addptr must feed make_tensor_view, initialize_l2g2l_pipe(gm_addr) or load/store_scalar for lowering"); + signalPassFailure(); + return; + } + + // ------------------------------------------------------------------ + // Stage 3: Rewrite Compute Ops + // [关键] 全面使用 op->getOperand(i) 避免 Typed Accessor Crash + // ------------------------------------------------------------------ + + // --- TLoadOp [Src, Dst] --- + DefaultInlineVector loads; + func.walk([&](mlir::pto::TLoadOp op) { loads.push_back(op); }); + for (auto op : loads) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op->getOperand(0); + Value dst = op->getOperand(1); + + auto newOp = + rewriter.create(op.getLoc(), TypeRange{}, src, dst); + newOp->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newOp->getResults()); + } + + // --- TStoreOp [Src, Dst] --- + DefaultInlineVector storeops; + func.walk([&](mlir::pto::TStoreOp op) { storeops.push_back(op); }); + for (auto op : storeops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op->getOperand(0); + Value dst = op->getOperand(1); + Value preQuant = op.getPreQuantScalar(); + + pto::TStoreOp newOp; + if (preQuant) { + newOp = rewriter.create(op.getLoc(), TypeRange{}, + src, dst, preQuant); + } else { + newOp = rewriter.create(op.getLoc(), TypeRange{}, + src, dst, Value{}); + } + newOp->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newOp->getResults()); + } + + // --- TTransOp [Src, Tmp, Dst] --- + DefaultInlineVector trans; + func.walk([&](mlir::pto::TTransOp op) { trans.push_back(op); }); + for (auto op : trans) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex)); + } + + // --- TExpOp [Src, Dst] --- + DefaultInlineVector exp; + func.walk([&](mlir::pto::TExpOp op) { exp.push_back(op); }); + for (auto op : exp) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, op->getOperand(0), op->getOperand(1)); + } + + // --- TMulOp [Src, Scalar, Dst] --- + DefaultInlineVector mul; + func.walk([&](mlir::pto::TMulOp op) { mul.push_back(op); }); + for (auto op : mul) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, op->getOperand(0), op.getOperand(1), + op->getOperand(kThirdOperandIndex)); + } + + // --- TMulSOp [Src, Scalar, Dst] --- + DefaultInlineVector muls; + func.walk([&](mlir::pto::TMulSOp op) { muls.push_back(op); }); + for (auto op : muls) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, op->getOperand(0), op.getScalar(), + op->getOperand(kThirdOperandIndex)); + } + + // --- TAddOp [Src0, Src1, Dst] --- + DefaultInlineVector addops; + func.walk([&](mlir::pto::TAddOp op) { addops.push_back(op); }); + for (auto op : addops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex)); + } + + // --- TMatmulOp [Lhs, Rhs, Dst] (no optional bias in ODS) --- + DefaultInlineVector matmuls; + func.walk([&](mlir::pto::TMatmulOp op) { matmuls.push_back(op); }); + for (auto op : matmuls) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Value lhs = op->getOperand(0); + Value rhs = op->getOperand(1); + Value dst = op->getOperand(kThirdOperandIndex); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, lhs, rhs, dst, op.getAccPhaseAttr()); + } + + // --- TMatmulAccOp [Acc, Lhs, Rhs, Dst] --- + DefaultInlineVector matmulAccs; + func.walk([&](mlir::pto::TMatmulAccOp op) { matmulAccs.push_back(op); }); + for (auto op : matmulAccs) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex), op.getAccPhaseAttr()); + } + + // --- TMatmulBiasOp [Acc, Lhs, Rhs, Bias, Dst] --- + DefaultInlineVector matmulBiass; + func.walk([&](mlir::pto::TMatmulBiasOp op) { matmulBiass.push_back(op); }); + for (auto op : matmulBiass) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex)); + } + + // --- TMatmulMxOp--- + DefaultInlineVector matmulMxs; + func.walk([&](mlir::pto::TMatmulMxOp op) { matmulMxs.push_back(op); }); + for (auto op : matmulMxs) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex), + op->getOperand(kFifthOperandIndex)); + } + + // --- TMatmulMxAccOp --- + DefaultInlineVector matmulMxAccs; + func.walk([&](mlir::pto::TMatmulMxAccOp op) { matmulMxAccs.push_back(op); }); + for (auto op : matmulMxAccs) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex), + op->getOperand(kFifthOperandIndex), + op->getOperand(kSixthOperandIndex)); + } + + // --- TMatmulMxBiasOp --- + DefaultInlineVector matmulMxBiass; + func.walk([&](mlir::pto::TMatmulMxBiasOp op) { matmulMxBiass.push_back(op); }); + for (auto op : matmulMxBiass) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex), + op->getOperand(kFifthOperandIndex), + op->getOperand(kSixthOperandIndex)); + } + + // --- TGemvOp [Lhs, Rhs, Dst] --- + DefaultInlineVector gemvs; + func.walk([&](mlir::pto::TGemvOp op) { gemvs.push_back(op); }); + for (auto op : gemvs) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value lhs = op->getOperand(0); + Value rhs = op->getOperand(1); + Value dst = op->getOperand(kThirdOperandIndex); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, lhs, rhs, dst); + } + + // --- TGemvAccOp [Acc, Lhs, Rhs, Dst] --- + DefaultInlineVector gemvAccs; + func.walk([&](mlir::pto::TGemvAccOp op) { gemvAccs.push_back(op); }); + for (auto op : gemvAccs) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex)); + } + + // --- TGemvBiasOp [Acc, Lhs, Rhs, Bias, Dst] --- + DefaultInlineVector gemvBiass; + func.walk([&](mlir::pto::TGemvBiasOp op) { gemvBiass.push_back(op); }); + for (auto op : gemvBiass) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex)); + } + + // --- TGemvMxOp [A, AScale, B, BScale, Dst] --- + DefaultInlineVector gemvMxs; + func.walk([&](mlir::pto::TGemvMxOp op) { gemvMxs.push_back(op); }); + for (auto op : gemvMxs) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex), + op->getOperand(kFifthOperandIndex)); + } + + // --- TGemvMxAccOp [CIn, A, AScale, B, BScale, Dst] --- + DefaultInlineVector gemvMxAccs; + func.walk([&](mlir::pto::TGemvMxAccOp op) { gemvMxAccs.push_back(op); }); + for (auto op : gemvMxAccs) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex), + op->getOperand(kFifthOperandIndex), + op->getOperand(kSixthOperandIndex)); + } + + // --- TGemvMxBiasOp [A, AScale, B, BScale, Bias, Dst] --- + DefaultInlineVector gemvMxBiass; + func.walk([&](mlir::pto::TGemvMxBiasOp op) { gemvMxBiass.push_back(op); }); + for (auto op : gemvMxBiass) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex), + op->getOperand(kFifthOperandIndex), + op->getOperand(kSixthOperandIndex)); + } + + // --- TMovOp [Src, Dst] --- + DefaultInlineVector movs; + func.walk([&](mlir::pto::TMovOp op) { movs.push_back(op); }); + for (auto op : movs) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, op.getSrc(), op.getDst(), op.getFp(), + op.getPreQuantScalar(), op.getAccToVecModeAttr(), + op.getReluPreModeAttr()); + } + + DefaultInlineVector abseops; + func.walk([&](mlir::pto::TAbsOp op) { abseops.push_back(op); }); + + for (auto op : abseops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); + } + + DefaultInlineVector addcops; + func.walk([&](mlir::pto::TAddCOp op) { addcops.push_back(op); }); + + for (auto op : addcops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value src2 = op.getSrc2(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto src2Ty = dyn_cast(src2.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !src2Ty ||!dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + src2, + dst); + } + + DefaultInlineVector addsops; + func.walk([&](mlir::pto::TAddSOp op) { addsops.push_back(op); }); + + for (auto op : addsops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value scalar = op.getScalar(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + scalar, + dst); + } + + DefaultInlineVector addscops; + func.walk([&](mlir::pto::TAddSCOp op) { addscops.push_back(op); }); + + for (auto op : addscops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value scalar = op.getScalar(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + scalar, + src1, + dst); + } + + DefaultInlineVector andops; + func.walk([&](mlir::pto::TAndOp op) { andops.push_back(op); }); + + for (auto op : andops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + dst); + } + + DefaultInlineVector concats; + func.walk([&](mlir::pto::TConcatOp op) { concats.push_back(op); }); + + for (auto op : concats) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + dst); + } + + DefaultInlineVector concatIdxs; + func.walk([&](mlir::pto::TConcatidxOp op) { concatIdxs.push_back(op); }); + + IRRewriter rewriter(ctx); + for (auto op : concatIdxs) { + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value src0Idx = op.getSrc0Idx(); + Value src1Idx = op.getSrc1Idx(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto src0IdxTy = dyn_cast(src0Idx.getType()); + auto src1IdxTy = dyn_cast(src1Idx.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !src0IdxTy || !src1IdxTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + src0Idx, + src1Idx, + dst); + } + + DefaultInlineVector andsops; + func.walk([&](mlir::pto::TAndSOp op) { andsops.push_back(op); }); + + for (auto op : andsops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value scalar = op.getScalar(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + scalar, + dst); + } + + DefaultInlineVector ciops; + func.walk([&](mlir::pto::TCIOp op) { ciops.push_back(op); }); + + for (auto op : ciops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value s = op->getOperand(0); + Value dst = op.getDst(); + bool descending = op.getDescending(); + + auto sTy = dyn_cast(s.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!sTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + s, + dst, + descending); + } + + DefaultInlineVector cmpops; + func.walk([&](mlir::pto::TCmpOp op) { cmpops.push_back(op); }); + + for (auto op : cmpops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + auto newOp = rewriter.create( + op.getLoc(), + TypeRange{}, + src0, + src1, + dst); + + if (auto a = op.getCmpModeAttr()) + newOp->setAttr("cmpMode", a); + + rewriter.replaceOp(op, newOp->getResults()); // 0 results -> OK + } + + DefaultInlineVector cmpsops; + func.walk([&](mlir::pto::TCmpSOp op) { cmpsops.push_back(op); }); + + for (auto op : cmpsops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value scalar = op.getScalar(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + auto scalarTy = scalar.getType(); + bool scalarOk = + isa(scalarTy); // ScalarType in ODS: int/float + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + if (!scalarOk) { + op.emitError("expects scalar to be an integer or float type"); + signalPassFailure(); + return; + } + + auto cmpMode = op.getCmpModeAttr(); + auto newOp = rewriter.create( + op.getLoc(), + TypeRange{}, + src, + scalar, + cmpMode, + dst); + + rewriter.replaceOp(op, newOp->getResults()); // 0 results -> OK + } + + DefaultInlineVector colexpand; + func.walk([&](mlir::pto::TColExpandOp op) { colexpand.push_back(op); }); + + for (auto op : colexpand) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if ( !srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); + } + + DefaultInlineVector colmaxops; + func.walk([&](mlir::pto::TColMaxOp op) { colmaxops.push_back(op); }); + + for (auto op : colmaxops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if ( !srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); + } + + DefaultInlineVector colminops; + func.walk([&](mlir::pto::TColMinOp op) { colminops.push_back(op); }); + + for (auto op : colminops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if ( !srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); + } + + DefaultInlineVector colexpandmulops; + func.walk([&](mlir::pto::TColExpandMulOp op) { + colexpandmulops.push_back(op); + }); + + for (auto op : colexpandmulops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + dst); + } + + DefaultInlineVector colexpandmaxops; + func.walk([&](mlir::pto::TColExpandMaxOp op) { + colexpandmaxops.push_back(op); + }); + + for (auto op : colexpandmaxops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + dst); + } + + DefaultInlineVector colexpandminops; + func.walk([&](mlir::pto::TColExpandMinOp op) { + colexpandminops.push_back(op); + }); + + for (auto op : colexpandminops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + dst); + } + + DefaultInlineVector colsumops; + func.walk([&](mlir::pto::TColSumOp op) { colsumops.push_back(op); }); + + for (auto op : colsumops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + Value tmp = op.getTmp(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("src/dst are not memref yet"); + signalPassFailure(); + return; + } + + // If tmp exists, it must have isBinary attribute + if (tmp) { + auto tmpTy = dyn_cast(tmp.getType()); + if (!tmpTy) { + op.emitError("tmp is not memref yet"); + signalPassFailure(); + return; + } + + // Get isBinary attribute (should exist if tmp exists) + BoolAttr isBinaryAttr = op.getIsBinaryAttr(); + if (!isBinaryAttr) { + isBinaryAttr = BoolAttr::get(ctx, false); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + tmp, + dst, + isBinaryAttr); + } else { + // Format 1: no tmp, no isBinary + // Use generic builder to avoid adding default isBinary attribute + SmallVector operands = {src, dst}; + SmallVector attrs; + // Copy all attributes except isBinary + for (auto attr : op->getAttrs()) { + if (attr.getName() != "isBinary") { + attrs.push_back(attr); + } + } + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + operands, + attrs); + } + } + + DefaultInlineVector cvtops; + func.walk([&](mlir::pto::TCvtOp op) { cvtops.push_back(op); }); + + for (auto op : cvtops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + auto rmodeAttr = op.getRmodeAttr(); // PTO_RoundModeAttr + auto satModeAttr = op.getSatModeAttr(); + + auto newOp = rewriter.create( + op.getLoc(), + TypeRange{}, + src, + dst, + rmodeAttr, + satModeAttr); + + rewriter.replaceOp(op, newOp->getResults()); + } + + DefaultInlineVector divops; + func.walk([&](mlir::pto::TDivOp op) { divops.push_back(op); }); + + for (auto op : divops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + dst); + } + + DefaultInlineVector divsops; + func.walk([&](mlir::pto::TDivSOp op) { divsops.push_back(op); }); + + for (auto op : divsops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value scale = op.getScalar(); + Value dst = op.getDst(); + + // Check types - they might still be TileBufType or already converted to MemRefType + auto srcTy = dyn_cast(src.getType()); + auto srcTileTy = dyn_cast(src.getType()); + auto scaleTileTy = dyn_cast(scale.getType()); + auto dstTy = dyn_cast(dst.getType()); + auto dstTileTy = dyn_cast(dst.getType()); + + // Determine which operand is tile-like and which is scalar-like. + // Keep the original operand order (set by parser textual form). + // Check if src is memref/tensor/tile (not scalar) + bool srcIsMemref = (srcTy != nullptr || srcTileTy != nullptr || + isa(src.getType()) || + isa(src.getType())); + // Check if scale is memref/tensor/tile (not scalar) + bool scaleIsMemref = (isa(scale.getType()) || + scaleTileTy != nullptr || + isa(scale.getType()) || + isa(scale.getType())); + + // Type validation - ensure we have the right types + if (!srcIsMemref && !scaleIsMemref) { + op.emitError("at least one operand (src or scale) must be tile_buf or memref"); + signalPassFailure(); + return; + } + if (srcIsMemref && scaleIsMemref) { + op.emitError("exactly one operand (src or scale) must be tile_buf or memref, the other must be scalar"); + signalPassFailure(); + return; + } + + if (!dstTy && !dstTileTy) { + op.emitError("dst operand must be tile_buf or memref"); + signalPassFailure(); + return; + } + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + scale, + dst); + } + + DefaultInlineVector expandsops; + func.walk([&](mlir::pto::TExpandsOp op) { expandsops.push_back(op); }); + + for (auto op : expandsops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value scalar = op.getScalar(); + Value dst = op.getDst(); + + auto dstTy = dyn_cast(dst.getType()); + if (!dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + scalar, + dst); + } + + DefaultInlineVector extractops; + func.walk([&](mlir::pto::TExtractOp op) { extractops.push_back(op); }); + + for (auto op : extractops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value indexRow = op.getIndexRow(); + Value indexCol = op.getIndexCol(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto indexRowTy = dyn_cast(indexRow.getType()); + auto indexColTy = dyn_cast(indexCol.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !indexRowTy || !indexColTy || !dstTy) { + op.emitError("ins/outs are not correct yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + indexRow, + indexCol, + dst); + } + + DefaultInlineVector fillpadops; + func.walk([&](mlir::pto::TFillPadOp op) { fillpadops.push_back(op); }); + + for (auto op : fillpadops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); + } + + DefaultInlineVector fillpadInplaceOps; + func.walk( + [&](mlir::pto::TFillPadInplaceOp op) { fillpadInplaceOps.push_back(op); }); + + for (auto op : fillpadInplaceOps) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); + } + + // --- TSetValOp [Dst, Offset, Val] --- + // Lower tile-world scalar write to memref-world SETVAL DPS op. + DefaultInlineVector tsetvalops; + func.walk([&](mlir::pto::TSetValOp op) { tsetvalops.push_back(op); }); + + for (auto op : tsetvalops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value dst = op.getDst(); + Value offset = op.getOffset(); + Value val = op.getVal(); + + auto dstTy = dyn_cast(dst.getType()); + if (!dstTy) { + op.emitError("dst is not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + dst, + offset, + val); + } + + // --- TGetValOp [Src, Offset] -> Scalar --- + // Lower tile-world scalar read to memref-world GETVAL DPS op. + DefaultInlineVector tgetvalops; + func.walk([&](mlir::pto::TGetValOp op) { tgetvalops.push_back(op); }); + + for (auto op : tgetvalops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value offset = op.getOffset(); + Type dstType = op.getDst().getType(); + + auto srcTy = dyn_cast(src.getType()); + if (!srcTy) { + op.emitError("src is not memref yet"); + signalPassFailure(); + return; + } + + auto newOp = rewriter.create( + op.getLoc(), + dstType, + src, + offset); + rewriter.replaceOp(op, newOp.getDst()); + } + + DefaultInlineVector gatherops; + func.walk([&](mlir::pto::TGatherOp op) { gatherops.push_back(op); }); + + for (auto op : gatherops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + Value cdst = op.getCdst(); + Value indices = op.getIndices(); + Value tmp = op.getTmp(); + Value kValue = op.getKValue(); + auto maskPattern = op.getMaskPatternAttr(); + auto cmpMode = op.getCmpModeAttr(); + auto offset = op.getOffsetAttr(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + if (maskPattern) { + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst, + /*cdst=*/Value(), + /*indices=*/Value(), + /*tmp=*/Value(), + /*kValue=*/Value(), + /*maskPattern=*/maskPattern, + /*cmpMode=*/pto::CmpModeAttr(), + /*offset=*/IntegerAttr()); + continue; + } + + if (cdst || kValue) { + auto cdstTy = dyn_cast(cdst.getType()); + auto tmpTy = dyn_cast(tmp.getType()); + if (!cdstTy || !tmpTy) { + op.emitError("compare-form tgather expects cdst/tmp to be memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst, + cdst, + /*indices=*/Value(), + tmp, + kValue, + /*maskPattern=*/pto::MaskPatternAttr(), + cmpMode, + offset); + continue; + } + + if (indices || tmp) { + auto indicesTy = dyn_cast(indices.getType()); + auto tmpTy = dyn_cast(tmp.getType()); + if (!indicesTy || !tmpTy) { + op.emitError("index-form tgather expects indices/tmp to be memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst, + /*cdst=*/Value(), + indices, + tmp, + /*kValue=*/Value(), + /*maskPattern=*/pto::MaskPatternAttr(), + /*cmpMode=*/pto::CmpModeAttr(), + /*offset=*/IntegerAttr()); + continue; + } + + op.emitError("expects tgather to be in mask, index+tmp, or compare+tmp form"); + signalPassFailure(); + return; + } + + DefaultInlineVector gatherbops; + func.walk([&](mlir::pto::TGatherBOp op) { gatherbops.push_back(op); }); + + for (auto op : gatherbops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value offsets = op.getOffsets(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto offsetsTy = dyn_cast(offsets.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !offsetsTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + offsets, + dst); + } + + DefaultInlineVector logops; + func.walk([&](mlir::pto::TLogOp op) { logops.push_back(op); }); + + for (auto op : logops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); + } + + DefaultInlineVector lreluops; + func.walk([&](mlir::pto::TLReluOp op) { lreluops.push_back(op); }); + + for (auto op : lreluops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value slope = op.getSlope(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto slopeTy = dyn_cast(slope.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !slopeTy || !dstTy) { + op.emitError("ins/outs are not correct type yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + slope, + dst); + } + + DefaultInlineVector maxops; + func.walk([&](mlir::pto::TMaxOp op) { maxops.push_back(op); }); + + for (auto op : maxops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + dst); + } + + DefaultInlineVector maxsops; + func.walk([&](mlir::pto::TMaxSOp op) { maxsops.push_back(op); }); + + for (auto op : maxsops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value scalar = op.getScalar(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + bool scalarIsScalar = isa(scalar.getType()); + if (!srcTy || !scalarIsScalar || !dstTy) { + op.emitError("expects src/dst to be memref and scalar to be integer/float"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + scalar, + dst); + } + + DefaultInlineVector minops; + func.walk([&](mlir::pto::TMinOp op) { minops.push_back(op); }); + + for (auto op : minops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + dst); + } + + DefaultInlineVector minsops; + func.walk([&](mlir::pto::TMinSOp op) { minsops.push_back(op); }); + + for (auto op : minsops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value scalar = op.getScalar(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + bool scalarIsScalar = isa(scalar.getType()); + if (!srcTy || !scalarIsScalar || !dstTy) { + op.emitError("expects src/dst to be memref and scalar to be integer/float"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + scalar, + dst); + } + + DefaultInlineVector movfpops; + func.walk([&](mlir::pto::TMovFPOp op) { movfpops.push_back(op); }); + + for (auto op : movfpops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value fp = op.getFp(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto fpTy = dyn_cast(fp.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !fpTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + fp, + dst); + } + + DefaultInlineVector quantops; + func.walk([&](mlir::pto::TQuantOp op) { quantops.push_back(op); }); + + for (auto op : quantops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value fp = op.getFp(); + Value offset = op.getOffset(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto fpTy = dyn_cast(fp.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !fpTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + if (offset && !dyn_cast(offset.getType())) { + op.emitError("offset is not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + fp, + offset, + dst, + op.getQuantTypeAttr()); + } + + DefaultInlineVector mrgsortops; + func.walk([&](mlir::pto::TMrgSortOp op) { mrgsortops.push_back(op); }); + + for (auto op : mrgsortops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + if (op.isFormat1()) { + Value src = op.getSrc(); + Value dst = op.getDst(); + Value blockLenVal = op.getBlockLen(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + ValueRange{src}, + blockLenVal, + ValueRange{dst}, + Value() /*tmp*/, + Value() /*excuted*/, + op.getExhaustedAttr()); + } else if (op.isFormat2()) { + bool allMemRef = true; + for (Value v : op.getSrcs()) + if (!dyn_cast(v.getType())) { allMemRef = false; break; } + if (!allMemRef) { + op.emitError("format2 ins/outs are not memref yet"); + signalPassFailure(); + return; + } + if (op.getDsts().size() != 1u || !op.getTmp()) { + op.emitError("format2 expects outs(dst) and ins(tmp)"); + signalPassFailure(); + return; + } + + Value dst = op.getDst(); + Value tmp = op.getTmp(); + Value excuted = op.getExcuted(); + if (!dyn_cast(dst.getType()) || !dyn_cast(tmp.getType())) { + op.emitError("format2 dst/tmp must be memref"); + signalPassFailure(); + return; + } + if (!dyn_cast(excuted.getType())) { + op.emitError("format2 outs(excuted) must be vector"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + op.getSrcs(), + Value() /*blockLen*/, + ValueRange{dst}, + tmp, + excuted, + op.getExhaustedAttr()); + } else { + op.emitError("tmrgsort must be format1 or format2"); + signalPassFailure(); + return; + } + } + + DefaultInlineVector negops; + func.walk([&](mlir::pto::TNegOp op) { negops.push_back(op); }); + + for (auto op : negops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); + } + + DefaultInlineVector notops; + func.walk([&](mlir::pto::TNotOp op) { notops.push_back(op); }); + + for (auto op : notops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); + } + + DefaultInlineVector orops; + func.walk([&](mlir::pto::TOrOp op) { orops.push_back(op); }); + + for (auto op : orops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + src0, + src1, + dst); + } + + DefaultInlineVector orsops; + func.walk([&](mlir::pto::TOrSOp op) { orsops.push_back(op); }); + + for (auto op : orsops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value scalar = op.getScalar(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto scalarTy = dyn_cast(scalar.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !scalarTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + scalar, + dst); + } + + DefaultInlineVector partaddops; + func.walk([&](mlir::pto::TPartAddOp op) { partaddops.push_back(op); }); + + for (auto op : partaddops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + src0, + src1, + dst); + } + + DefaultInlineVector partmulops; + func.walk([&](mlir::pto::TPartMulOp op) { partmulops.push_back(op); }); + + for (auto op : partmulops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + src0, + src1, + dst); + } + + DefaultInlineVector mgatherops; + func.walk([&](mlir::pto::MGatherOp op) { mgatherops.push_back(op); }); + + for (auto op : mgatherops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value dst = op.getDst(); + Value idx = op.getIdx(); + Value mem = op.getMem(); + + auto dstTy = dyn_cast(dst.getType()); + auto idxTy = dyn_cast(idx.getType()); + auto memTy = dyn_cast(mem.getType()); + if (!dstTy || !idxTy || !memTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + mem, + idx, + dst, + op.getGatherOobAttr()); + } + + DefaultInlineVector mascatterops; + func.walk([&](mlir::pto::MScatterOp op) { mascatterops.push_back(op); }); + + for (auto op : mascatterops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value idx = op.getIdx(); + Value mem = op.getMem(); + + auto srcTy = dyn_cast(src.getType()); + auto idxTy = dyn_cast(idx.getType()); + auto memTy = dyn_cast(mem.getType()); + if (!srcTy || !idxTy || !memTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + idx, + mem, + op.getScatterAtomicOpAttr(), + op.getScatterOobAttr()); + } + DefaultInlineVector printops; + func.walk([&](mlir::pto::TPrintOp op) { printops.push_back(op); }); + + for (auto op : printops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + + auto srcTy = dyn_cast(src.getType()); + if (!srcTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src); + } + + // ------------------------------------------------------------------ + // Stage 4: Reconcile control-flow result types + // ------------------------------------------------------------------ + if (failed(reconcileSCFIfResultTypes(func))) { + signalPassFailure(); + return; + } + if (failed(reconcileSCFForResultTypes(func))) { + signalPassFailure(); + return; + } + + // Mark memref-form set_validshape only after control-flow result-type + // reconciliation. Values such as scf.if results can stay tile_buf until + // this late stage. + if (failed(markLoweredSetValidShapeOps(func, ctx))) { + signalPassFailure(); + return; + } + } + + // Debug Output + LLVM_DEBUG(llvm::dbgs() << mod.getOperation()); + } +}; + +} // namespace + +std::unique_ptr createPTOViewToMemrefPass() { + return std::make_unique(); +} + +} // namespace pto +} // namespace mlir diff --git a/tools/ptobc/generated/ptobc_opcodes_v0.def b/tools/ptobc/generated/ptobc_opcodes_v0.def new file mode 100644 index 000000000..8303e1261 --- /dev/null +++ b/tools/ptobc/generated/ptobc_opcodes_v0.def @@ -0,0 +1,722 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Generated by docs/bytecode/tools/gen_v0_tables.py +#pragma once + +#include +#include + +#include +#include + +namespace ptobc::v0 { + +inline constexpr uint8_t kVariantDefault = 0; +inline constexpr uint8_t kVariantAcc = 1; +inline constexpr uint8_t kVariantBias = 2; +inline constexpr uint8_t kVariantMx = 3; +inline constexpr uint8_t kVariantMxAcc = 4; +inline constexpr uint8_t kVariantMxBias = 5; +inline constexpr uint8_t kSectionCubeVariant = 0; +inline constexpr uint8_t kSectionVectorVariant = 1; +inline constexpr uint8_t kHasVariant = 1; +inline constexpr uint16_t kTscatterMaskOpcode = 0x109C; + +inline constexpr int kTgemvOperandCount = 3; +inline constexpr int kTgemvAccOperandCount = 4; +inline constexpr int kTgemvBiasOperandCount = 4; +inline constexpr int kTgemvMxOperandCount = 5; +inline constexpr int kTgemvMxAccOperandCount = 6; +inline constexpr int kTgemvMxBiasOperandCount = 6; +inline constexpr int kTmatmulOperandCount = 3; +inline constexpr int kTmatmulAccOperandCount = 4; +inline constexpr int kTmatmulBiasOperandCount = 4; +inline constexpr int kTmatmulMxOperandCount = 5; +inline constexpr int kTmatmulMxAccOperandCount = 6; +inline constexpr int kTmatmulMxBiasOperandCount = 6; + +struct OpInfo { + uint16_t opcode; + const char *name; + uint8_t has_variant_u8; + uint8_t result_type_mode; + uint8_t operand_mode; + uint16_t num_operands; + uint16_t num_results; + uint16_t num_regions; + uint8_t imm_kind; +}; + +inline constexpr OpInfo kOpTable[] = { + {0x0000, "pto.get_block_idx", 0, 0x01, 0x00, 0, 1, 0, 0x00}, + {0x0001, "pto.get_block_num", 0, 0x01, 0x00, 0, 1, 0, 0x00}, + {0x0002, "pto.get_subblock_idx", 0, 0x01, 0x00, 0, 1, 0, 0x00}, + {0x0003, "pto.get_subblock_num", 0, 0x01, 0x00, 0, 1, 0, 0x00}, + {0x0004, "pto.make_tensor_view", 0, 0x01, 0x03, 1, 1, 0, 0x06}, + {0x0005, "pto.partition_view", 0, 0x01, 0x03, 1, 1, 0, 0x07}, + {0x0006, "pto.section", 1, 0x00, 0x00, 0, 0, 1, 0x00}, + {0x1000, "pto.addptr", 0, 0x01, 0x00, 2, 1, 0, 0x00}, + {0x1001, "pto.alloc_tile", 0, 0x01, 0x04, 0, 1, 0, 0x08}, + {0x1002, "pto.barrier", 0, 0x00, 0x00, 0, 0, 0, 0x00}, + {0x1003, "pto.mgather", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1004, "pto.mscatter", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1005, "pto.record_event", 0, 0x00, 0x00, 0, 0, 0, 0x02}, + {0x1006, "pto.tabs", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1007, "pto.tadd", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1008, "pto.taddc", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x1009, "pto.tadds", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x100A, "pto.taddsc", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x100B, "pto.tand", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x100C, "pto.tands", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x100D, "pto.tci", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x100E, "pto.tcmp", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x100F, "pto.tcmps", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1010, "pto.tcolexpand", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1011, "pto.tcolexpandadd", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1012, "pto.tcolexpanddiv", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1013, "pto.tcolexpandexpdif", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1014, "pto.tcolexpandmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1015, "pto.tcolexpandmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1016, "pto.tcolexpandmul", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1017, "pto.tcolexpandsub", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1018, "pto.tcolmax", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1019, "pto.tcolmin", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x101A, "pto.tcolprod", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x101B, "pto.tcolsum", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x101C, "pto.tcvt", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x101D, "pto.tdiv", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x101E, "pto.tdivs", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x101F, "pto.texp", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1020, "pto.texpands", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1021, "pto.textract", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x1022, "pto.textract_fp", 0, 0x00, 0x00, 5, 0, 0, 0x00}, + {0x1023, "pto.tfillpad", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1024, "pto.tfillpad_expand", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1025, "pto.tfillpad_inplace", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1026, "pto.tfmod", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1027, "pto.tfmods", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1028, "pto.tgather", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1029, "pto.tgatherb", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x102A, "pto.tgemv", 1, 0x00, 0x01, 0, 0, 0, 0x00}, + {0x102B, "pto.tgetval", 0, 0x01, 0x00, 2, 1, 0, 0x00}, + {0x102C, "pto.timg2col", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x102D, "pto.tinsert", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x102E, "pto.tinsert_fp", 0, 0x00, 0x00, 5, 0, 0, 0x00}, + {0x102F, "pto.tload", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1030, "pto.tlog", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1031, "pto.tlrelu", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1032, "pto.tmatmul", 1, 0x00, 0x01, 0, 0, 0, 0x00}, + {0x1033, "pto.tmatmul.mx", 1, 0x00, 0x01, 0, 0, 0, 0x00}, + {0x1034, "pto.tmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1035, "pto.tmaxs", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1036, "pto.tmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1037, "pto.tmins", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1038, "pto.tmov", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1039, "pto.tmov.fp", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x103A, "pto.tmrgsort", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x103B, "pto.tmul", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x103C, "pto.tmuls", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x103D, "pto.tneg", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x103E, "pto.tnot", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x103F, "pto.tor", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1040, "pto.tors", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1041, "pto.tpartadd", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1042, "pto.tpartmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1043, "pto.tpartmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1044, "pto.tpartmul", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1045, "pto.tprefetch", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1046, "pto.tprelu", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x1047, "pto.tquant", 0, 0x00, 0x02, 3, 0, 0, 0x00}, + {0x1048, "pto.trecip", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1049, "pto.trelu", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x104A, "pto.trem", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x104B, "pto.trems", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x104C, "pto.treshape", 0, 0x01, 0x00, 1, 1, 0, 0x00}, + {0x104D, "pto.trowexpand", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x104E, "pto.trowexpandadd", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x104F, "pto.trowexpandexpdif", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1050, "pto.trowexpandmax", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1051, "pto.trowexpandmin", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1052, "pto.trowmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1053, "pto.trowmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1054, "pto.trowsum", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1055, "pto.trsqrt", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1056, "pto.tscatter", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1057, "pto.tsel", 0, 0x00, 0x00, 5, 0, 0, 0x00}, + {0x1058, "pto.tsels", 0, 0x00, 0x00, 5, 0, 0, 0x00}, + {0x1059, "pto.tset_img2col_padding", 0, 0x00, 0x00, 1, 0, 0, 0x00}, + {0x105A, "pto.tset_img2col_rpt", 0, 0x00, 0x00, 1, 0, 0, 0x00}, + {0x105B, "pto.tsetfmatrix", 0, 0x00, 0x00, 1, 0, 0, 0x00}, + {0x105C, "pto.tsethf32mode", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x105D, "pto.tsettf32mode", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x105E, "pto.tsetval", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x105F, "pto.tshl", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1060, "pto.tshls", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1061, "pto.tshr", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1062, "pto.tshrs", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1063, "pto.tsort32", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1064, "pto.tsqrt", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1065, "pto.tstore", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1066, "pto.tstore_fp", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1067, "pto.tsub", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1068, "pto.tsubc", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x1069, "pto.tsubs", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x106A, "pto.tsubsc", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x106B, "pto.trowexpandsub", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x106C, "pto.ttrans", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x106D, "pto.ttri", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x106E, "pto.txor", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x106F, "pto.txors", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x1070, "pto.wait_event", 0, 0x00, 0x00, 0, 0, 0, 0x02}, + {0x1071, "pto.tprint", 0, 0x00, 0x00, 1, 0, 0, 0x00}, + {0x1072, "pto.subview", 0, 0x01, 0x02, 0, 1, 0, 0x00}, + {0x1073, "pto.trowexpanddiv", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1074, "pto.trowexpandmul", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1075, "pto.tdequant", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x1076, "pto.taxpy", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1077, "pto.thistogram", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1078, "pto.tget_scale_addr", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1079, "pto.trowargmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x107A, "pto.trowargmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x107B, "pto.tcolargmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x107C, "pto.tcolargmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x107D, "pto.tsync", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x107E, "pto.reserve_buffer", 0, 0x01, 0x00, 0, 1, 0, 0x00}, + {0x107F, "pto.import_reserved_buffer", 0, 0x01, 0x00, 0, 1, 0, 0x00}, + {0x1080, "pto.aic_initialize_pipe", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1081, "pto.aiv_initialize_pipe", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1082, "pto.tpush_to_aiv", 0, 0x00, 0x00, 1, 0, 0, 0x00}, + {0x1083, "pto.tpush_to_aic", 0, 0x00, 0x00, 1, 0, 0, 0x00}, + {0x1084, "pto.tpop_from_aic", 0, 0x01, 0x02, 0, 1, 0, 0x00}, + {0x1085, "pto.tpop_from_aiv", 0, 0x01, 0x02, 0, 1, 0, 0x00}, + {0x1086, "pto.tfree_from_aic", 0, 0x00, 0x00, 0, 0, 0, 0x00}, + {0x1087, "pto.tfree_from_aiv", 0, 0x00, 0x00, 0, 0, 0, 0x00}, + {0x1088, "pto.set_validshape", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1089, "pto.tconcat", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x108A, "pto.trowprod", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x108B, "pto.initialize_l2g2l_pipe", 0, 0x01, 0x02, 0, 1, 0, 0x00}, + {0x108C, "pto.initialize_l2l_pipe", 0, 0x01, 0x02, 0, 1, 0, 0x00}, + {0x108D, "pto.tpush", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x108E, "pto.declare_tile", 0, 0x01, 0x00, 0, 1, 0, 0x00}, + {0x108F, "pto.tpop", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1090, "pto.tfree", 0, 0x00, 0x00, 1, 0, 0, 0x00}, + {0x1091, "pto.comm.tput", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1092, "pto.comm.tget", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1093, "pto.comm.tnotify", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1094, "pto.comm.twait", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1095, "pto.comm.ttest", 0, 0x01, 0x02, 0, 1, 0, 0x00}, + {0x1096, "pto.comm.tbroadcast", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1097, "pto.comm.tgather", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1098, "pto.comm.tscatter", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1099, "pto.comm.treduce", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x109A, "pto.tpartargmax", 0, 0x00, 0x00, 6, 0, 0, 0x00}, + {0x109B, "pto.tpartargmin", 0, 0x00, 0x00, 6, 0, 0, 0x00}, + {0x109C, "pto.tscatter.maskpattern", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x2000, "arith.addi", 0, 0x01, 0x00, 2, 1, 0, 0x00}, + {0x2001, "arith.ceildivsi", 0, 0x01, 0x00, 2, 1, 0, 0x00}, + {0x2002, "arith.cmpi", 0, 0x01, 0x00, 2, 1, 0, 0x01}, + {0x2003, "arith.constant", 0, 0x01, 0x00, 0, 1, 0, 0x05}, + {0x2004, "arith.index_cast", 0, 0x01, 0x00, 1, 1, 0, 0x00}, + {0x2005, "arith.minui", 0, 0x01, 0x00, 2, 1, 0, 0x00}, + {0x2006, "arith.muli", 0, 0x01, 0x00, 2, 1, 0, 0x00}, + {0x2007, "arith.select", 0, 0x01, 0x00, 3, 1, 0, 0x00}, + {0x2008, "arith.subi", 0, 0x01, 0x00, 2, 1, 0, 0x00}, + {0x4000, "scf.for", 0, 0x00, 0x00, 3, 0, 1, 0x00}, + {0x4001, "scf.if", 0, 0x00, 0x00, 1, 0, 2, 0x00}, + {0x4002, "scf.yield", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x6000, "func.func", 0, 0x00, 0x00, 0, 0, 0, 0x00}, + {0x6001, "func.return", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x6002, "func.call", 0, 0x02, 0x02, 0, 0, 0, 0x00}, +}; + +inline const OpInfo *lookupByOpcode(uint16_t opcode) { + // Binary search on kOpTable (sorted by opcode). + size_t lo = 0, hi = sizeof(kOpTable) / sizeof(kOpTable[0]); + while (lo < hi) { + size_t mid = lo + (hi - lo) / 2; + uint16_t v = kOpTable[mid].opcode; + if (v == opcode) return &kOpTable[mid]; + if (v < opcode) lo = mid + 1; else hi = mid; + } + return nullptr; +} + +inline std::optional lookupOpcodeByName(llvm::StringRef name) { + uint16_t v = llvm::StringSwitch(name) + .Case("arith.addi", 0x2000) + .Case("arith.ceildivsi", 0x2001) + .Case("arith.cmpi", 0x2002) + .Case("arith.constant", 0x2003) + .Case("arith.index_cast", 0x2004) + .Case("arith.minui", 0x2005) + .Case("arith.muli", 0x2006) + .Case("arith.select", 0x2007) + .Case("arith.subi", 0x2008) + .Case("func.func", 0x6000) + .Case("func.return", 0x6001) + .Case("func.call", 0x6002) + .Case("pto.addptr", 0x1000) + .Case("pto.alloc_tile", 0x1001) + .Case("pto.barrier", 0x1002) + .Case("pto.get_block_idx", 0x0000) + .Case("pto.get_block_num", 0x0001) + .Case("pto.get_subblock_idx", 0x0002) + .Case("pto.get_subblock_num", 0x0003) + .Case("pto.make_tensor_view", 0x0004) + .Case("pto.mgather", 0x1003) + .Case("pto.mscatter", 0x1004) + .Case("pto.partition_view", 0x0005) + .Case("pto.record_event", 0x1005) + .Case("pto.section", 0x0006) + .Case("pto.tabs", 0x1006) + .Case("pto.tadd", 0x1007) + .Case("pto.taddc", 0x1008) + .Case("pto.tadds", 0x1009) + .Case("pto.taddsc", 0x100A) + .Case("pto.tand", 0x100B) + .Case("pto.tands", 0x100C) + .Case("pto.tci", 0x100D) + .Case("pto.tcmp", 0x100E) + .Case("pto.tcmps", 0x100F) + .Case("pto.tcolexpand", 0x1010) + .Case("pto.tcolexpandadd", 0x1011) + .Case("pto.tcolexpanddiv", 0x1012) + .Case("pto.tcolexpandexpdif", 0x1013) + .Case("pto.tcolexpandmax", 0x1014) + .Case("pto.tcolexpandmin", 0x1015) + .Case("pto.tcolexpandmul", 0x1016) + .Case("pto.tcolexpandsub", 0x1017) + .Case("pto.tcolmax", 0x1018) + .Case("pto.tcolmin", 0x1019) + .Case("pto.tcolprod", 0x101A) + .Case("pto.tcolsum", 0x101B) + .Case("pto.tcvt", 0x101C) + .Case("pto.tdiv", 0x101D) + .Case("pto.tdivs", 0x101E) + .Case("pto.texp", 0x101F) + .Case("pto.texpands", 0x1020) + .Case("pto.textract", 0x1021) + .Case("pto.textract_fp", 0x1022) + .Case("pto.tfillpad", 0x1023) + .Case("pto.tfillpad_expand", 0x1024) + .Case("pto.tfillpad_inplace", 0x1025) + .Case("pto.tfmod", 0x1026) + .Case("pto.tfmods", 0x1027) + .Case("pto.tgather", 0x1028) + .Case("pto.tgatherb", 0x1029) + .Case("pto.tgemv", 0x102A) + .Case("pto.tgetval", 0x102B) + .Case("pto.timg2col", 0x102C) + .Case("pto.tinsert", 0x102D) + .Case("pto.tinsert_fp", 0x102E) + .Case("pto.tload", 0x102F) + .Case("pto.tlog", 0x1030) + .Case("pto.tlrelu", 0x1031) + .Case("pto.tmatmul", 0x1032) + .Case("pto.tmatmul.mx", 0x1033) + .Case("pto.tmax", 0x1034) + .Case("pto.tmaxs", 0x1035) + .Case("pto.tmin", 0x1036) + .Case("pto.tmins", 0x1037) + .Case("pto.tmov", 0x1038) + .Case("pto.tmov.fp", 0x1039) + .Case("pto.tmrgsort", 0x103A) + .Case("pto.tmul", 0x103B) + .Case("pto.tmuls", 0x103C) + .Case("pto.tneg", 0x103D) + .Case("pto.tnot", 0x103E) + .Case("pto.tor", 0x103F) + .Case("pto.tors", 0x1040) + .Case("pto.tpartadd", 0x1041) + .Case("pto.tpartmax", 0x1042) + .Case("pto.tpartmin", 0x1043) + .Case("pto.tpartmul", 0x1044) + .Case("pto.tprefetch", 0x1045) + .Case("pto.tprelu", 0x1046) + .Case("pto.tquant", 0x1047) + .Case("pto.trecip", 0x1048) + .Case("pto.trelu", 0x1049) + .Case("pto.trem", 0x104A) + .Case("pto.trems", 0x104B) + .Case("pto.treshape", 0x104C) + .Case("pto.trowexpand", 0x104D) + .Case("pto.trowexpandadd", 0x104E) + .Case("pto.trowexpandexpdif", 0x104F) + .Case("pto.trowexpandmax", 0x1050) + .Case("pto.trowexpandmin", 0x1051) + .Case("pto.trowmax", 0x1052) + .Case("pto.trowmin", 0x1053) + .Case("pto.trowsum", 0x1054) + .Case("pto.trsqrt", 0x1055) + .Case("pto.tscatter", 0x1056) + .Case("pto.tsel", 0x1057) + .Case("pto.tsels", 0x1058) + .Case("pto.tset_img2col_padding", 0x1059) + .Case("pto.tset_img2col_rpt", 0x105A) + .Case("pto.tsetfmatrix", 0x105B) + .Case("pto.tsethf32mode", 0x105C) + .Case("pto.tsettf32mode", 0x105D) + .Case("pto.tsetval", 0x105E) + .Case("pto.tshl", 0x105F) + .Case("pto.tshls", 0x1060) + .Case("pto.tshr", 0x1061) + .Case("pto.tshrs", 0x1062) + .Case("pto.tsort32", 0x1063) + .Case("pto.tsqrt", 0x1064) + .Case("pto.tstore", 0x1065) + .Case("pto.tstore_fp", 0x1066) + .Case("pto.tsub", 0x1067) + .Case("pto.tsubc", 0x1068) + .Case("pto.tsubs", 0x1069) + .Case("pto.tsubsc", 0x106A) + .Case("pto.trowexpandsub", 0x106B) + .Case("pto.ttrans", 0x106C) + .Case("pto.ttri", 0x106D) + .Case("pto.txor", 0x106E) + .Case("pto.txors", 0x106F) + .Case("pto.wait_event", 0x1070) + .Case("pto.tprint", 0x1071) + .Case("pto.subview", 0x1072) + .Case("pto.trowexpanddiv", 0x1073) + .Case("pto.trowexpandmul", 0x1074) + .Case("pto.tdequant", 0x1075) + .Case("pto.taxpy", 0x1076) + .Case("pto.thistogram", 0x1077) + .Case("pto.tget_scale_addr", 0x1078) + .Case("pto.trowargmax", 0x1079) + .Case("pto.trowargmin", 0x107A) + .Case("pto.tcolargmax", 0x107B) + .Case("pto.tcolargmin", 0x107C) + .Case("pto.tsync", 0x107D) + .Case("pto.reserve_buffer", 0x107E) + .Case("pto.import_reserved_buffer", 0x107F) + .Case("pto.aic_initialize_pipe", 0x1080) + .Case("pto.aiv_initialize_pipe", 0x1081) + .Case("pto.tpush_to_aiv", 0x1082) + .Case("pto.tpush_to_aic", 0x1083) + .Case("pto.tpop_from_aic", 0x1084) + .Case("pto.tpop_from_aiv", 0x1085) + .Case("pto.tfree_from_aic", 0x1086) + .Case("pto.tfree_from_aiv", 0x1087) + .Case("pto.set_validshape", 0x1088) + .Case("pto.tconcat", 0x1089) + .Case("pto.trowprod", 0x108A) + .Case("pto.initialize_l2g2l_pipe", 0x108B) + .Case("pto.initialize_l2l_pipe", 0x108C) + .Case("pto.tpush", 0x108D) + .Case("pto.declare_tile", 0x108E) + .Case("pto.tpop", 0x108F) + .Case("pto.tfree", 0x1090) + .Case("pto.comm.tput", 0x1091) + .Case("pto.comm.tget", 0x1092) + .Case("pto.comm.tnotify", 0x1093) + .Case("pto.comm.twait", 0x1094) + .Case("pto.comm.ttest", 0x1095) + .Case("pto.comm.tbroadcast", 0x1096) + .Case("pto.comm.tgather", 0x1097) + .Case("pto.comm.tscatter", 0x1098) + .Case("pto.comm.treduce", 0x1099) + .Case("pto.tpartargmax", 0x109A) + .Case("pto.tpartargmin", 0x109B) + .Case("scf.for", 0x4000) + .Case("scf.if", 0x4001) + .Case("scf.yield", 0x4002) + .Default(0xFFFF); + if (v == 0xFFFF) return std::nullopt; + return v; +} + +inline const OpInfo *lookupByName(llvm::StringRef name) { + auto o = lookupOpcodeByName(name); + if (!o) return nullptr; + return lookupByOpcode(*o); +} + +struct OpcodeAndVariant { uint16_t opcode; uint8_t hasVariant; uint8_t variant; }; + +inline std::optional lookupOpcodeAndVariantByFullName(llvm::StringRef fullName) { + // For non-family ops, variant is 0. For family ops, variant is the assigned u8. + // NOTE: `pto.section` is not a real op name; use `pto.section.cube`/`pto.section.vector`. + return llvm::StringSwitch>(fullName) + .Case("arith.addi", OpcodeAndVariant{0x2000, 0, 0}) + .Case("arith.ceildivsi", OpcodeAndVariant{0x2001, 0, 0}) + .Case("arith.cmpi", OpcodeAndVariant{0x2002, 0, 0}) + .Case("arith.constant", OpcodeAndVariant{0x2003, 0, 0}) + .Case("arith.index_cast", OpcodeAndVariant{0x2004, 0, 0}) + .Case("arith.minui", OpcodeAndVariant{0x2005, 0, 0}) + .Case("arith.muli", OpcodeAndVariant{0x2006, 0, 0}) + .Case("arith.select", OpcodeAndVariant{0x2007, 0, 0}) + .Case("arith.subi", OpcodeAndVariant{0x2008, 0, 0}) + .Case("func.func", OpcodeAndVariant{0x6000, 0, 0}) + .Case("func.return", OpcodeAndVariant{0x6001, 0, 0}) + .Case("func.call", OpcodeAndVariant{0x6002, 0, 0}) + .Case("pto.addptr", OpcodeAndVariant{0x1000, 0, 0}) + .Case("pto.alloc_tile", OpcodeAndVariant{0x1001, 0, 0}) + .Case("pto.barrier", OpcodeAndVariant{0x1002, 0, 0}) + .Case("pto.get_block_idx", OpcodeAndVariant{0x0000, 0, 0}) + .Case("pto.get_block_num", OpcodeAndVariant{0x0001, 0, 0}) + .Case("pto.get_subblock_idx", OpcodeAndVariant{0x0002, 0, 0}) + .Case("pto.get_subblock_num", OpcodeAndVariant{0x0003, 0, 0}) + .Case("pto.make_tensor_view", OpcodeAndVariant{0x0004, 0, 0}) + .Case("pto.mgather", OpcodeAndVariant{0x1003, 0, 0}) + .Case("pto.mscatter", OpcodeAndVariant{0x1004, 0, 0}) + .Case("pto.partition_view", OpcodeAndVariant{0x0005, 0, 0}) + .Case("pto.record_event", OpcodeAndVariant{0x1005, 0, 0}) + .Case("pto.tabs", OpcodeAndVariant{0x1006, 0, 0}) + .Case("pto.tadd", OpcodeAndVariant{0x1007, 0, 0}) + .Case("pto.taddc", OpcodeAndVariant{0x1008, 0, 0}) + .Case("pto.tadds", OpcodeAndVariant{0x1009, 0, 0}) + .Case("pto.taddsc", OpcodeAndVariant{0x100A, 0, 0}) + .Case("pto.tand", OpcodeAndVariant{0x100B, 0, 0}) + .Case("pto.tands", OpcodeAndVariant{0x100C, 0, 0}) + .Case("pto.tci", OpcodeAndVariant{0x100D, 0, 0}) + .Case("pto.tcmp", OpcodeAndVariant{0x100E, 0, 0}) + .Case("pto.tcmps", OpcodeAndVariant{0x100F, 0, 0}) + .Case("pto.tcolexpand", OpcodeAndVariant{0x1010, 0, 0}) + .Case("pto.tcolexpandadd", OpcodeAndVariant{0x1011, 0, 0}) + .Case("pto.tcolexpanddiv", OpcodeAndVariant{0x1012, 0, 0}) + .Case("pto.tcolexpandexpdif", OpcodeAndVariant{0x1013, 0, 0}) + .Case("pto.tcolexpandmax", OpcodeAndVariant{0x1014, 0, 0}) + .Case("pto.tcolexpandmin", OpcodeAndVariant{0x1015, 0, 0}) + .Case("pto.tcolexpandmul", OpcodeAndVariant{0x1016, 0, 0}) + .Case("pto.tcolexpandsub", OpcodeAndVariant{0x1017, 0, 0}) + .Case("pto.tcolmax", OpcodeAndVariant{0x1018, 0, 0}) + .Case("pto.tcolmin", OpcodeAndVariant{0x1019, 0, 0}) + .Case("pto.tcolprod", OpcodeAndVariant{0x101A, 0, 0}) + .Case("pto.tcolsum", OpcodeAndVariant{0x101B, 0, 0}) + .Case("pto.tcvt", OpcodeAndVariant{0x101C, 0, 0}) + .Case("pto.tdiv", OpcodeAndVariant{0x101D, 0, 0}) + .Case("pto.tdivs", OpcodeAndVariant{0x101E, 0, 0}) + .Case("pto.texp", OpcodeAndVariant{0x101F, 0, 0}) + .Case("pto.texpands", OpcodeAndVariant{0x1020, 0, 0}) + .Case("pto.textract", OpcodeAndVariant{0x1021, 0, 0}) + .Case("pto.textract_fp", OpcodeAndVariant{0x1022, 0, 0}) + .Case("pto.tfillpad", OpcodeAndVariant{0x1023, 0, 0}) + .Case("pto.tfillpad_expand", OpcodeAndVariant{0x1024, 0, 0}) + .Case("pto.tfillpad_inplace", OpcodeAndVariant{0x1025, 0, 0}) + .Case("pto.tfmod", OpcodeAndVariant{0x1026, 0, 0}) + .Case("pto.tfmods", OpcodeAndVariant{0x1027, 0, 0}) + .Case("pto.tgather", OpcodeAndVariant{0x1028, 0, 0}) + .Case("pto.tgatherb", OpcodeAndVariant{0x1029, 0, 0}) + .Case("pto.tgetval", OpcodeAndVariant{0x102B, 0, 0}) + .Case("pto.timg2col", OpcodeAndVariant{0x102C, 0, 0}) + .Case("pto.tinsert", OpcodeAndVariant{0x102D, 0, 0}) + .Case("pto.tinsert_fp", OpcodeAndVariant{0x102E, 0, 0}) + .Case("pto.tload", OpcodeAndVariant{0x102F, 0, 0}) + .Case("pto.tlog", OpcodeAndVariant{0x1030, 0, 0}) + .Case("pto.tlrelu", OpcodeAndVariant{0x1031, 0, 0}) + .Case("pto.tmax", OpcodeAndVariant{0x1034, 0, 0}) + .Case("pto.tmaxs", OpcodeAndVariant{0x1035, 0, 0}) + .Case("pto.tmin", OpcodeAndVariant{0x1036, 0, 0}) + .Case("pto.tmins", OpcodeAndVariant{0x1037, 0, 0}) + .Case("pto.tmov", OpcodeAndVariant{0x1038, 0, 0}) + .Case("pto.tmov.fp", OpcodeAndVariant{0x1039, 0, 0}) + .Case("pto.tmrgsort", OpcodeAndVariant{0x103A, 0, 0}) + .Case("pto.tmul", OpcodeAndVariant{0x103B, 0, 0}) + .Case("pto.tmuls", OpcodeAndVariant{0x103C, 0, 0}) + .Case("pto.tneg", OpcodeAndVariant{0x103D, 0, 0}) + .Case("pto.tnot", OpcodeAndVariant{0x103E, 0, 0}) + .Case("pto.tor", OpcodeAndVariant{0x103F, 0, 0}) + .Case("pto.tors", OpcodeAndVariant{0x1040, 0, 0}) + .Case("pto.tpartadd", OpcodeAndVariant{0x1041, 0, 0}) + .Case("pto.tpartmax", OpcodeAndVariant{0x1042, 0, 0}) + .Case("pto.tpartmin", OpcodeAndVariant{0x1043, 0, 0}) + .Case("pto.tpartmul", OpcodeAndVariant{0x1044, 0, 0}) + .Case("pto.tprefetch", OpcodeAndVariant{0x1045, 0, 0}) + .Case("pto.tprelu", OpcodeAndVariant{0x1046, 0, 0}) + .Case("pto.tquant", OpcodeAndVariant{0x1047, 0, 0}) + .Case("pto.trecip", OpcodeAndVariant{0x1048, 0, 0}) + .Case("pto.trelu", OpcodeAndVariant{0x1049, 0, 0}) + .Case("pto.trem", OpcodeAndVariant{0x104A, 0, 0}) + .Case("pto.trems", OpcodeAndVariant{0x104B, 0, 0}) + .Case("pto.treshape", OpcodeAndVariant{0x104C, 0, 0}) + .Case("pto.trowexpand", OpcodeAndVariant{0x104D, 0, 0}) + .Case("pto.trowexpandadd", OpcodeAndVariant{0x104E, 0, 0}) + .Case("pto.trowexpandexpdif", OpcodeAndVariant{0x104F, 0, 0}) + .Case("pto.trowexpandmax", OpcodeAndVariant{0x1050, 0, 0}) + .Case("pto.trowexpandmin", OpcodeAndVariant{0x1051, 0, 0}) + .Case("pto.trowmax", OpcodeAndVariant{0x1052, 0, 0}) + .Case("pto.trowmin", OpcodeAndVariant{0x1053, 0, 0}) + .Case("pto.trowsum", OpcodeAndVariant{0x1054, 0, 0}) + .Case("pto.trsqrt", OpcodeAndVariant{0x1055, 0, 0}) + .Case("pto.tscatter", OpcodeAndVariant{0x1056, 0, 0}) + .Case("pto.tsel", OpcodeAndVariant{0x1057, 0, 0}) + .Case("pto.tsels", OpcodeAndVariant{0x1058, 0, 0}) + .Case("pto.tset_img2col_padding", OpcodeAndVariant{0x1059, 0, 0}) + .Case("pto.tset_img2col_rpt", OpcodeAndVariant{0x105A, 0, 0}) + .Case("pto.tsetfmatrix", OpcodeAndVariant{0x105B, 0, 0}) + .Case("pto.tsethf32mode", OpcodeAndVariant{0x105C, 0, 0}) + .Case("pto.tsettf32mode", OpcodeAndVariant{0x105D, 0, 0}) + .Case("pto.tsetval", OpcodeAndVariant{0x105E, 0, 0}) + .Case("pto.tshl", OpcodeAndVariant{0x105F, 0, 0}) + .Case("pto.tshls", OpcodeAndVariant{0x1060, 0, 0}) + .Case("pto.tshr", OpcodeAndVariant{0x1061, 0, 0}) + .Case("pto.tshrs", OpcodeAndVariant{0x1062, 0, 0}) + .Case("pto.tsort32", OpcodeAndVariant{0x1063, 0, 0}) + .Case("pto.tsqrt", OpcodeAndVariant{0x1064, 0, 0}) + .Case("pto.tstore", OpcodeAndVariant{0x1065, 0, 0}) + .Case("pto.tstore_fp", OpcodeAndVariant{0x1066, 0, 0}) + .Case("pto.tsub", OpcodeAndVariant{0x1067, 0, 0}) + .Case("pto.tsubc", OpcodeAndVariant{0x1068, 0, 0}) + .Case("pto.tsubs", OpcodeAndVariant{0x1069, 0, 0}) + .Case("pto.tsubsc", OpcodeAndVariant{0x106A, 0, 0}) + .Case("pto.trowexpandsub", OpcodeAndVariant{0x106B, 0, 0}) + .Case("pto.ttrans", OpcodeAndVariant{0x106C, 0, 0}) + .Case("pto.ttri", OpcodeAndVariant{0x106D, 0, 0}) + .Case("pto.txor", OpcodeAndVariant{0x106E, 0, 0}) + .Case("pto.txors", OpcodeAndVariant{0x106F, 0, 0}) + .Case("pto.wait_event", OpcodeAndVariant{0x1070, 0, 0}) + .Case("pto.tprint", OpcodeAndVariant{0x1071, 0, 0}) + .Case("pto.subview", OpcodeAndVariant{0x1072, 0, 0}) + .Case("pto.trowexpanddiv", OpcodeAndVariant{0x1073, 0, 0}) + .Case("pto.trowexpandmul", OpcodeAndVariant{0x1074, 0, 0}) + .Case("pto.tdequant", OpcodeAndVariant{0x1075, 0, 0}) + .Case("pto.taxpy", OpcodeAndVariant{0x1076, 0, 0}) + .Case("pto.thistogram", OpcodeAndVariant{0x1077, 0, 0}) + .Case("pto.tget_scale_addr", OpcodeAndVariant{0x1078, 0, 0}) + .Case("pto.trowargmax", OpcodeAndVariant{0x1079, 0, 0}) + .Case("pto.trowargmin", OpcodeAndVariant{0x107A, 0, 0}) + .Case("pto.tcolargmax", OpcodeAndVariant{0x107B, 0, 0}) + .Case("pto.tcolargmin", OpcodeAndVariant{0x107C, 0, 0}) + .Case("pto.tsync", OpcodeAndVariant{0x107D, 0, 0}) + .Case("pto.reserve_buffer", OpcodeAndVariant{0x107E, 0, 0}) + .Case("pto.import_reserved_buffer", OpcodeAndVariant{0x107F, 0, 0}) + .Case("pto.aic_initialize_pipe", OpcodeAndVariant{0x1080, 0, 0}) + .Case("pto.aiv_initialize_pipe", OpcodeAndVariant{0x1081, 0, 0}) + .Case("pto.tpush_to_aiv", OpcodeAndVariant{0x1082, 0, 0}) + .Case("pto.tpush_to_aic", OpcodeAndVariant{0x1083, 0, 0}) + .Case("pto.tpop_from_aic", OpcodeAndVariant{0x1084, 0, 0}) + .Case("pto.tpop_from_aiv", OpcodeAndVariant{0x1085, 0, 0}) + .Case("pto.tfree_from_aic", OpcodeAndVariant{0x1086, 0, 0}) + .Case("pto.tfree_from_aiv", OpcodeAndVariant{0x1087, 0, 0}) + .Case("pto.set_validshape", OpcodeAndVariant{0x1088, 0, 0}) + .Case("pto.tconcat", OpcodeAndVariant{0x1089, 0, 0}) + .Case("pto.trowprod", OpcodeAndVariant{0x108A, 0, 0}) + .Case("pto.initialize_l2g2l_pipe", OpcodeAndVariant{0x108B, 0, 0}) + .Case("pto.initialize_l2l_pipe", OpcodeAndVariant{0x108C, 0, 0}) + .Case("pto.tpush", OpcodeAndVariant{0x108D, 0, 0}) + .Case("pto.declare_tile", OpcodeAndVariant{0x108E, 0, 0}) + .Case("pto.tpop", OpcodeAndVariant{0x108F, 0, 0}) + .Case("pto.tfree", OpcodeAndVariant{0x1090, 0, 0}) + .Case("pto.comm.tput", OpcodeAndVariant{0x1091, 0, 0}) + .Case("pto.comm.tget", OpcodeAndVariant{0x1092, 0, 0}) + .Case("pto.comm.tnotify", OpcodeAndVariant{0x1093, 0, 0}) + .Case("pto.comm.twait", OpcodeAndVariant{0x1094, 0, 0}) + .Case("pto.comm.ttest", OpcodeAndVariant{0x1095, 0, 0}) + .Case("pto.comm.tbroadcast", OpcodeAndVariant{0x1096, 0, 0}) + .Case("pto.comm.tgather", OpcodeAndVariant{0x1097, 0, 0}) + .Case("pto.comm.tscatter", OpcodeAndVariant{0x1098, 0, 0}) + .Case("pto.comm.treduce", OpcodeAndVariant{0x1099, 0, 0}) + .Case("pto.tpartargmax", OpcodeAndVariant{0x109A, 0, 0}) + .Case("pto.tpartargmin", OpcodeAndVariant{0x109B, 0, 0}) + .Case("scf.for", OpcodeAndVariant{0x4000, 0, 0}) + .Case("scf.if", OpcodeAndVariant{0x4001, 0, 0}) + .Case("scf.yield", OpcodeAndVariant{0x4002, 0, 0}) + .Case("pto.section.cube", + OpcodeAndVariant{0x0006, kHasVariant, kSectionCubeVariant}) + .Case("pto.section.vector", + OpcodeAndVariant{0x0006, kHasVariant, kSectionVectorVariant}) + .Case("pto.tgemv", + OpcodeAndVariant{0x102A, kHasVariant, kVariantDefault}) + .Case("pto.tgemv.acc", + OpcodeAndVariant{0x102A, kHasVariant, kVariantAcc}) + .Case("pto.tgemv.bias", + OpcodeAndVariant{0x102A, kHasVariant, kVariantBias}) + .Case("pto.tgemv.mx", + OpcodeAndVariant{0x102A, kHasVariant, kVariantMx}) + .Case("pto.tgemv.mx.acc", + OpcodeAndVariant{0x102A, kHasVariant, kVariantMxAcc}) + .Case("pto.tgemv.mx.bias", + OpcodeAndVariant{0x102A, kHasVariant, kVariantMxBias}) + .Case("pto.tmatmul", + OpcodeAndVariant{0x1032, kHasVariant, kVariantDefault}) + .Case("pto.tmatmul.acc", + OpcodeAndVariant{0x1032, kHasVariant, kVariantAcc}) + .Case("pto.tmatmul.bias", + OpcodeAndVariant{0x1032, kHasVariant, kVariantBias}) + .Case("pto.tmatmul.mx", + OpcodeAndVariant{0x1033, kHasVariant, kVariantDefault}) + .Case("pto.tmatmul.mx.acc", + OpcodeAndVariant{0x1033, kHasVariant, kVariantAcc}) + .Case("pto.tmatmul.mx.bias", + OpcodeAndVariant{0x1033, kHasVariant, kVariantBias}) + .Default(std::nullopt); +} + +inline const char *fullNameFromOpcodeVariant(uint16_t opcode, uint8_t variant) { + const OpInfo *info = lookupByOpcode(opcode); + if (!info) return nullptr; + if (opcode == kTscatterMaskOpcode) return "pto.tscatter"; + if (!info->has_variant_u8) return info->name; + switch (opcode) { + case 0x0006: + switch (variant) { + case kSectionCubeVariant: return "pto.section.cube"; + case kSectionVectorVariant: return "pto.section.vector"; + default: return info->name; + } + case 0x102A: + switch (variant) { + case kVariantDefault: return "pto.tgemv"; + case kVariantAcc: return "pto.tgemv.acc"; + case kVariantBias: return "pto.tgemv.bias"; + case kVariantMx: return "pto.tgemv.mx"; + case kVariantMxAcc: return "pto.tgemv.mx.acc"; + case kVariantMxBias: return "pto.tgemv.mx.bias"; + default: return info->name; + } + case 0x1032: + switch (variant) { + case kVariantDefault: return "pto.tmatmul"; + case kVariantAcc: return "pto.tmatmul.acc"; + case kVariantBias: return "pto.tmatmul.bias"; + default: return info->name; + } + case 0x1033: + switch (variant) { + case kVariantDefault: return "pto.tmatmul.mx"; + case kVariantAcc: return "pto.tmatmul.mx.acc"; + case kVariantBias: return "pto.tmatmul.mx.bias"; + default: return info->name; + } + default: return info->name; + } +} + +inline std::optional lookupOperandsByVariant(uint16_t opcode, uint8_t variant) { + switch (opcode) { + case 0x102A: + switch (variant) { + case kVariantDefault: return kTgemvOperandCount; + case kVariantAcc: return kTgemvAccOperandCount; + case kVariantBias: return kTgemvBiasOperandCount; + case kVariantMx: return kTgemvMxOperandCount; + case kVariantMxAcc: return kTgemvMxAccOperandCount; + case kVariantMxBias: return kTgemvMxBiasOperandCount; + default: return std::nullopt; + } + case 0x1032: + switch (variant) { + case kVariantDefault: return kTmatmulOperandCount; + case kVariantAcc: return kTmatmulAccOperandCount; + case kVariantBias: return kTmatmulBiasOperandCount; + default: return std::nullopt; + } + case 0x1033: + switch (variant) { + case kVariantDefault: return kTmatmulMxOperandCount; + case kVariantAcc: return kTmatmulMxAccOperandCount; + case kVariantBias: return kTmatmulMxBiasOperandCount; + default: return std::nullopt; + } + default: return std::nullopt; + } +} + +} // namespace ptobc::v0 diff --git a/tools/ptobc/generated/ptobc_opcodes_v0.h b/tools/ptobc/generated/ptobc_opcodes_v0.h index 8303e1261..3f0faf5f1 100644 --- a/tools/ptobc/generated/ptobc_opcodes_v0.h +++ b/tools/ptobc/generated/ptobc_opcodes_v0.h @@ -6,717 +6,5 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// Generated by docs/bytecode/tools/gen_v0_tables.py -#pragma once -#include -#include - -#include -#include - -namespace ptobc::v0 { - -inline constexpr uint8_t kVariantDefault = 0; -inline constexpr uint8_t kVariantAcc = 1; -inline constexpr uint8_t kVariantBias = 2; -inline constexpr uint8_t kVariantMx = 3; -inline constexpr uint8_t kVariantMxAcc = 4; -inline constexpr uint8_t kVariantMxBias = 5; -inline constexpr uint8_t kSectionCubeVariant = 0; -inline constexpr uint8_t kSectionVectorVariant = 1; -inline constexpr uint8_t kHasVariant = 1; -inline constexpr uint16_t kTscatterMaskOpcode = 0x109C; - -inline constexpr int kTgemvOperandCount = 3; -inline constexpr int kTgemvAccOperandCount = 4; -inline constexpr int kTgemvBiasOperandCount = 4; -inline constexpr int kTgemvMxOperandCount = 5; -inline constexpr int kTgemvMxAccOperandCount = 6; -inline constexpr int kTgemvMxBiasOperandCount = 6; -inline constexpr int kTmatmulOperandCount = 3; -inline constexpr int kTmatmulAccOperandCount = 4; -inline constexpr int kTmatmulBiasOperandCount = 4; -inline constexpr int kTmatmulMxOperandCount = 5; -inline constexpr int kTmatmulMxAccOperandCount = 6; -inline constexpr int kTmatmulMxBiasOperandCount = 6; - -struct OpInfo { - uint16_t opcode; - const char *name; - uint8_t has_variant_u8; - uint8_t result_type_mode; - uint8_t operand_mode; - uint16_t num_operands; - uint16_t num_results; - uint16_t num_regions; - uint8_t imm_kind; -}; - -inline constexpr OpInfo kOpTable[] = { - {0x0000, "pto.get_block_idx", 0, 0x01, 0x00, 0, 1, 0, 0x00}, - {0x0001, "pto.get_block_num", 0, 0x01, 0x00, 0, 1, 0, 0x00}, - {0x0002, "pto.get_subblock_idx", 0, 0x01, 0x00, 0, 1, 0, 0x00}, - {0x0003, "pto.get_subblock_num", 0, 0x01, 0x00, 0, 1, 0, 0x00}, - {0x0004, "pto.make_tensor_view", 0, 0x01, 0x03, 1, 1, 0, 0x06}, - {0x0005, "pto.partition_view", 0, 0x01, 0x03, 1, 1, 0, 0x07}, - {0x0006, "pto.section", 1, 0x00, 0x00, 0, 0, 1, 0x00}, - {0x1000, "pto.addptr", 0, 0x01, 0x00, 2, 1, 0, 0x00}, - {0x1001, "pto.alloc_tile", 0, 0x01, 0x04, 0, 1, 0, 0x08}, - {0x1002, "pto.barrier", 0, 0x00, 0x00, 0, 0, 0, 0x00}, - {0x1003, "pto.mgather", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1004, "pto.mscatter", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1005, "pto.record_event", 0, 0x00, 0x00, 0, 0, 0, 0x02}, - {0x1006, "pto.tabs", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1007, "pto.tadd", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1008, "pto.taddc", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x1009, "pto.tadds", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x100A, "pto.taddsc", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x100B, "pto.tand", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x100C, "pto.tands", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x100D, "pto.tci", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x100E, "pto.tcmp", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x100F, "pto.tcmps", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1010, "pto.tcolexpand", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1011, "pto.tcolexpandadd", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1012, "pto.tcolexpanddiv", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1013, "pto.tcolexpandexpdif", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1014, "pto.tcolexpandmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1015, "pto.tcolexpandmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1016, "pto.tcolexpandmul", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1017, "pto.tcolexpandsub", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1018, "pto.tcolmax", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1019, "pto.tcolmin", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x101A, "pto.tcolprod", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x101B, "pto.tcolsum", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x101C, "pto.tcvt", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x101D, "pto.tdiv", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x101E, "pto.tdivs", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x101F, "pto.texp", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1020, "pto.texpands", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1021, "pto.textract", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x1022, "pto.textract_fp", 0, 0x00, 0x00, 5, 0, 0, 0x00}, - {0x1023, "pto.tfillpad", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1024, "pto.tfillpad_expand", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1025, "pto.tfillpad_inplace", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1026, "pto.tfmod", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1027, "pto.tfmods", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1028, "pto.tgather", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1029, "pto.tgatherb", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x102A, "pto.tgemv", 1, 0x00, 0x01, 0, 0, 0, 0x00}, - {0x102B, "pto.tgetval", 0, 0x01, 0x00, 2, 1, 0, 0x00}, - {0x102C, "pto.timg2col", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x102D, "pto.tinsert", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x102E, "pto.tinsert_fp", 0, 0x00, 0x00, 5, 0, 0, 0x00}, - {0x102F, "pto.tload", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1030, "pto.tlog", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1031, "pto.tlrelu", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1032, "pto.tmatmul", 1, 0x00, 0x01, 0, 0, 0, 0x00}, - {0x1033, "pto.tmatmul.mx", 1, 0x00, 0x01, 0, 0, 0, 0x00}, - {0x1034, "pto.tmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1035, "pto.tmaxs", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1036, "pto.tmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1037, "pto.tmins", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1038, "pto.tmov", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1039, "pto.tmov.fp", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x103A, "pto.tmrgsort", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x103B, "pto.tmul", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x103C, "pto.tmuls", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x103D, "pto.tneg", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x103E, "pto.tnot", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x103F, "pto.tor", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1040, "pto.tors", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1041, "pto.tpartadd", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1042, "pto.tpartmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1043, "pto.tpartmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1044, "pto.tpartmul", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1045, "pto.tprefetch", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1046, "pto.tprelu", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x1047, "pto.tquant", 0, 0x00, 0x02, 3, 0, 0, 0x00}, - {0x1048, "pto.trecip", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1049, "pto.trelu", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x104A, "pto.trem", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x104B, "pto.trems", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x104C, "pto.treshape", 0, 0x01, 0x00, 1, 1, 0, 0x00}, - {0x104D, "pto.trowexpand", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x104E, "pto.trowexpandadd", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x104F, "pto.trowexpandexpdif", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1050, "pto.trowexpandmax", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1051, "pto.trowexpandmin", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1052, "pto.trowmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1053, "pto.trowmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1054, "pto.trowsum", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1055, "pto.trsqrt", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1056, "pto.tscatter", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1057, "pto.tsel", 0, 0x00, 0x00, 5, 0, 0, 0x00}, - {0x1058, "pto.tsels", 0, 0x00, 0x00, 5, 0, 0, 0x00}, - {0x1059, "pto.tset_img2col_padding", 0, 0x00, 0x00, 1, 0, 0, 0x00}, - {0x105A, "pto.tset_img2col_rpt", 0, 0x00, 0x00, 1, 0, 0, 0x00}, - {0x105B, "pto.tsetfmatrix", 0, 0x00, 0x00, 1, 0, 0, 0x00}, - {0x105C, "pto.tsethf32mode", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x105D, "pto.tsettf32mode", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x105E, "pto.tsetval", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x105F, "pto.tshl", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1060, "pto.tshls", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1061, "pto.tshr", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1062, "pto.tshrs", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1063, "pto.tsort32", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1064, "pto.tsqrt", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1065, "pto.tstore", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1066, "pto.tstore_fp", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1067, "pto.tsub", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1068, "pto.tsubc", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x1069, "pto.tsubs", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x106A, "pto.tsubsc", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x106B, "pto.trowexpandsub", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x106C, "pto.ttrans", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x106D, "pto.ttri", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x106E, "pto.txor", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x106F, "pto.txors", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x1070, "pto.wait_event", 0, 0x00, 0x00, 0, 0, 0, 0x02}, - {0x1071, "pto.tprint", 0, 0x00, 0x00, 1, 0, 0, 0x00}, - {0x1072, "pto.subview", 0, 0x01, 0x02, 0, 1, 0, 0x00}, - {0x1073, "pto.trowexpanddiv", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1074, "pto.trowexpandmul", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1075, "pto.tdequant", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x1076, "pto.taxpy", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1077, "pto.thistogram", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1078, "pto.tget_scale_addr", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1079, "pto.trowargmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x107A, "pto.trowargmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x107B, "pto.tcolargmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x107C, "pto.tcolargmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x107D, "pto.tsync", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x107E, "pto.reserve_buffer", 0, 0x01, 0x00, 0, 1, 0, 0x00}, - {0x107F, "pto.import_reserved_buffer", 0, 0x01, 0x00, 0, 1, 0, 0x00}, - {0x1080, "pto.aic_initialize_pipe", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1081, "pto.aiv_initialize_pipe", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1082, "pto.tpush_to_aiv", 0, 0x00, 0x00, 1, 0, 0, 0x00}, - {0x1083, "pto.tpush_to_aic", 0, 0x00, 0x00, 1, 0, 0, 0x00}, - {0x1084, "pto.tpop_from_aic", 0, 0x01, 0x02, 0, 1, 0, 0x00}, - {0x1085, "pto.tpop_from_aiv", 0, 0x01, 0x02, 0, 1, 0, 0x00}, - {0x1086, "pto.tfree_from_aic", 0, 0x00, 0x00, 0, 0, 0, 0x00}, - {0x1087, "pto.tfree_from_aiv", 0, 0x00, 0x00, 0, 0, 0, 0x00}, - {0x1088, "pto.set_validshape", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1089, "pto.tconcat", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x108A, "pto.trowprod", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x108B, "pto.initialize_l2g2l_pipe", 0, 0x01, 0x02, 0, 1, 0, 0x00}, - {0x108C, "pto.initialize_l2l_pipe", 0, 0x01, 0x02, 0, 1, 0, 0x00}, - {0x108D, "pto.tpush", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x108E, "pto.declare_tile", 0, 0x01, 0x00, 0, 1, 0, 0x00}, - {0x108F, "pto.tpop", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1090, "pto.tfree", 0, 0x00, 0x00, 1, 0, 0, 0x00}, - {0x1091, "pto.comm.tput", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1092, "pto.comm.tget", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1093, "pto.comm.tnotify", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1094, "pto.comm.twait", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1095, "pto.comm.ttest", 0, 0x01, 0x02, 0, 1, 0, 0x00}, - {0x1096, "pto.comm.tbroadcast", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1097, "pto.comm.tgather", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1098, "pto.comm.tscatter", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1099, "pto.comm.treduce", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x109A, "pto.tpartargmax", 0, 0x00, 0x00, 6, 0, 0, 0x00}, - {0x109B, "pto.tpartargmin", 0, 0x00, 0x00, 6, 0, 0, 0x00}, - {0x109C, "pto.tscatter.maskpattern", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x2000, "arith.addi", 0, 0x01, 0x00, 2, 1, 0, 0x00}, - {0x2001, "arith.ceildivsi", 0, 0x01, 0x00, 2, 1, 0, 0x00}, - {0x2002, "arith.cmpi", 0, 0x01, 0x00, 2, 1, 0, 0x01}, - {0x2003, "arith.constant", 0, 0x01, 0x00, 0, 1, 0, 0x05}, - {0x2004, "arith.index_cast", 0, 0x01, 0x00, 1, 1, 0, 0x00}, - {0x2005, "arith.minui", 0, 0x01, 0x00, 2, 1, 0, 0x00}, - {0x2006, "arith.muli", 0, 0x01, 0x00, 2, 1, 0, 0x00}, - {0x2007, "arith.select", 0, 0x01, 0x00, 3, 1, 0, 0x00}, - {0x2008, "arith.subi", 0, 0x01, 0x00, 2, 1, 0, 0x00}, - {0x4000, "scf.for", 0, 0x00, 0x00, 3, 0, 1, 0x00}, - {0x4001, "scf.if", 0, 0x00, 0x00, 1, 0, 2, 0x00}, - {0x4002, "scf.yield", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x6000, "func.func", 0, 0x00, 0x00, 0, 0, 0, 0x00}, - {0x6001, "func.return", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x6002, "func.call", 0, 0x02, 0x02, 0, 0, 0, 0x00}, -}; - -inline const OpInfo *lookupByOpcode(uint16_t opcode) { - // Binary search on kOpTable (sorted by opcode). - size_t lo = 0, hi = sizeof(kOpTable) / sizeof(kOpTable[0]); - while (lo < hi) { - size_t mid = lo + (hi - lo) / 2; - uint16_t v = kOpTable[mid].opcode; - if (v == opcode) return &kOpTable[mid]; - if (v < opcode) lo = mid + 1; else hi = mid; - } - return nullptr; -} - -inline std::optional lookupOpcodeByName(llvm::StringRef name) { - uint16_t v = llvm::StringSwitch(name) - .Case("arith.addi", 0x2000) - .Case("arith.ceildivsi", 0x2001) - .Case("arith.cmpi", 0x2002) - .Case("arith.constant", 0x2003) - .Case("arith.index_cast", 0x2004) - .Case("arith.minui", 0x2005) - .Case("arith.muli", 0x2006) - .Case("arith.select", 0x2007) - .Case("arith.subi", 0x2008) - .Case("func.func", 0x6000) - .Case("func.return", 0x6001) - .Case("func.call", 0x6002) - .Case("pto.addptr", 0x1000) - .Case("pto.alloc_tile", 0x1001) - .Case("pto.barrier", 0x1002) - .Case("pto.get_block_idx", 0x0000) - .Case("pto.get_block_num", 0x0001) - .Case("pto.get_subblock_idx", 0x0002) - .Case("pto.get_subblock_num", 0x0003) - .Case("pto.make_tensor_view", 0x0004) - .Case("pto.mgather", 0x1003) - .Case("pto.mscatter", 0x1004) - .Case("pto.partition_view", 0x0005) - .Case("pto.record_event", 0x1005) - .Case("pto.section", 0x0006) - .Case("pto.tabs", 0x1006) - .Case("pto.tadd", 0x1007) - .Case("pto.taddc", 0x1008) - .Case("pto.tadds", 0x1009) - .Case("pto.taddsc", 0x100A) - .Case("pto.tand", 0x100B) - .Case("pto.tands", 0x100C) - .Case("pto.tci", 0x100D) - .Case("pto.tcmp", 0x100E) - .Case("pto.tcmps", 0x100F) - .Case("pto.tcolexpand", 0x1010) - .Case("pto.tcolexpandadd", 0x1011) - .Case("pto.tcolexpanddiv", 0x1012) - .Case("pto.tcolexpandexpdif", 0x1013) - .Case("pto.tcolexpandmax", 0x1014) - .Case("pto.tcolexpandmin", 0x1015) - .Case("pto.tcolexpandmul", 0x1016) - .Case("pto.tcolexpandsub", 0x1017) - .Case("pto.tcolmax", 0x1018) - .Case("pto.tcolmin", 0x1019) - .Case("pto.tcolprod", 0x101A) - .Case("pto.tcolsum", 0x101B) - .Case("pto.tcvt", 0x101C) - .Case("pto.tdiv", 0x101D) - .Case("pto.tdivs", 0x101E) - .Case("pto.texp", 0x101F) - .Case("pto.texpands", 0x1020) - .Case("pto.textract", 0x1021) - .Case("pto.textract_fp", 0x1022) - .Case("pto.tfillpad", 0x1023) - .Case("pto.tfillpad_expand", 0x1024) - .Case("pto.tfillpad_inplace", 0x1025) - .Case("pto.tfmod", 0x1026) - .Case("pto.tfmods", 0x1027) - .Case("pto.tgather", 0x1028) - .Case("pto.tgatherb", 0x1029) - .Case("pto.tgemv", 0x102A) - .Case("pto.tgetval", 0x102B) - .Case("pto.timg2col", 0x102C) - .Case("pto.tinsert", 0x102D) - .Case("pto.tinsert_fp", 0x102E) - .Case("pto.tload", 0x102F) - .Case("pto.tlog", 0x1030) - .Case("pto.tlrelu", 0x1031) - .Case("pto.tmatmul", 0x1032) - .Case("pto.tmatmul.mx", 0x1033) - .Case("pto.tmax", 0x1034) - .Case("pto.tmaxs", 0x1035) - .Case("pto.tmin", 0x1036) - .Case("pto.tmins", 0x1037) - .Case("pto.tmov", 0x1038) - .Case("pto.tmov.fp", 0x1039) - .Case("pto.tmrgsort", 0x103A) - .Case("pto.tmul", 0x103B) - .Case("pto.tmuls", 0x103C) - .Case("pto.tneg", 0x103D) - .Case("pto.tnot", 0x103E) - .Case("pto.tor", 0x103F) - .Case("pto.tors", 0x1040) - .Case("pto.tpartadd", 0x1041) - .Case("pto.tpartmax", 0x1042) - .Case("pto.tpartmin", 0x1043) - .Case("pto.tpartmul", 0x1044) - .Case("pto.tprefetch", 0x1045) - .Case("pto.tprelu", 0x1046) - .Case("pto.tquant", 0x1047) - .Case("pto.trecip", 0x1048) - .Case("pto.trelu", 0x1049) - .Case("pto.trem", 0x104A) - .Case("pto.trems", 0x104B) - .Case("pto.treshape", 0x104C) - .Case("pto.trowexpand", 0x104D) - .Case("pto.trowexpandadd", 0x104E) - .Case("pto.trowexpandexpdif", 0x104F) - .Case("pto.trowexpandmax", 0x1050) - .Case("pto.trowexpandmin", 0x1051) - .Case("pto.trowmax", 0x1052) - .Case("pto.trowmin", 0x1053) - .Case("pto.trowsum", 0x1054) - .Case("pto.trsqrt", 0x1055) - .Case("pto.tscatter", 0x1056) - .Case("pto.tsel", 0x1057) - .Case("pto.tsels", 0x1058) - .Case("pto.tset_img2col_padding", 0x1059) - .Case("pto.tset_img2col_rpt", 0x105A) - .Case("pto.tsetfmatrix", 0x105B) - .Case("pto.tsethf32mode", 0x105C) - .Case("pto.tsettf32mode", 0x105D) - .Case("pto.tsetval", 0x105E) - .Case("pto.tshl", 0x105F) - .Case("pto.tshls", 0x1060) - .Case("pto.tshr", 0x1061) - .Case("pto.tshrs", 0x1062) - .Case("pto.tsort32", 0x1063) - .Case("pto.tsqrt", 0x1064) - .Case("pto.tstore", 0x1065) - .Case("pto.tstore_fp", 0x1066) - .Case("pto.tsub", 0x1067) - .Case("pto.tsubc", 0x1068) - .Case("pto.tsubs", 0x1069) - .Case("pto.tsubsc", 0x106A) - .Case("pto.trowexpandsub", 0x106B) - .Case("pto.ttrans", 0x106C) - .Case("pto.ttri", 0x106D) - .Case("pto.txor", 0x106E) - .Case("pto.txors", 0x106F) - .Case("pto.wait_event", 0x1070) - .Case("pto.tprint", 0x1071) - .Case("pto.subview", 0x1072) - .Case("pto.trowexpanddiv", 0x1073) - .Case("pto.trowexpandmul", 0x1074) - .Case("pto.tdequant", 0x1075) - .Case("pto.taxpy", 0x1076) - .Case("pto.thistogram", 0x1077) - .Case("pto.tget_scale_addr", 0x1078) - .Case("pto.trowargmax", 0x1079) - .Case("pto.trowargmin", 0x107A) - .Case("pto.tcolargmax", 0x107B) - .Case("pto.tcolargmin", 0x107C) - .Case("pto.tsync", 0x107D) - .Case("pto.reserve_buffer", 0x107E) - .Case("pto.import_reserved_buffer", 0x107F) - .Case("pto.aic_initialize_pipe", 0x1080) - .Case("pto.aiv_initialize_pipe", 0x1081) - .Case("pto.tpush_to_aiv", 0x1082) - .Case("pto.tpush_to_aic", 0x1083) - .Case("pto.tpop_from_aic", 0x1084) - .Case("pto.tpop_from_aiv", 0x1085) - .Case("pto.tfree_from_aic", 0x1086) - .Case("pto.tfree_from_aiv", 0x1087) - .Case("pto.set_validshape", 0x1088) - .Case("pto.tconcat", 0x1089) - .Case("pto.trowprod", 0x108A) - .Case("pto.initialize_l2g2l_pipe", 0x108B) - .Case("pto.initialize_l2l_pipe", 0x108C) - .Case("pto.tpush", 0x108D) - .Case("pto.declare_tile", 0x108E) - .Case("pto.tpop", 0x108F) - .Case("pto.tfree", 0x1090) - .Case("pto.comm.tput", 0x1091) - .Case("pto.comm.tget", 0x1092) - .Case("pto.comm.tnotify", 0x1093) - .Case("pto.comm.twait", 0x1094) - .Case("pto.comm.ttest", 0x1095) - .Case("pto.comm.tbroadcast", 0x1096) - .Case("pto.comm.tgather", 0x1097) - .Case("pto.comm.tscatter", 0x1098) - .Case("pto.comm.treduce", 0x1099) - .Case("pto.tpartargmax", 0x109A) - .Case("pto.tpartargmin", 0x109B) - .Case("scf.for", 0x4000) - .Case("scf.if", 0x4001) - .Case("scf.yield", 0x4002) - .Default(0xFFFF); - if (v == 0xFFFF) return std::nullopt; - return v; -} - -inline const OpInfo *lookupByName(llvm::StringRef name) { - auto o = lookupOpcodeByName(name); - if (!o) return nullptr; - return lookupByOpcode(*o); -} - -struct OpcodeAndVariant { uint16_t opcode; uint8_t hasVariant; uint8_t variant; }; - -inline std::optional lookupOpcodeAndVariantByFullName(llvm::StringRef fullName) { - // For non-family ops, variant is 0. For family ops, variant is the assigned u8. - // NOTE: `pto.section` is not a real op name; use `pto.section.cube`/`pto.section.vector`. - return llvm::StringSwitch>(fullName) - .Case("arith.addi", OpcodeAndVariant{0x2000, 0, 0}) - .Case("arith.ceildivsi", OpcodeAndVariant{0x2001, 0, 0}) - .Case("arith.cmpi", OpcodeAndVariant{0x2002, 0, 0}) - .Case("arith.constant", OpcodeAndVariant{0x2003, 0, 0}) - .Case("arith.index_cast", OpcodeAndVariant{0x2004, 0, 0}) - .Case("arith.minui", OpcodeAndVariant{0x2005, 0, 0}) - .Case("arith.muli", OpcodeAndVariant{0x2006, 0, 0}) - .Case("arith.select", OpcodeAndVariant{0x2007, 0, 0}) - .Case("arith.subi", OpcodeAndVariant{0x2008, 0, 0}) - .Case("func.func", OpcodeAndVariant{0x6000, 0, 0}) - .Case("func.return", OpcodeAndVariant{0x6001, 0, 0}) - .Case("func.call", OpcodeAndVariant{0x6002, 0, 0}) - .Case("pto.addptr", OpcodeAndVariant{0x1000, 0, 0}) - .Case("pto.alloc_tile", OpcodeAndVariant{0x1001, 0, 0}) - .Case("pto.barrier", OpcodeAndVariant{0x1002, 0, 0}) - .Case("pto.get_block_idx", OpcodeAndVariant{0x0000, 0, 0}) - .Case("pto.get_block_num", OpcodeAndVariant{0x0001, 0, 0}) - .Case("pto.get_subblock_idx", OpcodeAndVariant{0x0002, 0, 0}) - .Case("pto.get_subblock_num", OpcodeAndVariant{0x0003, 0, 0}) - .Case("pto.make_tensor_view", OpcodeAndVariant{0x0004, 0, 0}) - .Case("pto.mgather", OpcodeAndVariant{0x1003, 0, 0}) - .Case("pto.mscatter", OpcodeAndVariant{0x1004, 0, 0}) - .Case("pto.partition_view", OpcodeAndVariant{0x0005, 0, 0}) - .Case("pto.record_event", OpcodeAndVariant{0x1005, 0, 0}) - .Case("pto.tabs", OpcodeAndVariant{0x1006, 0, 0}) - .Case("pto.tadd", OpcodeAndVariant{0x1007, 0, 0}) - .Case("pto.taddc", OpcodeAndVariant{0x1008, 0, 0}) - .Case("pto.tadds", OpcodeAndVariant{0x1009, 0, 0}) - .Case("pto.taddsc", OpcodeAndVariant{0x100A, 0, 0}) - .Case("pto.tand", OpcodeAndVariant{0x100B, 0, 0}) - .Case("pto.tands", OpcodeAndVariant{0x100C, 0, 0}) - .Case("pto.tci", OpcodeAndVariant{0x100D, 0, 0}) - .Case("pto.tcmp", OpcodeAndVariant{0x100E, 0, 0}) - .Case("pto.tcmps", OpcodeAndVariant{0x100F, 0, 0}) - .Case("pto.tcolexpand", OpcodeAndVariant{0x1010, 0, 0}) - .Case("pto.tcolexpandadd", OpcodeAndVariant{0x1011, 0, 0}) - .Case("pto.tcolexpanddiv", OpcodeAndVariant{0x1012, 0, 0}) - .Case("pto.tcolexpandexpdif", OpcodeAndVariant{0x1013, 0, 0}) - .Case("pto.tcolexpandmax", OpcodeAndVariant{0x1014, 0, 0}) - .Case("pto.tcolexpandmin", OpcodeAndVariant{0x1015, 0, 0}) - .Case("pto.tcolexpandmul", OpcodeAndVariant{0x1016, 0, 0}) - .Case("pto.tcolexpandsub", OpcodeAndVariant{0x1017, 0, 0}) - .Case("pto.tcolmax", OpcodeAndVariant{0x1018, 0, 0}) - .Case("pto.tcolmin", OpcodeAndVariant{0x1019, 0, 0}) - .Case("pto.tcolprod", OpcodeAndVariant{0x101A, 0, 0}) - .Case("pto.tcolsum", OpcodeAndVariant{0x101B, 0, 0}) - .Case("pto.tcvt", OpcodeAndVariant{0x101C, 0, 0}) - .Case("pto.tdiv", OpcodeAndVariant{0x101D, 0, 0}) - .Case("pto.tdivs", OpcodeAndVariant{0x101E, 0, 0}) - .Case("pto.texp", OpcodeAndVariant{0x101F, 0, 0}) - .Case("pto.texpands", OpcodeAndVariant{0x1020, 0, 0}) - .Case("pto.textract", OpcodeAndVariant{0x1021, 0, 0}) - .Case("pto.textract_fp", OpcodeAndVariant{0x1022, 0, 0}) - .Case("pto.tfillpad", OpcodeAndVariant{0x1023, 0, 0}) - .Case("pto.tfillpad_expand", OpcodeAndVariant{0x1024, 0, 0}) - .Case("pto.tfillpad_inplace", OpcodeAndVariant{0x1025, 0, 0}) - .Case("pto.tfmod", OpcodeAndVariant{0x1026, 0, 0}) - .Case("pto.tfmods", OpcodeAndVariant{0x1027, 0, 0}) - .Case("pto.tgather", OpcodeAndVariant{0x1028, 0, 0}) - .Case("pto.tgatherb", OpcodeAndVariant{0x1029, 0, 0}) - .Case("pto.tgetval", OpcodeAndVariant{0x102B, 0, 0}) - .Case("pto.timg2col", OpcodeAndVariant{0x102C, 0, 0}) - .Case("pto.tinsert", OpcodeAndVariant{0x102D, 0, 0}) - .Case("pto.tinsert_fp", OpcodeAndVariant{0x102E, 0, 0}) - .Case("pto.tload", OpcodeAndVariant{0x102F, 0, 0}) - .Case("pto.tlog", OpcodeAndVariant{0x1030, 0, 0}) - .Case("pto.tlrelu", OpcodeAndVariant{0x1031, 0, 0}) - .Case("pto.tmax", OpcodeAndVariant{0x1034, 0, 0}) - .Case("pto.tmaxs", OpcodeAndVariant{0x1035, 0, 0}) - .Case("pto.tmin", OpcodeAndVariant{0x1036, 0, 0}) - .Case("pto.tmins", OpcodeAndVariant{0x1037, 0, 0}) - .Case("pto.tmov", OpcodeAndVariant{0x1038, 0, 0}) - .Case("pto.tmov.fp", OpcodeAndVariant{0x1039, 0, 0}) - .Case("pto.tmrgsort", OpcodeAndVariant{0x103A, 0, 0}) - .Case("pto.tmul", OpcodeAndVariant{0x103B, 0, 0}) - .Case("pto.tmuls", OpcodeAndVariant{0x103C, 0, 0}) - .Case("pto.tneg", OpcodeAndVariant{0x103D, 0, 0}) - .Case("pto.tnot", OpcodeAndVariant{0x103E, 0, 0}) - .Case("pto.tor", OpcodeAndVariant{0x103F, 0, 0}) - .Case("pto.tors", OpcodeAndVariant{0x1040, 0, 0}) - .Case("pto.tpartadd", OpcodeAndVariant{0x1041, 0, 0}) - .Case("pto.tpartmax", OpcodeAndVariant{0x1042, 0, 0}) - .Case("pto.tpartmin", OpcodeAndVariant{0x1043, 0, 0}) - .Case("pto.tpartmul", OpcodeAndVariant{0x1044, 0, 0}) - .Case("pto.tprefetch", OpcodeAndVariant{0x1045, 0, 0}) - .Case("pto.tprelu", OpcodeAndVariant{0x1046, 0, 0}) - .Case("pto.tquant", OpcodeAndVariant{0x1047, 0, 0}) - .Case("pto.trecip", OpcodeAndVariant{0x1048, 0, 0}) - .Case("pto.trelu", OpcodeAndVariant{0x1049, 0, 0}) - .Case("pto.trem", OpcodeAndVariant{0x104A, 0, 0}) - .Case("pto.trems", OpcodeAndVariant{0x104B, 0, 0}) - .Case("pto.treshape", OpcodeAndVariant{0x104C, 0, 0}) - .Case("pto.trowexpand", OpcodeAndVariant{0x104D, 0, 0}) - .Case("pto.trowexpandadd", OpcodeAndVariant{0x104E, 0, 0}) - .Case("pto.trowexpandexpdif", OpcodeAndVariant{0x104F, 0, 0}) - .Case("pto.trowexpandmax", OpcodeAndVariant{0x1050, 0, 0}) - .Case("pto.trowexpandmin", OpcodeAndVariant{0x1051, 0, 0}) - .Case("pto.trowmax", OpcodeAndVariant{0x1052, 0, 0}) - .Case("pto.trowmin", OpcodeAndVariant{0x1053, 0, 0}) - .Case("pto.trowsum", OpcodeAndVariant{0x1054, 0, 0}) - .Case("pto.trsqrt", OpcodeAndVariant{0x1055, 0, 0}) - .Case("pto.tscatter", OpcodeAndVariant{0x1056, 0, 0}) - .Case("pto.tsel", OpcodeAndVariant{0x1057, 0, 0}) - .Case("pto.tsels", OpcodeAndVariant{0x1058, 0, 0}) - .Case("pto.tset_img2col_padding", OpcodeAndVariant{0x1059, 0, 0}) - .Case("pto.tset_img2col_rpt", OpcodeAndVariant{0x105A, 0, 0}) - .Case("pto.tsetfmatrix", OpcodeAndVariant{0x105B, 0, 0}) - .Case("pto.tsethf32mode", OpcodeAndVariant{0x105C, 0, 0}) - .Case("pto.tsettf32mode", OpcodeAndVariant{0x105D, 0, 0}) - .Case("pto.tsetval", OpcodeAndVariant{0x105E, 0, 0}) - .Case("pto.tshl", OpcodeAndVariant{0x105F, 0, 0}) - .Case("pto.tshls", OpcodeAndVariant{0x1060, 0, 0}) - .Case("pto.tshr", OpcodeAndVariant{0x1061, 0, 0}) - .Case("pto.tshrs", OpcodeAndVariant{0x1062, 0, 0}) - .Case("pto.tsort32", OpcodeAndVariant{0x1063, 0, 0}) - .Case("pto.tsqrt", OpcodeAndVariant{0x1064, 0, 0}) - .Case("pto.tstore", OpcodeAndVariant{0x1065, 0, 0}) - .Case("pto.tstore_fp", OpcodeAndVariant{0x1066, 0, 0}) - .Case("pto.tsub", OpcodeAndVariant{0x1067, 0, 0}) - .Case("pto.tsubc", OpcodeAndVariant{0x1068, 0, 0}) - .Case("pto.tsubs", OpcodeAndVariant{0x1069, 0, 0}) - .Case("pto.tsubsc", OpcodeAndVariant{0x106A, 0, 0}) - .Case("pto.trowexpandsub", OpcodeAndVariant{0x106B, 0, 0}) - .Case("pto.ttrans", OpcodeAndVariant{0x106C, 0, 0}) - .Case("pto.ttri", OpcodeAndVariant{0x106D, 0, 0}) - .Case("pto.txor", OpcodeAndVariant{0x106E, 0, 0}) - .Case("pto.txors", OpcodeAndVariant{0x106F, 0, 0}) - .Case("pto.wait_event", OpcodeAndVariant{0x1070, 0, 0}) - .Case("pto.tprint", OpcodeAndVariant{0x1071, 0, 0}) - .Case("pto.subview", OpcodeAndVariant{0x1072, 0, 0}) - .Case("pto.trowexpanddiv", OpcodeAndVariant{0x1073, 0, 0}) - .Case("pto.trowexpandmul", OpcodeAndVariant{0x1074, 0, 0}) - .Case("pto.tdequant", OpcodeAndVariant{0x1075, 0, 0}) - .Case("pto.taxpy", OpcodeAndVariant{0x1076, 0, 0}) - .Case("pto.thistogram", OpcodeAndVariant{0x1077, 0, 0}) - .Case("pto.tget_scale_addr", OpcodeAndVariant{0x1078, 0, 0}) - .Case("pto.trowargmax", OpcodeAndVariant{0x1079, 0, 0}) - .Case("pto.trowargmin", OpcodeAndVariant{0x107A, 0, 0}) - .Case("pto.tcolargmax", OpcodeAndVariant{0x107B, 0, 0}) - .Case("pto.tcolargmin", OpcodeAndVariant{0x107C, 0, 0}) - .Case("pto.tsync", OpcodeAndVariant{0x107D, 0, 0}) - .Case("pto.reserve_buffer", OpcodeAndVariant{0x107E, 0, 0}) - .Case("pto.import_reserved_buffer", OpcodeAndVariant{0x107F, 0, 0}) - .Case("pto.aic_initialize_pipe", OpcodeAndVariant{0x1080, 0, 0}) - .Case("pto.aiv_initialize_pipe", OpcodeAndVariant{0x1081, 0, 0}) - .Case("pto.tpush_to_aiv", OpcodeAndVariant{0x1082, 0, 0}) - .Case("pto.tpush_to_aic", OpcodeAndVariant{0x1083, 0, 0}) - .Case("pto.tpop_from_aic", OpcodeAndVariant{0x1084, 0, 0}) - .Case("pto.tpop_from_aiv", OpcodeAndVariant{0x1085, 0, 0}) - .Case("pto.tfree_from_aic", OpcodeAndVariant{0x1086, 0, 0}) - .Case("pto.tfree_from_aiv", OpcodeAndVariant{0x1087, 0, 0}) - .Case("pto.set_validshape", OpcodeAndVariant{0x1088, 0, 0}) - .Case("pto.tconcat", OpcodeAndVariant{0x1089, 0, 0}) - .Case("pto.trowprod", OpcodeAndVariant{0x108A, 0, 0}) - .Case("pto.initialize_l2g2l_pipe", OpcodeAndVariant{0x108B, 0, 0}) - .Case("pto.initialize_l2l_pipe", OpcodeAndVariant{0x108C, 0, 0}) - .Case("pto.tpush", OpcodeAndVariant{0x108D, 0, 0}) - .Case("pto.declare_tile", OpcodeAndVariant{0x108E, 0, 0}) - .Case("pto.tpop", OpcodeAndVariant{0x108F, 0, 0}) - .Case("pto.tfree", OpcodeAndVariant{0x1090, 0, 0}) - .Case("pto.comm.tput", OpcodeAndVariant{0x1091, 0, 0}) - .Case("pto.comm.tget", OpcodeAndVariant{0x1092, 0, 0}) - .Case("pto.comm.tnotify", OpcodeAndVariant{0x1093, 0, 0}) - .Case("pto.comm.twait", OpcodeAndVariant{0x1094, 0, 0}) - .Case("pto.comm.ttest", OpcodeAndVariant{0x1095, 0, 0}) - .Case("pto.comm.tbroadcast", OpcodeAndVariant{0x1096, 0, 0}) - .Case("pto.comm.tgather", OpcodeAndVariant{0x1097, 0, 0}) - .Case("pto.comm.tscatter", OpcodeAndVariant{0x1098, 0, 0}) - .Case("pto.comm.treduce", OpcodeAndVariant{0x1099, 0, 0}) - .Case("pto.tpartargmax", OpcodeAndVariant{0x109A, 0, 0}) - .Case("pto.tpartargmin", OpcodeAndVariant{0x109B, 0, 0}) - .Case("scf.for", OpcodeAndVariant{0x4000, 0, 0}) - .Case("scf.if", OpcodeAndVariant{0x4001, 0, 0}) - .Case("scf.yield", OpcodeAndVariant{0x4002, 0, 0}) - .Case("pto.section.cube", - OpcodeAndVariant{0x0006, kHasVariant, kSectionCubeVariant}) - .Case("pto.section.vector", - OpcodeAndVariant{0x0006, kHasVariant, kSectionVectorVariant}) - .Case("pto.tgemv", - OpcodeAndVariant{0x102A, kHasVariant, kVariantDefault}) - .Case("pto.tgemv.acc", - OpcodeAndVariant{0x102A, kHasVariant, kVariantAcc}) - .Case("pto.tgemv.bias", - OpcodeAndVariant{0x102A, kHasVariant, kVariantBias}) - .Case("pto.tgemv.mx", - OpcodeAndVariant{0x102A, kHasVariant, kVariantMx}) - .Case("pto.tgemv.mx.acc", - OpcodeAndVariant{0x102A, kHasVariant, kVariantMxAcc}) - .Case("pto.tgemv.mx.bias", - OpcodeAndVariant{0x102A, kHasVariant, kVariantMxBias}) - .Case("pto.tmatmul", - OpcodeAndVariant{0x1032, kHasVariant, kVariantDefault}) - .Case("pto.tmatmul.acc", - OpcodeAndVariant{0x1032, kHasVariant, kVariantAcc}) - .Case("pto.tmatmul.bias", - OpcodeAndVariant{0x1032, kHasVariant, kVariantBias}) - .Case("pto.tmatmul.mx", - OpcodeAndVariant{0x1033, kHasVariant, kVariantDefault}) - .Case("pto.tmatmul.mx.acc", - OpcodeAndVariant{0x1033, kHasVariant, kVariantAcc}) - .Case("pto.tmatmul.mx.bias", - OpcodeAndVariant{0x1033, kHasVariant, kVariantBias}) - .Default(std::nullopt); -} - -inline const char *fullNameFromOpcodeVariant(uint16_t opcode, uint8_t variant) { - const OpInfo *info = lookupByOpcode(opcode); - if (!info) return nullptr; - if (opcode == kTscatterMaskOpcode) return "pto.tscatter"; - if (!info->has_variant_u8) return info->name; - switch (opcode) { - case 0x0006: - switch (variant) { - case kSectionCubeVariant: return "pto.section.cube"; - case kSectionVectorVariant: return "pto.section.vector"; - default: return info->name; - } - case 0x102A: - switch (variant) { - case kVariantDefault: return "pto.tgemv"; - case kVariantAcc: return "pto.tgemv.acc"; - case kVariantBias: return "pto.tgemv.bias"; - case kVariantMx: return "pto.tgemv.mx"; - case kVariantMxAcc: return "pto.tgemv.mx.acc"; - case kVariantMxBias: return "pto.tgemv.mx.bias"; - default: return info->name; - } - case 0x1032: - switch (variant) { - case kVariantDefault: return "pto.tmatmul"; - case kVariantAcc: return "pto.tmatmul.acc"; - case kVariantBias: return "pto.tmatmul.bias"; - default: return info->name; - } - case 0x1033: - switch (variant) { - case kVariantDefault: return "pto.tmatmul.mx"; - case kVariantAcc: return "pto.tmatmul.mx.acc"; - case kVariantBias: return "pto.tmatmul.mx.bias"; - default: return info->name; - } - default: return info->name; - } -} - -inline std::optional lookupOperandsByVariant(uint16_t opcode, uint8_t variant) { - switch (opcode) { - case 0x102A: - switch (variant) { - case kVariantDefault: return kTgemvOperandCount; - case kVariantAcc: return kTgemvAccOperandCount; - case kVariantBias: return kTgemvBiasOperandCount; - case kVariantMx: return kTgemvMxOperandCount; - case kVariantMxAcc: return kTgemvMxAccOperandCount; - case kVariantMxBias: return kTgemvMxBiasOperandCount; - default: return std::nullopt; - } - case 0x1032: - switch (variant) { - case kVariantDefault: return kTmatmulOperandCount; - case kVariantAcc: return kTmatmulAccOperandCount; - case kVariantBias: return kTmatmulBiasOperandCount; - default: return std::nullopt; - } - case 0x1033: - switch (variant) { - case kVariantDefault: return kTmatmulMxOperandCount; - case kVariantAcc: return kTmatmulMxAccOperandCount; - case kVariantBias: return kTmatmulMxBiasOperandCount; - default: return std::nullopt; - } - default: return std::nullopt; - } -} - -} // namespace ptobc::v0 +#include "ptobc_opcodes_v0.def" From c0faaaa4db8240da7ac03fef2eca9bcfe1c6fd58 Mon Sep 17 00:00:00 2001 From: HecreReed <821896444@qq.com> Date: Tue, 26 May 2026 12:22:50 +0800 Subject: [PATCH 2/8] Revert "refactor: split nbnc hotspot files" This reverts commit 17080f6484aad9b35575e1b0bb8c5eaf89b2f1d1. --- lib/PTO/IR/PTO.cpp | 12925 ++++++++++++++- lib/PTO/IR/PTO.def | 12933 ---------------- .../Transforms/GraphSyncSolver/SyncSolver.cpp | 2568 ++- .../Transforms/GraphSyncSolver/SyncSolver.def | 2576 --- lib/PTO/Transforms/PTOToEmitC.cpp | 12895 ++++++++++++++- lib/PTO/Transforms/PTOToEmitC.def | 12903 --------------- lib/PTO/Transforms/PTOViewToMemref.cpp | 3607 ++++- lib/PTO/Transforms/PTOViewToMemref.def | 3615 ----- tools/ptobc/generated/ptobc_opcodes_v0.def | 722 - tools/ptobc/generated/ptobc_opcodes_v0.h | 714 +- 10 files changed, 32704 insertions(+), 32754 deletions(-) delete mode 100644 lib/PTO/IR/PTO.def delete mode 100644 lib/PTO/Transforms/GraphSyncSolver/SyncSolver.def delete mode 100644 lib/PTO/Transforms/PTOToEmitC.def delete mode 100644 lib/PTO/Transforms/PTOViewToMemref.def delete mode 100644 tools/ptobc/generated/ptobc_opcodes_v0.def diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index e9dc72235..376b9c017 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -6,5 +6,12928 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. +//===- PTO.cpp - PTO Dialect ----------------------------------------------===// +//===----------------------------------------------------------------------===// -#include "PTO.def" +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" +#include "PTO/IR/PTOSyncUtils.h" + +#include "mlir/AsmParser/AsmParser.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Parser/Parser.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "llvm/Support/ErrorHandling.h" + +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace mlir::pto; + +// Forward declarations for custom shape/type printers used by tensor_view and +// partition_tensor_view. +namespace mlir { +namespace pto { +static LogicalResult parseShapeAndElem(AsmParser &parser, + SmallVectorImpl &shape, + Type &elementType, + bool allowDynamic = true); +static void printShapeAndElem(AsmPrinter &printer, + ArrayRef shape, + Type elementType); +} // namespace pto +} // namespace mlir + +// ============================================================================= +// TileBufType 的自定义 Shape 解析与打印函数 +// ============================================================================= + +// 解析逻辑:解析形如 "32x32" 的维度列表 +[[maybe_unused]] static ParseResult parseShape(AsmParser &parser, SmallVectorImpl &shape) { + // parseDimensionList 会解析 "dim x dim x ...", 遇到无法解析为维度的字符停止 + // 参数 allowDynamic=true (允许 ?), withTrailingX=false (不吞掉末尾的 x) + if (parser.parseDimensionList(shape, /*allowDynamic=*/true, /*withTrailingX=*/false)) + return failure(); + return success(); +} + +// 打印逻辑:打印形如 "32x32" 的维度列表 +[[maybe_unused]] static void printShape(AsmPrinter &printer, ArrayRef shape) { + for (auto it = shape.begin(); it != shape.end(); ++it) { + if (it != shape.begin()) printer << "x"; // 维度间的分隔符 + if (*it == ShapedType::kDynamic) + printer << "?"; + else + printer << *it; + } + // 注意:我们不在这里打印末尾的 'x',因为 assemblyFormat 中已经写了 `x` $elementType +} + +static std::optional getPTOMemorySpaceEnum(Type ty); +enum class VerifierTargetArch { + A2A3, + A5, +}; +static VerifierTargetArch getVerifierTargetArch(Operation *op); +static std::optional getVerifierArchName(Operation *op); +static bool isSupportedVecElemType(Type ty, bool allowBf16 = true, + bool allowInt8 = true); +static bool isSupportedLoadStoreElemTypeA2A3(Type ty); +static bool isSupportedGatherElemTypeA2A3(Type ty); +static bool isSupportedGatherElemTypeA5(Type ty); +static bool isA5TLoadStoreTransferElemType(Type ty); +static bool isA5AccStorePreQuantDstType(Type srcElem, Type dstElem); +static bool isA5LowPrecisionTCvtPair(Type srcElem, Type dstElem); +static bool isA5SupportedTCvtPair(Type srcElem, Type dstElem); +static ParseResult parseSyncEventOpCommon(OpAsmParser &parser, + OperationState &result, + StringAttr pipeAttrName, + StringAttr eventIdAttrName); +static void printSyncEventOpCommon(OpAsmPrinter &p, Operation *op, + PipeAttr pipeAttr, IntegerAttr eventAttr, + Value eventDyn, StringRef pipeAttrName, + StringRef eventIdAttrName); +static bool isTileLikeType(Type ty); +static SmallVector getShapeVec(Type ty); +static SmallVector getValidShapeVec(Type ty); +static SmallVector getValidShapeVec(Value value); +static bool isByteIntegerType(Type ty); +static LogicalResult verifyTileBufCommon(Operation *op, Type ty, StringRef name, + bool allowLowPrecision = false); +static LogicalResult verifyTileBufSameElemType(Operation *op, Type lhs, Type rhs, + StringRef lhsName, + StringRef rhsName); +static LogicalResult verifyTileBufSameLogicalExtent(Operation *op, Type lhs, + Type rhs, StringRef lhsName, + StringRef rhsName, + bool compareValidShape); + +static LogicalResult verifyTileBufSameValidShape(Operation *op, Type lhs, Type rhs, + StringRef lhsName, StringRef rhsName); +static LogicalResult verifyVecTileCommon(Operation *op, Type ty, StringRef name); +static LogicalResult verifyVecTileCommonA2A3(Operation *op, Type ty, + StringRef name); +static LogicalResult verifyVecTileCommonA5(Operation *op, Type ty, + StringRef name); +static LogicalResult verifyVecTileUnaryOp(Operation *op, Type srcTy, Type dstTy, + StringRef srcName = "src", + StringRef dstName = "dst", + bool allowBf16 = true, + bool allowInt8 = true); +static LogicalResult verifyAccTileCommon(Operation *op, Type ty, StringRef name); +static LogicalResult verifyAccTileCommonA2A3(Operation *op, Type ty, + StringRef name); +static LogicalResult verifyAccTileCommonA5(Operation *op, Type ty, + StringRef name); +static LogicalResult verifyMatTileOperands(Operation *op, Type lhsTy, Type rhsTy, + Type dstTy); +static LogicalResult verifyMatTileOperandsA2A3(Operation *op, Type lhsTy, + Type rhsTy, Type dstTy); +static LogicalResult verifyMatTileOperandsA5(Operation *op, Type lhsTy, + Type rhsTy, Type dstTy); +static LogicalResult verifyGemvTileOperands(Operation *op, Type lhsTy, Type rhsTy, + Type dstTy); +static LogicalResult verifyAsyncFlatContiguous1DGMViewLike(Operation *op, + Value value, + StringRef name); +static LogicalResult verifyGemvTileOperandsA2A3(Operation *op, Type lhsTy, + Type rhsTy, Type dstTy); +static LogicalResult verifyGemvTileOperandsA5(Operation *op, Type lhsTy, + Type rhsTy, Type dstTy); +static LogicalResult verifyMatBiasTile(Operation *op, Type biasTy, Type dstTy, + bool requireFloatBias = false); +static LogicalResult verifyMatBiasTileA2A3(Operation *op, Type biasTy, Type dstTy, + bool requireFloatBias = false); +static LogicalResult verifyMatBiasTileA5(Operation *op, Type biasTy, Type dstTy, + bool requireFloatBias = false); +static LogicalResult verifyMatmulTypeTriple(Operation *op, Type lhsElemTy, + Type rhsElemTy, Type dstElemTy); +static std::optional getLogicalViewLayout(Value value); +static std::optional getTileBufLogicalLayout(pto::TileBufType type); +static std::optional getConstantIntegerValue(Value value); +static LogicalResult verifyPartialValidPattern(Operation *op, Type src0Ty, + Type src1Ty, Type dstTy); +static Type getElemTy(Type ty); +static FailureOr +verifyMatchingRowMajorBinaryTileOpCommon(Operation *op, Type src0Ty, + Type src1Ty, Type dstTy); +static FailureOr +verifyNumericScalarTileOpCommon(Operation *op, Type srcTy, Type dstTy, + Type scalarTy, bool requireValidRowsEqual); +static FailureOr +verifyShiftLikeBinaryTileOpCommon(Operation *op, Type src0Ty, Type src1Ty, + Type dstTy); +static LogicalResult verifyArithmeticElemTypeForArch( + Operation *op, Type elemTy, PTOArch targetArch, bool allowInt8OnA5, + bool allowBf16OnA5, StringRef a2a3Error, StringRef a5Error); +static bool isRowMajorTileBuf(Type ty); + +#define GET_ENUM_CLASSES +#include "PTO/IR/PTOEnums.cpp.inc" + +#define GET_TYPEDEF_CLASSES +#include "PTO/IR/PTOTypeDefs.cpp.inc" + +#define GET_ATTRDEF_CLASSES +#include "PTO/IR/PTOAttrs.cpp.inc" + +#include "PTO/IR/PTODialect.cpp.inc" + +[[maybe_unused]] static LogicalResult parseShapeAndElemStable(mlir::AsmParser &parser, + llvm::SmallVectorImpl &shape, + mlir::Type &elementType) { + if (failed(parser.parseLess())) + return failure(); + + if (failed(parser.parseDimensionList(shape, /*allowDynamic=*/true))) + return failure(); + + if (failed(parser.parseType(elementType))) + return failure(); + + if (failed(parser.parseGreater())) + return failure(); + + return success(); +} + +static int64_t getPTOTypeRank(Type type) { + // 1. 处理标准的 MLIR 类型 (MemRef, Tensor, Vector) + if (auto shapedTy = dyn_cast(type)) { + if (shapedTy.hasRank()) + return shapedTy.getRank(); + return -1; // Unranked type + } + + // 2. 处理 PTO 自定义类型 + if (auto tvTy = dyn_cast(type)) + return tvTy.getRank(); + + if (auto tileTy = dyn_cast(type)) + return tileTy.getRank(); + + if (auto tileViewTy = dyn_cast(type)) + return tileViewTy.getRank(); + + if (auto tileBufTy = dyn_cast(type)) + return tileBufTy.getRank(); + + // 3. 不支持的类型 + return -1; +} + +static bool isGmAddressSpaceAttr(Attribute memorySpace) { + if (!memorySpace) + return true; + if (auto addr = mlir::dyn_cast(memorySpace)) + return addr.getAddressSpace() == pto::AddressSpace::GM; + if (auto intAttr = mlir::dyn_cast(memorySpace)) + return intAttr.getInt() == 0; + return false; +} + +PTOArch mlir::pto::getTargetArch(ModuleOp module) { + if (!module) + return PTOArch::A3; + + auto arch = module->getAttrOfType(kPTOTargetArchAttrName); + if (arch && arch.getValue().equals_insensitive("a5")) + return PTOArch::A5; + return PTOArch::A3; +} + +PTOArch mlir::pto::getTargetArch(Operation *op) { + if (!op) + return PTOArch::A3; + return getTargetArch(op->getParentOfType()); +} + +bool mlir::pto::isTargetArchA3(ModuleOp module) { + return getTargetArch(module) == PTOArch::A3; +} + +bool mlir::pto::isTargetArchA5(ModuleOp module) { + return getTargetArch(module) == PTOArch::A5; +} + +bool mlir::pto::isTargetArchA3(Operation *op) { + return getTargetArch(op) == PTOArch::A3; +} + +bool mlir::pto::isTargetArchA5(Operation *op) { + return getTargetArch(op) == PTOArch::A5; +} + +static llvm::TypeSize getOneByteTypeSize() { + return llvm::TypeSize::getFixed(8); +} + +llvm::TypeSize mlir::pto::HiF8Type::getTypeSizeInBits( + const DataLayout &, DataLayoutEntryListRef) const { + return getOneByteTypeSize(); +} + +uint64_t mlir::pto::HiF8Type::getABIAlignment(const DataLayout &, + DataLayoutEntryListRef) const { + return 1; +} + +uint64_t mlir::pto::HiF8Type::getPreferredAlignment( + const DataLayout &, DataLayoutEntryListRef) const { + return 1; +} + +llvm::TypeSize mlir::pto::F4E1M2x2Type::getTypeSizeInBits( + const DataLayout &, DataLayoutEntryListRef) const { + return getOneByteTypeSize(); +} + +uint64_t mlir::pto::F4E1M2x2Type::getABIAlignment( + const DataLayout &, DataLayoutEntryListRef) const { + return 1; +} + +uint64_t mlir::pto::F4E1M2x2Type::getPreferredAlignment( + const DataLayout &, DataLayoutEntryListRef) const { + return 1; +} + +llvm::TypeSize mlir::pto::F4E2M1x2Type::getTypeSizeInBits( + const DataLayout &, DataLayoutEntryListRef) const { + return getOneByteTypeSize(); +} + +uint64_t mlir::pto::F4E2M1x2Type::getABIAlignment( + const DataLayout &, DataLayoutEntryListRef) const { + return 1; +} + +uint64_t mlir::pto::F4E2M1x2Type::getPreferredAlignment( + const DataLayout &, DataLayoutEntryListRef) const { + return 1; +} + +static VerifierTargetArch getVerifierTargetArch(Operation *op) { + if (auto archName = getVerifierArchName(op)) { + return archName->equals_insensitive("a5") ? VerifierTargetArch::A5 + : VerifierTargetArch::A2A3; + } + + switch (getPTOParserTargetArch(op ? op->getContext() : nullptr)) { + case PTOParserTargetArch::A5: + return VerifierTargetArch::A5; + case PTOParserTargetArch::A3: + case PTOParserTargetArch::Unspecified: + return VerifierTargetArch::A2A3; + } + + return VerifierTargetArch::A2A3; +} + +static std::optional getVerifierArchName(Operation *op) { + auto module = op ? op->getParentOfType() : ModuleOp(); + if (!module) + return std::nullopt; + if (auto arch = module->getAttrOfType(kPTOTargetArchAttrName)) + return arch.getValue(); + return std::nullopt; +} + +static bool shouldBypassDecodedMemrefVerifier(Operation *op) { + if (!op) + return false; + for (Value operand : op->getOperands()) { + if (isa(operand.getType())) + return true; + if (operand.getDefiningOp()) + return true; + } + return false; +} + +static SmallVector canonicalizeTileBufValidShape(ArrayRef validShape) { + SmallVector canonical; + canonical.reserve(validShape.size()); + for (int64_t dim : validShape) + canonical.push_back(dim < 0 ? ShapedType::kDynamic : dim); + return canonical; +} + +template +static LogicalResult dispatchVerifierByArch(Operation *op, FnA2A3 &&verifyA2A3, + FnA5 &&verifyA5) { + if (shouldBypassDecodedMemrefVerifier(op)) + return success(); + switch (getVerifierTargetArch(op)) { + case VerifierTargetArch::A2A3: + return verifyA2A3(); + case VerifierTargetArch::A5: + return verifyA5(); + } + return failure(); +} + +static ParseResult parseSyncEventOpCommon(OpAsmParser &parser, + OperationState &result, + StringAttr pipeAttrName, + StringAttr eventIdAttrName) { + PipeAttr pipeAttr; + if (succeeded(parser.parseOptionalLess())) { + StringRef pipeTok; + if (parser.parseKeyword(&pipeTok) || parser.parseGreater()) + return failure(); + auto pipeOr = symbolizePIPE(pipeTok); + if (!pipeOr) + return parser.emitError(parser.getCurrentLocation()) + << "unknown pipe token: " << pipeTok; + pipeAttr = PipeAttr::get(parser.getContext(), *pipeOr); + result.addAttribute(pipeAttrName, pipeAttr); + } else if (parser.parseAttribute(pipeAttr, pipeAttrName, + result.attributes)) { + return failure(); + } + if (parser.parseComma()) + return failure(); + + OpAsmParser::UnresolvedOperand eventOperand; + OptionalParseResult parseEventOperand = + parser.parseOptionalOperand(eventOperand); + if (parseEventOperand.has_value()) { + if (failed(*parseEventOperand)) + return failure(); + if (parser.resolveOperand(eventOperand, parser.getBuilder().getIndexType(), + result.operands)) + return failure(); + } else { + IntegerAttr eventAttr; + if (parser.parseAttribute(eventAttr, parser.getBuilder().getI32Type(), + eventIdAttrName, result.attributes)) + return failure(); + } + + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + return success(); +} + +static void printSyncEventOpCommon(OpAsmPrinter &p, Operation *op, + PipeAttr pipeAttr, IntegerAttr eventAttr, + Value eventDyn, StringRef pipeAttrName, + StringRef eventIdAttrName) { + p << " <" << stringifyPIPE(pipeAttr.getPipe()) << ">, "; + if (eventAttr) + p << eventAttr.getInt(); + else + p << eventDyn; + p.printOptionalAttrDict(op->getAttrs(), {pipeAttrName, eventIdAttrName}); +} + +[[maybe_unused]] static mlir::Type parsePTOTypeAllowNoBang(mlir::OpAsmParser &parser) { + mlir::Type ty; + + mlir::OptionalParseResult opt = parser.parseOptionalType(ty); + + if (opt.has_value()) { + if (failed(*opt)) + return mlir::Type(); + return ty; + } + + + llvm::StringRef head; + if (failed(parser.parseKeyword(&head))) + return mlir::Type(); + + mlir::MLIRContext *ctx = parser.getContext(); + + auto parseShapeElemForOpParser = + [&](llvm::SmallVectorImpl &shape, mlir::Type &elem) -> mlir::LogicalResult { + if (failed(parser.parseLess())) + return failure(); + if (failed(parser.parseDimensionList(shape, /*allowDynamic=*/true))) + return failure(); + if (failed(parser.parseType(elem))) + return failure(); + if (failed(parser.parseGreater())) + return failure(); + return success(); + }; + + if (head == "pto.tile_view") { + llvm::SmallVector shape; + mlir::Type elem; + if (failed(parseShapeElemForOpParser(shape, elem))) + return mlir::Type(); + return mlir::pto::PartitionTensorViewType::get(ctx, shape, elem); + } + + if (head == "pto.tile") { + llvm::SmallVector shape; + mlir::Type elem; + if (failed(parseShapeElemForOpParser(shape, elem))) + return mlir::Type(); + return mlir::pto::TileType::get(ctx, shape, elem); + } + + if (head == "pto.ptr") { + if (failed(parser.parseLess())) + return mlir::Type(); + mlir::Type elem; + if (failed(parser.parseType(elem))) + return mlir::Type(); + if (succeeded(parser.parseOptionalComma())) { + // ptr no longer accepts an address space; consume the attr for recovery. + mlir::Attribute memorySpace; + (void)parser.parseAttribute(memorySpace); + parser.emitError(parser.getCurrentLocation(), + "!pto.ptr no longer accepts address space; use !pto.ptr"); + return mlir::Type(); + } + if (failed(parser.parseGreater())) + return mlir::Type(); + return mlir::pto::PtrType::get(ctx, elem); + } + + if (head == "pto.tensor_view") { + llvm::SmallVector shape; + mlir::Type elem; + if (failed(parseShapeElemForOpParser(shape, elem))) + return mlir::Type(); + return mlir::pto::TensorViewType::get(ctx, shape, elem); + } + + return mlir::Type(); +} + +mlir::Type TensorViewType::parse(::mlir::AsmParser &parser) { + SmallVector shape; + Type elementType; + if (failed(parseShapeAndElem(parser, shape, elementType, /*allowDynamic=*/true))) + return Type(); + return TensorViewType::get(parser.getContext(), shape, elementType); +} + +void TensorViewType::print(::mlir::AsmPrinter &printer) const { + printShapeAndElem(printer, getShape(), getElementType()); +} + +//===----------------------------------------------------------------------===// +// pto.tdivs custom asm to support both: +// pto.tdivs ins(%src, %scalar : !pto.tile_buf<...>, f32) outs(%dst : !pto.tile_buf<...>) +// pto.tdivs ins(%scalar, %src : f32, !pto.tile_buf<...>) outs(%dst : !pto.tile_buf<...>) +// The operand order in the op follows textual input order. +//===----------------------------------------------------------------------===// + +ParseResult mlir::pto::TDivSOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand op0, op1, dst; + Type ty0, ty1, dstTy; + + if (parser.parseKeyword("ins") || parser.parseLParen() || + parser.parseOperand(op0) || parser.parseComma() || + parser.parseOperand(op1) || parser.parseColonType(ty0) || + parser.parseComma() || parser.parseType(ty1) || parser.parseRParen()) + return failure(); + + if (parser.parseKeyword("outs") || parser.parseLParen() || + parser.parseOperand(dst) || parser.parseColonType(dstTy) || + parser.parseRParen()) + return failure(); + + NamedAttrList attrs; + if (parser.parseOptionalAttrDict(attrs)) + return failure(); + + auto tile0 = dyn_cast(ty0); + auto tile1 = dyn_cast(ty1); + if ((tile0 && tile1) || (!tile0 && !tile1)) + return parser.emitError(parser.getCurrentLocation(), + "expected exactly one tile_buf operand and one scalar operand"); + + if (!dyn_cast(dstTy)) + return parser.emitError(parser.getCurrentLocation(), + "expected outs type to be !pto.tile_buf<...>"); + + // Keep textual order so later lowering can distinguish the two APIs by the + // first ins operand type. + if (parser.resolveOperand(op0, ty0, result.operands) || + parser.resolveOperand(op1, ty1, result.operands)) + return failure(); + + if (parser.resolveOperand(dst, dstTy, result.operands)) + return failure(); + + result.addAttributes(attrs); + return success(); +} + +void mlir::pto::TDivSOp::print(OpAsmPrinter &p) { + p << " ins("; + p << getSrc() << ", " << getScalar() << " : " + << getSrc().getType() << ", " << getScalar().getType(); + p << ") outs(" << getDst() << " : " << getDst().getType() << ")"; + + p.printOptionalAttrDict((*this)->getAttrs()); +} + + +//===----------------------------------------------------------------------===// +// pto.tgather custom asm supports three PTO-ISA forms: +// 1) index+tmp : ins(%src, %indices, %tmp : srcTy, indicesTy, tmpTy) outs(%dst : dstTy) +// 2) compare+tmp : ins(%src, %kValue, %tmp : srcTy, scalarTy, tmpTy) +// outs(%dst, %cdst : dstTy, cdstTy) {cmpMode = #pto.cmp, offset = 7} +// 3) mask : ins(%src, {maskPattern = #pto.mask_pattern} : srcTy) outs(%dst : dstTy) +//===----------------------------------------------------------------------===// + +ParseResult mlir::pto::TGatherOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand src, dst, cdst; + SmallVector insOps; + SmallVector insTypes; + Type srcTy, dstTy, cdstTy; + bool hasCdst = false; + bool hasMask = false; + bool hasIndices = false; + bool hasTmp = false; + bool hasKValue = false; + + if (parser.parseKeyword("ins") || parser.parseLParen() || parser.parseOperand(src)) + return failure(); + + if (!succeeded(parser.parseOptionalComma())) { + return parser.emitError(parser.getCurrentLocation(), + "expected ',' after src operand in ins(...)"); + } + + if (succeeded(parser.parseOptionalLBrace())) { + if (parser.parseKeyword("maskPattern") || parser.parseEqual()) + return failure(); + + Attribute rawMaskAttr; + if (parser.parseAttribute(rawMaskAttr) || parser.parseRBrace()) + return failure(); + + auto mp = llvm::dyn_cast(rawMaskAttr); + if (!mp) { + return parser.emitError(parser.getCurrentLocation(), + "expected #pto.mask_pattern for maskPattern"); + } + + result.addAttribute("maskPattern", mp); + hasMask = true; + + if (parser.parseColonType(srcTy) || parser.parseRParen()) + return failure(); + } else { + OpAsmParser::UnresolvedOperand extra; + if (parser.parseOperand(extra)) + return failure(); + insOps.push_back(extra); + while (succeeded(parser.parseOptionalComma())) { + if (insOps.size() == 3) { + return parser.emitError(parser.getCurrentLocation(), + "expected at most 3 extra operands in tgather ins(...)"); + } + if (parser.parseOperand(extra)) + return failure(); + insOps.push_back(extra); + } + + if (parser.parseColon() || parser.parseType(srcTy)) + return failure(); + for (size_t i = 0; i < insOps.size(); ++i) { + Type ty; + if (parser.parseComma() || parser.parseType(ty)) + return failure(); + insTypes.push_back(ty); + } + if (parser.parseRParen()) + return failure(); + } + + if (parser.parseKeyword("outs") || parser.parseLParen() || parser.parseOperand(dst)) + return failure(); + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseOperand(cdst)) + return failure(); + hasCdst = true; + } + if (parser.parseColonType(dstTy)) + return failure(); + if (hasCdst && (parser.parseComma() || parser.parseType(cdstTy))) + return failure(); + if (parser.parseRParen()) + return failure(); + + if (succeeded(parser.parseOptionalKeyword("maskPattern"))) { + if (hasMask) + return parser.emitError(parser.getCurrentLocation(), + "maskPattern may only be specified once"); + if (parser.parseEqual()) + return failure(); + Attribute rawMaskAttr; + if (parser.parseAttribute(rawMaskAttr)) + return failure(); + auto mp = llvm::dyn_cast(rawMaskAttr); + if (!mp) { + return parser.emitError(parser.getCurrentLocation(), + "expected #pto.mask_pattern for maskPattern"); + } + result.addAttribute("maskPattern", mp); + hasMask = true; + } + + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + if (hasMask) { + if (!insOps.empty()) + return parser.emitError(parser.getCurrentLocation(), + "mask-pattern tgather does not take extra ins operands"); + if (hasCdst) + return parser.emitError(parser.getCurrentLocation(), + "mask-pattern tgather expects a single outs operand"); + } else if (hasCdst) { + if (insOps.empty() || + !(mlir::isa(insTypes.front()) || + mlir::isa(insTypes.front()))) + return parser.emitError(parser.getCurrentLocation(), + "compare-form tgather expects a scalar kValue operand"); + hasKValue = true; + if (insOps.size() >= 2) { + if (!isTileLikeType(insTypes[1])) + return parser.emitError(parser.getCurrentLocation(), + "compare-form tgather tmp must be tile-like"); + hasTmp = true; + } + if (insOps.size() == 3) { + return parser.emitError(parser.getCurrentLocation(), + "compare-form tgather expects at most src, kValue, tmp in ins(...)"); + } + } else { + if (!insOps.empty() && !isTileLikeType(insTypes.front())) { + return parser.emitError(parser.getCurrentLocation(), + "index-form tgather expects tile-like indices; " + "compare-form must use outs(dst, cdst)"); + } + if (!insOps.empty()) { + hasIndices = true; + if (insOps.size() >= 2) { + if (!isTileLikeType(insTypes[1])) + return parser.emitError(parser.getCurrentLocation(), + "index-form tgather tmp must be tile-like"); + hasTmp = true; + } + } + if (insOps.size() == 3) { + return parser.emitError(parser.getCurrentLocation(), + "index-form tgather expects at most src, indices, tmp in ins(...)"); + } + } + + if (parser.resolveOperand(src, srcTy, result.operands) || + parser.resolveOperand(dst, dstTy, result.operands)) + return failure(); + if (hasCdst && parser.resolveOperand(cdst, cdstTy, result.operands)) + return failure(); + if (hasIndices && parser.resolveOperand(insOps[0], insTypes[0], result.operands)) + return failure(); + if (hasTmp && parser.resolveOperand(insOps[hasIndices ? 1 : 1], insTypes[1], result.operands)) + return failure(); + if (hasKValue && parser.resolveOperand(insOps[0], insTypes[0], result.operands)) + return failure(); + + result.addAttribute("operandSegmentSizes", + parser.getBuilder().getDenseI32ArrayAttr( + {1, 1, hasCdst ? 1 : 0, hasIndices ? 1 : 0, + hasTmp ? 1 : 0, hasKValue ? 1 : 0})); + return success(); +} + +void mlir::pto::TGatherOp::print(OpAsmPrinter &p) { + p << " ins(" << getSrc() << ", "; + if (auto mp = getMaskPatternAttr()) { + p << "{maskPattern = " << mp << "} : " << getSrc().getType(); + } else if (getCdst()) { + p << getKValue(); + if (getTmp()) { + p << ", " << getTmp(); + p << " : " << getSrc().getType() << ", " << getKValue().getType() + << ", " << getTmp().getType(); + } else { + p << " : " << getSrc().getType() << ", " << getKValue().getType(); + } + } else { + p << getIndices(); + if (getTmp()) { + p << ", " << getTmp(); + p << " : " << getSrc().getType() << ", " << getIndices().getType() + << ", " << getTmp().getType(); + } else { + p << " : " << getSrc().getType() << ", " << getIndices().getType(); + } + } + p << ") outs(" << getDst(); + if (getCdst()) + p << ", " << getCdst(); + p << " : " << getDst().getType(); + if (getCdst()) + p << ", " << getCdst().getType(); + p << ")"; + + if (getMaskPatternAttr()) { + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"maskPattern", "operandSegmentSizes"}); + } else { + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"operandSegmentSizes"}); + } +} + +ParseResult mlir::pto::TScatterOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand src, indexes, dst; + Type srcTy, idxTy, dstTy; + bool hasMask = false; + bool hasIndexes = false; + + if (parser.parseKeyword("ins") || parser.parseLParen() || + parser.parseOperand(src)) + return failure(); + + if (!succeeded(parser.parseOptionalComma())) + return parser.emitError(parser.getCurrentLocation(), + "expected ',' after src operand in ins(...)"); + + if (succeeded(parser.parseOptionalLBrace())) { + if (parser.parseKeyword("maskPattern") || parser.parseEqual()) + return failure(); + Attribute rawMaskAttr; + if (parser.parseAttribute(rawMaskAttr) || parser.parseRBrace()) + return failure(); + auto mp = llvm::dyn_cast(rawMaskAttr); + if (!mp) + return parser.emitError(parser.getCurrentLocation(), + "expected #pto.mask_pattern for maskPattern"); + result.addAttribute("maskPattern", mp); + hasMask = true; + if (parser.parseColonType(srcTy) || parser.parseRParen()) + return failure(); + } else { + if (parser.parseOperand(indexes)) + return failure(); + hasIndexes = true; + if (parser.parseColon() || parser.parseType(srcTy) || parser.parseComma() || + parser.parseType(idxTy) || parser.parseRParen()) + return failure(); + } + + if (parser.parseKeyword("outs") || parser.parseLParen() || + parser.parseOperand(dst) || parser.parseColonType(dstTy) || + parser.parseRParen()) + return failure(); + + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + if (result.attributes.get("maskPattern")) + hasMask = true; + + if (hasMask && hasIndexes) + return parser.emitError(parser.getCurrentLocation(), + "mask-pattern tscatter does not take indexes"); + if (!hasMask && !hasIndexes) + return parser.emitError(parser.getCurrentLocation(), + "expected indexes operand or maskPattern for tscatter"); + + if (parser.resolveOperand(src, srcTy, result.operands) || + parser.resolveOperand(dst, dstTy, result.operands) || + (hasIndexes && parser.resolveOperand(indexes, idxTy, result.operands))) + return failure(); + return success(); +} + +void mlir::pto::TScatterOp::print(OpAsmPrinter &p) { + p << " ins(" << getSrc() << ", "; + if (getMaskPatternAttr()) { + p << "{maskPattern = " << getMaskPatternAttr() << "} : " + << getSrc().getType(); + } else { + p << getIndexes() << " : " << getSrc().getType() << ", " + << getIndexes().getType(); + } + p << ") outs(" << getDst() << " : " << getDst().getType() << ")"; + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"maskPattern"}); +} + +namespace { + +struct CommRecvClause { + OpAsmParser::UnresolvedOperand ping; + std::optional pong; + Type pingTy; + Type pongTy; +}; + +static ParseResult parseCommRecvClause(OpAsmParser &parser, + CommRecvClause &recvClause) { + if (parser.parseKeyword("recv") || parser.parseLParen() || + parser.parseOperand(recvClause.ping)) + return failure(); + if (succeeded(parser.parseOptionalComma())) { + OpAsmParser::UnresolvedOperand pong; + if (parser.parseOperand(pong)) + return failure(); + recvClause.pong = pong; + } + return parser.parseRParen(); +} + +static ParseResult parseCommCollectiveTail( + OpAsmParser &parser, OperationState &result, + ArrayRef fixedOperands, + SmallVectorImpl &fixedTypes, CommRecvClause &recvClause, + SmallVectorImpl &groupOps, + SmallVectorImpl &groupTypes, ArrayRef operandSegmentsPrefix, + ArrayRef requiredAttrs) { + if (parser.parseComma() || parser.parseKeyword("group") || parser.parseLParen()) + return failure(); + + OpAsmParser::UnresolvedOperand group; + if (parser.parseOperand(group)) + return failure(); + groupOps.push_back(group); + while (succeeded(parser.parseOptionalComma())) { + if (parser.parseOperand(group)) + return failure(); + groupOps.push_back(group); + } + + if (parser.parseRParen()) + return failure(); + + if (parser.parseColon()) + return failure(); + + for (size_t i = 0; i < fixedTypes.size(); ++i) { + if (i != 0 && parser.parseComma()) + return failure(); + if (parser.parseType(fixedTypes[i])) + return failure(); + } + if (parser.parseComma() || parser.parseType(recvClause.pingTy)) + return failure(); + if (recvClause.pong) { + if (parser.parseComma() || parser.parseType(recvClause.pongTy)) + return failure(); + } + for (size_t i = 0; i < groupOps.size(); ++i) { + Type groupTy; + if (parser.parseComma() || parser.parseType(groupTy)) + return failure(); + groupTypes.push_back(groupTy); + } + if (parser.parseRParen()) + return failure(); + + NamedAttrList attrs; + if (parser.parseOptionalAttrDict(attrs)) + return failure(); + for (StringRef attrName : requiredAttrs) { + if (!attrs.get(attrName)) { + return parser.emitError(parser.getCurrentLocation()) + << "expected '" << attrName << "' attribute"; + } + } + result.addAttributes(attrs); + + for (auto [operand, type] : llvm::zip_equal(fixedOperands, fixedTypes)) { + if (parser.resolveOperand(operand, type, result.operands)) + return failure(); + } + if (parser.resolveOperand(recvClause.ping, recvClause.pingTy, result.operands)) + return failure(); + if (recvClause.pong && + parser.resolveOperand(*recvClause.pong, recvClause.pongTy, result.operands)) + return failure(); + if (parser.resolveOperands(groupOps, groupTypes, parser.getCurrentLocation(), + result.operands)) + return failure(); + + SmallVector segmentSizes(operandSegmentsPrefix.begin(), + operandSegmentsPrefix.end()); + segmentSizes.push_back(static_cast(groupOps.size())); + result.addAttribute("operandSegmentSizes", + parser.getBuilder().getDenseI32ArrayAttr(segmentSizes)); + return success(); +} + +static void printCommRecvClause(OpAsmPrinter &p, Value ping, Value pong) { + p << "recv(" << ping; + if (pong) + p << ", " << pong; + p << ")"; +} + +static void printCommGroupTypes(OpAsmPrinter &p, ValueRange group) { + for (Value groupValue : group) + p << ", " << groupValue.getType(); +} + +static void printCommGroupClause(OpAsmPrinter &p, ValueRange group) { + p << "group("; + p.printOperands(group); + p << ")"; +} + +} // namespace + +ParseResult mlir::pto::TBroadcastOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand src; + CommRecvClause recvClause; + SmallVector groupOps; + SmallVector groupTypes; + + if (parser.parseLParen() || parser.parseOperand(src) || parser.parseComma()) + return failure(); + if (failed(parseCommRecvClause(parser, recvClause))) + return failure(); + + SmallVector fixedOperands{src}; + SmallVector fixedTypes(1); + if (failed(parseCommCollectiveTail(parser, result, fixedOperands, fixedTypes, + recvClause, groupOps, groupTypes, + {1, 1, recvClause.pong ? 1 : 0}, {"root"}))) + return failure(); + return success(); +} + +void mlir::pto::TBroadcastOp::print(OpAsmPrinter &p) { + p << "(" << getSrc() << ", "; + printCommRecvClause(p, getPing(), getPong()); + p << ", "; + printCommGroupClause(p, getGroup()); + p << " : " << getSrc().getType() << ", " << getPing().getType(); + if (getPong()) + p << ", " << getPong().getType(); + printCommGroupTypes(p, getGroup()); + p << ")"; + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"operandSegmentSizes"}); +} + +ParseResult mlir::pto::CommTGatherOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand dst; + CommRecvClause recvClause; + SmallVector groupOps; + SmallVector groupTypes; + + if (parser.parseLParen() || parser.parseOperand(dst) || parser.parseComma()) + return failure(); + if (failed(parseCommRecvClause(parser, recvClause))) + return failure(); + + SmallVector fixedOperands{dst}; + SmallVector fixedTypes(1); + if (failed(parseCommCollectiveTail( + parser, result, fixedOperands, fixedTypes, recvClause, groupOps, + groupTypes, {1, 1, recvClause.pong ? 1 : 0}, + {"root"}))) + return failure(); + return success(); +} + +void mlir::pto::CommTGatherOp::print(OpAsmPrinter &p) { + p << "(" << getDst() << ", "; + printCommRecvClause(p, getPing(), getPong()); + p << ", "; + printCommGroupClause(p, getGroup()); + p << " : " << getDst().getType() << ", " << getPing().getType(); + if (getPong()) + p << ", " << getPong().getType(); + printCommGroupTypes(p, getGroup()); + p << ")"; + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"operandSegmentSizes"}); +} + +ParseResult mlir::pto::CommTScatterOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand src; + CommRecvClause recvClause; + SmallVector groupOps; + SmallVector groupTypes; + + if (parser.parseLParen() || parser.parseOperand(src) || parser.parseComma()) + return failure(); + if (failed(parseCommRecvClause(parser, recvClause))) + return failure(); + + SmallVector fixedOperands{src}; + SmallVector fixedTypes(1); + if (failed(parseCommCollectiveTail( + parser, result, fixedOperands, fixedTypes, recvClause, groupOps, + groupTypes, {1, 1, recvClause.pong ? 1 : 0}, + {"root"}))) + return failure(); + return success(); +} + +void mlir::pto::CommTScatterOp::print(OpAsmPrinter &p) { + p << "(" << getSrc() << ", "; + printCommRecvClause(p, getPing(), getPong()); + p << ", "; + printCommGroupClause(p, getGroup()); + p << " : " << getSrc().getType() << ", " << getPing().getType(); + if (getPong()) + p << ", " << getPong().getType(); + printCommGroupTypes(p, getGroup()); + p << ")"; + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"operandSegmentSizes"}); +} + +ParseResult mlir::pto::TReduceOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand dst, acc; + CommRecvClause recvClause; + SmallVector groupOps; + SmallVector groupTypes; + + if (parser.parseLParen() || parser.parseOperand(dst) || parser.parseComma() || + parser.parseOperand(acc) || parser.parseComma()) + return failure(); + if (failed(parseCommRecvClause(parser, recvClause))) + return failure(); + + SmallVector fixedOperands{dst, acc}; + SmallVector fixedTypes(2); + if (failed(parseCommCollectiveTail( + parser, result, fixedOperands, fixedTypes, recvClause, groupOps, + groupTypes, {1, 1, 1, recvClause.pong ? 1 : 0}, + {"reduceOp", "root"}))) + return failure(); + return success(); +} + +void mlir::pto::TReduceOp::print(OpAsmPrinter &p) { + p << "(" << getDst() << ", " << getAcc() << ", "; + printCommRecvClause(p, getRecvPing(), getRecvPong()); + p << ", "; + printCommGroupClause(p, getGroup()); + p << " : " << getDst().getType() << ", " << getAcc().getType() << ", " + << getRecvPing().getType(); + if (getRecvPong()) + p << ", " << getRecvPong().getType(); + printCommGroupTypes(p, getGroup()); + p << ")"; + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"operandSegmentSizes"}); +} + +ParseResult mlir::pto::MakeTensorViewOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand ptr; + SmallVector shapeOps; + SmallVector strideOps; + + Type resultTy; + + // %ptr + if (parser.parseOperand(ptr)) + return failure(); + + // , shape = [ ... ] + if (parser.parseComma() || parser.parseKeyword("shape") || parser.parseEqual() || + parser.parseLSquare() || + parser.parseOperandList(shapeOps) || + parser.parseRSquare()) + return failure(); + + // strides = [ ... ] + if (parser.parseComma() || parser.parseKeyword("strides") || parser.parseEqual() || + parser.parseLSquare() || + parser.parseOperandList(strideOps) || + parser.parseRSquare()) + return failure(); + + // attr-dict + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + // : result-type + if (parser.parseColonType(resultTy)) + return failure(); + result.addTypes(resultTy); + + auto tvTy = llvm::dyn_cast(resultTy); + if (!tvTy) + return parser.emitError(parser.getCurrentLocation(), + "expected result type pto.tensor_view<...>"); + + Type elemTy = tvTy.getElementType(); + + Type ptrTy = mlir::pto::PtrType::get(parser.getContext(), elemTy); + + // resolve %ptr + if (parser.resolveOperand(ptr, ptrTy, result.operands)) + return failure(); + + // resolve shape/strides 为 index + Type indexTy = parser.getBuilder().getIndexType(); + if (parser.resolveOperands(shapeOps, indexTy, result.operands)) + return failure(); + if (parser.resolveOperands(strideOps, indexTy, result.operands)) + return failure(); + + auto segAttr = parser.getBuilder().getDenseI32ArrayAttr( + {1, (int32_t)shapeOps.size(), (int32_t)strideOps.size()}); + result.addAttribute("operandSegmentSizes", segAttr); + + return success(); +} + +void mlir::pto::MakeTensorViewOp::print(OpAsmPrinter &p) { + p << " " << getPtr(); + + p << ", shape = ["; + p.printOperands(getShape()); + p << "]"; + + p << ", strides = ["; + p.printOperands(getStrides()); + p << "]"; + + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"operandSegmentSizes"}); + + p << " : " << getResult().getType(); +} + +// Layout inference helpers for make_tensor_view +static std::optional getConstIndexValue(Value v) { + if (auto c = v.getDefiningOp()) + return c.value(); + if (auto c = v.getDefiningOp()) { + if (auto ia = dyn_cast(c.getValue())) + return ia.getInt(); + } + return std::nullopt; +} + +static FailureOr +inferPartitionViewResultTypeFromSizes(mlir::pto::TensorViewType sourceType, + ValueRange sizes) { + if (!sourceType) + return failure(); + + if ((int64_t)sizes.size() != sourceType.getRank()) + return failure(); + + SmallVector shape; + shape.reserve(sizes.size()); + for (Value size : sizes) { + auto constSize = getConstIndexValue(size); + if (constSize && *constSize >= 0) + shape.push_back(*constSize); + else + shape.push_back(ShapedType::kDynamic); + } + + return mlir::pto::PartitionTensorViewType::get( + sourceType.getContext(), shape, sourceType.getElementType()); +} + +ParseResult mlir::pto::PartitionViewOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand source; + SmallVector offsets; + SmallVector sizes; + Type sourceTy; + Type resultTy; + bool hasExplicitResultTy = false; + + if (parser.parseOperand(source) || parser.parseComma() || + parser.parseKeyword("offsets") || parser.parseEqual() || + parser.parseLSquare() || parser.parseOperandList(offsets) || + parser.parseRSquare() || parser.parseComma() || + parser.parseKeyword("sizes") || parser.parseEqual() || + parser.parseLSquare() || parser.parseOperandList(sizes) || + parser.parseRSquare() || parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(sourceTy)) + return failure(); + + if (succeeded(parser.parseOptionalArrow())) { + if (parser.parseType(resultTy)) + return failure(); + hasExplicitResultTy = true; + } + + if (parser.resolveOperand(source, sourceTy, result.operands)) + return failure(); + + Type indexTy = parser.getBuilder().getIndexType(); + if (parser.resolveOperands(offsets, indexTy, result.operands) || + parser.resolveOperands(sizes, indexTy, result.operands)) + return failure(); + + auto &properties = result.getOrAddProperties(); + llvm::copy(ArrayRef( + {1, static_cast(offsets.size()), + static_cast(sizes.size())}), + properties.operandSegmentSizes.begin()); + + if (hasExplicitResultTy) { + result.addTypes(resultTy); + return success(); + } + + ValueRange allOperands(result.operands); + ValueRange sizeOperands = + allOperands.slice(1 + offsets.size(), sizes.size()); + auto inferredResultType = inferPartitionViewResultTypeFromSizes( + dyn_cast(sourceTy), sizeOperands); + if (failed(inferredResultType)) { + return parser.emitError(parser.getCurrentLocation(), + "failed to infer pto.partition_view result type"); + } + + result.addTypes(*inferredResultType); + return success(); +} + +void mlir::pto::PartitionViewOp::print(OpAsmPrinter &printer) { + printer << " " << getSource() << ", offsets = ["; + printer.printOperands(getOffsets()); + printer << "], sizes = ["; + printer.printOperands(getSizes()); + printer << "]"; + printer.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"operandSegmentSizes"}); + printer << " : " << getSource().getType(); + + auto inferredResultType = inferPartitionViewResultTypeFromSizes( + dyn_cast(getSource().getType()), getSizes()); + if (succeeded(inferredResultType) && *inferredResultType == getResult().getType()) + return; + + printer << " -> " << getResult().getType(); +} + +static std::optional getConstantIntegerValueEx( + Value v, bool includeIndexAndIntOpsInConstFold) { + if (includeIndexAndIntOpsInConstFold) { + if (auto c = v.getDefiningOp()) + return c.value(); + if (auto c = v.getDefiningOp()) + return c.value(); + } + if (auto c = v.getDefiningOp()) { + if (auto ia = dyn_cast(c.getValue())) + return ia.getInt(); + } + return std::nullopt; +} + +static LogicalResult verifyNonNegativeIndexRowCol( + Operation &op, Value indexRow, Value indexCol, + bool includeIndexAndIntOpsInConstFold) { + if (!indexRow.getType().isIndex() || !indexCol.getType().isIndex()) + return op.emitOpError("expects indexRow and indexCol to be index type"); + auto row = + getConstantIntegerValueEx(indexRow, includeIndexAndIntOpsInConstFold); + auto col = + getConstantIntegerValueEx(indexCol, includeIndexAndIntOpsInConstFold); + if (row && *row < 0) + return op.emitOpError("expects indexRow to be non-negative"); + if (col && *col < 0) + return op.emitOpError("expects indexCol to be non-negative"); + return success(); +} + +static LogicalResult verifyExtractStaticBoundsCommon( + Operation &op, Value indexRow, Value indexCol, Type srcTy, Type dstTy, + bool includeIndexAndIntOpsInConstFold) { + auto row = + getConstantIntegerValueEx(indexRow, includeIndexAndIntOpsInConstFold); + auto col = + getConstantIntegerValueEx(indexCol, includeIndexAndIntOpsInConstFold); + auto srcShape = getShapeVec(srcTy); + auto dstShape = getShapeVec(dstTy); + if (srcShape.size() != 2 || dstShape.size() != 2) + return op.emitOpError("expects src and dst to be rank-2 tile_buf"); + if (row && srcShape[0] != ShapedType::kDynamic && + dstShape[0] != ShapedType::kDynamic && + *row + dstShape[0] > srcShape[0]) + return op.emitOpError("expects indexRow + dst.rows <= src.rows"); + if (col && srcShape[1] != ShapedType::kDynamic && + dstShape[1] != ShapedType::kDynamic && + *col + dstShape[1] > srcShape[1]) + return op.emitOpError("expects indexCol + dst.cols <= src.cols"); + return success(); +} + +static LogicalResult verifyInsertStaticBoundsCommon( + Operation &op, Value indexRow, Value indexCol, Type srcTy, Type dstTy, + bool includeIndexAndIntOpsInConstFold) { + auto row = + getConstantIntegerValueEx(indexRow, includeIndexAndIntOpsInConstFold); + auto col = + getConstantIntegerValueEx(indexCol, includeIndexAndIntOpsInConstFold); + auto srcShape = getValidShapeVec(srcTy); + auto dstShape = getShapeVec(dstTy); + if (srcShape.size() != 2 || dstShape.size() != 2) + return op.emitOpError("expects src and dst to be rank-2 tile_buf"); + if (row && srcShape[0] != ShapedType::kDynamic && + dstShape[0] != ShapedType::kDynamic && + *row + srcShape[0] > dstShape[0]) + return op.emitOpError("expects indexRow + src.rows <= dst.rows"); + if (col && srcShape[1] != ShapedType::kDynamic && + dstShape[1] != ShapedType::kDynamic && + *col + srcShape[1] > dstShape[1]) + return op.emitOpError("expects indexCol + src.cols <= dst.cols"); + return success(); +} + +static unsigned getElemByteSize(Type ty) { + return getPTOStorageElemByteSize(ty); +} + +static LogicalResult verifyTileBufLayoutConstraints(Operation *op, + pto::TileBufType tb, + StringRef name) { + auto shape = tb.getShape(); + if (shape.size() != 2) + return op->emitOpError() << "expects " << name << " to be rank-2"; + + int64_t rows = shape[0]; + int64_t cols = shape[1]; + if (rows != ShapedType::kDynamic && rows <= 0) + return op->emitOpError() << "expects " << name << " rows to be positive"; + if (cols != ShapedType::kDynamic && cols <= 0) + return op->emitOpError() << "expects " << name << " cols to be positive"; + + unsigned elemBytes = getElemByteSize(tb.getElementType()); + if (elemBytes == 0) + return op->emitOpError() << "expects " << name + << " element type to have a byte size"; + + auto cfg = tb.getConfigAttr(); + if (!cfg) + cfg = TileBufConfigAttr::getDefault(tb.getContext()); + auto readBLayout = [](Attribute attr, int32_t &out) -> bool { + if (auto layout = dyn_cast_or_null(attr)) { + out = static_cast(layout.getValue()); + return true; + } + if (auto value = dyn_cast_or_null(attr)) { + out = static_cast(value.getInt()); + return true; + } + return false; + }; + auto readSLayout = [](Attribute attr, int32_t &out) -> bool { + if (auto layout = dyn_cast_or_null(attr)) { + out = static_cast(layout.getValue()); + return true; + } + if (auto value = dyn_cast_or_null(attr)) { + out = static_cast(value.getInt()); + return true; + } + return false; + }; + int32_t blayout = 0; + int32_t slayout = 0; + if (!readBLayout(cfg.getBLayout(), blayout) || + !readSLayout(cfg.getSLayout(), slayout)) + return op->emitOpError() << "expects " << name + << " to have concrete tile layout attributes"; + constexpr int64_t kAlignedBytes = 32; + + auto checkByteAlignment = [&](int64_t dim, StringRef layoutName, + StringRef byteExpr) -> LogicalResult { + if (dim == ShapedType::kDynamic) + return success(); + int64_t bytes = dim * static_cast(elemBytes); + if (bytes % kAlignedBytes == 0) + return success(); + return op->emitOpError() + << "expects " << name << " " << layoutName + << " none_box tile " << byteExpr + << " to be 32-byte aligned, but got " << bytes << " bytes"; + }; + + if (slayout == static_cast(SLayout::NoneBox)) { + if (blayout == static_cast(BLayout::RowMajor)) + return checkByteAlignment(cols, "row-major", + "row byte size (cols * sizeof(dtype))"); + return checkByteAlignment(rows, "col-major", + "column byte size (rows * sizeof(dtype))"); + } + + int64_t innerRows = 0; + int64_t innerCols = 0; + int32_t fractal = static_cast(cfg.getSFractalSize().getInt()); + switch (fractal) { + case 1024: + innerRows = 16; + innerCols = 16; + break; + case 32: + innerRows = 16; + innerCols = 2; + break; + case 512: + if (kAlignedBytes % elemBytes != 0) + return op->emitOpError() << "expects " << name + << " element byte size to divide 32 for boxed " + "fractal-512 tile layout"; + if (slayout == static_cast(SLayout::RowMajor)) { + innerRows = 16; + innerCols = kAlignedBytes / static_cast(elemBytes); + } else if (slayout == static_cast(SLayout::ColMajor)) { + innerRows = kAlignedBytes / static_cast(elemBytes); + innerCols = 16; + } + break; + default: + break; + } + if (innerRows <= 0 || innerCols <= 0) + return op->emitOpError() << "expects " << name + << " to use a supported boxed tile layout"; + + auto loc = getPTOMemorySpaceEnum(tb); + bool allowUnalignedRows = + (loc && *loc == pto::AddressSpace::VEC) || fractal == 32 || rows == 1; + if (!allowUnalignedRows && rows != ShapedType::kDynamic && + rows % innerRows != 0) + return op->emitOpError() + << "expects " << name + << " boxed tile rows to be a multiple of innerRows (" << innerRows + << "), but got " << rows; + if (cols != ShapedType::kDynamic && cols % innerCols != 0) + return op->emitOpError() + << "expects " << name + << " boxed tile cols to be a multiple of innerCols (" << innerCols + << "), but got " << cols; + + return success(); +} + +[[maybe_unused]] static bool isSupportedLoadStoreElemTypeA2A3(Type ty) { + if (ty.isF16() || ty.isBF16() || ty.isF32()) + return true; + if (auto it = dyn_cast(ty)) { + unsigned width = it.getWidth(); + return width == 8 || width == 16 || width == 32 || width == 64; + } + return false; +} + +static bool isSupportedGatherElemTypeA2A3(Type ty) { + if (ty.isF16() || ty.isF32()) + return true; + if (auto it = dyn_cast(ty)) { + unsigned width = it.getWidth(); + return width == 16 || width == 32; + } + return false; +} + +static bool isSupportedGatherElemTypeA5(Type ty) { + if (isSupportedGatherElemTypeA2A3(ty) || ty.isBF16()) + return true; + if (auto ft = dyn_cast(ty)) { + unsigned width = ft.getWidth(); + return width == 8; + } + if (auto it = dyn_cast(ty)) + return it.getWidth() == 8 || it.getWidth() == 16 || it.getWidth() == 32; + return false; +} + +static std::optional +inferLayout(ArrayRef shape, ArrayRef strides, + unsigned elemBytes) { + if (shape.size() != strides.size() || elemBytes == 0) + return std::nullopt; + + // NZ / fractal: rank>=5, check middle dims (sh3/sh4/sh5 per spec) + if (shape.size() >= 5) { + int64_t sh3 = shape[2], sh4 = shape[3], sh5 = shape[4]; + int64_t st4 = strides[3], st5 = strides[4]; + bool alignMatch = (sh3 == 16) && (sh3 * sh4 * elemBytes == 512); + bool strideMatch = (st5 == 1) && (st4 == sh5); + if (alignMatch && strideMatch) + return mlir::pto::Layout::NZ; + } + + // ND: row-major contiguous + bool isRowMajor = true; + for (int i = 0, e = (int)shape.size() - 1; i < e; ++i) { + if (strides[i] != strides[i + 1] * shape[i + 1]) { + isRowMajor = false; + break; + } + } + if (isRowMajor && strides.back() == 1) + return mlir::pto::Layout::ND; + + // DN: col-major + bool isColMajor = true; + for (int i = 0, e = (int)shape.size() - 1; i < e; ++i) { + if (strides[i + 1] != strides[i] * shape[i]) { + isColMajor = false; + break; + } + } + if (isColMajor && strides.front() == 1) + return mlir::pto::Layout::DN; + + return mlir::pto::Layout::ND; // fallback +} + +static std::optional getLogicalViewLayout(Value value) { + if (!value) + return std::nullopt; + if (auto part = value.getDefiningOp()) + return getLogicalViewLayout(part.getSource()); + if (auto make = value.getDefiningOp()) { + auto tvTy = dyn_cast(make.getResult().getType()); + if (!tvTy) + return std::nullopt; + SmallVector shape(tvTy.getShape().begin(), tvTy.getShape().end()); + SmallVector strides; + strides.reserve(make.getStrides().size()); + for (Value stride : make.getStrides()) { + auto cst = getConstIndexValue(stride); + if (!cst) + return std::nullopt; + strides.push_back(*cst); + } + return inferLayout(shape, strides, getElemByteSize(tvTy.getElementType())); + } + return std::nullopt; +} + +static std::optional getTileBufLogicalLayout(pto::TileBufType type) { + if (!type) + return std::nullopt; + int32_t sl = type.getSLayoutValueI32(); + int32_t bl = type.getBLayoutValueI32(); + if (sl != static_cast(pto::SLayout::NoneBox)) + return pto::Layout::NZ; + if (bl == static_cast(pto::BLayout::RowMajor)) + return pto::Layout::ND; + if (bl == static_cast(pto::BLayout::ColMajor)) + return pto::Layout::DN; + return std::nullopt; +} + +static bool isRowMajorTileBuf(Type ty) { + auto tb = mlir::dyn_cast(ty); + return tb && tb.getBLayoutValueI32() == static_cast(pto::BLayout::RowMajor); +} + +static LogicalResult verifyRowReductionSrcLayout(Operation *op, Type ty, + StringRef name) { + if (failed(verifyTileBufCommon(op, ty, name))) + return failure(); + auto as = getPTOMemorySpaceEnum(ty); + if (!as || *as != pto::AddressSpace::VEC) + return op->emitOpError() << "expects " << name << " to be in the vec address space"; + if (auto tb = dyn_cast(ty)) { + if (tb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor)) + return op->emitOpError() << "expects " << name << " to use the row_major blayout"; + } + if (auto mr = dyn_cast(ty)) + (void)mr; + if (auto tb = dyn_cast(ty)) { + if (tb.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) + return op->emitOpError() << "expects " << name + << " to use the none_box slayout"; + } + if (auto tb = dyn_cast(ty)) { + auto layout = getTileBufLogicalLayout(tb); + if (layout && *layout != pto::Layout::ND) + return op->emitOpError() << "expects " << name + << " to use an ND-style tile layout"; + } + return success(); +} + +static LogicalResult verifyRowReductionDstLayout(Operation *op, Type ty, + StringRef name) { + if (failed(verifyTileBufCommon(op, ty, name))) + return failure(); + auto as = getPTOMemorySpaceEnum(ty); + if (!as || *as != pto::AddressSpace::VEC) + return op->emitOpError() << "expects " << name << " to be in the vec address space"; + if (auto tb = dyn_cast(ty)) { + if (tb.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) + return op->emitOpError() << "expects " << name + << " to use the none_box slayout"; + } + if (auto tb = dyn_cast(ty)) { + if (tb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) && + tb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor)) + return op->emitOpError() << "expects " << name + << " to use the row_major or col_major blayout"; + } + if (auto mr = dyn_cast(ty)) + (void)mr; + if (auto tb = dyn_cast(ty)) { + auto layout = getTileBufLogicalLayout(tb); + if (layout && *layout == pto::Layout::DN) { + auto shape = getShapeVec(ty); + if (shape.size() == 2 && shape[1] != ShapedType::kDynamic && shape[1] != 1) + return op->emitOpError() << "expects DN-style " << name + << " to have shape[1] == 1"; + return success(); + } + if (layout && *layout == pto::Layout::ND) + return success(); + if (layout) + return op->emitOpError() << "expects " << name + << " to use a DN-style column vector tile or legacy ND-style tile"; + } + return success(); + auto valid = getValidShapeVec(ty); + if (valid.size() != 2) + return op->emitOpError() << "expects " << name << " to have rank-2 valid_shape"; + if (valid[1] != ShapedType::kDynamic && valid[1] != 1) + return op->emitOpError() << "expects " << name << " valid_shape[1] to be 1"; + return success(); +} + +static LogicalResult verifyRowReductionValidRegion(Operation *op, Type srcTy, + Type dstTy) { + auto srcValid = getValidShapeVec(srcTy); + auto dstValid = getValidShapeVec(dstTy); + if (srcValid.size() != 2 || dstValid.size() != 2) + return op->emitOpError("expects src and dst to have rank-2 valid_shape"); + if (srcValid[0] != ShapedType::kDynamic && srcValid[0] == 0) + return op->emitOpError("expects src valid_shape[0] to be non-zero"); + if (srcValid[1] != ShapedType::kDynamic && srcValid[1] == 0) + return op->emitOpError("expects src valid_shape[1] to be non-zero"); + if (srcValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && + srcValid[0] != dstValid[0]) + return op->emitOpError("expects src and dst to have the same valid_shape[0]"); + if (dstValid[1] != ShapedType::kDynamic && dstValid[1] != 1) + return op->emitOpError("expects dst valid_shape[1] to be 1"); + return success(); +} + +static bool isSupportedRowReductionElemType(Type elem) { + return elem.isInteger(16) || elem.isInteger(32) || elem.isF16() || + elem.isF32(); +} + +static LogicalResult verifyTRowReductionNoTmpCommon(Operation *op, Type srcTy, + Type dstTy, + StringRef elemTypeError) { + if (failed(verifyRowReductionSrcLayout(op, srcTy, "src")) || + failed(verifyRowReductionDstLayout(op, dstTy, "dst"))) + return failure(); + if (getElemTy(srcTy) != getElemTy(dstTy)) + return op->emitOpError("expects src and dst to have the same element type"); + if (failed(verifyRowReductionValidRegion(op, srcTy, dstTy))) + return failure(); + if (!isSupportedRowReductionElemType(getElemTy(srcTy))) + return op->emitOpError(elemTypeError); + return success(); +} + +static LogicalResult verifyTRowReductionWithTmpCommon(Operation *op, Type srcTy, + Type tmpTy, Type dstTy, + StringRef elemTypeError) { + if (failed(verifyRowReductionSrcLayout(op, srcTy, "src")) || + failed(verifyVecTileCommon(op, tmpTy, "tmp")) || + failed(verifyRowReductionDstLayout(op, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameElemType(op, srcTy, tmpTy, "src", "tmp")) || + failed(verifyTileBufSameValidShape(op, srcTy, tmpTy, "src", "tmp"))) + return failure(); + if (getElemTy(srcTy) != getElemTy(dstTy)) + return op->emitOpError("expects src and dst to have the same element type"); + if (failed(verifyRowReductionValidRegion(op, srcTy, dstTy))) + return failure(); + if (!isSupportedRowReductionElemType(getElemTy(srcTy))) + return op->emitOpError(elemTypeError); + return success(); +} + +static LogicalResult verifyTRowArgReductionCommon(Operation *op, Type srcTy, + Type tmpTy, Type dstTy) { + if (failed(verifyRowReductionSrcLayout(op, srcTy, "src")) || + failed(verifyVecTileCommon(op, tmpTy, "tmp")) || + failed(verifyRowReductionDstLayout(op, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameElemType(op, srcTy, tmpTy, "src", "tmp")) || + failed(verifyTileBufSameValidShape(op, srcTy, tmpTy, "src", "tmp"))) + return failure(); + if (failed(verifyRowReductionValidRegion(op, srcTy, dstTy))) + return failure(); + Type srcElem = getElemTy(srcTy); + if (!isSupportedRowReductionElemType(srcElem)) + return op->emitOpError("expects src element type to be i16/i32/f16/f32"); + auto dstInt = dyn_cast(getElemTy(dstTy)); + if (!dstInt || dstInt.getWidth() != 32) + return op->emitOpError("expects dst element type to be i32 or ui32"); + return success(); +} + +static LogicalResult verifyNDStyleVecTile(Operation *op, Type ty, StringRef name) { + if (failed(verifyTileBufCommon(op, ty, name))) + return failure(); + auto as = getPTOMemorySpaceEnum(ty); + if (!as || *as != pto::AddressSpace::VEC) + return op->emitOpError() << "expects " << name << " to be in the vec address space"; + if (auto tb = dyn_cast(ty)) { + if (tb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor)) + return op->emitOpError() << "expects " << name << " to use the row_major blayout"; + if (tb.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) + return op->emitOpError() << "expects " << name << " to use the none_box slayout"; + } + return success(); +} + +static LogicalResult verifyColReductionValidRegion(Operation *op, Type srcTy, + Type dstTy, + bool requireNonZeroSrc) { + auto srcValid = getValidShapeVec(srcTy); + auto dstValid = getValidShapeVec(dstTy); + if (srcValid.size() != 2 || dstValid.size() != 2) + return op->emitOpError("expects src and dst to have rank-2 valid_shape"); + if (requireNonZeroSrc) { + if (srcValid[0] != ShapedType::kDynamic && srcValid[0] == 0) + return op->emitOpError("expects src valid_shape[0] to be non-zero"); + if (srcValid[1] != ShapedType::kDynamic && srcValid[1] == 0) + return op->emitOpError("expects src valid_shape[1] to be non-zero"); + } + if (srcValid[1] != ShapedType::kDynamic && dstValid[1] != ShapedType::kDynamic && + srcValid[1] != dstValid[1]) + return op->emitOpError("expects src and dst to have the same valid_shape[1]"); + return success(); +} + +static LogicalResult verifyColArgReductionDstLayout(Operation *op, Type ty, + StringRef name) { + if (failed(verifyNDStyleVecTile(op, ty, name))) + return failure(); + auto valid = getValidShapeVec(ty); + if (valid.size() != 2) + return op->emitOpError() << "expects " << name + << " to have rank-2 valid_shape"; + if (valid[0] != ShapedType::kDynamic && valid[0] != 1) + return op->emitOpError() << "expects " << name + << " valid_shape[0] to be 1"; + return success(); +} + +static std::optional getConstantIntegerValue(Value value) { + if (!value) + return std::nullopt; + if (auto arithCst = value.getDefiningOp()) { + if (auto intAttr = dyn_cast(arithCst.getValue())) + return intAttr.getInt(); + } + return std::nullopt; +} + +LogicalResult mlir::pto::MakeTensorViewOp::verify() { + auto tvTy = dyn_cast(getResult().getType()); + if (!tvTy) + return emitOpError("result must be pto.tensor_view<...>"); + + auto pty = dyn_cast(getPtr().getType()); + if (!pty) + return emitOpError("ptr operand must be !pto.ptr<...>"); + + if (pty.getElementType() != tvTy.getElementType()) + return emitOpError() << "ptr element type must match tensor_view element type, but got ptr=" + << pty.getElementType() << " view=" << tvTy.getElementType(); + + int64_t rank = tvTy.getRank(); + + if ((int64_t)getShape().size() != rank || (int64_t)getStrides().size() != rank) + return emitOpError() << "shape/strides operand counts must match tensor_view rank=" + << rank; + + // Detect dynamic shape/stride. + bool hasDynamicShape = llvm::any_of(tvTy.getShape(), [](int64_t v) { + return v == ShapedType::kDynamic; + }); + bool hasDynamicStride = llvm::any_of(getStrides(), [](Value s) { + return !getConstIndexValue(s).has_value(); + }); + + auto layoutAttr = getLayoutAttr(); + + // 1) Dynamic shape/stride without explicit layout: warn and keep going. + if ((hasDynamicShape || hasDynamicStride) && !layoutAttr) { + return success(); + } + + // 2) Static shape/stride with explicit layout: verify correctness. + bool allStaticStride = true; + SmallVector strideInts; + strideInts.reserve(getStrides().size()); + for (Value s : getStrides()) { + auto val = getConstIndexValue(s); + if (!val) { + allStaticStride = false; + break; + } + strideInts.push_back(*val); + } + + bool allStaticShape = + llvm::none_of(tvTy.getShape(), [](int64_t v) { return v == ShapedType::kDynamic; }); + + if (layoutAttr && allStaticShape && allStaticStride) { + SmallVector shapeInts(tvTy.getShape().begin(), tvTy.getShape().end()); + if (auto inferred = inferLayout(shapeInts, strideInts, + getElemByteSize(tvTy.getElementType()))) { + (void)inferred; + } + } + + return success(); +} + +LogicalResult mlir::pto::PartitionViewOp::verify() { + auto srcTy = dyn_cast(getSource().getType()); + auto resTy = dyn_cast(getResult().getType()); + if (!srcTy || !resTy) + return emitOpError("expects tensor_view source and partition_tensor_view result"); + + if (srcTy.getElementType() != resTy.getElementType()) + return emitOpError() << "element type mismatch between source and result: src=" + << srcTy.getElementType() << " result=" + << resTy.getElementType(); + + int64_t srcRank = srcTy.getRank(); + if ((int64_t)getOffsets().size() != srcRank) + return emitOpError() << "offset count (" << getOffsets().size() + << ") must match source rank (" << srcRank << ")"; + + if ((int64_t)getSizes().size() != srcRank) + return emitOpError() << "size count (" << getSizes().size() + << ") must match source rank (" << srcRank << ")"; + + ArrayRef srcShape = srcTy.getShape(); + ArrayRef resShape = resTy.getShape(); + bool sameRank = resTy.getRank() == srcRank; + + for (int64_t i = 0; i < srcRank; ++i) { + auto offVal = getConstIndexValue(getOffsets()[i]); + auto sizeVal = getConstIndexValue(getSizes()[i]); + + if (offVal && *offVal < 0) + return emitOpError() << "offset at dim " << i + << " must be non-negative, got " << *offVal; + + if (sizeVal && *sizeVal <= 0) + return emitOpError() << "size at dim " << i + << " must be positive, got " << *sizeVal; + + if (sameRank && sizeVal) { + int64_t resDim = resShape[i]; + if (resDim != ShapedType::kDynamic && *sizeVal != resDim) + return emitOpError() << "size/result mismatch at dim " << i + << ": size operand=" << *sizeVal + << " result type dim=" << resDim; + } + + int64_t srcDim = srcShape[i]; + if (srcDim == ShapedType::kDynamic) + continue; + + if (sizeVal && *sizeVal > srcDim) + return emitOpError() << "size at dim " << i << " (" << *sizeVal + << ") exceeds static source dim (" << srcDim << ")"; + + if (offVal && sizeVal && (*offVal + *sizeVal > srcDim)) + return emitOpError() << "offset+size at dim " << i << " (" + << (*offVal + *sizeVal) + << ") exceeds static source dim (" << srcDim << ")"; + } + + return success(); +} + +LogicalResult mlir::pto::AddPtrOp::verify() { + Value ptr = getOperation()->getOperand(0); + Value result = getOperation()->getResult(0); + + auto ptrTy = dyn_cast(ptr.getType()); + if (!ptrTy) + return emitOpError("ptr operand must be !pto.ptr<...>"); + + auto resTy = dyn_cast(result.getType()); + if (!resTy) + return emitOpError("result must be !pto.ptr<...>"); + + if (ptrTy != resTy) + return emitOpError("result type must match ptr operand type"); + + return success(); +} + +static LogicalResult verifyPtrLikeForAddressCast(Operation *op, Type type, + StringRef name) { + if (isa(type)) + return success(); + + auto memTy = dyn_cast(type); + if (!memTy) + return op->emitOpError() + << "expects " << name << " to be !pto.ptr<...> or a GM memref"; + + if (memTy.getRank() != 1) + return op->emitOpError() + << "expects lowered memref " << name << " to be rank-1"; + + if (!isGmAddressSpaceAttr(memTy.getMemorySpace())) + return op->emitOpError() + << "expects lowered memref " << name << " to use GM address space"; + + return success(); +} + +static Type getPointerLikeElementType(Type type) { + if (auto ptrTy = dyn_cast(type)) + return ptrTy.getElementType(); + if (auto memTy = dyn_cast(type)) + return memTy.getElementType(); + return Type(); +} + +static bool isEmitCSupportedScalarType(Type type) { + if (!type) + return false; + if (type.isF16() || type.isBF16() || type.isF32() || type.isF64()) + return true; + if (auto intTy = dyn_cast(type)) + return intTy.getWidth() == 8 || intTy.getWidth() == 16 || + intTy.getWidth() == 32 || intTy.getWidth() == 64; + if (mlir::pto::isPTOFloat8Type(type)) + return true; + if (isa(type)) + return true; + return false; +} + +LogicalResult mlir::pto::PtrToIntOp::verify() { + Type resultTy = getResult().getType(); + auto intTy = dyn_cast(resultTy); + if (!intTy || intTy.getWidth() != 64) + return emitOpError("result must be i64"); + + return verifyPtrLikeForAddressCast(getOperation(), getPtr().getType(), + "ptr operand"); +} + +LogicalResult mlir::pto::IntToPtrOp::verify() { + auto addrTy = dyn_cast(getAddr().getType()); + if (!addrTy || addrTy.getWidth() != 64) + return emitOpError("address operand must be i64"); + + if (failed(verifyPtrLikeForAddressCast(getOperation(), getResult().getType(), + "result"))) + return failure(); + + Type dstElem = getPointerLikeElementType(getResult().getType()); + if (!isEmitCSupportedScalarType(dstElem)) + return emitOpError("result element type is not supported by EmitC: ") + << dstElem; + + return success(); +} + +LogicalResult mlir::pto::LocalArrayGetOp::verify() { + auto arrayTy = getArray().getType(); + int64_t rank = arrayTy.getRank(); + int64_t numIdx = static_cast(getIndices().size()); + if (numIdx != rank) + return emitOpError() << "expects " << rank + << " indices for !pto.local_array of rank " << rank + << ", got " << numIdx; + if (getResult().getType() != arrayTy.getElementType()) + return emitOpError() + << "result type " << getResult().getType() + << " does not match array element type " + << arrayTy.getElementType(); + return success(); +} + +LogicalResult mlir::pto::LocalArraySetOp::verify() { + auto arrayTy = getArray().getType(); + int64_t rank = arrayTy.getRank(); + int64_t numIdx = static_cast(getIndices().size()); + if (numIdx != rank) + return emitOpError() << "expects " << rank + << " indices for !pto.local_array of rank " << rank + << ", got " << numIdx; + if (getValue().getType() != arrayTy.getElementType()) + return emitOpError() << "value type " << getValue().getType() + << " does not match array element type " + << arrayTy.getElementType(); + return success(); +} + + + + +void PTODialect::initialize() { + addTypes< +#define GET_TYPEDEF_LIST +#include "PTO/IR/PTOTypeDefs.cpp.inc" + >(); + + addOperations< +#define GET_OP_LIST +#include "PTO/IR/PTOOps.cpp.inc" + >(); + + addAttributes< +#define GET_ATTRDEF_LIST +#include "PTO/IR/PTOAttrs.cpp.inc" + >(); +} + + +AddressSpaceAttr mlir::pto::getPTOAddressSpaceAttr(Type type) { + auto memRefType = dyn_cast(type); + if (!memRefType) + return {}; + auto scopeAttr = dyn_cast(memRefType.getMemorySpace()); + if (!scopeAttr) + return {}; + return scopeAttr; +} + +bool mlir::pto::isScalarPtrOrMemRef(Type type) { + if (auto pty = dyn_cast(type)) + return true; + if (auto memTy = dyn_cast(type)) + return isGmAddressSpaceAttr(memTy.getMemorySpace()); + return false; +} + +bool mlir::pto::hasExplicitPTOEntryAttr(func::FuncOp func) { + return func && (func->hasAttrOfType(kPTOEntryAttrName) || + func->hasAttrOfType(kLegacyHACCEntryAttrName)); +} + +static constexpr StringLiteral kEffectivePTOEntryAttrName = + "pto.internal.entry"; + +static SmallVector getPTOFunctionDefinitions(ModuleOp module) { + SmallVector defs; + if (!module) + return defs; + for (auto func : module.getOps()) { + if (!func.isDeclaration()) + defs.push_back(func); + } + return defs; +} + +bool mlir::pto::isPTOEntryFunction(func::FuncOp func) { + if (!func || func.isDeclaration()) + return false; + if (auto attr = func->getAttrOfType(kEffectivePTOEntryAttrName)) + return attr.getValue(); + if (hasExplicitPTOEntryAttr(func)) + return true; + + ModuleOp module = func->getParentOfType(); + if (!module) + return false; + SmallVector defs = getPTOFunctionDefinitions(module); + return defs.size() == 1 && defs.front() == func; +} + +LogicalResult mlir::pto::validatePTOEntryFunctions(ModuleOp module) { + if (!module) + return success(); + + for (auto func : module.getOps()) { + if (!hasExplicitPTOEntryAttr(func)) + continue; + if (func.isDeclaration()) { + return func.emitOpError() + << "`" << kPTOEntryAttrName + << "` is only valid on function definitions"; + } + } + + for (auto func : module.getOps()) { + if (!isPTOEntryFunction(func)) + continue; + if (func.getFunctionType().getNumResults() != 0) { + return func.emitOpError() + << "PTO entry functions must return void"; + } + } + return success(); +} + +void mlir::pto::annotatePTOEntryFunctions(ModuleOp module) { + if (!module) + return; + + SmallVector defs = getPTOFunctionDefinitions(module); + for (auto func : module.getOps()) + func->removeAttr(kEffectivePTOEntryAttrName); + + if (defs.empty()) + return; + if (defs.size() == 1) { + defs.front()->setAttr(kEffectivePTOEntryAttrName, + BoolAttr::get(module.getContext(), true)); + return; + } + + for (auto func : defs) { + func->setAttr(kEffectivePTOEntryAttrName, + BoolAttr::get(module.getContext(), + hasExplicitPTOEntryAttr(func))); + } +} + +//===----------------------------------------------------------------------===// +// PTO Load/Store/Addf (non-DPS polymorphic) verification + inference. +// - If operands are memref/tensor: verify strictly. +// - Otherwise (tile_view/tile etc): accept (so old IR can still parse). +//===----------------------------------------------------------------------===// + +[[maybe_unused]] static LogicalResult verifyMemrefToTensorLoad(Operation *op, Value src, Value res) { + auto mr = dyn_cast(src.getType()); + auto rt = dyn_cast(res.getType()); + if (!mr) + return success(); // non-memref case: don't block old IR + if (!rt) + return op->emitOpError("when src is memref, result must be ranked tensor"); + + if (mr.getElementType() != rt.getElementType()) + return op->emitOpError() << "memref/tensor element type mismatch: memref=" + << mr.getElementType() << " tensor=" << rt.getElementType(); + + if (mr.getRank() != rt.getRank()) + return op->emitOpError() << "rank mismatch: memref rank=" << mr.getRank() + << " tensor rank=" << rt.getRank(); + + if (mr.hasStaticShape()) { + if (!rt.hasStaticShape()) + return op->emitOpError("memref has static shape but result tensor is not static"); + if (mr.getShape() != rt.getShape()) + return op->emitOpError() << "shape mismatch: memref=" << mr << " tensor=" << rt; + } else { + // For dynamic memref dims: if tensor dim is static, allow it; if it's dynamic too, also fine. + // We only reject when a memref static dim conflicts with tensor static dim. + for (int64_t i = 0; i < mr.getRank(); ++i) { + int64_t md = mr.getDimSize(i); + int64_t td = rt.getDimSize(i); + if (md != ShapedType::kDynamic && td != ShapedType::kDynamic && md != td) + return op->emitOpError() << "dim mismatch at " << i << ": memref=" << md << " tensor=" << td; + } + } + return success(); +} + +[[maybe_unused]] static LogicalResult verifyMemrefTensorStore(Operation *op, Value dst, Value src) { + auto mr = dyn_cast(dst.getType()); + if (!mr) + return success(); // non-memref case: old tile IR allowed + auto rt = dyn_cast(src.getType()); + if (!rt) + return op->emitOpError("when dst is memref, src must be ranked tensor"); + + if (mr.getElementType() != rt.getElementType()) + return op->emitOpError() << "memref/tensor element type mismatch: memref=" + << mr.getElementType() << " tensor=" << rt.getElementType(); + + if (mr.getRank() != rt.getRank()) + return op->emitOpError() << "rank mismatch: memref rank=" << mr.getRank() + << " tensor rank=" << rt.getRank(); + + for (int64_t i = 0; i < mr.getRank(); ++i) { + int64_t md = mr.getDimSize(i); + int64_t td = rt.getDimSize(i); + if (md != ShapedType::kDynamic && td != ShapedType::kDynamic && md != td) + return op->emitOpError() << "dim mismatch at " << i << ": memref=" << md << " tensor=" << td; + } + return success(); +} + +LogicalResult AllocTileOp::verify() { + auto ty = getResult().getType(); // TileBufType + + if (failed(verifyTileBufLayoutConstraints(*this, ty, "result"))) + return failure(); + + // op 上有没有传 operands + bool hasVR = getValidRow() != nullptr; + bool hasVC = getValidCol() != nullptr; + + // type 上的 validShape + auto vs = ty.getValidShape(); + if (vs.size() != 2) + return emitOpError("result tile_buf must have rank-2 validShape"); + + // TileBuf valid dims use a negative sentinel (e.g. '?' / -1). Be robust to + // any negative value (some code may materialize MLIR dynamic sentinels). + bool needVR = (vs[0] < 0); + bool needVC = (vs[1] < 0); + + // 你要求的:v_row=?, v_col=? 时必须同时给两个 + // (这条规则由下面两句自然实现) + if (hasVR != needVR) + return emitOpError() << "valid_row operand " + << (needVR ? "is required" : "must be absent") + << " because result type v_row is " + << (needVR ? "?" : std::to_string(vs[0])); + + if (hasVC != needVC) + return emitOpError() << "valid_col operand " + << (needVC ? "is required" : "must be absent") + << " because result type v_col is " + << (needVC ? "?" : std::to_string(vs[1])); + + return success(); +} + +LogicalResult MaterializeTileOp::verify() { + auto sourceTy = cast(getSource().getType()); + auto resultTy = cast(getResult().getType()); + + if (sourceTy.getRank() != 2) + return emitOpError("source memref must be rank-2 to materialize a tile handle"); + if (resultTy.getRank() != 2) + return emitOpError("result tile_buf must be rank-2"); + if (failed(verifyTileBufLayoutConstraints(*this, resultTy, "result"))) + return failure(); + + auto viewSemantics = (*this)->getAttrOfType("pto.view_semantics"); + bool isSubview = viewSemantics && viewSemantics.getValue() == "subview"; + if (!isSubview && sourceTy.getShape() != resultTy.getShape()) + return emitOpError() << "source/result shape mismatch: source=" + << sourceTy << " result=" << resultTy; + + if (sourceTy.getElementType() != resultTy.getElementType()) + return emitOpError() << "source/result element type mismatch: source=" + << sourceTy.getElementType() + << " result=" << resultTy.getElementType(); + + if (sourceTy.getMemorySpace() != resultTy.getMemorySpace()) + return emitOpError() << "source/result memory space mismatch"; + + if (getConfig() != resultTy.getConfigAttr()) + return emitOpError("config attribute must match the result tile_buf config"); + + auto shape = resultTy.getShape(); + auto validShape = resultTy.getValidShape(); + if (validShape.size() != 2) + return emitOpError("result tile_buf must have rank-2 validShape"); + for (unsigned i = 0; i < 2; ++i) { + if (shape[i] != ShapedType::kDynamic && + validShape[i] != ShapedType::kDynamic && validShape[i] > shape[i]) { + return emitOpError() << "valid_shape[" << i << "] must be <= shape[" + << i << "]"; + } + } + + return success(); +} + +LogicalResult TAssignOp::verify() { + if (getTile().getType() != getResult().getType()) { + return emitOpError("result type must match tile operand type"); + } + return success(); +} + +LogicalResult TLoadOp::verify() { + auto verifyCommon = + [&](bool allowLowPrecision) + -> FailureOr> { + auto srcPart = dyn_cast(getSrc().getType()); + auto dstTile = dyn_cast(getDst().getType()); + if (!srcPart || !dstTile) { + emitOpError("expects src to be !pto.partition_tensor_view and dst to be !pto.tile_buf"); + return failure(); + } + if (failed(verifyTileBufCommon(*this, dstTile, "dst", allowLowPrecision))) + return failure(); + + auto srcShape = srcPart.getShape(); + for (unsigned i = 0; i < srcShape.size(); ++i) { + if (srcShape[i] != ShapedType::kDynamic && srcShape[i] <= 0) { + emitOpError() << "expects src shape[" << i << "] to be positive"; + return failure(); + } + } + auto dstValid = dstTile.getValidShape(); + for (unsigned i = 0; i < dstValid.size(); ++i) { + if (dstValid[i] != ShapedType::kDynamic && dstValid[i] <= 0) { + emitOpError() << "expects dst valid_shape[" << i << "] to be positive"; + return failure(); + } + } + return std::make_pair(srcPart, dstTile); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + auto common = verifyCommon(/*allowLowPrecision=*/false); + if (failed(common)) + return failure(); + auto [srcPart, dstTile] = *common; + + Type srcElem = srcPart.getElementType(); + Type dstElem = dstTile.getElementType(); + if (isPTOLowPrecisionType(srcElem) || isPTOLowPrecisionType(dstElem)) + return emitOpError("expects A2/A3 tload low-precision element types to be unsupported"); + if (!(dstElem.isInteger(8) || dstElem.isInteger(16) || dstElem.isInteger(32) || + dstElem.isInteger(64) || dstElem.isF16() || dstElem.isBF16() || dstElem.isF32())) + return emitOpError("expects A2/A3 tload dst element type to be i8/i16/i32/i64/u64/f16/bf16/f32"); + + auto dstSpace = getPTOMemorySpaceEnum(dstTile); + if (!dstSpace || (*dstSpace != pto::AddressSpace::VEC && + *dstSpace != pto::AddressSpace::MAT)) + return emitOpError("expects A2/A3 tload dst to use loc=vec or loc=mat"); + + if (getElemByteSize(srcElem) != getElemByteSize(dstElem)) + return emitOpError("expects src and dst element types to have the same bitwidth"); + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + auto common = verifyCommon(/*allowLowPrecision=*/true); + if (failed(common)) + return failure(); + auto [srcPart, dstTile] = *common; + + Type srcElem = srcPart.getElementType(); + Type dstElem = dstTile.getElementType(); + unsigned srcBytes = getElemByteSize(srcElem); + unsigned dstBytes = getElemByteSize(dstElem); + if (srcBytes != dstBytes) + return emitOpError("expects src and dst element types to have the same element size"); + if (!(dstBytes == 1 || dstBytes == 2 || dstBytes == 4 || dstBytes == 8)) + return emitOpError("expects A5 tload dst element size to be 1, 2, 4, or 8 bytes"); + if (!isA5TLoadStoreTransferElemType(srcElem)) + return emitOpError("expects A5 tload src element type to be i8/i16/i32/i64/f16/bf16/f32/f8/hif8/fp4"); + if (!isA5TLoadStoreTransferElemType(dstElem)) + return emitOpError("expects A5 tload dst element type to be i8/i16/i32/i64/f16/bf16/f32/f8/hif8/fp4"); + + if (dstElem.isInteger(64)) { + auto pad = dstTile.getPadValueI32(); + if (pad != static_cast(pto::PadValue::Null) && + pad != static_cast(pto::PadValue::Zero)) + return emitOpError("expects A5 i64/u64 tload dst pad to be null or zero"); + } + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult TPrefetchOp::verify() { + auto verifyImpl = [&](bool allowLowPrecision) -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + + Type srcElem; + Type dstElem; + + if (auto srcPart = dyn_cast(srcTy)) { + auto srcShape = srcPart.getShape(); + for (unsigned i = 0; i < srcShape.size(); ++i) { + if (srcShape[i] != ShapedType::kDynamic && srcShape[i] <= 0) + return emitOpError() << "expects src shape[" << i << "] to be positive"; + } + srcElem = srcPart.getElementType(); + } else if (auto srcMr = dyn_cast(srcTy)) { + if (!srcMr.hasRank()) + return emitOpError("expects src memref to be ranked"); + for (int64_t dim : srcMr.getShape()) { + if (dim != ShapedType::kDynamic && dim <= 0) + return emitOpError("expects src memref shape to be positive"); + } + srcElem = srcMr.getElementType(); + } else { + return emitOpError("expects src to be !pto.partition_tensor_view or memref"); + } + + if (auto dstTile = dyn_cast(dstTy)) { + if (failed(verifyTileBufCommon(*this, dstTile, "dst", allowLowPrecision))) + return failure(); + auto dstValid = dstTile.getValidShape(); + for (unsigned i = 0; i < dstValid.size(); ++i) { + if (dstValid[i] != ShapedType::kDynamic && dstValid[i] <= 0) + return emitOpError() << "expects dst valid_shape[" << i << "] to be positive"; + } + auto dstSpace = getPTOMemorySpaceEnum(dstTile); + if (!dstSpace || (*dstSpace != pto::AddressSpace::VEC && + *dstSpace != pto::AddressSpace::MAT)) + return emitOpError("expects dst to use loc=vec or loc=mat"); + dstElem = dstTile.getElementType(); + } else if (auto dstMr = dyn_cast(dstTy)) { + auto dstSpace = getPTOMemorySpaceEnum(dstMr); + if (!dstSpace || (*dstSpace != pto::AddressSpace::VEC && + *dstSpace != pto::AddressSpace::MAT)) + return emitOpError("expects dst memref to use loc=vec or loc=mat"); + if (!dstMr.hasRank()) + return emitOpError("expects dst memref to be ranked"); + if (failed(verifyTileBufCommon(*this, dstMr, "dst", allowLowPrecision))) + return failure(); + dstElem = dstMr.getElementType(); + } else { + return emitOpError("expects dst to be !pto.tile_buf or memref"); + } + + if (getElemByteSize(srcElem) != getElemByteSize(dstElem)) + return emitOpError("expects src and dst element types to have the same element size"); + if (!allowLowPrecision && + (isPTOLowPrecisionType(srcElem) || isPTOLowPrecisionType(dstElem))) + return emitOpError("expects A2/A3 tprefetch low-precision element types to be unsupported"); + if (allowLowPrecision && + (!isA5TLoadStoreTransferElemType(srcElem) || + !isA5TLoadStoreTransferElemType(dstElem))) + return emitOpError("expects A5 tprefetch element types to be i8/i16/i32/i64/f16/bf16/f32/f8/hif8/fp4"); + return success(); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + return verifyImpl(/*allowLowPrecision=*/false); + }; + auto verifyA5 = [&]() -> LogicalResult { + return verifyImpl(/*allowLowPrecision=*/true); + }; + switch (getVerifierTargetArch(getOperation())) { + case VerifierTargetArch::A2A3: + return verifyA2A3(); + case VerifierTargetArch::A5: + return verifyA5(); + } + return failure(); +} + +LogicalResult MakePrefetchAsyncContextOp::verify() { + Type workspaceTy = getWorkspace().getType(); + Type elemTy = nullptr; + if (auto ptrTy = dyn_cast(workspaceTy)) { + elemTy = ptrTy.getElementType(); + } else if (auto memTy = dyn_cast(workspaceTy)) { + if (!isGmAddressSpaceAttr(memTy.getMemorySpace())) + return emitOpError("expects workspace memref to be in GM address space"); + elemTy = memTy.getElementType(); + } else { + return emitOpError("expects workspace to be !pto.ptr or GM memref"); + } + if (!isByteIntegerType(elemTy)) + return emitOpError("expects workspace element type to be an 8-bit integer"); + return success(); +} + +LogicalResult TPrefetchAsyncOp::verify() { + if (failed(verifyAsyncFlatContiguous1DGMViewLike(getOperation(), getSrc(), + "src"))) + return failure(); + return success(); +} + +LogicalResult mlir::pto::SetFFTsOp::verify() { + auto mr = llvm::dyn_cast(getFfts().getType()); + if (!mr) + return emitOpError("expects a memref operand"); + + if (!mr.getElementType().isInteger(64) && !mr.getElementType().isInteger(8)) + return emitOpError("expects element type i64 (or i8)"); + + return mlir::success(); +} + +ParseResult mlir::pto::SyncSetOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseSyncEventOpCommon(parser, result, + SyncSetOp::getPipeAttrName(result.name), + SyncSetOp::getEventIdAttrName(result.name)); +} + +void mlir::pto::SyncSetOp::print(OpAsmPrinter &p) { + printSyncEventOpCommon(p, getOperation(), getPipe(), getEventIdAttr(), + getEventIdDyn(), getPipeAttrName().getValue(), + getEventIdAttrName().getValue()); +} + +LogicalResult mlir::pto::SyncSetOp::verify() { + bool hasStatic = getEventIdAttr() != nullptr; + bool hasDynamic = static_cast(getEventIdDyn()); + if (hasStatic == hasDynamic) + return emitOpError() + << "expects exactly one event-id form: static attr or dynamic index operand"; + if (IntegerAttr fftsModeAttr = getFftsModeAttr()) { + int64_t fftsMode = fftsModeAttr.getInt(); + if (fftsMode < 0 || fftsMode > 2) + return emitOpError() << "requires ffts_mode in range [0, 2], but got " + << fftsMode; + } + + auto verifyA2A3 = [&]() -> LogicalResult { return success(); }; + auto verifyA5 = [&]() -> LogicalResult { + switch (getPipe().getPipe()) { + case PIPE::PIPE_FIX: + case PIPE::PIPE_MTE3: + return success(); + default: + return emitOpError() + << "A5 sync.set expects pipe to be one of , "; + } + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +ParseResult mlir::pto::SyncWaitOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseSyncEventOpCommon(parser, result, + SyncWaitOp::getPipeAttrName(result.name), + SyncWaitOp::getEventIdAttrName(result.name)); +} + +void mlir::pto::SyncWaitOp::print(OpAsmPrinter &p) { + printSyncEventOpCommon(p, getOperation(), getPipe(), getEventIdAttr(), + getEventIdDyn(), getPipeAttrName().getValue(), + getEventIdAttrName().getValue()); +} + +ParseResult mlir::pto::SyncAllOp::parse(OpAsmParser &parser, + OperationState &result) { + SmallVector operands; + SmallVector operandTypes; + Attribute modeAttr; + Attribute coreTypeAttr; + + if (parser.parseLParen()) + return failure(); + + if (failed(parser.parseOptionalRParen())) { + if (parser.parseOperandList(operands) || parser.parseColonTypeList(operandTypes) || + parser.parseRParen()) + return failure(); + if (operands.size() != operandTypes.size()) + return parser.emitError(parser.getCurrentLocation()) + << "expects the same number of operands and operand types"; + } + + if (parser.parseKeyword("mode") || parser.parseEqual() || + parser.parseAttribute(modeAttr) || parser.parseComma() || + parser.parseKeyword("core_type") || parser.parseEqual() || + parser.parseAttribute(coreTypeAttr)) + return failure(); + + auto mode = dyn_cast(modeAttr); + if (!mode) + return parser.emitError(parser.getCurrentLocation()) + << "expects mode to be #pto.sync_all_mode<...>"; + + auto coreType = dyn_cast(coreTypeAttr); + if (!coreType) + return parser.emitError(parser.getCurrentLocation()) + << "expects core_type to be #pto.sync_core_type<...>"; + + result.addAttribute("mode", mode); + result.addAttribute("core_type", coreType); + + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + auto addSegmentSizes = [&](int32_t gm, int32_t ub, int32_t l1, + int32_t used) { + result.addAttribute("operandSegmentSizes", + parser.getBuilder().getDenseI32ArrayAttr( + {gm, ub, l1, used})); + }; + + switch (mode.getValue()) { + case pto::SyncAllMode::Hard: + if (!operands.empty()) + return parser.emitError(parser.getCurrentLocation()) + << "expects hard syncall to have no operands"; + addSegmentSizes(0, 0, 0, 0); + return success(); + case pto::SyncAllMode::Soft: + break; + } + + switch (coreType.getValue()) { + case pto::SyncCoreType::AIVOnly: + if (operands.size() != 2 && operands.size() != 3) + return parser.emitError(parser.getCurrentLocation()) + << "expects soft AIV-only syncall to have gm_workspace, " + "ub_workspace, and optional used_cores"; + if (parser.resolveOperand(operands[0], operandTypes[0], result.operands) || + parser.resolveOperand(operands[1], operandTypes[1], result.operands)) + return failure(); + if (operands.size() == 3 && + parser.resolveOperand(operands[2], operandTypes[2], result.operands)) + return failure(); + addSegmentSizes(1, 1, 0, operands.size() == 3 ? 1 : 0); + return success(); + case pto::SyncCoreType::AICOnly: + if (operands.size() != 2 && operands.size() != 3) + return parser.emitError(parser.getCurrentLocation()) + << "expects soft AIC-only syncall to have gm_workspace, " + "l1_workspace, and optional used_cores"; + if (parser.resolveOperand(operands[0], operandTypes[0], result.operands) || + parser.resolveOperand(operands[1], operandTypes[1], result.operands)) + return failure(); + if (operands.size() == 3 && + parser.resolveOperand(operands[2], operandTypes[2], result.operands)) + return failure(); + addSegmentSizes(1, 0, 1, operands.size() == 3 ? 1 : 0); + return success(); + case pto::SyncCoreType::Mix: + if (operands.size() != 3 && operands.size() != 4) + return parser.emitError(parser.getCurrentLocation()) + << "expects soft mixed syncall to have gm_workspace, " + "ub_workspace, l1_workspace, and optional used_cores"; + if (parser.resolveOperand(operands[0], operandTypes[0], result.operands) || + parser.resolveOperand(operands[1], operandTypes[1], result.operands) || + parser.resolveOperand(operands[2], operandTypes[2], result.operands)) + return failure(); + if (operands.size() == 4 && + parser.resolveOperand(operands[3], operandTypes[3], result.operands)) + return failure(); + addSegmentSizes(1, 1, 1, operands.size() == 4 ? 1 : 0); + return success(); + } + + llvm_unreachable("unhandled SyncCoreType"); +} + +void mlir::pto::SyncAllOp::print(OpAsmPrinter &p) { + SmallVector operands; + if (getGmWorkspace()) + operands.push_back(getGmWorkspace()); + if (getUbWorkspace()) + operands.push_back(getUbWorkspace()); + if (getL1Workspace()) + operands.push_back(getL1Workspace()); + if (getUsedCores()) + operands.push_back(getUsedCores()); + + p << "("; + if (!operands.empty()) { + p.printOperands(operands); + p << " : "; + llvm::interleaveComma(operands, p, + [&](Value operand) { p.printType(operand.getType()); }); + } + p << ") mode = " << getMode() << ", core_type = " << getCoreType(); + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"operandSegmentSizes", "mode", + "core_type"}); +} + +LogicalResult mlir::pto::SyncWaitOp::verify() { + bool hasStatic = getEventIdAttr() != nullptr; + bool hasDynamic = static_cast(getEventIdDyn()); + if (hasStatic == hasDynamic) + return emitOpError() + << "expects exactly one event-id form: static attr or dynamic index operand"; + + auto verifyA2A3 = [&]() -> LogicalResult { return success(); }; + auto verifyA5 = [&]() -> LogicalResult { + switch (getPipe().getPipe()) { + case PIPE::PIPE_FIX: + case PIPE::PIPE_MTE1: + case PIPE::PIPE_MTE2: + case PIPE::PIPE_MTE3: + case PIPE::PIPE_V: + return success(); + default: + return emitOpError() << "A5 sync.wait expects pipe to be one of " + ", , , " + ", "; + } + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult TStoreOp::verify() { + auto verifyCommon = + [&](bool allowLowPrecision) + -> FailureOr> { + auto srcTile = dyn_cast(getSrc().getType()); + auto dstPart = dyn_cast(getDst().getType()); + if (!srcTile || !dstPart) { + emitOpError("expects src to be !pto.tile_buf and dst to be !pto.partition_tensor_view"); + return failure(); + } + if (failed(verifyTileBufCommon(*this, srcTile, "src", allowLowPrecision))) + return failure(); + for (auto [idx, dim] : llvm::enumerate(dstPart.getShape())) { + if (dim != ShapedType::kDynamic && dim <= 0) { + emitOpError() << "expects dst shape[" << idx << "] to be positive"; + return failure(); + } + } + auto srcValid = srcTile.getValidShape(); + for (auto [idx, dim] : llvm::enumerate(srcValid)) { + if (dim != ShapedType::kDynamic && dim <= 0) { + emitOpError() << "expects src valid_shape[" << idx << "] to be positive"; + return failure(); + } + } + + // Keep TSTORE contract explicit while preserving existing legal layout + // reinterpretation paths (e.g. 1x1024 <-> 32x32, 5D partition views). + // When both sides are fully static, require equal element counts between + // dst shape and src valid_shape. + auto getStaticElemCount = [](ArrayRef shape) -> std::optional { + int64_t total = 1; + for (int64_t dim : shape) { + if (dim == ShapedType::kDynamic) + return std::nullopt; + if (dim <= 0) + return std::nullopt; + if (total > std::numeric_limits::max() / dim) + return std::nullopt; + total *= dim; + } + return total; + }; + + auto dstElemCount = getStaticElemCount(dstPart.getShape()); + auto srcValidElemCount = getStaticElemCount(srcValid); + if (dstElemCount && srcValidElemCount && *dstElemCount != *srcValidElemCount) { + emitOpError() << "expects dst static element count (" << *dstElemCount + << ") to match src valid_shape static element count (" + << *srcValidElemCount << ")"; + return failure(); + } + return std::make_pair(srcTile, dstPart); + }; + + auto isLoadStoreElemType = [&](Type ty) -> bool { + return ty.isInteger(8) || ty.isInteger(16) || ty.isInteger(32) || + ty.isInteger(64) || ty.isF16() || ty.isBF16() || ty.isF32(); + }; + auto isI8Like = [&](Type ty) -> bool { return ty.isInteger(8); }; + bool hasPreQuant = static_cast(getPreQuantScalar()); + auto reluMode = getReluPreMode(); + + auto verifyA2A3 = [&]() -> LogicalResult { + auto common = verifyCommon(/*allowLowPrecision=*/false); + if (failed(common)) + return failure(); + auto [srcTile, dstPart] = *common; + auto srcSpace = getPTOMemorySpaceEnum(srcTile); + if (!srcSpace || (*srcSpace != pto::AddressSpace::VEC && + *srcSpace != pto::AddressSpace::MAT && + *srcSpace != pto::AddressSpace::ACC)) + return emitOpError("expects A2/A3 tstore src to use loc=vec, loc=mat, or loc=acc"); + if (hasPreQuant && *srcSpace != pto::AddressSpace::ACC) + return emitOpError("expects preQuantScalar form to use loc=acc src"); + if (reluMode != pto::ReluPreMode::NoRelu && *srcSpace != pto::AddressSpace::ACC) + return emitOpError("expects reluPreMode form to use loc=acc src"); + + Type srcElem = srcTile.getElementType(); + Type dstElem = dstPart.getElementType(); + if (*srcSpace == pto::AddressSpace::VEC || *srcSpace == pto::AddressSpace::MAT) { + if (hasPreQuant) + return emitOpError("expects preQuantScalar form to use loc=acc src"); + if (isPTOLowPrecisionType(dstElem)) + return emitOpError("expects A2/A3 vec/mat tstore low-precision dst element types to be unsupported"); + if (!isLoadStoreElemType(srcElem)) + return emitOpError("expects A2/A3 vec/mat tstore src element type to be i8/i16/i32/i64/u64/f16/bf16/f32"); + if (getElemByteSize(srcElem) != getElemByteSize(dstElem)) + return emitOpError("expects A2/A3 vec/mat tstore src and dst element types to have the same bitwidth"); + return success(); + } + + if (!(srcElem.isInteger(32) || srcElem.isF32())) + return emitOpError("expects A2/A3 acc tstore src element type to be i32 or f32"); + if (hasPreQuant) { + if (srcElem.isInteger(32)) { + if (!(isI8Like(dstElem) || dstElem.isF16())) + return emitOpError("expects A2/A3 acc preQuantScalar tstore dst type to be i8/ui8/f16"); + } else if (srcElem.isF32()) { + if (!isI8Like(dstElem)) + return emitOpError("expects A2/A3 acc preQuantScalar tstore dst type to be i8/ui8"); + } + } else { + if (!(dstElem.isInteger(32) || dstElem.isF32() || dstElem.isF16() || + dstElem.isBF16())) + return emitOpError("expects A2/A3 acc tstore dst element type to be i32/f32/f16/bf16"); + } + + auto srcShape = srcTile.getShape(); + if (srcShape[1] != ShapedType::kDynamic && + (srcShape[1] < 1 || srcShape[1] > 4095)) + return emitOpError("expects A2/A3 acc tstore src cols to be in [1, 4095]"); + auto srcValid = srcTile.getValidShape(); + if (srcValid[1] != ShapedType::kDynamic && + (srcValid[1] < 1 || srcValid[1] > 4095)) + return emitOpError("expects A2/A3 acc tstore src valid_shape[1] to be in [1, 4095]"); + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + auto common = verifyCommon(/*allowLowPrecision=*/true); + if (failed(common)) + return failure(); + auto [srcTile, dstPart] = *common; + auto srcSpace = getPTOMemorySpaceEnum(srcTile); + if (!srcSpace || (*srcSpace != pto::AddressSpace::VEC && + *srcSpace != pto::AddressSpace::ACC)) + return emitOpError("expects A5 tstore src to use loc=vec or loc=acc"); + if (hasPreQuant && *srcSpace != pto::AddressSpace::ACC) + return emitOpError("expects preQuantScalar form to use loc=acc src"); + if (reluMode != pto::ReluPreMode::NoRelu && *srcSpace != pto::AddressSpace::ACC) + return emitOpError("expects reluPreMode form to use loc=acc src"); + + Type srcElem = srcTile.getElementType(); + Type dstElem = dstPart.getElementType(); + if (*srcSpace == pto::AddressSpace::VEC) { + if (hasPreQuant) + return emitOpError("expects preQuantScalar form to use loc=acc src"); + if (!isA5TLoadStoreTransferElemType(srcElem)) + return emitOpError("expects A5 vec tstore src element type to be i8/i16/i32/i64/f16/bf16/f32/f8/hif8/fp4"); + if (getElemByteSize(srcElem) != getElemByteSize(dstElem)) + return emitOpError("expects A5 vec tstore src and dst element types to have the same bitwidth"); + return success(); + } + + if (!(srcElem.isInteger(32) || srcElem.isF32())) + return emitOpError("expects A5 acc tstore src element type to be i32 or f32"); + if (hasPreQuant) { + if (!isA5AccStorePreQuantDstType(srcElem, dstElem)) + return emitOpError("expects A5 acc preQuantScalar tstore dst type to be i8/ui8/f16/bf16/f32/hif8/f8E4M3"); + } else { + if (!(dstElem.isInteger(32) || dstElem.isF32() || dstElem.isF16() || + dstElem.isBF16())) + return emitOpError("expects A5 acc tstore dst element type to be i32/f32/f16/bf16"); + } + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult pto::TAbsOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyVecTileCommon(*this, srcTy, "src")) || + failed(verifyVecTileCommon(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst")) || + failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) + return failure(); + + Type elemTy; + if (auto tb = dyn_cast(srcTy)) + elemTy = tb.getElementType(); + else if (auto mr = dyn_cast(srcTy)) + elemTy = mr.getElementType(); + if (!(elemTy.isF16() || elemTy.isF32())) + return emitOpError() << "expects element type to be f16 or f32"; + + return success(); +} +// PTO.cpp + +static bool isPTOShapedLike(Type ty) { + return mlir::isa(ty); +} + +static bool isTileLikeType(Type ty) { + return isa(ty); +} + +static Type getElemTy(Type ty) { + if (auto mr = mlir::dyn_cast(ty)) return mr.getElementType(); + if (auto tt = mlir::dyn_cast(ty)) return tt.getElementType(); + if (auto tv = mlir::dyn_cast(ty)) return tv.getElementType(); + if (auto tb = mlir::dyn_cast(ty)) return tb.getElementType(); + if (auto tv = mlir::dyn_cast(ty)) return tv.getElementType(); + return Type(); +} + +static SmallVector getShapeVec(Type ty) { + SmallVector s; + if (auto mr = mlir::dyn_cast(ty)) + return SmallVector(mr.getShape().begin(), mr.getShape().end()); + if (auto tt = mlir::dyn_cast(ty)) + return SmallVector(tt.getShape().begin(), tt.getShape().end()); + if (auto tv = mlir::dyn_cast(ty)) + return SmallVector(tv.getShape().begin(), tv.getShape().end()); + if (auto tb = mlir::dyn_cast(ty)) + return SmallVector(tb.getShape().begin(), tb.getShape().end()); + if (auto tv = mlir::dyn_cast(ty)) + return SmallVector(tv.getShape().begin(), tv.getShape().end()); + return {}; +} + +static SmallVector getValidShapeVec(Type ty) { + if (auto tb = dyn_cast(ty)) + return SmallVector(tb.getValidShape().begin(), tb.getValidShape().end()); + return getShapeVec(ty); +} + +static int64_t getLogicalTileDim(int64_t rawDim, Type elemTy, + std::optional blayout, + unsigned dimIdx) { + if (rawDim == ShapedType::kDynamic || !isPTOFloat4PackedType(elemTy)) + return rawDim; + pto::BLayout layout = blayout.value_or(pto::BLayout::RowMajor); + unsigned packedDim = layout == pto::BLayout::ColMajor ? 0 : 1; + return dimIdx == packedDim ? rawDim * 2 : rawDim; +} + +static std::optional getTileBufBLayout(Type ty) { + if (auto tb = dyn_cast(ty)) + return static_cast(tb.getBLayoutValueI32()); + return std::nullopt; +} + +static SmallVector getLogicalTileExtentVec(Type ty, + bool useValidShape) { + SmallVector dims = + useValidShape ? getValidShapeVec(ty) : getShapeVec(ty); + if (!isTileLikeType(ty) || dims.size() != 2) + return dims; + + Type elemTy = getElemTy(ty); + auto blayout = getTileBufBLayout(ty); + for (unsigned i = 0; i < dims.size(); ++i) + dims[i] = getLogicalTileDim(dims[i], elemTy, blayout, i); + return dims; +} + +static int64_t getConstantIndexOrDynamic(Value value) { + if (!value) + return ShapedType::kDynamic; + if (auto cst = value.getDefiningOp()) + return cst.value(); + if (auto cst = value.getDefiningOp()) + return cst.value(); + return ShapedType::kDynamic; +} + +static SmallVector getValidShapeVec(Value value) { + if (!value) + return {}; + auto valid = getValidShapeVec(value.getType()); + if (auto bind = value.getDefiningOp()) { + if (valid.size() >= 1 && bind.getValidRow()) + valid[0] = getConstantIndexOrDynamic(bind.getValidRow()); + if (valid.size() >= 2 && bind.getValidCol()) + valid[1] = getConstantIndexOrDynamic(bind.getValidCol()); + } + return valid; +} + +static SmallVector getMatmulLogicalShapeVec(Type ty) { + auto shape = getShapeVec(ty); + auto valid = getValidShapeVec(ty); + if (!isa(ty) || shape.size() != valid.size()) + return shape; + + for (size_t i = 0, e = shape.size(); i < e; ++i) { + if (valid[i] != ShapedType::kDynamic) + shape[i] = valid[i]; + } + return shape; +} + +static bool isByteIntegerType(Type ty) { + auto intTy = dyn_cast(ty); + return intTy && intTy.getWidth() == 8; +} + +static LogicalResult verifyAsyncFlatContiguous1DGMMemRef(Operation *op, + Value value, + StringRef name) { + auto memTy = dyn_cast(value.getType()); + if (!memTy) + return op->emitOpError() << "expects " << name << " to be a memref"; + if (!memTy.hasRank()) + return op->emitOpError() << "expects " << name << " to be a ranked memref"; + if (!isGmAddressSpaceAttr(memTy.getMemorySpace())) + return op->emitOpError() << "expects " << name + << " to be in GM address space"; + + ArrayRef shape = memTy.getShape(); + if (shape.empty()) + return op->emitOpError() << "expects " << name + << " to have rank >= 1"; + for (int64_t dim : shape) { + if (dim == ShapedType::kDynamic) + return op->emitOpError() << "expects " << name + << " to have a static shape"; + } + + SmallVector strides; + int64_t offset = 0; + if (failed(getStridesAndOffset(memTy, strides, offset))) + return op->emitOpError() << "expects " << name + << " to be a strided memref with a known layout"; + + bool hasDynamicLayout = + offset == ShapedType::kDynamic || + llvm::any_of(strides, [](int64_t stride) { + return stride == ShapedType::kDynamic; + }); + if (hasDynamicLayout) + return success(); + + bool packed = !strides.empty() && strides.back() == 1; + for (int i = static_cast(shape.size()) - 2; i >= 0 && packed; --i) + packed &= strides[i] == strides[i + 1] * shape[i + 1]; + if (!packed) + return op->emitOpError() + << "expects " << name + << " to be a static flat contiguous logical 1D GM memref"; + + bool logical1D = true; + for (int i = 0, e = static_cast(shape.size()) - 1; i < e; ++i) + logical1D &= shape[i] == 1; + if (!logical1D) + return op->emitOpError() + << "expects " << name + << " to be a static flat contiguous logical 1D GM memref"; + + return success(); +} + +static LogicalResult verifyAsyncFlatContiguous1DGMViewLike(Operation *op, + Value value, + StringRef name) { + Type ty = value.getType(); + if (isa(ty)) + return verifyAsyncFlatContiguous1DGMMemRef(op, value, name); + + if (!isa(ty)) + return op->emitOpError() << "expects " << name + << " to be a memref/tensor_view/partition_view"; + + SmallVector shape = getShapeVec(ty); + if (shape.empty()) + return op->emitOpError() << "expects " << name << " to have rank >= 1"; + for (int64_t dim : shape) { + if (dim == ShapedType::kDynamic) + return op->emitOpError() << "expects " << name + << " to have a static shape"; + } + + bool logical1D = true; + for (int i = 0, e = static_cast(shape.size()) - 1; i < e; ++i) + logical1D &= shape[i] == 1; + if (!logical1D) + return op->emitOpError() + << "expects " << name + << " to be a static flat contiguous logical 1D GM view"; + + return success(); +} + +static bool isCommGlobalLikeType(Type ty) { + if (auto memTy = dyn_cast(ty)) + return isGmAddressSpaceAttr(memTy.getMemorySpace()); + return isa(ty); +} + +static LogicalResult verifyCommGlobalLike(Operation *op, Value value, + StringRef name) { + Type ty = value.getType(); + if (!isCommGlobalLikeType(ty)) + return op->emitOpError() << "expects " << name + << " to be a GM memref/tensor_view/partition_view"; + + SmallVector shape = getShapeVec(ty); + if (shape.empty()) + return op->emitOpError() << "expects " << name << " to have rank >= 1"; + for (int64_t dim : shape) { + if (dim == ShapedType::kDynamic || dim <= 0) + return op->emitOpError() << "expects " << name + << " to have a positive static shape"; + } + return success(); +} + +static LogicalResult verifyCommSignalLike(Operation *op, Value value, + StringRef name) { + if (failed(verifyCommGlobalLike(op, value, name))) + return failure(); + Type elemTy = getElemTy(value.getType()); + if (!elemTy || !elemTy.isSignlessInteger(32)) + return op->emitOpError() << "expects " << name + << " element type to be i32"; + return success(); +} + +static LogicalResult verifyCommStagingTileLike(Operation *op, Value value, + StringRef name) { + Type ty = value.getType(); + if (!isa(ty)) + return op->emitOpError() << "expects " << name + << " to be a tile_buf or memref tile"; + auto as = getPTOMemorySpaceEnum(ty); + if (!as || *as != pto::AddressSpace::VEC) + return op->emitOpError() << "expects " << name + << " to be in vec address space"; + SmallVector shape = getShapeVec(ty); + if (shape.empty()) + return op->emitOpError() << "expects " << name << " to have rank >= 1"; + for (int64_t dim : shape) { + if (dim == ShapedType::kDynamic || dim <= 0) + return op->emitOpError() << "expects " << name + << " to have a positive static shape"; + } + return success(); +} + +static LogicalResult verifyCommGlobalGroup(Operation *op, ValueRange group, + StringRef name) { + if (group.empty()) + return op->emitOpError() << "expects at least one " << name << " operand"; + Type groupTy = group.front().getType(); + for (auto it : llvm::enumerate(group)) { + if (failed(verifyCommGlobalLike(op, it.value(), + (name + "[" + Twine(it.index()) + "]").str()))) + return failure(); + if (it.value().getType() != groupTy) + return op->emitOpError() << "expects all " << name + << " operands to have identical types"; + } + return success(); +} + +static LogicalResult verifyCommPingPongSameType(Operation *op, Value ping, + Value pong, StringRef pingName, + StringRef pongName) { + if (!pong) + return success(); + if (failed(verifyCommStagingTileLike(op, ping, pingName)) || + failed(verifyCommStagingTileLike(op, pong, pongName))) + return failure(); + if (ping.getType() != pong.getType()) + return op->emitOpError() << "expects " << pingName << " and " << pongName + << " to have identical types"; + return success(); +} + +static std::optional getStaticByteSize(Type ty) { + SmallVector shape = getShapeVec(ty); + if (shape.empty()) + return std::nullopt; + for (int64_t dim : shape) { + if (dim == ShapedType::kDynamic || dim < 0) + return std::nullopt; + } + + Type elemTy = getElemTy(ty); + uint64_t elemBytes = getElemByteSize(elemTy); + if (elemBytes == 0) + return std::nullopt; + + uint64_t total = elemBytes; + for (int64_t dim : shape) { + total *= static_cast(dim); + } + return total; +} + +static std::optional getPTOMemorySpaceEnum(Type ty) { + if (auto tb = dyn_cast(ty)) { + if (auto as = dyn_cast_or_null(tb.getMemorySpace())) + return as.getAddressSpace(); + return std::nullopt; + } + if (auto mr = dyn_cast(ty)) { + if (auto as = dyn_cast_or_null(mr.getMemorySpace())) + return as.getAddressSpace(); + if (!mr.getMemorySpace()) + return pto::AddressSpace::GM; + } + return std::nullopt; +} + +[[maybe_unused]] static bool isRank2TileBuf(Type ty) { + auto tb = dyn_cast(ty); + return tb && tb.getRank() == 2 && tb.getValidShape().size() == 2; +} + +static bool isSupportedVecElemType(Type ty, bool allowBf16, + bool allowInt8) { + if (ty.isF16() || ty.isF32()) + return true; + if (allowBf16 && ty.isBF16()) + return true; + if (auto it = dyn_cast(ty)) { + switch (it.getWidth()) { + case 32: + case 16: + return true; + case 8: + return allowInt8; + default: + return false; + } + } + return false; +} + +static bool isSupportedMGatherMScatterIndexElemType(Type ty) { + auto it = dyn_cast(ty); + if (!it || it.getWidth() != 32) + return false; + return true; +} + +static bool isSupportedMGatherMScatterPayloadElemType(Operation *op, Type ty) { + if (isSupportedVecElemType(ty, /*allowBf16=*/true, /*allowInt8=*/true)) + return true; + if (!isTargetArchA5(op)) + return false; + return ty.isFloat8E4M3() || ty.isFloat8E4M3FN() || ty.isFloat8E4M3FNUZ() || + ty.isFloat8E4M3B11FNUZ() || ty.isFloat8E5M2() || ty.isFloat8E5M2FNUZ(); +} + +static bool isSupportedMScatterAtomicPayloadElemType(Type ty, + pto::ScatterAtomicOp atomic) { + auto intTy = dyn_cast(ty); + switch (atomic) { + case pto::ScatterAtomicOp::None: + return true; + case pto::ScatterAtomicOp::Add: + return ty.isF16() || ty.isF32() || + (intTy && intTy.getWidth() == 32); + case pto::ScatterAtomicOp::Max: + case pto::ScatterAtomicOp::Min: + return ty.isF32() || + (intTy && intTy.getWidth() == 32); + } + llvm_unreachable("Unknown ScatterAtomicOp"); +} + +static LogicalResult verifyMGatherMScatterMemOperand(Operation *op, + Value memValue, + Type dataElemTy, + StringRef dataOperandLabel) { + Type memTy = memValue.getType(); + Type memElem = getElemTy(memTy); + if (!memElem || memElem != dataElemTy) + return op->emitOpError() << "expects mem element type to match " + << dataOperandLabel << " element type"; + + if (isa(memTy)) { + if (auto layout = getLogicalViewLayout(memValue)) { + if (*layout != pto::Layout::ND) + return op->emitOpError( + "expects mem partition view to use ND logical layout when layout " + "can be inferred"); + } + return success(); + } + + if (auto mr = dyn_cast(memTy)) { + auto as = getPTOMemorySpaceEnum(mr); + if (!as || (*as != pto::AddressSpace::GM && + *as != pto::AddressSpace::Zero)) + return op->emitOpError( + "expects mem memref to use GM or zero address space"); + if (mr.getRank() == 5) { + auto shape = mr.getShape(); + bool allStatic = true; + for (int64_t d : shape) + if (d == ShapedType::kDynamic) + allStatic = false; + if (allStatic && (shape[0] != 1 || shape[1] != 1 || shape[2] != 1)) + return op->emitOpError( + "expects rank-5 GM memref leading dimensions to be [1,1,1,...] " + "(GlobalTensor table shape)"); + } + return success(); + } + + return op->emitOpError( + "expects mem to be !pto.partition_tensor_view or a GM/ZERO memref"); +} + +static bool hasCompatibleKnownExtent(int64_t lhs, int64_t rhs); +static bool isKnownUnitExtent(int64_t value); + +static LogicalResult verifyMGatherMScatterTileShape(Operation *op, Type dataTy, + Type idxTy, + StringRef dataName) { + auto dataValid = getValidShapeVec(dataTy); + auto idxValid = getValidShapeVec(idxTy); + if (dataValid.size() != 2 || idxValid.size() != 2) + return op->emitOpError() << "expects " << dataName + << " and idx to have rank-2 valid_shape"; + + auto idxTile = dyn_cast(idxTy); + if (!idxTile) + return op->emitOpError("expects idx to be a tile_buf type"); + + const bool idxRowMajor = + idxTile.getBLayoutValueI32() == + static_cast(pto::BLayout::RowMajor); + const bool idxColMajor = + idxTile.getBLayoutValueI32() == + static_cast(pto::BLayout::ColMajor); + + const bool rowCoalesce1xR = + idxRowMajor && isKnownUnitExtent(idxValid[0]) && + hasCompatibleKnownExtent(idxValid[1], dataValid[0]); + const bool rowCoalesceRx1 = + idxColMajor && hasCompatibleKnownExtent(idxValid[0], dataValid[0]) && + isKnownUnitExtent(idxValid[1]); + const bool elemCoalesce = + hasCompatibleKnownExtent(idxValid[0], dataValid[0]) && + hasCompatibleKnownExtent(idxValid[1], dataValid[1]); + + if (!(rowCoalesce1xR || rowCoalesceRx1 || elemCoalesce)) + return op->emitOpError() + << "expects idx valid_shape to be [1, " << dataName + << ".valid_row], [" << dataName + << ".valid_row, 1], or match " << dataName << " valid_shape"; + + return success(); +} + +static LogicalResult verifyMGatherMScatterIdxTile(Operation *op, Type ty, + StringRef name) { + if (failed(verifyTileBufCommon(op, ty, name))) + return failure(); + auto as = getPTOMemorySpaceEnum(ty); + if (!as || *as != pto::AddressSpace::VEC) + return op->emitOpError() << "expects " << name + << " to be in the vec address space"; + auto tb = dyn_cast(ty); + if (!tb) + return op->emitOpError() << "expects " << name << " to be a tile_buf type"; + int32_t blayout = tb.getBLayoutValueI32(); + if (blayout != static_cast(pto::BLayout::RowMajor) && + blayout != static_cast(pto::BLayout::ColMajor)) + return op->emitOpError() << "expects " << name + << " to use row_major or col_major blayout"; + if (tb.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) + return op->emitOpError() << "expects " << name + << " to use the none_box slayout"; + return success(); +} + +static bool isA5TLoadStoreTransferElemType(Type ty) { + return ty.isInteger(8) || ty.isInteger(16) || ty.isInteger(32) || + ty.isInteger(64) || ty.isF16() || ty.isBF16() || ty.isF32() || + isPTOLowPrecisionType(ty); +} + +static bool isA5AccStorePreQuantDstType(Type srcElem, Type dstElem) { + if (srcElem.isInteger(32)) + return dstElem.isInteger(8) || dstElem.isF16() || dstElem.isBF16(); + if (!srcElem.isF32()) + return false; + return dstElem.isInteger(8) || dstElem.isF16() || dstElem.isBF16() || + dstElem.isF32() || isPTOHiFloat8Type(dstElem) || + dstElem.isFloat8E4M3() || dstElem.isFloat8E4M3FN() || + dstElem.isFloat8E4M3FNUZ() || dstElem.isFloat8E4M3B11FNUZ(); +} + +static bool isA5LowPrecisionTCvtPair(Type srcElem, Type dstElem) { + if (srcElem.isF32()) + return isPTOFloat8Type(dstElem) || isPTOHiFloat8Type(dstElem); + if (srcElem.isF16()) + return isPTOHiFloat8Type(dstElem); + if (srcElem.isBF16()) + return isPTOFloat4PackedType(dstElem); + if (isPTOFloat4PackedType(srcElem)) + return dstElem.isBF16(); + if (isPTOFloat8Type(srcElem) || isPTOHiFloat8Type(srcElem)) + return dstElem.isF32(); + return false; +} + +static bool isA5SupportedTCvtPair(Type srcElem, Type dstElem) { + if (isPTOLowPrecisionType(srcElem) || isPTOLowPrecisionType(dstElem)) + return isA5LowPrecisionTCvtPair(srcElem, dstElem); + return true; +} + +static LogicalResult verifyTileBufCommon(Operation *op, Type ty, StringRef name, + bool allowLowPrecision) { + auto tb = dyn_cast(ty); + if (tb) { + if (tb.getRank() != 2) + return op->emitOpError() << "expects " << name << " to be a rank-2 tile_buf"; + Type elemTy = tb.getElementType(); + if (!allowLowPrecision && isPTOLowPrecisionType(elemTy)) + return op->emitOpError() << name << ": dtype " << elemTy + << " is not supported by this op yet"; + } else if (auto mr = dyn_cast(ty)) { + if (mr.getRank() != 2) + return op->emitOpError() << "expects " << name << " to be a rank-2 memref"; + if (!allowLowPrecision && isPTOLowPrecisionType(mr.getElementType())) + return op->emitOpError() << name << ": dtype " << mr.getElementType() + << " is not supported by this op yet"; + } else { + return op->emitOpError() << "expects " << name << " to be a !pto.tile_buf or rank-2 memref"; + } + + auto validShape = getValidShapeVec(ty); + if (validShape.size() != 2) + return op->emitOpError() << "expects " << name << " to have a rank-2 valid_shape"; + auto shape = getShapeVec(ty); + for (unsigned i = 0; i < 2; ++i) { + if (shape[i] != ShapedType::kDynamic && validShape[i] != ShapedType::kDynamic && + validShape[i] > shape[i]) + return op->emitOpError() << "expects " << name << " to satisfy valid_shape[" << i + << "] <= shape[" << i << "]"; + } + return success(); +} + +static LogicalResult verifyTileBufSameElemType(Operation *op, Type lhs, Type rhs, + StringRef lhsName, + StringRef rhsName) { + if (!isTileLikeType(lhs) || !isTileLikeType(rhs)) + return op->emitOpError() << "expects " << lhsName << " and " << rhsName + << " to be !pto.tile_buf or memref"; + if (getElemTy(lhs) != getElemTy(rhs)) + return op->emitOpError() << "expects " << lhsName << " and " << rhsName + << " to have the same element type"; + return success(); +} + +static LogicalResult verifyTileBufSameValidShape(Operation *op, Type lhs, Type rhs, + StringRef lhsName, StringRef rhsName) { + if (!isTileLikeType(lhs) || !isTileLikeType(rhs)) + return success(); + auto lhsValid = getValidShapeVec(lhs); + auto rhsValid = getValidShapeVec(rhs); + for (size_t i = 0; i < lhsValid.size() && i < rhsValid.size(); ++i) { + if (lhsValid[i] != ShapedType::kDynamic && rhsValid[i] != ShapedType::kDynamic && + lhsValid[i] != rhsValid[i]) + return op->emitOpError() << "expects " << lhsName << " and " << rhsName + << " to have the same valid_shape"; + } + if (lhsValid.size() != rhsValid.size()) + return op->emitOpError() << "expects " << lhsName << " and " << rhsName + << " to have the same valid_shape"; + return success(); +} + +static LogicalResult verifyTileBufSameLogicalExtent(Operation *op, Type lhs, + Type rhs, StringRef lhsName, + StringRef rhsName, + bool compareValidShape) { + if (!isTileLikeType(lhs) || !isTileLikeType(rhs)) + return success(); + + auto lhsExtent = getLogicalTileExtentVec(lhs, compareValidShape); + auto rhsExtent = getLogicalTileExtentVec(rhs, compareValidShape); + auto emitMismatch = [&]() -> LogicalResult { + if (compareValidShape) + return op->emitOpError() << "expects " << lhsName << " and " << rhsName + << " to have the same valid_shape"; + return op->emitOpError() << "expects " << lhsName << " and " << rhsName + << " to have compatible shapes"; + }; + if (lhsExtent.size() != rhsExtent.size()) + return emitMismatch(); + + for (size_t i = 0, e = lhsExtent.size(); i < e; ++i) { + if (lhsExtent[i] != ShapedType::kDynamic && + rhsExtent[i] != ShapedType::kDynamic && lhsExtent[i] != rhsExtent[i]) + return emitMismatch(); + } + return success(); +} + +static LogicalResult verifyScaleTileMatchesOperand(Operation *op, Type scaleTy, + Type operandTy, + StringRef scaleName, + StringRef operandName) { + if (failed(verifyTileBufCommon(op, scaleTy, scaleName))) + return failure(); + auto scaleSpace = getPTOMemorySpaceEnum(scaleTy); + if (!scaleSpace || *scaleSpace != pto::AddressSpace::SCALING) + return op->emitOpError() << "expects " << scaleName + << " to be in the scaling address space"; + + auto scaleShape = getShapeVec(scaleTy); + auto operandShape = getShapeVec(operandTy); + if (scaleShape.size() != operandShape.size()) + return op->emitOpError() << "expects " << scaleName << " and " << operandName + << " to have the same rank"; + for (size_t i = 0; i < scaleShape.size(); ++i) { + if (scaleShape[i] != ShapedType::kDynamic && + operandShape[i] != ShapedType::kDynamic && + scaleShape[i] != operandShape[i]) + return op->emitOpError() << "expects " << scaleName << " and " << operandName + << " to have the same shape"; + } + + auto scaleValid = getValidShapeVec(scaleTy); + auto operandValid = getValidShapeVec(operandTy); + if (scaleValid.size() != operandValid.size()) + return op->emitOpError() << "expects " << scaleName << " and " << operandName + << " to have the same valid_shape"; + for (size_t i = 0; i < scaleValid.size(); ++i) { + if (scaleValid[i] != ShapedType::kDynamic && + operandValid[i] != ShapedType::kDynamic && + scaleValid[i] != operandValid[i]) + return op->emitOpError() << "expects " << scaleName << " and " << operandName + << " to have the same valid_shape"; + } + return success(); +} + +static LogicalResult verifyPartialValidPattern(Operation *op, Type src0Ty, + Type src1Ty, Type dstTy) { + auto src0Valid = getValidShapeVec(src0Ty); + auto src1Valid = getValidShapeVec(src1Ty); + auto dstValid = getValidShapeVec(dstTy); + if (src0Valid.size() != 2 || src1Valid.size() != 2 || dstValid.size() != 2) + return op->emitOpError("expects src0, src1, and dst to have rank-2 valid_shape"); + + auto lessEqualKnown = [](int64_t lhs, int64_t rhs) { + return lhs == ShapedType::kDynamic || rhs == ShapedType::kDynamic || lhs <= rhs; + }; + auto equalsKnown = [](ArrayRef lhs, ArrayRef rhs) { + for (auto [a, b] : llvm::zip(lhs, rhs)) { + if (a != ShapedType::kDynamic && b != ShapedType::kDynamic && a != b) + return false; + } + return true; + }; + + for (unsigned i = 0; i < 2; ++i) { + if (!lessEqualKnown(src0Valid[i], dstValid[i]) || + !lessEqualKnown(src1Valid[i], dstValid[i])) + return op->emitOpError( + "expects src0/src1 valid_shape to be less than or equal to dst valid_shape"); + } + if (!equalsKnown(src0Valid, dstValid) && !equalsKnown(src1Valid, dstValid)) + return op->emitOpError( + "expects at least one of src0/src1 valid_shape to match dst valid_shape"); + return success(); +} + +[[maybe_unused]] static bool hasKnownZeroValidRegion(Type ty) { + auto valid = getValidShapeVec(ty); + if (valid.size() != 2) + return false; + return valid[0] == 0 || valid[1] == 0; +} + +static LogicalResult verifyScalarTileOp(Operation *op, Type srcTy, Type dstTy, + StringRef srcName, StringRef dstName, + bool requireValidRowsEqual, + bool requireValidColsEqual) { + if (failed(verifyTileBufCommon(op, srcTy, srcName)) || + failed(verifyTileBufCommon(op, dstTy, dstName))) + return failure(); + auto srcSpace = getPTOMemorySpaceEnum(srcTy); + auto dstSpace = getPTOMemorySpaceEnum(dstTy); + if (!srcSpace || *srcSpace != pto::AddressSpace::VEC) + return op->emitOpError() << "expects " << srcName + << " to be in the vec address space"; + if (!dstSpace || *dstSpace != pto::AddressSpace::VEC) + return op->emitOpError() << "expects " << dstName + << " to be in the vec address space"; + if (failed(verifyTileBufSameElemType(op, srcTy, dstTy, srcName, dstName))) + return failure(); + + auto srcValid = getValidShapeVec(srcTy); + auto dstValid = getValidShapeVec(dstTy); + if (srcValid.size() != 2 || dstValid.size() != 2) + return op->emitOpError() + << "expects " << srcName << " and " << dstName + << " to have rank-2 valid_shape"; + if (requireValidRowsEqual && + srcValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && + srcValid[0] != dstValid[0]) + return op->emitOpError() + << "expects " << srcName << " and " << dstName + << " to have the same valid_shape[0]"; + if (requireValidColsEqual && + srcValid[1] != ShapedType::kDynamic && dstValid[1] != ShapedType::kDynamic && + srcValid[1] != dstValid[1]) + return op->emitOpError() + << "expects " << srcName << " and " << dstName + << " to have the same valid_shape[1]"; + return success(); +} + +static FailureOr +verifyMatchingRowMajorBinaryTileOpCommon(Operation *op, Type src0Ty, Type src1Ty, + Type dstTy) { + if (failed(verifyTileBufCommon(op, src0Ty, "src0")) || + failed(verifyTileBufCommon(op, src1Ty, "src1")) || + failed(verifyTileBufCommon(op, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameElemType(op, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameElemType(op, src0Ty, dstTy, "src0", "dst")) || + failed(verifyTileBufSameValidShape(op, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameValidShape(op, src0Ty, dstTy, "src0", "dst"))) + return failure(); + if (!isRowMajorTileBuf(src0Ty) || !isRowMajorTileBuf(src1Ty) || + !isRowMajorTileBuf(dstTy)) { + op->emitOpError("expects src0, src1, and dst to use row-major layout"); + return failure(); + } + return getElemTy(src0Ty); +} + +static FailureOr +verifyNumericScalarTileOpCommon(Operation *op, Type srcTy, Type dstTy, + Type scalarTy, bool requireValidRowsEqual) { + if (failed(verifyScalarTileOp(op, srcTy, dstTy, "src", "dst", + requireValidRowsEqual, + /*requireValidColsEqual=*/true))) + return failure(); + if (!mlir::isa(scalarTy)) { + op->emitOpError("scalar must be a scalar type (integer/float)"); + return failure(); + } + return getElemTy(srcTy); +} + +static FailureOr +verifyShiftLikeBinaryTileOpCommon(Operation *op, Type src0Ty, Type src1Ty, + Type dstTy) { + if (failed(verifyTileBufCommon(op, src0Ty, "src0")) || + failed(verifyTileBufCommon(op, src1Ty, "src1")) || + failed(verifyTileBufCommon(op, dstTy, "dst"))) + return failure(); + Type e0 = getElemTy(src0Ty); + Type e1 = getElemTy(src1Ty); + if (!e0 || !e1) { + op->emitOpError("failed to get element type for operands"); + return failure(); + } + if (e0 != e1) { + op->emitOpError("expects src0 and src1 to have the same element type"); + return failure(); + } + if (!isRowMajorTileBuf(src0Ty) || !isRowMajorTileBuf(src1Ty) || + !isRowMajorTileBuf(dstTy)) { + op->emitOpError("expects src0, src1, and dst to use row-major layout"); + return failure(); + } + if (failed(verifyTileBufSameValidShape(op, src0Ty, dstTy, "src0", "dst")) || + failed(verifyTileBufSameValidShape(op, src1Ty, dstTy, "src1", "dst"))) + return failure(); + return e0; +} + +static FailureOr verifyDistinctRowMajorUnaryTileOpCommon( + Operation *op, Value src, Value dst, StringRef srcName = "src", + StringRef dstName = "dst") { + if (src == dst) { + op->emitOpError("expects src and dst to use different storage"); + return failure(); + } + Type srcTy = src.getType(); + Type dstTy = dst.getType(); + if (failed(verifyTileBufCommon(op, srcTy, srcName)) || + failed(verifyTileBufCommon(op, dstTy, dstName))) + return failure(); + + Type srcElem = getElemTy(srcTy); + Type dstElem = getElemTy(dstTy); + if (!srcElem || !dstElem) { + op->emitOpError("failed to get element type for src/dst"); + return failure(); + } + if (srcElem != dstElem) { + op->emitOpError("expects src and dst to have the same element type"); + return failure(); + } + if (!isRowMajorTileBuf(srcTy) || !isRowMajorTileBuf(dstTy)) { + op->emitOpError("expects src and dst to use row-major layout"); + return failure(); + } + if (failed(verifyTileBufSameValidShape(op, srcTy, dstTy, srcName, dstName))) + return failure(); + return srcElem; +} + +static LogicalResult verifyArithmeticElemTypeForArch( + Operation *op, Type elemTy, PTOArch targetArch, bool allowInt8OnA5, + bool allowBf16OnA5, StringRef a2a3Error, StringRef a5Error) { + bool supported = elemTy.isInteger(32) || elemTy.isInteger(16) || + elemTy.isF16() || elemTy.isF32(); + if (targetArch == PTOArch::A5) + supported = supported || (allowInt8OnA5 && elemTy.isInteger(8)) || + (allowBf16OnA5 && elemTy.isBF16()); + if (supported) + return success(); + return op->emitOpError(targetArch == PTOArch::A5 ? a5Error : a2a3Error); +} + +static LogicalResult verifyArithmeticBinaryTileOpWithArchDispatch( + Operation *op, Type src0Ty, Type src1Ty, Type dstTy, bool allowInt8OnA5, + bool allowBf16OnA5, StringRef a2a3Error, StringRef a5Error) { + auto verifyByArch = [&](PTOArch targetArch) -> LogicalResult { + FailureOr elemOr = + verifyMatchingRowMajorBinaryTileOpCommon(op, src0Ty, src1Ty, dstTy); + if (failed(elemOr)) + return failure(); + return verifyArithmeticElemTypeForArch(op, *elemOr, targetArch, + allowInt8OnA5, allowBf16OnA5, + a2a3Error, a5Error); + }; + auto verifyA2A3 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A3); }; + auto verifyA5 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A5); }; + return dispatchVerifierByArch(op, verifyA2A3, verifyA5); +} + +static LogicalResult verifyArithmeticScalarTileOpWithArchDispatch( + Operation *op, Type srcTy, Type dstTy, Type scalarTy, bool allowInt8OnA5, + bool allowBf16OnA5, StringRef a2a3Error, StringRef a5Error, + bool requireValidRowsEqualOnA2A3 = true, + bool requireValidRowsEqualOnA5 = false) { + auto verifyByArch = [&](PTOArch targetArch, + bool requireValidRowsEqual) -> LogicalResult { + FailureOr elemOr = verifyNumericScalarTileOpCommon( + op, srcTy, dstTy, scalarTy, requireValidRowsEqual); + if (failed(elemOr)) + return failure(); + return verifyArithmeticElemTypeForArch(op, *elemOr, targetArch, + allowInt8OnA5, allowBf16OnA5, + a2a3Error, a5Error); + }; + auto verifyA2A3 = [&]() -> LogicalResult { + return verifyByArch(PTOArch::A3, requireValidRowsEqualOnA2A3); + }; + auto verifyA5 = [&]() -> LogicalResult { + return verifyByArch(PTOArch::A5, requireValidRowsEqualOnA5); + }; + return dispatchVerifierByArch(op, verifyA2A3, verifyA5); +} + +static LogicalResult verifyTColReductionElemTypeForArch( + Operation *op, Type elemTy, PTOArch targetArch, bool allowInt8OnA5, + bool allowBf16OnA5, StringRef a2a3Error, StringRef a5Error) { + bool ok = elemTy.isF16() || elemTy.isF32() || elemTy.isInteger(16) || + elemTy.isInteger(32); + if (targetArch == PTOArch::A5) + ok = ok || (allowInt8OnA5 && elemTy.isInteger(8)) || + (allowBf16OnA5 && elemTy.isBF16()); + if (ok) + return success(); + return op->emitOpError(targetArch == PTOArch::A5 ? a5Error : a2a3Error); +} + +static LogicalResult verifyTColReductionOpWithArchDispatch( + Operation *op, Type srcTy, Type dstTy, bool requireNonZeroSrcOnA2A3, + bool requireNonZeroSrcOnA5, bool allowInt8OnA5, bool allowBf16OnA5, + StringRef a2a3Error, StringRef a5Error) { + auto verifyByArch = [&](PTOArch targetArch, + bool requireNonZeroSrc) -> LogicalResult { + if (failed(verifyNDStyleVecTile(op, srcTy, "src")) || + failed(verifyNDStyleVecTile(op, dstTy, "dst"))) + return failure(); + if (getElemTy(srcTy) != getElemTy(dstTy)) + return op->emitOpError("expects src and dst to have the same element type"); + if (failed(verifyColReductionValidRegion(op, srcTy, dstTy, requireNonZeroSrc))) + return failure(); + Type elem = getElemTy(srcTy); + return verifyTColReductionElemTypeForArch(op, elem, targetArch, allowInt8OnA5, + allowBf16OnA5, a2a3Error, a5Error); + }; + auto verifyA2A3 = [&]() -> LogicalResult { + return verifyByArch(PTOArch::A3, requireNonZeroSrcOnA2A3); + }; + auto verifyA5 = [&]() -> LogicalResult { + return verifyByArch(PTOArch::A5, requireNonZeroSrcOnA5); + }; + return dispatchVerifierByArch(op, verifyA2A3, verifyA5); +} + +static LogicalResult verifyTColArgReductionOpCommon(Operation *op, Type srcTy, + Type tmpTy, Type dstTy) { + if (failed(verifyNDStyleVecTile(op, srcTy, "src")) || + failed(verifyVecTileCommon(op, tmpTy, "tmp")) || + failed(verifyColArgReductionDstLayout(op, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameElemType(op, srcTy, tmpTy, "src", "tmp")) || + failed(verifyTileBufSameValidShape(op, srcTy, tmpTy, "src", "tmp"))) + return failure(); + if (failed(verifyColReductionValidRegion(op, srcTy, dstTy, + /*requireNonZeroSrc=*/true))) + return failure(); + Type srcElemTy = getElemTy(srcTy); + unsigned srcElemBits = srcElemTy ? getPTOStorageElemBitWidth(srcElemTy) : 0; + if (!(mlir::isa(srcElemTy) && + (srcElemBits == 8 || srcElemBits == 16 || srcElemBits == 32))) + return op->emitOpError( + "expects src/tmp element type to be 1, 2, or 4 bytes wide"); + auto dstInt = dyn_cast(getElemTy(dstTy)); + if (!dstInt || dstInt.getWidth() != 32) + return op->emitOpError("expects dst element type to be i32 or ui32"); + return success(); +} + +static bool hasCompatibleKnownExtent(int64_t lhs, int64_t rhs) { + return lhs == ShapedType::kDynamic || rhs == ShapedType::kDynamic || lhs == rhs; +} + +static bool isKnownUnitExtent(int64_t value) { + return value == ShapedType::kDynamic || value == 1; +} + +static LogicalResult verifyVecTileStorage(Operation *op, Type ty, StringRef name) { + if (failed(verifyTileBufCommon(op, ty, name))) + return failure(); + auto as = getPTOMemorySpaceEnum(ty); + if (!as || *as != pto::AddressSpace::VEC) + return op->emitOpError() << "expects " << name << " to be in the vec address space"; + return success(); +} + +static LogicalResult verifyVecTileCommonA2A3(Operation *op, Type ty, + StringRef name) { + if (failed(verifyTileBufCommon(op, ty, name))) + return failure(); + auto tb = dyn_cast(ty); + auto as = getPTOMemorySpaceEnum(ty); + if (as && *as != pto::AddressSpace::VEC) + return op->emitOpError() << "expects " << name << " to be in the vec address space"; + if (tb && tb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor)) + return op->emitOpError() << "expects " << name << " to use the row_major blayout"; + return success(); +} + +static LogicalResult verifyVecTileCommonA5(Operation *op, Type ty, + StringRef name) { + return verifyVecTileCommonA2A3(op, ty, name); +} + +static LogicalResult verifyVecTileCommon(Operation *op, Type ty, StringRef name) { + switch (getVerifierTargetArch(op)) { + case VerifierTargetArch::A2A3: + return verifyVecTileCommonA2A3(op, ty, name); + case VerifierTargetArch::A5: + return verifyVecTileCommonA5(op, ty, name); + } + return failure(); +} + +static LogicalResult verifyVecTileUnaryOp(Operation *op, Type srcTy, Type dstTy, + StringRef srcName, + StringRef dstName, + bool allowBf16, + bool allowInt8) { + if (failed(verifyVecTileCommon(op, srcTy, srcName)) || + failed(verifyVecTileCommon(op, dstTy, dstName))) + return failure(); + if (failed(verifyTileBufSameElemType(op, srcTy, dstTy, srcName, dstName))) + return failure(); + if (!isSupportedVecElemType(getElemTy(srcTy), allowBf16, allowInt8)) + return op->emitOpError() << "expects vec tile element types to be supported"; + return success(); +} + +static LogicalResult verifyAccTileCommonA2A3(Operation *op, Type ty, + StringRef name) { + if (failed(verifyTileBufCommon(op, ty, name))) + return failure(); + auto as = getPTOMemorySpaceEnum(ty); + if (!as || *as != pto::AddressSpace::ACC) + return op->emitOpError() << "expects " << name << " to be in the acc address space"; + return success(); +} + +static LogicalResult verifyAccTileCommonA5(Operation *op, Type ty, + StringRef name) { + return verifyAccTileCommonA2A3(op, ty, name); +} + +static LogicalResult verifyAccTileCommon(Operation *op, Type ty, StringRef name) { + switch (getVerifierTargetArch(op)) { + case VerifierTargetArch::A2A3: + return verifyAccTileCommonA2A3(op, ty, name); + case VerifierTargetArch::A5: + return verifyAccTileCommonA5(op, ty, name); + } + return failure(); +} + +static LogicalResult verifyMatTileOperandsA2A3(Operation *op, Type lhsTy, + Type rhsTy, Type dstTy) { + if (failed(verifyTileBufCommon(op, lhsTy, "lhs")) || + failed(verifyTileBufCommon(op, rhsTy, "rhs")) || + failed(verifyAccTileCommon(op, dstTy, "dst"))) + return failure(); + auto lhsSpace = getPTOMemorySpaceEnum(lhsTy); + auto rhsSpace = getPTOMemorySpaceEnum(rhsTy); + auto dstSpace = getPTOMemorySpaceEnum(dstTy); + if (!lhsSpace || !rhsSpace || !dstSpace) + return op->emitOpError("expects lhs, rhs, and dst to have explicit address spaces"); + if (*lhsSpace != pto::AddressSpace::LEFT || *rhsSpace != pto::AddressSpace::RIGHT || + *dstSpace != pto::AddressSpace::ACC) + return op->emitOpError( + "expects lhs, rhs, and dst to use the left, right, and acc address spaces"); + auto lhsShape = getMatmulLogicalShapeVec(lhsTy); + auto rhsShape = getMatmulLogicalShapeVec(rhsTy); + auto dstShape = getMatmulLogicalShapeVec(dstTy); + if ((lhsShape[0] != dstShape[0] || rhsShape[1] != dstShape[1] || lhsShape[1] != rhsShape[0])) + return op->emitOpError( + "expects static matmul tile shapes lhs[M,K], rhs[K,N], and dst[M,N]"); + auto lhsValid = getValidShapeVec(lhsTy); + auto rhsValid = getValidShapeVec(rhsTy); + if (lhsValid.size() == 2 && rhsValid.size() == 2) { + int64_t m = lhsValid[0]; + int64_t k = lhsValid[1]; + int64_t n = rhsValid[1]; + if ((m != ShapedType::kDynamic && (m < 1 || m > 4095)) || + (k != ShapedType::kDynamic && (k < 1 || k > 4095)) || + (n != ShapedType::kDynamic && (n < 1 || n > 4095))) + return op->emitOpError("expects m, k, and n valid sizes to be in [1, 4095]"); + } + return success(); +} + +static LogicalResult verifyMatTileOperandsA5(Operation *op, Type lhsTy, + Type rhsTy, Type dstTy) { + if (failed(verifyMatTileOperandsA2A3(op, lhsTy, rhsTy, dstTy))) + return failure(); + + auto lhsTb = mlir::dyn_cast(lhsTy); + auto rhsTb = mlir::dyn_cast(rhsTy); + auto dstTb = mlir::dyn_cast(dstTy); + if (!lhsTb || !rhsTb || !dstTb) + return success(); + + if (lhsTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor)) + return op->emitOpError("expects lhs to use the col_major blayout on A5"); + if (rhsTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor)) + return op->emitOpError("expects rhs to use the row_major blayout on A5"); + if (dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor)) + return op->emitOpError("expects dst to use the col_major blayout on A5"); + + if (lhsTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor)) + return op->emitOpError("expects lhs to use the row_major slayout on A5"); + if (rhsTb.getSLayoutValueI32() != static_cast(pto::SLayout::ColMajor)) + return op->emitOpError("expects rhs to use the col_major slayout on A5"); + if (dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor)) + return op->emitOpError("expects dst to use the row_major slayout on A5"); + return success(); +} + +static LogicalResult verifyMatTileOperands(Operation *op, Type lhsTy, Type rhsTy, + Type dstTy) { + switch (getVerifierTargetArch(op)) { + case VerifierTargetArch::A2A3: + return verifyMatTileOperandsA2A3(op, lhsTy, rhsTy, dstTy); + case VerifierTargetArch::A5: + return verifyMatTileOperandsA5(op, lhsTy, rhsTy, dstTy); + } + return failure(); +} + +static LogicalResult verifyGemvTileOperandsA2A3(Operation *op, Type lhsTy, + Type rhsTy, Type dstTy) { + if (failed(verifyTileBufCommon(op, lhsTy, "lhs")) || + failed(verifyTileBufCommon(op, rhsTy, "rhs")) || + failed(verifyAccTileCommon(op, dstTy, "dst"))) + return failure(); + + auto lhsSpace = getPTOMemorySpaceEnum(lhsTy); + auto rhsSpace = getPTOMemorySpaceEnum(rhsTy); + if (!lhsSpace || !rhsSpace) + return op->emitOpError("expects lhs and rhs to have explicit address spaces"); + if (*lhsSpace != pto::AddressSpace::LEFT || *rhsSpace != pto::AddressSpace::RIGHT) + return op->emitOpError( + "expects lhs and rhs to use the left and right address spaces"); + + auto lhsValid = getValidShapeVec(lhsTy); + auto rhsValid = getValidShapeVec(rhsTy); + auto dstValid = getValidShapeVec(dstTy); + if (lhsValid[0] != ShapedType::kDynamic && lhsValid[0] != 1) + return op->emitOpError("expects lhs valid_shape[0] to be 1 for tgemv"); + if (isa(dstTy) && dstValid[0] != ShapedType::kDynamic && + dstValid[0] != 1) + return op->emitOpError("expects dst valid_shape[0] to be 1 for tgemv"); + if (lhsValid[1] != ShapedType::kDynamic && rhsValid[0] != ShapedType::kDynamic && + lhsValid[1] != rhsValid[0]) + return op->emitOpError() + << "expects lhs valid_shape[1] to equal rhs valid_shape[0], but got " + << lhsValid[1] << " vs " << rhsValid[0]; + if (rhsValid[1] != ShapedType::kDynamic && dstValid[1] != ShapedType::kDynamic && + rhsValid[1] != dstValid[1]) + return op->emitOpError() + << "expects rhs valid_shape[1] to equal dst valid_shape[1], but got " + << rhsValid[1] << " vs " << dstValid[1]; + return success(); +} + +static LogicalResult verifyGemvTileOperandsA5(Operation *op, Type lhsTy, + Type rhsTy, Type dstTy) { + if (failed(verifyGemvTileOperandsA2A3(op, lhsTy, rhsTy, dstTy))) + return failure(); + return verifyMatTileOperandsA5(op, lhsTy, rhsTy, dstTy); +} + +static LogicalResult verifyGemvTileOperands(Operation *op, Type lhsTy, Type rhsTy, + Type dstTy) { + switch (getVerifierTargetArch(op)) { + case VerifierTargetArch::A2A3: + return verifyGemvTileOperandsA2A3(op, lhsTy, rhsTy, dstTy); + case VerifierTargetArch::A5: + return verifyGemvTileOperandsA5(op, lhsTy, rhsTy, dstTy); + } + return failure(); +} + +static LogicalResult verifyMatBiasTileA2A3(Operation *op, Type biasTy, Type dstTy, + bool requireFloatBias) { + if (failed(verifyTileBufCommon(op, biasTy, "bias"))) + return failure(); + auto biasSpace = getPTOMemorySpaceEnum(biasTy); + if (!biasSpace || *biasSpace != pto::AddressSpace::BIAS) + return op->emitOpError("expects bias to be in the bias address space"); + auto biasShape = getShapeVec(biasTy); + if (biasShape[0] != ShapedType::kDynamic && biasShape[0] != 1) + return op->emitOpError("expects bias to have 1 row"); + if (requireFloatBias) { + if (!getElemTy(biasTy).isF32()) + return op->emitOpError("expects bias to have element type f32"); + } else if (getElemTy(biasTy) != getElemTy(dstTy)) { + return op->emitOpError("expects bias and dst to have the same element type"); + } + return success(); +} + +static LogicalResult verifyMatBiasTileA5(Operation *op, Type biasTy, Type dstTy, + bool requireFloatBias) { + if (failed(verifyMatBiasTileA2A3(op, biasTy, dstTy, requireFloatBias))) + return failure(); + if (auto biasTb = dyn_cast(biasTy)) { + if (biasTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor)) + return op->emitOpError("expects bias to use the row_major blayout on A5"); + } + return success(); +} + +static LogicalResult verifyMatBiasTile(Operation *op, Type biasTy, Type dstTy, + bool requireFloatBias) { + switch (getVerifierTargetArch(op)) { + case VerifierTargetArch::A2A3: + return verifyMatBiasTileA2A3(op, biasTy, dstTy, requireFloatBias); + case VerifierTargetArch::A5: + return verifyMatBiasTileA5(op, biasTy, dstTy, requireFloatBias); + } + return failure(); +} + +static LogicalResult verifyMatmulTypeTriple(Operation *op, Type lhsElemTy, + Type rhsElemTy, Type dstElemTy) { + bool isA5 = getVerifierTargetArch(op) == VerifierTargetArch::A5; + auto isInt8 = [](Type ty) { + return ty.isInteger(8); + }; + if (dstElemTy.isInteger(32) && isInt8(lhsElemTy) && isInt8(rhsElemTy)) + return success(); + + auto isSupportedFpInput = [](Type ty) { + return ty.isF16() || ty.isBF16() || ty.isF32(); + }; + if (dstElemTy.isF32() && lhsElemTy == rhsElemTy && isSupportedFpInput(lhsElemTy)) + return success(); + + if (isA5 && dstElemTy.isF32() && lhsElemTy == rhsElemTy) { + if (auto ft = mlir::dyn_cast(lhsElemTy)) { + unsigned width = ft.getWidth(); + if (width == 8 || width == 16 || width == 32) + return success(); + } + } + + return op->emitOpError() + << "expects (dst, lhs, rhs) element types to match one of " + "(i32, i8, i8), (f32, f16, f16), (f32, bf16, bf16), (f32, f32, f32)" + << (isA5 ? ", or an A5-supported fp8 pair" : ""); +} + +LogicalResult pto::TAddOp::verify() { + return verifyArithmeticBinaryTileOpWithArchDispatch( + getOperation(), getSrc0().getType(), getSrc1().getType(), getDst().getType(), + /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, + "expects A2/A3 tadd element type to be i32/i16/f16/f32", + "expects A5 tadd element type to be i32/i16/i8/f16/bf16/f32"); +} + +LogicalResult pto::TAddCOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + Type t0 = getSrc0().getType(); + Type t1 = getSrc1().getType(); + Type t2 = getSrc2().getType(); + Type td = getDst().getType(); + + if (!isPTOShapedLike(t0) || !isPTOShapedLike(t1) || + !isPTOShapedLike(t2) || !isPTOShapedLike(td)) + return emitOpError("expects src0/src1/src2/dst to be memref/tile_buf types"); + + auto s0 = getShapeVec(t0); + auto s1 = getShapeVec(t1); + auto s2 = getShapeVec(t2); + auto sd = getShapeVec(td); + if (s0 != s1 || s0 != s2 || s0 != sd) + return emitOpError("expects src0/src1/src2/dst to have the same shape"); + return success(); +} +LogicalResult pto::TAddSOp::verify() { + return verifyArithmeticScalarTileOpWithArchDispatch( + getOperation(), getSrc().getType(), getDst().getType(), getScalar().getType(), + /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, + "expects A2/A3 tadds element type to be i32/i16/f16/f32", + "expects A5 tadds element type to be i32/i16/i8/f16/bf16/f32", + /*requireValidRowsEqualOnA2A3=*/true, + /*requireValidRowsEqualOnA5=*/false); +} + +LogicalResult pto::TAxpyOp::verify() { + auto verifyCommon = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyVecTileCommon(*this, srcTy, "src")) || + failed(verifyVecTileCommon(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) + return failure(); + + Type scalarTy = getScalar().getType(); + Type srcElem = getElemTy(srcTy); + if (scalarTy != srcElem) + return emitOpError("expects scalar type to match src element type"); + if (getShapeVec(srcTy) != getShapeVec(dstTy)) + return emitOpError("expects src and dst to have the same shape"); + return success(); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + if (failed(verifyCommon())) + return failure(); + Type srcElem = getElemTy(getSrc().getType()); + Type dstElem = getElemTy(getDst().getType()); + bool sameType = srcElem == dstElem; + bool widenF16ToF32 = srcElem.isF16() && dstElem.isF32(); + if (!(sameType || widenF16ToF32)) + return emitOpError( + "expects dst/src element types to match, or dst=f32 and src=f16"); + if (!(dstElem.isF16() || dstElem.isF32())) + return emitOpError("expects A2/A3 taxpy dst element type to be f16/f32"); + if (!(srcElem.isF16() || srcElem.isF32())) + return emitOpError("expects A2/A3 taxpy src element type to be f16/f32"); + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + if (failed(verifyCommon())) + return failure(); + Type srcElem = getElemTy(getSrc().getType()); + Type dstElem = getElemTy(getDst().getType()); + bool sameType = srcElem == dstElem; + bool widenF16ToF32 = srcElem.isF16() && dstElem.isF32(); + if (!(sameType || widenF16ToF32)) + return emitOpError( + "expects dst/src element types to match, or dst=f32 and src=f16"); + if (!(dstElem.isF16() || dstElem.isF32() || dstElem.isBF16())) + return emitOpError("expects A5 taxpy dst element type to be f16/bf16/f32"); + if (!(srcElem.isF16() || srcElem.isF32() || srcElem.isBF16())) + return emitOpError("expects A5 taxpy src element type to be f16/bf16/f32"); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult pto::TAddSCOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + Type ts0 = getSrc0().getType(); + Type ts1 = getSrc1().getType(); + Type td = getDst().getType(); + if (!isPTOShapedLike(ts0) || !isPTOShapedLike(ts1) || !isPTOShapedLike(td)) + return emitOpError("expects src0/src1/dst to be PTO shaped-like types"); + + auto s0 = getShapeVec(ts0); + auto s1 = getShapeVec(ts1); + auto sd = getShapeVec(td); + if (s0 != s1 || s0 != sd) + return emitOpError("expects src0/src1/dst to have the same shape"); + return success(); +} + +LogicalResult pto::TAndOp::verify() { + auto verifyCommon = [&]() -> FailureOr { + return verifyMatchingRowMajorBinaryTileOpCommon( + getOperation(), getSrc0().getType(), getSrc1().getType(), + getDst().getType()); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + FailureOr elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + auto it = mlir::dyn_cast(*elemOr); + if (!it || (it.getWidth() != 8 && it.getWidth() != 16)) + return emitOpError( + "expects A2/A3 tand src0, src1, and dst element type to be i8/i16"); + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + FailureOr elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + auto it = mlir::dyn_cast(*elemOr); + if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && + it.getWidth() != 32)) + return emitOpError( + "expects A5 tand src0, src1, and dst element type to be i8/i16/i32"); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TConcatOp::verify() { + auto verifyCommon = [&]() -> FailureOr { + Type t0 = getSrc0().getType(); + Type t1 = getSrc1().getType(); + Type td = getDst().getType(); + if (failed(verifyTileBufCommon(*this, t0, "src0")) || + failed(verifyTileBufCommon(*this, t1, "src1")) || + failed(verifyTileBufCommon(*this, td, "dst"))) + return failure(); + + Type e0 = getElemTy(t0); + Type e1 = getElemTy(t1); + Type ed = getElemTy(td); + if (!e0 || !e1 || !ed) { + emitOpError("failed to get element type for operands"); + return failure(); + } + if (e0 != e1 || e0 != ed) { + emitOpError("expects src0, src1, and dst to have the same element type"); + return failure(); + } + + auto v0 = getValidShapeVec(getSrc0()); + auto v1 = getValidShapeVec(getSrc1()); + auto vd = getValidShapeVec(getDst()); + if (v0.size() != 2 || v1.size() != 2 || vd.size() != 2) + return emitOpError("expects src0, src1, and dst to have rank-2 valid_shape"); + + // validRow must match dst (when known). + if (v0[0] != ShapedType::kDynamic && vd[0] != ShapedType::kDynamic && v0[0] != vd[0]) + return emitOpError("expects src0 valid row to match dst valid row"); + if (v1[0] != ShapedType::kDynamic && vd[0] != ShapedType::kDynamic && v1[0] != vd[0]) + return emitOpError("expects src1 valid row to match dst valid row"); + + // Total valid columns must fit within dst static cols (when known). + auto sd = getShapeVec(td); + if (sd.size() == 2 && sd[1] != ShapedType::kDynamic && + v0[1] != ShapedType::kDynamic && v1[1] != ShapedType::kDynamic) { + if (v0[1] + v1[1] > sd[1]) + return emitOpError("expects src0.valid_col + src1.valid_col <= dst.cols"); + } + + return e0; + }; + + auto verifyElemType = [&](Type elem) -> LogicalResult { + if (elem.isF16() || elem.isF32() || elem.isBF16()) + return success(); + auto it = mlir::dyn_cast(elem); + if (!it || + (it.getWidth() != 8 && it.getWidth() != 16 && it.getWidth() != 32)) + return emitOpError("expects element type to be i8, i16, i32, f16, f32, or bf16"); + return success(); + }; + + auto verifyLocVec = [&](Type ty, StringRef name) -> LogicalResult { + auto as = getPTOMemorySpaceEnum(ty); + if (!as || *as != pto::AddressSpace::VEC) + return emitOpError() << "expects " << name << " to use loc=vec"; + return success(); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + FailureOr elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + if (failed(verifyLocVec(getSrc0().getType(), "src0")) || + failed(verifyLocVec(getSrc1().getType(), "src1")) || + failed(verifyLocVec(getDst().getType(), "dst"))) + return failure(); + return verifyElemType(*elemOr); + }; + + auto verifyA5 = [&]() -> LogicalResult { + FailureOr elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + if (failed(verifyLocVec(getSrc0().getType(), "src0")) || + failed(verifyLocVec(getSrc1().getType(), "src1")) || + failed(verifyLocVec(getDst().getType(), "dst"))) + return failure(); + if (!isRowMajorTileBuf(getSrc0().getType()) || !isRowMajorTileBuf(getSrc1().getType()) || + !isRowMajorTileBuf(getDst().getType())) + return emitOpError("expects src0, src1, and dst to use row-major layout"); + return verifyElemType(*elemOr); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult pto::TConcatidxOp::verify() { + auto verifyCommon = [&]() -> FailureOr> { + Type t0 = getSrc0().getType(); + Type t1 = getSrc1().getType(); + Type ti0 = getSrc0Idx().getType(); + Type ti1 = getSrc1Idx().getType(); + Type td = getDst().getType(); + if (failed(verifyTileBufCommon(*this, t0, "src0")) || + failed(verifyTileBufCommon(*this, t1, "src1")) || + failed(verifyTileBufCommon(*this, ti0, "src0Idx")) || + failed(verifyTileBufCommon(*this, ti1, "src1Idx")) || + failed(verifyTileBufCommon(*this, td, "dst"))) + return failure(); + + // Check data element type consistency. + Type e0 = getElemTy(t0); + Type e1 = getElemTy(t1); + Type ed = getElemTy(td); + if (!e0 || !e1 || !ed) { + emitOpError("failed to get element type for data operands"); + return failure(); + } + if (e0 != e1 || e0 != ed) { + emitOpError("expects src0, src1, and dst to have the same element type"); + return failure(); + } + + // Check index element type consistency. + Type ei0 = getElemTy(ti0); + Type ei1 = getElemTy(ti1); + if (!ei0 || !ei1) { + emitOpError("failed to get element type for index operands"); + return failure(); + } + if (ei0 != ei1) { + emitOpError("expects src0Idx and src1Idx to have the same element type"); + return failure(); + } + + // All five tiles must be rank-2. + auto v0 = getValidShapeVec(getSrc0()); + auto v1 = getValidShapeVec(getSrc1()); + auto vi0 = getValidShapeVec(getSrc0Idx()); + auto vi1 = getValidShapeVec(getSrc1Idx()); + auto vd = getValidShapeVec(getDst()); + if (v0.size() != 2 || v1.size() != 2 || vi0.size() != 2 || + vi1.size() != 2 || vd.size() != 2) + return emitOpError("expects all operands to have rank-2 valid_shape"); + + // validRow must match dst (when known). + auto checkValidRow = [&](const auto &v, StringRef name) -> LogicalResult { + if (v[0] != ShapedType::kDynamic && vd[0] != ShapedType::kDynamic && + v[0] != vd[0]) + return emitOpError("expects ") << name << " valid row to match dst valid row"; + return success(); + }; + if (failed(checkValidRow(v0, "src0")) || + failed(checkValidRow(v1, "src1")) || + failed(checkValidRow(vi0, "src0Idx")) || + failed(checkValidRow(vi1, "src1Idx"))) + return failure(); + + // Index tile must have cols >= 1 (when known). + if (vi0[1] != ShapedType::kDynamic && vi0[1] < 1) + return emitOpError("expects src0Idx valid_col >= 1"); + if (vi1[1] != ShapedType::kDynamic && vi1[1] < 1) + return emitOpError("expects src1Idx valid_col >= 1"); + + return std::make_pair(e0, ei0); + }; + + auto verifyElementTypes = [&](Type dataElem, Type idxElem) -> LogicalResult { + // Data element type: f16, f32, bf16, i8, i16, i32 (signless). + if (!dataElem.isF16() && !dataElem.isF32() && !dataElem.isBF16()) { + auto it = mlir::dyn_cast(dataElem); + if (!it || !it.isSignless() || + (it.getWidth() != 8 && it.getWidth() != 16 && it.getWidth() != 32)) + return emitOpError() + << "expects data element type to be i8, i16, i32, f16, f32, or bf16"; + } + + // Index element type: i8, i16, i32 (signless). + auto it = mlir::dyn_cast(idxElem); + if (!it || !it.isSignless() || + (it.getWidth() != 8 && it.getWidth() != 16 && it.getWidth() != 32)) + return emitOpError() + << "expects index element type to be i8, i16, or i32"; + return success(); + }; + + auto verifyLocVec = [&](Type ty, StringRef name) -> LogicalResult { + auto as = getPTOMemorySpaceEnum(ty); + if (!as || *as != pto::AddressSpace::VEC) + return emitOpError() << "expects " << name << " to use loc=vec"; + return success(); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + auto elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + if (failed(verifyLocVec(getSrc0().getType(), "src0")) || + failed(verifyLocVec(getSrc1().getType(), "src1")) || + failed(verifyLocVec(getSrc0Idx().getType(), "src0Idx")) || + failed(verifyLocVec(getSrc1Idx().getType(), "src1Idx")) || + failed(verifyLocVec(getDst().getType(), "dst"))) + return failure(); + return verifyElementTypes(elemOr->first, elemOr->second); + }; + + auto verifyA5 = [&]() -> LogicalResult { + auto elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + if (failed(verifyLocVec(getSrc0().getType(), "src0")) || + failed(verifyLocVec(getSrc1().getType(), "src1")) || + failed(verifyLocVec(getSrc0Idx().getType(), "src0Idx")) || + failed(verifyLocVec(getSrc1Idx().getType(), "src1Idx")) || + failed(verifyLocVec(getDst().getType(), "dst"))) + return failure(); + if (!isRowMajorTileBuf(getSrc0().getType()) || + !isRowMajorTileBuf(getSrc1().getType()) || + !isRowMajorTileBuf(getSrc0Idx().getType()) || + !isRowMajorTileBuf(getSrc1Idx().getType()) || + !isRowMajorTileBuf(getDst().getType())) + return emitOpError( + "expects all operands to use row-major layout"); + return verifyElementTypes(elemOr->first, elemOr->second); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult pto::TAndSOp::verify() { + auto verifyCommon = [&]() -> FailureOr { + return verifyDistinctRowMajorUnaryTileOpCommon(getOperation(), getSrc(), + getDst(), "src", "dst"); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + FailureOr elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + auto it = mlir::dyn_cast(*elemOr); + if (!it || (it.getWidth() != 8 && it.getWidth() != 16)) + return emitOpError( + "expects A2/A3 tands src, scalar, and dst element type to be i8/i16"); + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + FailureOr elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + auto it = mlir::dyn_cast(*elemOr); + if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && + it.getWidth() != 32)) + return emitOpError( + "expects A5 tands src, scalar, and dst element type to be i8/i16/i32"); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult pto::TCIOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + + auto elemTy = mlir::dyn_cast(getElemTy(dstTy)); + if (!elemTy) + return emitOpError("expects dst element type to be integer"); + + unsigned bw = elemTy.getWidth(); + if (bw != 16 && bw != 32) + return emitOpError("expects dst element type to be i16/i32"); + + auto sTy = mlir::dyn_cast(getOperand(0).getType()); + if (!sTy) + return emitOpError("expects S to be integer"); + + if (sTy != elemTy) + return emitOpError("expects S and dst element type to be exactly the same type"); + auto shape = getShapeVec(dstTy); + if (shape.size() != 2) + return emitOpError("expects dst to be rank-2"); + if (shape[1] != ShapedType::kDynamic && shape[1] == 1) + return emitOpError("expects dst cols to be different from 1"); + + return success(); +} + +LogicalResult pto::TTriOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + + Type dstTy = getDst().getType(); + if (failed(verifyVecTileCommon(*this, dstTy, "dst"))) + return failure(); + + auto diagonalTy = mlir::dyn_cast(getDiagonal().getType()); + if (!diagonalTy) + return emitOpError("expects diagonal to be an integer operand"); + + int32_t upperOrLower = getUpperOrLower(); + if (upperOrLower != 0 && upperOrLower != 1) + return emitOpError("expects upperOrLower to be 0 (lower) or 1 (upper)"); + + Type elemTy = getElemTy(dstTy); + return dispatchVerifierByArch( + getOperation(), + [&]() -> LogicalResult { + if (!isSupportedVecElemType(elemTy, /*allowBf16=*/false, + /*allowInt8=*/false)) + return emitOpError() + << "expects A2/A3 dst element type to be f16/f32/i16/i32/u16/u32"; + return success(); + }, + [&]() -> LogicalResult { + if (!isSupportedVecElemType(elemTy, /*allowBf16=*/true, + /*allowInt8=*/true)) + return emitOpError() + << "expects A5 dst element type to be f16/f32/bf16/i8/i16/i32/u8/u16/u32"; + return success(); + }); +} + +LogicalResult pto::TCmpOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + Type t0 = getSrc0().getType(); + Type t1 = getSrc1().getType(); + Type td = getDst().getType(); + if (failed(verifyVecTileStorage(*this, t0, "src0")) || + failed(verifyVecTileStorage(*this, t1, "src1")) || + failed(verifyVecTileStorage(*this, td, "dst"))) + return failure(); + + Type e0 = getElemTy(t0); + Type e1 = getElemTy(t1); + Type ed = getElemTy(td); + if (!e0 || !e1 || !ed) + return emitOpError("failed to get element type for src0/src1/dst"); + if (e0 != e1) + return emitOpError("expects src0 and src1 to have the same element type"); + if (!(e0.isInteger(32) || e0.isF16() || e0.isF32())) + return emitOpError("expects A2/A3 tcmp input element type to be i32/f16/f32"); + if (!ed.isInteger(8)) + return emitOpError("expects dst element type to be i8"); + + auto valid0 = getValidShapeVec(t0); + auto valid1 = getValidShapeVec(t1); + auto validd = getValidShapeVec(td); + if (valid0.size() != 2 || valid1.size() != 2 || validd.size() != 2) + return emitOpError("expects src0, src1, and dst to have rank-2 valid_shape"); + if (!hasCompatibleKnownExtent(valid0[0], valid1[0])) + return emitOpError("expects src0 and src1 to have the same valid row"); + if (!hasCompatibleKnownExtent(valid0[1], valid1[1])) + return emitOpError("expects src0 and src1 to have the same valid column"); + if (!hasCompatibleKnownExtent(valid0[0], validd[0])) + return emitOpError("expects src0 valid row to equal dst valid row"); + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + Type t0 = getSrc0().getType(); + Type t1 = getSrc1().getType(); + Type td = getDst().getType(); + if (failed(verifyTileBufCommon(*this, t0, "src0")) || + failed(verifyTileBufCommon(*this, t1, "src1")) || + failed(verifyTileBufCommon(*this, td, "dst"))) + return failure(); + + Type e0 = getElemTy(t0); + Type e1 = getElemTy(t1); + Type ed = getElemTy(td); + if (!e0 || !e1 || !ed) + return emitOpError("failed to get element type for src0/src1/dst"); + if (e0 != e1) + return emitOpError("expects src0 and src1 to have the same element type"); + bool inputOk = e0.isF16() || e0.isF32() || e0.isBF16() || + e0.isInteger(8) || e0.isInteger(16) || e0.isInteger(32); + if (!inputOk) + return emitOpError("expects A5 tcmp input element type to be i8/i16/i32/f16/bf16/f32"); + if (auto it = dyn_cast(ed)) { + if (it.getWidth() != 8) + return emitOpError("expects dst element type to be i8"); + } else { + return emitOpError("expects dst element type to be i8"); + } + + if (getShapeVec(t0) != getShapeVec(t1) || getShapeVec(t0) != getShapeVec(td)) + return emitOpError("expects src0, src1, and dst to have the same shape"); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +// ---- TCMPS verify ---- +LogicalResult pto::TCmpSOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyVecTileStorage(*this, srcTy, "src")) || + failed(verifyVecTileStorage(*this, dstTy, "dst"))) + return failure(); + + Type elemTy = getElemTy(srcTy); + if (!(elemTy.isInteger(16) || elemTy.isInteger(32) || + elemTy.isF16() || elemTy.isF32())) + return emitOpError("expects A2/A3 tcmps input element type to be i16/i32/f16/f32"); + + auto scalarTy = getScalar().getType(); + if (!(scalarTy.isIntOrIndexOrFloat())) + return emitOpError("expects scalar to be integer, index, or float"); + + auto srcValid = getValidShapeVec(srcTy); + auto dstValid = getValidShapeVec(dstTy); + if (srcValid.size() != 2 || dstValid.size() != 2) + return emitOpError("expects src and dst to have rank-2 valid_shape"); + if (srcValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && + srcValid[0] != dstValid[0]) + return emitOpError("expects src and dst to have the same valid_shape[0]"); + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyVecTileStorage(*this, srcTy, "src")) || + failed(verifyVecTileStorage(*this, dstTy, "dst"))) + return failure(); + + Type elemTy = getElemTy(srcTy); + if (!(elemTy.isInteger(8) || elemTy.isInteger(16) || elemTy.isInteger(32) || + elemTy.isF16() || elemTy.isF32())) + return emitOpError("expects A5 tcmps input element type to be i8/i16/i32/f16/f32"); + + auto scalarTy = getScalar().getType(); + if (!(scalarTy.isIntOrIndexOrFloat())) + return emitOpError("expects scalar to be integer, index, or float"); + + auto srcValid = getValidShapeVec(srcTy); + auto dstValid = getValidShapeVec(dstTy); + if (srcValid.size() != 2 || dstValid.size() != 2) + return emitOpError("expects src and dst to have rank-2 valid_shape"); + if (srcValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && + srcValid[0] != dstValid[0]) + return emitOpError("expects src and dst to have the same valid_shape[0]"); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} +LogicalResult pto::TColExpandOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyNDStyleVecTile(*this, srcTy, "src")) || + failed(verifyNDStyleVecTile(*this, dstTy, "dst"))) + return failure(); + if (getElemTy(srcTy) != getElemTy(dstTy)) + return emitOpError("expects src and dst to have the same element type"); + if (!isSupportedVecElemType(getElemTy(srcTy), /*allowBf16=*/true, + /*allowInt8=*/true)) + return emitOpError("expects tcolexpand element type to be supported"); + auto srcValid = getValidShapeVec(getSrc()); + auto dstValid = getValidShapeVec(getDst()); + if (srcValid.size() != 2 || dstValid.size() != 2) + return emitOpError("expects src and dst to have rank-2 valid_shape"); + if (srcValid[1] != ShapedType::kDynamic && dstValid[1] != ShapedType::kDynamic && + srcValid[1] != dstValid[1]) + return emitOpError("expects src and dst to have the same valid_shape[1]"); + return success(); +} +static LogicalResult verifyTColExpandBinaryLikeOp(Operation *op, Type t0, Type t1, + Type td, PTOArch targetArch, + StringRef opName, + bool allowIntegerTypes) { + if (!isPTOShapedLike(t0) || !isPTOShapedLike(t1) || !isPTOShapedLike(td)) + return op->emitOpError("expects src0/src1/dst to be PTO shaped-like types"); + + Type e0 = getElemTy(t0); + Type e1 = getElemTy(t1); + Type ed = getElemTy(td); + if (!e0 || !e1 || !ed) + return op->emitOpError("failed to get element type for src0/src1/dst"); + + auto isSupportedElem = [&](Type elemTy) { + if (elemTy.isF16() || elemTy.isF32()) + return true; + if (!allowIntegerTypes) + return false; + if (elemTy.isInteger(16) || elemTy.isInteger(32)) + return true; + return targetArch == PTOArch::A5 && elemTy.isInteger(8); + }; + if (!isSupportedElem(e0) || !isSupportedElem(e1) || !isSupportedElem(ed)) { + if (!allowIntegerTypes) + return op->emitOpError() << "expects " << opName + << " element type to be f16 or f32"; + if (targetArch == PTOArch::A5) + return op->emitOpError() << "expects A5 " << opName + << " element type to be i8/i16/i32/f16/f32"; + return op->emitOpError() << "expects A2/A3 " << opName + << " element type to be i16/i32/f16/f32"; + } + + if (getShapeVec(t0) != getShapeVec(td)) + return op->emitOpError("expects src0/dst to have same shape"); + if (failed(verifyTileBufSameValidShape(op, t0, td, "src0", "dst"))) + return failure(); + + if (auto src0TileTy = dyn_cast(t0)) { + if (src0TileTy.getBLayoutValueI32() != 0) + return op->emitOpError("expects src0 to use row-major layout"); + } + + if (auto src1TileTy = dyn_cast(t1)) { + if (src1TileTy.getBLayoutValueI32() != 0) + return op->emitOpError("expects src1 to use row-major layout"); + } + if (auto dstTileTy = dyn_cast(td)) { + if (dstTileTy.getBLayoutValueI32() != 0) + return op->emitOpError("expects dst to use row-major layout"); + } + + auto src1Valid = getValidShapeVec(t1); + auto dstValid = getValidShapeVec(td); + if (src1Valid.size() == 2 && dstValid.size() == 2 && + src1Valid[1] != ShapedType::kDynamic && dstValid[1] != ShapedType::kDynamic && + src1Valid[1] != dstValid[1]) + return op->emitOpError("expects src1 valid_shape[1] to equal dst valid_shape[1]"); + + return success(); +} +LogicalResult pto::TColExpandMulOp::verify() { + PTOArch arch = getTargetArch(getOperation()); + return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + arch, "tcolexpandmul", + /*allowIntegerTypes=*/true); +} +LogicalResult pto::TColExpandAddOp::verify() { + PTOArch arch = getTargetArch(getOperation()); + return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + arch, "tcolexpandadd", + /*allowIntegerTypes=*/true); +} +LogicalResult pto::TColExpandDivOp::verify() { + auto verifyByArch = [&](PTOArch targetArch) -> LogicalResult { + bool allowIntegerTypes = (targetArch == PTOArch::A5); + return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + targetArch, "tcolexpanddiv", + /*allowIntegerTypes=*/allowIntegerTypes); + }; + auto verifyA2A3 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A3); }; + auto verifyA5 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A5); }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} +LogicalResult pto::TColExpandSubOp::verify() { + PTOArch arch = getTargetArch(getOperation()); + return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + arch, "tcolexpandsub", + /*allowIntegerTypes=*/true); +} +LogicalResult pto::TColExpandExpdifOp::verify() { + PTOArch arch = getTargetArch(getOperation()); + return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + arch, "tcolexpandexpdif", + /*allowIntegerTypes=*/false); +} +LogicalResult pto::TColExpandMaxOp::verify() { + PTOArch arch = getTargetArch(getOperation()); + return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + arch, "tcolexpandmax", + /*allowIntegerTypes=*/true); +} +LogicalResult pto::TColExpandMinOp::verify() { + PTOArch arch = getTargetArch(getOperation()); + return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + arch, "tcolexpandmin", + /*allowIntegerTypes=*/true); +} +LogicalResult pto::TColMaxOp::verify() { + return verifyTColReductionOpWithArchDispatch( + getOperation(), getSrc().getType(), getDst().getType(), + /*requireNonZeroSrcOnA2A3=*/false, /*requireNonZeroSrcOnA5=*/true, + /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, + "expects A2/A3 tcolmax element type to be f16/f32/i16/i32", + "expects A5 tcolmax element type to be i8/i16/i32/f16/bf16/f32"); +} + +LogicalResult pto::TColArgMaxOp::verify() { + auto verifyByArch = [&]() -> LogicalResult { + return verifyTColArgReductionOpCommon(*this, getSrc().getType(), + getTmp().getType(), getDst().getType()); + }; + + return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); +} + +LogicalResult pto::TColMinOp::verify() { + return verifyTColReductionOpWithArchDispatch( + getOperation(), getSrc().getType(), getDst().getType(), + /*requireNonZeroSrcOnA2A3=*/false, /*requireNonZeroSrcOnA5=*/true, + /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, + "expects A2/A3 tcolmin element type to be f16/f32/i16/i32", + "expects A5 tcolmin element type to be i8/i16/i32/f16/bf16/f32"); +} + +LogicalResult pto::TColArgMinOp::verify() { + auto verifyByArch = [&]() -> LogicalResult { + return verifyTColArgReductionOpCommon(*this, getSrc().getType(), + getTmp().getType(), getDst().getType()); + }; + + return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); +} + + + +ParseResult mlir::pto::TColSumOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand src; + OpAsmParser::UnresolvedOperand tmp; + OpAsmParser::UnresolvedOperand dst; + Type srcTy, tmpTy, dstTy; + bool hasTmp = false; + + // Parse: ins(%src : type) or ins(%src, %tmp {isBinary = ...}: type, type) + if (parser.parseKeyword("ins") || parser.parseLParen() || parser.parseOperand(src)) + return failure(); + + // Check for optional tmp operand (format 2) + if (succeeded(parser.parseOptionalComma())) { + // Format 2: ins(%src, %tmp {isBinary = ...}: type, type) + if (parser.parseOperand(tmp)) + return failure(); + hasTmp = true; + + // Parse attributes (isBinary) + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + // Parse types: : type, type + if (parser.parseColonType(srcTy) || parser.parseComma() || parser.parseType(tmpTy)) + return failure(); + } else { + // Format 1: ins(%src : type) + if (parser.parseColonType(srcTy)) + return failure(); + } + + if (parser.parseRParen()) + return failure(); + + // Parse: outs(%dst : type) + if (parser.parseKeyword("outs") || parser.parseLParen() || + parser.parseOperand(dst) || parser.parseColonType(dstTy) || + parser.parseRParen()) + return failure(); + + // Parse any remaining attributes (for format 1) + if (!hasTmp) { + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + } + + // Resolve operands + if (parser.resolveOperand(src, srcTy, result.operands)) + return failure(); + + if (hasTmp) { + if (parser.resolveOperand(tmp, tmpTy, result.operands)) + return failure(); + } + + if (parser.resolveOperand(dst, dstTy, result.operands)) + return failure(); + + return success(); +} + +void mlir::pto::TColSumOp::print(OpAsmPrinter &p) { + if (getTmp()) { + // Format 2: ins(%src, %tmp {isBinary = ...}: type, type) outs(%dst : type) + p << " ins(" << getSrc() << ", " << getTmp(); + // Print isBinary attribute if present + SmallVector elidedAttrs; + if (!getIsBinaryAttr() || getIsBinaryAttr().getValue() == false) { + elidedAttrs.push_back("isBinary"); + } + p.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + p << " : " << getSrc().getType() << ", " << getTmp().getType() << ")"; + } else { + // Format 1: ins(%src : type) outs(%dst : type) + p << " ins(" << getSrc() << " : " << getSrc().getType() << ")"; + } + + p << " outs(" << getDst() << " : " << getDst().getType() << ")"; + + // Print remaining attributes for format 1 (excluding isBinary) + if (!getTmp()) { + SmallVector elidedAttrs = {"isBinary"}; + p.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + } +} + +LogicalResult pto::TColSumOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyNDStyleVecTile(*this, srcTy, "src")) || + failed(verifyNDStyleVecTile(*this, dstTy, "dst"))) + return failure(); + bool hasTmp = (bool)getTmp(); + bool hasIsBinary = (bool)getIsBinaryAttr(); + if (hasTmp != hasIsBinary) { + if (hasTmp) + return emitOpError("tmp operand requires isBinary attribute"); + return emitOpError("isBinary attribute requires tmp operand"); + } + if (getTmp()) { + Type tmpTy = getTmp().getType(); + if (failed(verifyNDStyleVecTile(*this, tmpTy, "tmp"))) + return failure(); + if (getElemTy(srcTy) != getElemTy(dstTy) || getElemTy(srcTy) != getElemTy(tmpTy)) + return emitOpError("expects src/tmp/dst element types to match"); + } + if (getElemTy(srcTy) != getElemTy(dstTy)) + return emitOpError("expects src/dst element types to match"); + if (failed(verifyColReductionValidRegion(*this, srcTy, dstTy, + /*requireNonZeroSrc=*/false))) + return failure(); + Type elem = getElemTy(srcTy); + if (!(elem.isF16() || elem.isF32() || elem.isInteger(16) || elem.isInteger(32))) + return emitOpError("expects A2/A3 tcolsum element type to be f16/f32/i16/i32"); + return success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyNDStyleVecTile(*this, srcTy, "src")) || + failed(verifyNDStyleVecTile(*this, dstTy, "dst"))) + return failure(); + bool hasTmp = (bool)getTmp(); + bool hasIsBinary = (bool)getIsBinaryAttr(); + if (hasTmp != hasIsBinary) { + if (hasTmp) + return emitOpError("tmp operand requires isBinary attribute"); + return emitOpError("isBinary attribute requires tmp operand"); + } + if (getTmp()) { + Type tmpTy = getTmp().getType(); + if (failed(verifyNDStyleVecTile(*this, tmpTy, "tmp"))) + return failure(); + if (getElemTy(srcTy) != getElemTy(dstTy) || getElemTy(srcTy) != getElemTy(tmpTy)) + return emitOpError("expects src/tmp/dst element types to match"); + } + if (getElemTy(srcTy) != getElemTy(dstTy)) + return emitOpError("expects src/dst element types to match"); + if (failed(verifyColReductionValidRegion(*this, srcTy, dstTy, + /*requireNonZeroSrc=*/true))) + return failure(); + Type elem = getElemTy(srcTy); + if (!(elem.isF16() || elem.isF32() || elem.isBF16() || elem.isInteger(8) || + elem.isInteger(16) || elem.isInteger(32))) + return emitOpError("expects A5 tcolsum element type to be i8/i16/i32/f16/bf16/f32"); + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult pto::TColProdOp::verify() { + return verifyTColReductionOpWithArchDispatch( + getOperation(), getSrc().getType(), getDst().getType(), + /*requireNonZeroSrcOnA2A3=*/false, /*requireNonZeroSrcOnA5=*/false, + /*allowInt8OnA5=*/false, /*allowBf16OnA5=*/true, + "expects A2/A3 tcolprod element type to be f16/f32/i16/i32", + "expects A5 tcolprod element type to be i16/ui16/i32/ui32/f16/bf16/f32"); +} + +llvm::LogicalResult mlir::pto::TCvtOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, srcTy, "src", /*allowLowPrecision=*/true)) || + failed(verifyTileBufCommon(*this, dstTy, "dst", /*allowLowPrecision=*/true))) + return failure(); + if (failed(verifyTileBufSameLogicalExtent(*this, srcTy, dstTy, "src", "dst", + /*compareValidShape=*/false))) + return failure(); + if (failed(verifyTileBufSameLogicalExtent(*this, srcTy, dstTy, "src", "dst", + /*compareValidShape=*/true))) + return failure(); + Type srcElem = getElemTy(srcTy); + Type dstElem = getElemTy(dstTy); + auto verifyA2A3 = [&]() -> LogicalResult { + if (isPTOLowPrecisionType(srcElem) || isPTOLowPrecisionType(dstElem)) + return emitOpError("expects A2/A3 tcvt low-precision element types to be unsupported"); + return success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + if (!isA5SupportedTCvtPair(srcElem, dstElem)) + return emitOpError("expects A5 tcvt low-precision type pairs to match PTO-ISA support"); + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +llvm::LogicalResult mlir::pto::TRandomOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + return emitOpError("trandom is only supported for A5 targets"); + }; + auto verifyA5 = [&]() -> LogicalResult { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + if (!isRowMajorTileBuf(dstTy)) + return emitOpError("expects dst to use row-major layout"); + + Type elemTy = getElemTy(dstTy); + if (!elemTy.isInteger(32)) + return emitOpError("expects dst element type to be i32 or ui32"); + + auto checkWord = [&](Value v, StringRef name) -> LogicalResult { + auto ty = dyn_cast(v.getType()); + if (!ty || ty.getWidth() != 32) + return emitOpError() << "expects " << name << " to be i32/ui32"; + return success(); + }; + if (failed(checkWord(getKey0(), "key0")) || + failed(checkWord(getKey1(), "key1")) || + failed(checkWord(getCounter0(), "counter0")) || + failed(checkWord(getCounter1(), "counter1")) || + failed(checkWord(getCounter2(), "counter2")) || + failed(checkWord(getCounter3(), "counter3"))) + return failure(); + + int32_t rounds = getRounds(); + if (rounds != 7 && rounds != 10) + return emitOpError("expects rounds to be 7 or 10"); + + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult mlir::pto::TDivOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + FailureOr elemOr = verifyMatchingRowMajorBinaryTileOpCommon( + getOperation(), getSrc0().getType(), getSrc1().getType(), + getDst().getType()); + if (failed(elemOr)) + return failure(); + auto elem0 = *elemOr; + if (!(elem0.isF16() || elem0.isF32())) + return emitOpError("expects A2/A3 tdiv element type to be f16 or f32"); + return success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + FailureOr elemOr = verifyMatchingRowMajorBinaryTileOpCommon( + getOperation(), getSrc0().getType(), getSrc1().getType(), + getDst().getType()); + if (failed(elemOr)) + return failure(); + auto elem0 = *elemOr; + if (!(elem0.isF16() || elem0.isF32() || elem0.isInteger(16) || elem0.isInteger(32))) + return emitOpError("expects A5 tdiv element type to be i32/i16/f16/f32"); + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TDivSOp::verify() { + auto isTileLike = [](Type ty) -> bool { + return isa(ty); + }; + auto isScalarLike = [](Type ty) -> bool { + return mlir::isa(ty); + }; + + auto verifyByArch = [&](PTOArch targetArch) -> LogicalResult { + Type srcTy = getSrc().getType(); + Type rhsTy = getScalar().getType(); + Type dstTy = getDst().getType(); + + bool srcTile = isTileLike(srcTy); + bool rhsTile = isTileLike(rhsTy); + bool srcScalar = isScalarLike(srcTy); + bool rhsScalar = isScalarLike(rhsTy); + + if (!(srcTile && rhsScalar) && !(srcScalar && rhsTile)) + return emitOpError("expects one tile-like operand and one scalar operand in ins(...)"); + + Type tileTy = srcTile ? srcTy : rhsTy; + Type scalarTy = srcTile ? rhsTy : srcTy; + + if (failed(verifyScalarTileOp(*this, tileTy, dstTy, "src", "dst", + /*requireValidRowsEqual=*/true, + /*requireValidColsEqual=*/true))) + return failure(); + if (!mlir::isa(scalarTy)) + return emitOpError("scalar must be a scalar type (integer/float)"); + Type elem = getElemTy(tileTy); + if (targetArch == PTOArch::A3 && + !(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || + elem.isF32())) + return emitOpError("expects A2/A3 tdivs element type to be i32/i16/f16/f32"); + if (targetArch == PTOArch::A5 && + !(elem.isInteger(32) || elem.isInteger(16) || elem.isInteger(8) || + elem.isF16() || elem.isF32())) + return emitOpError("expects A5 tdivs element type to be i32/i16/i8/f16/f32"); + return success(); + }; + auto verifyA2A3 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A3); }; + auto verifyA5 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A5); }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TExpOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyVecTileUnaryOp(*this, srcTy, dstTy, "src", "dst", + /*allowBf16=*/false, /*allowInt8=*/false))) + return failure(); + if (failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) + return failure(); + Type srcElem = getElemTy(srcTy); + if (!srcElem.isF16() && !srcElem.isF32()) + return emitOpError("expects element type to be f16 or f32"); + return mlir::success(); + }; + auto verifyA5 = [&]() -> LogicalResult { return verifyA2A3(); }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TExpandsOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + auto dstSpace = getPTOMemorySpaceEnum(dstTy); + if (!dstSpace || (*dstSpace != pto::AddressSpace::VEC && + *dstSpace != pto::AddressSpace::MAT)) + return emitOpError("expects dst to be in the vec or mat address space"); + Type dstElem = getElemTy(dstTy); + Type scalarTy = getScalar().getType(); + if (scalarTy != dstElem) + return emitOpError("expects scalar type == dst element type"); + if (*dstSpace == pto::AddressSpace::VEC && !isRowMajorTileBuf(dstTy)) + return emitOpError("expects vec dst to use row-major layout on A2/A3"); + if (dstElem.isF16() || dstElem.isBF16() || dstElem.isF32()) + return mlir::success(); + if (auto it = mlir::dyn_cast(dstElem)) { + unsigned w = it.getWidth(); + if (w == 16 || w == 32) + return mlir::success(); + } + return emitOpError("expects A2/A3 texpands dst element type to be i16/i32/f16/bf16/f32"); + }; + auto verifyA5 = [&]() -> LogicalResult { + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + auto dstSpace = getPTOMemorySpaceEnum(dstTy); + if (!dstSpace || (*dstSpace != pto::AddressSpace::VEC && + *dstSpace != pto::AddressSpace::MAT)) + return emitOpError("expects dst to be in the vec or mat address space"); + Type dstElem = getElemTy(dstTy); + Type scalarTy = getScalar().getType(); + if (scalarTy != dstElem) + return emitOpError("expects scalar type == dst element type"); + if (dstElem.isF16() || dstElem.isBF16() || dstElem.isF32()) + return mlir::success(); + if (auto it = mlir::dyn_cast(dstElem)) { + unsigned w = it.getWidth(); + if (w == 8 || w == 16 || w == 32) + return mlir::success(); + } + return emitOpError("expects A5 texpands dst element type to be i8/i16/i32/f16/bf16/f32"); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TExtractOp::verify() { + auto hasMatExtractSourceLayoutA2A3 = [&](pto::TileBufType srcTy) -> bool { + int32_t bl = srcTy.getBLayoutValueI32(); + int32_t sl = srcTy.getSLayoutValueI32(); + return bl == static_cast(pto::BLayout::RowMajor) || + (bl != static_cast(pto::BLayout::RowMajor) && + sl == static_cast(pto::SLayout::RowMajor)); + }; + auto hasMatExtractSourceLayoutA5 = [&](pto::TileBufType srcTy, + pto::AddressSpace dstSpace) -> bool { + int32_t bl = srcTy.getBLayoutValueI32(); + int32_t sl = srcTy.getSLayoutValueI32(); + if (dstSpace == pto::AddressSpace::LEFT) { + return (bl == static_cast(pto::BLayout::RowMajor) && + sl == static_cast(pto::SLayout::ColMajor)) || + (bl != static_cast(pto::BLayout::RowMajor) && + sl == static_cast(pto::SLayout::RowMajor)) || + bl == static_cast(pto::BLayout::RowMajor); + } + return (bl == static_cast(pto::BLayout::RowMajor) && + sl == static_cast(pto::SLayout::ColMajor)) || + (bl != static_cast(pto::BLayout::RowMajor) && + sl == static_cast(pto::SLayout::RowMajor)); + }; + auto isA2A3ExtractElemType = [&](Type ty) -> bool { + return ty.isInteger(8) || ty.isF16() || ty.isBF16() || ty.isF32(); + }; + auto isA5ExtractElemType = [&](Type ty) -> bool { + if (auto it = dyn_cast(ty)) + return it.getWidth() == 8; + if (auto ft = dyn_cast(ty)) + return ft.getWidth() == 8 || ft.isF16() || ft.isBF16() || ft.isF32(); + return false; + }; + auto isRowMajorNoneBoxND = [&](pto::TileBufType ty) -> bool { + return ty.getBLayoutValueI32() == static_cast(pto::BLayout::RowMajor) && + ty.getSLayoutValueI32() == static_cast(pto::SLayout::NoneBox); + }; + auto verifyCommon = [&]() -> FailureOr, + std::optional>> { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + auto srcTb = dyn_cast(srcTy); + auto dstTb = dyn_cast(dstTy); + if (!srcTb || !dstTb) + return emitOpError("expects src and dst to be !pto.tile_buf"); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, dstTy, "dst")) || + failed(verifyNonNegativeIndexRowCol( + *getOperation(), getIndexRow(), getIndexCol(), + /*includeIndexAndIntOpsInConstFold=*/false)) || + failed(verifyExtractStaticBoundsCommon( + *getOperation(), getIndexRow(), getIndexCol(), srcTy, dstTy, + /*includeIndexAndIntOpsInConstFold=*/false))) + return failure(); + Type srcElem = getElemTy(srcTy); + Type dstElem = getElemTy(dstTy); + if (!srcElem || !dstElem || srcElem != dstElem) + return emitOpError("expects src and dst to have the same element type"); + auto srcSpace = getPTOMemorySpaceEnum(srcTy); + auto dstSpace = getPTOMemorySpaceEnum(dstTy); + return std::make_tuple(srcTy, dstTy, srcTb, dstTb, srcElem, dstElem, + srcSpace, dstSpace); + }; + auto verifyA2A3 = [&]() -> LogicalResult { + auto common = verifyCommon(); + if (failed(common)) + return failure(); + auto [srcTy, dstTy, srcTb, dstTb, srcElem, dstElem, srcSpace, dstSpace] = + *common; + (void)srcTy; + (void)dstTy; + (void)srcElem; + if (!isA2A3ExtractElemType(dstElem)) + return emitOpError("expects A2/A3 textract element type to be i8/f16/bf16/f32"); + if (srcSpace && dstSpace && *srcSpace == pto::AddressSpace::VEC && + *dstSpace == pto::AddressSpace::VEC) + return mlir::success(); + if (!srcSpace || *srcSpace != pto::AddressSpace::MAT) + return emitOpError("expects A2/A3 textract src to use loc=mat or vec"); + if (!dstSpace || (*dstSpace != pto::AddressSpace::LEFT && + *dstSpace != pto::AddressSpace::RIGHT)) + return emitOpError("expects A2/A3 textract dst to use loc=left, loc=right, or loc=vec"); + if (!hasMatExtractSourceLayoutA2A3(srcTb)) + return emitOpError("expects A2/A3 textract src to use a supported mat blayout/slayout combination"); + if (*dstSpace == pto::AddressSpace::LEFT) { + if (dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) || + dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor)) + return emitOpError("expects A2/A3 left dst to use row_major blayout and row_major slayout"); + } else { + if (dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) || + dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::ColMajor)) + return emitOpError("expects A2/A3 right dst to use row_major blayout and col_major slayout"); + } + return mlir::success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + auto common = verifyCommon(); + if (failed(common)) + return failure(); + auto [srcTy, dstTy, srcTb, dstTb, srcElem, dstElem, srcSpace, dstSpace] = + *common; + (void)srcTy; + (void)dstTy; + (void)srcElem; + if (!isA5ExtractElemType(dstElem)) + return emitOpError("expects A5 textract element type to be an fp8/f16/bf16/f32 or int8 family type"); + if (!srcSpace || !dstSpace) + return emitOpError("expects src and dst to have explicit loc"); + bool okPair = + (*srcSpace == pto::AddressSpace::MAT && + (*dstSpace == pto::AddressSpace::LEFT || + *dstSpace == pto::AddressSpace::RIGHT || + *dstSpace == pto::AddressSpace::SCALING)) || + (*srcSpace == pto::AddressSpace::VEC && + (*dstSpace == pto::AddressSpace::MAT || + *dstSpace == pto::AddressSpace::VEC)); + if (!okPair) + return emitOpError("expects A5 textract to use a supported src/dst loc pair"); + if (*srcSpace == pto::AddressSpace::MAT) { + if (!hasMatExtractSourceLayoutA5(srcTb, *dstSpace)) + return emitOpError("expects A5 textract src to use a supported mat blayout/slayout combination"); + if (*dstSpace == pto::AddressSpace::LEFT) { + if (dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor) || + dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor)) + return emitOpError("expects A5 left dst to use col_major blayout and row_major slayout"); + } else if (*dstSpace == pto::AddressSpace::RIGHT) { + if (dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) || + dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::ColMajor)) + return emitOpError("expects A5 right dst to use row_major blayout and col_major slayout"); + } + } else if (*srcSpace == pto::AddressSpace::VEC && + *dstSpace == pto::AddressSpace::VEC) { + if (!isRowMajorNoneBoxND(srcTb) || !isRowMajorNoneBoxND(dstTb)) + return emitOpError( + "expects A5 vec->vec textract src/dst to use ND layout " + "(blayout=row_major, slayout=none_box)"); + } + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} +mlir::LogicalResult mlir::pto::TInsertOp::verify() { + auto isColMajorRowMajorNZ = [&](pto::TileBufType ty) -> bool { + return ty.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) && + ty.getSLayoutValueI32() == static_cast(pto::SLayout::RowMajor); + }; + auto isRowMajorNoneBoxND = [&](pto::TileBufType ty) -> bool { + return ty.getBLayoutValueI32() == static_cast(pto::BLayout::RowMajor) && + ty.getSLayoutValueI32() == static_cast(pto::SLayout::NoneBox); + }; + auto isA5SupportedVecElemType = [&](Type ty) -> bool { + if (auto it = dyn_cast(ty)) + return it.getWidth() == 8 || it.getWidth() == 32; + if (auto ft = dyn_cast(ty)) + return ft.getWidth() == 8 || ft.isF16() || ft.isBF16() || ft.isF32(); + return false; + }; + auto isA2A3VecInsertElemType = [&](Type ty) -> bool { + return ty.isInteger(8) || ty.isF16() || ty.isBF16() || ty.isF32(); + }; + auto verifyCommon = [&]() -> FailureOr, + std::optional>> { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + auto srcTb = dyn_cast(srcTy); + auto dstTb = dyn_cast(dstTy); + if (!srcTb || !dstTb) + return emitOpError("expects src and dst to be !pto.tile_buf"); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, dstTy, "dst")) || + failed(verifyNonNegativeIndexRowCol( + *getOperation(), getIndexRow(), getIndexCol(), + /*includeIndexAndIntOpsInConstFold=*/true)) || + failed(verifyInsertStaticBoundsCommon( + *getOperation(), getIndexRow(), getIndexCol(), srcTy, dstTy, + /*includeIndexAndIntOpsInConstFold=*/true))) + return failure(); + Type srcElem = getElemTy(srcTy); + Type dstElem = getElemTy(dstTy); + auto srcSpace = getPTOMemorySpaceEnum(srcTy); + auto dstSpace = getPTOMemorySpaceEnum(dstTy); + return std::make_tuple(srcTy, dstTy, srcTb, dstTb, srcElem, dstElem, + srcSpace, dstSpace); + }; + auto verifyA2A3 = [&]() -> LogicalResult { + auto common = verifyCommon(); + if (failed(common)) + return failure(); + auto [srcTy, dstTy, srcTb, dstTb, srcElem, dstElem, srcSpace, dstSpace] = + *common; + if (srcSpace && dstSpace && *srcSpace == pto::AddressSpace::VEC && + *dstSpace == pto::AddressSpace::VEC) { + if (srcElem != dstElem || !isA2A3VecInsertElemType(srcElem)) + return emitOpError( + "expects A2/A3 vec->vec tinsert src/dst to have same supported dtype " + "(i8/f16/bf16/f32)"); + return success(); + } + if (!srcSpace || !dstSpace || *srcSpace != pto::AddressSpace::ACC || + *dstSpace != pto::AddressSpace::MAT) + return emitOpError("expects A2/A3 tinsert to use acc->mat or vec->vec"); + + if (!isColMajorRowMajorNZ(srcTb)) + return emitOpError("expects A2/A3 tinsert src to use blayout=col_major and slayout=row_major"); + if (!isColMajorRowMajorNZ(dstTb)) + return emitOpError("expects A2/A3 tinsert dst to use blayout=col_major and slayout=row_major"); + if (dstTb.getSFractalSizeI32() != 512) + return emitOpError("expects A2/A3 tinsert dst fractal size to be 512"); + + if (!(srcElem.isF32() && (dstElem.isF16() || dstElem.isBF16()))) + return emitOpError("expects A2/A3 tinsert element types to be src=f32, dst=f16/bf16"); + return success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + auto common = verifyCommon(); + if (failed(common)) + return failure(); + auto [srcTy, dstTy, srcTb, dstTb, srcElem, dstElem, srcSpace, dstSpace] = + *common; + if (!srcSpace || !dstSpace) + return emitOpError("expects A5 tinsert src/dst to have explicit loc"); + + // A5 regular acc->mat path. + if (*srcSpace == pto::AddressSpace::ACC && *dstSpace == pto::AddressSpace::MAT) { + if (!isColMajorRowMajorNZ(srcTb)) + return emitOpError("expects A5 acc->mat tinsert src to use blayout=col_major and slayout=row_major"); + if (!isColMajorRowMajorNZ(dstTb)) + return emitOpError("expects A5 acc->mat tinsert dst to use blayout=col_major and slayout=row_major"); + bool okTypes = (srcElem.isF32() && + (dstElem.isF16() || dstElem.isBF16() || dstElem.isF32())) || + (srcElem.isInteger(32) && dstElem.isInteger(32)); + if (!okTypes) + return emitOpError( + "expects A5 acc->mat tinsert element types to be " + "(src=f32,dst=f16/bf16/f32) or (src=i32,dst=i32)"); + return success(); + } + + // A5 vec->mat path (ND/NZ modes in pto-isa). + if (*srcSpace == pto::AddressSpace::VEC && *dstSpace == pto::AddressSpace::MAT) { + if (!isColMajorRowMajorNZ(dstTb)) + return emitOpError("expects A5 vec->mat tinsert dst to use blayout=col_major and slayout=row_major"); + bool srcIsND = isRowMajorNoneBoxND(srcTb); + bool srcIsNZ = isColMajorRowMajorNZ(srcTb); + if (!srcIsND && !srcIsNZ) + return emitOpError( + "expects A5 vec->mat tinsert src to use ND(row_major/none_box) or NZ(col_major/row_major) layout"); + if (srcElem != dstElem || !isA5SupportedVecElemType(srcElem)) + return emitOpError( + "expects A5 vec->mat tinsert src/dst to have same supported dtype " + "(fp8/f16/bf16/f32/i8/i32)"); + return success(); + } + + // A5 vec->vec path (PR561 ND_VEC). + if (*srcSpace == pto::AddressSpace::VEC && *dstSpace == pto::AddressSpace::VEC) { + if (!isRowMajorNoneBoxND(srcTb) || !isRowMajorNoneBoxND(dstTb)) + return emitOpError( + "expects A5 vec->vec tinsert src/dst to use ND layout " + "(blayout=row_major, slayout=none_box)"); + if (srcElem != dstElem || !isA5SupportedVecElemType(srcElem)) + return emitOpError( + "expects A5 vec->vec tinsert src/dst to have same supported dtype " + "(fp8/f16/bf16/f32/i8/i32)"); + return success(); + } + + return emitOpError( + "expects A5 tinsert to use a supported src/dst loc pair: " + "acc->mat, vec->mat, or vec->vec"); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +static bool isColMajorRowMajorNZTileBuf(pto::TileBufType ty) { + return ty.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) && + ty.getSLayoutValueI32() == static_cast(pto::SLayout::RowMajor); +} + +static bool isA2A3VectorPreQuantTypePair(Type srcElem, Type dstElem) { + if (srcElem.isF32()) + return dstElem.isInteger(8); + if (srcElem.isInteger(32)) + return dstElem.isInteger(8) || dstElem.isF16() || dstElem.isInteger(16); + return false; +} + +static bool isA5Fp8LikeType(Type ty) { + if (auto ft = dyn_cast(ty)) + return ft.getWidth() == 8; + return false; +} + +static bool isA5MxInputType(Type ty) { + return isA5Fp8LikeType(ty); +} + +static LogicalResult verifyA5MxTypeTriple(Operation *op, Type lhsTy, Type rhsTy, + Type dstTy, StringRef lhsName, + StringRef rhsName, StringRef dstName) { + Type lhsElem = getElemTy(lhsTy); + Type rhsElem = getElemTy(rhsTy); + Type dstElem = getElemTy(dstTy); + + if (!isA5MxInputType(lhsElem) || !isA5MxInputType(rhsElem)) + return op->emitOpError() + << "expects A5 mx operands " << lhsName << " and " << rhsName + << " to use fp8 element types"; + + if (!dstElem.isF32()) + return op->emitOpError() + << "expects A5 mx result " << dstName << " to use f32 element type"; + + return success(); +} + +static bool isA5VectorPreQuantTypePair(Type srcElem, Type dstElem) { + if (srcElem.isF32()) + return dstElem.isInteger(8) || isA5Fp8LikeType(dstElem) || dstElem.isF16() || + dstElem.isBF16() || dstElem.isF32(); + if (srcElem.isInteger(32)) + return dstElem.isInteger(8) || dstElem.isF16() || dstElem.isBF16(); + return false; +} + +mlir::LogicalResult mlir::pto::TExtractFPOp::verify() { + auto verifyCommon = [&]() -> FailureOr> { + Type srcTy = getSrc().getType(); + Type fpTy = getFp().getType(); + Type dstTy = getDst().getType(); + auto srcTb = dyn_cast(srcTy); + auto fpTb = dyn_cast(fpTy); + auto dstTb = dyn_cast(dstTy); + if (!srcTb || !fpTb || !dstTb) + return emitOpError("expects src, fp, and dst to be !pto.tile_buf"); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, fpTy, "fp")) || + failed(verifyTileBufCommon(*this, dstTy, "dst")) || + failed(verifyNonNegativeIndexRowCol( + *getOperation(), getIndexRow(), getIndexCol(), + /*includeIndexAndIntOpsInConstFold=*/true)) || + failed(verifyExtractStaticBoundsCommon( + *getOperation(), getIndexRow(), getIndexCol(), srcTy, dstTy, + /*includeIndexAndIntOpsInConstFold=*/true))) + return failure(); + auto srcSpace = getPTOMemorySpaceEnum(srcTy); + auto fpSpace = getPTOMemorySpaceEnum(fpTy); + auto dstSpace = getPTOMemorySpaceEnum(dstTy); + if (!srcSpace || !fpSpace || !dstSpace) + return emitOpError("expects src, fp, and dst to have explicit loc"); + if (*srcSpace != pto::AddressSpace::ACC) + return emitOpError("expects src to use loc=acc"); + if (*fpSpace != pto::AddressSpace::SCALING) + return emitOpError("expects fp to use loc=scaling"); + if (*dstSpace != pto::AddressSpace::MAT) + return emitOpError("expects dst to use loc=mat"); + if (!isColMajorRowMajorNZTileBuf(srcTb)) + return emitOpError("expects src to use blayout=col_major and slayout=row_major"); + if (!isColMajorRowMajorNZTileBuf(dstTb)) + return emitOpError("expects dst to use blayout=col_major and slayout=row_major"); + return std::make_tuple(srcTy, fpTy, dstTy, srcTb, fpTb, dstTb, *srcSpace, + *fpSpace, *dstSpace); + }; + auto verifyA2A3 = [&]() -> LogicalResult { + auto common = verifyCommon(); + if (failed(common)) + return failure(); + auto [srcTy, fpTy, dstTy, srcTb, fpTb, dstTb, srcSpace, fpSpace, dstSpace] = + *common; + (void)fpTy; + (void)srcSpace; + (void)fpSpace; + (void)dstSpace; + if (dstTb.getSFractalSizeI32() != 512) + return emitOpError("expects dst fractal size to be 512"); + Type srcElem = getElemTy(srcTy); + Type dstElem = getElemTy(dstTy); + if (!isA2A3VectorPreQuantTypePair(srcElem, dstElem)) + return emitOpError( + "expects A2/A3 textract_fp element types to be (src=f32,dst=i8) " + "or (src=i32,dst=i8/f16/i16)"); + return success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + auto common = verifyCommon(); + if (failed(common)) + return failure(); + auto [srcTy, fpTy, dstTy, srcTb, fpTb, dstTb, srcSpace, fpSpace, dstSpace] = + *common; + (void)fpTy; + (void)srcTb; + (void)fpTb; + (void)dstTb; + (void)srcSpace; + (void)fpSpace; + (void)dstSpace; + Type srcElem = getElemTy(srcTy); + Type dstElem = getElemTy(dstTy); + if (!isA5VectorPreQuantTypePair(srcElem, dstElem)) + return emitOpError( + "expects A5 textract_fp element types to be (src=f32,dst=i8/fp8/f16/bf16/f32) " + "or (src=i32,dst=i8/f16/bf16)"); + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TInsertFPOp::verify() { + auto verifyCommon = [&]() -> FailureOr> { + Type srcTy = getSrc().getType(); + Type fpTy = getFp().getType(); + Type dstTy = getDst().getType(); + auto srcTb = dyn_cast(srcTy); + auto fpTb = dyn_cast(fpTy); + auto dstTb = dyn_cast(dstTy); + if (!srcTb || !fpTb || !dstTb) + return emitOpError("expects src, fp, and dst to be !pto.tile_buf"); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, fpTy, "fp")) || + failed(verifyTileBufCommon(*this, dstTy, "dst")) || + failed(verifyNonNegativeIndexRowCol( + *getOperation(), getIndexRow(), getIndexCol(), + /*includeIndexAndIntOpsInConstFold=*/true)) || + failed(verifyInsertStaticBoundsCommon( + *getOperation(), getIndexRow(), getIndexCol(), srcTy, dstTy, + /*includeIndexAndIntOpsInConstFold=*/true))) + return failure(); + auto srcSpace = getPTOMemorySpaceEnum(srcTy); + auto fpSpace = getPTOMemorySpaceEnum(fpTy); + auto dstSpace = getPTOMemorySpaceEnum(dstTy); + if (!srcSpace || !fpSpace || !dstSpace) + return emitOpError("expects src, fp, and dst to have explicit loc"); + if (*srcSpace != pto::AddressSpace::ACC) + return emitOpError("expects src to use loc=acc"); + if (*fpSpace != pto::AddressSpace::SCALING) + return emitOpError("expects fp to use loc=scaling"); + if (*dstSpace != pto::AddressSpace::MAT) + return emitOpError("expects dst to use loc=mat"); + if (!isColMajorRowMajorNZTileBuf(srcTb)) + return emitOpError("expects src to use blayout=col_major and slayout=row_major"); + if (!isColMajorRowMajorNZTileBuf(dstTb)) + return emitOpError("expects dst to use blayout=col_major and slayout=row_major"); + return std::make_tuple(srcTy, fpTy, dstTy, srcTb, fpTb, dstTb, *srcSpace, + *fpSpace, *dstSpace); + }; + auto verifyA2A3 = [&]() -> LogicalResult { + auto common = verifyCommon(); + if (failed(common)) + return failure(); + auto [srcTy, fpTy, dstTy, srcTb, fpTb, dstTb, srcSpace, fpSpace, dstSpace] = + *common; + (void)fpTy; + (void)srcTb; + (void)fpTb; + (void)srcSpace; + (void)fpSpace; + (void)dstSpace; + if (dstTb.getSFractalSizeI32() != 512) + return emitOpError("expects dst fractal size to be 512"); + Type srcElem = getElemTy(srcTy); + Type dstElem = getElemTy(dstTy); + if (!isA2A3VectorPreQuantTypePair(srcElem, dstElem)) + return emitOpError( + "expects A2/A3 tinsert_fp element types to be (src=f32,dst=i8) " + "or (src=i32,dst=i8/f16/i16)"); + return success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + auto common = verifyCommon(); + if (failed(common)) + return failure(); + auto [srcTy, fpTy, dstTy, srcTb, fpTb, dstTb, srcSpace, fpSpace, dstSpace] = + *common; + (void)fpTy; + (void)srcTb; + (void)fpTb; + (void)dstTb; + (void)srcSpace; + (void)fpSpace; + (void)dstSpace; + Type srcElem = getElemTy(srcTy); + Type dstElem = getElemTy(dstTy); + if (!isA5VectorPreQuantTypePair(srcElem, dstElem)) + return emitOpError( + "expects A5 tinsert_fp element types to be (src=f32,dst=i8/fp8/f16/bf16/f32) " + "or (src=i32,dst=i8/f16/bf16)"); + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +static mlir::LogicalResult verifyTFillPadLike(Operation *op, Type srcTy, Type dstTy, + bool allowDstExpand, + llvm::StringRef opName) { + if (!isPTOShapedLike(srcTy) || !isPTOShapedLike(dstTy)) + return op->emitError("expects src/dst to be PTO shaped-like types"); + + auto srcShape = getShapeVec(srcTy); + auto dstShape = getShapeVec(dstTy); + if (srcShape.size() != 2 || dstShape.size() != 2) + return op->emitError("expects rank-2 shaped types for src/dst"); + + auto srcElem = getElemTy(srcTy); + auto dstElem = getElemTy(dstTy); + + auto getElemBytes = [](mlir::Type t) -> int64_t { + unsigned elemBytes = getPTOStorageElemByteSize(t); + return elemBytes == 0 ? -1 : static_cast(elemBytes); + }; + + int64_t srcB = getElemBytes(srcElem); + int64_t dstB = getElemBytes(dstElem); + if (srcB < 0 || dstB < 0) + return op->emitError("unsupported element type (expects int/float element types)"); + if (srcB != dstB) + return op->emitError("expects sizeof(src element) == sizeof(dst element)"); + if (!(srcB == 1 || srcB == 2 || srcB == 4)) + return op->emitError("expects element size to be 1, 2, or 4 bytes"); + + // pto.tfillpad lowers to TFILLPAD(dst, src). For loc=mat, pto-isa only + // exposes the homogeneous overload, so src/dst must use the same Tile<...> + // specialization (including valid_shape and pad). + // Note: tfillpad_expand is intentionally not covered here because its + // cross-layer ABI contract for loc=mat heterogeneous shape expansion is not + // finalized yet. + if (opName == "tfillpad") { + auto srcTb = mlir::dyn_cast(srcTy); + auto dstTb = mlir::dyn_cast(dstTy); + auto srcSpace = getPTOMemorySpaceEnum(srcTy); + auto dstSpace = getPTOMemorySpaceEnum(dstTy); + if (srcTb && dstTb && srcSpace && dstSpace && + *srcSpace == mlir::pto::AddressSpace::MAT && + *dstSpace == mlir::pto::AddressSpace::MAT && srcTb != dstTb) { + auto dimToStr = [](int64_t dim) -> std::string { + return dim == ShapedType::kDynamic ? "?" : std::to_string(dim); + }; + SmallVector mismatchFields; + auto srcValid = getValidShapeVec(srcTy); + auto dstValid = getValidShapeVec(dstTy); + if (srcValid.size() == 2 && dstValid.size() == 2) { + if (srcValid[0] != dstValid[0]) + mismatchFields.push_back("v_row (" + dimToStr(srcValid[0]) + " vs " + + dimToStr(dstValid[0]) + ")"); + if (srcValid[1] != dstValid[1]) + mismatchFields.push_back("v_col (" + dimToStr(srcValid[1]) + " vs " + + dimToStr(dstValid[1]) + ")"); + } + if (srcTb.getPadValueI32() != dstTb.getPadValueI32()) + mismatchFields.push_back("pad (" + std::to_string(srcTb.getPadValueI32()) + + " vs " + std::to_string(dstTb.getPadValueI32()) + + ")"); + + auto diag = op->emitError() + << "expects src/dst tile types to be lowerable to TFILLPAD " + "for loc=mat"; + if (!mismatchFields.empty()) + diag << "; mismatching fields: " << llvm::join(mismatchFields, ", "); + diag << "\n src: " << srcTy; + diag << "\n dst: " << dstTy; + diag << "\n note: heterogeneous TFILLPAD overload is only available for loc=vec"; + return failure(); + } + } + + if (auto dstTileTy = mlir::dyn_cast(dstTy)) { + auto padAttr = mlir::dyn_cast(dstTileTy.getPadValueAttr()); + if (!padAttr || padAttr.getValue() == mlir::pto::PadValue::Null) + return op->emitError() << "expects dst PadVal != Null for " << opName; + } + + if (!allowDstExpand) { + if (srcShape != dstShape) + return op->emitError() + << "expects src and dst to have the same static shape for " << opName; + return mlir::success(); + } + + if (srcShape[0] > dstShape[0] || srcShape[1] > dstShape[1]) { + return op->emitError() + << "expects dst static shape to be >= src static shape for " << opName; + } + + return mlir::success(); +} + +mlir::LogicalResult mlir::pto::TFillPadOp::verify() { + return verifyTFillPadLike(getOperation(), getSrc().getType(), getDst().getType(), + /*allowDstExpand=*/false, "tfillpad"); +} + +mlir::LogicalResult mlir::pto::TFillPadExpandOp::verify() { + return verifyTFillPadLike(getOperation(), getSrc().getType(), getDst().getType(), + /*allowDstExpand=*/true, "tfillpad_expand"); +} + +mlir::LogicalResult mlir::pto::TFillPadInplaceOp::verify() { + return verifyTFillPadLike(getOperation(), getSrc().getType(), getDst().getType(), + /*allowDstExpand=*/false, "tfillpad_inplace"); +} + + +llvm::LogicalResult mlir::pto::TGatherOp::verify() { + auto isSupportedGatherElemTypeA5Index = [&](Type ty) -> bool { + if (ty.isF16() || ty.isF32()) + return true; + if (auto it = dyn_cast(ty)) { + unsigned width = it.getWidth(); + return width == 8 || width == 16 || width == 32; + } + return false; + }; + + auto verifyMaskForm = [&](bool allowA5MaskTypes) -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + + Type srcElem = getElemTy(srcTy); + Type dstElem = getElemTy(dstTy); + if (!srcElem || !dstElem) + return emitOpError("failed to get element type for src/dst"); + if (!isRowMajorTileBuf(srcTy) || !isRowMajorTileBuf(dstTy)) + return emitOpError("expects src and dst to use row-major layout"); + auto srcSpace = getPTOMemorySpaceEnum(srcTy); + auto dstSpace = getPTOMemorySpaceEnum(dstTy); + if (!srcSpace || !dstSpace || *srcSpace != pto::AddressSpace::VEC || + *dstSpace != pto::AddressSpace::VEC) + return emitOpError("expects src and dst to be in the vec address space"); + unsigned srcElemBytes = getPTOStorageElemByteSize(srcElem); + unsigned dstElemBytes = getPTOStorageElemByteSize(dstElem); + if (srcElemBytes == 0 || dstElemBytes == 0) + return emitOpError("failed to get element size for src/dst"); + if (srcElemBytes != dstElemBytes) + return emitOpError("expects src and dst element sizes to match"); + + auto dstValid = getValidShapeVec(dstTy); + auto dstShape = getShapeVec(dstTy); + if (dstValid.size() == 2 && dstShape.size() == 2 && + dstValid[1] != ShapedType::kDynamic && dstShape[1] != ShapedType::kDynamic && + dstValid[1] != dstShape[1]) { + return emitOpError("expects dst valid_shape[1] to equal dst cols"); + } + + if (allowA5MaskTypes) { + if (!(srcElemBytes == 1 || srcElemBytes == 2 || srcElemBytes == 4)) + return emitOpError("expects A5 mask-pattern gather element size to be 1, 2, or 4 bytes"); + if (!isSupportedGatherElemTypeA5(srcElem) || !isSupportedGatherElemTypeA5(dstElem)) + return emitOpError( + "expects A5 mask-pattern gather src/dst element type to be i8/i16/i32/f16/bf16/f32/fp8-like"); + } else { + if (!(srcElemBytes == 2 || srcElemBytes == 4)) + return emitOpError("expects A2/A3 mask-pattern gather element size to be 2 or 4 bytes"); + } + return success(); + }; + + auto verifyIndexForm = [&](bool allow16BitIndices, bool allowA5ElemTypes) -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + Type idxTy = getIndices().getType(); + Type tmpTy = getTmp().getType(); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, dstTy, "dst")) || + failed(verifyTileBufCommon(*this, idxTy, "indices")) || + failed(verifyTileBufCommon(*this, tmpTy, "tmp"))) + return failure(); + + Type srcElem = getElemTy(srcTy); + Type dstElem = getElemTy(dstTy); + if (!srcElem || !dstElem) + return emitOpError("failed to get element type for src/dst"); + if (srcElem != dstElem) + return emitOpError("expects src and dst to have the same element type"); + if (allowA5ElemTypes) { + if (!isSupportedGatherElemTypeA5Index(srcElem) || + !isSupportedGatherElemTypeA5Index(dstElem)) + return emitOpError( + "expects A5 gather src/dst element type to be i8/i16/i32/f16/f32"); + } else if (!isSupportedGatherElemTypeA2A3(srcElem) || + !isSupportedGatherElemTypeA2A3(dstElem)) { + return emitOpError("expects gather src/dst element type to be i16/i32/f16/f32"); + } + + auto idxElem = dyn_cast(getElemTy(idxTy)); + if (!idxElem) + return emitOpError("indices element type must be integer"); + unsigned width = idxElem.getWidth(); + if (!(width == 32 || (allow16BitIndices && width == 16))) { + return emitOpError() << "expects indices element type to be i32" + << (allow16BitIndices ? " or i16" : ""); + } + + auto dstValid = getValidShapeVec(dstTy); + auto dstShape = getShapeVec(dstTy); + if (dstValid.size() == 2 && dstShape.size() == 2 && + dstValid[1] != ShapedType::kDynamic && dstShape[1] != ShapedType::kDynamic && + dstValid[1] != dstShape[1]) { + return emitOpError("expects dst valid_shape[1] to equal dst cols"); + } + + auto idxValid = getValidShapeVec(idxTy); + auto idxShape = getShapeVec(idxTy); + if (idxValid.size() == 2 && idxShape.size() == 2 && + idxValid[1] != ShapedType::kDynamic && idxShape[1] != ShapedType::kDynamic && + idxValid[1] != idxShape[1]) { + return emitOpError("expects indices valid_shape[1] to equal indices cols"); + } + + if (!allowA5ElemTypes) { + Type tmpElem = getElemTy(tmpTy); + if (tmpElem != idxElem) + return emitOpError("expects tmp and indices to have the same element type"); + if (failed(verifyTileBufSameValidShape(*this, idxTy, tmpTy, "indices", "tmp"))) + return failure(); + } + return success(); + }; + + auto verifyCompareForm = [&](bool allowA5SrcTypes) -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + Type cdstTy = getCdst().getType(); + Type tmpTy = getTmp().getType(); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, dstTy, "dst")) || + failed(verifyTileBufCommon(*this, cdstTy, "cdst")) || + failed(verifyTileBufCommon(*this, tmpTy, "tmp"))) + return failure(); + + Type srcElem = getElemTy(srcTy); + Type dstElem = getElemTy(dstTy); + Type cdstElem = getElemTy(cdstTy); + if (!srcElem || !dstElem || !cdstElem) + return emitOpError("failed to get element type for src/dst/cdst"); + auto dstInt = dyn_cast(dstElem); + if (!dstInt || dstInt.getWidth() != 32) + return emitOpError("expects dst element type to be i32"); + if (cdstElem != dstElem) + return emitOpError("expects cdst to have the same element type as dst"); + if (getKValue().getType() != srcElem) + return emitOpError("expects kValue to have the same type as src element type"); + + auto cmpAttr = getCmpModeAttr(); + auto cmpMode = cmpAttr ? cmpAttr.getValue() : pto::CmpMode::EQ; + if (cmpMode != pto::CmpMode::EQ && cmpMode != pto::CmpMode::GT) + return emitOpError("expects compare-form tgather cmpMode to be eq or gt"); + + if (allowA5SrcTypes) { + if (!(srcElem.isF16() || srcElem.isF32() || srcElem.isInteger(16) || + srcElem.isInteger(32))) { + return emitOpError( + "expects A5 compare-form tgather src element type to be i16/i32/f16/f32"); + } + } else { + if (!(srcElem.isF16() || srcElem.isF32() || + (srcElem.isInteger(32) && cmpMode == pto::CmpMode::EQ))) { + return emitOpError( + "expects A2/A3 compare-form tgather src element type to be f16/f32, or i32 when cmpMode=eq"); + } + } + + if (failed(verifyVecTileCommonA2A3(*this, srcTy, "src")) || + failed(verifyVecTileCommonA2A3(*this, dstTy, "dst")) || + failed(verifyVecTileCommonA2A3(*this, cdstTy, "cdst")) || + failed(verifyVecTileCommonA2A3(*this, tmpTy, "tmp"))) + return failure(); + return success(); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + if (getMaskPatternAttr()) { + if (getCdst() || getIndices() || getTmp() || getKValue()) + return emitOpError("mask-pattern tgather only allows src and dst operands"); + return verifyMaskForm(/*allowA5MaskTypes=*/false); + } + if (getCdst() || getKValue()) { + if (!getCdst() || !getKValue() || !getTmp()) + return emitOpError("compare-form tgather expects dst, cdst, kValue, and tmp"); + if (getIndices()) + return emitOpError("compare-form tgather does not take indices"); + return verifyCompareForm(/*allowA5SrcTypes=*/false); + } + if (!getIndices() || !getTmp()) + return emitOpError("index-form tgather expects both indices and tmp"); + return verifyIndexForm(/*allow16BitIndices=*/false, /*allowA5ElemTypes=*/false); + }; + + auto verifyA5 = [&]() -> LogicalResult { + if (getMaskPatternAttr()) { + if (getCdst() || getIndices() || getTmp() || getKValue()) + return emitOpError("mask-pattern tgather only allows src and dst operands"); + return verifyMaskForm(/*allowA5MaskTypes=*/true); + } + if (getCdst() || getKValue()) { + if (!getCdst() || !getKValue() || !getTmp()) + return emitOpError("compare-form tgather expects dst, cdst, kValue, and tmp"); + if (getIndices()) + return emitOpError("compare-form tgather does not take indices"); + return verifyCompareForm(/*allowA5SrcTypes=*/true); + } + if (!getIndices() || !getTmp()) + return emitOpError("index-form tgather expects both indices and tmp"); + return verifyIndexForm(/*allow16BitIndices=*/true, /*allowA5ElemTypes=*/true); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} +mlir::LogicalResult mlir::pto::TGatherBOp::verify() { + auto verifyCommon = [&]() -> FailureOr> { + Type srcTy = getSrc().getType(); + Type offTy = getOffsets().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, offTy, "offsets")) || + failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + auto srcElemTy = getElemTy(srcTy); + auto dstElemTy = getElemTy(dstTy); + if (!srcElemTy || !dstElemTy) + return emitOpError() << "failed to get element type for src/dst"; + return std::make_pair(srcElemTy, dstElemTy); + }; + + auto getElemBytes = [](Type ty) -> std::optional { + unsigned elemBytes = getPTOStorageElemByteSize(ty); + if (elemBytes == 0) + return std::nullopt; + return elemBytes; + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + FailureOr> elems = verifyCommon(); + if (failed(elems)) + return failure(); + Type dstTy = getDst().getType(); + Type dstElemTy = elems->second; + if (!isRowMajorTileBuf(dstTy)) + return emitOpError() << "expects dst to use row-major layout"; + auto dstBytes = getElemBytes(dstElemTy); + if (!dstBytes || (*dstBytes != 1 && *dstBytes != 2 && *dstBytes != 4)) + return emitOpError() << "expects dst element size to be 1, 2, or 4 bytes"; + return mlir::success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + FailureOr> elems = verifyCommon(); + if (failed(elems)) + return failure(); + Type dstElemTy = elems->second; + auto dstBytes = getElemBytes(dstElemTy); + if (!dstBytes || (*dstBytes != 1 && *dstBytes != 2 && *dstBytes != 4)) + return emitOpError() << "expects dst element size to be 1, 2, or 4 bytes"; + return mlir::success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TLogOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyVecTileUnaryOp(*this, srcTy, dstTy, "src", "dst", + /*allowBf16=*/false, /*allowInt8=*/false))) + return failure(); + if (failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) + return failure(); + auto elemTy = getElemTy(srcTy); + if (!(elemTy.isF16() || elemTy.isF32())) + return emitOpError() << "expects element type to be f16 or f32"; + return mlir::success(); +} + +mlir::LogicalResult mlir::pto::TLReluOp::verify() { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + auto verifyA2A3 = [&]() -> LogicalResult { + if (failed(verifyVecTileStorage(*this, srcTy, "src")) || + failed(verifyVecTileStorage(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst")) || + failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) + return failure(); + auto valid = getValidShapeVec(srcTy); + if (valid.size() != 2) + return emitOpError("expects src to have rank-2 valid_shape"); + if (valid[0] != ShapedType::kDynamic && valid[0] <= 0) + return emitOpError("expects src valid_shape[0] to be positive"); + if (valid[1] != ShapedType::kDynamic && valid[1] <= 0) + return emitOpError("expects src valid_shape[1] to be positive"); + Type elemTy = getElemTy(srcTy); + if (!(elemTy.isF16() || elemTy.isF32())) + return emitOpError() << "expects A2/A3 tlrelu element type to be f16 or f32"; + return success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + if (failed(verifyVecTileStorage(*this, srcTy, "src")) || + failed(verifyVecTileStorage(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst")) || + failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) + return failure(); + Type elemTy = getElemTy(srcTy); + if (!(elemTy.isF16() || elemTy.isF32())) + return emitOpError() << "expects A5 tlrelu element type to be f16 or f32"; + if (!getSlope().getType().isF32()) + return emitOpError() << "expects slope to have type f32"; + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TMaxOp::verify() { + return verifyArithmeticBinaryTileOpWithArchDispatch( + getOperation(), getSrc0().getType(), getSrc1().getType(), getDst().getType(), + /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/false, + "expects A2/A3 tmax element type to be i32/i16/f16/f32", + "expects A5 tmax element type to be i32/i16/i8/f16/f32"); +} + +mlir::LogicalResult mlir::pto::TMaxSOp::verify() { + return verifyArithmeticScalarTileOpWithArchDispatch( + getOperation(), getSrc().getType(), getDst().getType(), getScalar().getType(), + /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, + "expects A2/A3 tmaxs element type to be i32/i16/f16/f32", + "expects A5 tmaxs element type to be i32/i16/i8/f16/bf16/f32", + /*requireValidRowsEqualOnA2A3=*/true, + /*requireValidRowsEqualOnA5=*/true); +} + +mlir::LogicalResult mlir::pto::TMinOp::verify() { + return verifyArithmeticBinaryTileOpWithArchDispatch( + getOperation(), getSrc0().getType(), getSrc1().getType(), getDst().getType(), + /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, + "expects A2/A3 tmin element type to be i32/i16/f16/f32", + "expects A5 tmin element type to be i32/i16/i8/f16/bf16/f32"); +} + +mlir::LogicalResult mlir::pto::TMinSOp::verify() { + return verifyArithmeticScalarTileOpWithArchDispatch( + getOperation(), getSrc().getType(), getDst().getType(), getScalar().getType(), + /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, + "expects A2/A3 tmins element type to be i32/i16/f16/f32", + "expects A5 tmins element type to be i32/i16/i8/f16/bf16/f32", + /*requireValidRowsEqualOnA2A3=*/true, + /*requireValidRowsEqualOnA5=*/false); +} + +mlir::LogicalResult mlir::pto::TMovOp::verify() { + auto verifyImpl = [&](bool isA5) -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + Value fp = getFp(); + Value preQuantScalar = getPreQuantScalar(); + auto accToVecModeAttr = getAccToVecModeAttr(); + auto reluMode = getReluPreMode(); + const bool hasFp = static_cast(fp); + const bool hasPreQuantScalar = static_cast(preQuantScalar); + + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + if (hasFp && failed(verifyTileBufCommon(*this, fp.getType(), "fp"))) + return failure(); + if (hasFp && hasPreQuantScalar) + return emitOpError() << "expects fp and preQuantScalar forms to be mutually exclusive"; + + auto srcSpace = getPTOMemorySpaceEnum(srcTy); + auto dstSpace = getPTOMemorySpaceEnum(dstTy); + if (!srcSpace || !dstSpace) + return emitOpError() << "expects src and dst to have explicit address spaces"; + + auto srcShape = getShapeVec(srcTy); + auto dstShape = getShapeVec(dstTy); + if (*srcSpace == pto::AddressSpace::MAT && srcShape != dstShape) + return emitOpError() << "expects mat-source tmov to use matching src/dst shapes"; + if (!isA5 && *srcSpace != pto::AddressSpace::MAT && srcShape != dstShape) + return emitOpError() << "expects A2/A3 non-mat tmov to use matching src/dst shapes"; + + const bool isMatToTile = + *srcSpace == pto::AddressSpace::MAT && + (*dstSpace == pto::AddressSpace::LEFT || + *dstSpace == pto::AddressSpace::RIGHT || + *dstSpace == pto::AddressSpace::BIAS || + *dstSpace == pto::AddressSpace::SCALING); + const bool isVecToVec = + *srcSpace == pto::AddressSpace::VEC && + *dstSpace == pto::AddressSpace::VEC; + const bool isVecToMat = + *srcSpace == pto::AddressSpace::VEC && + *dstSpace == pto::AddressSpace::MAT; + const bool isAccToMat = + *srcSpace == pto::AddressSpace::ACC && + *dstSpace == pto::AddressSpace::MAT; + const bool isAccToVec = + *srcSpace == pto::AddressSpace::ACC && + *dstSpace == pto::AddressSpace::VEC; + + bool okPair = isMatToTile || isVecToVec || isAccToMat || isAccToVec; + if (isA5) + okPair = okPair || isVecToMat; + if (!okPair) + return emitOpError() + << "expects a supported tmov address-space pair for this target"; + + if (accToVecModeAttr && !isAccToVec) + return emitOpError() + << "expects accToVecMode to be used only for acc-to-vec tmov"; + + if (reluMode != pto::ReluPreMode::NoRelu && !(isAccToMat || isAccToVec)) + return emitOpError() + << "expects reluPreMode form to use loc=acc src"; + + if (hasPreQuantScalar && !(isAccToMat || isAccToVec)) + return emitOpError() + << "expects preQuantScalar form to use loc=acc src"; + + if (hasFp) { + auto fpTy = fp.getType(); + auto fpSpace = getPTOMemorySpaceEnum(fpTy); + if (!fpSpace || *fpSpace != pto::AddressSpace::SCALING) + return emitOpError() << "expects fp to be in the scaling address space"; + auto srcElemTy = getElemTy(srcTy); + auto srcIntTy = dyn_cast(srcElemTy); + if (!(srcElemTy.isF32() || (srcIntTy && srcIntTy.getWidth() == 32))) + return emitOpError() + << "expects fp form src to have element type f32, i32"; + if (!(isAccToMat || isAccToVec)) + return emitOpError() << "expects fp form to use loc=acc src"; + } + + if ((hasFp || hasPreQuantScalar) && accToVecModeAttr) { + switch (accToVecModeAttr.getValue()) { + case pto::AccToVecMode::SingleModeVec0: + case pto::AccToVecMode::SingleModeVec1: + break; + case pto::AccToVecMode::DualModeSplitM: + case pto::AccToVecMode::DualModeSplitN: + return emitOpError() + << "expects fp/preQuantScalar acc-to-vec forms to use single-mode accToVecMode"; + } + } + + auto srcTb = dyn_cast(srcTy); + auto dstTb = dyn_cast(dstTy); + if (srcTb && *srcSpace == pto::AddressSpace::ACC && + (hasFp || reluMode != pto::ReluPreMode::NoRelu)) { + if (srcTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor) || + srcTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor)) + return emitOpError() + << "expects acc-source fp/relu tmov src to use blayout=col_major and slayout=row_major"; + } + if (srcTb && dstTb && isAccToMat && !isA5 && + dstTb.getSFractalSizeI32() != 512) + return emitOpError() << "expects A2/A3 acc-to-mat tmov destination fractal to be 512"; + + return success(); + }; + auto verifyA2A3 = [&]() -> LogicalResult { return verifyImpl(/*isA5=*/false); }; + auto verifyA5 = [&]() -> LogicalResult { return verifyImpl(/*isA5=*/true); }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TMovFPOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type fpTy = getFp().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, fpTy, "fp")) || + failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + auto srcElemTy = getElemTy(srcTy); + auto srcIntTy = dyn_cast(srcElemTy); + if (!(srcElemTy.isF32() || + (srcIntTy && srcIntTy.getWidth() == 32))) + return emitOpError() + << "expects src to have element type f32, i32"; + auto fpSpace = getPTOMemorySpaceEnum(fpTy); + if (!fpSpace || *fpSpace != mlir::pto::AddressSpace::SCALING) + return emitOpError() << "expects fp to be in the scaling address space"; + auto srcSpace = getPTOMemorySpaceEnum(srcTy); + if (!srcSpace || *srcSpace != mlir::pto::AddressSpace::ACC) + return emitOpError() << "expects src to be in the acc address space"; + auto dstSpace = getPTOMemorySpaceEnum(dstTy); + if (!dstSpace || *dstSpace != mlir::pto::AddressSpace::MAT) + return emitOpError() << "expects dst to be in the mat address space"; + auto srcTb = dyn_cast(srcTy); + auto dstTb = dyn_cast(dstTy); + if (srcTb && + (srcTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor) || + srcTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor))) + return emitOpError() + << "expects src to use blayout=col_major and slayout=row_major"; + if (dstTb && + (dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor) || + dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor))) + return emitOpError() + << "expects dst to use blayout=col_major and slayout=row_major"; + if (dstTb && dstTb.getSFractalSizeI32() != 512) + return emitOpError() << "expects dst to use fractal size 512"; + return mlir::success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type fpTy = getFp().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, fpTy, "fp")) || + failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + auto srcElemTy = getElemTy(srcTy); + auto srcIntTy = dyn_cast(srcElemTy); + if (!(srcElemTy.isF32() || + (srcIntTy && srcIntTy.getWidth() == 32))) + return emitOpError() + << "expects src to have element type f32, i32"; + auto fpSpace = getPTOMemorySpaceEnum(fpTy); + if (!fpSpace || *fpSpace != mlir::pto::AddressSpace::SCALING) + return emitOpError() << "expects fp to be in the scaling address space"; + auto srcTb = dyn_cast(srcTy); + if (srcTb && + (srcTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor) || + srcTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor))) + return emitOpError() + << "expects src to use blayout=col_major and slayout=row_major"; + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} +// 辅助函数:获取 Rank,支持 ShapedType 和 PTO TileTypes +static int64_t getRankHelper(Type t) { + if (auto s = dyn_cast(t)) return s.getRank(); + if (auto tile = dyn_cast(t)) return tile.getRank(); + if (auto view = dyn_cast(t)) return view.getRank(); + return -1; +} + +static LogicalResult verifyMatmulLike(Operation *op, Type aTy, Type bTy, Type dstTy, bool checkRank = true) { + // 1. 检查类型 (ShapedType 或 Tile 类型) + bool aValid = isa(aTy); + bool bValid = isa(bTy); + bool dValid = isa(dstTy); + + if (!aValid || !bValid || !dValid) + return op->emitOpError("expects inputs/outputs to be shaped types or PTO tile types"); + + if (checkRank) { + int64_t aRank = getRankHelper(aTy); + int64_t bRank = getRankHelper(bTy); + int64_t dRank = getRankHelper(dstTy); + + // 检查 Rank 一致性 + if (aRank != -1 && dRank != -1 && aRank != dRank) + return op->emitOpError("expects a and dst to have the same rank"); + if (bRank != -1 && dRank != -1 && bRank != dRank) + return op->emitOpError("expects b and dst to have the same rank"); + } + + return success(); +} + +// ---- LoadScalarOp ---- +LogicalResult LoadScalarOp::verify() { + Type ptrTy = getPtr().getType(); + Type elemTy; + if (auto pty = dyn_cast(ptrTy)) { + elemTy = pty.getElementType(); + } else if (auto memTy = dyn_cast(ptrTy)) { + elemTy = memTy.getElementType(); + if (!isGmAddressSpaceAttr(memTy.getMemorySpace())) + return emitOpError() << "scalar load only supports GM address space pointers"; + } else { + return emitOpError("expects ptr to be !pto.ptr or memref type"); + } + + if (getValue().getType() != elemTy) + return emitOpError("expects result type to match ptr element type"); + + return success(); +} +// ---- StoreScalarOp ---- +LogicalResult StoreScalarOp::verify() { + Type ptrTy = getPtr().getType(); + Type elemTy; + if (auto pty = dyn_cast(ptrTy)) { + elemTy = pty.getElementType(); + } else if (auto memTy = dyn_cast(ptrTy)) { + elemTy = memTy.getElementType(); + if (!isGmAddressSpaceAttr(memTy.getMemorySpace())) + return emitOpError() << "scalar store only supports GM address space pointers"; + } else { + return emitOpError("expects ptr to be !pto.ptr or memref type"); + } + + if (getValue().getType() != elemTy) + return emitOpError("expects value type to match ptr element type"); + + return success(); +} + +// ---- GetBufOp / RlsBufOp ---- +static LogicalResult verifyBufSyncOp(Operation *op, Attribute opTypeAttr, + IntegerAttr bufIdAttr, IntegerAttr modeAttr) { + if (!opTypeAttr) + return op->emitOpError("expects 'op_type' attribute"); + + auto opTypeOr = parseSyncOpTypeLikeAttr(opTypeAttr); + if (failed(opTypeOr)) { + auto diag = + op->emitOpError("expects 'op_type' to be pipe_event_type/sync_op_type, got "); + diag << opTypeAttr; + return failure(); + } + pto::PIPE pipe = mapSyncOpTypeToPipe(*opTypeOr); + if (!isConcreteSyncPipe(pipe)) + return op->emitOpError("expects 'op_type' to map to a concrete pipe, not PIPE_ALL/PIPE_UNASSIGNED"); + + if (!bufIdAttr) + return op->emitOpError("expects 'buf_id' attribute"); + int64_t bufId = bufIdAttr.getInt(); + if (bufId < 0 || bufId > 31) + return op->emitOpError("expects 'buf_id' in range [0, 31]"); + + if (modeAttr) { + int64_t mode = modeAttr.getInt(); + if (mode < 0) + return op->emitOpError("expects 'mode' to be non-negative"); + } + + return success(); +} + +LogicalResult GetBufOp::verify() { + return verifyBufSyncOp(getOperation(), getOpTypeAttr(), getBufIdAttr(), + getModeAttr()); +} + +LogicalResult RlsBufOp::verify() { + return verifyBufSyncOp(getOperation(), getOpTypeAttr(), getBufIdAttr(), + getModeAttr()); +} +// ---- TOp ---- +LogicalResult TGemvBiasOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + if (failed(verifyGemvTileOperands(*this, getA().getType(), getB().getType(), + getDst().getType())) || + failed(verifyMatBiasTile(*this, getBias().getType(), getDst().getType()))) + return failure(); + if (failed(verifyMatmulTypeTriple(*this, getElemTy(getA().getType()), + getElemTy(getB().getType()), + getElemTy(getDst().getType())))) + return failure(); + return verifyMatmulLike(*this, getA().getType(), getB().getType(), + getDst().getType()); + }; + auto verifyA5 = [&]() -> LogicalResult { return verifyA2A3(); }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult TGemvMxOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + return emitOpError("tgemv.mx is only supported on A5 targets"); + }; + auto verifyA5 = [&]() -> LogicalResult { + if (failed(verifyScaleTileMatchesOperand(*this, getAScale().getType(), + getA().getType(), "a_scale", "a")) || + failed(verifyScaleTileMatchesOperand(*this, getBScale().getType(), + getB().getType(), "b_scale", "b")) || + failed(verifyGemvTileOperands(*this, getA().getType(), getB().getType(), + getDst().getType()))) + return failure(); + if (failed(verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), + getDst().getType(), "a", "b", "dst"))) + return failure(); + return verifyMatmulLike(*this, getA().getType(), getB().getType(), + getDst().getType()); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult TGemvMxAccOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + return emitOpError("tgemv.mx.acc is only supported on A5 targets"); + }; + auto verifyA5 = [&]() -> LogicalResult { + if (failed(verifyAccTileCommon(*this, getCIn().getType(), "c_in")) || + failed(verifyScaleTileMatchesOperand(*this, getAScale().getType(), + getA().getType(), "a_scale", "a")) || + failed(verifyScaleTileMatchesOperand(*this, getBScale().getType(), + getB().getType(), "b_scale", "b")) || + failed(verifyGemvTileOperands(*this, getA().getType(), getB().getType(), + getDst().getType()))) + return failure(); + if (failed(verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), + getDst().getType(), "a", "b", "dst"))) + return failure(); + if (failed(verifyTileBufSameElemType(*this, getCIn().getType(), + getDst().getType(), "c_in", "dst")) || + failed(verifyTileBufSameValidShape(*this, getCIn().getType(), + getDst().getType(), "c_in", "dst"))) + return failure(); + return verifyMatmulLike(*this, getA().getType(), getB().getType(), + getDst().getType()); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult TGemvMxBiasOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + return emitOpError("tgemv.mx.bias is only supported on A5 targets"); + }; + auto verifyA5 = [&]() -> LogicalResult { + if (failed(verifyScaleTileMatchesOperand(*this, getAScale().getType(), + getA().getType(), "a_scale", "a")) || + failed(verifyScaleTileMatchesOperand(*this, getBScale().getType(), + getB().getType(), "b_scale", "b")) || + failed(verifyGemvTileOperands(*this, getA().getType(), getB().getType(), + getDst().getType())) || + failed(verifyMatBiasTile(*this, getBias().getType(), getDst().getType(), + /*requireFloatBias=*/true))) + return failure(); + if (failed(verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), + getDst().getType(), "a", "b", "dst"))) + return failure(); + auto biasShape = getShapeVec(getBias().getType()); + auto dstShape = getShapeVec(getDst().getType()); + if (biasShape.size() != 2 || dstShape.size() != 2) + return emitOpError("expects bias and dst to be rank-2 for tgemv.mx.bias"); + if (biasShape[1] != ShapedType::kDynamic && dstShape[1] != ShapedType::kDynamic && + biasShape[1] != dstShape[1]) + return emitOpError("expects bias and dst to have the same column shape"); + if (failed(verifyTileBufSameValidShape(*this, getBias().getType(), + getDst().getType(), "bias", "dst"))) + return failure(); + return verifyMatmulLike(*this, getA().getType(), getB().getType(), + getDst().getType()); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult TMatmulBiasOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + if (failed(verifyMatTileOperands(*this, getA().getType(), getB().getType(), + getDst().getType())) || + failed(verifyMatBiasTile(*this, getBias().getType(), getDst().getType()))) + return failure(); + if (failed(verifyMatmulTypeTriple(*this, getElemTy(getA().getType()), + getElemTy(getB().getType()), + getElemTy(getDst().getType())))) + return failure(); + return verifyMatmulLike(*this, getA().getType(), getB().getType(), + getDst().getType()); + }; + auto verifyA5 = [&]() -> LogicalResult { return verifyA2A3(); }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult TMatmulMxOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + if (failed(verifyTileBufCommon(*this, getAScale().getType(), "a_scale")) || + failed(verifyTileBufCommon(*this, getBScale().getType(), "b_scale"))) + return failure(); + return verifyMatmulLike(*this, getA().getType(), getB().getType(), + getDst().getType()); + }; + auto verifyA5 = [&]() -> LogicalResult { + if (failed(verifyA2A3())) + return failure(); + return verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), + getDst().getType(), "a", "b", "dst"); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult TMatmulMxAccOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + if (failed(verifyAccTileCommon(*this, getCIn().getType(), "c_in")) || + failed(verifyTileBufCommon(*this, getAScale().getType(), "a_scale")) || + failed(verifyTileBufCommon(*this, getBScale().getType(), "b_scale"))) + return failure(); + return success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + if (failed(verifyA2A3())) + return failure(); + if (failed(verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), + getDst().getType(), "a", "b", "dst"))) + return failure(); + if (failed(verifyTileBufSameElemType(*this, getCIn().getType(), + getDst().getType(), "c_in", "dst")) || + failed(verifyTileBufSameValidShape(*this, getCIn().getType(), + getDst().getType(), "c_in", "dst"))) + return failure(); + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} +LogicalResult TMatmulMxBiasOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + if (failed(verifyTileBufCommon(*this, getAScale().getType(), "a_scale")) || + failed(verifyTileBufCommon(*this, getBScale().getType(), "b_scale")) || + failed(verifyMatTileOperands(*this, getA().getType(), getB().getType(), + getDst().getType())) || + failed(verifyMatBiasTile(*this, getBias().getType(), getDst().getType(), + /*requireFloatBias=*/true))) + return failure(); + return verifyMatmulLike(*this, getA().getType(), getB().getType(), + getDst().getType()); + }; + auto verifyA5 = [&]() -> LogicalResult { + if (failed(verifyA2A3())) + return failure(); + return verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), + getDst().getType(), "a", "b", "dst"); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} +// ---- TSetValOp ---- +LogicalResult TSetValOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + // dst can be tile/tensor/tilebuf (PTODpsType). Keep checks minimal. + if (auto shaped = dyn_cast(getDst().getType())) { + if (shaped.getElementType() != getVal().getType()) + return emitOpError("expects val type to match dst element type"); + } + return success(); +} +// ---- TGetValOp ---- +LogicalResult TGetValOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + Type srcTy = getSrc().getType(); + if (!mlir::isa(srcTy)) + return emitOpError("expects src to be tile_buf or memref type"); + + // Memory space must be vec (Ascend does not support getval from MAT etc.). + Attribute memSpace = + isa(srcTy) + ? cast(srcTy).getMemorySpace() + : cast(srcTy).getMemorySpace(); + auto addrSpaceAttr = dyn_cast_or_null(memSpace); + if (!addrSpaceAttr || + addrSpaceAttr.getAddressSpace() != pto::AddressSpace::VEC) { + if (addrSpaceAttr && + addrSpaceAttr.getAddressSpace() == pto::AddressSpace::MAT) + return emitOpError( + "Ascend hardware does not support reading from Mat tile_buf to Scalar unit"); + return emitOpError("expects src memory space to be vec"); + } + + if (getElemTy(srcTy) != getDst().getType()) + return emitOpError("expects dst type to match src element type"); + return success(); +} + +LogicalResult THistogramOp::verify() { + auto isIntegerWidth = [](Type ty, unsigned width) { + auto it = dyn_cast(ty); + return it && it.getWidth() == width; + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + return emitOpError("thistogram is only supported on A5"); + }; + auto verifyA5 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type idxTy = getIdx().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, idxTy, "idx")) || + failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + + auto srcSpace = getPTOMemorySpaceEnum(srcTy); + auto idxSpace = getPTOMemorySpaceEnum(idxTy); + auto dstSpace = getPTOMemorySpaceEnum(dstTy); + if (!srcSpace || *srcSpace != pto::AddressSpace::VEC) + return emitOpError("expects src to be in the vec address space"); + if (!idxSpace || *idxSpace != pto::AddressSpace::VEC) + return emitOpError("expects idx to be in the vec address space"); + if (!dstSpace || *dstSpace != pto::AddressSpace::VEC) + return emitOpError("expects dst to be in the vec address space"); + + auto srcTB = dyn_cast(srcTy); + auto idxTB = dyn_cast(idxTy); + auto dstTB = dyn_cast(dstTy); + if (!srcTB || !idxTB || !dstTB) + return emitOpError("expects src, idx, and dst to be tile_buf types"); + + if (srcTB.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) || + srcTB.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) + return emitOpError("expects src to use row_major + none_box layout"); + if (dstTB.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) || + dstTB.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) + return emitOpError("expects dst to use row_major + none_box layout"); + if (idxTB.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor) || + idxTB.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) + return emitOpError( + "expects idx to use DN layout (col_major + none_box)"); + + if (!isIntegerWidth(getElemTy(srcTy), 16)) + return emitOpError("expects src element type to be ui16"); + if (!isIntegerWidth(getElemTy(idxTy), 8)) + return emitOpError("expects idx element type to be ui8"); + if (!isIntegerWidth(getElemTy(dstTy), 32)) + return emitOpError("expects dst element type to be ui32"); + + auto srcShape = getShapeVec(srcTy); + auto idxShape = getShapeVec(idxTy); + auto dstShape = getShapeVec(dstTy); + auto srcValid = getValidShapeVec(srcTy); + auto idxValid = getValidShapeVec(idxTy); + auto dstValid = getValidShapeVec(dstTy); + if (srcShape.size() != 2 || idxShape.size() != 2 || dstShape.size() != 2 || + srcValid.size() != 2 || idxValid.size() != 2 || dstValid.size() != 2) + return emitOpError( + "expects src, idx, and dst to have rank-2 shape and valid_shape"); + + if (!hasCompatibleKnownExtent(srcShape[0], idxShape[0]) || + !hasCompatibleKnownExtent(srcValid[0], idxValid[0])) + return emitOpError("expects idx rows and valid rows to match src"); + if (!hasCompatibleKnownExtent(srcShape[0], dstShape[0]) || + !hasCompatibleKnownExtent(srcValid[0], dstValid[0])) + return emitOpError("expects dst rows and valid rows to match src"); + + if (!isKnownUnitExtent(idxShape[1]) || !isKnownUnitExtent(idxValid[1])) + return emitOpError("expects idx to have exactly one column"); + if (dstShape[1] != ShapedType::kDynamic && dstShape[1] < 256) + return emitOpError("expects dst shape[1] to be at least 256"); + if (dstValid[1] != ShapedType::kDynamic && dstValid[1] < 256) + return emitOpError("expects dst valid_shape[1] to be at least 256"); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult TGetScaleAddrOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + return emitOpError("tget_scale_addr is only supported on A5"); + }; + auto verifyA5 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, srcTy, "src"))) + return failure(); + if (failed(verifyScaleTileMatchesOperand(*this, dstTy, srcTy, "dst", "src"))) + return failure(); + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +// ---- MScatterOp ---- +LogicalResult MScatterOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + + if (!isTargetArchA5(getOperation())) + return emitOpError("pto.mscatter is only supported on A5 targets"); + + Type srcTy = getSrc().getType(); + Type idxTy = getIdx().getType(); + Type memTy = getMem().getType(); + + if (getPTOTypeRank(srcTy) == -1 || getPTOTypeRank(idxTy) == -1 || + getPTOTypeRank(memTy) == -1) + return emitOpError("expects src, idx, and mem to use supported PTO shapes"); + + if (failed(verifyNDStyleVecTile(*this, srcTy, "src")) || + failed(verifyMGatherMScatterIdxTile(getOperation(), idxTy, "idx"))) + return failure(); + + Type srcElem = getElemTy(srcTy); + Type idxElem = getElemTy(idxTy); + if (!srcElem || !idxElem) + return emitOpError("failed to resolve element types for src or idx"); + + if (!isSupportedMGatherMScatterPayloadElemType(getOperation(), srcElem)) + return emitOpError( + "expects src element type to be i8/ui8/i16/ui16/i32/ui32/f16/bf16/f32 " + "(and on A5 targets also float8_e4m3/float8_e5m2 family types)"); + + if (!isSupportedMGatherMScatterIndexElemType(idxElem)) + return emitOpError("expects idx element type to be signless i32"); + + if (failed(verifyMGatherMScatterMemOperand(getOperation(), getMem(), srcElem, + "src"))) + return failure(); + + if (getScatterAtomicOp() != pto::ScatterAtomicOp::None || + getScatterOob() != pto::ScatterOOB::Undefined) { + if (!isTargetArchA5(getOperation())) + return emitOpError( + "expects non-default scatterAtomicOp/scatterOob only on A5 targets"); + } + + if (!isSupportedMScatterAtomicPayloadElemType(srcElem, getScatterAtomicOp())) + return emitOpError( + "expects scatterAtomicOp-compatible src element type: add supports " + "i32/ui32/f16/f32, max/min support signless i32/f32"); + + if (failed(verifyMGatherMScatterTileShape(getOperation(), srcTy, idxTy, "src"))) + return failure(); + + return success(); +} + +// ---- MGatherOp ---- +LogicalResult MGatherOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + + if (!isTargetArchA5(getOperation())) + return emitOpError("pto.mgather is only supported on A5 targets"); + + Type memTy = getMem().getType(); + Type idxTy = getIdx().getType(); + Type dstTy = getDst().getType(); + + if (getPTOTypeRank(memTy) == -1 || getPTOTypeRank(idxTy) == -1 || + getPTOTypeRank(dstTy) == -1) + return emitOpError("expects mem, idx, and dst to use supported PTO shapes"); + + if (failed(verifyNDStyleVecTile(*this, dstTy, "dst")) || + failed(verifyMGatherMScatterIdxTile(getOperation(), idxTy, "idx"))) + return failure(); + + Type dstElem = getElemTy(dstTy); + Type idxElem = getElemTy(idxTy); + if (!dstElem || !idxElem) + return emitOpError("failed to resolve element types for dst or idx"); + + if (!isSupportedMGatherMScatterPayloadElemType(getOperation(), dstElem)) + return emitOpError( + "expects dst element type to be i8/ui8/i16/ui16/i32/ui32/f16/bf16/f32 " + "(and on A5 targets also float8_e4m3/float8_e5m2 family types)"); + + if (!isSupportedMGatherMScatterIndexElemType(idxElem)) + return emitOpError("expects idx element type to be signless i32"); + + if (failed(verifyMGatherMScatterMemOperand(getOperation(), getMem(), dstElem, + "dst"))) + return failure(); + + if (getGatherOob() != pto::GatherOOB::Undefined && + !isTargetArchA5(getOperation())) + return emitOpError( + "expects non-default gatherOob only on A5 targets"); + + if (failed(verifyMGatherMScatterTileShape(getOperation(), dstTy, idxTy, "dst"))) + return failure(); + + return success(); +} + +void mlir::pto::TCvtOp::print(OpAsmPrinter &p) { + p << " ins(" << getSrc(); + Builder builder(getContext()); + NamedAttrList attrs; + for (auto attr : (*this)->getAttrs()) { + if (attr.getName() == "sat_mode") { + attrs.set(builder.getStringAttr("satmode"), attr.getValue()); + continue; + } + attrs.set(attr.getName(), attr.getValue()); + } + p.printOptionalAttrDict(attrs.getAttrs()); + p << " : " << getSrc().getType(); + p << ") outs(" << getDst() << " : " << getDst().getType() << ")"; +} + +ParseResult mlir::pto::TCvtOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand src, dst; + Type srcTy, dstTy; + + if (parser.parseKeyword("ins") || parser.parseLParen() || parser.parseOperand(src)) + return failure(); + NamedAttrList attrs; + if (parser.parseOptionalAttrDict(attrs) || parser.parseColonType(srcTy)) + return failure(); + if (auto satmode = attrs.get("satmode")) { + attrs.erase("satmode"); + if (attrs.get("sat_mode")) + return parser.emitError(parser.getCurrentLocation(), + "cannot specify both satmode and sat_mode"); + attrs.set("sat_mode", satmode); + } + result.attributes = attrs; + if (parser.parseRParen() || parser.parseKeyword("outs") || parser.parseLParen() || + parser.parseOperand(dst) || parser.parseColonType(dstTy) || parser.parseRParen()) + return failure(); + + if (parser.resolveOperand(src, srcTy, result.operands) || + parser.resolveOperand(dst, dstTy, result.operands)) + return failure(); + return success(); +} + +void mlir::pto::TMrgSortOp::print(OpAsmPrinter &p) { + if (isFormat1()) { + p << " ins(" << getSrc() << ", " << getBlockLen() << " : " << getSrc().getType() + << ", " << getBlockLen().getType() << ") outs(" << getDst() << " : " + << getDst().getType() << ")"; + } else if (isFormat2()) { + p << " ins("; + llvm::interleaveComma(getSrcs(), p, [&](Value src) { p << src; }); + p << ", " << getTmp(); + p << " {exhausted = " << (getExhausted() ? "true" : "false") << "} : "; + llvm::interleaveComma(getSrcs().getTypes(), p, [&](Type ty) { p << ty; }); + p << ", " << getTmp().getType(); + p << ") outs(" << getDst() << ", " << getExcuted() + << " : " << getDst().getType() << ", " << getExcuted().getType() << ")"; + } else { + llvm::report_fatal_error("TMrgSortOp print expects format1 or format2"); + } + p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"operandSegmentSizes", "exhausted"}); +} + +ParseResult mlir::pto::TMrgSortOp::parse(OpAsmParser &parser, OperationState &result) { + if (parser.parseKeyword("ins") || parser.parseLParen()) + return failure(); + OpAsmParser::UnresolvedOperand first, second; + if (parser.parseOperand(first) || parser.parseComma() || parser.parseOperand(second)) + return failure(); + + if (parser.parseOptionalColon().succeeded()) { + Type srcTy, blockLenTy, dstTy; + if (parser.parseType(srcTy) || parser.parseComma() || parser.parseType(blockLenTy) || + parser.parseRParen() || parser.parseKeyword("outs") || parser.parseLParen()) + return failure(); + OpAsmParser::UnresolvedOperand dstOp; + if (parser.parseOperand(dstOp) || parser.parseColon() || parser.parseType(dstTy) || + parser.parseRParen()) + return failure(); + result.addAttribute("operandSegmentSizes", + parser.getBuilder().getDenseI32ArrayAttr({1, 1, 1, 0, 0})); + if (parser.resolveOperand(first, srcTy, result.operands) || + parser.resolveOperand(second, blockLenTy, result.operands) || + parser.resolveOperand(dstOp, dstTy, result.operands)) + return failure(); + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + if (!result.attributes.get("exhausted")) + result.addAttribute("exhausted", parser.getBuilder().getBoolAttr(false)); + return success(); + } + + SmallVector srcs = {first, second}; + while (parser.parseOptionalComma().succeeded()) { + OpAsmParser::UnresolvedOperand next; + if (parser.parseOperand(next)) + return failure(); + srcs.push_back(next); + } + if (srcs.size() < 3 || srcs.size() > 5) + return parser.emitError(parser.getCurrentLocation(), + "tmrgsort format2 expects 2 to 4 src operands plus one tmp operand"); + OpAsmParser::UnresolvedOperand tmpOp = srcs.pop_back_val(); + bool exhaustedVal = false; + if (parser.parseOptionalLBrace().succeeded()) { + if (parser.parseKeyword("exhausted") || parser.parseEqual()) + return failure(); + StringRef kw; + if (parser.parseKeyword(&kw) || parser.parseRBrace()) + return failure(); + exhaustedVal = (kw == "true"); + } + SmallVector srcTypes; + srcTypes.reserve(srcs.size()); + if (parser.parseColon()) + return failure(); + Type firstSrcTy; + if (parser.parseType(firstSrcTy)) + return failure(); + srcTypes.push_back(firstSrcTy); + while (parser.parseOptionalComma().succeeded()) { + Type nextTy; + if (parser.parseType(nextTy)) + return failure(); + srcTypes.push_back(nextTy); + } + if (srcTypes.size() != srcs.size() + 1 || parser.parseRParen() || + parser.parseKeyword("outs") || parser.parseLParen()) + return failure(); + Type tmpTy = srcTypes.pop_back_val(); + OpAsmParser::UnresolvedOperand dstOp, excutedOp; + Type dstTy, excutedTy; + if (parser.parseOperand(dstOp) || parser.parseComma() || parser.parseOperand(excutedOp) || + parser.parseColon() || parser.parseType(dstTy) || parser.parseComma() || + parser.parseType(excutedTy) || parser.parseRParen()) + return failure(); + result.addAttribute("operandSegmentSizes", + parser.getBuilder().getDenseI32ArrayAttr( + {static_cast(srcs.size()), 0, 1, 1, 1})); + if (parser.resolveOperands(srcs, srcTypes, parser.getCurrentLocation(), result.operands) || + parser.resolveOperand(dstOp, dstTy, result.operands) || + parser.resolveOperand(tmpOp, tmpTy, result.operands) || + parser.resolveOperand(excutedOp, excutedTy, result.operands)) + return failure(); + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + if (!result.attributes.get("exhausted")) + result.addAttribute("exhausted", parser.getBuilder().getBoolAttr(exhaustedVal)); + return success(); +} + +mlir::LogicalResult mlir::pto::TMrgSortOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + if (isFormat1()) { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (!isPTOShapedLike(srcTy) || !isPTOShapedLike(dstTy)) + return emitOpError() << "format1 expects PTO shaped-like types for src/dst"; + if (getElemTy(srcTy) != getElemTy(dstTy)) + return emitOpError() << "expects src/dst to have the same element type"; + if (!getElemTy(srcTy).isF16() && !getElemTy(srcTy).isF32()) + return emitOpError() << "expects element type to be f16 or f32"; + auto ss = getShapeVec(srcTy); + auto ds = getShapeVec(dstTy); + if (ss.size() != 2 || ds.size() != 2) + return emitOpError() << "expects src/dst to be rank-2 tile-shaped"; + if (ss[0] != mlir::ShapedType::kDynamic && ss[0] != 1) + return emitOpError() << "expects src rows == 1"; + if (ds[0] != mlir::ShapedType::kDynamic && ds[0] != 1) + return emitOpError() << "expects dst rows == 1"; + if (ss[1] != mlir::ShapedType::kDynamic && ds[1] != mlir::ShapedType::kDynamic && ss[1] != ds[1]) + return emitOpError() << "expects src/dst cols to match"; + if (getBlockLen()) { + if (auto cstOp = getBlockLen().getDefiningOp()) { + if (auto intAttr = mlir::dyn_cast(cstOp.getValue())) { + int64_t v = intAttr.getValue().getSExtValue(); + if (v <= 0 || (v % 64) != 0) + return emitOpError() << "expects blockLen > 0 and multiple of 64"; + } + } + } + return mlir::success(); + } + if (isFormat2()) { + for (Value v : getSrcs()) + if (!isPTOShapedLike(v.getType())) + return emitOpError() << "format2 expects PTO shaped-like type for each src"; + if (getSrcs().size() < 2u || getSrcs().size() > 4u) + return emitOpError() << "format2 expects 2 to 4 srcs"; + if (getDsts().size() != 1u || !getTmp() || !getExcuted()) + return emitOpError() << "format2 expects ins(srcs..., tmp), outs(dst), and excuted=vector"; + Type dstTy = getDst().getType(); + Type tmpTy = getTmp().getType(); + if (!isPTOShapedLike(dstTy) || !isPTOShapedLike(tmpTy)) + return emitOpError() << "format2 dst/tmp must be PTO shaped-like"; + auto excutedTy = mlir::dyn_cast(getExcuted().getType()); + if (!excutedTy || excutedTy.getRank() != 1 || excutedTy.getNumElements() != 4 || + !excutedTy.getElementType().isInteger(16)) + return emitOpError() << "format2 excuted must be vector<4xi16>"; + Type elemTy = getElemTy(dstTy); + if (elemTy != getElemTy(tmpTy)) + return emitOpError() << "format2 expects dst/tmp element types to match"; + auto dstShape = getShapeVec(dstTy); + auto tmpShape = getShapeVec(tmpTy); + if (dstShape.size() != 2 || tmpShape.size() != 2) + return emitOpError() << "format2 expects dst/tmp to be rank-2 tile-shaped"; + if ((dstShape[0] != mlir::ShapedType::kDynamic && dstShape[0] != 1) || + (tmpShape[0] != mlir::ShapedType::kDynamic && tmpShape[0] != 1)) + return emitOpError() << "format2 expects dst/tmp rows == 1"; + if (dstShape[1] != mlir::ShapedType::kDynamic && + tmpShape[1] != mlir::ShapedType::kDynamic && + tmpShape[1] < dstShape[1]) + return emitOpError() << "format2 expects tmp.cols >= dst.cols"; + for (Value src : getSrcs()) { + Type srcTy = src.getType(); + auto srcShape = getShapeVec(srcTy); + if (srcShape.size() != 2) + return emitOpError() << "format2 expects src to be rank-2 tile-shaped"; + if (srcShape[0] != mlir::ShapedType::kDynamic && srcShape[0] != 1) + return emitOpError() << "format2 expects src rows == 1"; + if (getElemTy(srcTy) != elemTy) + return emitOpError() << "format2 expects src/dst/tmp element types to match"; + } + return mlir::success(); + } + return emitOpError() << "tmrgsort expects format1 (1 src + blockLen + 1 dst) or " + "format2 (2 to 4 srcs + tmp, outs dst, excuted)"; +} + +mlir::LogicalResult mlir::pto::TMulOp::verify() { + return verifyArithmeticBinaryTileOpWithArchDispatch( + getOperation(), getSrc0().getType(), getSrc1().getType(), getDst().getType(), + /*allowInt8OnA5=*/false, /*allowBf16OnA5=*/false, + "expects A2/A3 tmul element type to be i32/i16/f16/f32", + "expects A5 tmul element type to be i32/i16/f16/f32"); +} + +mlir::LogicalResult mlir::pto::TMulSOp::verify() { + return verifyArithmeticScalarTileOpWithArchDispatch( + getOperation(), getSrc0().getType(), getDst().getType(), + getScalar().getType(), /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, + "expects A2/A3 tmuls element type to be i32/i16/f16/f32", + "expects A5 tmuls element type to be i32/i16/i8/f16/bf16/f32", + /*requireValidRowsEqualOnA2A3=*/true, + /*requireValidRowsEqualOnA5=*/false); +} + +mlir::LogicalResult mlir::pto::TShlSOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + + Type srcElem = getElemTy(srcTy); + Type dstElem = getElemTy(dstTy); + if (!srcElem || !dstElem) + return emitOpError() << "failed to get element type for src/dst"; + if (srcElem != dstElem) + return emitOpError() << "expects src and dst to have the same element type"; + if (!mlir::isa(srcElem)) + return emitOpError() << "expects integral element types"; + if (auto scalarValue = getConstantIntegerValue(getScalar()); scalarValue && *scalarValue < 0) + return emitOpError("expects tshls scalar to be non-negative"); + return mlir::success(); +} + +mlir::LogicalResult mlir::pto::TShrSOp::verify() { + auto verifyCommon = [&]() -> FailureOr { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyVecTileCommon(*this, srcTy, "src")) || + failed(verifyVecTileCommon(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) + return failure(); + + Type srcElem = getElemTy(srcTy); + Type dstElem = getElemTy(dstTy); + if (!srcElem || !dstElem) { + emitOpError("failed to get element type for src/dst"); + return failure(); + } + if (srcElem != dstElem) { + emitOpError("expects src and dst to have the same element type"); + return failure(); + } + return srcElem; + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + FailureOr elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + auto it = mlir::dyn_cast(*elemOr); + if (!it || (it.getWidth() != 16 && it.getWidth() != 32)) + return emitOpError( + "expects A2/A3 tshrs src and dst element type to be i16/i32"); + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + FailureOr elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + auto it = mlir::dyn_cast(*elemOr); + if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && + it.getWidth() != 32)) + return emitOpError( + "expects A5 tshrs src and dst element type to be i8/i16/i32"); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TNegOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyVecTileStorage(*this, srcTy, "src")) || + failed(verifyVecTileStorage(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst")) || + failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) + return failure(); + + Type elemTy = getElemTy(srcTy); + if (!(elemTy.isInteger(16) || elemTy.isInteger(32) || elemTy.isF16() || + elemTy.isF32())) + return emitOpError() + << "expects A2/A3 tneg element type to be i16/i32/f16/f32"; + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyVecTileStorage(*this, srcTy, "src")) || + failed(verifyVecTileStorage(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst"))) + return failure(); + + auto srcValid = getValidShapeVec(srcTy); + auto dstValid = getValidShapeVec(dstTy); + if (srcValid.size() != 2 || dstValid.size() != 2) + return emitOpError() << "expects src and dst to have rank-2 valid_shape"; + if (srcValid[1] != ShapedType::kDynamic && + dstValid[1] != ShapedType::kDynamic && + srcValid[1] != dstValid[1]) + return emitOpError() + << "expects src and dst to have the same valid_shape[1]"; + + Type elemTy = getElemTy(srcTy); + if (!(elemTy.isInteger(8) || elemTy.isInteger(16) || elemTy.isInteger(32) || + elemTy.isF16() || elemTy.isF32() || elemTy.isBF16())) + return emitOpError() + << "expects A5 tneg element type to be i8/i16/i32/f16/f32/bf16"; + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TNotOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyVecTileCommon(*this, srcTy, "src")) || + failed(verifyVecTileCommon(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) + return failure(); + auto elemTy = getElemTy(srcTy); + if (elemTy != getElemTy(dstTy)) + return emitOpError() << "expects src and dst to have the same element type"; + if (!elemTy.isInteger(16)) + return emitOpError() << "expects A2/A3 tnot element type to be i16"; + return success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyVecTileCommon(*this, srcTy, "src")) || + failed(verifyVecTileCommon(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) + return failure(); + auto elemTy = getElemTy(srcTy); + if (elemTy != getElemTy(dstTy)) + return emitOpError() << "expects src and dst to have the same element type"; + if (!(elemTy.isInteger(8) || elemTy.isInteger(16) || elemTy.isInteger(32))) + return emitOpError() << "expects A5 tnot element type to be i8/i16/i32"; + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TOrOp::verify() { + auto verifyCommon = [&]() -> FailureOr { + return verifyMatchingRowMajorBinaryTileOpCommon( + getOperation(), getSrc0().getType(), getSrc1().getType(), + getDst().getType()); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + FailureOr elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + auto it = mlir::dyn_cast(*elemOr); + if (!it || (it.getWidth() != 8 && it.getWidth() != 16)) + return emitOpError( + "expects A2/A3 tor src0, src1, and dst element type to be i8/i16"); + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + FailureOr elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + auto it = mlir::dyn_cast(*elemOr); + if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && + it.getWidth() != 32)) + return emitOpError( + "expects A5 tor src0, src1, and dst element type to be i8/i16/i32"); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TOrSOp::verify() { + auto verifyCommon = [&]() -> FailureOr { + return verifyDistinctRowMajorUnaryTileOpCommon(getOperation(), getSrc(), + getDst(), "src", "dst"); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + FailureOr elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + auto it = mlir::dyn_cast(*elemOr); + if (!it || (it.getWidth() != 8 && it.getWidth() != 16)) + return emitOpError( + "expects A2/A3 tors src and dst element type to be i8/i16"); + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + FailureOr elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + auto it = mlir::dyn_cast(*elemOr); + if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && + it.getWidth() != 32)) + return emitOpError( + "expects A5 tors src and dst element type to be i8/i16/i32"); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +static FailureOr verifyPTOShapedBinarySameElemAndShape(Operation *op, + Type src0Ty, + Type src1Ty, + Type dstTy) { + if (!isPTOShapedLike(src0Ty) || !isPTOShapedLike(src1Ty) || + !isPTOShapedLike(dstTy)) + return op->emitOpError( + "expects src0/src1/dst to be memref/tensor/tile_buf/tile_view types"), + failure(); + Type e0 = getElemTy(src0Ty), e1 = getElemTy(src1Ty), ed = getElemTy(dstTy); + if (!e0 || !e1 || !ed) + return op->emitOpError("failed to get element type for operands"), failure(); + if (e0 != e1 || e0 != ed) + return op->emitOpError("expects src0/src1/dst to have the same element type"), + failure(); + auto s0 = getShapeVec(src0Ty), s1 = getShapeVec(src1Ty), sd = getShapeVec(dstTy); + if (s0 != s1 || s0 != sd) + return op->emitOpError("expects src0/src1/dst to have the same shape"), + failure(); + return e0; +} + +mlir::LogicalResult mlir::pto::TPartAddOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + Type src0Ty = getSrc0().getType(); + Type src1Ty = getSrc1().getType(); + Type dstTy = getDst().getType(); + if (!isPTOShapedLike(src0Ty) || !isPTOShapedLike(src1Ty) || + !isPTOShapedLike(dstTy)) + return emitOpError() << "expects PTO shaped-like src0/src1/dst"; + if (getElemTy(src0Ty) != getElemTy(src1Ty) || + getElemTy(src0Ty) != getElemTy(dstTy)) + return emitOpError() << "expects src0/src1/dst to have the same element type"; + auto s0 = getShapeVec(src0Ty); + auto s1 = getShapeVec(src1Ty); + auto d = getShapeVec(dstTy); + if (s0.size() != 2 || s1.size() != 2 || d.size() != 2) + return emitOpError() << "expects src0/src1/dst to be rank-2 (tile-shaped)"; + if (failed(verifyPartialValidPattern(*this, src0Ty, src1Ty, dstTy))) + return failure(); + Type elem = getElemTy(src0Ty); + if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || elem.isF32())) + return emitOpError("expects A2/A3 tpartadd element type to be i32/i16/f16/f32"); + return mlir::success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + Type src0Ty = getSrc0().getType(); + Type src1Ty = getSrc1().getType(); + Type dstTy = getDst().getType(); + if (!isPTOShapedLike(src0Ty) || !isPTOShapedLike(src1Ty) || + !isPTOShapedLike(dstTy)) + return emitOpError() << "expects PTO shaped-like src0/src1/dst"; + if (getElemTy(src0Ty) != getElemTy(src1Ty) || + getElemTy(src0Ty) != getElemTy(dstTy)) + return emitOpError() << "expects src0/src1/dst to have the same element type"; + Type elem = getElemTy(src0Ty); + if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isInteger(8) || + elem.isF16() || elem.isBF16() || elem.isF32())) + return emitOpError("expects A5 tpartadd element type to be i32/i16/i8/f16/bf16/f32"); + auto s0 = getShapeVec(src0Ty); + auto s1 = getShapeVec(src1Ty); + auto d = getShapeVec(dstTy); + if (s0.size() != 2 || s1.size() != 2 || d.size() != 2) + return emitOpError() << "expects src0/src1/dst to be rank-2 (tile-shaped)"; + return mlir::success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TPartMaxOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + Type t0 = getSrc0().getType(); + Type t1 = getSrc1().getType(); + Type td = getDst().getType(); + FailureOr elemOr = + verifyPTOShapedBinarySameElemAndShape(getOperation(), t0, t1, td); + if (failed(elemOr)) + return failure(); + if (failed(verifyPartialValidPattern(*this, t0, t1, td))) + return failure(); + Type e0 = *elemOr; + if (!(e0.isInteger(32) || e0.isInteger(16) || e0.isF16() || e0.isF32())) + return emitOpError("expects A2/A3 tpartmax element type to be i32/i16/f16/f32"); + return mlir::success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + Type t0 = getSrc0().getType(); + Type t1 = getSrc1().getType(); + Type td = getDst().getType(); + FailureOr elemOr = + verifyPTOShapedBinarySameElemAndShape(getOperation(), t0, t1, td); + if (failed(elemOr)) + return failure(); + Type e0 = *elemOr; + if (!(e0.isInteger(32) || e0.isInteger(16) || e0.isInteger(8) || + e0.isF16() || e0.isBF16() || e0.isF32())) + return emitOpError("expects A5 tpartmax element type to be i32/i16/i8/f16/bf16/f32"); + return mlir::success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TPartMinOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + Type t0 = getSrc0().getType(); + Type t1 = getSrc1().getType(); + Type td = getDst().getType(); + FailureOr elemOr = + verifyPTOShapedBinarySameElemAndShape(getOperation(), t0, t1, td); + if (failed(elemOr)) + return failure(); + if (failed(verifyPartialValidPattern(*this, t0, t1, td))) + return failure(); + Type e0 = *elemOr; + if (!(e0.isInteger(32) || e0.isInteger(16) || e0.isF16() || e0.isF32())) + return emitOpError("expects A2/A3 tpartmin element type to be i32/i16/f16/f32"); + return mlir::success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + Type t0 = getSrc0().getType(); + Type t1 = getSrc1().getType(); + Type td = getDst().getType(); + FailureOr elemOr = + verifyPTOShapedBinarySameElemAndShape(getOperation(), t0, t1, td); + if (failed(elemOr)) + return failure(); + Type e0 = *elemOr; + if (!(e0.isInteger(32) || e0.isInteger(16) || e0.isInteger(8) || + e0.isF16() || e0.isBF16() || e0.isF32())) + return emitOpError("expects A5 tpartmin element type to be i32/i16/i8/f16/bf16/f32"); + return mlir::success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +static LogicalResult verifyTPartArgOpCommon(Operation *op, Type src0Ty, + Type src1Ty, Type src0IdxTy, + Type src1IdxTy, Type dstTy, + Type dstIdxTy, StringRef opName) { + FailureOr dataElemOr = + verifyPTOShapedBinarySameElemAndShape(op, src0Ty, src1Ty, dstTy); + if (failed(dataElemOr)) + return failure(); + if (failed(verifyPartialValidPattern(op, src0Ty, src1Ty, dstTy))) + return failure(); + + if (!isPTOShapedLike(src0IdxTy) || !isPTOShapedLike(src1IdxTy) || + !isPTOShapedLike(dstIdxTy)) + return op->emitOpError("expects PTO shaped-like src0Idx/src1Idx/dstIdx"); + Type idxElem = getElemTy(src0IdxTy); + if (!idxElem || idxElem != getElemTy(src1IdxTy) || + idxElem != getElemTy(dstIdxTy)) + return op->emitOpError( + "expects src0Idx/src1Idx/dstIdx to have the same element type"); + auto idxInt = dyn_cast(idxElem); + if (!idxInt || idxInt.getWidth() != 32) + return op->emitOpError( + "expects src0Idx/src1Idx/dstIdx element type to be i32 or ui32"); + + auto dataShape = getShapeVec(src0Ty); + if (dataShape != getShapeVec(src0IdxTy) || + dataShape != getShapeVec(src1IdxTy) || + dataShape != getShapeVec(dstIdxTy)) + return op->emitOpError( + "expects data and index operands to have the same shape"); + if (getValidShapeVec(src0Ty) != getValidShapeVec(src0IdxTy) || + getValidShapeVec(src1Ty) != getValidShapeVec(src1IdxTy) || + getValidShapeVec(dstTy) != getValidShapeVec(dstIdxTy)) + return op->emitOpError( + "expects each data operand and its index operand to have the same valid_shape"); + + Type elem = *dataElemOr; + PTOArch arch = getTargetArch(op); + if (arch == PTOArch::A5) { + if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isInteger(8) || + elem.isF16() || elem.isBF16() || elem.isF32())) + return op->emitOpError() << "expects A5 " << opName + << " element type to be i32/i16/i8/f16/bf16/f32"; + } else { + if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || + elem.isF32())) + return op->emitOpError() << "expects A2/A3 " << opName + << " element type to be i32/i16/f16/f32"; + } + return success(); +} + +mlir::LogicalResult mlir::pto::TPartArgMaxOp::verify() { + auto verifyByArch = [&]() -> LogicalResult { + return verifyTPartArgOpCommon( + getOperation(), getSrc0().getType(), getSrc1().getType(), + getSrc0Idx().getType(), getSrc1Idx().getType(), getDst().getType(), + getDstIdx().getType(), "tpartargmax"); + }; + return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); +} + +mlir::LogicalResult mlir::pto::TPartArgMinOp::verify() { + auto verifyByArch = [&]() -> LogicalResult { + return verifyTPartArgOpCommon( + getOperation(), getSrc0().getType(), getSrc1().getType(), + getSrc0Idx().getType(), getSrc1Idx().getType(), getDst().getType(), + getDstIdx().getType(), "tpartargmin"); + }; + return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); +} + +mlir::LogicalResult mlir::pto::TPartMulOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + Type src0Ty = getSrc0().getType(); + Type src1Ty = getSrc1().getType(); + Type dstTy = getDst().getType(); + if (!isPTOShapedLike(src0Ty) || !isPTOShapedLike(src1Ty) || + !isPTOShapedLike(dstTy)) + return emitOpError() << "expects PTO shaped-like src0/src1/dst"; + if (getElemTy(src0Ty) != getElemTy(src1Ty) || + getElemTy(src0Ty) != getElemTy(dstTy)) + return emitOpError() + << "expects src0/src1/dst to have the same element type"; + auto s0 = getShapeVec(src0Ty); + auto s1 = getShapeVec(src1Ty); + auto d = getShapeVec(dstTy); + if (s0.size() != 2 || s1.size() != 2 || d.size() != 2) + return emitOpError() + << "expects src0/src1/dst to be rank-2 (tile-shaped)"; + if (failed(verifyPartialValidPattern(*this, src0Ty, src1Ty, dstTy))) + return failure(); + Type elem = getElemTy(src0Ty); + if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || + elem.isF32())) + return emitOpError( + "expects A2/A3 tpartmul element type to be i32/i16/f16/f32"); + return mlir::success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + Type src0Ty = getSrc0().getType(); + Type src1Ty = getSrc1().getType(); + Type dstTy = getDst().getType(); + if (!isPTOShapedLike(src0Ty) || !isPTOShapedLike(src1Ty) || + !isPTOShapedLike(dstTy)) + return emitOpError() << "expects PTO shaped-like src0/src1/dst"; + if (getElemTy(src0Ty) != getElemTy(src1Ty) || + getElemTy(src0Ty) != getElemTy(dstTy)) + return emitOpError() + << "expects src0/src1/dst to have the same element type"; + Type elem = getElemTy(src0Ty); + if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isInteger(8) || + elem.isF16() || elem.isBF16() || elem.isF32())) + return emitOpError( + "expects A5 tpartmul element type to be i32/i16/i8/f16/bf16/f32"); + auto s0 = getShapeVec(src0Ty); + auto s1 = getShapeVec(src1Ty); + auto d = getShapeVec(dstTy); + if (s0.size() != 2 || s1.size() != 2 || d.size() != 2) + return emitOpError() + << "expects src0/src1/dst to be rank-2 (tile-shaped)"; + return mlir::success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TPReluOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + auto verifyCommon = [&]() -> FailureOr> { + Type t0 = getSrc0().getType(); + Type t1 = getSrc1().getType(); + Type tt = getTmp().getType(); + Type td = getDst().getType(); + if (failed(verifyTileBufCommon(*this, t0, "src0")) || + failed(verifyTileBufCommon(*this, t1, "src1")) || + failed(verifyTileBufCommon(*this, tt, "tmp")) || + failed(verifyTileBufCommon(*this, td, "dst"))) + return failure(); + + Type e0 = getElemTy(t0), e1 = getElemTy(t1), et = getElemTy(tt), ed = getElemTy(td); + if (!e0 || !e1 || !et || !ed) { + emitOpError("failed to get element type for operands"); + return failure(); + } + if (e0 != e1 || e0 != ed) { + emitOpError("expects dst/src0/src1 to have the same element type"); + return failure(); + } + if (!(e0.isF16() || e0.isF32())) { + emitOpError("expects dst/src0/src1 element type to be f16 or f32"); + return failure(); + } + if (!isRowMajorTileBuf(t0) || !isRowMajorTileBuf(t1) || !isRowMajorTileBuf(td)) { + emitOpError("expects src0, src1, and dst to use row-major layout"); + return failure(); + } + if (failed(verifyTileBufSameValidShape(*this, t0, td, "src0", "dst")) || + failed(verifyTileBufSameValidShape(*this, t1, td, "src1", "dst"))) + return failure(); + + auto s0 = getShapeVec(t0), s1 = getShapeVec(t1), st = getShapeVec(tt), sd = getShapeVec(td); + if (s0 != s1 || s0 != st || s0 != sd) { + emitOpError("expects src0/src1/tmp/dst to have the same shape"); + return failure(); + } + return std::make_tuple(t0, t1, tt, td); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + auto tysOr = verifyCommon(); + if (failed(tysOr)) + return failure(); + auto [t0, t1, tt, td] = *tysOr; + Type tmpElem = getElemTy(tt); + auto tmpIntTy = mlir::dyn_cast(tmpElem); + if (!tmpIntTy || tmpIntTy.getWidth() != 8) + return emitOpError("expects A2/A3 tmp element type to be u8"); + if (!isRowMajorTileBuf(tt)) + return emitOpError("expects tmp to use row-major layout"); + if (auto arch = getVerifierArchName(getOperation()); + arch && arch->equals_insensitive("a3")) { + if (getSrc0() == getSrc1() || getSrc0() == getTmp() || getSrc0() == getDst() || + getSrc1() == getTmp() || getSrc1() == getDst() || getTmp() == getDst()) + return emitOpError( + "expects A3 src0, src1, tmp, and dst to use different storage"); + } + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + auto tysOr = verifyCommon(); + if (failed(tysOr)) + return failure(); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TQuantOp::verify() { + // Structural checks: always run regardless of operand representation + // (applies both before and after PTOViewToMemref lowering). + auto verifyStructural = [&]() -> LogicalResult { + // dst elem type and offset presence must be consistent with quant_type. + Type dstTy = getDst().getType(); + Type dstElemTy = getElemTy(dstTy); + auto dstIntTy = dyn_cast(dstElemTy); + if (getQuantType() == mlir::pto::QuantType::INT8_SYM) { + if (!dstIntTy || dstIntTy.getWidth() != 8) + return emitOpError() + << "expects dst element type i8/ui8 for INT8_SYM quantization"; + if (getOffset()) + return emitOpError() + << "INT8_SYM quantization must not have an offset operand"; + } else { + // INT8_ASYM + if (!dstIntTy || dstIntTy.getWidth() != 8) + return emitOpError() + << "expects dst element type i8/ui8 for INT8_ASYM quantization"; + if (!getOffset()) + return emitOpError() + << "INT8_ASYM quantization requires an offset operand"; + } + return success(); + }; + + if (failed(verifyStructural())) + return failure(); + + // Layout/tile-buffer checks: only meaningful for pre-lowering tile types. + // Skip when operands are already plain MemRefs (post PTOViewToMemref). + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + + auto verifyCommon = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type fpTy = getFp().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, fpTy, "fp")) || + failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + // src must be f32 (ISA static_assert) + if (!getElemTy(srcTy).isF32()) + return emitOpError() << "expects src to have element type f32"; + if (getOffset()) { + Type offsetTy = getOffset().getType(); + if (failed(verifyTileBufCommon(*this, offsetTy, "offset"))) + return failure(); + if (!getElemTy(offsetTy).isF32()) + return emitOpError() << "expects offset to have element type f32"; + } + return success(); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + if (failed(verifyCommon())) + return failure(); + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (!isRowMajorTileBuf(srcTy) || !isRowMajorTileBuf(dstTy)) + return emitOpError() << "expects A2/A3 src and dst to use row-major layout"; + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + return verifyCommon(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TDequantOp::verify() { + // Structural checks: src must be i8 or i16, dst/scale/offset must be f32. + auto verifyStructural = [&]() -> LogicalResult { + Type srcElemTy = getElemTy(getSrc().getType()); + auto srcIntTy = dyn_cast(srcElemTy); + if (!srcIntTy || !(srcIntTy.getWidth() == 8 || srcIntTy.getWidth() == 16)) + return emitOpError() + << "expects src element type i8 or i16"; + if (!getElemTy(getDst().getType()).isF32()) + return emitOpError() << "expects dst element type f32"; + if (!getElemTy(getScale().getType()).isF32()) + return emitOpError() << "expects scale element type f32"; + if (!getElemTy(getOffset().getType()).isF32()) + return emitOpError() << "expects offset element type f32"; + return success(); + }; + + if (failed(verifyStructural())) + return failure(); + + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + + auto verifyCommon = [&]() -> LogicalResult { + if (failed(verifyTileBufCommon(*this, getSrc().getType(), "src")) || + failed(verifyTileBufCommon(*this, getScale().getType(), "scale")) || + failed(verifyTileBufCommon(*this, getOffset().getType(), "offset")) || + failed(verifyTileBufCommon(*this, getDst().getType(), "dst"))) + return failure(); + return success(); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + if (failed(verifyCommon())) + return failure(); + if (!isRowMajorTileBuf(getSrc().getType()) || + !isRowMajorTileBuf(getDst().getType())) + return emitOpError() + << "expects A2/A3 src and dst to use row-major layout"; + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { return verifyCommon(); }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TRecipOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + Type ts = getSrc().getType(); + Type td = getDst().getType(); + if (failed(verifyVecTileUnaryOp(*this, ts, td, "src", "dst", + /*allowBf16=*/false, /*allowInt8=*/false))) + return failure(); + if (failed(verifyTileBufSameValidShape(*this, ts, td, "src", "dst"))) + return failure(); + Type elemTy = getElemTy(ts); + if (!(elemTy.isF16() || elemTy.isF32())) + return emitOpError() << "expects element type to be f16 or f32"; + if (auto arch = getVerifierArchName(getOperation()); + arch && arch->equals_insensitive("a3") && getSrc() == getDst()) + return emitOpError("expects A3 trecip src and dst to use different storage"); + return mlir::success(); +} + +mlir::LogicalResult mlir::pto::TReluOp::verify() { + auto verifyByArch = [&](StringRef errorMessage) -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyVecTileCommon(*this, srcTy, "src")) || + failed(verifyVecTileCommon(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst")) || + failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) + return failure(); + Type elemTy = getElemTy(srcTy); + if (!(elemTy.isInteger(32) || elemTy.isF16() || elemTy.isF32())) + return emitOpError() << errorMessage; + return success(); + }; + auto verifyA2A3 = [&]() -> LogicalResult { + return verifyByArch("expects A2/A3 trelu element type to be i32/f16/f32"); + }; + auto verifyA5 = [&]() -> LogicalResult { + return verifyByArch("expects A5 trelu element type to be i32/f16/f32"); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + + +mlir::LogicalResult mlir::pto::TRemOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + + Type src0Ty = getSrc0().getType(); + Type src1Ty = getSrc1().getType(); + Type tmpTy = getTmp().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, src0Ty, "src0")) || + failed(verifyTileBufCommon(*this, src1Ty, "src1")) || + failed(verifyTileBufCommon(*this, tmpTy, "tmp")) || + failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameElemType(*this, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst")) || + failed(verifyTileBufSameValidShape(*this, src0Ty, src1Ty, "src0", "src1")) || + failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) + return failure(); + if (getElemTy(tmpTy) != getElemTy(dstTy)) + return emitOpError("expects tmp and dst to have the same element type"); + if (!isRowMajorTileBuf(src0Ty) || !isRowMajorTileBuf(src1Ty) || + !isRowMajorTileBuf(tmpTy) || !isRowMajorTileBuf(dstTy)) + return emitOpError("expects src0, src1, tmp, and dst to use row-major layout"); + auto dstValid = getValidShapeVec(dstTy); + auto tmpValid = getValidShapeVec(tmpTy); + if (dstValid.size() != 2 || tmpValid.size() != 2) + return emitOpError("expects tmp and dst to be rank-2 tiles"); + if (tmpValid[0] != ShapedType::kDynamic && tmpValid[0] < 1) + return emitOpError("expects tmp to have at least 1 valid row"); + if (dstValid[1] != ShapedType::kDynamic && tmpValid[1] != ShapedType::kDynamic && + tmpValid[1] < dstValid[1]) + return emitOpError("expects tmp valid columns to cover dst valid columns"); + + Type elem = getElemTy(src0Ty); + auto verifyA2A3 = [&]() -> LogicalResult { + if (!(elem.isInteger(32) || elem.isF32())) + return emitOpError("expects A2/A3 trem element type to be i32/f32"); + return success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || elem.isF32())) + return emitOpError("expects A5 trem element type to be i32/i16/f16/f32"); + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TFModOp::verify() { + return verifyArithmeticBinaryTileOpWithArchDispatch( + getOperation(), getSrc0().getType(), getSrc1().getType(), getDst().getType(), + /*allowInt8OnA5=*/false, /*allowBf16OnA5=*/false, + "expects A2/A3 tfmod element type to be i32/i16/f16/f32", + "expects A5 tfmod element type to be i32/i16/f16/f32"); +} + +mlir::LogicalResult mlir::pto::TRemSOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + Type ts = getSrc().getType(); + Type tt = getTmp().getType(); + Type td = getDst().getType(); + Type scalarTy = getScalar().getType(); + if (failed(verifyTileBufCommon(*this, ts, "src")) || + failed(verifyTileBufCommon(*this, tt, "tmp")) || + failed(verifyTileBufCommon(*this, td, "dst"))) + return failure(); + if (failed(verifyTileBufSameElemType(*this, ts, td, "src", "dst")) || + failed(verifyTileBufSameValidShape(*this, ts, td, "src", "dst"))) + return failure(); + if (getElemTy(tt) != getElemTy(td)) + return emitOpError("expects tmp and dst to have the same element type"); + if (!isRowMajorTileBuf(ts) || !isRowMajorTileBuf(tt) || !isRowMajorTileBuf(td)) + return emitOpError("expects src, tmp, and dst to use row-major layout"); + Type elem = getElemTy(ts); + if (scalarTy != elem) + return emitOpError("expects scalar type to match the tile element type"); + auto dstValid = getValidShapeVec(td); + auto tmpValid = getValidShapeVec(tt); + if (dstValid.size() != 2 || tmpValid.size() != 2) + return emitOpError("expects tmp and dst to be rank-2 tiles"); + if (tmpValid[0] != ShapedType::kDynamic && tmpValid[0] < 1) + return emitOpError("expects tmp to have at least 1 valid row"); + if (dstValid[1] != ShapedType::kDynamic && tmpValid[1] != ShapedType::kDynamic && + tmpValid[1] < dstValid[1]) + return emitOpError("expects tmp valid columns to cover dst valid columns"); + auto verifyA2A3 = [&]() -> LogicalResult { + if (!(elem.isInteger(32) || elem.isF32())) + return emitOpError("expects A2/A3 trems element type to be i32/f32"); + return success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || elem.isF32())) + return emitOpError("expects A5 trems element type to be i32/i16/f16/f32"); + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TFModSOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + Type scalarTy = getScalar().getType(); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst")) || + failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) + return failure(); + if (!isRowMajorTileBuf(srcTy) || !isRowMajorTileBuf(dstTy)) + return emitOpError("expects src and dst to use row-major layout"); + + Type elem = getElemTy(srcTy); + if (scalarTy != elem) + return emitOpError("expects scalar type to match the tile element type"); + + auto verifyA2A3 = [&]() -> LogicalResult { + if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || elem.isF32())) + return emitOpError("expects A2/A3 tfmods element type to be i32/i16/f16/f32"); + return success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || elem.isF32())) + return emitOpError("expects A5 tfmods element type to be i32/i16/f16/f32"); + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + + +static std::optional getStaticNumElements(ArrayRef shape) { + int64_t numel = 1; + for (int64_t d : shape) { + if (d == ShapedType::kDynamic) + return std::nullopt; + if (d < 0) + return std::nullopt; + numel *= d; + } + return numel; +} + +static std::optional getElemBytes(Type elemTy) { + if (!elemTy) + return std::nullopt; + if (auto ft = dyn_cast(elemTy)) { + if (ft.isF16() || ft.isBF16()) + return 2; + if (ft.isF32()) + return 4; + if (ft.isF64()) + return 8; + return std::nullopt; + } + if (auto it = dyn_cast(elemTy)) { + int64_t bits = it.getWidth(); + if (bits <= 0) + return std::nullopt; + return std::max(1, bits / 8); + } + return std::nullopt; +} + +[[maybe_unused]] static bool isTileBufOrMemref(Type ty) { + return mlir::isa(ty); +} + +static constexpr llvm::StringLiteral kLoweredSetValidShapeAttrName = + "__pto.lowered_set_validshape"; + +static bool isLocallyBoundTileSource(Value value) { + if (!value || isa(value)) + return false; + + if (isa( + value.getDefiningOp())) + return true; + + if (auto bitcast = value.getDefiningOp()) + return isLocallyBoundTileSource(bitcast.getSrc()); + if (auto reshape = value.getDefiningOp()) + return isLocallyBoundTileSource(reshape.getSrc()); + + return false; +} + +static std::optional getConstIndexLike(Value v) { + if (auto cOp = v.getDefiningOp()) + return cOp.value(); + if (auto cInt = v.getDefiningOp()) + return cInt.value(); + if (auto cOp = v.getDefiningOp()) { + if (auto ia = dyn_cast(cOp.getValue())) + return ia.getInt(); + } + if (auto castOp = v.getDefiningOp()) + return getConstIndexLike(castOp.getIn()); + if (auto extOp = v.getDefiningOp()) + return getConstIndexLike(extOp.getIn()); + if (auto extOp = v.getDefiningOp()) + return getConstIndexLike(extOp.getIn()); + if (auto truncOp = v.getDefiningOp()) + return getConstIndexLike(truncOp.getIn()); + return std::nullopt; +} + +mlir::LogicalResult mlir::pto::SetValidShapeOp::verify() { + SmallVector shape; + if (auto srcTy = llvm::dyn_cast(getSource().getType())) { + if (srcTy.getRank() != 2) + return emitOpError("expects rank-2 tile_buf source"); + + ArrayRef validShape = srcTy.getValidShape(); + if (validShape.size() != 2) + return emitOpError("expects source validShape to be rank-2"); + if (!srcTy.hasDynamicValid()) + return emitOpError("expects source tile_buf to have dynamic validShape (?, ?)"); + + shape.assign(srcTy.getShape().begin(), srcTy.getShape().end()); + + if (!isLocallyBoundTileSource(getSource())) + return emitOpError( + "requires a locally bound tile source; function arguments/results " + "are unsupported"); + } else if (auto srcTy = llvm::dyn_cast(getSource().getType())) { + if (!(*this)->hasAttr(kLoweredSetValidShapeAttrName)) + return emitOpError( + "expects tile_buf source; memref source is only valid for the internal lowered form"); + if (srcTy.getRank() != 2) + return emitOpError("expects rank-2 memref source after tile lowering"); + shape.assign(srcTy.getShape().begin(), srcTy.getShape().end()); + } else { + return emitOpError("expects tile_buf source (or lowered memref source)"); + } + + auto checkDim = [&](Value operand, unsigned dimIdx, + StringRef dimName) -> LogicalResult { + int64_t maxStatic = shape[dimIdx]; + + auto constVal = getConstIndexLike(operand); + if (!constVal) + return success(); + + if (*constVal < 0) + return emitOpError() << "expects " << dimName << " operand to be non-negative"; + if (maxStatic != ShapedType::kDynamic && *constVal > maxStatic) + return emitOpError() << "expects " << dimName << " operand <= shape dim (" + << maxStatic << ")"; + return success(); + }; + + if (failed(checkDim(getValidRow(), /*dimIdx=*/0, "row"))) + return failure(); + if (failed(checkDim(getValidCol(), /*dimIdx=*/1, "col"))) + return failure(); + + return success(); +} + +mlir::LogicalResult mlir::pto::GetValidShapeOp::verify() { + if (auto srcTy = llvm::dyn_cast(getSource().getType())) { + if (srcTy.getRank() != 2) + return emitOpError("expects rank-2 tile_buf source"); + if (srcTy.getValidShape().size() != 2) + return emitOpError("expects source validShape to be rank-2"); + return success(); + } + if (auto srcTy = llvm::dyn_cast(getSource().getType())) { + if (srcTy.getRank() != 2) + return emitOpError("expects rank-2 memref source after tile lowering"); + return success(); + } + return emitOpError("expects tile_buf source (or lowered memref source)"); +} + + +mlir::LogicalResult mlir::pto::TReshapeOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + Type ts = getSrc().getType(); + Type tr = getResult().getType(); + auto srcTb = dyn_cast(ts); + auto dstTb = dyn_cast(tr); + if (!srcTb || !dstTb) + return emitOpError("expects src/result to be !pto.tile_buf types"); + + if (failed(verifyTileBufCommon(*this, ts, "src")) || + failed(verifyTileBufCommon(*this, tr, "dst"))) + return failure(); + + if (srcTb.getMemorySpace() != dstTb.getMemorySpace()) + return emitOpError("expects src and dst to use the same loc"); + + Type srcElem = srcTb.getElementType(); + Type dstElem = dstTb.getElementType(); + auto srcElemBytes = getElemBytes(srcElem); + auto dstElemBytes = getElemBytes(dstElem); + if (!srcElem || !dstElem || !srcElemBytes.has_value() || !dstElemBytes.has_value()) + return emitOpError("failed to get element byte width for src/dst"); + + auto srcNumel = getStaticNumElements(getShapeVec(ts)); + auto dstNumel = getStaticNumElements(getShapeVec(tr)); + if (!srcNumel.has_value() || !dstNumel.has_value()) + return emitOpError("expects static shapes for treshape"); + + if (srcElemBytes.value() * srcNumel.value() != + dstElemBytes.value() * dstNumel.value()) + return emitOpError("expects src and dst to have the same total byte size"); + + bool srcBoxed = + srcTb.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox); + bool dstBoxed = + dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox); + if (srcBoxed != dstBoxed) + return emitOpError("cannot reshape between boxed and non-boxed tile layouts"); + + return success(); +} + +mlir::LogicalResult mlir::pto::BitcastOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + auto srcTy = llvm::dyn_cast(getSrc().getType()); + auto dstTy = llvm::dyn_cast(getResult().getType()); + if (!srcTy || !dstTy) + return emitOpError("expects tile_buf src and tile_buf result"); + + if (srcTy.getMemorySpace() != dstTy.getMemorySpace()) + return emitOpError("expects src/result to have the same memorySpace"); + + if (srcTy.getElementType() == dstTy.getElementType()) + return emitOpError( + "expects src/result to have different element types; use " + "pto.treshape for shape/config changes"); + + if (srcTy.getShape() != dstTy.getShape()) + return emitOpError("expects src/result to have the same shape; use pto.treshape for shape changes"); + + if (srcTy.getValidShape() != dstTy.getValidShape()) + return emitOpError("expects src/result to have the same validShape"); + + auto srcCfg = srcTy.getConfigAttr(); + auto dstCfg = dstTy.getConfigAttr(); + if (srcCfg != dstCfg) + return emitOpError("expects src/result to have the same tile config"); + + auto numel = getStaticNumElements(srcTy.getShape()); + if (!numel.has_value()) + return emitOpError("expects static shapes for bitcast"); + + auto srcBytes = getElemBytes(srcTy.getElementType()); + auto dstBytes = getElemBytes(dstTy.getElementType()); + if (!srcBytes.has_value() || !dstBytes.has_value()) + return emitOpError("unsupported element type for bitcast"); + + int64_t srcTotalBytes = numel.value() * srcBytes.value(); + int64_t dstTotalBytes = numel.value() * dstBytes.value(); + if (dstTotalBytes > srcTotalBytes) + return emitOpError("bitcast result requires more bytes than source storage"); + + return success(); +} + + +mlir::LogicalResult mlir::pto::TRowExpandOp::verify() { + auto verifyCommon = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyNDStyleVecTile(*this, dstTy, "dst"))) + return failure(); + auto srcSpace = getPTOMemorySpaceEnum(srcTy); + if (!srcSpace || *srcSpace != pto::AddressSpace::VEC) + return emitOpError("expects src to be in the vec address space"); + if (auto srcTb = dyn_cast(srcTy)) { + if (srcTb.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) + return emitOpError("expects src to use the none_box slayout"); + } + if (getElemTy(srcTy) != getElemTy(dstTy)) + return emitOpError("expects src and dst to have the same element type"); + if (!isSupportedVecElemType(getElemTy(srcTy), /*allowBf16=*/true, + /*allowInt8=*/true)) + return emitOpError("expects trowexpand element type to be supported"); + auto srcValid = getValidShapeVec(getSrc()); + auto dstValid = getValidShapeVec(getDst()); + if (srcValid.size() != 2 || dstValid.size() != 2) + return emitOpError("expects src and dst to have rank-2 valid_shape"); + if (srcValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && + srcValid[0] != dstValid[0]) + return emitOpError("expects src and dst to have the same valid_shape[0]"); + if (srcValid[0] != ShapedType::kDynamic && srcValid[0] == 0) + return emitOpError("expects src valid_shape[0] to be non-zero"); + if (srcValid[1] != ShapedType::kDynamic && srcValid[1] == 0) + return emitOpError("expects src valid_shape[1] to be non-zero"); + if (dstValid[0] != ShapedType::kDynamic && dstValid[0] == 0) + return emitOpError("expects dst valid_shape[0] to be non-zero"); + if (dstValid[1] != ShapedType::kDynamic && dstValid[1] == 0) + return emitOpError("expects dst valid_shape[1] to be non-zero"); + return success(); + }; + auto verifyA2A3 = [&]() -> LogicalResult { + return verifyCommon(); + }; + auto verifyA5 = [&]() -> LogicalResult { + return verifyCommon(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + + +ParseResult mlir::pto::TSort32Op::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand src, idx, tmp, dst; + Type srcTy, dstTy, idxTy, tmpTy; + bool hasTmp = false; + + if (parser.parseKeyword("ins") || parser.parseLParen() || parser.parseOperand(src)) + return failure(); + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseOperand(idx)) + return failure(); + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseOperand(tmp)) + return failure(); + hasTmp = true; + } + } else { + return failure(); + } + if (parser.parseColonType(srcTy) || parser.parseComma() || parser.parseType(idxTy)) + return failure(); + if (hasTmp) { + if (parser.parseComma() || parser.parseType(tmpTy)) + return failure(); + } + if (parser.parseRParen()) + return failure(); + + if (parser.parseKeyword("outs") || parser.parseLParen() || + parser.parseOperand(dst) || parser.parseColonType(dstTy) || + parser.parseRParen()) + return failure(); + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + if (parser.resolveOperand(src, srcTy, result.operands) || + parser.resolveOperand(idx, idxTy, result.operands)) + return failure(); + if (hasTmp) { + if (parser.resolveOperand(tmp, tmpTy, result.operands)) + return failure(); + } + if (parser.resolveOperand(dst, dstTy, result.operands)) + return failure(); + + result.addAttribute( + "operandSegmentSizes", + parser.getBuilder().getDenseI32ArrayAttr({1, 1, hasTmp ? 1 : 0, 1})); + return success(); +} + +void mlir::pto::TSort32Op::print(OpAsmPrinter &p) { + p << " ins(" << getSrc() << ", " << getIdx(); + if (getTmp()) { + p << ", " << getTmp(); + p << " : " << getSrc().getType() << ", " << getIdx().getType() + << ", " << getTmp().getType() << ")"; + } else { + p << " : " << getSrc().getType() << ", " << getIdx().getType() << ")"; + } + p << " outs(" << getDst() << " : " << getDst().getType() << ")"; + p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"operandSegmentSizes"}); +} + +ParseResult mlir::pto::TRsqrtOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand src, tmp, dst; + Type srcTy, tmpTy, dstTy; + bool hasTmp = false; + + if (parser.parseKeyword("ins") || parser.parseLParen() || parser.parseOperand(src)) + return failure(); + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseOperand(tmp)) + return failure(); + hasTmp = true; + } + if (parser.parseColonType(srcTy)) + return failure(); + if (hasTmp) { + if (parser.parseComma() || parser.parseType(tmpTy)) + return failure(); + } + if (parser.parseRParen()) + return failure(); + + if (parser.parseKeyword("outs") || parser.parseLParen() || + parser.parseOperand(dst) || parser.parseColonType(dstTy) || + parser.parseRParen()) + return failure(); + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + if (parser.resolveOperand(src, srcTy, result.operands) || + parser.resolveOperand(dst, dstTy, result.operands)) + return failure(); + if (hasTmp && parser.resolveOperand(tmp, tmpTy, result.operands)) + return failure(); + + return success(); +} + +void mlir::pto::TRsqrtOp::print(OpAsmPrinter &p) { + p << " ins(" << getSrc(); + if (getTmp()) + p << ", " << getTmp(); + p << " : " << getSrc().getType(); + if (getTmp()) + p << ", " << getTmp().getType(); + p << ")"; + p << " outs(" << getDst() << " : " << getDst().getType() << ")"; + p.printOptionalAttrDict((*this)->getAttrs()); +} + +static ParseResult parseTRowExpandBinaryLikeOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand src0, src1, tmp, dst; + Type src0Ty, src1Ty, tmpTy, dstTy; + bool hasTmp = false; + + if (parser.parseKeyword("ins") || parser.parseLParen() || + parser.parseOperand(src0) || parser.parseComma() || parser.parseOperand(src1)) + return failure(); + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseOperand(tmp)) + return failure(); + hasTmp = true; + } + if (parser.parseColon()) + return failure(); + if (parser.parseType(src0Ty) || parser.parseComma() || parser.parseType(src1Ty)) + return failure(); + if (hasTmp) { + if (parser.parseComma() || parser.parseType(tmpTy)) + return failure(); + } + if (parser.parseRParen()) + return failure(); + if (parser.parseKeyword("outs") || parser.parseLParen() || + parser.parseOperand(dst) || parser.parseColonType(dstTy) || + parser.parseRParen()) + return failure(); + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + if (parser.resolveOperand(src0, src0Ty, result.operands) || + parser.resolveOperand(src1, src1Ty, result.operands)) + return failure(); + if (hasTmp) { + if (parser.resolveOperand(tmp, tmpTy, result.operands)) + return failure(); + } + if (parser.resolveOperand(dst, dstTy, result.operands)) + return failure(); + + result.addAttribute( + "operandSegmentSizes", + parser.getBuilder().getDenseI32ArrayAttr({1, 1, hasTmp ? 1 : 0, 1})); + return success(); +} + +static void printTRowExpandBinaryLikeOp(OpAsmPrinter &p, Operation *op, Value src0, + Value src1, Value tmp, Value dst) { + p << " ins(" << src0 << ", " << src1; + if (tmp) { + p << ", " << tmp; + p << " : " << src0.getType() << ", " << src1.getType() << ", " + << tmp.getType() << ")"; + } else { + p << " : " << src0.getType() << ", " << src1.getType() << ")"; + } + p << " outs(" << dst << " : " << dst.getType() << ")"; + p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"operandSegmentSizes"}); +} + +ParseResult mlir::pto::TRowExpandDivOp::parse(OpAsmParser &parser, OperationState &result) { + return parseTRowExpandBinaryLikeOp(parser, result); +} + +void mlir::pto::TRowExpandDivOp::print(OpAsmPrinter &p) { + printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), + getDst()); +} + +ParseResult mlir::pto::TRowExpandMulOp::parse(OpAsmParser &parser, OperationState &result) { + return parseTRowExpandBinaryLikeOp(parser, result); +} + +void mlir::pto::TRowExpandMulOp::print(OpAsmPrinter &p) { + printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), + getDst()); +} + +ParseResult mlir::pto::TRowExpandSubOp::parse(OpAsmParser &parser, OperationState &result) { + return parseTRowExpandBinaryLikeOp(parser, result); +} + +void mlir::pto::TRowExpandSubOp::print(OpAsmPrinter &p) { + printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), + getDst()); +} + +ParseResult mlir::pto::TRowExpandExpdifOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseTRowExpandBinaryLikeOp(parser, result); +} + +void mlir::pto::TRowExpandExpdifOp::print(OpAsmPrinter &p) { + printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), + getDst()); +} + +ParseResult mlir::pto::TRowExpandMaxOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseTRowExpandBinaryLikeOp(parser, result); +} + +void mlir::pto::TRowExpandMaxOp::print(OpAsmPrinter &p) { + printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), + getDst()); +} + +ParseResult mlir::pto::TRowExpandMinOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseTRowExpandBinaryLikeOp(parser, result); +} + +void mlir::pto::TRowExpandMinOp::print(OpAsmPrinter &p) { + printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), + getDst()); +} + +static FailureOr verifyTRowExpandBinaryCore(Operation *op, Type src0Ty, + Type src1Ty, Type dstTy, + Type tmpTy, bool hasTmp) { + if (failed(verifyTileBufCommon(op, src0Ty, "src0")) || + failed(verifyTileBufCommon(op, src1Ty, "src1")) || + failed(verifyTileBufCommon(op, dstTy, "dst"))) + return failure(); + if (hasTmp && failed(verifyTileBufCommon(op, tmpTy, "tmp"))) + return failure(); + if (failed(verifyTileBufSameElemType(op, src0Ty, dstTy, "src0", "dst"))) + return failure(); + if (getElemTy(src0Ty) != getElemTy(src1Ty)) { + op->emitOpError("expects src0 and src1 to have the same element type"); + return failure(); + } + if (!isRowMajorTileBuf(dstTy)) { + op->emitOpError("expects dst to use row-major layout"); + return failure(); + } + return getElemTy(src0Ty); +} + +mlir::LogicalResult mlir::pto::TRowExpandDivOp::verify() { + auto verifyByArch = [&](PTOArch targetArch) -> LogicalResult { + Type src0Ty = getSrc0().getType(); + Type src1Ty = getSrc1().getType(); + Type dstTy = getDst().getType(); + FailureOr elemOr = verifyTRowExpandBinaryCore( + *this, src0Ty, src1Ty, dstTy, getTmp() ? getTmp().getType() : Type{}, + static_cast(getTmp())); + if (failed(elemOr)) + return failure(); + Type elem = *elemOr; + bool supported = + elem.isF16() || elem.isF32() || + (targetArch == PTOArch::A5 && + (elem.isInteger(8) || elem.isInteger(16) || elem.isInteger(32))); + if (!supported) { + if (targetArch == PTOArch::A5) + return emitOpError( + "expects A5 trowexpanddiv element type to be i8/i16/i32/f16/f32"); + return emitOpError("expects element type to be f16 or f32"); + } + return mlir::success(); + }; + auto verifyA2A3 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A3); }; + auto verifyA5 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A5); }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + + +mlir::LogicalResult mlir::pto::TRowExpandMulOp::verify() { + auto verifyByArch = [&](PTOArch targetArch) -> LogicalResult { + Type src0Ty = getSrc0().getType(); + Type src1Ty = getSrc1().getType(); + Type dstTy = getDst().getType(); + FailureOr elemOr = verifyTRowExpandBinaryCore( + *this, src0Ty, src1Ty, dstTy, getTmp() ? getTmp().getType() : Type{}, + static_cast(getTmp())); + if (failed(elemOr)) + return failure(); + Type elem = *elemOr; + bool supported = elem.isF16() || elem.isF32() || elem.isInteger(16) || + elem.isInteger(32) || + (targetArch == PTOArch::A5 && elem.isInteger(8)); + if (!supported) { + if (targetArch == PTOArch::A5) + return emitOpError( + "expects A5 trowexpandmul element type to be i8/i16/i32/f16/f32"); + return emitOpError( + "expects A2/A3 trowexpandmul element type to be i16/i32/f16/f32"); + } + return mlir::success(); + }; + auto verifyA2A3 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A3); }; + auto verifyA5 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A5); }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + + +mlir::LogicalResult mlir::pto::TRowExpandSubOp::verify() { + auto verifyByArch = [&](PTOArch targetArch) -> LogicalResult { + Type src0Ty = getSrc0().getType(); + Type src1Ty = getSrc1().getType(); + Type dstTy = getDst().getType(); + FailureOr elemOr = verifyTRowExpandBinaryCore( + *this, src0Ty, src1Ty, dstTy, getTmp() ? getTmp().getType() : Type{}, + static_cast(getTmp())); + if (failed(elemOr)) + return failure(); + Type elem = *elemOr; + bool supported = elem.isF16() || elem.isF32() || elem.isInteger(16) || + elem.isInteger(32) || + (targetArch == PTOArch::A5 && elem.isInteger(8)); + if (!supported) { + if (targetArch == PTOArch::A5) + return emitOpError( + "expects A5 trowexpandsub element type to be i8/i16/i32/f16/f32"); + return emitOpError( + "expects A2/A3 trowexpandsub element type to be i16/i32/f16/f32"); + } + return mlir::success(); + }; + auto verifyA2A3 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A3); }; + auto verifyA5 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A5); }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TRowExpandAddOp::verify() { + auto verifyByArch = [&](PTOArch targetArch) -> LogicalResult { + Type src0Ty = getSrc0().getType(); + Type src1Ty = getSrc1().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, src0Ty, "src0")) || + failed(verifyTileBufCommon(*this, src1Ty, "src1")) || + failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst"))) + return failure(); + if (failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) + return failure(); + if (getElemTy(src0Ty) != getElemTy(src1Ty)) + return emitOpError("expects src0 and src1 to have the same element type"); + if (!isRowMajorTileBuf(src0Ty)) + return emitOpError("expects src0 to use row-major layout"); + if (!isRowMajorTileBuf(dstTy)) + return emitOpError("expects dst to use row-major layout"); + Type elem = getElemTy(src0Ty); + bool supported = elem.isF16() || elem.isF32() || elem.isInteger(16) || + elem.isInteger(32) || + (targetArch == PTOArch::A5 && elem.isInteger(8)); + if (!supported) { + if (targetArch == PTOArch::A5) + return emitOpError( + "expects A5 trowexpandadd element type to be i8/i16/i32/f16/f32"); + return emitOpError( + "expects A2/A3 trowexpandadd element type to be i16/i32/f16/f32"); + } + auto src1Valid = getValidShapeVec(src1Ty); + auto dstValid = getValidShapeVec(dstTy); + if (src1Valid.size() != 2 || dstValid.size() != 2) + return emitOpError("expects src1 and dst to have rank-2 valid_shape"); + if (src1Valid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && + src1Valid[0] != dstValid[0]) + return emitOpError("expects src1 valid_shape[0] to equal dst valid_shape[0]"); + bool src1IsRowMajor = isRowMajorTileBuf(src1Ty); + int64_t expectedCol = elem.isInteger(8) + ? 32 + : ((elem.isF16() || elem.isInteger(16)) ? 16 : 8); + int64_t src1Col = src1Valid[1]; + if (src1IsRowMajor) { + if (src1Col != ShapedType::kDynamic && src1Col != expectedCol) + return emitOpError("expects row-major src1 valid_shape[1] to be 32/sizeof(dtype)"); + } else { + if (src1Col != ShapedType::kDynamic && src1Col != 1) + return emitOpError("expects non-row-major src1 valid_shape[1] to be 1"); + } + return mlir::success(); + }; + auto verifyA2A3 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A3); }; + auto verifyA5 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A5); }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +static LogicalResult verifyTRowExpandReduceLikeOp(Operation *op, Type src0Ty, + Type src1Ty, Type dstTy, + Type tmpTy, bool hasTmp, + PTOArch targetArch, + StringRef opName, + bool allowIntegerTypes) { + if (failed(verifyTileBufCommon(op, src0Ty, "src0")) || + failed(verifyTileBufCommon(op, src1Ty, "src1")) || + failed(verifyTileBufCommon(op, dstTy, "dst"))) + return failure(); + if (hasTmp) { + if (failed(verifyTileBufCommon(op, tmpTy, "tmp"))) + return failure(); + if (getElemTy(tmpTy) != getElemTy(dstTy)) + return op->emitOpError() << "expects tmp and dst to have the same element type"; + } + + Type elem = getElemTy(dstTy); + if (!elem || getElemTy(src0Ty) != elem || getElemTy(src1Ty) != elem) + return op->emitOpError("expects src0, src1, and dst to have the same element type"); + bool supported = elem.isF16() || elem.isF32() || + (allowIntegerTypes && + (elem.isInteger(16) || elem.isInteger(32) || + (targetArch == PTOArch::A5 && elem.isInteger(8)))); + if (!supported) { + if (!allowIntegerTypes) + return op->emitOpError() << "expects " << opName + << " element type to be f16 or f32"; + if (targetArch == PTOArch::A5) + return op->emitOpError() << "expects A5 " << opName + << " element type to be i8/i16/i32/f16/f32"; + return op->emitOpError() << "expects A2/A3 " << opName + << " element type to be i16/i32/f16/f32"; + } + + if (!isRowMajorTileBuf(dstTy)) + return op->emitOpError("expects dst to use row-major layout"); + + auto src0Valid = getValidShapeVec(src0Ty); + auto src1Valid = getValidShapeVec(src1Ty); + auto dstValid = getValidShapeVec(dstTy); + if (src0Valid.size() != 2 || src1Valid.size() != 2 || dstValid.size() != 2) + return op->emitOpError("expects src0, src1, and dst to have rank-2 valid_shape"); + + if (dstValid[0] != ShapedType::kDynamic && dstValid[0] == 0) + return op->emitOpError("expects dst valid_shape[0] to be non-zero"); + if (dstValid[1] != ShapedType::kDynamic && dstValid[1] == 0) + return op->emitOpError("expects dst valid_shape[1] to be non-zero"); + + auto validShapeMatches = [](ArrayRef lhs, + ArrayRef rhs) -> bool { + if (lhs.size() != rhs.size()) + return false; + for (auto [l, r] : llvm::zip(lhs, rhs)) { + if (l != ShapedType::kDynamic && r != ShapedType::kDynamic && l != r) + return false; + } + return true; + }; + + const bool src0MatchesDst = validShapeMatches(src0Valid, dstValid); + const bool src1MatchesDst = validShapeMatches(src1Valid, dstValid); + + auto checkBroadcastOperand = [&](Type operandTy, ArrayRef operandValid, + StringRef operandName, + bool requireNonRowMajor) -> LogicalResult { + if (operandValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && + operandValid[0] != dstValid[0]) { + return op->emitOpError() << "expects " << operandName + << " valid_shape[0] to equal dst valid_shape[0]"; + } + int64_t expectedCol = elem.isInteger(8) ? 32 : ((elem.isF16() || elem.isInteger(16)) ? 16 : 8); + int64_t operandCol = operandValid[1]; + bool operandIsRowMajor = isRowMajorTileBuf(operandTy); + if (requireNonRowMajor && operandIsRowMajor) { + return op->emitOpError() << "expects " << operandName + << " to use a non-row-major layout when tmp is present"; + } + if (operandIsRowMajor) { + if (operandCol != ShapedType::kDynamic && operandCol != expectedCol) { + return op->emitOpError() + << "expects row-major " << operandName + << " valid_shape[1] to be 32/sizeof(dtype)"; + } + return success(); + } + if (operandCol != ShapedType::kDynamic && operandCol != 1) { + return op->emitOpError() << "expects non-row-major " << operandName + << " valid_shape[1] to be 1"; + } + return success(); + }; + + auto checkFullAndBroadcast = [&](Type fullTy, ArrayRef fullValid, + StringRef fullName, Type broadcastTy, + ArrayRef broadcastValid, + StringRef broadcastName) -> LogicalResult { + if (!isRowMajorTileBuf(fullTy)) + return op->emitOpError() << "expects " << fullName + << " to use row-major layout when it matches dst"; + if (fullValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && + fullValid[0] != dstValid[0]) + return op->emitOpError() << "expects " << fullName + << " valid_shape[0] to equal dst valid_shape[0]"; + if (fullValid[1] != ShapedType::kDynamic && dstValid[1] != ShapedType::kDynamic && + fullValid[1] != dstValid[1]) + return op->emitOpError() << "expects " << fullName + << " valid_shape[1] to equal dst valid_shape[1]"; + return checkBroadcastOperand(broadcastTy, broadcastValid, broadcastName, + /*requireNonRowMajor=*/hasTmp && + targetArch == PTOArch::A3); + }; + + if (hasTmp && targetArch == PTOArch::A5) + return op->emitOpError("expects A5 form to omit tmp"); + + if (src0MatchesDst) { + if (succeeded(checkFullAndBroadcast(src0Ty, src0Valid, "src0", src1Ty, + src1Valid, "src1"))) + return success(); + } + if (src1MatchesDst) { + if (succeeded(checkFullAndBroadcast(src1Ty, src1Valid, "src1", src0Ty, + src0Valid, "src0"))) + return success(); + } + + return op->emitOpError() << "expects one of src0/src1 to match dst valid_shape" + << " and the other to be a per-row scalar vector"; +} + +mlir::LogicalResult mlir::pto::TRowExpandExpdifOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + getTmp() ? getTmp().getType() : Type{}, + (bool)getTmp(), PTOArch::A3, + "trowexpandexpdif", + /*allowIntegerTypes=*/false); + }; + auto verifyA5 = [&]() -> LogicalResult { + return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + getTmp() ? getTmp().getType() : Type{}, + (bool)getTmp(), PTOArch::A5, + "trowexpandexpdif", + /*allowIntegerTypes=*/false); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TRowExpandMaxOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + getTmp() ? getTmp().getType() : Type{}, + (bool)getTmp(), PTOArch::A3, + "trowexpandmax", + /*allowIntegerTypes=*/true); + }; + auto verifyA5 = [&]() -> LogicalResult { + return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + getTmp() ? getTmp().getType() : Type{}, + (bool)getTmp(), PTOArch::A5, + "trowexpandmax", + /*allowIntegerTypes=*/true); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TRowExpandMinOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + getTmp() ? getTmp().getType() : Type{}, + (bool)getTmp(), PTOArch::A3, + "trowexpandmin", + /*allowIntegerTypes=*/true); + }; + auto verifyA5 = [&]() -> LogicalResult { + return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), + getSrc1().getType(), getDst().getType(), + getTmp() ? getTmp().getType() : Type{}, + (bool)getTmp(), PTOArch::A5, + "trowexpandmin", + /*allowIntegerTypes=*/true); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + + +mlir::LogicalResult mlir::pto::TRowMaxOp::verify() { + auto verifyByArch = [&]() -> LogicalResult { + return verifyTRowReductionNoTmpCommon(*this, getSrc().getType(), + getDst().getType(), + "expects element type to be i16/i32/f16/f32"); + }; + return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); +} + +mlir::LogicalResult mlir::pto::TRowArgMaxOp::verify() { + auto verifyByArch = [&]() -> LogicalResult { + return verifyTRowArgReductionCommon(*this, getSrc().getType(), + getTmp().getType(), getDst().getType()); + }; + + return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); +} + + +mlir::LogicalResult mlir::pto::TRowMinOp::verify() { + auto verifyByArch = [&]() -> LogicalResult { + return verifyTRowReductionWithTmpCommon( + *this, getSrc().getType(), getTmp().getType(), getDst().getType(), + "expects element type to be i16/i32/f16/f32"); + }; + return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); +} + +mlir::LogicalResult mlir::pto::TRowArgMinOp::verify() { + auto verifyByArch = [&]() -> LogicalResult { + return verifyTRowArgReductionCommon(*this, getSrc().getType(), + getTmp().getType(), getDst().getType()); + }; + + return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); +} + + +mlir::LogicalResult mlir::pto::TRowSumOp::verify() { + auto verifyByArch = [&]() -> LogicalResult { + return verifyTRowReductionNoTmpCommon(*this, getSrc().getType(), + getDst().getType(), + "expects element type to be i16/i32/f16/f32"); + }; + return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); +} + +mlir::LogicalResult mlir::pto::TRowProdOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + return verifyTRowReductionWithTmpCommon( + *this, getSrc().getType(), getTmp().getType(), getDst().getType(), + "expects A2/A3 trowprod element type to be i16/i32/f16/f32"); + }; + auto verifyA5 = [&]() -> LogicalResult { + return verifyTRowReductionWithTmpCommon( + *this, getSrc().getType(), getTmp().getType(), getDst().getType(), + "expects A5 trowprod element type to be i16/i32/f16/f32"); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + + +mlir::LogicalResult mlir::pto::TRsqrtOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + Type ts = getSrc().getType(); + Type td = getDst().getType(); + if (failed(verifyVecTileUnaryOp(*this, ts, td, "src", "dst", + /*allowBf16=*/false, /*allowInt8=*/false))) + return failure(); + if (failed(verifyTileBufSameValidShape(*this, ts, td, "src", "dst"))) + return failure(); + auto ft = mlir::dyn_cast(getElemTy(ts)); + if (!ft || (!ft.isF16() && !ft.isF32())) + return emitOpError("expects element type to be f16 or f32"); + if (auto tmp = getTmp()) { + Type tt = tmp.getType(); + if (failed(verifyVecTileCommon(*this, tt, "tmp"))) + return failure(); + + auto tmpElemTy = getElemTy(tt); + auto tmpElemBytes = getElemBytes(tmpElemTy); + auto tmpNumel = getStaticNumElements(getShapeVec(tt)); + if (!tmpElemBytes.has_value() || !tmpNumel.has_value()) + return emitOpError("expects tmp to have a static, byte-addressable tile type"); + if (tmpElemBytes.value() * tmpNumel.value() < 32) + return emitOpError("expects tmp to be at least 32 bytes when provided"); + } + return mlir::success(); +} + + +mlir::LogicalResult mlir::pto::TScatterOp::verify() { + const bool hasIndexes = static_cast(getIndexes()); + const bool hasMaskPattern = static_cast(getMaskPatternAttr()); + if (hasIndexes == hasMaskPattern) { + return emitOpError( + "expects exactly one of indexes operand or maskPattern attribute"); + } + + auto isAllowedDataElem = [&](mlir::Type t) -> bool { + if (t.isF16() || t.isF32() || t.isBF16()) return true; + if (auto it = mlir::dyn_cast(t)) + return (it.getWidth() == 8 || it.getWidth() == 16 || it.getWidth() == 32); + return false; + }; + auto isAllowedIndexElem = [&](mlir::Type t) -> bool { + if (auto it = mlir::dyn_cast(t)) + return (it.getWidth() == 16 || it.getWidth() == 32); + return false; + }; + auto getMaskScatterTimes = [&](mlir::pto::MaskPatternAttr mp) -> unsigned { + switch (mp.getValue()) { + case mlir::pto::MaskPattern::P1111: + return 1; + case mlir::pto::MaskPattern::P0101: + case mlir::pto::MaskPattern::P1010: + return 2; + default: + return 4; + } + }; + + auto verifyIndexedForm = [&]() -> LogicalResult { + Type ts = getSrc().getType(); + Type ti = getIndexes().getType(); + Type td = getDst().getType(); + if (failed(verifyVecTileStorage(*this, ts, "src")) || + failed(verifyVecTileStorage(*this, ti, "indexes")) || + failed(verifyVecTileStorage(*this, td, "dst"))) + return failure(); + + Type srcElem = getElemTy(ts), dstElem = getElemTy(td), idxElem = getElemTy(ti); + if (!srcElem || !dstElem || !idxElem) + return emitOpError("failed to get element type for operands"); + if (srcElem != dstElem) + return emitOpError("expects src/dst to have the same element type"); + + if (!isAllowedDataElem(srcElem)) + return emitOpError("expects src/dst element type to be i8/i16/i32/f16/bf16/f32"); + if (!isAllowedIndexElem(idxElem)) + return emitOpError("expects indexes element type to be i16/i32"); + + auto bwData = getPTOStorageElemBitWidth(srcElem); + auto bwIdx = getPTOStorageElemBitWidth(idxElem); + if (bwData != 8 && bwData != 16 && bwData != 32) + return emitOpError("unexpected src/dst element bitwidth"); + + unsigned dataBytes = bwData / 8; + unsigned idxBytes = bwIdx / 8; + unsigned expectedIdxBytes = (dataBytes == 1) ? 2 : dataBytes; + if (idxBytes != expectedIdxBytes) + return emitOpError("expects indexes element size to match the documented scatter rule"); + return mlir::success(); + }; + + auto verifyMaskForm = [&]() -> LogicalResult { + Type ts = getSrc().getType(); + Type td = getDst().getType(); + if (failed(verifyVecTileCommon(*this, ts, "src")) || + failed(verifyVecTileCommon(*this, td, "dst"))) + return failure(); + + auto srcTB = dyn_cast(ts); + auto dstTB = dyn_cast(td); + if (!srcTB || !dstTB) + return emitOpError("expects src and dst to be tile_buf types"); + + if (getElemTy(ts) != getElemTy(td)) + return emitOpError("expects src and dst to have the same element type"); + if (!isAllowedDataElem(getElemTy(ts))) + return emitOpError("expects src/dst element type to be i8/i16/i32/f16/bf16/f32"); + + auto srcValid = getValidShapeVec(ts); + auto dstValid = getValidShapeVec(td); + if (srcValid.size() != 2 || dstValid.size() != 2) + return emitOpError("expects src and dst to have rank-2 valid_shape"); + + auto mp = getMaskPatternAttr(); + if (!mp) + return emitOpError("expects mask-pattern tscatter to provide maskPattern"); + const unsigned times = getMaskScatterTimes(mp); + if (srcValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && + srcValid[0] != dstValid[0]) + return emitOpError("expects src and dst to have the same valid rows"); + if (srcValid[1] != ShapedType::kDynamic && dstValid[1] != ShapedType::kDynamic && + srcValid[1] != static_cast(dstValid[1] * times)) + return emitOpError("expects src valid cols to equal dst valid cols times the mask expansion factor"); + + if (srcTB.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) || + dstTB.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor)) + return emitOpError("expects mask-pattern tscatter to use row_major blayout"); + return mlir::success(); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + if (hasMaskPattern) + return verifyMaskForm(); + return verifyIndexedForm(); + }; + auto verifyA5 = [&]() -> LogicalResult { + if (hasMaskPattern) + return emitOpError("mask-pattern tscatter is not supported on A5 yet"); + return verifyIndexedForm(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + + +mlir::LogicalResult mlir::pto::TSelOp::verify() { + auto verifyCommon = [&]() -> FailureOr { + Type t0 = getSrc0().getType(); + Type t1 = getSrc1().getType(); + Type td = getDst().getType(); + if (failed(verifyTileBufCommon(*this, t0, "src0")) || + failed(verifyTileBufCommon(*this, t1, "src1")) || + failed(verifyTileBufCommon(*this, td, "dst"))) + return failure(); + + Type srcElem = getElemTy(t0); + Type src1Elem = getElemTy(t1); + Type dstElem = getElemTy(td); + if (!srcElem || !src1Elem || !dstElem) { + emitOpError("failed to get element type for operands"); + return failure(); + } + if (srcElem != src1Elem || srcElem != dstElem) { + emitOpError("expects src0, src1, and dst to have the same element type"); + return failure(); + } + + if (!isRowMajorTileBuf(t0) || !isRowMajorTileBuf(t1) || + !isRowMajorTileBuf(td)) { + emitOpError( + "expects src0, src1, and dst to use row-major layout"); + return failure(); + } + return srcElem; + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + FailureOr srcElem = verifyCommon(); + if (failed(srcElem)) + return failure(); + Type elem = *srcElem; + bool ok = elem.isF16() || elem.isBF16() || elem.isF32(); + if (auto it = dyn_cast(elem)) + ok = it.getWidth() == 16 || it.getWidth() == 32; + if (!ok) + return emitOpError( + "expects A2/A3 tsel src0, src1, and dst element type to be i16/i32/f16/bf16/f32"); + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + FailureOr srcElem = verifyCommon(); + if (failed(srcElem)) + return failure(); + Type elem = *srcElem; + bool ok = elem.isF16() || elem.isBF16() || elem.isF32(); + if (auto it = dyn_cast(elem)) + ok = it.getWidth() == 8 || it.getWidth() == 16 || it.getWidth() == 32; + if (!ok) + return emitOpError( + "expects A5 tsel src0, src1, and dst element type to be i8/i16/i32/f16/bf16/f32"); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + + +mlir::LogicalResult mlir::pto::TSelSOp::verify() { + // Constraints & Verification per PTO_IR_manual.md pto.tsels: + // - src and dst same element type; A2A3: i16/i32/f16/f32; A5: i8/i16/i32/f16/f32 + // - src and dst row-major; src and dst same valid region + auto verifyCommon = [&]() -> FailureOr { + Type tMask = getMask().getType(); + Type tSrc = getSrc().getType(); + Type tTmp = getTmp().getType(); + Type tDst = getDst().getType(); + if (failed(verifyTileBufCommon(*this, tMask, "mask")) || + failed(verifyTileBufCommon(*this, tSrc, "src")) || + failed(verifyTileBufCommon(*this, tTmp, "tmp")) || + failed(verifyTileBufCommon(*this, tDst, "dst"))) + return failure(); + Type eMask = getElemTy(tMask), eSrc = getElemTy(tSrc); + Type eTmp = getElemTy(tTmp), eDst = getElemTy(tDst); + if (!eMask || !eSrc || !eTmp || !eDst) { + emitOpError("failed to get element type for operands"); + return failure(); + } + if (eSrc != eDst) + return emitOpError("expects src and dst to have the same element type"); + if (failed(verifyTileBufSameValidShape(*this, tSrc, tDst, "src", "dst"))) + return failure(); + return eDst; + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + FailureOr elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + Type tSrc = getSrc().getType(); + Type tDst = getDst().getType(); + if (!isRowMajorTileBuf(tSrc) || !isRowMajorTileBuf(tDst)) + return emitOpError("expects src and dst to use row-major layout"); + Type elem = *elemOr; + bool ok = elem.isF16() || elem.isF32(); + if (auto it = mlir::dyn_cast(elem)) + ok = (it.getWidth() == 16 || it.getWidth() == 32); + if (!ok) + return emitOpError( + "expects A2/A3 tsels src and dst element type to be i16, i32, f16, or f32"); + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + FailureOr elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + Type tSrc = getSrc().getType(); + Type tDst = getDst().getType(); + if (!isRowMajorTileBuf(tSrc) || !isRowMajorTileBuf(tDst)) + return emitOpError("expects src and dst to use row-major layout"); + Type elem = *elemOr; + bool ok = elem.isF16() || elem.isF32(); + if (auto it = mlir::dyn_cast(elem)) + ok = (it.getWidth() == 8 || it.getWidth() == 16 || it.getWidth() == 32); + if (!ok) + return emitOpError( + "expects A5 tsels src and dst element type to be i8, i16, i32, f16, or f32"); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + + +mlir::LogicalResult mlir::pto::TShlOp::verify() { + auto verify = [&]() -> LogicalResult { + FailureOr elemOr = verifyShiftLikeBinaryTileOpCommon( + *this, getSrc0().getType(), getSrc1().getType(), getDst().getType()); + if (failed(elemOr)) + return failure(); + auto it = mlir::dyn_cast(*elemOr); + if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && + it.getWidth() != 32)) + return emitOpError( + "expects tshl src0 and src1 element type to be i8/i16/i32"); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verify, verify); +} + + +mlir::LogicalResult mlir::pto::TShrOp::verify() { + auto verify = [&]() -> LogicalResult { + FailureOr elemOr = verifyShiftLikeBinaryTileOpCommon( + *this, getSrc0().getType(), getSrc1().getType(), getDst().getType()); + if (failed(elemOr)) + return failure(); + auto it = mlir::dyn_cast(*elemOr); + if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && + it.getWidth() != 32)) + return emitOpError( + "expects tshr src0 and src1 element type to be i8/i16/i32"); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verify, verify); +} + + +mlir::LogicalResult mlir::pto::TSort32Op::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + Type idxTy = getIdx().getType(); + if (failed(verifyVecTileCommon(*this, srcTy, "src")) || + failed(verifyVecTileCommon(*this, dstTy, "dst")) || + failed(verifyVecTileCommon(*this, idxTy, "idx"))) + return failure(); + if (getTmp() && + failed(verifyVecTileCommon(*this, getTmp().getType(), "tmp"))) + return failure(); + + auto srcElem = getElemTy(srcTy); + auto dstElem = getElemTy(dstTy); + if (!srcElem || !dstElem || srcElem != dstElem) + return emitOpError() << "expects src and dst to have the same element type"; + if (!(srcElem.isF16() || srcElem.isF32())) + return emitOpError() << "expects src and dst element type to be f16 or f32"; + + auto idxElem = getElemTy(idxTy); + auto idxInt = dyn_cast(idxElem); + if (!idxInt || idxInt.getWidth() != 32) + return emitOpError() << "expects idx element type to be i32/u32"; + return mlir::success(); +} + + +mlir::LogicalResult mlir::pto::TSqrtOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyVecTileUnaryOp(*this, srcTy, dstTy, "src", "dst", + /*allowBf16=*/false, /*allowInt8=*/false))) + return failure(); + if (failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) + return failure(); + + auto srcElem = getElemTy(srcTy); + if (!(mlir::isa(srcElem) || mlir::isa(srcElem))) + return emitOpError() << "expects src and dst element type to be float or half"; + + return mlir::success(); +} + + + +mlir::LogicalResult mlir::pto::TStoreFPOp::verify() { + auto shouldBypassDecoded = [&]() -> bool { + Value src = getSrc(); + Value fp = getFp(); + return isa(src.getType()) || isa(fp.getType()) || + src.getDefiningOp() || + fp.getDefiningOp(); + }; + + auto verifyDstType = [&]() -> LogicalResult { + Type dstTy = getDst().getType(); + if (!isa(dstTy)) + return emitOpError() + << "expects dst to be a memref or !pto.partition_tensor_view"; + if (auto dstPart = dyn_cast(dstTy)) { + for (auto [idx, dim] : llvm::enumerate(dstPart.getShape())) { + if (dim != ShapedType::kDynamic && dim <= 0) + return emitOpError() + << "expects dst shape[" << idx << "] to be positive"; + } + } + return success(); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type fpTy = getFp().getType(); + if (!isa(srcTy)) + return emitOpError() << "expects src to be a !pto.tile_buf"; + if (!isa(fpTy)) + return emitOpError() << "expects fp to be a !pto.tile_buf"; + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, fpTy, "fp"))) + return failure(); + if (failed(verifyDstType())) + return failure(); + auto srcSpace = getPTOMemorySpaceEnum(srcTy); + if (!srcSpace || *srcSpace != pto::AddressSpace::ACC) + return emitOpError() << "expects src to be in the acc address space"; + auto srcElemTy = getElemTy(srcTy); + auto srcIntTy = dyn_cast(srcElemTy); + if (!(srcElemTy.isF32() || + (srcIntTy && srcIntTy.getWidth() == 32))) + return emitOpError() + << "expects src to have element type f32, i32"; + auto srcShape = getShapeVec(srcTy); + if (srcShape.size() != 2) + return emitOpError() << "expects src to have rank 2"; + if (srcShape[1] != ShapedType::kDynamic && + (srcShape[1] < 1 || srcShape[1] > 4095)) + return emitOpError() << "expects src.cols to be in the range [1, 4095]"; + auto srcValid = getValidShapeVec(srcTy); + if (srcValid.size() != 2) + return emitOpError() << "expects src to have a rank-2 valid_shape"; + if (srcValid[1] != ShapedType::kDynamic && + (srcValid[1] < 1 || srcValid[1] > 4095)) + return emitOpError() + << "expects src.valid_shape[1] to be in the range [1, 4095]"; + return mlir::success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type fpTy = getFp().getType(); + if (!isa(srcTy)) + return emitOpError() << "expects src to be a !pto.tile_buf"; + if (!isa(fpTy)) + return emitOpError() << "expects fp to be a !pto.tile_buf"; + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, fpTy, "fp"))) + return failure(); + if (failed(verifyDstType())) + return failure(); + auto srcSpace = getPTOMemorySpaceEnum(srcTy); + if (!srcSpace || *srcSpace != pto::AddressSpace::ACC) + return emitOpError() << "expects src to be in the acc address space"; + return mlir::success(); + }; + if (shouldBypassDecoded()) + return success(); + switch (getVerifierTargetArch(getOperation())) { + case VerifierTargetArch::A2A3: + return verifyA2A3(); + case VerifierTargetArch::A5: + return verifyA5(); + } + return failure(); +} + + +mlir::LogicalResult mlir::pto::TSubOp::verify() { + return verifyArithmeticBinaryTileOpWithArchDispatch( + getOperation(), getSrc0().getType(), getSrc1().getType(), getDst().getType(), + /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/false, + "expects A2/A3 tsub element type to be i32/i16/f16/f32", + "expects A5 tsub element type to be i32/i16/i8/f16/f32"); +} + + +mlir::LogicalResult mlir::pto::TSubCOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + Type src0Ty = getSrc0().getType(); + Type src1Ty = getSrc1().getType(); + Type src2Ty = getSrc2().getType(); + Type dstTy = getDst().getType(); + if (!isPTOShapedLike(src0Ty) || !isPTOShapedLike(src1Ty) || !isPTOShapedLike(src2Ty) || !isPTOShapedLike(dstTy)) + return emitOpError() << "expects PTO shaped-like src0, src1, src2, and dst"; + + auto d = getShapeVec(dstTy); + if (getShapeVec(src0Ty).size() != d.size() || getShapeVec(src1Ty).size() != d.size() || getShapeVec(src2Ty).size() != d.size()) + return emitOpError() << "expects all tensors to have the same rank"; + return mlir::success(); +} + + +mlir::LogicalResult mlir::pto::TSubSOp::verify() { + return verifyArithmeticScalarTileOpWithArchDispatch( + getOperation(), getSrc().getType(), getDst().getType(), getScalar().getType(), + /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, + "expects A2/A3 tsubs element type to be i32/i16/f16/f32", + "expects A5 tsubs element type to be i32/i16/i8/f16/bf16/f32", + /*requireValidRowsEqualOnA2A3=*/true, + /*requireValidRowsEqualOnA5=*/false); +} + + +mlir::LogicalResult mlir::pto::TSubSCOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + Type src0Ty = getSrc0().getType(); + Type src1Ty = getSrc1().getType(); + Type dstTy = getDst().getType(); + if (!isPTOShapedLike(src0Ty) || !isPTOShapedLike(src1Ty) || !isPTOShapedLike(dstTy)) + return emitOpError() << "expects PTO shaped-like src0, src1, and dst"; + + auto d = getShapeVec(dstTy); + if (getShapeVec(src0Ty).size() != d.size() || getShapeVec(src1Ty).size() != d.size()) + return emitOpError() << "expects src0, src1, and dst to have the same rank"; + return mlir::success(); +} +mlir::LogicalResult mlir::pto::TTransOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type tmpTy = getTmp().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, tmpTy, "tmp")) || + failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + Type srcElem = getElemTy(srcTy); + Type tmpElem = getElemTy(tmpTy); + Type dstElem = getElemTy(dstTy); + if (!srcElem || !tmpElem || !dstElem || srcElem != dstElem || srcElem != tmpElem) + return emitOpError() << "expects src and dst to have the same element type"; + if (auto srcTb = dyn_cast(srcTy)) { + if (srcTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor)) + return emitOpError() << "expects A2/A3 transpose src to use the row_major blayout"; + } + unsigned elemBytes = getPTOStorageElemByteSize(srcElem); + if (elemBytes == 0) + return emitOpError() << "failed to get transpose element size"; + if (elemBytes != 1 && elemBytes != 2 && elemBytes != 4) + return emitOpError() << "expects transpose element size to be 1, 2, or 4 bytes"; + auto isAllowedWidthType = [&](Type ty) { + if (elemBytes == 4) + return ty.isInteger(32) || ty.isF32(); + if (elemBytes == 2) + return ty.isInteger(16) || ty.isF16() || ty.isBF16(); + return ty.isInteger(8); + }; + if (!isAllowedWidthType(srcElem)) + return emitOpError() << "expects transpose element type to match the supported set for its width"; + return mlir::success(); + }; + auto verifyA5 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type tmpTy = getTmp().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, tmpTy, "tmp")) || + failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + Type srcElem = getElemTy(srcTy); + Type tmpElem = getElemTy(tmpTy); + Type dstElem = getElemTy(dstTy); + if (!srcElem || !tmpElem || !dstElem || srcElem != dstElem || srcElem != tmpElem) + return emitOpError() << "expects src, tmp, and dst to have the same element type"; + unsigned elemBytes = getPTOStorageElemByteSize(srcElem); + if (elemBytes == 0) + return emitOpError() << "failed to get transpose element size"; + if (elemBytes != 1 && elemBytes != 2 && elemBytes != 4) + return emitOpError() << "expects transpose element size to be 1, 2, or 4 bytes"; + auto isAllowedWidthType = [&](Type ty) { + if (elemBytes == 4) + return ty.isInteger(32) || ty.isF32(); + if (elemBytes == 2) + return ty.isInteger(16) || ty.isF16() || ty.isBF16(); + return ty.isInteger(8); + }; + if (!isAllowedWidthType(srcElem)) + return emitOpError() << "expects transpose element type to match the supported set for its width"; + auto checkAlignedMajor = [&](Type ty, StringRef name) -> LogicalResult { + auto tb = mlir::dyn_cast(ty); + if (!tb) + return success(); + auto shape = getShapeVec(ty); + if (shape.size() != 2) + return success(); + bool rowMajor = tb.getBLayoutValueI32() == static_cast(pto::BLayout::RowMajor); + int64_t major = rowMajor ? shape[1] : shape[0]; + if (major != ShapedType::kDynamic && (major * static_cast(elemBytes)) % 32 != 0) + return emitOpError() << "expects " << name << " major dimension times element size to be 32-byte aligned on A5"; + return success(); + }; + if (failed(checkAlignedMajor(srcTy, "src")) || failed(checkAlignedMajor(dstTy, "dst"))) + return failure(); + return mlir::success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +mlir::LogicalResult mlir::pto::TXorOp::verify() { + auto verifyBase = [&]() -> FailureOr { + return verifyMatchingRowMajorBinaryTileOpCommon( + getOperation(), getSrc0().getType(), getSrc1().getType(), + getDst().getType()); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + FailureOr elemOr = verifyBase(); + if (failed(elemOr)) + return failure(); + Type tmpTy = getTmp().getType(); + if (failed(verifyTileBufCommon(*this, tmpTy, "tmp"))) + return failure(); + Type elem = *elemOr; + if (getElemTy(tmpTy) != elem) + return emitOpError("expects tmp to have the same element type as src0, src1, and dst"); + if (!isRowMajorTileBuf(tmpTy)) + return emitOpError("expects tmp to use row-major layout"); + if (failed(verifyTileBufSameValidShape(*this, tmpTy, getDst().getType(), "tmp", "dst"))) + return failure(); + auto it = mlir::dyn_cast(elem); + if (!it || (it.getWidth() != 8 && it.getWidth() != 16)) + return emitOpError( + "expects A2/A3 txor src0, src1, tmp, and dst element type to be i8/i16"); + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + FailureOr elemOr = verifyBase(); + if (failed(elemOr)) + return failure(); + auto it = mlir::dyn_cast(*elemOr); + if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && + it.getWidth() != 32)) + return emitOpError( + "expects A5 txor src0, src1, and dst element type to be i8/i16/i32"); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + + +mlir::LogicalResult mlir::pto::TXorSOp::verify() { + auto verifyCommon = [&]() -> FailureOr { + return verifyDistinctRowMajorUnaryTileOpCommon(getOperation(), getSrc(), + getDst(), "src", "dst"); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + FailureOr elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + auto it = mlir::dyn_cast(*elemOr); + if (!it || (it.getWidth() != 8 && it.getWidth() != 16)) + return emitOpError( + "expects A2/A3 txors src and dst element type to be i8/i16"); + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + FailureOr elemOr = verifyCommon(); + if (failed(elemOr)) + return failure(); + auto it = mlir::dyn_cast(*elemOr); + if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && + it.getWidth() != 32)) + return emitOpError( + "expects A5 txors src and dst element type to be i8/i16/i32"); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} +mlir::LogicalResult mlir::pto::TPrintOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + auto srcType = getSrc().getType(); + if (auto tb = mlir::dyn_cast(srcType)) { + auto elem = tb.getElementType(); + if (!(elem.isF16() || elem.isF32() || + elem.isInteger(8) || elem.isInteger(16) || elem.isInteger(32))) + return emitOpError() << "expects printable tile element type"; + auto space = getPTOMemorySpaceEnum(srcType); + if (!space || *space != pto::AddressSpace::VEC) + return emitOpError() << "expects printable tile_buf to be in vec address space"; + return success(); + } + if (mlir::dyn_cast(srcType) || + mlir::dyn_cast(srcType)) + return mlir::success(); + return emitOpError() << "expects tile_buf, memref, or partition_tensor_view for src"; +} + + + +[[maybe_unused]] static LogicalResult verifyMatmulCommon(Operation *op, Value lhs, Value rhs, + Value biasOpt, Type maybeDstElemTy, + Type maybeResultElemTy) { + // ---- case A: tensor/memref (ShapedType) ---- + if (auto lhsTy = dyn_cast(lhs.getType())) { + auto rhsTy = dyn_cast(rhs.getType()); + if (!rhsTy || !lhsTy.hasRank() || !rhsTy.hasRank()) + return op->emitOpError("expects lhs and rhs to be ranked tensors or memrefs"); + + if (lhsTy.getElementType() != rhsTy.getElementType()) + return op->emitOpError() + << "expects lhs and rhs to have the same element type, but got lhs=" + << lhsTy.getElementType() << " rhs=" << rhsTy.getElementType(); + + if (biasOpt) { + auto biasTy = dyn_cast(biasOpt.getType()); + if (!biasTy || !biasTy.hasRank()) + return op->emitOpError("expects bias to be a ranked tensor or memref"); + if (biasTy.getElementType() != lhsTy.getElementType()) + return op->emitOpError() + << "expects bias to have the same element type as lhs and rhs, but got bias=" + << biasTy.getElementType() << " vs " << lhsTy.getElementType(); + } + + if (maybeDstElemTy && maybeDstElemTy != lhsTy.getElementType()) + return op->emitOpError() + << "expects dst to have the same element type as lhs and rhs, but got dst=" + << maybeDstElemTy << " vs " << lhsTy.getElementType(); + + if (maybeResultElemTy && maybeResultElemTy != lhsTy.getElementType()) + return op->emitOpError() + << "expects result to have the same element type as lhs and rhs, but got result=" + << maybeResultElemTy << " vs " << lhsTy.getElementType(); + + return success(); + } + + // ---- case B: tile ---- + auto lhsTile = dyn_cast(lhs.getType()); + auto rhsTile = dyn_cast(rhs.getType()); + if (!lhsTile || !rhsTile) + return op->emitOpError("expects lhs and rhs to be ranked tensors, memrefs, or !pto.tile"); + + if (lhsTile.getElementType() != rhsTile.getElementType()) + return op->emitOpError() << "expects lhs and rhs tiles to have the same element type, but got lhs=" + << lhsTile.getElementType() << " rhs=" << rhsTile.getElementType(); + + if ((int64_t)lhsTile.getShape().size() != 2 || (int64_t)rhsTile.getShape().size() != 2) + return op->emitOpError("expects lhs and rhs tiles to be 2D"); + + if (lhsTile.getShape()[1] != rhsTile.getShape()[0]) + return op->emitOpError() << "expects lhs dim1 to equal rhs dim0, but got " + << lhsTile.getShape()[1] << " vs " << rhsTile.getShape()[0]; + + if (biasOpt) { + auto biasTile = dyn_cast(biasOpt.getType()); + if (!biasTile) + return op->emitOpError("expects bias to be !pto.tile when lhs and rhs are !pto.tile"); + if (biasTile.getElementType() != lhsTile.getElementType()) + return op->emitOpError("expects bias to have the same element type as lhs and rhs"); + } + + if (maybeDstElemTy && maybeDstElemTy != lhsTile.getElementType()) + return op->emitOpError() << "expects dst to have the same element type as lhs and rhs"; + + if (maybeResultElemTy && maybeResultElemTy != lhsTile.getElementType()) + return op->emitOpError() << "expects result to have the same element type as lhs and rhs"; + + return success(); +} + +LogicalResult mlir::pto::TMatmulOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + if (failed(verifyMatTileOperands(*this, getLhs().getType(), getRhs().getType(), + getDst().getType()))) + return failure(); + if (failed(verifyMatmulTypeTriple(*this, getElemTy(getLhs().getType()), + getElemTy(getRhs().getType()), + getElemTy(getDst().getType())))) + return failure(); + return verifyMatmulLike(*this, getLhs().getType(), getRhs().getType(), + getDst().getType()); + }; + auto verifyA5 = [&]() -> LogicalResult { return verifyA2A3(); }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult mlir::pto::TGemvOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + if (failed(verifyGemvTileOperands(*this, getLhs().getType(), getRhs().getType(), + getDst().getType()))) + return failure(); + if (failed(verifyMatmulTypeTriple(*this, getElemTy(getLhs().getType()), + getElemTy(getRhs().getType()), + getElemTy(getDst().getType())))) + return failure(); + return verifyMatmulLike(*this, getLhs().getType(), getRhs().getType(), + getDst().getType()); + }; + auto verifyA5 = [&]() -> LogicalResult { return verifyA2A3(); }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult mlir::pto::TMatmulAccOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + if (failed(verifyAccTileCommon(*this, getAccIn().getType(), "acc_in")) || + failed(verifyMatTileOperands(*this, getLhs().getType(), getRhs().getType(), + getDst().getType()))) + return failure(); + return success(); +} + +LogicalResult mlir::pto::TGemvAccOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + if (failed(verifyAccTileCommon(*this, getAccIn().getType(), "acc_in")) || + failed(verifyGemvTileOperands(*this, getLhs().getType(), getRhs().getType(), + getDst().getType()))) + return failure(); + return success(); +} + +//===----------------------------------------------------------------------===// +// inferReturnTypes() for matmul ops (keep your existing code) +//===----------------------------------------------------------------------=== +[[maybe_unused]] static mlir::Type inferMatmulTileResult2DFromAB(MLIRContext *context, ValueRange operands) { + if (operands.size() < 2) + return mlir::Type(); + + auto lhsTile = dyn_cast(operands[0].getType()); + auto rhsTile = dyn_cast(operands[1].getType()); + if (!lhsTile || !rhsTile) + return mlir::Type(); + + Type elemTy = lhsTile.getElementType(); + + if (operands.size() >= 3) { + if (auto biasTile = dyn_cast(operands[2].getType())) { + return mlir::pto::TileType::get(context, biasTile.getShape(), elemTy); + } + } + + auto lhsShape = lhsTile.getShape(); + auto rhsShape = rhsTile.getShape(); + if (lhsShape.size() >= 2 && rhsShape.size() >= 2) { + int64_t M = lhsShape[0]; + int64_t N = rhsShape[1]; + llvm::SmallVector outShape = {M, N}; + return mlir::pto::TileType::get(context, outShape, elemTy); + } + + return mlir::Type(); +} + +[[maybe_unused]] static RankedTensorType inferMatmulResult2DFromAB(ValueRange operands) { + if (operands.size() < 2) + return RankedTensorType(); + + auto lhsTy = dyn_cast(operands[0].getType()); + auto rhsTy = dyn_cast(operands[1].getType()); + if (!lhsTy || !rhsTy || !lhsTy.hasRank() || !rhsTy.hasRank()) + return RankedTensorType(); + + Type elemTy = lhsTy.getElementType(); + + if (operands.size() >= 3) { + if (auto biasRT = dyn_cast(operands[2].getType())) + return RankedTensorType::get(biasRT.getShape(), elemTy); + if (auto biasMR = dyn_cast(operands[2].getType())) { + if (biasMR.hasStaticShape()) + return RankedTensorType::get(biasMR.getShape(), elemTy); + } + } + + if (lhsTy.getRank() >= 2 && rhsTy.getRank() >= 2) { + int64_t M = lhsTy.getDimSize(0); + int64_t N = rhsTy.getDimSize(1); + return RankedTensorType::get({M, N}, elemTy); + } + + return RankedTensorType(); +} + +[[maybe_unused]] static RankedTensorType inferAccReturnFromAccIn(ValueRange operands) { + if (operands.empty()) + return RankedTensorType(); + if (auto accRT = dyn_cast(operands[0].getType())) + return accRT; + return RankedTensorType(); +} + +namespace mlir { +namespace pto { + +static LogicalResult parseShapeAndElem(AsmParser &parser, + SmallVectorImpl &shape, + Type &elementType, + bool allowDynamic) { + if (parser.parseLess()) + return failure(); + + if (parser.parseDimensionList(shape, allowDynamic)) + return failure(); + + if (parser.parseType(elementType)) + return failure(); + + if (parser.parseGreater()) + return failure(); + + return success(); +} + +static void printShapeAndElem(AsmPrinter &printer, + ArrayRef shape, + Type elementType) { + printer << "<"; + for (auto d : shape) { + if (d == ShapedType::kDynamic) + printer << "?"; + else + printer << d; + printer << "x"; + } + printer.printType(elementType); + printer << ">"; +} + +// ============================================================================= +// PartitionTensorViewType Implementation +// ============================================================================= + +Type PartitionTensorViewType::parse(AsmParser &parser) { + SmallVector shape; + Type elemTy; + if (failed(parseShapeAndElem(parser, shape, elemTy, /*allowDynamic=*/true))) + return Type(); + + return PartitionTensorViewType::get(parser.getContext(), shape, elemTy); +} + +void PartitionTensorViewType::print(AsmPrinter &printer) const { + printShapeAndElem(printer, getShape(), getElementType()); +} + +// ---- TileType ---- +Type TileType::parse(AsmParser &parser) { + SmallVector shape; + Type elemTy; + if (failed(parseShapeAndElem(parser, shape, elemTy, /*allowDynamic=*/true))) + return Type(); + return TileType::get(parser.getContext(), shape, elemTy); +} + +void TileType::print(AsmPrinter &printer) const { + printShapeAndElem(printer, getShape(), getElementType()); +} + +// ---- LocalArrayType ---- +// Asm form: !pto.local_array +// Static shape only (no '?'). Element type must be a scalar; this is enforced +// by the type verifier below. +Type LocalArrayType::parse(AsmParser &parser) { + SmallVector shape; + Type elemTy; + if (failed(parseShapeAndElem(parser, shape, elemTy, /*allowDynamic=*/false))) + return Type(); + return LocalArrayType::getChecked( + [&]() { return parser.emitError(parser.getNameLoc()); }, + parser.getContext(), shape, elemTy); +} + +void LocalArrayType::print(AsmPrinter &printer) const { + printShapeAndElem(printer, getShape(), getElementType()); +} + +LogicalResult LocalArrayType::verify( + llvm::function_ref emitError, + llvm::ArrayRef shape, Type elementType) { + if (shape.empty()) + return emitError() << "'!pto.local_array' requires at least one dimension"; + for (auto [i, d] : llvm::enumerate(shape)) { + if (d <= 0) + return emitError() + << "'!pto.local_array' dimension " << i + << " must be a positive static size, got " << d; + } + if (!elementType.isIntOrFloat()) + return emitError() + << "'!pto.local_array' element type must be a scalar integer or " + "float, got " + << elementType; + return success(); +} + +// ============================================================================= +// Decompose Helper (Reverse Engineering AffineMap -> Strides) +// ============================================================================= + +// Helper: 递归地将 Add 表达式拆解为单独的项列表 +static void flattenAddExpr(AffineExpr expr, SmallVectorImpl &terms) { + if (auto add = llvm::dyn_cast(expr)) { + if (add.getKind() == AffineExprKind::Add) { + flattenAddExpr(add.getLHS(), terms); + flattenAddExpr(add.getRHS(), terms); + return; + } + } + terms.push_back(expr); +} + +// Helper: 从 AffineMap 中提取 Strides +static void decomposeStridedLayout(AffineMap map, SmallVectorImpl &strides) { + // 1. 初始化 + strides.assign(map.getNumDims(), 0); + + if (map.getNumResults() != 1) return; + + // 2. 摊平表达式 + SmallVector terms; + flattenAddExpr(map.getResult(0), terms); + + // 3. 分析每一项 + for (auto term : terms) { + // 情况 A: dN * Const 或 Const * dN + if (auto mul = llvm::dyn_cast(term)) { + if (mul.getKind() == AffineExprKind::Mul) { + AffineExpr lhs = mul.getLHS(); + AffineExpr rhs = mul.getRHS(); + + // 尝试匹配 LHS=Dim, RHS=Const + if (auto dim = llvm::dyn_cast(lhs)) { + if (auto cst = llvm::dyn_cast(rhs)) { + strides[dim.getPosition()] = cst.getValue(); + continue; + } + } + + // 尝试匹配 LHS=Const, RHS=Dim (乘法交换律) + if (auto dim = llvm::dyn_cast(rhs)) { + if (auto cst = llvm::dyn_cast(lhs)) { + strides[dim.getPosition()] = cst.getValue(); + continue; + } + } + } + } + // 情况 B: 单独的 dN (隐含 Stride = 1) + else if (auto dim = llvm::dyn_cast(term)) { + strides[dim.getPosition()] = 1; + } + } +} + +// ============================================================================= +// [Critical] Strict Alignment Protocol Helper +// ============================================================================= +// This function is the SINGLE source of truth for building the AffineMap. +// Both the Parser and the Op Inference MUST use this exact function. +// It ensures that the order of AffineExpr addition is: +// 0 + (d0*str0 + d1*str1...) + (s0*str0 + s1*str1...) +// This guarantees bitwise-identical AffineMaps for verification. +static AffineMap buildStrictBitwiseAffineMap(MLIRContext *ctx, + ArrayRef strides, + bool isMultiDimSymbol) { + unsigned rank = strides.size(); + + // Step 1: Initialize with Constant(0) + AffineExpr totalExpr = getAffineConstantExpr(0, ctx); + + // Step 2: Add Dimensions (d0*str0 + d1*str1...) + // Strictly in order: 0, 1, 2... + for (unsigned i = 0; i < rank; ++i) { + auto dim = getAffineDimExpr(i, ctx); + auto str = getAffineConstantExpr(strides[i], ctx); + totalExpr = totalExpr + (dim * str); + } + + // Step 3: Add Symbols (s0*str0 + s1*str1...) + // Strictly in order: 0, 1, 2... + if (isMultiDimSymbol) { + for (unsigned i = 0; i < rank; ++i) { + auto sym = getAffineSymbolExpr(i, ctx); + auto str = getAffineConstantExpr(strides[i], ctx); + totalExpr = totalExpr + (sym * str); + } + } + // (Optional: handle single dynamic offset case if needed, omitted for clarity) + + // numSymbols is rank if multi-dim (for offsets), else 0 + unsigned numSymbols = isMultiDimSymbol ? rank : 0; + return AffineMap::get(rank, numSymbols, totalExpr); +} + + +// ============================================================================= +// Parser Implementation +// ============================================================================= + +// Helper for parsing [64, 1] +static ParseResult parseStrideList(AsmParser &parser, SmallVectorImpl &strides) { + if (parser.parseLSquare()) return failure(); + do { + int64_t stride; + if (parser.parseInteger(stride)) return failure(); + strides.push_back(stride); + } while (succeeded(parser.parseOptionalComma())); + if (parser.parseRSquare()) return failure(); + return success(); +} + +// The custom attribute parser for: strided<[64, 1], offset: [?, ?]> +[[maybe_unused]] static ParseResult parseStridedLayout(AsmParser &parser, Attribute &layout) { + if (parser.parseLess()) return failure(); + + // 1. Parse Strides + SmallVector strides; + if (parseStrideList(parser, strides)) return failure(); + + bool isMultiDim = false; + unsigned numSymbols = 0; + + // 2. Parse Offset + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseKeyword("offset") || parser.parseColon()) return failure(); + + // Check for multi-dim syntax: [?, ?] + if (succeeded(parser.parseOptionalLSquare())) { + isMultiDim = true; + do { + if (parser.parseQuestion()) return failure(); + numSymbols++; + } while (succeeded(parser.parseOptionalComma())); + if (parser.parseRSquare()) return failure(); + } else { + // Fallback for old scalar syntax '?' + if (parser.parseOptionalQuestion()) { /* handle single scalar */ } + } + } + + if (parser.parseGreater()) return failure(); + + // 3. Validation + if (isMultiDim && numSymbols != strides.size()) { + return parser.emitError(parser.getCurrentLocation(), + "Number of offset symbols must match rank"); + } + + // 4. [CALL SHARED BUILDER] + // Delegate to the strict builder + MLIRContext *ctx = parser.getContext(); + AffineMap map = buildStrictBitwiseAffineMap(ctx, strides, isMultiDim); + + layout = AffineMapAttr::get(map); + return success(); +} + +// ============================================================================= +// Printer Implementation +// ============================================================================= + +[[maybe_unused]] static void printLayout(AsmPrinter &printer, Attribute layoutAttr) { + if (!layoutAttr) return; + auto mapAttr = llvm::dyn_cast(layoutAttr); + if (!mapAttr) { printer << ", " << layoutAttr; return; } + + AffineMap map = mapAttr.getValue(); + if (map.isIdentity()) return; + + // 1. [核心修改] 反解 Strides + SmallVector strides; + decomposeStridedLayout(map, strides); + + printer << ", strided<["; + // 2. 打印真实的 strides + llvm::interleaveComma(strides, printer); + printer << "]"; + + // Print Offset: [?, ?] + unsigned numSyms = map.getNumSymbols(); + if (numSyms > 0) { + printer << ", offset: ["; + for (unsigned i = 0; i < numSyms; ++i) { + printer << "?"; + if (i < numSyms - 1) printer << ", "; + } + printer << "]"; + } + printer << ">"; +} + +// ---- TileBuf --- + + +// Tile subview 相关实现 + +// ============================================================================= +// Op Interface Implementation: SubViewOp +// ============================================================================= + +ParseResult mlir::pto::SubViewOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand source; + SmallVector offsets; + SmallVector valids; + Type sourceTy; + Type resultTy; + bool hasExplicitResultTy = false; + + if (parser.parseOperand(source) || parser.parseLSquare() || + parser.parseOperandList(offsets) || parser.parseRSquare() || + parser.parseKeyword("sizes")) + return failure(); + + ArrayAttr sizesAttr; + if (parser.parseAttribute(sizesAttr, "sizes", result.attributes)) + return failure(); + + if (succeeded(parser.parseOptionalKeyword("valid"))) { + OpAsmParser::UnresolvedOperand vrow, vcol; + if (parser.parseLSquare() || parser.parseOperand(vrow) || parser.parseComma() || + parser.parseOperand(vcol) || parser.parseRSquare()) + return failure(); + valids.push_back(vrow); + valids.push_back(vcol); + } + + if (parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(sourceTy)) + return failure(); + + if (succeeded(parser.parseOptionalArrow())) { + if (parser.parseType(resultTy)) + return failure(); + hasExplicitResultTy = true; + } + + if (parser.resolveOperand(source, sourceTy, result.operands)) + return failure(); + + Type indexTy = parser.getBuilder().getIndexType(); + if (parser.resolveOperands(offsets, indexTy, result.operands)) + return failure(); + if (!valids.empty() && + parser.resolveOperands(valids, indexTy, result.operands)) + return failure(); + + int32_t hasValid = valids.empty() ? 0 : 1; + result.addAttribute( + "operandSegmentSizes", + parser.getBuilder().getDenseI32ArrayAttr( + {1, static_cast(offsets.size()), hasValid, hasValid})); + + if (hasExplicitResultTy) { + result.addTypes(resultTy); + return success(); + } + + SmallVector inferredReturnTypes; + DictionaryAttr attrs = result.attributes.getDictionary(parser.getContext()); + if (failed(SubViewOp::inferReturnTypes( + parser.getContext(), std::nullopt, result.operands, attrs, nullptr, + RegionRange(), inferredReturnTypes))) { + return parser.emitError(parser.getCurrentLocation(), + "failed to infer pto.subview result type"); + } + result.addTypes(inferredReturnTypes); + return success(); +} + +void mlir::pto::SubViewOp::print(OpAsmPrinter &printer) { + printer << " " << getSource() << "["; + printer.printOperands(getOffsets()); + printer << "] sizes " << getSizes(); + if (getValidRow()) { + printer << " valid [" << getValidRow() << ", " << getValidCol() << "]"; + } + printer.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"operandSegmentSizes", + "sizes"}); + printer << " : " << getSource().getType() << " -> " << getResult().getType(); +} + +LogicalResult SubViewOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + + // 1. 获取 Source Type + if (operands.empty()) return failure(); + auto sourceType = llvm::dyn_cast(operands[0].getType()); + if (!sourceType) return failure(); + + // 2. 获取 subview 逻辑窗口(sizes) + ArrayAttr sizeAttr; + if (properties) { + const auto *prop = properties.as(); + if (prop) sizeAttr = prop->sizes; + } + if (!sizeAttr && attributes) { + sizeAttr = attributes.getAs("sizes"); + } + if (!sizeAttr) return failure(); + + SmallVector subviewShape; + for (auto attr : sizeAttr) { + int64_t dim = llvm::cast(attr).getInt(); + subviewShape.push_back(dim); + } + + // Design: subview 的结果 tile 类型显式表达逻辑子窗口 shape(sizes)。 + ArrayRef parentShape = sourceType.getShape(); + if (subviewShape.size() != parentShape.size()) + return failure(); + + // Derive valid shape from explicit valid_row/valid_col when provided. + // Otherwise default to subview shape (no parent valid-shape inheritance). + SmallVector validShape; + constexpr int64_t kDynamicValidDim = -1; + int64_t rank = static_cast(subviewShape.size()); + Value explicitVRow; + Value explicitVCol; + + // Robustly decode optional valid operands using AttrSizedOperandSegments: + // [source, offsets..., valid_row?, valid_col?] + if (attributes) { + if (auto segAttr = + attributes.getAs("operandSegmentSizes")) { + ArrayRef segs = segAttr.asArrayRef(); + if (segs.size() == 4) { + int32_t srcSeg = segs[0]; + int32_t offSeg = segs[1]; + int32_t vRowSeg = segs[2]; + int32_t vColSeg = segs[3]; + if (srcSeg == 1 && offSeg >= 0 && (vRowSeg == 0 || vRowSeg == 1) && + (vColSeg == 0 || vColSeg == 1)) { + size_t idx = static_cast(srcSeg + offSeg); + if (vRowSeg == 1 && idx < operands.size()) + explicitVRow = operands[idx++]; + if (vColSeg == 1 && idx < operands.size()) + explicitVCol = operands[idx]; + } + } + } + } + + // Fallback for legacy callers that may not provide operandSegmentSizes. + if (!explicitVRow && !explicitVCol && rank == 2) { + size_t expectedWithoutValid = static_cast(1 + rank); + if (operands.size() >= expectedWithoutValid + 2) { + explicitVRow = operands[expectedWithoutValid]; + explicitVCol = operands[expectedWithoutValid + 1]; + } + } + + for (size_t i = 0, e = subviewShape.size(); i < e; ++i) { + int64_t vdim = subviewShape[i]; + Value explicitV = (i == 0) ? explicitVRow : (i == 1 ? explicitVCol : Value()); + if (explicitV) { + auto cst = getConstIndexValue(explicitV); + vdim = cst ? std::min(*cst, subviewShape[i]) : kDynamicValidDim; + } + validShape.push_back(vdim); + } + + // 3. 继承 Config (若为空使用默认) + auto cfg = sourceType.getConfigAttr(); + if (!cfg) cfg = TileBufConfigAttr::getDefault(context); + + // 4. 构建 Result Type + auto canonicalValidShape = canonicalizeTileBufValidShape(validShape); + auto resultType = TileBufType::get( + context, subviewShape, sourceType.getElementType(), + sourceType.getMemorySpace(), canonicalValidShape, cfg); + + inferredReturnTypes.push_back(resultType); + return success(); +} + +// ============================================================================= +// SubViewOp verifier +// ============================================================================= +static bool getConstIndex(Value v, int64_t &out) { + if (auto cOp = v.getDefiningOp()) { + out = cOp.value(); + return true; + } + if (auto cInt = v.getDefiningOp()) { + out = cInt.value(); + return true; + } + if (auto cOp = v.getDefiningOp()) { + if (auto ia = dyn_cast(cOp.getValue())) { + out = ia.getInt(); + return true; + } + } + if (auto castOp = v.getDefiningOp()) + return getConstIndex(castOp.getIn(), out); + if (auto extOp = v.getDefiningOp()) + return getConstIndex(extOp.getIn(), out); + if (auto extOp = v.getDefiningOp()) + return getConstIndex(extOp.getIn(), out); + if (auto truncOp = v.getDefiningOp()) + return getConstIndex(truncOp.getIn(), out); + return false; +} + +static LogicalResult computeInnerShape(TileBufConfigAttr cfg, Type elemTy, + int64_t &innerRows, int64_t &innerCols, + bool &boxed, int32_t &bl, int32_t &sl) { + auto readBLayoutI32 = [](Attribute attr, int32_t &out) -> bool { + if (auto a = dyn_cast(attr)) { + out = (int32_t)a.getValue(); + return true; + } + if (auto a = dyn_cast(attr)) { + out = (int32_t)a.getInt(); + return true; + } + return false; + }; + auto readSLayoutI32 = [](Attribute attr, int32_t &out) -> bool { + if (auto a = dyn_cast(attr)) { + out = (int32_t)a.getValue(); + return true; + } + if (auto a = dyn_cast(attr)) { + out = (int32_t)a.getInt(); + return true; + } + return false; + }; + bl = 0; + sl = 0; + int32_t fr = 512; + (void)readBLayoutI32(cfg.getBLayout(), bl); + (void)readSLayoutI32(cfg.getSLayout(), sl); + if (auto attr = dyn_cast(cfg.getSFractalSize())) fr = (int32_t)attr.getInt(); + + boxed = (sl != 0); + if (!boxed) { + innerRows = 1; + innerCols = 1; + return success(); + } + + int64_t elemBytes = static_cast(getElemByteSize(elemTy)); + if (elemBytes <= 0) return failure(); + + if (fr == 1024) { + innerRows = 16; + innerCols = 16; + return success(); + } + if (fr == 32) { + innerRows = 16; + innerCols = 2; + return success(); + } + if (fr == 512) { + if (sl == 1) { + innerRows = 16; + innerCols = 32 / elemBytes; + return success(); + } + if (sl == 2) { + innerRows = 32 / elemBytes; + innerCols = 16; + return success(); + } + } + return failure(); +} + +mlir::LogicalResult mlir::pto::SubViewOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + auto srcTy = llvm::dyn_cast(getSource().getType()); + auto dstTy = llvm::dyn_cast(getResult().getType()); + if (!srcTy || !dstTy) + return emitOpError("expects tile_buf src and tile_buf result"); + if (srcTy.getRank() != 2 || dstTy.getRank() != 2) + return emitOpError("expects rank-2 tilebuf for src/dst"); + + auto sizesAttr = getSizes(); + if (!sizesAttr || sizesAttr.size() != 2) + return emitOpError("subview expects 2D sizes"); + int64_t sizeR = cast(sizesAttr[0]).getInt(); + int64_t sizeC = cast(sizesAttr[1]).getInt(); + if (sizeR <= 0 || sizeC <= 0) + return emitOpError("subview sizes must be positive"); + if (getOffsets().size() != 2) + return emitOpError("subview expects 2D offsets"); + + int64_t offR = 0, offC = 0; + bool offRConst = getConstIndex(getOffsets()[0], offR); + bool offCConst = getConstIndex(getOffsets()[1], offC); + if (offRConst && offR < 0) + return emitOpError("subview offsets must be non-negative"); + if (offCConst && offC < 0) + return emitOpError("subview offsets must be non-negative"); + + bool hasValidRow = static_cast(getValidRow()); + bool hasValidCol = static_cast(getValidCol()); + if (hasValidRow != hasValidCol) + return emitOpError( + "subview expects valid_row and valid_col to be both present or both absent"); + + if (hasValidRow) { + int64_t vRow = 0, vCol = 0; + if (getConstIndex(getValidRow(), vRow)) { + if (vRow <= 0) + return emitOpError("valid_row must be positive when constant"); + if (vRow > sizeR) + return emitOpError("valid_row must be <= subview row size"); + } + if (getConstIndex(getValidCol(), vCol)) { + if (vCol <= 0) + return emitOpError("valid_col must be positive when constant"); + if (vCol > sizeC) + return emitOpError("valid_col must be <= subview col size"); + } + } + + auto dstShape = dstTy.getShape(); + if (dstShape.size() != 2) + return emitOpError("expects result to be rank-2"); + auto srcShape = srcTy.getShape(); + if (srcShape.size() != 2) + return emitOpError("expects source to be rank-2"); + if (dstShape[0] != sizeR || dstShape[1] != sizeC) + return emitOpError("expects result shape to match subview sizes"); + + if (dstTy.getElementType() != srcTy.getElementType()) + return emitOpError("expects result element type to match source"); + if (dstTy.getMemorySpace() != srcTy.getMemorySpace()) + return emitOpError("expects result address space to match source"); + auto srcCfg = srcTy.getConfigAttr(); + if (!srcCfg) srcCfg = TileBufConfigAttr::getDefault(getContext()); + auto dstCfg = dstTy.getConfigAttr(); + if (!dstCfg) dstCfg = TileBufConfigAttr::getDefault(getContext()); + if (dstCfg != srcCfg) + return emitOpError("expects result tile config to match source"); + + // Design choice: when valid[...] is omitted, infer result valid_shape from + // subview sizes directly. We intentionally do not constrain it by source + // valid_shape to allow user-controlled subview semantics. + + auto expectedValidDim = [&](Value explicitValid, int64_t defaultSize) { + if (!explicitValid) + return defaultSize; + int64_t c = 0; + if (getConstIndex(explicitValid, c)) + return std::min(c, defaultSize); + return ShapedType::kDynamic; + }; + int64_t expectedVRow = expectedValidDim(getValidRow(), sizeR); + int64_t expectedVCol = expectedValidDim(getValidCol(), sizeC); + auto dstValid = dstTy.getValidShape(); + if (dstValid.size() != 2) + return emitOpError("expects result to have rank-2 valid_shape"); + if (dstValid[0] != expectedVRow) + return emitOpError("expects result valid_shape[0] to match inferred/explicit valid_row"); + if (dstValid[1] != expectedVCol) + return emitOpError("expects result valid_shape[1] to match inferred/explicit valid_col"); + + auto cfg = srcTy.getConfigAttr(); + if (!cfg) cfg = TileBufConfigAttr::getDefault(getContext()); + + int64_t innerRows = 1, innerCols = 1; + bool boxed = false; + int32_t bl = 0, sl = 0; + if (failed(computeInnerShape(cfg, srcTy.getElementType(), innerRows, innerCols, + boxed, bl, sl))) + return emitOpError("unsupported tile layout for subview"); + + if (!boxed) + return success(); + + // Boxed layout: require static 2D sizes with inner alignment. Offsets may be + // dynamic, but static offsets must be aligned. + if (sizeR % innerRows != 0 || sizeC % innerCols != 0) + return emitOpError("boxed layout subview sizes must be multiples of inner shape"); + + if (offRConst) { + if (offR % innerRows != 0) + return emitOpError("boxed layout subview offsets must be multiples of inner shape"); + } + if (offCConst) { + if (offC % innerCols != 0) + return emitOpError("boxed layout subview offsets must be multiples of inner shape"); + } + + (void)bl; + if (srcShape.size() != 2 || + srcShape[0] == ShapedType::kDynamic || + srcShape[1] == ShapedType::kDynamic) { + return emitOpError("boxed layout subview requires static source shape"); + } + + return success(); +} + +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +// ============================================================================= +// Helper Functions +// ============================================================================= + +[[maybe_unused]] static AddressSpace getAddressSpace(Value val) { + auto type = llvm::dyn_cast(val.getType()); + if (!type) return AddressSpace::Zero; // Default + + // 假设你的 AddressSpaceAttr 存储在 MemRef 的 memorySpace 中 + // 需要根据你的 getPTOAddressSpaceAttr 实现来调整 + auto attr = llvm::dyn_cast_or_null(type.getMemorySpace()); + if (attr) return attr.getAddressSpace(); + return AddressSpace::Zero; +} + +// ============================================================================= +// Side Effects Implementation +// ============================================================================= + +// [Fix] 辅助函数:重载以支持 OpOperand* 和 OpResult,避免直接传 Value + +// 针对操作数 (Operand) 的重载 +static void addEffect( + SmallVectorImpl> &effects, + OpOperand *operand, MemoryEffects::Effect *effect) { + if (operand) + effects.emplace_back(effect, operand, SideEffects::DefaultResource::get()); +} + +// 针对结果 (Result) 的重载 +static void addEffect( + SmallVectorImpl> &effects, + OpResult result, MemoryEffects::Effect *effect) { + if (result) + effects.emplace_back(effect, result, SideEffects::DefaultResource::get()); +} + +// === TLoadOp === +// Read: src, Write: dst +// 针对 OpOperand* 的重载 +void TLoadOp::getEffects(SmallVectorImpl> &effects) { + // [Fix] 单个操作数,直接取地址 + addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +void TPrefetchOp::getEffects( + SmallVectorImpl> &effects) { + addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +// === TAbsOp === +// Read: src, Write: dst +void TAbsOp::getEffects( + SmallVectorImpl> &effects) { + addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +// === TStoreOp === +// Read: src, Write: dst (GM) +void TStoreOp::getEffects(SmallVectorImpl> &effects) { + addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); + auto preQuantRange = getPreQuantScalarMutable(); + if (!preQuantRange.empty()) + addEffect(effects, &*preQuantRange.begin(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +// === TMovOp === +// Read: src, Write: dst +void TMovOp::getEffects(SmallVectorImpl> &effects) { + addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); + auto fpRange = getFpMutable(); + if (!fpRange.empty()) + addEffect(effects, &*fpRange.begin(), MemoryEffects::Read::get()); + auto preQuantRange = getPreQuantScalarMutable(); + if (!preQuantRange.empty()) + addEffect(effects, &*preQuantRange.begin(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +#define PTO_ADD_READ(operand) addEffect(effects, &(operand), MemoryEffects::Read::get()) +#define PTO_ADD_WRITE(operand) addEffect(effects, &(operand), MemoryEffects::Write::get()) + +#define PTO_DEFINE_UNARY_EFFECTS(OpClass, srcOperand, dstOperand) \ + void OpClass::getEffects( \ + SmallVectorImpl> &effects) { \ + PTO_ADD_READ(srcOperand); \ + PTO_ADD_WRITE(dstOperand); \ + } + +#define PTO_DEFINE_BINARY_EFFECTS(OpClass, lhsOperand, rhsOperand, dstOperand) \ + void OpClass::getEffects( \ + SmallVectorImpl> &effects) { \ + PTO_ADD_READ(lhsOperand); \ + PTO_ADD_READ(rhsOperand); \ + PTO_ADD_WRITE(dstOperand); \ + } + +#define PTO_DEFINE_TERNARY_EFFECTS(OpClass, op0, op1, op2, dstOperand) \ + void OpClass::getEffects( \ + SmallVectorImpl> &effects) { \ + PTO_ADD_READ(op0); \ + PTO_ADD_READ(op1); \ + PTO_ADD_READ(op2); \ + PTO_ADD_WRITE(dstOperand); \ + } + +#define PTO_DEFINE_QUATERNARY_EFFECTS(OpClass, op0, op1, op2, op3, dstOperand) \ + void OpClass::getEffects( \ + SmallVectorImpl> &effects) { \ + PTO_ADD_READ(op0); \ + PTO_ADD_READ(op1); \ + PTO_ADD_READ(op2); \ + PTO_ADD_READ(op3); \ + PTO_ADD_WRITE(dstOperand); \ + } + +void LoadScalarOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getPtrMutable()); +} + +void StoreScalarOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_WRITE(getPtrMutable()); +} + +// === Tile/Device ops added for InsertSync === + +// MGATHER: Read(mem, idx) -> Write(dst) +void MGatherOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getMemMutable()); + PTO_ADD_READ(getIdxMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +// MSCATTER: Read(src, idx) -> Write(mem) +void MScatterOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_READ(getIdxMutable()); + PTO_ADD_WRITE(getMemMutable()); +} + +// TGETVAL: Read(src) -> scalar result +void TGetValOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); +} + +void THistogramOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_READ(getIdxMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +void TGetScaleAddrOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +// TSETVAL: Write(dst) (single element update) +void TSetValOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_WRITE(getDstMutable()); +} + +// SET_VALIDSHAPE: update runtime valid row/col metadata on source tile in-place. +void SetValidShapeOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_WRITE(getSourceMutable()); +} + +// GET_VALIDSHAPE: read runtime valid row/col metadata from source tile. +void GetValidShapeOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSourceMutable()); +} + +// Elementwise + reductions: mostly PIPE_V tilebuf ops +PTO_DEFINE_BINARY_EFFECTS(TAddOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_TERNARY_EFFECTS(TAddCOp, getSrc0Mutable(), getSrc1Mutable(), getSrc2Mutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TAddSOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TAddSCOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +void TAxpyOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_READ(getScalarMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +PTO_DEFINE_BINARY_EFFECTS(TAndOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TConcatOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_QUATERNARY_EFFECTS(TConcatidxOp, getSrc0Mutable(), getSrc1Mutable(), getSrc0IdxMutable(), getSrc1IdxMutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TAndSOp, getSrcMutable(), getDstMutable()) + +// TCI: Write(dst) (generates sequence) +void TCIOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_WRITE(getDstMutable()); +} + +// TTRI: Write(dst) (generates triangular mask) +void TTriOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_WRITE(getDstMutable()); +} + +PTO_DEFINE_BINARY_EFFECTS(TCmpOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TCmpSOp, getSrcMutable(), getDstMutable()) + +PTO_DEFINE_UNARY_EFFECTS(TColExpandOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TColExpandAddOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TColExpandMulOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TColExpandDivOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TColExpandSubOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TColExpandExpdifOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TColExpandMaxOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TColExpandMinOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TColMaxOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TColMinOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TColProdOp, getSrcMutable(), getDstMutable()) + +void TColArgMaxOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +void TColArgMinOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +void TColSumOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + auto tmp = getTmpMutable(); + if (!tmp.empty()) { + PTO_ADD_WRITE(tmp[0]); + } + PTO_ADD_WRITE(getDstMutable()); +} + +void TCvtOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getDstMutable()); +} +void TRandomOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_WRITE(getDstMutable()); +} +PTO_DEFINE_BINARY_EFFECTS(TDivOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) + +// TDIVS has custom assembly format; conservatively treat first 2 operands as reads. +void TDivSOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_READ(getScalarMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +PTO_DEFINE_UNARY_EFFECTS(TExpOp, getSrcMutable(), getDstMutable()) + +// TEXPANDS: Write(dst) (broadcast scalar) +void TExpandsOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_WRITE(getDstMutable()); +} + +// TEXTRACT: Read(src) -> Write(dst) +void TExtractOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +// TINSERT: Read(src) -> Write(dst) +void TInsertOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +// TEXTRACT_FP: Read(src), Read(fp) -> Write(dst) +void TExtractFPOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_READ(getFpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +// TINSERT_FP: Read(src), Read(fp) -> Write(dst) +void TInsertFPOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_READ(getFpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +PTO_DEFINE_UNARY_EFFECTS(TFillPadOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TFillPadExpandOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TFillPadInplaceOp, getSrcMutable(), getDstMutable()) + +void TGatherOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + if (auto cdst = getCdstMutable(); !cdst.empty()) + PTO_ADD_WRITE(cdst[0]); + if (auto indices = getIndicesMutable(); !indices.empty()) + PTO_ADD_READ(indices[0]); + if (auto tmp = getTmpMutable(); !tmp.empty()) + PTO_ADD_READ(tmp[0]); + PTO_ADD_WRITE(getDstMutable()); +} + +PTO_DEFINE_BINARY_EFFECTS(TGatherBOp, getSrcMutable(), getOffsetsMutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TLogOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TLReluOp, getSrcMutable(), getDstMutable()) + +PTO_DEFINE_BINARY_EFFECTS(TMaxOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TMaxSOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TMinOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TMinSOp, getSrcMutable(), getDstMutable()) + +PTO_DEFINE_BINARY_EFFECTS(TMovFPOp, getSrcMutable(), getFpMutable(), getDstMutable()) + +void TMrgSortOp::getEffects( + SmallVectorImpl> &effects) { + for (auto &opnd : getSrcsMutable()) { + PTO_ADD_READ(opnd); + } + auto tmp = getTmpMutable(); + if (!tmp.empty()) + PTO_ADD_WRITE(tmp[0]); + for (auto &opnd : getDstsMutable()) { + PTO_ADD_WRITE(opnd); + } + auto executed = getExcutedMutable(); + if (!executed.empty()) { + PTO_ADD_WRITE(executed[0]); + } +} + +PTO_DEFINE_BINARY_EFFECTS(TMulOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TMulSOp, getSrc0Mutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TNegOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TNotOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TOrOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TOrSOp, getSrcMutable(), getDstMutable()) + +PTO_DEFINE_BINARY_EFFECTS(TPartAddOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TPartMaxOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TPartMinOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +void TPartArgMaxOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrc0Mutable()); + PTO_ADD_READ(getSrc1Mutable()); + PTO_ADD_READ(getSrc0IdxMutable()); + PTO_ADD_READ(getSrc1IdxMutable()); + PTO_ADD_WRITE(getDstMutable()); + PTO_ADD_WRITE(getDstIdxMutable()); +} +void TPartArgMinOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrc0Mutable()); + PTO_ADD_READ(getSrc1Mutable()); + PTO_ADD_READ(getSrc0IdxMutable()); + PTO_ADD_READ(getSrc1IdxMutable()); + PTO_ADD_WRITE(getDstMutable()); + PTO_ADD_WRITE(getDstIdxMutable()); +} +PTO_DEFINE_BINARY_EFFECTS(TPartMulOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +// TPRELU: Read(src0, src1) -> Write(tmp, dst) +void TPReluOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrc0Mutable()); + PTO_ADD_READ(getSrc1Mutable()); + // A5 pto-isa TPRELU implementation does not consume tmp; modeling tmp as a + // write-only scratch on A5 incorrectly inflates local-memory planning and + // can trigger false vec-overflow diagnostics. + if (getTargetArch(getOperation()) != PTOArch::A5) + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +void TQuantOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_READ(getFpMutable()); + auto offsetRange = getOffsetMutable(); + if (!offsetRange.empty()) + PTO_ADD_READ(offsetRange[0]); + PTO_ADD_WRITE(getDstMutable()); +} +PTO_DEFINE_TERNARY_EFFECTS(TDequantOp, getSrcMutable(), getScaleMutable(), + getOffsetMutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TRecipOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TReluOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TFModOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TFModSOp, getSrcMutable(), getDstMutable()) +void TRemOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrc0Mutable()); + PTO_ADD_READ(getSrc1Mutable()); + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +void TRemSOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} +PTO_DEFINE_UNARY_EFFECTS(TRowExpandOp, getSrcMutable(), getDstMutable()) + +void TRowExpandDivOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrc0Mutable()); + PTO_ADD_READ(getSrc1Mutable()); + auto tmp = getTmpMutable(); + if (!tmp.empty()) + PTO_ADD_WRITE(tmp[0]); + PTO_ADD_WRITE(getDstMutable()); +} + +void TRowExpandMulOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrc0Mutable()); + PTO_ADD_READ(getSrc1Mutable()); + auto tmp = getTmpMutable(); + if (!tmp.empty()) + PTO_ADD_WRITE(tmp[0]); + PTO_ADD_WRITE(getDstMutable()); +} + +void TRowExpandSubOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrc0Mutable()); + PTO_ADD_READ(getSrc1Mutable()); + auto tmp = getTmpMutable(); + if (!tmp.empty()) + PTO_ADD_WRITE(tmp[0]); + PTO_ADD_WRITE(getDstMutable()); +} + +PTO_DEFINE_BINARY_EFFECTS(TRowExpandAddOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) + +void TRowExpandExpdifOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrc0Mutable()); + PTO_ADD_READ(getSrc1Mutable()); + auto tmp = getTmpMutable(); + if (!tmp.empty()) + PTO_ADD_WRITE(tmp[0]); + PTO_ADD_WRITE(getDstMutable()); +} + +void TRowExpandMaxOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrc0Mutable()); + PTO_ADD_READ(getSrc1Mutable()); + auto tmp = getTmpMutable(); + if (!tmp.empty()) + PTO_ADD_WRITE(tmp[0]); + PTO_ADD_WRITE(getDstMutable()); +} + +void TRowExpandMinOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrc0Mutable()); + PTO_ADD_READ(getSrc1Mutable()); + auto tmp = getTmpMutable(); + if (!tmp.empty()) + PTO_ADD_WRITE(tmp[0]); + PTO_ADD_WRITE(getDstMutable()); +} + +// Row reductions use tmp scratch tile. +void TRowMaxOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +void TRowArgMaxOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + // A5 lowering does not consume tmp for TROWARGMAX; modeling tmp as a + // scratch write inflates local-memory planning and can trigger false + // vec-overflow diagnostics, mirroring the fixed A5 TPRELU issue. + if (getTargetArch(getOperation()) != PTOArch::A5) + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +void TRowMinOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +void TRowArgMinOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + // A5 lowering does not consume tmp for TROWARGMIN; modeling tmp as a + // scratch write inflates local-memory planning and can trigger false + // vec-overflow diagnostics, mirroring the fixed A5 TPRELU issue. + if (getTargetArch(getOperation()) != PTOArch::A5) + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +void TRowSumOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +void TRowProdOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} +void TRsqrtOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + auto tmp = getTmpMutable(); + if (!tmp.empty()) + PTO_ADD_WRITE(tmp[0]); + PTO_ADD_WRITE(getDstMutable()); +} + +void TScatterOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + if (getIndexes()) { + auto idx = getIndexesMutable(); + if (!idx.empty()) + PTO_ADD_READ(idx[0]); + } + PTO_ADD_WRITE(getDstMutable()); +} + +// Select: Read(mask, src0, src1) -> Write(tmp, dst) +void TSelOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getMaskMutable()); + PTO_ADD_READ(getSrc0Mutable()); + PTO_ADD_READ(getSrc1Mutable()); + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +// TSELS: Read(src0, src1) -> Write(tmp, dst) +void TSelSOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getMaskMutable()); + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +PTO_DEFINE_BINARY_EFFECTS(TShlOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TShrOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TShlSOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TShrSOp, getSrcMutable(), getDstMutable()) + +// TSORT32: Read(src, idx) -> Write(dst [, tmp]) +void TSort32Op::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_READ(getIdxMutable()); + auto tmp = getTmpMutable(); + if (!tmp.empty()) + PTO_ADD_WRITE(tmp[0]); + PTO_ADD_WRITE(getDstMutable()); +} + +PTO_DEFINE_UNARY_EFFECTS(TSqrtOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TSubOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +PTO_DEFINE_TERNARY_EFFECTS(TSubCOp, getSrc0Mutable(), getSrc1Mutable(), getSrc2Mutable(), getDstMutable()) +PTO_DEFINE_UNARY_EFFECTS(TSubSOp, getSrcMutable(), getDstMutable()) +PTO_DEFINE_BINARY_EFFECTS(TSubSCOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) + +// TXORS: Read(src) -> Write(tmp, dst) +void TXorSOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +// TXOR: Read(src0, src1) -> Write(tmp?, dst) +void TXorOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrc0Mutable()); + PTO_ADD_READ(getSrc1Mutable()); + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +// TTRANS: Read(src) -> Write(tmp, dst) +void TTransOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +void TPrintOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getSrcMutable()); +} + +#undef PTO_DEFINE_TERNARY_EFFECTS +#undef PTO_DEFINE_BINARY_EFFECTS +#undef PTO_DEFINE_UNARY_EFFECTS +#undef PTO_ADD_WRITE +#undef PTO_ADD_READ + +// === TMatmulOp === +// Read: lhs, rhs, (bias), Write: dst +void TMatmulOp::getEffects(SmallVectorImpl> &effects) { + // Singleton -> 直接取地址 + addEffect(effects, &getLhsMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getRhsMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +// === TMatmulAccOp === +// Read: acc_in, lhs, rhs, Write: dst +void TMatmulAccOp::getEffects(SmallVectorImpl> &effects) { + addEffect(effects, &getAccInMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getLhsMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getRhsMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +// === TMatmulBiasOp === +// Read: a, b, bias, Write: dst +void TMatmulBiasOp::getEffects(SmallVectorImpl> &effects) { + addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); + // 这里的 bias 是必选的 AnyType:$bias,所以是 Singleton + addEffect(effects, &getBiasMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +// === TGemvOp === +// Read: lhs, rhs, Write: dst +void TGemvOp::getEffects(SmallVectorImpl> &effects) { + addEffect(effects, &getLhsMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getRhsMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +// === TGemvAccOp === +// Read: acc_in, lhs, rhs, Write: dst +void TGemvAccOp::getEffects(SmallVectorImpl> &effects) { + addEffect(effects, &getAccInMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getLhsMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getRhsMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +// === TGemvBiasOp === +// Read: a, b, bias, Write: dst +void TGemvBiasOp::getEffects(SmallVectorImpl> &effects) { + addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBiasMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +// === TGemvMxOp === +// Read: a, a_scale, b, b_scale, Write: dst +void TGemvMxOp::getEffects(SmallVectorImpl> &effects) { + addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getAScaleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBScaleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +// === TGemvMxAccOp === +// Read: c_in, a, a_scale, b, b_scale, Write: dst +void TGemvMxAccOp::getEffects(SmallVectorImpl> &effects) { + addEffect(effects, &getCInMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getAScaleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBScaleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +// === TGemvMxBiasOp === +// Read: a, a_scale, b, b_scale, bias, Write: dst +void TGemvMxBiasOp::getEffects(SmallVectorImpl> &effects) { + addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getAScaleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBScaleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBiasMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +// === TMatmulOp === +void TMatmulMxOp::getEffects(SmallVectorImpl> &effects) { + addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getAScaleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBScaleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +// === TMatmulAccMxOp === +// Read: acc_in, lhs, rhs, Write: dst +void TMatmulMxAccOp::getEffects(SmallVectorImpl> &effects) { + addEffect(effects, &getCInMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getAScaleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBScaleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +// === TMatmulBiasMxOp === +// Read: a, b, bias, Write: dst +void TMatmulMxBiasOp::getEffects(SmallVectorImpl> &effects) { + addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getAScaleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getBScaleMutable(), MemoryEffects::Read::get()); + // 这里的 bias 是必选的 AnyType:$bias,所以是 Singleton + addEffect(effects, &getBiasMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); +} + +static bool isInsideSectionCube(Operation *op) { + return op->getParentOfType() != nullptr; +} + +static bool isInsideSectionVector(Operation *op) { + return op->getParentOfType() != nullptr; +} + +static std::optional +getEnclosingFunctionKernelKind(Operation *op) { + auto funcOp = op->getParentOfType(); + if (!funcOp) + return std::nullopt; + + auto kernelKindAttr = + funcOp->getAttrOfType( + FunctionKernelKindAttr::name); + if (!kernelKindAttr) + return std::nullopt; + + return kernelKindAttr.getKernelKind(); +} + +static bool isInsideSectionOrAttributedKernel(Operation *op) { + return isInsideSectionCube(op) || isInsideSectionVector(op) || + getEnclosingFunctionKernelKind(op).has_value(); +} + +static LogicalResult verifySplitAttr(Operation *op, int64_t split) { + if (split < 0 || split > 2) + return op->emitOpError("expects 'split' to be 0, 1, or 2"); + return success(); +} + +static LogicalResult verifyFrontendKernelKind(Operation *op, + FunctionKernelKind expected, + StringRef kernelName) { + auto kernelKind = getEnclosingFunctionKernelKind(op); + if (!kernelKind || *kernelKind != expected) { + return op->emitOpError("must be inside a ") + << kernelName << " kernel function"; + } + return success(); +} + +static ParseResult parseFrontendInitializePipeOp(OpAsmParser &parser, + OperationState &result) { + NamedAttrList attrs; + bool sawId = false; + bool sawDirMask = false; + bool sawSlotSize = false; + bool sawLocalSlotNum = false; + bool sawNoSplit = false; + + if (parser.parseLBrace()) + return failure(); + + while (failed(parser.parseOptionalRBrace())) { + StringRef keyword; + if (parser.parseKeyword(&keyword) || parser.parseEqual()) + return failure(); + + if (keyword == "id") { + if (sawId) + return parser.emitError(parser.getCurrentLocation(), + "duplicate 'id' clause"); + IntegerAttr idAttr; + if (parser.parseAttribute(idAttr, parser.getBuilder().getI32Type(), "id", + attrs)) + return failure(); + sawId = true; + } else if (keyword == "dir_mask") { + if (sawDirMask) + return parser.emitError(parser.getCurrentLocation(), + "duplicate 'dir_mask' clause"); + IntegerAttr dirMaskAttr; + if (parser.parseAttribute(dirMaskAttr, parser.getBuilder().getI8Type(), + "dir_mask", attrs)) + return failure(); + sawDirMask = true; + } else if (keyword == "slot_size") { + if (sawSlotSize) + return parser.emitError(parser.getCurrentLocation(), + "duplicate 'slot_size' clause"); + IntegerAttr slotSizeAttr; + if (parser.parseAttribute(slotSizeAttr, parser.getBuilder().getI32Type(), + "slot_size", attrs)) + return failure(); + sawSlotSize = true; + } else if (keyword == "local_slot_num") { + if (sawLocalSlotNum) + return parser.emitError(parser.getCurrentLocation(), + "duplicate 'local_slot_num' clause"); + IntegerAttr localSlotNumAttr; + if (parser.parseAttribute(localSlotNumAttr, parser.getBuilder().getI32Type(), + "local_slot_num", attrs)) + return failure(); + sawLocalSlotNum = true; + } else if (keyword == "nosplit") { + if (sawNoSplit) + return parser.emitError(parser.getCurrentLocation(), + "duplicate 'nosplit' clause"); + BoolAttr noSplitAttr; + if (parser.parseAttribute(noSplitAttr, "nosplit", attrs)) + return failure(); + sawNoSplit = true; + } else { + return parser.emitError(parser.getCurrentLocation()) + << "unexpected keyword '" << keyword << "'"; + } + + if (succeeded(parser.parseOptionalRBrace())) + break; + if (parser.parseComma()) + return failure(); + } + + if (!sawDirMask) + return parser.emitError(parser.getNameLoc(), "expected 'dir_mask' clause"); + if (!sawSlotSize) + return parser.emitError(parser.getNameLoc(), "expected 'slot_size' clause"); + if (!sawId) + attrs.set("id", parser.getBuilder().getI32IntegerAttr(0)); + + OpAsmParser::UnresolvedOperand gmSlotBuffer; + OpAsmParser::UnresolvedOperand gmSlotTensor; + OpAsmParser::UnresolvedOperand c2vConsumerBuf; + OpAsmParser::UnresolvedOperand v2cConsumerBuf; + Type gmSlotBufferTy; + Type gmSlotTensorTy; + Type c2vConsumerBufTy; + Type v2cConsumerBufTy; + bool hasGmSlotBuffer = false; + bool hasGmSlotTensor = false; + bool hasC2vConsumerBuf = false; + bool hasV2cConsumerBuf = false; + + if (parser.parseLParen()) + return failure(); + while (failed(parser.parseOptionalRParen())) { + StringRef keyword; + if (parser.parseKeyword(&keyword) || parser.parseEqual()) + return failure(); + + if (keyword == "gm_slot_buffer") { + if (hasGmSlotBuffer) + return parser.emitError(parser.getCurrentLocation(), + "duplicate 'gm_slot_buffer' operand"); + if (parser.parseOperand(gmSlotBuffer) || + parser.parseColonType(gmSlotBufferTy)) + return failure(); + hasGmSlotBuffer = true; + } else if (keyword == "gm_slot_tensor") { + if (hasGmSlotTensor) + return parser.emitError(parser.getCurrentLocation(), + "duplicate 'gm_slot_tensor' operand"); + if (parser.parseOperand(gmSlotTensor) || + parser.parseColonType(gmSlotTensorTy)) + return failure(); + hasGmSlotTensor = true; + } else if (keyword == "c2v_consumer_buf") { + if (hasC2vConsumerBuf) + return parser.emitError(parser.getCurrentLocation(), + "duplicate 'c2v_consumer_buf' operand"); + if (parser.parseOperand(c2vConsumerBuf) || + parser.parseColonType(c2vConsumerBufTy)) + return failure(); + hasC2vConsumerBuf = true; + } else if (keyword == "v2c_consumer_buf") { + if (hasV2cConsumerBuf) + return parser.emitError(parser.getCurrentLocation(), + "duplicate 'v2c_consumer_buf' operand"); + if (parser.parseOperand(v2cConsumerBuf) || + parser.parseColonType(v2cConsumerBufTy)) + return failure(); + hasV2cConsumerBuf = true; + } else { + return parser.emitError(parser.getCurrentLocation()) + << "unexpected initialize_pipe operand '" << keyword << "'"; + } + + if (succeeded(parser.parseOptionalRParen())) + break; + if (parser.parseComma()) + return failure(); + } + + if (parser.parseOptionalAttrDict(attrs)) + return failure(); + + result.addAttributes(attrs); + result.addAttribute("operandSegmentSizes", + parser.getBuilder().getDenseI32ArrayAttr( + {hasGmSlotBuffer ? 1 : 0, hasGmSlotTensor ? 1 : 0, + hasC2vConsumerBuf ? 1 : 0, + hasV2cConsumerBuf ? 1 : 0})); + if (hasGmSlotBuffer && + parser.resolveOperand(gmSlotBuffer, gmSlotBufferTy, result.operands)) + return failure(); + if (hasGmSlotTensor && + parser.resolveOperand(gmSlotTensor, gmSlotTensorTy, result.operands)) + return failure(); + if (hasC2vConsumerBuf && + parser.resolveOperand(c2vConsumerBuf, c2vConsumerBufTy, result.operands)) + return failure(); + if (hasV2cConsumerBuf && + parser.resolveOperand(v2cConsumerBuf, v2cConsumerBufTy, result.operands)) + return failure(); + return success(); +} + +template +static void printFrontendInitializePipeOp(InitOpT op, OpAsmPrinter &p) { + p << " {"; + bool needsComma = false; + auto printClause = [&](StringRef keyword, auto value) { + if (needsComma) + p << ", "; + p << keyword << " = " << value; + needsComma = true; + }; + + if (op.getId() != 0) + printClause("id", op.getId()); + printClause("dir_mask", static_cast(op.getDirMask())); + printClause("slot_size", op.getSlotSize()); + if (auto localSlotNumAttr = op.getLocalSlotNumAttr()) + printClause("local_slot_num", localSlotNumAttr.getInt()); + if (auto noSplitAttr = op.getNosplitAttr()) + printClause("nosplit", noSplitAttr.getValue() ? "true" : "false"); + p << "}"; + + p << "("; + bool needsOperandComma = false; + auto printOperandClause = [&](StringRef keyword, Value value) { + if (needsOperandComma) + p << ", "; + p << keyword << " = " << value << " : " << value.getType(); + needsOperandComma = true; + }; + if (op.getGmSlotBuffer()) { + printOperandClause("gm_slot_buffer", op.getGmSlotBuffer()); + } + if (op.getGmSlotTensor()) + printOperandClause("gm_slot_tensor", op.getGmSlotTensor()); + if (op.getC2vConsumerBuf()) + printOperandClause("c2v_consumer_buf", op.getC2vConsumerBuf()); + if (op.getV2cConsumerBuf()) + printOperandClause("v2c_consumer_buf", op.getV2cConsumerBuf()); + p << ")"; + p.printOptionalAttrDict( + op->getAttrs(), + /*elidedAttrs=*/{"id", "dir_mask", "slot_size", "local_slot_num", + "nosplit", "operandSegmentSizes"}); +} + +static std::optional +getStaticElementCount(ArrayRef shape) { + uint64_t count = 1; + for (int64_t dim : shape) { + if (dim == ShapedType::kDynamic || dim < 0) + return std::nullopt; + count *= static_cast(dim); + } + return count; +} + +static bool isSameOrHalfSlotByteSize(uint64_t tensorBytes, uint64_t slotBytes) { + return tensorBytes == slotBytes || tensorBytes * 2 == slotBytes; +} + +static LogicalResult verifyFrontendGlobalSlotTensor(Operation *op, Value tensor, + int8_t dirMask, + int32_t slotSize) { + (void)dirMask; + auto tvTy = dyn_cast(tensor.getType()); + if (!tvTy) + return op->emitOpError("expects 'gm_slot_tensor' to be !pto.tensor_view"); + + ArrayRef shape = tvTy.getShape(); + if (shape.empty()) + return op->emitOpError( + "expects 'gm_slot_tensor' to describe one slot entry tensor"); + + if (auto elemCount = getStaticElementCount(shape)) { + uint64_t elemBytes = getElemByteSize(tvTy.getElementType()); + if (elemBytes != 0) { + uint64_t tensorBytes = *elemCount * elemBytes; + if (!isSameOrHalfSlotByteSize(tensorBytes, + static_cast(slotSize))) { + return op->emitOpError() + << "expects 'slot_size' to equal gm_slot_tensor byte size " + "or twice gm_slot_tensor byte size for split GlobalTensor " + "entries (got slot_size = " + << slotSize << ", gm_slot_tensor byte size = " << tensorBytes + << ")"; + } + } + } + + return success(); +} + +template +static LogicalResult verifyFrontendInitCommon(InitOpT op, + FunctionKernelKind expected, + StringRef kernelName) { + if (failed(verifyFrontendKernelKind(op.getOperation(), expected, kernelName))) + return failure(); + + auto funcOp = op->template getParentOfType(); + if (!funcOp) + return op.emitOpError("must be nested under a func.func"); + + if (op.getId() < 0) + return op.emitOpError("expects 'id' to be non-negative"); + + unsigned sameIdInitCount = 0; + funcOp.walk([&](Operation *candidate) { + if (auto aic = dyn_cast(candidate)) { + if (aic.getId() == op.getId()) + ++sameIdInitCount; + return; + } + if (auto aiv = dyn_cast(candidate)) + if (aiv.getId() == op.getId()) + ++sameIdInitCount; + }); + if (sameIdInitCount > 1) { + return op.emitOpError( + "requires 'id' to be unique across frontend initialize_pipe ops in the function"); + } + + int8_t dirMask = op.getDirMask(); + if (dirMask != 1 && dirMask != 2 && dirMask != 3) + return op.emitOpError("expects 'dir_mask' to be 1, 2, or 3"); + if (op.getSlotSize() <= 0) + return op.emitOpError("expects 'slot_size' to be greater than 0"); + + bool hasGlobalSlotTensor = static_cast(op.getGmSlotTensor()); + bool hasC2vConsumerBuf = static_cast(op.getC2vConsumerBuf()); + bool hasV2cConsumerBuf = static_cast(op.getV2cConsumerBuf()); + if (hasGlobalSlotTensor) { + if (op.getGmSlotBuffer() || hasC2vConsumerBuf || hasV2cConsumerBuf) { + return op.emitOpError( + "globaltensor pipe init expects only 'gm_slot_tensor' and no " + "'gm_slot_buffer', 'c2v_consumer_buf', or 'v2c_consumer_buf'"); + } + if (op.getLocalSlotNumAttr()) + return op.emitOpError( + "globaltensor pipe init does not use 'local_slot_num'"); + if (getTargetArch(op.getOperation()) == PTOArch::A5) { + return op.emitOpError( + "globaltensor pipe entries are supported for a2/a3 l2g2l pipes"); + } + return verifyFrontendGlobalSlotTensor( + op.getOperation(), op.getGmSlotTensor(), dirMask, op.getSlotSize()); + } + + if (hasC2vConsumerBuf != hasV2cConsumerBuf) { + return op.emitOpError( + "expects 'c2v_consumer_buf' and 'v2c_consumer_buf' to be provided together"); + } + if (!hasC2vConsumerBuf) { + return op.emitOpError( + "expects local pipe init to provide 'c2v_consumer_buf' and " + "'v2c_consumer_buf'; use 'gm_slot_tensor' for globaltensor pipe entries"); + } + + if (auto localSlotNumAttr = op.getLocalSlotNumAttr()) { + int32_t localSlotNum = localSlotNumAttr.getInt(); + if (localSlotNum <= 0) + return op.emitOpError("expects 'local_slot_num' to be greater than 0"); + int32_t loweredSlotNum = dirMask == 3 ? 4 : 8; + if (localSlotNum > loweredSlotNum) { + return op.emitOpError() + << "expects 'local_slot_num' to be less than or equal to " + << loweredSlotNum << " for dir_mask = " << static_cast(dirMask); + } + } + + return success(); +} + +ParseResult AicInitializePipeOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseFrontendInitializePipeOp(parser, result); +} + +void AicInitializePipeOp::print(OpAsmPrinter &p) { + printFrontendInitializePipeOp(*this, p); +} + +ParseResult AivInitializePipeOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseFrontendInitializePipeOp(parser, result); +} + +void AivInitializePipeOp::print(OpAsmPrinter &p) { + printFrontendInitializePipeOp(*this, p); +} + +static ReserveBufferOp findReserveBufferByName(func::FuncOp funcOp, + StringRef name) { + ReserveBufferOp found; + funcOp.walk([&](ReserveBufferOp reserveOp) { + if (reserveOp.getName() != name) + return WalkResult::advance(); + found = reserveOp; + return WalkResult::interrupt(); + }); + return found; +} + +LogicalResult ReserveBufferOp::verify() { + auto funcOp = getOperation()->getParentOfType(); + if (!funcOp) + return emitOpError("must be nested under a func.func"); + + if (getSize() <= 0) + return emitOpError("expects 'size' to be greater than 0"); + + auto location = getLocation().getAddressSpace(); + if (location != AddressSpace::VEC && location != AddressSpace::MAT) + return emitOpError("expects 'location' to be #pto.address_space or #pto.address_space"); + + if (!getAutoAlloc() && !getBaseAttr()) + return emitOpError("expects 'base' when 'auto' is false"); + + if (auto baseAttr = getBaseAttr(); baseAttr && baseAttr.getInt() < 0) + return emitOpError("expects 'base' to be non-negative when present"); + + unsigned sameNameCount = 0; + funcOp.walk([&](ReserveBufferOp reserveOp) { + if (reserveOp.getName() == getName()) + ++sameNameCount; + }); + if (sameNameCount > 1) + return emitOpError("requires 'name' to be unique within the function"); + + return success(); +} + +LogicalResult ImportReservedBufferOp::verify() { + auto funcOp = getOperation()->getParentOfType(); + if (!funcOp) + return emitOpError("must be nested under a func.func"); + + auto peerFunc = SymbolTable::lookupNearestSymbolFrom( + getOperation(), getPeerFuncAttr()); + if (!peerFunc) + return emitOpError("expects 'peer_func' to reference an existing func.func"); + + unsigned sameImportCount = 0; + funcOp.walk([&](ImportReservedBufferOp importOp) { + if (importOp.getName() == getName() && + importOp.getPeerFuncAttr() == getPeerFuncAttr()) { + ++sameImportCount; + } + }); + if (sameImportCount > 1) { + return emitOpError( + "requires (name, peer_func) to be unique within the function"); + } + + if (!findReserveBufferByName(peerFunc, getName())) + return emitOpError("expects matching peer reserve_buffer to exist"); + + return success(); +} + +static FailureOr lookupFrontendInitOpById(Operation *op, + func::FuncOp funcOp, + int32_t id) { + Operation *matchedInit = nullptr; + unsigned matchedInitCount = 0; + funcOp.walk([&](Operation *candidate) { + if (auto aic = dyn_cast(candidate)) { + if (aic.getId() == static_cast(id)) { + matchedInit = candidate; + ++matchedInitCount; + } + return WalkResult::advance(); + } + if (auto aiv = dyn_cast(candidate)) { + if (aiv.getId() == static_cast(id)) { + matchedInit = candidate; + ++matchedInitCount; + } + return WalkResult::advance(); + } + return WalkResult::advance(); + }); + + if (matchedInitCount == 0) { + op->emitOpError() << "expects 'id' = " << id + << " to match a frontend initialize_pipe op in the same function"; + return failure(); + } + if (matchedInitCount > 1) { + op->emitOpError() << "expects 'id' = " << id + << " to match exactly one frontend initialize_pipe op in the same function"; + return failure(); + } + return matchedInit; +} + +static LogicalResult verifyFrontendSplitOp(Operation *op, + FunctionKernelKind expected, + StringRef kernelName, + int32_t id, + int64_t split) { + if (failed(verifyFrontendKernelKind(op, expected, kernelName))) + return failure(); + if (id < 0) + return op->emitOpError("expects 'id' to be non-negative"); + return verifySplitAttr(op, split); +} + +static FailureOr lookupFrontendInitDirMaskById(Operation *op, + func::FuncOp funcOp, + int32_t id) { + auto initOr = lookupFrontendInitOpById(op, funcOp, id); + if (failed(initOr)) + return failure(); + if (auto aic = dyn_cast(*initOr)) + return aic.getDirMask(); + return cast(*initOr).getDirMask(); +} + +static LogicalResult verifyFrontendDataOpDirection(Operation *op, int32_t id, + bool expectC2V) { + auto funcOp = op->getParentOfType(); + if (!funcOp) + return op->emitOpError("must be nested under a func.func"); + + auto dirMaskOr = lookupFrontendInitDirMaskById(op, funcOp, id); + if (failed(dirMaskOr)) + return failure(); + + int8_t dirMask = *dirMaskOr; + if (expectC2V && dirMask != 1 && dirMask != 3) { + return op->emitOpError() + << "expects 'id' = " << id + << " to reference initialize_pipe with dir_mask = 1 or 3"; + } + if (!expectC2V && dirMask != 2 && dirMask != 3) { + return op->emitOpError() + << "expects 'id' = " << id + << " to reference initialize_pipe with dir_mask = 2 or 3"; + } + return success(); +} + +static Value getFrontendInitGmSlotTensor(Operation *initOp) { + if (auto aic = dyn_cast(initOp)) + return aic.getGmSlotTensor(); + return cast(initOp).getGmSlotTensor(); +} + +static LogicalResult verifyFrontendTensorEntryMatchesInit(Operation *op, + int32_t id, + Type entryTy) { + auto entryViewTy = dyn_cast(entryTy); + if (!entryViewTy) + return success(); + + auto funcOp = op->getParentOfType(); + if (!funcOp) + return op->emitOpError("must be nested under a func.func"); + + auto initOr = lookupFrontendInitOpById(op, funcOp, id); + if (failed(initOr)) + return failure(); + Value gmSlotTensor = getFrontendInitGmSlotTensor(*initOr); + if (!gmSlotTensor) { + return op->emitOpError() + << "expects 'id' = " << id + << " to reference initialize_pipe with 'gm_slot_tensor' when the " + "pipe entry is !pto.tensor_view"; + } + + auto slotTensorTy = dyn_cast(gmSlotTensor.getType()); + if (!slotTensorTy) + return op->emitOpError("expects 'gm_slot_tensor' to be !pto.tensor_view"); + if (slotTensorTy.getElementType() != entryViewTy.getElementType()) { + return op->emitOpError() + << "expects pipe entry element type to match gm_slot_tensor element type"; + } + if (slotTensorTy.getRank() != entryViewTy.getRank()) { + return op->emitOpError() + << "expects pipe entry rank to match gm_slot_tensor rank"; + } + + ArrayRef slotShape = slotTensorTy.getShape(); + ArrayRef entryShape = entryViewTy.getShape(); + for (auto [idx, entryDim] : llvm::enumerate(entryShape)) { + int64_t slotDim = slotShape[idx]; + if (slotDim == ShapedType::kDynamic || + entryDim == ShapedType::kDynamic || slotDim == entryDim) + continue; + return op->emitOpError() + << "expects pipe entry dimension " << idx + << " to match gm_slot_tensor dimension " << slotDim; + } + return success(); +} + +template +static LogicalResult verifyFrontendPopOp(FrontendPopOpT op, + FunctionKernelKind expected, + StringRef kernelName, + bool expectC2V) { + if (failed(verifyFrontendSplitOp(op.getOperation(), expected, kernelName, + op.getId(), + op.getSplit()))) + return failure(); + if (failed(verifyFrontendDataOpDirection(op.getOperation(), op.getId(), + expectC2V))) + return failure(); + if (failed(verifyFrontendTensorEntryMatchesInit(op.getOperation(), op.getId(), + op.getTile().getType()))) + return failure(); + + bool hasValidRow = static_cast(op.getValidRow()); + bool hasValidCol = static_cast(op.getValidCol()); + if (hasValidRow != hasValidCol) + return op.emitOpError( + "expects valid_row and valid_col operands to be provided together"); + if (!hasValidRow) + return success(); + + if (isa(op.getTile().getType())) + return op.emitOpError( + "does not accept valid_row/valid_col when result is !pto.tensor_view"); + + auto tileTy = dyn_cast(op.getTile().getType()); + if (!tileTy) + return op.emitOpError( + "expects tile result to be !pto.tile_buf when valid_row/valid_col operands are provided"); + if (!tileTy.hasDynamicValid()) + return op.emitOpError( + "expects tile result to have dynamic validShape (?, ?) when valid_row/valid_col operands are provided"); + return success(); +} + +static LogicalResult verifyPipeShape(Operation *op, int8_t dirMask, int32_t slotSize, + int32_t slotNum, + std::optional flagBase) { + constexpr int32_t kMaxHardwareFlagIds = 16; + if (dirMask != 1 && dirMask != 2 && dirMask != 3) + return op->emitOpError("expects 'dir_mask' to be 1, 2, or 3"); + if (slotSize <= 0) + return op->emitOpError("expects 'slot_size' to be greater than 0"); + if (slotNum != 4 && slotNum != 8) + return op->emitOpError("expects 'slot_num' to be 4 or 8"); + if (flagBase && *flagBase < 0) + return op->emitOpError("expects 'flag_base' to be non-negative when present"); + if (flagBase) { + int32_t flagWidth = dirMask == 3 ? 4 : 2; + if (*flagBase + flagWidth > kMaxHardwareFlagIds) { + return op->emitOpError() + << "requires 'flag_base' and dir_mask to fit within " + << kMaxHardwareFlagIds << " hardware flag ids"; + } + } + + return success(); +} + +static LogicalResult verifyPipeHandleProducer(Operation *op, Value pipeHandle) { + if (!isa(pipeHandle.getType())) + return op->emitOpError("expects pipe operand type !pto.pipe"); + if (!pipeHandle.getDefiningOp() && + !pipeHandle.getDefiningOp()) { + return op->emitOpError( + "pipe_handle must be produced by pto.initialize_l2l_pipe or " + "pto.initialize_l2g2l_pipe"); + } + return success(); +} + +static bool getTensorLikeElementAndShape(Type ty, Type &elementType, + ArrayRef &shape) { + if (auto tvTy = dyn_cast(ty)) { + elementType = tvTy.getElementType(); + shape = tvTy.getShape(); + return true; + } + if (auto memrefTy = dyn_cast(ty)) { + elementType = memrefTy.getElementType(); + shape = memrefTy.getShape(); + return true; + } + return false; +} + +static LogicalResult verifyTensorEntryMatchesInternalPipeInit(Operation *op, + Value pipeHandle, + Type entryTy) { + auto entryViewTy = dyn_cast(entryTy); + if (!entryViewTy) + return success(); + + auto initOp = pipeHandle.getDefiningOp(); + if (!initOp) { + return op->emitOpError() + << "expects !pto.tensor_view pipe entry to use a pipe produced by " + "pto.initialize_l2g2l_pipe"; + } + if (initOp.getLocalAddr()) { + return op->emitOpError() + << "expects !pto.tensor_view pipe entry to use global-only " + "pto.initialize_l2g2l_pipe without local_addr"; + } + + Type slotElementType; + ArrayRef slotShape; + if (!getTensorLikeElementAndShape(initOp.getGmAddr().getType(), + slotElementType, slotShape)) { + return op->emitOpError() + << "expects !pto.tensor_view pipe entry to use " + "pto.initialize_l2g2l_pipe gm_addr with tensor/memref slot type"; + } + + if (slotElementType != entryViewTy.getElementType()) { + return op->emitOpError() + << "expects pipe entry element type to match initialize_l2g2l_pipe " + "gm_addr element type"; + } + if (slotShape.size() != static_cast(entryViewTy.getRank())) { + return op->emitOpError() + << "expects pipe entry rank to match initialize_l2g2l_pipe gm_addr " + "rank"; + } + + ArrayRef entryShape = entryViewTy.getShape(); + for (auto [idx, entryDim] : llvm::enumerate(entryShape)) { + int64_t slotDim = slotShape[idx]; + if (slotDim == ShapedType::kDynamic || + entryDim == ShapedType::kDynamic || slotDim == entryDim) + continue; + return op->emitOpError() + << "expects pipe entry dimension " << idx + << " to match initialize_l2g2l_pipe gm_addr dimension " + << slotDim; + } + + if (auto entryElemCount = getStaticElementCount(entryShape)) { + uint64_t elemBytes = getElemByteSize(entryViewTy.getElementType()); + uint64_t entryBytes = *entryElemCount * elemBytes; + if (elemBytes != 0) { + int8_t split = 0; + if (auto alloc = dyn_cast(op)) + split = alloc.getSplit(); + else if (auto push = dyn_cast(op)) + split = push.getSplit(); + else if (auto pop = dyn_cast(op)) + split = pop.getSplit(); + else if (auto free = dyn_cast(op)) + split = free.getSplit(); + + uint64_t slotBytes = static_cast(initOp.getSlotSize()); + bool isSplitEntry = split != 0; + bool byteSizeMatches = + entryBytes == slotBytes || (isSplitEntry && entryBytes * 2 == slotBytes); + if (!byteSizeMatches) { + return op->emitOpError() + << "expects pipe entry byte size to match initialize_l2g2l_pipe " + "slot_size" + << (isSplitEntry ? " or half slot_size for split entries" : "") + << " (got entry byte size = " << entryBytes + << ", slot_size = " << initOp.getSlotSize() << ")"; + } + } + } + + return success(); +} + +LogicalResult BuildAsyncSessionOp::verify() { + Type scratchTy = getScratch().getType(); + if (!isa(scratchTy)) + return emitOpError("expects scratch to be tile_buf or memref type"); + + auto scratchSpace = getPTOMemorySpaceEnum(scratchTy); + if (!scratchSpace || *scratchSpace != pto::AddressSpace::VEC) + return emitOpError("expects scratch to be in vec address space"); + + auto scratchShape = getShapeVec(scratchTy); + if (scratchShape.empty() || scratchShape.size() > 2) + return emitOpError("expects scratch to be rank-1 or rank-2"); + for (int64_t dim : scratchShape) { + if (dim == ShapedType::kDynamic) + return emitOpError("expects scratch to have a static shape"); + } + + auto scratchBytes = getStaticByteSize(scratchTy); + if (!scratchBytes) + return emitOpError("expects scratch byte size to be statically known"); + if (*scratchBytes < sizeof(uint64_t)) + return emitOpError("expects scratch to provide at least 8 bytes"); + + Type workspaceElemTy; + Type workspaceTy = getWorkspace().getType(); + if (auto ptrTy = dyn_cast(workspaceTy)) { + workspaceElemTy = ptrTy.getElementType(); + } else if (auto memTy = dyn_cast(workspaceTy)) { + workspaceElemTy = memTy.getElementType(); + if (!isGmAddressSpaceAttr(memTy.getMemorySpace())) + return emitOpError("expects workspace to be in GM address space"); + } else { + return emitOpError("expects workspace to be !pto.ptr or memref type"); + } + if (!isByteIntegerType(workspaceElemTy)) + return emitOpError("expects workspace element type to be an 8-bit integer"); + + if (auto syncIdAttr = getSyncIdAttr()) { + int64_t syncId = syncIdAttr.getInt(); + if (syncId < 0 || syncId > 7) + return emitOpError("expects sync_id in range [0, 7]"); + } + if (auto blockBytesAttr = getBlockBytesAttr()) { + if (blockBytesAttr.getInt() <= 0) + return emitOpError("expects block_bytes to be greater than 0"); + } + if (auto commBlockOffsetAttr = getCommBlockOffsetAttr()) { + if (commBlockOffsetAttr.getInt() < 0) + return emitOpError("expects comm_block_offset to be non-negative"); + } + if (auto queueNumAttr = getQueueNumAttr()) { + if (queueNumAttr.getInt() <= 0) + return emitOpError("expects queue_num to be greater than 0"); + } + if (auto channelGroupIdxAttr = getChannelGroupIdxAttr()) { + APInt value = channelGroupIdxAttr.getValue(); + if (value.isNegative()) + return emitOpError("expects channel_group_idx to be non-negative"); + if (value.ugt(UINT32_MAX)) + return emitOpError("expects channel_group_idx to fit in uint32"); + } + + return success(); +} + +static LogicalResult verifyAsyncTransferOp(Operation *op, Value dst, Value src) { + Type dstElemTy = getElemTy(dst.getType()); + Type srcElemTy = getElemTy(src.getType()); + if (!dstElemTy || !srcElemTy) + return op->emitOpError("expects src and dst to have element types"); + if (dstElemTy != srcElemTy) + return op->emitOpError("expects src and dst to have the same element type"); + if (failed(verifyAsyncFlatContiguous1DGMViewLike(op, dst, "dst")) || + failed(verifyAsyncFlatContiguous1DGMViewLike(op, src, "src"))) + return failure(); + if (getShapeVec(dst.getType()) != getShapeVec(src.getType())) + return op->emitOpError("expects src and dst to have the same static shape"); + return success(); +} + +LogicalResult TPutAsyncOp::verify() { + return verifyAsyncTransferOp(getOperation(), getDst(), getSrc()); +} + +LogicalResult TGetAsyncOp::verify() { + return verifyAsyncTransferOp(getOperation(), getDst(), getSrc()); +} + +LogicalResult TPutOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + if (failed(verifyCommGlobalLike(*this, getDst(), "dst")) || + failed(verifyCommGlobalLike(*this, getSrc(), "src")) || + failed(verifyCommStagingTileLike(*this, getPing(), "ping")) || + failed(verifyCommPingPongSameType(*this, getPing(), getPong(), "ping", + "pong"))) + return failure(); + if (getElemTy(getDst().getType()) != getElemTy(getSrc().getType())) + return emitOpError("expects src and dst to have the same element type"); + if (getShapeVec(getDst().getType()) != getShapeVec(getSrc().getType())) + return emitOpError("expects src and dst to have the same static shape"); + if (getElemTy(getPing().getType()) != getElemTy(getSrc().getType())) + return emitOpError("expects staging tile element type to match src/dst"); + return success(); +} + +LogicalResult TGetOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + if (failed(verifyCommGlobalLike(*this, getDst(), "dst")) || + failed(verifyCommGlobalLike(*this, getSrc(), "src")) || + failed(verifyCommStagingTileLike(*this, getPing(), "ping")) || + failed(verifyCommPingPongSameType(*this, getPing(), getPong(), "ping", + "pong"))) + return failure(); + if (getElemTy(getDst().getType()) != getElemTy(getSrc().getType())) + return emitOpError("expects src and dst to have the same element type"); + if (getShapeVec(getDst().getType()) != getShapeVec(getSrc().getType())) + return emitOpError("expects src and dst to have the same static shape"); + if (getElemTy(getPing().getType()) != getElemTy(getSrc().getType())) + return emitOpError("expects staging tile element type to match src/dst"); + return success(); +} + +LogicalResult TNotifyOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + if (failed(verifyCommSignalLike(*this, getSignal(), "signal"))) + return failure(); + auto valueTy = dyn_cast(getValue().getType()); + if (!valueTy || valueTy.getWidth() != 32) + return emitOpError("expects value to be i32"); + return success(); +} + +LogicalResult TWaitOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + if (failed(verifyCommSignalLike(*this, getSignal(), "signal"))) + return failure(); + auto cmpTy = dyn_cast(getCmpValue().getType()); + if (!cmpTy || cmpTy.getWidth() != 32) + return emitOpError("expects cmp_value to be i32"); + return success(); +} + +LogicalResult TTestOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + if (failed(verifyCommSignalLike(*this, getSignal(), "signal"))) + return failure(); + auto cmpTy = dyn_cast(getCmpValue().getType()); + if (!cmpTy || cmpTy.getWidth() != 32) + return emitOpError("expects cmp_value to be i32"); + return success(); +} + +static LogicalResult verifySyncAllGmWorkspace(Operation *op, Value workspace, + StringRef name) { + Type ty = workspace.getType(); + if (!isa(ty)) + return op->emitOpError() << "expects " << name + << " to be a GM memref/tensor_view/partition_view"; + + if (auto memTy = dyn_cast(ty)) { + if (!memTy.hasRank()) + return op->emitOpError() << "expects " << name << " to be ranked"; + if (!isGmAddressSpaceAttr(memTy.getMemorySpace())) + return op->emitOpError() << "expects " << name + << " to be in GM address space"; + } + + auto elemTy = dyn_cast(getElemTy(ty)); + if (!elemTy || elemTy.getWidth() != 32) + return op->emitOpError() << "expects " << name + << " element type to be i32"; + + SmallVector shape = getShapeVec(ty); + if (shape.empty()) + return op->emitOpError() << "expects " << name << " to have rank >= 1"; + for (int64_t dim : shape) { + if (dim != ShapedType::kDynamic && dim <= 0) + return op->emitOpError() << "expects " << name + << " shape to be positive"; + } + return success(); +} + +static LogicalResult verifySyncAllTileWorkspace(Operation *op, Value workspace, + StringRef name, + pto::AddressSpace expectedSpace) { + Type ty = workspace.getType(); + if (!isa(ty)) + return op->emitOpError() << "expects " << name + << " to be tile_buf or memref type"; + + if (isa(ty) && failed(verifyTileBufCommon(op, ty, name))) + return failure(); + + auto as = getPTOMemorySpaceEnum(ty); + if (!as || *as != expectedSpace) + return op->emitOpError() << "expects " << name << " to be in " + << (expectedSpace == pto::AddressSpace::VEC + ? "vec" + : "mat") + << " address space"; + + Type elemTy = getElemTy(ty); + auto intTy = dyn_cast_or_null(elemTy); + if (!intTy || intTy.getWidth() != 32) + return op->emitOpError() << "expects " << name + << " element type to be i32"; + + auto shape = getShapeVec(ty); + if (shape.empty() || shape.size() > 2) + return op->emitOpError() << "expects " << name + << " to be rank-1 or rank-2"; + for (int64_t dim : shape) { + if (dim != ShapedType::kDynamic && dim <= 0) + return op->emitOpError() << "expects " << name + << " shape to be positive"; + } + return success(); +} + +LogicalResult SyncAllOp::verify() { + bool hasGm = static_cast(getGmWorkspace()); + bool hasUb = static_cast(getUbWorkspace()); + bool hasL1 = static_cast(getL1Workspace()); + auto mode = getMode().getValue(); + auto coreType = getCoreType().getValue(); + + if (mode == pto::SyncAllMode::Hard) { + if (hasGm || hasUb || hasL1 || getUsedCores()) + return emitOpError( + "expects hard syncall to have no workspace operands or used_cores"); + return success(); + } + + if (!hasGm) + return emitOpError("expects soft syncall to provide gm_workspace"); + if (failed(verifySyncAllGmWorkspace(getOperation(), getGmWorkspace(), + "gm_workspace"))) + return failure(); + + if (auto used = getUsedCores()) { + auto intTy = dyn_cast(used.getType()); + if (!intTy || intTy.getWidth() != 32) + return emitOpError("expects used_cores to be i32"); + } + + switch (coreType) { + case pto::SyncCoreType::AIVOnly: + if (!hasUb || hasL1) + return emitOpError("expects soft AIV-only syncall to use gm_workspace " + "+ ub_workspace only"); + return verifySyncAllTileWorkspace(getOperation(), getUbWorkspace(), + "ub_workspace", + pto::AddressSpace::VEC); + case pto::SyncCoreType::AICOnly: + if (hasUb || !hasL1) + return emitOpError("expects soft AIC-only syncall to use gm_workspace " + "+ l1_workspace only"); + return verifySyncAllTileWorkspace(getOperation(), getL1Workspace(), + "l1_workspace", + pto::AddressSpace::MAT); + case pto::SyncCoreType::Mix: + if (!hasUb || !hasL1) + return emitOpError("expects soft mixed syncall to use gm_workspace + " + "ub_workspace + l1_workspace"); + if (failed(verifySyncAllTileWorkspace(getOperation(), getUbWorkspace(), + "ub_workspace", + pto::AddressSpace::VEC))) + return failure(); + return verifySyncAllTileWorkspace(getOperation(), getL1Workspace(), + "l1_workspace", + pto::AddressSpace::MAT); + } + + llvm_unreachable("unhandled SyncCoreType"); +} + +LogicalResult TBroadcastOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + if (failed(verifyCommGlobalLike(*this, getSrc(), "src")) || + failed(verifyCommStagingTileLike(*this, getPing(), "ping")) || + failed(verifyCommPingPongSameType(*this, getPing(), getPong(), "ping", + "pong")) || + failed(verifyCommGlobalGroup(*this, getGroup(), "group"))) + return failure(); + if (getRoot() >= static_cast(getGroup().size())) + return emitOpError("expects root to index into group operands"); + if (getSrc().getType() != getGroup().front().getType()) + return emitOpError("expects src type to match group member type"); + if (getElemTy(getPing().getType()) != getElemTy(getSrc().getType())) + return emitOpError("expects staging tile element type to match src"); + return success(); +} + +LogicalResult CommTGatherOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + if (failed(verifyCommGlobalLike(*this, getDst(), "dst")) || + failed(verifyCommStagingTileLike(*this, getPing(), "ping")) || + failed(verifyCommPingPongSameType(*this, getPing(), getPong(), "ping", + "pong")) || + failed(verifyCommGlobalGroup(*this, getGroup(), "group"))) + return failure(); + if (getRoot() >= static_cast(getGroup().size())) + return emitOpError("expects root to index into group operands"); + if (getElemTy(getDst().getType()) != getElemTy(getGroup().front().getType())) + return emitOpError("expects dst element type to match group member type"); + if (getElemTy(getPing().getType()) != getElemTy(getDst().getType())) + return emitOpError("expects staging tile element type to match dst"); + return success(); +} + +LogicalResult CommTScatterOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + if (failed(verifyCommGlobalLike(*this, getSrc(), "src")) || + failed(verifyCommStagingTileLike(*this, getPing(), "ping")) || + failed(verifyCommPingPongSameType(*this, getPing(), getPong(), "ping", + "pong")) || + failed(verifyCommGlobalGroup(*this, getGroup(), "group"))) + return failure(); + if (getRoot() >= static_cast(getGroup().size())) + return emitOpError("expects root to index into group operands"); + if (getElemTy(getSrc().getType()) != getElemTy(getGroup().front().getType())) + return emitOpError("expects src element type to match group member type"); + if (getElemTy(getPing().getType()) != getElemTy(getSrc().getType())) + return emitOpError("expects staging tile element type to match src"); + return success(); +} + +LogicalResult TReduceOp::verify() { + if (shouldBypassDecodedMemrefVerifier(getOperation())) + return success(); + if (failed(verifyCommGlobalLike(*this, getDst(), "dst")) || + failed(verifyCommStagingTileLike(*this, getAcc(), "acc")) || + failed(verifyCommStagingTileLike(*this, getRecvPing(), "recv_ping")) || + failed(verifyCommPingPongSameType(*this, getRecvPing(), getRecvPong(), + "recv_ping", "recv_pong")) || + failed(verifyCommGlobalGroup(*this, getGroup(), "group"))) + return failure(); + if (getRoot() >= static_cast(getGroup().size())) + return emitOpError("expects root to index into group operands"); + if (getElemTy(getDst().getType()) != getElemTy(getGroup().front().getType())) + return emitOpError("expects dst element type to match group member type"); + if (getAcc().getType() != getRecvPing().getType()) + return emitOpError("expects acc and recv_ping to have identical types"); + if (getElemTy(getAcc().getType()) != getElemTy(getDst().getType())) + return emitOpError("expects accumulator/receive tiles to match dst element type"); + return success(); +} + +LogicalResult AicInitializePipeOp::verify() { + return verifyFrontendInitCommon(*this, FunctionKernelKind::Cube, "cube"); +} + +LogicalResult AivInitializePipeOp::verify() { + return verifyFrontendInitCommon(*this, FunctionKernelKind::Vector, "vector"); +} + +LogicalResult TAllocToAivOp::verify() { + if (failed(verifyFrontendSplitOp(getOperation(), FunctionKernelKind::Cube, + "cube", getId(), getSplit()))) + return failure(); + if (failed(verifyFrontendDataOpDirection(getOperation(), getId(), + /*expectC2V=*/true))) + return failure(); + return verifyFrontendTensorEntryMatchesInit(getOperation(), getId(), + getEntry().getType()); +} + +LogicalResult TAllocToAicOp::verify() { + if (failed(verifyFrontendSplitOp(getOperation(), FunctionKernelKind::Vector, + "vector", getId(), getSplit()))) + return failure(); + if (failed(verifyFrontendDataOpDirection(getOperation(), getId(), + /*expectC2V=*/false))) + return failure(); + return verifyFrontendTensorEntryMatchesInit(getOperation(), getId(), + getEntry().getType()); +} + +LogicalResult TPushToAivOp::verify() { + if (failed(verifyFrontendSplitOp(getOperation(), FunctionKernelKind::Cube, + "cube", getId(), getSplit()))) + return failure(); + if (failed(verifyFrontendDataOpDirection(getOperation(), getId(), + /*expectC2V=*/true))) + return failure(); + return verifyFrontendTensorEntryMatchesInit(getOperation(), getId(), + getTile().getType()); +} + +LogicalResult TPushToAicOp::verify() { + if (failed(verifyFrontendSplitOp(getOperation(), FunctionKernelKind::Vector, + "vector", getId(), getSplit()))) + return failure(); + if (failed(verifyFrontendDataOpDirection(getOperation(), getId(), + /*expectC2V=*/false))) + return failure(); + return verifyFrontendTensorEntryMatchesInit(getOperation(), getId(), + getTile().getType()); +} + +LogicalResult TPopFromAicOp::verify() { + return verifyFrontendPopOp(*this, FunctionKernelKind::Vector, "vector", + /*expectC2V=*/true); +} + +LogicalResult TPopFromAivOp::verify() { + return verifyFrontendPopOp(*this, FunctionKernelKind::Cube, "cube", + /*expectC2V=*/false); +} + +LogicalResult TFreeFromAicOp::verify() { + if (failed(verifyFrontendSplitOp(getOperation(), FunctionKernelKind::Vector, + "vector", getId(), getSplit()))) + return failure(); + if (failed(verifyFrontendDataOpDirection(getOperation(), getId(), + /*expectC2V=*/true))) + return failure(); + if (getEntry()) + return verifyFrontendTensorEntryMatchesInit(getOperation(), getId(), + getEntry().getType()); + return success(); +} + +LogicalResult TFreeFromAivOp::verify() { + if (failed(verifyFrontendSplitOp(getOperation(), FunctionKernelKind::Cube, + "cube", getId(), getSplit()))) + return failure(); + if (failed(verifyFrontendDataOpDirection(getOperation(), getId(), + /*expectC2V=*/false))) + return failure(); + if (getEntry()) + return verifyFrontendTensorEntryMatchesInit(getOperation(), getId(), + getEntry().getType()); + return success(); +} + +LogicalResult InitializeL2G2LPipeOp::verify() { + if (failed(verifyPipeShape(getOperation(), getDirMask(), getSlotSize(), + getSlotNum(), + getFlagBaseAttr() + ? std::optional(getFlagBaseAttr().getInt()) + : std::nullopt))) + return failure(); + + if (!getLocalAddr()) { + if (getPeerLocalAddr()) + return emitOpError("'peer_local_addr' requires 'local_addr'"); + if (getLocalSlotNumAttr()) + return emitOpError( + "'local_slot_num' is only allowed when 'local_addr' is present"); + return success(); + } + + if (auto localSlotNumAttr = getLocalSlotNumAttr()) { + int32_t localSlotNum = localSlotNumAttr.getInt(); + if (localSlotNum <= 0) + return emitOpError("expects 'local_slot_num' to be greater than 0"); + if (static_cast(localSlotNum) > getSlotNum()) + return emitOpError( + "expects 'local_slot_num' to be less than or equal to slot_num"); + } + + if (getDirMask() == 3 && !getPeerLocalAddr()) + return emitOpError("expects 'peer_local_addr' when dir_mask is 3"); + if (getDirMask() != 3 && getPeerLocalAddr()) + return emitOpError("'peer_local_addr' is only allowed when dir_mask is 3"); + return success(); +} + +LogicalResult InitializeL2LPipeOp::verify() { + if (failed(verifyPipeShape(getOperation(), getDirMask(), getSlotSize(), + getSlotNum(), + getFlagBaseAttr() + ? std::optional(getFlagBaseAttr().getInt()) + : std::nullopt))) + return failure(); + + if (getDirMask() == 3 && !getPeerLocalAddr()) + return emitOpError("expects 'peer_local_addr' when dir_mask is 3"); + if (getDirMask() != 3 && getPeerLocalAddr()) + return emitOpError("'peer_local_addr' is only allowed when dir_mask is 3"); + return success(); +} + +LogicalResult TPushOp::verify() { + if (!isInsideSectionOrAttributedKernel(getOperation())) + return emitOpError("must be inside pto.section.cube/vector or a kernel_kind function"); + if (failed(verifyPipeHandleProducer(getOperation(), getPipeHandle()))) + return failure(); + if (failed(verifySplitAttr(getOperation(), getSplit()))) + return failure(); + if (failed(verifyTensorEntryMatchesInternalPipeInit( + getOperation(), getPipeHandle(), getTile().getType()))) + return failure(); + if (!isa(getTile().getType()) && + getPipe() == pto::PIPE::PIPE_UNASSIGNED) + return emitOpError("tile type must map to a supported producer pipe"); + return success(); +} + +LogicalResult TAllocOp::verify() { + if (!isInsideSectionOrAttributedKernel(getOperation())) + return emitOpError("must be inside pto.section.cube/vector or a kernel_kind function"); + if (failed(verifyPipeHandleProducer(getOperation(), getPipeHandle()))) + return failure(); + if (failed(verifyTensorEntryMatchesInternalPipeInit( + getOperation(), getPipeHandle(), getEntry().getType()))) + return failure(); + return verifySplitAttr(getOperation(), getSplit()); +} + +LogicalResult TPopOp::verify() { + if (!isInsideSectionOrAttributedKernel(getOperation())) + return emitOpError("must be inside pto.section.cube/vector or a kernel_kind function"); + if (failed(verifyPipeHandleProducer(getOperation(), getPipeHandle()))) + return failure(); + if (failed(verifySplitAttr(getOperation(), getSplit()))) + return failure(); + if (failed(verifyTensorEntryMatchesInternalPipeInit( + getOperation(), getPipeHandle(), getTile().getType()))) + return failure(); + if (!isa(getTile().getType()) && + getPipe() == pto::PIPE::PIPE_UNASSIGNED) + return emitOpError( + "tile type and target arch must map to a supported consumer pipe"); + return success(); +} + +LogicalResult TFreeOp::verify() { + if (!isInsideSectionOrAttributedKernel(getOperation())) + return emitOpError("must be inside pto.section.cube/vector or a kernel_kind function"); + if (failed(verifyPipeHandleProducer(getOperation(), getPipeHandle()))) + return failure(); + if (getEntry() && + failed(verifyTensorEntryMatchesInternalPipeInit( + getOperation(), getPipeHandle(), getEntry().getType()))) + return failure(); + return verifySplitAttr(getOperation(), getSplit()); +} + +ParseResult TFreeOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand first; + OpAsmParser::UnresolvedOperand pipe; + Type firstTy; + Type pipeTy; + bool hasEntry = false; + + if (parser.parseLParen() || parser.parseOperand(first)) + return failure(); + + if (succeeded(parser.parseOptionalComma())) { + hasEntry = true; + if (parser.parseOperand(pipe) || parser.parseColonType(firstTy) || + parser.parseComma() || parser.parseType(pipeTy) || parser.parseRParen()) + return failure(); + } else { + if (parser.parseColonType(pipeTy) || parser.parseRParen()) + return failure(); + pipe = first; + } + + NamedAttrList attrs; + if (parser.parseLBrace() || parser.parseKeyword("split") || + parser.parseEqual()) + return failure(); + IntegerAttr splitAttr; + if (parser.parseAttribute(splitAttr, parser.getBuilder().getI8Type(), + "split", attrs) || + parser.parseRBrace() || parser.parseOptionalAttrDict(attrs)) + return failure(); + + result.addAttributes(attrs); + if (hasEntry && + parser.resolveOperand(first, firstTy, result.operands)) + return failure(); + if (parser.resolveOperand(pipe, pipeTy, result.operands)) + return failure(); + return success(); +} + +void TFreeOp::print(OpAsmPrinter &p) { + p << "("; + if (getEntry()) { + p << getEntry() << ", " << getPipeHandle() << " : " + << getEntry().getType() << ", " << getPipeHandle().getType(); + } else { + p << getPipeHandle() << " : " << getPipeHandle().getType(); + } + p << ") {split = " << static_cast(getSplit()) << "}"; + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"split"}); +} + +void BuildAsyncSessionOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getScratchMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getWorkspaceMutable(), MemoryEffects::Read::get()); + addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); +} + +void TPutAsyncOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); + addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getSessionMutable(), MemoryEffects::Read::get()); + addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); +} + +void TGetAsyncOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); + addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getSessionMutable(), MemoryEffects::Read::get()); + addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); +} + +void TPutOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); + addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getPingMutable(), MemoryEffects::Read::get()); +} + +void TGetOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); + addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getPingMutable(), MemoryEffects::Read::get()); +} + +void TNotifyOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getSignalMutable(), MemoryEffects::Write::get()); + addEffect(effects, &getValueMutable(), MemoryEffects::Read::get()); +} + +void TWaitOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getSignalMutable(), MemoryEffects::Write::get()); + addEffect(effects, &getCmpValueMutable(), MemoryEffects::Read::get()); +} + +void TTestOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getSignalMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getCmpValueMutable(), MemoryEffects::Read::get()); + addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); +} + +void TBroadcastOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getPingMutable(), MemoryEffects::Write::get()); + if (getPong()) { + auto pongRange = getPongMutable(); + if (auto it = pongRange.begin(); it != pongRange.end()) + addEffect(effects, &*it, MemoryEffects::Write::get()); + } +} + +void CommTGatherOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); + addEffect(effects, &getPingMutable(), MemoryEffects::Read::get()); +} + +void CommTScatterOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getPingMutable(), MemoryEffects::Write::get()); + if (getPong()) { + auto pongRange = getPongMutable(); + if (auto it = pongRange.begin(); it != pongRange.end()) + addEffect(effects, &*it, MemoryEffects::Write::get()); + } +} + +void TReduceOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); + addEffect(effects, &getAccMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getAccMutable(), MemoryEffects::Write::get()); + addEffect(effects, &getRecvPingMutable(), MemoryEffects::Write::get()); +} + +void WaitAsyncEventOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getEventMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getSessionMutable(), MemoryEffects::Read::get()); + addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); +} + +void TestAsyncEventOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getEventMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getSessionMutable(), MemoryEffects::Read::get()); + addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); +} + +void InitializeL2G2LPipeOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getGmAddrMutable(), MemoryEffects::Read::get()); + auto localAddr = getLocalAddrMutable(); + if (!localAddr.empty()) + addEffect(effects, &*localAddr.begin(), MemoryEffects::Read::get()); + auto peerLocalAddr = getPeerLocalAddrMutable(); + if (!peerLocalAddr.empty()) + addEffect(effects, &*peerLocalAddr.begin(), MemoryEffects::Read::get()); + addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); +} + +void InitializeL2LPipeOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getLocalAddrMutable(), MemoryEffects::Read::get()); + addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); +} + +void TPushOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getTileMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getPipeHandleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getPipeHandleMutable(), MemoryEffects::Write::get()); +} + +void TAllocOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getEntryMutable(), MemoryEffects::Write::get()); + addEffect(effects, &getPipeHandleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getPipeHandleMutable(), MemoryEffects::Write::get()); +} + +void TPopOp::getEffects( + SmallVectorImpl> + &effects) { + addEffect(effects, &getPipeHandleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getPipeHandleMutable(), MemoryEffects::Write::get()); + addEffect(effects, &getTileMutable(), MemoryEffects::Write::get()); +} + +void TFreeOp::getEffects( + SmallVectorImpl> + &effects) { + auto entry = getEntryMutable(); + if (!entry.empty()) + addEffect(effects, &*entry.begin(), MemoryEffects::Read::get()); + addEffect(effects, &getPipeHandleMutable(), MemoryEffects::Read::get()); + addEffect(effects, &getPipeHandleMutable(), MemoryEffects::Write::get()); +} + +// [Include 必须放在最后] +#include "PTO/IR/PTOInterfaces.cpp.inc" +#define GET_OP_CLASSES +#include "PTO/IR/PTOOps.cpp.inc" diff --git a/lib/PTO/IR/PTO.def b/lib/PTO/IR/PTO.def deleted file mode 100644 index 376b9c017..000000000 --- a/lib/PTO/IR/PTO.def +++ /dev/null @@ -1,12933 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -//===- PTO.cpp - PTO Dialect ----------------------------------------------===// -//===----------------------------------------------------------------------===// - -#include "PTO/IR/PTO.h" -#include "PTO/IR/PTOTypeUtils.h" -#include "PTO/IR/PTOSyncUtils.h" - -#include "mlir/AsmParser/AsmParser.h" -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/LLVMIR/LLVMTypes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/DialectImplementation.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/IR/Value.h" -#include "mlir/IR/AsmState.h" -#include "mlir/IR/SymbolTable.h" -#include "mlir/IR/Types.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Parser/Parser.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/TypeSwitch.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "llvm/Support/ErrorHandling.h" - -#include -#include -#include -#include -#include - -using namespace mlir; -using namespace mlir::pto; - -// Forward declarations for custom shape/type printers used by tensor_view and -// partition_tensor_view. -namespace mlir { -namespace pto { -static LogicalResult parseShapeAndElem(AsmParser &parser, - SmallVectorImpl &shape, - Type &elementType, - bool allowDynamic = true); -static void printShapeAndElem(AsmPrinter &printer, - ArrayRef shape, - Type elementType); -} // namespace pto -} // namespace mlir - -// ============================================================================= -// TileBufType 的自定义 Shape 解析与打印函数 -// ============================================================================= - -// 解析逻辑:解析形如 "32x32" 的维度列表 -[[maybe_unused]] static ParseResult parseShape(AsmParser &parser, SmallVectorImpl &shape) { - // parseDimensionList 会解析 "dim x dim x ...", 遇到无法解析为维度的字符停止 - // 参数 allowDynamic=true (允许 ?), withTrailingX=false (不吞掉末尾的 x) - if (parser.parseDimensionList(shape, /*allowDynamic=*/true, /*withTrailingX=*/false)) - return failure(); - return success(); -} - -// 打印逻辑:打印形如 "32x32" 的维度列表 -[[maybe_unused]] static void printShape(AsmPrinter &printer, ArrayRef shape) { - for (auto it = shape.begin(); it != shape.end(); ++it) { - if (it != shape.begin()) printer << "x"; // 维度间的分隔符 - if (*it == ShapedType::kDynamic) - printer << "?"; - else - printer << *it; - } - // 注意:我们不在这里打印末尾的 'x',因为 assemblyFormat 中已经写了 `x` $elementType -} - -static std::optional getPTOMemorySpaceEnum(Type ty); -enum class VerifierTargetArch { - A2A3, - A5, -}; -static VerifierTargetArch getVerifierTargetArch(Operation *op); -static std::optional getVerifierArchName(Operation *op); -static bool isSupportedVecElemType(Type ty, bool allowBf16 = true, - bool allowInt8 = true); -static bool isSupportedLoadStoreElemTypeA2A3(Type ty); -static bool isSupportedGatherElemTypeA2A3(Type ty); -static bool isSupportedGatherElemTypeA5(Type ty); -static bool isA5TLoadStoreTransferElemType(Type ty); -static bool isA5AccStorePreQuantDstType(Type srcElem, Type dstElem); -static bool isA5LowPrecisionTCvtPair(Type srcElem, Type dstElem); -static bool isA5SupportedTCvtPair(Type srcElem, Type dstElem); -static ParseResult parseSyncEventOpCommon(OpAsmParser &parser, - OperationState &result, - StringAttr pipeAttrName, - StringAttr eventIdAttrName); -static void printSyncEventOpCommon(OpAsmPrinter &p, Operation *op, - PipeAttr pipeAttr, IntegerAttr eventAttr, - Value eventDyn, StringRef pipeAttrName, - StringRef eventIdAttrName); -static bool isTileLikeType(Type ty); -static SmallVector getShapeVec(Type ty); -static SmallVector getValidShapeVec(Type ty); -static SmallVector getValidShapeVec(Value value); -static bool isByteIntegerType(Type ty); -static LogicalResult verifyTileBufCommon(Operation *op, Type ty, StringRef name, - bool allowLowPrecision = false); -static LogicalResult verifyTileBufSameElemType(Operation *op, Type lhs, Type rhs, - StringRef lhsName, - StringRef rhsName); -static LogicalResult verifyTileBufSameLogicalExtent(Operation *op, Type lhs, - Type rhs, StringRef lhsName, - StringRef rhsName, - bool compareValidShape); - -static LogicalResult verifyTileBufSameValidShape(Operation *op, Type lhs, Type rhs, - StringRef lhsName, StringRef rhsName); -static LogicalResult verifyVecTileCommon(Operation *op, Type ty, StringRef name); -static LogicalResult verifyVecTileCommonA2A3(Operation *op, Type ty, - StringRef name); -static LogicalResult verifyVecTileCommonA5(Operation *op, Type ty, - StringRef name); -static LogicalResult verifyVecTileUnaryOp(Operation *op, Type srcTy, Type dstTy, - StringRef srcName = "src", - StringRef dstName = "dst", - bool allowBf16 = true, - bool allowInt8 = true); -static LogicalResult verifyAccTileCommon(Operation *op, Type ty, StringRef name); -static LogicalResult verifyAccTileCommonA2A3(Operation *op, Type ty, - StringRef name); -static LogicalResult verifyAccTileCommonA5(Operation *op, Type ty, - StringRef name); -static LogicalResult verifyMatTileOperands(Operation *op, Type lhsTy, Type rhsTy, - Type dstTy); -static LogicalResult verifyMatTileOperandsA2A3(Operation *op, Type lhsTy, - Type rhsTy, Type dstTy); -static LogicalResult verifyMatTileOperandsA5(Operation *op, Type lhsTy, - Type rhsTy, Type dstTy); -static LogicalResult verifyGemvTileOperands(Operation *op, Type lhsTy, Type rhsTy, - Type dstTy); -static LogicalResult verifyAsyncFlatContiguous1DGMViewLike(Operation *op, - Value value, - StringRef name); -static LogicalResult verifyGemvTileOperandsA2A3(Operation *op, Type lhsTy, - Type rhsTy, Type dstTy); -static LogicalResult verifyGemvTileOperandsA5(Operation *op, Type lhsTy, - Type rhsTy, Type dstTy); -static LogicalResult verifyMatBiasTile(Operation *op, Type biasTy, Type dstTy, - bool requireFloatBias = false); -static LogicalResult verifyMatBiasTileA2A3(Operation *op, Type biasTy, Type dstTy, - bool requireFloatBias = false); -static LogicalResult verifyMatBiasTileA5(Operation *op, Type biasTy, Type dstTy, - bool requireFloatBias = false); -static LogicalResult verifyMatmulTypeTriple(Operation *op, Type lhsElemTy, - Type rhsElemTy, Type dstElemTy); -static std::optional getLogicalViewLayout(Value value); -static std::optional getTileBufLogicalLayout(pto::TileBufType type); -static std::optional getConstantIntegerValue(Value value); -static LogicalResult verifyPartialValidPattern(Operation *op, Type src0Ty, - Type src1Ty, Type dstTy); -static Type getElemTy(Type ty); -static FailureOr -verifyMatchingRowMajorBinaryTileOpCommon(Operation *op, Type src0Ty, - Type src1Ty, Type dstTy); -static FailureOr -verifyNumericScalarTileOpCommon(Operation *op, Type srcTy, Type dstTy, - Type scalarTy, bool requireValidRowsEqual); -static FailureOr -verifyShiftLikeBinaryTileOpCommon(Operation *op, Type src0Ty, Type src1Ty, - Type dstTy); -static LogicalResult verifyArithmeticElemTypeForArch( - Operation *op, Type elemTy, PTOArch targetArch, bool allowInt8OnA5, - bool allowBf16OnA5, StringRef a2a3Error, StringRef a5Error); -static bool isRowMajorTileBuf(Type ty); - -#define GET_ENUM_CLASSES -#include "PTO/IR/PTOEnums.cpp.inc" - -#define GET_TYPEDEF_CLASSES -#include "PTO/IR/PTOTypeDefs.cpp.inc" - -#define GET_ATTRDEF_CLASSES -#include "PTO/IR/PTOAttrs.cpp.inc" - -#include "PTO/IR/PTODialect.cpp.inc" - -[[maybe_unused]] static LogicalResult parseShapeAndElemStable(mlir::AsmParser &parser, - llvm::SmallVectorImpl &shape, - mlir::Type &elementType) { - if (failed(parser.parseLess())) - return failure(); - - if (failed(parser.parseDimensionList(shape, /*allowDynamic=*/true))) - return failure(); - - if (failed(parser.parseType(elementType))) - return failure(); - - if (failed(parser.parseGreater())) - return failure(); - - return success(); -} - -static int64_t getPTOTypeRank(Type type) { - // 1. 处理标准的 MLIR 类型 (MemRef, Tensor, Vector) - if (auto shapedTy = dyn_cast(type)) { - if (shapedTy.hasRank()) - return shapedTy.getRank(); - return -1; // Unranked type - } - - // 2. 处理 PTO 自定义类型 - if (auto tvTy = dyn_cast(type)) - return tvTy.getRank(); - - if (auto tileTy = dyn_cast(type)) - return tileTy.getRank(); - - if (auto tileViewTy = dyn_cast(type)) - return tileViewTy.getRank(); - - if (auto tileBufTy = dyn_cast(type)) - return tileBufTy.getRank(); - - // 3. 不支持的类型 - return -1; -} - -static bool isGmAddressSpaceAttr(Attribute memorySpace) { - if (!memorySpace) - return true; - if (auto addr = mlir::dyn_cast(memorySpace)) - return addr.getAddressSpace() == pto::AddressSpace::GM; - if (auto intAttr = mlir::dyn_cast(memorySpace)) - return intAttr.getInt() == 0; - return false; -} - -PTOArch mlir::pto::getTargetArch(ModuleOp module) { - if (!module) - return PTOArch::A3; - - auto arch = module->getAttrOfType(kPTOTargetArchAttrName); - if (arch && arch.getValue().equals_insensitive("a5")) - return PTOArch::A5; - return PTOArch::A3; -} - -PTOArch mlir::pto::getTargetArch(Operation *op) { - if (!op) - return PTOArch::A3; - return getTargetArch(op->getParentOfType()); -} - -bool mlir::pto::isTargetArchA3(ModuleOp module) { - return getTargetArch(module) == PTOArch::A3; -} - -bool mlir::pto::isTargetArchA5(ModuleOp module) { - return getTargetArch(module) == PTOArch::A5; -} - -bool mlir::pto::isTargetArchA3(Operation *op) { - return getTargetArch(op) == PTOArch::A3; -} - -bool mlir::pto::isTargetArchA5(Operation *op) { - return getTargetArch(op) == PTOArch::A5; -} - -static llvm::TypeSize getOneByteTypeSize() { - return llvm::TypeSize::getFixed(8); -} - -llvm::TypeSize mlir::pto::HiF8Type::getTypeSizeInBits( - const DataLayout &, DataLayoutEntryListRef) const { - return getOneByteTypeSize(); -} - -uint64_t mlir::pto::HiF8Type::getABIAlignment(const DataLayout &, - DataLayoutEntryListRef) const { - return 1; -} - -uint64_t mlir::pto::HiF8Type::getPreferredAlignment( - const DataLayout &, DataLayoutEntryListRef) const { - return 1; -} - -llvm::TypeSize mlir::pto::F4E1M2x2Type::getTypeSizeInBits( - const DataLayout &, DataLayoutEntryListRef) const { - return getOneByteTypeSize(); -} - -uint64_t mlir::pto::F4E1M2x2Type::getABIAlignment( - const DataLayout &, DataLayoutEntryListRef) const { - return 1; -} - -uint64_t mlir::pto::F4E1M2x2Type::getPreferredAlignment( - const DataLayout &, DataLayoutEntryListRef) const { - return 1; -} - -llvm::TypeSize mlir::pto::F4E2M1x2Type::getTypeSizeInBits( - const DataLayout &, DataLayoutEntryListRef) const { - return getOneByteTypeSize(); -} - -uint64_t mlir::pto::F4E2M1x2Type::getABIAlignment( - const DataLayout &, DataLayoutEntryListRef) const { - return 1; -} - -uint64_t mlir::pto::F4E2M1x2Type::getPreferredAlignment( - const DataLayout &, DataLayoutEntryListRef) const { - return 1; -} - -static VerifierTargetArch getVerifierTargetArch(Operation *op) { - if (auto archName = getVerifierArchName(op)) { - return archName->equals_insensitive("a5") ? VerifierTargetArch::A5 - : VerifierTargetArch::A2A3; - } - - switch (getPTOParserTargetArch(op ? op->getContext() : nullptr)) { - case PTOParserTargetArch::A5: - return VerifierTargetArch::A5; - case PTOParserTargetArch::A3: - case PTOParserTargetArch::Unspecified: - return VerifierTargetArch::A2A3; - } - - return VerifierTargetArch::A2A3; -} - -static std::optional getVerifierArchName(Operation *op) { - auto module = op ? op->getParentOfType() : ModuleOp(); - if (!module) - return std::nullopt; - if (auto arch = module->getAttrOfType(kPTOTargetArchAttrName)) - return arch.getValue(); - return std::nullopt; -} - -static bool shouldBypassDecodedMemrefVerifier(Operation *op) { - if (!op) - return false; - for (Value operand : op->getOperands()) { - if (isa(operand.getType())) - return true; - if (operand.getDefiningOp()) - return true; - } - return false; -} - -static SmallVector canonicalizeTileBufValidShape(ArrayRef validShape) { - SmallVector canonical; - canonical.reserve(validShape.size()); - for (int64_t dim : validShape) - canonical.push_back(dim < 0 ? ShapedType::kDynamic : dim); - return canonical; -} - -template -static LogicalResult dispatchVerifierByArch(Operation *op, FnA2A3 &&verifyA2A3, - FnA5 &&verifyA5) { - if (shouldBypassDecodedMemrefVerifier(op)) - return success(); - switch (getVerifierTargetArch(op)) { - case VerifierTargetArch::A2A3: - return verifyA2A3(); - case VerifierTargetArch::A5: - return verifyA5(); - } - return failure(); -} - -static ParseResult parseSyncEventOpCommon(OpAsmParser &parser, - OperationState &result, - StringAttr pipeAttrName, - StringAttr eventIdAttrName) { - PipeAttr pipeAttr; - if (succeeded(parser.parseOptionalLess())) { - StringRef pipeTok; - if (parser.parseKeyword(&pipeTok) || parser.parseGreater()) - return failure(); - auto pipeOr = symbolizePIPE(pipeTok); - if (!pipeOr) - return parser.emitError(parser.getCurrentLocation()) - << "unknown pipe token: " << pipeTok; - pipeAttr = PipeAttr::get(parser.getContext(), *pipeOr); - result.addAttribute(pipeAttrName, pipeAttr); - } else if (parser.parseAttribute(pipeAttr, pipeAttrName, - result.attributes)) { - return failure(); - } - if (parser.parseComma()) - return failure(); - - OpAsmParser::UnresolvedOperand eventOperand; - OptionalParseResult parseEventOperand = - parser.parseOptionalOperand(eventOperand); - if (parseEventOperand.has_value()) { - if (failed(*parseEventOperand)) - return failure(); - if (parser.resolveOperand(eventOperand, parser.getBuilder().getIndexType(), - result.operands)) - return failure(); - } else { - IntegerAttr eventAttr; - if (parser.parseAttribute(eventAttr, parser.getBuilder().getI32Type(), - eventIdAttrName, result.attributes)) - return failure(); - } - - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - return success(); -} - -static void printSyncEventOpCommon(OpAsmPrinter &p, Operation *op, - PipeAttr pipeAttr, IntegerAttr eventAttr, - Value eventDyn, StringRef pipeAttrName, - StringRef eventIdAttrName) { - p << " <" << stringifyPIPE(pipeAttr.getPipe()) << ">, "; - if (eventAttr) - p << eventAttr.getInt(); - else - p << eventDyn; - p.printOptionalAttrDict(op->getAttrs(), {pipeAttrName, eventIdAttrName}); -} - -[[maybe_unused]] static mlir::Type parsePTOTypeAllowNoBang(mlir::OpAsmParser &parser) { - mlir::Type ty; - - mlir::OptionalParseResult opt = parser.parseOptionalType(ty); - - if (opt.has_value()) { - if (failed(*opt)) - return mlir::Type(); - return ty; - } - - - llvm::StringRef head; - if (failed(parser.parseKeyword(&head))) - return mlir::Type(); - - mlir::MLIRContext *ctx = parser.getContext(); - - auto parseShapeElemForOpParser = - [&](llvm::SmallVectorImpl &shape, mlir::Type &elem) -> mlir::LogicalResult { - if (failed(parser.parseLess())) - return failure(); - if (failed(parser.parseDimensionList(shape, /*allowDynamic=*/true))) - return failure(); - if (failed(parser.parseType(elem))) - return failure(); - if (failed(parser.parseGreater())) - return failure(); - return success(); - }; - - if (head == "pto.tile_view") { - llvm::SmallVector shape; - mlir::Type elem; - if (failed(parseShapeElemForOpParser(shape, elem))) - return mlir::Type(); - return mlir::pto::PartitionTensorViewType::get(ctx, shape, elem); - } - - if (head == "pto.tile") { - llvm::SmallVector shape; - mlir::Type elem; - if (failed(parseShapeElemForOpParser(shape, elem))) - return mlir::Type(); - return mlir::pto::TileType::get(ctx, shape, elem); - } - - if (head == "pto.ptr") { - if (failed(parser.parseLess())) - return mlir::Type(); - mlir::Type elem; - if (failed(parser.parseType(elem))) - return mlir::Type(); - if (succeeded(parser.parseOptionalComma())) { - // ptr no longer accepts an address space; consume the attr for recovery. - mlir::Attribute memorySpace; - (void)parser.parseAttribute(memorySpace); - parser.emitError(parser.getCurrentLocation(), - "!pto.ptr no longer accepts address space; use !pto.ptr"); - return mlir::Type(); - } - if (failed(parser.parseGreater())) - return mlir::Type(); - return mlir::pto::PtrType::get(ctx, elem); - } - - if (head == "pto.tensor_view") { - llvm::SmallVector shape; - mlir::Type elem; - if (failed(parseShapeElemForOpParser(shape, elem))) - return mlir::Type(); - return mlir::pto::TensorViewType::get(ctx, shape, elem); - } - - return mlir::Type(); -} - -mlir::Type TensorViewType::parse(::mlir::AsmParser &parser) { - SmallVector shape; - Type elementType; - if (failed(parseShapeAndElem(parser, shape, elementType, /*allowDynamic=*/true))) - return Type(); - return TensorViewType::get(parser.getContext(), shape, elementType); -} - -void TensorViewType::print(::mlir::AsmPrinter &printer) const { - printShapeAndElem(printer, getShape(), getElementType()); -} - -//===----------------------------------------------------------------------===// -// pto.tdivs custom asm to support both: -// pto.tdivs ins(%src, %scalar : !pto.tile_buf<...>, f32) outs(%dst : !pto.tile_buf<...>) -// pto.tdivs ins(%scalar, %src : f32, !pto.tile_buf<...>) outs(%dst : !pto.tile_buf<...>) -// The operand order in the op follows textual input order. -//===----------------------------------------------------------------------===// - -ParseResult mlir::pto::TDivSOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand op0, op1, dst; - Type ty0, ty1, dstTy; - - if (parser.parseKeyword("ins") || parser.parseLParen() || - parser.parseOperand(op0) || parser.parseComma() || - parser.parseOperand(op1) || parser.parseColonType(ty0) || - parser.parseComma() || parser.parseType(ty1) || parser.parseRParen()) - return failure(); - - if (parser.parseKeyword("outs") || parser.parseLParen() || - parser.parseOperand(dst) || parser.parseColonType(dstTy) || - parser.parseRParen()) - return failure(); - - NamedAttrList attrs; - if (parser.parseOptionalAttrDict(attrs)) - return failure(); - - auto tile0 = dyn_cast(ty0); - auto tile1 = dyn_cast(ty1); - if ((tile0 && tile1) || (!tile0 && !tile1)) - return parser.emitError(parser.getCurrentLocation(), - "expected exactly one tile_buf operand and one scalar operand"); - - if (!dyn_cast(dstTy)) - return parser.emitError(parser.getCurrentLocation(), - "expected outs type to be !pto.tile_buf<...>"); - - // Keep textual order so later lowering can distinguish the two APIs by the - // first ins operand type. - if (parser.resolveOperand(op0, ty0, result.operands) || - parser.resolveOperand(op1, ty1, result.operands)) - return failure(); - - if (parser.resolveOperand(dst, dstTy, result.operands)) - return failure(); - - result.addAttributes(attrs); - return success(); -} - -void mlir::pto::TDivSOp::print(OpAsmPrinter &p) { - p << " ins("; - p << getSrc() << ", " << getScalar() << " : " - << getSrc().getType() << ", " << getScalar().getType(); - p << ") outs(" << getDst() << " : " << getDst().getType() << ")"; - - p.printOptionalAttrDict((*this)->getAttrs()); -} - - -//===----------------------------------------------------------------------===// -// pto.tgather custom asm supports three PTO-ISA forms: -// 1) index+tmp : ins(%src, %indices, %tmp : srcTy, indicesTy, tmpTy) outs(%dst : dstTy) -// 2) compare+tmp : ins(%src, %kValue, %tmp : srcTy, scalarTy, tmpTy) -// outs(%dst, %cdst : dstTy, cdstTy) {cmpMode = #pto.cmp, offset = 7} -// 3) mask : ins(%src, {maskPattern = #pto.mask_pattern} : srcTy) outs(%dst : dstTy) -//===----------------------------------------------------------------------===// - -ParseResult mlir::pto::TGatherOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand src, dst, cdst; - SmallVector insOps; - SmallVector insTypes; - Type srcTy, dstTy, cdstTy; - bool hasCdst = false; - bool hasMask = false; - bool hasIndices = false; - bool hasTmp = false; - bool hasKValue = false; - - if (parser.parseKeyword("ins") || parser.parseLParen() || parser.parseOperand(src)) - return failure(); - - if (!succeeded(parser.parseOptionalComma())) { - return parser.emitError(parser.getCurrentLocation(), - "expected ',' after src operand in ins(...)"); - } - - if (succeeded(parser.parseOptionalLBrace())) { - if (parser.parseKeyword("maskPattern") || parser.parseEqual()) - return failure(); - - Attribute rawMaskAttr; - if (parser.parseAttribute(rawMaskAttr) || parser.parseRBrace()) - return failure(); - - auto mp = llvm::dyn_cast(rawMaskAttr); - if (!mp) { - return parser.emitError(parser.getCurrentLocation(), - "expected #pto.mask_pattern for maskPattern"); - } - - result.addAttribute("maskPattern", mp); - hasMask = true; - - if (parser.parseColonType(srcTy) || parser.parseRParen()) - return failure(); - } else { - OpAsmParser::UnresolvedOperand extra; - if (parser.parseOperand(extra)) - return failure(); - insOps.push_back(extra); - while (succeeded(parser.parseOptionalComma())) { - if (insOps.size() == 3) { - return parser.emitError(parser.getCurrentLocation(), - "expected at most 3 extra operands in tgather ins(...)"); - } - if (parser.parseOperand(extra)) - return failure(); - insOps.push_back(extra); - } - - if (parser.parseColon() || parser.parseType(srcTy)) - return failure(); - for (size_t i = 0; i < insOps.size(); ++i) { - Type ty; - if (parser.parseComma() || parser.parseType(ty)) - return failure(); - insTypes.push_back(ty); - } - if (parser.parseRParen()) - return failure(); - } - - if (parser.parseKeyword("outs") || parser.parseLParen() || parser.parseOperand(dst)) - return failure(); - if (succeeded(parser.parseOptionalComma())) { - if (parser.parseOperand(cdst)) - return failure(); - hasCdst = true; - } - if (parser.parseColonType(dstTy)) - return failure(); - if (hasCdst && (parser.parseComma() || parser.parseType(cdstTy))) - return failure(); - if (parser.parseRParen()) - return failure(); - - if (succeeded(parser.parseOptionalKeyword("maskPattern"))) { - if (hasMask) - return parser.emitError(parser.getCurrentLocation(), - "maskPattern may only be specified once"); - if (parser.parseEqual()) - return failure(); - Attribute rawMaskAttr; - if (parser.parseAttribute(rawMaskAttr)) - return failure(); - auto mp = llvm::dyn_cast(rawMaskAttr); - if (!mp) { - return parser.emitError(parser.getCurrentLocation(), - "expected #pto.mask_pattern for maskPattern"); - } - result.addAttribute("maskPattern", mp); - hasMask = true; - } - - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - - if (hasMask) { - if (!insOps.empty()) - return parser.emitError(parser.getCurrentLocation(), - "mask-pattern tgather does not take extra ins operands"); - if (hasCdst) - return parser.emitError(parser.getCurrentLocation(), - "mask-pattern tgather expects a single outs operand"); - } else if (hasCdst) { - if (insOps.empty() || - !(mlir::isa(insTypes.front()) || - mlir::isa(insTypes.front()))) - return parser.emitError(parser.getCurrentLocation(), - "compare-form tgather expects a scalar kValue operand"); - hasKValue = true; - if (insOps.size() >= 2) { - if (!isTileLikeType(insTypes[1])) - return parser.emitError(parser.getCurrentLocation(), - "compare-form tgather tmp must be tile-like"); - hasTmp = true; - } - if (insOps.size() == 3) { - return parser.emitError(parser.getCurrentLocation(), - "compare-form tgather expects at most src, kValue, tmp in ins(...)"); - } - } else { - if (!insOps.empty() && !isTileLikeType(insTypes.front())) { - return parser.emitError(parser.getCurrentLocation(), - "index-form tgather expects tile-like indices; " - "compare-form must use outs(dst, cdst)"); - } - if (!insOps.empty()) { - hasIndices = true; - if (insOps.size() >= 2) { - if (!isTileLikeType(insTypes[1])) - return parser.emitError(parser.getCurrentLocation(), - "index-form tgather tmp must be tile-like"); - hasTmp = true; - } - } - if (insOps.size() == 3) { - return parser.emitError(parser.getCurrentLocation(), - "index-form tgather expects at most src, indices, tmp in ins(...)"); - } - } - - if (parser.resolveOperand(src, srcTy, result.operands) || - parser.resolveOperand(dst, dstTy, result.operands)) - return failure(); - if (hasCdst && parser.resolveOperand(cdst, cdstTy, result.operands)) - return failure(); - if (hasIndices && parser.resolveOperand(insOps[0], insTypes[0], result.operands)) - return failure(); - if (hasTmp && parser.resolveOperand(insOps[hasIndices ? 1 : 1], insTypes[1], result.operands)) - return failure(); - if (hasKValue && parser.resolveOperand(insOps[0], insTypes[0], result.operands)) - return failure(); - - result.addAttribute("operandSegmentSizes", - parser.getBuilder().getDenseI32ArrayAttr( - {1, 1, hasCdst ? 1 : 0, hasIndices ? 1 : 0, - hasTmp ? 1 : 0, hasKValue ? 1 : 0})); - return success(); -} - -void mlir::pto::TGatherOp::print(OpAsmPrinter &p) { - p << " ins(" << getSrc() << ", "; - if (auto mp = getMaskPatternAttr()) { - p << "{maskPattern = " << mp << "} : " << getSrc().getType(); - } else if (getCdst()) { - p << getKValue(); - if (getTmp()) { - p << ", " << getTmp(); - p << " : " << getSrc().getType() << ", " << getKValue().getType() - << ", " << getTmp().getType(); - } else { - p << " : " << getSrc().getType() << ", " << getKValue().getType(); - } - } else { - p << getIndices(); - if (getTmp()) { - p << ", " << getTmp(); - p << " : " << getSrc().getType() << ", " << getIndices().getType() - << ", " << getTmp().getType(); - } else { - p << " : " << getSrc().getType() << ", " << getIndices().getType(); - } - } - p << ") outs(" << getDst(); - if (getCdst()) - p << ", " << getCdst(); - p << " : " << getDst().getType(); - if (getCdst()) - p << ", " << getCdst().getType(); - p << ")"; - - if (getMaskPatternAttr()) { - p.printOptionalAttrDict((*this)->getAttrs(), - /*elidedAttrs=*/{"maskPattern", "operandSegmentSizes"}); - } else { - p.printOptionalAttrDict((*this)->getAttrs(), - /*elidedAttrs=*/{"operandSegmentSizes"}); - } -} - -ParseResult mlir::pto::TScatterOp::parse(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::UnresolvedOperand src, indexes, dst; - Type srcTy, idxTy, dstTy; - bool hasMask = false; - bool hasIndexes = false; - - if (parser.parseKeyword("ins") || parser.parseLParen() || - parser.parseOperand(src)) - return failure(); - - if (!succeeded(parser.parseOptionalComma())) - return parser.emitError(parser.getCurrentLocation(), - "expected ',' after src operand in ins(...)"); - - if (succeeded(parser.parseOptionalLBrace())) { - if (parser.parseKeyword("maskPattern") || parser.parseEqual()) - return failure(); - Attribute rawMaskAttr; - if (parser.parseAttribute(rawMaskAttr) || parser.parseRBrace()) - return failure(); - auto mp = llvm::dyn_cast(rawMaskAttr); - if (!mp) - return parser.emitError(parser.getCurrentLocation(), - "expected #pto.mask_pattern for maskPattern"); - result.addAttribute("maskPattern", mp); - hasMask = true; - if (parser.parseColonType(srcTy) || parser.parseRParen()) - return failure(); - } else { - if (parser.parseOperand(indexes)) - return failure(); - hasIndexes = true; - if (parser.parseColon() || parser.parseType(srcTy) || parser.parseComma() || - parser.parseType(idxTy) || parser.parseRParen()) - return failure(); - } - - if (parser.parseKeyword("outs") || parser.parseLParen() || - parser.parseOperand(dst) || parser.parseColonType(dstTy) || - parser.parseRParen()) - return failure(); - - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - - if (result.attributes.get("maskPattern")) - hasMask = true; - - if (hasMask && hasIndexes) - return parser.emitError(parser.getCurrentLocation(), - "mask-pattern tscatter does not take indexes"); - if (!hasMask && !hasIndexes) - return parser.emitError(parser.getCurrentLocation(), - "expected indexes operand or maskPattern for tscatter"); - - if (parser.resolveOperand(src, srcTy, result.operands) || - parser.resolveOperand(dst, dstTy, result.operands) || - (hasIndexes && parser.resolveOperand(indexes, idxTy, result.operands))) - return failure(); - return success(); -} - -void mlir::pto::TScatterOp::print(OpAsmPrinter &p) { - p << " ins(" << getSrc() << ", "; - if (getMaskPatternAttr()) { - p << "{maskPattern = " << getMaskPatternAttr() << "} : " - << getSrc().getType(); - } else { - p << getIndexes() << " : " << getSrc().getType() << ", " - << getIndexes().getType(); - } - p << ") outs(" << getDst() << " : " << getDst().getType() << ")"; - p.printOptionalAttrDict((*this)->getAttrs(), - /*elidedAttrs=*/{"maskPattern"}); -} - -namespace { - -struct CommRecvClause { - OpAsmParser::UnresolvedOperand ping; - std::optional pong; - Type pingTy; - Type pongTy; -}; - -static ParseResult parseCommRecvClause(OpAsmParser &parser, - CommRecvClause &recvClause) { - if (parser.parseKeyword("recv") || parser.parseLParen() || - parser.parseOperand(recvClause.ping)) - return failure(); - if (succeeded(parser.parseOptionalComma())) { - OpAsmParser::UnresolvedOperand pong; - if (parser.parseOperand(pong)) - return failure(); - recvClause.pong = pong; - } - return parser.parseRParen(); -} - -static ParseResult parseCommCollectiveTail( - OpAsmParser &parser, OperationState &result, - ArrayRef fixedOperands, - SmallVectorImpl &fixedTypes, CommRecvClause &recvClause, - SmallVectorImpl &groupOps, - SmallVectorImpl &groupTypes, ArrayRef operandSegmentsPrefix, - ArrayRef requiredAttrs) { - if (parser.parseComma() || parser.parseKeyword("group") || parser.parseLParen()) - return failure(); - - OpAsmParser::UnresolvedOperand group; - if (parser.parseOperand(group)) - return failure(); - groupOps.push_back(group); - while (succeeded(parser.parseOptionalComma())) { - if (parser.parseOperand(group)) - return failure(); - groupOps.push_back(group); - } - - if (parser.parseRParen()) - return failure(); - - if (parser.parseColon()) - return failure(); - - for (size_t i = 0; i < fixedTypes.size(); ++i) { - if (i != 0 && parser.parseComma()) - return failure(); - if (parser.parseType(fixedTypes[i])) - return failure(); - } - if (parser.parseComma() || parser.parseType(recvClause.pingTy)) - return failure(); - if (recvClause.pong) { - if (parser.parseComma() || parser.parseType(recvClause.pongTy)) - return failure(); - } - for (size_t i = 0; i < groupOps.size(); ++i) { - Type groupTy; - if (parser.parseComma() || parser.parseType(groupTy)) - return failure(); - groupTypes.push_back(groupTy); - } - if (parser.parseRParen()) - return failure(); - - NamedAttrList attrs; - if (parser.parseOptionalAttrDict(attrs)) - return failure(); - for (StringRef attrName : requiredAttrs) { - if (!attrs.get(attrName)) { - return parser.emitError(parser.getCurrentLocation()) - << "expected '" << attrName << "' attribute"; - } - } - result.addAttributes(attrs); - - for (auto [operand, type] : llvm::zip_equal(fixedOperands, fixedTypes)) { - if (parser.resolveOperand(operand, type, result.operands)) - return failure(); - } - if (parser.resolveOperand(recvClause.ping, recvClause.pingTy, result.operands)) - return failure(); - if (recvClause.pong && - parser.resolveOperand(*recvClause.pong, recvClause.pongTy, result.operands)) - return failure(); - if (parser.resolveOperands(groupOps, groupTypes, parser.getCurrentLocation(), - result.operands)) - return failure(); - - SmallVector segmentSizes(operandSegmentsPrefix.begin(), - operandSegmentsPrefix.end()); - segmentSizes.push_back(static_cast(groupOps.size())); - result.addAttribute("operandSegmentSizes", - parser.getBuilder().getDenseI32ArrayAttr(segmentSizes)); - return success(); -} - -static void printCommRecvClause(OpAsmPrinter &p, Value ping, Value pong) { - p << "recv(" << ping; - if (pong) - p << ", " << pong; - p << ")"; -} - -static void printCommGroupTypes(OpAsmPrinter &p, ValueRange group) { - for (Value groupValue : group) - p << ", " << groupValue.getType(); -} - -static void printCommGroupClause(OpAsmPrinter &p, ValueRange group) { - p << "group("; - p.printOperands(group); - p << ")"; -} - -} // namespace - -ParseResult mlir::pto::TBroadcastOp::parse(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::UnresolvedOperand src; - CommRecvClause recvClause; - SmallVector groupOps; - SmallVector groupTypes; - - if (parser.parseLParen() || parser.parseOperand(src) || parser.parseComma()) - return failure(); - if (failed(parseCommRecvClause(parser, recvClause))) - return failure(); - - SmallVector fixedOperands{src}; - SmallVector fixedTypes(1); - if (failed(parseCommCollectiveTail(parser, result, fixedOperands, fixedTypes, - recvClause, groupOps, groupTypes, - {1, 1, recvClause.pong ? 1 : 0}, {"root"}))) - return failure(); - return success(); -} - -void mlir::pto::TBroadcastOp::print(OpAsmPrinter &p) { - p << "(" << getSrc() << ", "; - printCommRecvClause(p, getPing(), getPong()); - p << ", "; - printCommGroupClause(p, getGroup()); - p << " : " << getSrc().getType() << ", " << getPing().getType(); - if (getPong()) - p << ", " << getPong().getType(); - printCommGroupTypes(p, getGroup()); - p << ")"; - p.printOptionalAttrDict((*this)->getAttrs(), - /*elidedAttrs=*/{"operandSegmentSizes"}); -} - -ParseResult mlir::pto::CommTGatherOp::parse(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::UnresolvedOperand dst; - CommRecvClause recvClause; - SmallVector groupOps; - SmallVector groupTypes; - - if (parser.parseLParen() || parser.parseOperand(dst) || parser.parseComma()) - return failure(); - if (failed(parseCommRecvClause(parser, recvClause))) - return failure(); - - SmallVector fixedOperands{dst}; - SmallVector fixedTypes(1); - if (failed(parseCommCollectiveTail( - parser, result, fixedOperands, fixedTypes, recvClause, groupOps, - groupTypes, {1, 1, recvClause.pong ? 1 : 0}, - {"root"}))) - return failure(); - return success(); -} - -void mlir::pto::CommTGatherOp::print(OpAsmPrinter &p) { - p << "(" << getDst() << ", "; - printCommRecvClause(p, getPing(), getPong()); - p << ", "; - printCommGroupClause(p, getGroup()); - p << " : " << getDst().getType() << ", " << getPing().getType(); - if (getPong()) - p << ", " << getPong().getType(); - printCommGroupTypes(p, getGroup()); - p << ")"; - p.printOptionalAttrDict((*this)->getAttrs(), - /*elidedAttrs=*/{"operandSegmentSizes"}); -} - -ParseResult mlir::pto::CommTScatterOp::parse(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::UnresolvedOperand src; - CommRecvClause recvClause; - SmallVector groupOps; - SmallVector groupTypes; - - if (parser.parseLParen() || parser.parseOperand(src) || parser.parseComma()) - return failure(); - if (failed(parseCommRecvClause(parser, recvClause))) - return failure(); - - SmallVector fixedOperands{src}; - SmallVector fixedTypes(1); - if (failed(parseCommCollectiveTail( - parser, result, fixedOperands, fixedTypes, recvClause, groupOps, - groupTypes, {1, 1, recvClause.pong ? 1 : 0}, - {"root"}))) - return failure(); - return success(); -} - -void mlir::pto::CommTScatterOp::print(OpAsmPrinter &p) { - p << "(" << getSrc() << ", "; - printCommRecvClause(p, getPing(), getPong()); - p << ", "; - printCommGroupClause(p, getGroup()); - p << " : " << getSrc().getType() << ", " << getPing().getType(); - if (getPong()) - p << ", " << getPong().getType(); - printCommGroupTypes(p, getGroup()); - p << ")"; - p.printOptionalAttrDict((*this)->getAttrs(), - /*elidedAttrs=*/{"operandSegmentSizes"}); -} - -ParseResult mlir::pto::TReduceOp::parse(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::UnresolvedOperand dst, acc; - CommRecvClause recvClause; - SmallVector groupOps; - SmallVector groupTypes; - - if (parser.parseLParen() || parser.parseOperand(dst) || parser.parseComma() || - parser.parseOperand(acc) || parser.parseComma()) - return failure(); - if (failed(parseCommRecvClause(parser, recvClause))) - return failure(); - - SmallVector fixedOperands{dst, acc}; - SmallVector fixedTypes(2); - if (failed(parseCommCollectiveTail( - parser, result, fixedOperands, fixedTypes, recvClause, groupOps, - groupTypes, {1, 1, 1, recvClause.pong ? 1 : 0}, - {"reduceOp", "root"}))) - return failure(); - return success(); -} - -void mlir::pto::TReduceOp::print(OpAsmPrinter &p) { - p << "(" << getDst() << ", " << getAcc() << ", "; - printCommRecvClause(p, getRecvPing(), getRecvPong()); - p << ", "; - printCommGroupClause(p, getGroup()); - p << " : " << getDst().getType() << ", " << getAcc().getType() << ", " - << getRecvPing().getType(); - if (getRecvPong()) - p << ", " << getRecvPong().getType(); - printCommGroupTypes(p, getGroup()); - p << ")"; - p.printOptionalAttrDict((*this)->getAttrs(), - /*elidedAttrs=*/{"operandSegmentSizes"}); -} - -ParseResult mlir::pto::MakeTensorViewOp::parse(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::UnresolvedOperand ptr; - SmallVector shapeOps; - SmallVector strideOps; - - Type resultTy; - - // %ptr - if (parser.parseOperand(ptr)) - return failure(); - - // , shape = [ ... ] - if (parser.parseComma() || parser.parseKeyword("shape") || parser.parseEqual() || - parser.parseLSquare() || - parser.parseOperandList(shapeOps) || - parser.parseRSquare()) - return failure(); - - // strides = [ ... ] - if (parser.parseComma() || parser.parseKeyword("strides") || parser.parseEqual() || - parser.parseLSquare() || - parser.parseOperandList(strideOps) || - parser.parseRSquare()) - return failure(); - - // attr-dict - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - - // : result-type - if (parser.parseColonType(resultTy)) - return failure(); - result.addTypes(resultTy); - - auto tvTy = llvm::dyn_cast(resultTy); - if (!tvTy) - return parser.emitError(parser.getCurrentLocation(), - "expected result type pto.tensor_view<...>"); - - Type elemTy = tvTy.getElementType(); - - Type ptrTy = mlir::pto::PtrType::get(parser.getContext(), elemTy); - - // resolve %ptr - if (parser.resolveOperand(ptr, ptrTy, result.operands)) - return failure(); - - // resolve shape/strides 为 index - Type indexTy = parser.getBuilder().getIndexType(); - if (parser.resolveOperands(shapeOps, indexTy, result.operands)) - return failure(); - if (parser.resolveOperands(strideOps, indexTy, result.operands)) - return failure(); - - auto segAttr = parser.getBuilder().getDenseI32ArrayAttr( - {1, (int32_t)shapeOps.size(), (int32_t)strideOps.size()}); - result.addAttribute("operandSegmentSizes", segAttr); - - return success(); -} - -void mlir::pto::MakeTensorViewOp::print(OpAsmPrinter &p) { - p << " " << getPtr(); - - p << ", shape = ["; - p.printOperands(getShape()); - p << "]"; - - p << ", strides = ["; - p.printOperands(getStrides()); - p << "]"; - - p.printOptionalAttrDict((*this)->getAttrs(), - /*elidedAttrs=*/{"operandSegmentSizes"}); - - p << " : " << getResult().getType(); -} - -// Layout inference helpers for make_tensor_view -static std::optional getConstIndexValue(Value v) { - if (auto c = v.getDefiningOp()) - return c.value(); - if (auto c = v.getDefiningOp()) { - if (auto ia = dyn_cast(c.getValue())) - return ia.getInt(); - } - return std::nullopt; -} - -static FailureOr -inferPartitionViewResultTypeFromSizes(mlir::pto::TensorViewType sourceType, - ValueRange sizes) { - if (!sourceType) - return failure(); - - if ((int64_t)sizes.size() != sourceType.getRank()) - return failure(); - - SmallVector shape; - shape.reserve(sizes.size()); - for (Value size : sizes) { - auto constSize = getConstIndexValue(size); - if (constSize && *constSize >= 0) - shape.push_back(*constSize); - else - shape.push_back(ShapedType::kDynamic); - } - - return mlir::pto::PartitionTensorViewType::get( - sourceType.getContext(), shape, sourceType.getElementType()); -} - -ParseResult mlir::pto::PartitionViewOp::parse(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::UnresolvedOperand source; - SmallVector offsets; - SmallVector sizes; - Type sourceTy; - Type resultTy; - bool hasExplicitResultTy = false; - - if (parser.parseOperand(source) || parser.parseComma() || - parser.parseKeyword("offsets") || parser.parseEqual() || - parser.parseLSquare() || parser.parseOperandList(offsets) || - parser.parseRSquare() || parser.parseComma() || - parser.parseKeyword("sizes") || parser.parseEqual() || - parser.parseLSquare() || parser.parseOperandList(sizes) || - parser.parseRSquare() || parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(sourceTy)) - return failure(); - - if (succeeded(parser.parseOptionalArrow())) { - if (parser.parseType(resultTy)) - return failure(); - hasExplicitResultTy = true; - } - - if (parser.resolveOperand(source, sourceTy, result.operands)) - return failure(); - - Type indexTy = parser.getBuilder().getIndexType(); - if (parser.resolveOperands(offsets, indexTy, result.operands) || - parser.resolveOperands(sizes, indexTy, result.operands)) - return failure(); - - auto &properties = result.getOrAddProperties(); - llvm::copy(ArrayRef( - {1, static_cast(offsets.size()), - static_cast(sizes.size())}), - properties.operandSegmentSizes.begin()); - - if (hasExplicitResultTy) { - result.addTypes(resultTy); - return success(); - } - - ValueRange allOperands(result.operands); - ValueRange sizeOperands = - allOperands.slice(1 + offsets.size(), sizes.size()); - auto inferredResultType = inferPartitionViewResultTypeFromSizes( - dyn_cast(sourceTy), sizeOperands); - if (failed(inferredResultType)) { - return parser.emitError(parser.getCurrentLocation(), - "failed to infer pto.partition_view result type"); - } - - result.addTypes(*inferredResultType); - return success(); -} - -void mlir::pto::PartitionViewOp::print(OpAsmPrinter &printer) { - printer << " " << getSource() << ", offsets = ["; - printer.printOperands(getOffsets()); - printer << "], sizes = ["; - printer.printOperands(getSizes()); - printer << "]"; - printer.printOptionalAttrDict((*this)->getAttrs(), - /*elidedAttrs=*/{"operandSegmentSizes"}); - printer << " : " << getSource().getType(); - - auto inferredResultType = inferPartitionViewResultTypeFromSizes( - dyn_cast(getSource().getType()), getSizes()); - if (succeeded(inferredResultType) && *inferredResultType == getResult().getType()) - return; - - printer << " -> " << getResult().getType(); -} - -static std::optional getConstantIntegerValueEx( - Value v, bool includeIndexAndIntOpsInConstFold) { - if (includeIndexAndIntOpsInConstFold) { - if (auto c = v.getDefiningOp()) - return c.value(); - if (auto c = v.getDefiningOp()) - return c.value(); - } - if (auto c = v.getDefiningOp()) { - if (auto ia = dyn_cast(c.getValue())) - return ia.getInt(); - } - return std::nullopt; -} - -static LogicalResult verifyNonNegativeIndexRowCol( - Operation &op, Value indexRow, Value indexCol, - bool includeIndexAndIntOpsInConstFold) { - if (!indexRow.getType().isIndex() || !indexCol.getType().isIndex()) - return op.emitOpError("expects indexRow and indexCol to be index type"); - auto row = - getConstantIntegerValueEx(indexRow, includeIndexAndIntOpsInConstFold); - auto col = - getConstantIntegerValueEx(indexCol, includeIndexAndIntOpsInConstFold); - if (row && *row < 0) - return op.emitOpError("expects indexRow to be non-negative"); - if (col && *col < 0) - return op.emitOpError("expects indexCol to be non-negative"); - return success(); -} - -static LogicalResult verifyExtractStaticBoundsCommon( - Operation &op, Value indexRow, Value indexCol, Type srcTy, Type dstTy, - bool includeIndexAndIntOpsInConstFold) { - auto row = - getConstantIntegerValueEx(indexRow, includeIndexAndIntOpsInConstFold); - auto col = - getConstantIntegerValueEx(indexCol, includeIndexAndIntOpsInConstFold); - auto srcShape = getShapeVec(srcTy); - auto dstShape = getShapeVec(dstTy); - if (srcShape.size() != 2 || dstShape.size() != 2) - return op.emitOpError("expects src and dst to be rank-2 tile_buf"); - if (row && srcShape[0] != ShapedType::kDynamic && - dstShape[0] != ShapedType::kDynamic && - *row + dstShape[0] > srcShape[0]) - return op.emitOpError("expects indexRow + dst.rows <= src.rows"); - if (col && srcShape[1] != ShapedType::kDynamic && - dstShape[1] != ShapedType::kDynamic && - *col + dstShape[1] > srcShape[1]) - return op.emitOpError("expects indexCol + dst.cols <= src.cols"); - return success(); -} - -static LogicalResult verifyInsertStaticBoundsCommon( - Operation &op, Value indexRow, Value indexCol, Type srcTy, Type dstTy, - bool includeIndexAndIntOpsInConstFold) { - auto row = - getConstantIntegerValueEx(indexRow, includeIndexAndIntOpsInConstFold); - auto col = - getConstantIntegerValueEx(indexCol, includeIndexAndIntOpsInConstFold); - auto srcShape = getValidShapeVec(srcTy); - auto dstShape = getShapeVec(dstTy); - if (srcShape.size() != 2 || dstShape.size() != 2) - return op.emitOpError("expects src and dst to be rank-2 tile_buf"); - if (row && srcShape[0] != ShapedType::kDynamic && - dstShape[0] != ShapedType::kDynamic && - *row + srcShape[0] > dstShape[0]) - return op.emitOpError("expects indexRow + src.rows <= dst.rows"); - if (col && srcShape[1] != ShapedType::kDynamic && - dstShape[1] != ShapedType::kDynamic && - *col + srcShape[1] > dstShape[1]) - return op.emitOpError("expects indexCol + src.cols <= dst.cols"); - return success(); -} - -static unsigned getElemByteSize(Type ty) { - return getPTOStorageElemByteSize(ty); -} - -static LogicalResult verifyTileBufLayoutConstraints(Operation *op, - pto::TileBufType tb, - StringRef name) { - auto shape = tb.getShape(); - if (shape.size() != 2) - return op->emitOpError() << "expects " << name << " to be rank-2"; - - int64_t rows = shape[0]; - int64_t cols = shape[1]; - if (rows != ShapedType::kDynamic && rows <= 0) - return op->emitOpError() << "expects " << name << " rows to be positive"; - if (cols != ShapedType::kDynamic && cols <= 0) - return op->emitOpError() << "expects " << name << " cols to be positive"; - - unsigned elemBytes = getElemByteSize(tb.getElementType()); - if (elemBytes == 0) - return op->emitOpError() << "expects " << name - << " element type to have a byte size"; - - auto cfg = tb.getConfigAttr(); - if (!cfg) - cfg = TileBufConfigAttr::getDefault(tb.getContext()); - auto readBLayout = [](Attribute attr, int32_t &out) -> bool { - if (auto layout = dyn_cast_or_null(attr)) { - out = static_cast(layout.getValue()); - return true; - } - if (auto value = dyn_cast_or_null(attr)) { - out = static_cast(value.getInt()); - return true; - } - return false; - }; - auto readSLayout = [](Attribute attr, int32_t &out) -> bool { - if (auto layout = dyn_cast_or_null(attr)) { - out = static_cast(layout.getValue()); - return true; - } - if (auto value = dyn_cast_or_null(attr)) { - out = static_cast(value.getInt()); - return true; - } - return false; - }; - int32_t blayout = 0; - int32_t slayout = 0; - if (!readBLayout(cfg.getBLayout(), blayout) || - !readSLayout(cfg.getSLayout(), slayout)) - return op->emitOpError() << "expects " << name - << " to have concrete tile layout attributes"; - constexpr int64_t kAlignedBytes = 32; - - auto checkByteAlignment = [&](int64_t dim, StringRef layoutName, - StringRef byteExpr) -> LogicalResult { - if (dim == ShapedType::kDynamic) - return success(); - int64_t bytes = dim * static_cast(elemBytes); - if (bytes % kAlignedBytes == 0) - return success(); - return op->emitOpError() - << "expects " << name << " " << layoutName - << " none_box tile " << byteExpr - << " to be 32-byte aligned, but got " << bytes << " bytes"; - }; - - if (slayout == static_cast(SLayout::NoneBox)) { - if (blayout == static_cast(BLayout::RowMajor)) - return checkByteAlignment(cols, "row-major", - "row byte size (cols * sizeof(dtype))"); - return checkByteAlignment(rows, "col-major", - "column byte size (rows * sizeof(dtype))"); - } - - int64_t innerRows = 0; - int64_t innerCols = 0; - int32_t fractal = static_cast(cfg.getSFractalSize().getInt()); - switch (fractal) { - case 1024: - innerRows = 16; - innerCols = 16; - break; - case 32: - innerRows = 16; - innerCols = 2; - break; - case 512: - if (kAlignedBytes % elemBytes != 0) - return op->emitOpError() << "expects " << name - << " element byte size to divide 32 for boxed " - "fractal-512 tile layout"; - if (slayout == static_cast(SLayout::RowMajor)) { - innerRows = 16; - innerCols = kAlignedBytes / static_cast(elemBytes); - } else if (slayout == static_cast(SLayout::ColMajor)) { - innerRows = kAlignedBytes / static_cast(elemBytes); - innerCols = 16; - } - break; - default: - break; - } - if (innerRows <= 0 || innerCols <= 0) - return op->emitOpError() << "expects " << name - << " to use a supported boxed tile layout"; - - auto loc = getPTOMemorySpaceEnum(tb); - bool allowUnalignedRows = - (loc && *loc == pto::AddressSpace::VEC) || fractal == 32 || rows == 1; - if (!allowUnalignedRows && rows != ShapedType::kDynamic && - rows % innerRows != 0) - return op->emitOpError() - << "expects " << name - << " boxed tile rows to be a multiple of innerRows (" << innerRows - << "), but got " << rows; - if (cols != ShapedType::kDynamic && cols % innerCols != 0) - return op->emitOpError() - << "expects " << name - << " boxed tile cols to be a multiple of innerCols (" << innerCols - << "), but got " << cols; - - return success(); -} - -[[maybe_unused]] static bool isSupportedLoadStoreElemTypeA2A3(Type ty) { - if (ty.isF16() || ty.isBF16() || ty.isF32()) - return true; - if (auto it = dyn_cast(ty)) { - unsigned width = it.getWidth(); - return width == 8 || width == 16 || width == 32 || width == 64; - } - return false; -} - -static bool isSupportedGatherElemTypeA2A3(Type ty) { - if (ty.isF16() || ty.isF32()) - return true; - if (auto it = dyn_cast(ty)) { - unsigned width = it.getWidth(); - return width == 16 || width == 32; - } - return false; -} - -static bool isSupportedGatherElemTypeA5(Type ty) { - if (isSupportedGatherElemTypeA2A3(ty) || ty.isBF16()) - return true; - if (auto ft = dyn_cast(ty)) { - unsigned width = ft.getWidth(); - return width == 8; - } - if (auto it = dyn_cast(ty)) - return it.getWidth() == 8 || it.getWidth() == 16 || it.getWidth() == 32; - return false; -} - -static std::optional -inferLayout(ArrayRef shape, ArrayRef strides, - unsigned elemBytes) { - if (shape.size() != strides.size() || elemBytes == 0) - return std::nullopt; - - // NZ / fractal: rank>=5, check middle dims (sh3/sh4/sh5 per spec) - if (shape.size() >= 5) { - int64_t sh3 = shape[2], sh4 = shape[3], sh5 = shape[4]; - int64_t st4 = strides[3], st5 = strides[4]; - bool alignMatch = (sh3 == 16) && (sh3 * sh4 * elemBytes == 512); - bool strideMatch = (st5 == 1) && (st4 == sh5); - if (alignMatch && strideMatch) - return mlir::pto::Layout::NZ; - } - - // ND: row-major contiguous - bool isRowMajor = true; - for (int i = 0, e = (int)shape.size() - 1; i < e; ++i) { - if (strides[i] != strides[i + 1] * shape[i + 1]) { - isRowMajor = false; - break; - } - } - if (isRowMajor && strides.back() == 1) - return mlir::pto::Layout::ND; - - // DN: col-major - bool isColMajor = true; - for (int i = 0, e = (int)shape.size() - 1; i < e; ++i) { - if (strides[i + 1] != strides[i] * shape[i]) { - isColMajor = false; - break; - } - } - if (isColMajor && strides.front() == 1) - return mlir::pto::Layout::DN; - - return mlir::pto::Layout::ND; // fallback -} - -static std::optional getLogicalViewLayout(Value value) { - if (!value) - return std::nullopt; - if (auto part = value.getDefiningOp()) - return getLogicalViewLayout(part.getSource()); - if (auto make = value.getDefiningOp()) { - auto tvTy = dyn_cast(make.getResult().getType()); - if (!tvTy) - return std::nullopt; - SmallVector shape(tvTy.getShape().begin(), tvTy.getShape().end()); - SmallVector strides; - strides.reserve(make.getStrides().size()); - for (Value stride : make.getStrides()) { - auto cst = getConstIndexValue(stride); - if (!cst) - return std::nullopt; - strides.push_back(*cst); - } - return inferLayout(shape, strides, getElemByteSize(tvTy.getElementType())); - } - return std::nullopt; -} - -static std::optional getTileBufLogicalLayout(pto::TileBufType type) { - if (!type) - return std::nullopt; - int32_t sl = type.getSLayoutValueI32(); - int32_t bl = type.getBLayoutValueI32(); - if (sl != static_cast(pto::SLayout::NoneBox)) - return pto::Layout::NZ; - if (bl == static_cast(pto::BLayout::RowMajor)) - return pto::Layout::ND; - if (bl == static_cast(pto::BLayout::ColMajor)) - return pto::Layout::DN; - return std::nullopt; -} - -static bool isRowMajorTileBuf(Type ty) { - auto tb = mlir::dyn_cast(ty); - return tb && tb.getBLayoutValueI32() == static_cast(pto::BLayout::RowMajor); -} - -static LogicalResult verifyRowReductionSrcLayout(Operation *op, Type ty, - StringRef name) { - if (failed(verifyTileBufCommon(op, ty, name))) - return failure(); - auto as = getPTOMemorySpaceEnum(ty); - if (!as || *as != pto::AddressSpace::VEC) - return op->emitOpError() << "expects " << name << " to be in the vec address space"; - if (auto tb = dyn_cast(ty)) { - if (tb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor)) - return op->emitOpError() << "expects " << name << " to use the row_major blayout"; - } - if (auto mr = dyn_cast(ty)) - (void)mr; - if (auto tb = dyn_cast(ty)) { - if (tb.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) - return op->emitOpError() << "expects " << name - << " to use the none_box slayout"; - } - if (auto tb = dyn_cast(ty)) { - auto layout = getTileBufLogicalLayout(tb); - if (layout && *layout != pto::Layout::ND) - return op->emitOpError() << "expects " << name - << " to use an ND-style tile layout"; - } - return success(); -} - -static LogicalResult verifyRowReductionDstLayout(Operation *op, Type ty, - StringRef name) { - if (failed(verifyTileBufCommon(op, ty, name))) - return failure(); - auto as = getPTOMemorySpaceEnum(ty); - if (!as || *as != pto::AddressSpace::VEC) - return op->emitOpError() << "expects " << name << " to be in the vec address space"; - if (auto tb = dyn_cast(ty)) { - if (tb.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) - return op->emitOpError() << "expects " << name - << " to use the none_box slayout"; - } - if (auto tb = dyn_cast(ty)) { - if (tb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) && - tb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor)) - return op->emitOpError() << "expects " << name - << " to use the row_major or col_major blayout"; - } - if (auto mr = dyn_cast(ty)) - (void)mr; - if (auto tb = dyn_cast(ty)) { - auto layout = getTileBufLogicalLayout(tb); - if (layout && *layout == pto::Layout::DN) { - auto shape = getShapeVec(ty); - if (shape.size() == 2 && shape[1] != ShapedType::kDynamic && shape[1] != 1) - return op->emitOpError() << "expects DN-style " << name - << " to have shape[1] == 1"; - return success(); - } - if (layout && *layout == pto::Layout::ND) - return success(); - if (layout) - return op->emitOpError() << "expects " << name - << " to use a DN-style column vector tile or legacy ND-style tile"; - } - return success(); - auto valid = getValidShapeVec(ty); - if (valid.size() != 2) - return op->emitOpError() << "expects " << name << " to have rank-2 valid_shape"; - if (valid[1] != ShapedType::kDynamic && valid[1] != 1) - return op->emitOpError() << "expects " << name << " valid_shape[1] to be 1"; - return success(); -} - -static LogicalResult verifyRowReductionValidRegion(Operation *op, Type srcTy, - Type dstTy) { - auto srcValid = getValidShapeVec(srcTy); - auto dstValid = getValidShapeVec(dstTy); - if (srcValid.size() != 2 || dstValid.size() != 2) - return op->emitOpError("expects src and dst to have rank-2 valid_shape"); - if (srcValid[0] != ShapedType::kDynamic && srcValid[0] == 0) - return op->emitOpError("expects src valid_shape[0] to be non-zero"); - if (srcValid[1] != ShapedType::kDynamic && srcValid[1] == 0) - return op->emitOpError("expects src valid_shape[1] to be non-zero"); - if (srcValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && - srcValid[0] != dstValid[0]) - return op->emitOpError("expects src and dst to have the same valid_shape[0]"); - if (dstValid[1] != ShapedType::kDynamic && dstValid[1] != 1) - return op->emitOpError("expects dst valid_shape[1] to be 1"); - return success(); -} - -static bool isSupportedRowReductionElemType(Type elem) { - return elem.isInteger(16) || elem.isInteger(32) || elem.isF16() || - elem.isF32(); -} - -static LogicalResult verifyTRowReductionNoTmpCommon(Operation *op, Type srcTy, - Type dstTy, - StringRef elemTypeError) { - if (failed(verifyRowReductionSrcLayout(op, srcTy, "src")) || - failed(verifyRowReductionDstLayout(op, dstTy, "dst"))) - return failure(); - if (getElemTy(srcTy) != getElemTy(dstTy)) - return op->emitOpError("expects src and dst to have the same element type"); - if (failed(verifyRowReductionValidRegion(op, srcTy, dstTy))) - return failure(); - if (!isSupportedRowReductionElemType(getElemTy(srcTy))) - return op->emitOpError(elemTypeError); - return success(); -} - -static LogicalResult verifyTRowReductionWithTmpCommon(Operation *op, Type srcTy, - Type tmpTy, Type dstTy, - StringRef elemTypeError) { - if (failed(verifyRowReductionSrcLayout(op, srcTy, "src")) || - failed(verifyVecTileCommon(op, tmpTy, "tmp")) || - failed(verifyRowReductionDstLayout(op, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameElemType(op, srcTy, tmpTy, "src", "tmp")) || - failed(verifyTileBufSameValidShape(op, srcTy, tmpTy, "src", "tmp"))) - return failure(); - if (getElemTy(srcTy) != getElemTy(dstTy)) - return op->emitOpError("expects src and dst to have the same element type"); - if (failed(verifyRowReductionValidRegion(op, srcTy, dstTy))) - return failure(); - if (!isSupportedRowReductionElemType(getElemTy(srcTy))) - return op->emitOpError(elemTypeError); - return success(); -} - -static LogicalResult verifyTRowArgReductionCommon(Operation *op, Type srcTy, - Type tmpTy, Type dstTy) { - if (failed(verifyRowReductionSrcLayout(op, srcTy, "src")) || - failed(verifyVecTileCommon(op, tmpTy, "tmp")) || - failed(verifyRowReductionDstLayout(op, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameElemType(op, srcTy, tmpTy, "src", "tmp")) || - failed(verifyTileBufSameValidShape(op, srcTy, tmpTy, "src", "tmp"))) - return failure(); - if (failed(verifyRowReductionValidRegion(op, srcTy, dstTy))) - return failure(); - Type srcElem = getElemTy(srcTy); - if (!isSupportedRowReductionElemType(srcElem)) - return op->emitOpError("expects src element type to be i16/i32/f16/f32"); - auto dstInt = dyn_cast(getElemTy(dstTy)); - if (!dstInt || dstInt.getWidth() != 32) - return op->emitOpError("expects dst element type to be i32 or ui32"); - return success(); -} - -static LogicalResult verifyNDStyleVecTile(Operation *op, Type ty, StringRef name) { - if (failed(verifyTileBufCommon(op, ty, name))) - return failure(); - auto as = getPTOMemorySpaceEnum(ty); - if (!as || *as != pto::AddressSpace::VEC) - return op->emitOpError() << "expects " << name << " to be in the vec address space"; - if (auto tb = dyn_cast(ty)) { - if (tb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor)) - return op->emitOpError() << "expects " << name << " to use the row_major blayout"; - if (tb.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) - return op->emitOpError() << "expects " << name << " to use the none_box slayout"; - } - return success(); -} - -static LogicalResult verifyColReductionValidRegion(Operation *op, Type srcTy, - Type dstTy, - bool requireNonZeroSrc) { - auto srcValid = getValidShapeVec(srcTy); - auto dstValid = getValidShapeVec(dstTy); - if (srcValid.size() != 2 || dstValid.size() != 2) - return op->emitOpError("expects src and dst to have rank-2 valid_shape"); - if (requireNonZeroSrc) { - if (srcValid[0] != ShapedType::kDynamic && srcValid[0] == 0) - return op->emitOpError("expects src valid_shape[0] to be non-zero"); - if (srcValid[1] != ShapedType::kDynamic && srcValid[1] == 0) - return op->emitOpError("expects src valid_shape[1] to be non-zero"); - } - if (srcValid[1] != ShapedType::kDynamic && dstValid[1] != ShapedType::kDynamic && - srcValid[1] != dstValid[1]) - return op->emitOpError("expects src and dst to have the same valid_shape[1]"); - return success(); -} - -static LogicalResult verifyColArgReductionDstLayout(Operation *op, Type ty, - StringRef name) { - if (failed(verifyNDStyleVecTile(op, ty, name))) - return failure(); - auto valid = getValidShapeVec(ty); - if (valid.size() != 2) - return op->emitOpError() << "expects " << name - << " to have rank-2 valid_shape"; - if (valid[0] != ShapedType::kDynamic && valid[0] != 1) - return op->emitOpError() << "expects " << name - << " valid_shape[0] to be 1"; - return success(); -} - -static std::optional getConstantIntegerValue(Value value) { - if (!value) - return std::nullopt; - if (auto arithCst = value.getDefiningOp()) { - if (auto intAttr = dyn_cast(arithCst.getValue())) - return intAttr.getInt(); - } - return std::nullopt; -} - -LogicalResult mlir::pto::MakeTensorViewOp::verify() { - auto tvTy = dyn_cast(getResult().getType()); - if (!tvTy) - return emitOpError("result must be pto.tensor_view<...>"); - - auto pty = dyn_cast(getPtr().getType()); - if (!pty) - return emitOpError("ptr operand must be !pto.ptr<...>"); - - if (pty.getElementType() != tvTy.getElementType()) - return emitOpError() << "ptr element type must match tensor_view element type, but got ptr=" - << pty.getElementType() << " view=" << tvTy.getElementType(); - - int64_t rank = tvTy.getRank(); - - if ((int64_t)getShape().size() != rank || (int64_t)getStrides().size() != rank) - return emitOpError() << "shape/strides operand counts must match tensor_view rank=" - << rank; - - // Detect dynamic shape/stride. - bool hasDynamicShape = llvm::any_of(tvTy.getShape(), [](int64_t v) { - return v == ShapedType::kDynamic; - }); - bool hasDynamicStride = llvm::any_of(getStrides(), [](Value s) { - return !getConstIndexValue(s).has_value(); - }); - - auto layoutAttr = getLayoutAttr(); - - // 1) Dynamic shape/stride without explicit layout: warn and keep going. - if ((hasDynamicShape || hasDynamicStride) && !layoutAttr) { - return success(); - } - - // 2) Static shape/stride with explicit layout: verify correctness. - bool allStaticStride = true; - SmallVector strideInts; - strideInts.reserve(getStrides().size()); - for (Value s : getStrides()) { - auto val = getConstIndexValue(s); - if (!val) { - allStaticStride = false; - break; - } - strideInts.push_back(*val); - } - - bool allStaticShape = - llvm::none_of(tvTy.getShape(), [](int64_t v) { return v == ShapedType::kDynamic; }); - - if (layoutAttr && allStaticShape && allStaticStride) { - SmallVector shapeInts(tvTy.getShape().begin(), tvTy.getShape().end()); - if (auto inferred = inferLayout(shapeInts, strideInts, - getElemByteSize(tvTy.getElementType()))) { - (void)inferred; - } - } - - return success(); -} - -LogicalResult mlir::pto::PartitionViewOp::verify() { - auto srcTy = dyn_cast(getSource().getType()); - auto resTy = dyn_cast(getResult().getType()); - if (!srcTy || !resTy) - return emitOpError("expects tensor_view source and partition_tensor_view result"); - - if (srcTy.getElementType() != resTy.getElementType()) - return emitOpError() << "element type mismatch between source and result: src=" - << srcTy.getElementType() << " result=" - << resTy.getElementType(); - - int64_t srcRank = srcTy.getRank(); - if ((int64_t)getOffsets().size() != srcRank) - return emitOpError() << "offset count (" << getOffsets().size() - << ") must match source rank (" << srcRank << ")"; - - if ((int64_t)getSizes().size() != srcRank) - return emitOpError() << "size count (" << getSizes().size() - << ") must match source rank (" << srcRank << ")"; - - ArrayRef srcShape = srcTy.getShape(); - ArrayRef resShape = resTy.getShape(); - bool sameRank = resTy.getRank() == srcRank; - - for (int64_t i = 0; i < srcRank; ++i) { - auto offVal = getConstIndexValue(getOffsets()[i]); - auto sizeVal = getConstIndexValue(getSizes()[i]); - - if (offVal && *offVal < 0) - return emitOpError() << "offset at dim " << i - << " must be non-negative, got " << *offVal; - - if (sizeVal && *sizeVal <= 0) - return emitOpError() << "size at dim " << i - << " must be positive, got " << *sizeVal; - - if (sameRank && sizeVal) { - int64_t resDim = resShape[i]; - if (resDim != ShapedType::kDynamic && *sizeVal != resDim) - return emitOpError() << "size/result mismatch at dim " << i - << ": size operand=" << *sizeVal - << " result type dim=" << resDim; - } - - int64_t srcDim = srcShape[i]; - if (srcDim == ShapedType::kDynamic) - continue; - - if (sizeVal && *sizeVal > srcDim) - return emitOpError() << "size at dim " << i << " (" << *sizeVal - << ") exceeds static source dim (" << srcDim << ")"; - - if (offVal && sizeVal && (*offVal + *sizeVal > srcDim)) - return emitOpError() << "offset+size at dim " << i << " (" - << (*offVal + *sizeVal) - << ") exceeds static source dim (" << srcDim << ")"; - } - - return success(); -} - -LogicalResult mlir::pto::AddPtrOp::verify() { - Value ptr = getOperation()->getOperand(0); - Value result = getOperation()->getResult(0); - - auto ptrTy = dyn_cast(ptr.getType()); - if (!ptrTy) - return emitOpError("ptr operand must be !pto.ptr<...>"); - - auto resTy = dyn_cast(result.getType()); - if (!resTy) - return emitOpError("result must be !pto.ptr<...>"); - - if (ptrTy != resTy) - return emitOpError("result type must match ptr operand type"); - - return success(); -} - -static LogicalResult verifyPtrLikeForAddressCast(Operation *op, Type type, - StringRef name) { - if (isa(type)) - return success(); - - auto memTy = dyn_cast(type); - if (!memTy) - return op->emitOpError() - << "expects " << name << " to be !pto.ptr<...> or a GM memref"; - - if (memTy.getRank() != 1) - return op->emitOpError() - << "expects lowered memref " << name << " to be rank-1"; - - if (!isGmAddressSpaceAttr(memTy.getMemorySpace())) - return op->emitOpError() - << "expects lowered memref " << name << " to use GM address space"; - - return success(); -} - -static Type getPointerLikeElementType(Type type) { - if (auto ptrTy = dyn_cast(type)) - return ptrTy.getElementType(); - if (auto memTy = dyn_cast(type)) - return memTy.getElementType(); - return Type(); -} - -static bool isEmitCSupportedScalarType(Type type) { - if (!type) - return false; - if (type.isF16() || type.isBF16() || type.isF32() || type.isF64()) - return true; - if (auto intTy = dyn_cast(type)) - return intTy.getWidth() == 8 || intTy.getWidth() == 16 || - intTy.getWidth() == 32 || intTy.getWidth() == 64; - if (mlir::pto::isPTOFloat8Type(type)) - return true; - if (isa(type)) - return true; - return false; -} - -LogicalResult mlir::pto::PtrToIntOp::verify() { - Type resultTy = getResult().getType(); - auto intTy = dyn_cast(resultTy); - if (!intTy || intTy.getWidth() != 64) - return emitOpError("result must be i64"); - - return verifyPtrLikeForAddressCast(getOperation(), getPtr().getType(), - "ptr operand"); -} - -LogicalResult mlir::pto::IntToPtrOp::verify() { - auto addrTy = dyn_cast(getAddr().getType()); - if (!addrTy || addrTy.getWidth() != 64) - return emitOpError("address operand must be i64"); - - if (failed(verifyPtrLikeForAddressCast(getOperation(), getResult().getType(), - "result"))) - return failure(); - - Type dstElem = getPointerLikeElementType(getResult().getType()); - if (!isEmitCSupportedScalarType(dstElem)) - return emitOpError("result element type is not supported by EmitC: ") - << dstElem; - - return success(); -} - -LogicalResult mlir::pto::LocalArrayGetOp::verify() { - auto arrayTy = getArray().getType(); - int64_t rank = arrayTy.getRank(); - int64_t numIdx = static_cast(getIndices().size()); - if (numIdx != rank) - return emitOpError() << "expects " << rank - << " indices for !pto.local_array of rank " << rank - << ", got " << numIdx; - if (getResult().getType() != arrayTy.getElementType()) - return emitOpError() - << "result type " << getResult().getType() - << " does not match array element type " - << arrayTy.getElementType(); - return success(); -} - -LogicalResult mlir::pto::LocalArraySetOp::verify() { - auto arrayTy = getArray().getType(); - int64_t rank = arrayTy.getRank(); - int64_t numIdx = static_cast(getIndices().size()); - if (numIdx != rank) - return emitOpError() << "expects " << rank - << " indices for !pto.local_array of rank " << rank - << ", got " << numIdx; - if (getValue().getType() != arrayTy.getElementType()) - return emitOpError() << "value type " << getValue().getType() - << " does not match array element type " - << arrayTy.getElementType(); - return success(); -} - - - - -void PTODialect::initialize() { - addTypes< -#define GET_TYPEDEF_LIST -#include "PTO/IR/PTOTypeDefs.cpp.inc" - >(); - - addOperations< -#define GET_OP_LIST -#include "PTO/IR/PTOOps.cpp.inc" - >(); - - addAttributes< -#define GET_ATTRDEF_LIST -#include "PTO/IR/PTOAttrs.cpp.inc" - >(); -} - - -AddressSpaceAttr mlir::pto::getPTOAddressSpaceAttr(Type type) { - auto memRefType = dyn_cast(type); - if (!memRefType) - return {}; - auto scopeAttr = dyn_cast(memRefType.getMemorySpace()); - if (!scopeAttr) - return {}; - return scopeAttr; -} - -bool mlir::pto::isScalarPtrOrMemRef(Type type) { - if (auto pty = dyn_cast(type)) - return true; - if (auto memTy = dyn_cast(type)) - return isGmAddressSpaceAttr(memTy.getMemorySpace()); - return false; -} - -bool mlir::pto::hasExplicitPTOEntryAttr(func::FuncOp func) { - return func && (func->hasAttrOfType(kPTOEntryAttrName) || - func->hasAttrOfType(kLegacyHACCEntryAttrName)); -} - -static constexpr StringLiteral kEffectivePTOEntryAttrName = - "pto.internal.entry"; - -static SmallVector getPTOFunctionDefinitions(ModuleOp module) { - SmallVector defs; - if (!module) - return defs; - for (auto func : module.getOps()) { - if (!func.isDeclaration()) - defs.push_back(func); - } - return defs; -} - -bool mlir::pto::isPTOEntryFunction(func::FuncOp func) { - if (!func || func.isDeclaration()) - return false; - if (auto attr = func->getAttrOfType(kEffectivePTOEntryAttrName)) - return attr.getValue(); - if (hasExplicitPTOEntryAttr(func)) - return true; - - ModuleOp module = func->getParentOfType(); - if (!module) - return false; - SmallVector defs = getPTOFunctionDefinitions(module); - return defs.size() == 1 && defs.front() == func; -} - -LogicalResult mlir::pto::validatePTOEntryFunctions(ModuleOp module) { - if (!module) - return success(); - - for (auto func : module.getOps()) { - if (!hasExplicitPTOEntryAttr(func)) - continue; - if (func.isDeclaration()) { - return func.emitOpError() - << "`" << kPTOEntryAttrName - << "` is only valid on function definitions"; - } - } - - for (auto func : module.getOps()) { - if (!isPTOEntryFunction(func)) - continue; - if (func.getFunctionType().getNumResults() != 0) { - return func.emitOpError() - << "PTO entry functions must return void"; - } - } - return success(); -} - -void mlir::pto::annotatePTOEntryFunctions(ModuleOp module) { - if (!module) - return; - - SmallVector defs = getPTOFunctionDefinitions(module); - for (auto func : module.getOps()) - func->removeAttr(kEffectivePTOEntryAttrName); - - if (defs.empty()) - return; - if (defs.size() == 1) { - defs.front()->setAttr(kEffectivePTOEntryAttrName, - BoolAttr::get(module.getContext(), true)); - return; - } - - for (auto func : defs) { - func->setAttr(kEffectivePTOEntryAttrName, - BoolAttr::get(module.getContext(), - hasExplicitPTOEntryAttr(func))); - } -} - -//===----------------------------------------------------------------------===// -// PTO Load/Store/Addf (non-DPS polymorphic) verification + inference. -// - If operands are memref/tensor: verify strictly. -// - Otherwise (tile_view/tile etc): accept (so old IR can still parse). -//===----------------------------------------------------------------------===// - -[[maybe_unused]] static LogicalResult verifyMemrefToTensorLoad(Operation *op, Value src, Value res) { - auto mr = dyn_cast(src.getType()); - auto rt = dyn_cast(res.getType()); - if (!mr) - return success(); // non-memref case: don't block old IR - if (!rt) - return op->emitOpError("when src is memref, result must be ranked tensor"); - - if (mr.getElementType() != rt.getElementType()) - return op->emitOpError() << "memref/tensor element type mismatch: memref=" - << mr.getElementType() << " tensor=" << rt.getElementType(); - - if (mr.getRank() != rt.getRank()) - return op->emitOpError() << "rank mismatch: memref rank=" << mr.getRank() - << " tensor rank=" << rt.getRank(); - - if (mr.hasStaticShape()) { - if (!rt.hasStaticShape()) - return op->emitOpError("memref has static shape but result tensor is not static"); - if (mr.getShape() != rt.getShape()) - return op->emitOpError() << "shape mismatch: memref=" << mr << " tensor=" << rt; - } else { - // For dynamic memref dims: if tensor dim is static, allow it; if it's dynamic too, also fine. - // We only reject when a memref static dim conflicts with tensor static dim. - for (int64_t i = 0; i < mr.getRank(); ++i) { - int64_t md = mr.getDimSize(i); - int64_t td = rt.getDimSize(i); - if (md != ShapedType::kDynamic && td != ShapedType::kDynamic && md != td) - return op->emitOpError() << "dim mismatch at " << i << ": memref=" << md << " tensor=" << td; - } - } - return success(); -} - -[[maybe_unused]] static LogicalResult verifyMemrefTensorStore(Operation *op, Value dst, Value src) { - auto mr = dyn_cast(dst.getType()); - if (!mr) - return success(); // non-memref case: old tile IR allowed - auto rt = dyn_cast(src.getType()); - if (!rt) - return op->emitOpError("when dst is memref, src must be ranked tensor"); - - if (mr.getElementType() != rt.getElementType()) - return op->emitOpError() << "memref/tensor element type mismatch: memref=" - << mr.getElementType() << " tensor=" << rt.getElementType(); - - if (mr.getRank() != rt.getRank()) - return op->emitOpError() << "rank mismatch: memref rank=" << mr.getRank() - << " tensor rank=" << rt.getRank(); - - for (int64_t i = 0; i < mr.getRank(); ++i) { - int64_t md = mr.getDimSize(i); - int64_t td = rt.getDimSize(i); - if (md != ShapedType::kDynamic && td != ShapedType::kDynamic && md != td) - return op->emitOpError() << "dim mismatch at " << i << ": memref=" << md << " tensor=" << td; - } - return success(); -} - -LogicalResult AllocTileOp::verify() { - auto ty = getResult().getType(); // TileBufType - - if (failed(verifyTileBufLayoutConstraints(*this, ty, "result"))) - return failure(); - - // op 上有没有传 operands - bool hasVR = getValidRow() != nullptr; - bool hasVC = getValidCol() != nullptr; - - // type 上的 validShape - auto vs = ty.getValidShape(); - if (vs.size() != 2) - return emitOpError("result tile_buf must have rank-2 validShape"); - - // TileBuf valid dims use a negative sentinel (e.g. '?' / -1). Be robust to - // any negative value (some code may materialize MLIR dynamic sentinels). - bool needVR = (vs[0] < 0); - bool needVC = (vs[1] < 0); - - // 你要求的:v_row=?, v_col=? 时必须同时给两个 - // (这条规则由下面两句自然实现) - if (hasVR != needVR) - return emitOpError() << "valid_row operand " - << (needVR ? "is required" : "must be absent") - << " because result type v_row is " - << (needVR ? "?" : std::to_string(vs[0])); - - if (hasVC != needVC) - return emitOpError() << "valid_col operand " - << (needVC ? "is required" : "must be absent") - << " because result type v_col is " - << (needVC ? "?" : std::to_string(vs[1])); - - return success(); -} - -LogicalResult MaterializeTileOp::verify() { - auto sourceTy = cast(getSource().getType()); - auto resultTy = cast(getResult().getType()); - - if (sourceTy.getRank() != 2) - return emitOpError("source memref must be rank-2 to materialize a tile handle"); - if (resultTy.getRank() != 2) - return emitOpError("result tile_buf must be rank-2"); - if (failed(verifyTileBufLayoutConstraints(*this, resultTy, "result"))) - return failure(); - - auto viewSemantics = (*this)->getAttrOfType("pto.view_semantics"); - bool isSubview = viewSemantics && viewSemantics.getValue() == "subview"; - if (!isSubview && sourceTy.getShape() != resultTy.getShape()) - return emitOpError() << "source/result shape mismatch: source=" - << sourceTy << " result=" << resultTy; - - if (sourceTy.getElementType() != resultTy.getElementType()) - return emitOpError() << "source/result element type mismatch: source=" - << sourceTy.getElementType() - << " result=" << resultTy.getElementType(); - - if (sourceTy.getMemorySpace() != resultTy.getMemorySpace()) - return emitOpError() << "source/result memory space mismatch"; - - if (getConfig() != resultTy.getConfigAttr()) - return emitOpError("config attribute must match the result tile_buf config"); - - auto shape = resultTy.getShape(); - auto validShape = resultTy.getValidShape(); - if (validShape.size() != 2) - return emitOpError("result tile_buf must have rank-2 validShape"); - for (unsigned i = 0; i < 2; ++i) { - if (shape[i] != ShapedType::kDynamic && - validShape[i] != ShapedType::kDynamic && validShape[i] > shape[i]) { - return emitOpError() << "valid_shape[" << i << "] must be <= shape[" - << i << "]"; - } - } - - return success(); -} - -LogicalResult TAssignOp::verify() { - if (getTile().getType() != getResult().getType()) { - return emitOpError("result type must match tile operand type"); - } - return success(); -} - -LogicalResult TLoadOp::verify() { - auto verifyCommon = - [&](bool allowLowPrecision) - -> FailureOr> { - auto srcPart = dyn_cast(getSrc().getType()); - auto dstTile = dyn_cast(getDst().getType()); - if (!srcPart || !dstTile) { - emitOpError("expects src to be !pto.partition_tensor_view and dst to be !pto.tile_buf"); - return failure(); - } - if (failed(verifyTileBufCommon(*this, dstTile, "dst", allowLowPrecision))) - return failure(); - - auto srcShape = srcPart.getShape(); - for (unsigned i = 0; i < srcShape.size(); ++i) { - if (srcShape[i] != ShapedType::kDynamic && srcShape[i] <= 0) { - emitOpError() << "expects src shape[" << i << "] to be positive"; - return failure(); - } - } - auto dstValid = dstTile.getValidShape(); - for (unsigned i = 0; i < dstValid.size(); ++i) { - if (dstValid[i] != ShapedType::kDynamic && dstValid[i] <= 0) { - emitOpError() << "expects dst valid_shape[" << i << "] to be positive"; - return failure(); - } - } - return std::make_pair(srcPart, dstTile); - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - auto common = verifyCommon(/*allowLowPrecision=*/false); - if (failed(common)) - return failure(); - auto [srcPart, dstTile] = *common; - - Type srcElem = srcPart.getElementType(); - Type dstElem = dstTile.getElementType(); - if (isPTOLowPrecisionType(srcElem) || isPTOLowPrecisionType(dstElem)) - return emitOpError("expects A2/A3 tload low-precision element types to be unsupported"); - if (!(dstElem.isInteger(8) || dstElem.isInteger(16) || dstElem.isInteger(32) || - dstElem.isInteger(64) || dstElem.isF16() || dstElem.isBF16() || dstElem.isF32())) - return emitOpError("expects A2/A3 tload dst element type to be i8/i16/i32/i64/u64/f16/bf16/f32"); - - auto dstSpace = getPTOMemorySpaceEnum(dstTile); - if (!dstSpace || (*dstSpace != pto::AddressSpace::VEC && - *dstSpace != pto::AddressSpace::MAT)) - return emitOpError("expects A2/A3 tload dst to use loc=vec or loc=mat"); - - if (getElemByteSize(srcElem) != getElemByteSize(dstElem)) - return emitOpError("expects src and dst element types to have the same bitwidth"); - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - auto common = verifyCommon(/*allowLowPrecision=*/true); - if (failed(common)) - return failure(); - auto [srcPart, dstTile] = *common; - - Type srcElem = srcPart.getElementType(); - Type dstElem = dstTile.getElementType(); - unsigned srcBytes = getElemByteSize(srcElem); - unsigned dstBytes = getElemByteSize(dstElem); - if (srcBytes != dstBytes) - return emitOpError("expects src and dst element types to have the same element size"); - if (!(dstBytes == 1 || dstBytes == 2 || dstBytes == 4 || dstBytes == 8)) - return emitOpError("expects A5 tload dst element size to be 1, 2, 4, or 8 bytes"); - if (!isA5TLoadStoreTransferElemType(srcElem)) - return emitOpError("expects A5 tload src element type to be i8/i16/i32/i64/f16/bf16/f32/f8/hif8/fp4"); - if (!isA5TLoadStoreTransferElemType(dstElem)) - return emitOpError("expects A5 tload dst element type to be i8/i16/i32/i64/f16/bf16/f32/f8/hif8/fp4"); - - if (dstElem.isInteger(64)) { - auto pad = dstTile.getPadValueI32(); - if (pad != static_cast(pto::PadValue::Null) && - pad != static_cast(pto::PadValue::Zero)) - return emitOpError("expects A5 i64/u64 tload dst pad to be null or zero"); - } - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult TPrefetchOp::verify() { - auto verifyImpl = [&](bool allowLowPrecision) -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - - Type srcElem; - Type dstElem; - - if (auto srcPart = dyn_cast(srcTy)) { - auto srcShape = srcPart.getShape(); - for (unsigned i = 0; i < srcShape.size(); ++i) { - if (srcShape[i] != ShapedType::kDynamic && srcShape[i] <= 0) - return emitOpError() << "expects src shape[" << i << "] to be positive"; - } - srcElem = srcPart.getElementType(); - } else if (auto srcMr = dyn_cast(srcTy)) { - if (!srcMr.hasRank()) - return emitOpError("expects src memref to be ranked"); - for (int64_t dim : srcMr.getShape()) { - if (dim != ShapedType::kDynamic && dim <= 0) - return emitOpError("expects src memref shape to be positive"); - } - srcElem = srcMr.getElementType(); - } else { - return emitOpError("expects src to be !pto.partition_tensor_view or memref"); - } - - if (auto dstTile = dyn_cast(dstTy)) { - if (failed(verifyTileBufCommon(*this, dstTile, "dst", allowLowPrecision))) - return failure(); - auto dstValid = dstTile.getValidShape(); - for (unsigned i = 0; i < dstValid.size(); ++i) { - if (dstValid[i] != ShapedType::kDynamic && dstValid[i] <= 0) - return emitOpError() << "expects dst valid_shape[" << i << "] to be positive"; - } - auto dstSpace = getPTOMemorySpaceEnum(dstTile); - if (!dstSpace || (*dstSpace != pto::AddressSpace::VEC && - *dstSpace != pto::AddressSpace::MAT)) - return emitOpError("expects dst to use loc=vec or loc=mat"); - dstElem = dstTile.getElementType(); - } else if (auto dstMr = dyn_cast(dstTy)) { - auto dstSpace = getPTOMemorySpaceEnum(dstMr); - if (!dstSpace || (*dstSpace != pto::AddressSpace::VEC && - *dstSpace != pto::AddressSpace::MAT)) - return emitOpError("expects dst memref to use loc=vec or loc=mat"); - if (!dstMr.hasRank()) - return emitOpError("expects dst memref to be ranked"); - if (failed(verifyTileBufCommon(*this, dstMr, "dst", allowLowPrecision))) - return failure(); - dstElem = dstMr.getElementType(); - } else { - return emitOpError("expects dst to be !pto.tile_buf or memref"); - } - - if (getElemByteSize(srcElem) != getElemByteSize(dstElem)) - return emitOpError("expects src and dst element types to have the same element size"); - if (!allowLowPrecision && - (isPTOLowPrecisionType(srcElem) || isPTOLowPrecisionType(dstElem))) - return emitOpError("expects A2/A3 tprefetch low-precision element types to be unsupported"); - if (allowLowPrecision && - (!isA5TLoadStoreTransferElemType(srcElem) || - !isA5TLoadStoreTransferElemType(dstElem))) - return emitOpError("expects A5 tprefetch element types to be i8/i16/i32/i64/f16/bf16/f32/f8/hif8/fp4"); - return success(); - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - return verifyImpl(/*allowLowPrecision=*/false); - }; - auto verifyA5 = [&]() -> LogicalResult { - return verifyImpl(/*allowLowPrecision=*/true); - }; - switch (getVerifierTargetArch(getOperation())) { - case VerifierTargetArch::A2A3: - return verifyA2A3(); - case VerifierTargetArch::A5: - return verifyA5(); - } - return failure(); -} - -LogicalResult MakePrefetchAsyncContextOp::verify() { - Type workspaceTy = getWorkspace().getType(); - Type elemTy = nullptr; - if (auto ptrTy = dyn_cast(workspaceTy)) { - elemTy = ptrTy.getElementType(); - } else if (auto memTy = dyn_cast(workspaceTy)) { - if (!isGmAddressSpaceAttr(memTy.getMemorySpace())) - return emitOpError("expects workspace memref to be in GM address space"); - elemTy = memTy.getElementType(); - } else { - return emitOpError("expects workspace to be !pto.ptr or GM memref"); - } - if (!isByteIntegerType(elemTy)) - return emitOpError("expects workspace element type to be an 8-bit integer"); - return success(); -} - -LogicalResult TPrefetchAsyncOp::verify() { - if (failed(verifyAsyncFlatContiguous1DGMViewLike(getOperation(), getSrc(), - "src"))) - return failure(); - return success(); -} - -LogicalResult mlir::pto::SetFFTsOp::verify() { - auto mr = llvm::dyn_cast(getFfts().getType()); - if (!mr) - return emitOpError("expects a memref operand"); - - if (!mr.getElementType().isInteger(64) && !mr.getElementType().isInteger(8)) - return emitOpError("expects element type i64 (or i8)"); - - return mlir::success(); -} - -ParseResult mlir::pto::SyncSetOp::parse(OpAsmParser &parser, - OperationState &result) { - return parseSyncEventOpCommon(parser, result, - SyncSetOp::getPipeAttrName(result.name), - SyncSetOp::getEventIdAttrName(result.name)); -} - -void mlir::pto::SyncSetOp::print(OpAsmPrinter &p) { - printSyncEventOpCommon(p, getOperation(), getPipe(), getEventIdAttr(), - getEventIdDyn(), getPipeAttrName().getValue(), - getEventIdAttrName().getValue()); -} - -LogicalResult mlir::pto::SyncSetOp::verify() { - bool hasStatic = getEventIdAttr() != nullptr; - bool hasDynamic = static_cast(getEventIdDyn()); - if (hasStatic == hasDynamic) - return emitOpError() - << "expects exactly one event-id form: static attr or dynamic index operand"; - if (IntegerAttr fftsModeAttr = getFftsModeAttr()) { - int64_t fftsMode = fftsModeAttr.getInt(); - if (fftsMode < 0 || fftsMode > 2) - return emitOpError() << "requires ffts_mode in range [0, 2], but got " - << fftsMode; - } - - auto verifyA2A3 = [&]() -> LogicalResult { return success(); }; - auto verifyA5 = [&]() -> LogicalResult { - switch (getPipe().getPipe()) { - case PIPE::PIPE_FIX: - case PIPE::PIPE_MTE3: - return success(); - default: - return emitOpError() - << "A5 sync.set expects pipe to be one of , "; - } - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -ParseResult mlir::pto::SyncWaitOp::parse(OpAsmParser &parser, - OperationState &result) { - return parseSyncEventOpCommon(parser, result, - SyncWaitOp::getPipeAttrName(result.name), - SyncWaitOp::getEventIdAttrName(result.name)); -} - -void mlir::pto::SyncWaitOp::print(OpAsmPrinter &p) { - printSyncEventOpCommon(p, getOperation(), getPipe(), getEventIdAttr(), - getEventIdDyn(), getPipeAttrName().getValue(), - getEventIdAttrName().getValue()); -} - -ParseResult mlir::pto::SyncAllOp::parse(OpAsmParser &parser, - OperationState &result) { - SmallVector operands; - SmallVector operandTypes; - Attribute modeAttr; - Attribute coreTypeAttr; - - if (parser.parseLParen()) - return failure(); - - if (failed(parser.parseOptionalRParen())) { - if (parser.parseOperandList(operands) || parser.parseColonTypeList(operandTypes) || - parser.parseRParen()) - return failure(); - if (operands.size() != operandTypes.size()) - return parser.emitError(parser.getCurrentLocation()) - << "expects the same number of operands and operand types"; - } - - if (parser.parseKeyword("mode") || parser.parseEqual() || - parser.parseAttribute(modeAttr) || parser.parseComma() || - parser.parseKeyword("core_type") || parser.parseEqual() || - parser.parseAttribute(coreTypeAttr)) - return failure(); - - auto mode = dyn_cast(modeAttr); - if (!mode) - return parser.emitError(parser.getCurrentLocation()) - << "expects mode to be #pto.sync_all_mode<...>"; - - auto coreType = dyn_cast(coreTypeAttr); - if (!coreType) - return parser.emitError(parser.getCurrentLocation()) - << "expects core_type to be #pto.sync_core_type<...>"; - - result.addAttribute("mode", mode); - result.addAttribute("core_type", coreType); - - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - - auto addSegmentSizes = [&](int32_t gm, int32_t ub, int32_t l1, - int32_t used) { - result.addAttribute("operandSegmentSizes", - parser.getBuilder().getDenseI32ArrayAttr( - {gm, ub, l1, used})); - }; - - switch (mode.getValue()) { - case pto::SyncAllMode::Hard: - if (!operands.empty()) - return parser.emitError(parser.getCurrentLocation()) - << "expects hard syncall to have no operands"; - addSegmentSizes(0, 0, 0, 0); - return success(); - case pto::SyncAllMode::Soft: - break; - } - - switch (coreType.getValue()) { - case pto::SyncCoreType::AIVOnly: - if (operands.size() != 2 && operands.size() != 3) - return parser.emitError(parser.getCurrentLocation()) - << "expects soft AIV-only syncall to have gm_workspace, " - "ub_workspace, and optional used_cores"; - if (parser.resolveOperand(operands[0], operandTypes[0], result.operands) || - parser.resolveOperand(operands[1], operandTypes[1], result.operands)) - return failure(); - if (operands.size() == 3 && - parser.resolveOperand(operands[2], operandTypes[2], result.operands)) - return failure(); - addSegmentSizes(1, 1, 0, operands.size() == 3 ? 1 : 0); - return success(); - case pto::SyncCoreType::AICOnly: - if (operands.size() != 2 && operands.size() != 3) - return parser.emitError(parser.getCurrentLocation()) - << "expects soft AIC-only syncall to have gm_workspace, " - "l1_workspace, and optional used_cores"; - if (parser.resolveOperand(operands[0], operandTypes[0], result.operands) || - parser.resolveOperand(operands[1], operandTypes[1], result.operands)) - return failure(); - if (operands.size() == 3 && - parser.resolveOperand(operands[2], operandTypes[2], result.operands)) - return failure(); - addSegmentSizes(1, 0, 1, operands.size() == 3 ? 1 : 0); - return success(); - case pto::SyncCoreType::Mix: - if (operands.size() != 3 && operands.size() != 4) - return parser.emitError(parser.getCurrentLocation()) - << "expects soft mixed syncall to have gm_workspace, " - "ub_workspace, l1_workspace, and optional used_cores"; - if (parser.resolveOperand(operands[0], operandTypes[0], result.operands) || - parser.resolveOperand(operands[1], operandTypes[1], result.operands) || - parser.resolveOperand(operands[2], operandTypes[2], result.operands)) - return failure(); - if (operands.size() == 4 && - parser.resolveOperand(operands[3], operandTypes[3], result.operands)) - return failure(); - addSegmentSizes(1, 1, 1, operands.size() == 4 ? 1 : 0); - return success(); - } - - llvm_unreachable("unhandled SyncCoreType"); -} - -void mlir::pto::SyncAllOp::print(OpAsmPrinter &p) { - SmallVector operands; - if (getGmWorkspace()) - operands.push_back(getGmWorkspace()); - if (getUbWorkspace()) - operands.push_back(getUbWorkspace()); - if (getL1Workspace()) - operands.push_back(getL1Workspace()); - if (getUsedCores()) - operands.push_back(getUsedCores()); - - p << "("; - if (!operands.empty()) { - p.printOperands(operands); - p << " : "; - llvm::interleaveComma(operands, p, - [&](Value operand) { p.printType(operand.getType()); }); - } - p << ") mode = " << getMode() << ", core_type = " << getCoreType(); - p.printOptionalAttrDict((*this)->getAttrs(), - /*elidedAttrs=*/{"operandSegmentSizes", "mode", - "core_type"}); -} - -LogicalResult mlir::pto::SyncWaitOp::verify() { - bool hasStatic = getEventIdAttr() != nullptr; - bool hasDynamic = static_cast(getEventIdDyn()); - if (hasStatic == hasDynamic) - return emitOpError() - << "expects exactly one event-id form: static attr or dynamic index operand"; - - auto verifyA2A3 = [&]() -> LogicalResult { return success(); }; - auto verifyA5 = [&]() -> LogicalResult { - switch (getPipe().getPipe()) { - case PIPE::PIPE_FIX: - case PIPE::PIPE_MTE1: - case PIPE::PIPE_MTE2: - case PIPE::PIPE_MTE3: - case PIPE::PIPE_V: - return success(); - default: - return emitOpError() << "A5 sync.wait expects pipe to be one of " - ", , , " - ", "; - } - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult TStoreOp::verify() { - auto verifyCommon = - [&](bool allowLowPrecision) - -> FailureOr> { - auto srcTile = dyn_cast(getSrc().getType()); - auto dstPart = dyn_cast(getDst().getType()); - if (!srcTile || !dstPart) { - emitOpError("expects src to be !pto.tile_buf and dst to be !pto.partition_tensor_view"); - return failure(); - } - if (failed(verifyTileBufCommon(*this, srcTile, "src", allowLowPrecision))) - return failure(); - for (auto [idx, dim] : llvm::enumerate(dstPart.getShape())) { - if (dim != ShapedType::kDynamic && dim <= 0) { - emitOpError() << "expects dst shape[" << idx << "] to be positive"; - return failure(); - } - } - auto srcValid = srcTile.getValidShape(); - for (auto [idx, dim] : llvm::enumerate(srcValid)) { - if (dim != ShapedType::kDynamic && dim <= 0) { - emitOpError() << "expects src valid_shape[" << idx << "] to be positive"; - return failure(); - } - } - - // Keep TSTORE contract explicit while preserving existing legal layout - // reinterpretation paths (e.g. 1x1024 <-> 32x32, 5D partition views). - // When both sides are fully static, require equal element counts between - // dst shape and src valid_shape. - auto getStaticElemCount = [](ArrayRef shape) -> std::optional { - int64_t total = 1; - for (int64_t dim : shape) { - if (dim == ShapedType::kDynamic) - return std::nullopt; - if (dim <= 0) - return std::nullopt; - if (total > std::numeric_limits::max() / dim) - return std::nullopt; - total *= dim; - } - return total; - }; - - auto dstElemCount = getStaticElemCount(dstPart.getShape()); - auto srcValidElemCount = getStaticElemCount(srcValid); - if (dstElemCount && srcValidElemCount && *dstElemCount != *srcValidElemCount) { - emitOpError() << "expects dst static element count (" << *dstElemCount - << ") to match src valid_shape static element count (" - << *srcValidElemCount << ")"; - return failure(); - } - return std::make_pair(srcTile, dstPart); - }; - - auto isLoadStoreElemType = [&](Type ty) -> bool { - return ty.isInteger(8) || ty.isInteger(16) || ty.isInteger(32) || - ty.isInteger(64) || ty.isF16() || ty.isBF16() || ty.isF32(); - }; - auto isI8Like = [&](Type ty) -> bool { return ty.isInteger(8); }; - bool hasPreQuant = static_cast(getPreQuantScalar()); - auto reluMode = getReluPreMode(); - - auto verifyA2A3 = [&]() -> LogicalResult { - auto common = verifyCommon(/*allowLowPrecision=*/false); - if (failed(common)) - return failure(); - auto [srcTile, dstPart] = *common; - auto srcSpace = getPTOMemorySpaceEnum(srcTile); - if (!srcSpace || (*srcSpace != pto::AddressSpace::VEC && - *srcSpace != pto::AddressSpace::MAT && - *srcSpace != pto::AddressSpace::ACC)) - return emitOpError("expects A2/A3 tstore src to use loc=vec, loc=mat, or loc=acc"); - if (hasPreQuant && *srcSpace != pto::AddressSpace::ACC) - return emitOpError("expects preQuantScalar form to use loc=acc src"); - if (reluMode != pto::ReluPreMode::NoRelu && *srcSpace != pto::AddressSpace::ACC) - return emitOpError("expects reluPreMode form to use loc=acc src"); - - Type srcElem = srcTile.getElementType(); - Type dstElem = dstPart.getElementType(); - if (*srcSpace == pto::AddressSpace::VEC || *srcSpace == pto::AddressSpace::MAT) { - if (hasPreQuant) - return emitOpError("expects preQuantScalar form to use loc=acc src"); - if (isPTOLowPrecisionType(dstElem)) - return emitOpError("expects A2/A3 vec/mat tstore low-precision dst element types to be unsupported"); - if (!isLoadStoreElemType(srcElem)) - return emitOpError("expects A2/A3 vec/mat tstore src element type to be i8/i16/i32/i64/u64/f16/bf16/f32"); - if (getElemByteSize(srcElem) != getElemByteSize(dstElem)) - return emitOpError("expects A2/A3 vec/mat tstore src and dst element types to have the same bitwidth"); - return success(); - } - - if (!(srcElem.isInteger(32) || srcElem.isF32())) - return emitOpError("expects A2/A3 acc tstore src element type to be i32 or f32"); - if (hasPreQuant) { - if (srcElem.isInteger(32)) { - if (!(isI8Like(dstElem) || dstElem.isF16())) - return emitOpError("expects A2/A3 acc preQuantScalar tstore dst type to be i8/ui8/f16"); - } else if (srcElem.isF32()) { - if (!isI8Like(dstElem)) - return emitOpError("expects A2/A3 acc preQuantScalar tstore dst type to be i8/ui8"); - } - } else { - if (!(dstElem.isInteger(32) || dstElem.isF32() || dstElem.isF16() || - dstElem.isBF16())) - return emitOpError("expects A2/A3 acc tstore dst element type to be i32/f32/f16/bf16"); - } - - auto srcShape = srcTile.getShape(); - if (srcShape[1] != ShapedType::kDynamic && - (srcShape[1] < 1 || srcShape[1] > 4095)) - return emitOpError("expects A2/A3 acc tstore src cols to be in [1, 4095]"); - auto srcValid = srcTile.getValidShape(); - if (srcValid[1] != ShapedType::kDynamic && - (srcValid[1] < 1 || srcValid[1] > 4095)) - return emitOpError("expects A2/A3 acc tstore src valid_shape[1] to be in [1, 4095]"); - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - auto common = verifyCommon(/*allowLowPrecision=*/true); - if (failed(common)) - return failure(); - auto [srcTile, dstPart] = *common; - auto srcSpace = getPTOMemorySpaceEnum(srcTile); - if (!srcSpace || (*srcSpace != pto::AddressSpace::VEC && - *srcSpace != pto::AddressSpace::ACC)) - return emitOpError("expects A5 tstore src to use loc=vec or loc=acc"); - if (hasPreQuant && *srcSpace != pto::AddressSpace::ACC) - return emitOpError("expects preQuantScalar form to use loc=acc src"); - if (reluMode != pto::ReluPreMode::NoRelu && *srcSpace != pto::AddressSpace::ACC) - return emitOpError("expects reluPreMode form to use loc=acc src"); - - Type srcElem = srcTile.getElementType(); - Type dstElem = dstPart.getElementType(); - if (*srcSpace == pto::AddressSpace::VEC) { - if (hasPreQuant) - return emitOpError("expects preQuantScalar form to use loc=acc src"); - if (!isA5TLoadStoreTransferElemType(srcElem)) - return emitOpError("expects A5 vec tstore src element type to be i8/i16/i32/i64/f16/bf16/f32/f8/hif8/fp4"); - if (getElemByteSize(srcElem) != getElemByteSize(dstElem)) - return emitOpError("expects A5 vec tstore src and dst element types to have the same bitwidth"); - return success(); - } - - if (!(srcElem.isInteger(32) || srcElem.isF32())) - return emitOpError("expects A5 acc tstore src element type to be i32 or f32"); - if (hasPreQuant) { - if (!isA5AccStorePreQuantDstType(srcElem, dstElem)) - return emitOpError("expects A5 acc preQuantScalar tstore dst type to be i8/ui8/f16/bf16/f32/hif8/f8E4M3"); - } else { - if (!(dstElem.isInteger(32) || dstElem.isF32() || dstElem.isF16() || - dstElem.isBF16())) - return emitOpError("expects A5 acc tstore dst element type to be i32/f32/f16/bf16"); - } - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult pto::TAbsOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyVecTileCommon(*this, srcTy, "src")) || - failed(verifyVecTileCommon(*this, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst")) || - failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) - return failure(); - - Type elemTy; - if (auto tb = dyn_cast(srcTy)) - elemTy = tb.getElementType(); - else if (auto mr = dyn_cast(srcTy)) - elemTy = mr.getElementType(); - if (!(elemTy.isF16() || elemTy.isF32())) - return emitOpError() << "expects element type to be f16 or f32"; - - return success(); -} -// PTO.cpp - -static bool isPTOShapedLike(Type ty) { - return mlir::isa(ty); -} - -static bool isTileLikeType(Type ty) { - return isa(ty); -} - -static Type getElemTy(Type ty) { - if (auto mr = mlir::dyn_cast(ty)) return mr.getElementType(); - if (auto tt = mlir::dyn_cast(ty)) return tt.getElementType(); - if (auto tv = mlir::dyn_cast(ty)) return tv.getElementType(); - if (auto tb = mlir::dyn_cast(ty)) return tb.getElementType(); - if (auto tv = mlir::dyn_cast(ty)) return tv.getElementType(); - return Type(); -} - -static SmallVector getShapeVec(Type ty) { - SmallVector s; - if (auto mr = mlir::dyn_cast(ty)) - return SmallVector(mr.getShape().begin(), mr.getShape().end()); - if (auto tt = mlir::dyn_cast(ty)) - return SmallVector(tt.getShape().begin(), tt.getShape().end()); - if (auto tv = mlir::dyn_cast(ty)) - return SmallVector(tv.getShape().begin(), tv.getShape().end()); - if (auto tb = mlir::dyn_cast(ty)) - return SmallVector(tb.getShape().begin(), tb.getShape().end()); - if (auto tv = mlir::dyn_cast(ty)) - return SmallVector(tv.getShape().begin(), tv.getShape().end()); - return {}; -} - -static SmallVector getValidShapeVec(Type ty) { - if (auto tb = dyn_cast(ty)) - return SmallVector(tb.getValidShape().begin(), tb.getValidShape().end()); - return getShapeVec(ty); -} - -static int64_t getLogicalTileDim(int64_t rawDim, Type elemTy, - std::optional blayout, - unsigned dimIdx) { - if (rawDim == ShapedType::kDynamic || !isPTOFloat4PackedType(elemTy)) - return rawDim; - pto::BLayout layout = blayout.value_or(pto::BLayout::RowMajor); - unsigned packedDim = layout == pto::BLayout::ColMajor ? 0 : 1; - return dimIdx == packedDim ? rawDim * 2 : rawDim; -} - -static std::optional getTileBufBLayout(Type ty) { - if (auto tb = dyn_cast(ty)) - return static_cast(tb.getBLayoutValueI32()); - return std::nullopt; -} - -static SmallVector getLogicalTileExtentVec(Type ty, - bool useValidShape) { - SmallVector dims = - useValidShape ? getValidShapeVec(ty) : getShapeVec(ty); - if (!isTileLikeType(ty) || dims.size() != 2) - return dims; - - Type elemTy = getElemTy(ty); - auto blayout = getTileBufBLayout(ty); - for (unsigned i = 0; i < dims.size(); ++i) - dims[i] = getLogicalTileDim(dims[i], elemTy, blayout, i); - return dims; -} - -static int64_t getConstantIndexOrDynamic(Value value) { - if (!value) - return ShapedType::kDynamic; - if (auto cst = value.getDefiningOp()) - return cst.value(); - if (auto cst = value.getDefiningOp()) - return cst.value(); - return ShapedType::kDynamic; -} - -static SmallVector getValidShapeVec(Value value) { - if (!value) - return {}; - auto valid = getValidShapeVec(value.getType()); - if (auto bind = value.getDefiningOp()) { - if (valid.size() >= 1 && bind.getValidRow()) - valid[0] = getConstantIndexOrDynamic(bind.getValidRow()); - if (valid.size() >= 2 && bind.getValidCol()) - valid[1] = getConstantIndexOrDynamic(bind.getValidCol()); - } - return valid; -} - -static SmallVector getMatmulLogicalShapeVec(Type ty) { - auto shape = getShapeVec(ty); - auto valid = getValidShapeVec(ty); - if (!isa(ty) || shape.size() != valid.size()) - return shape; - - for (size_t i = 0, e = shape.size(); i < e; ++i) { - if (valid[i] != ShapedType::kDynamic) - shape[i] = valid[i]; - } - return shape; -} - -static bool isByteIntegerType(Type ty) { - auto intTy = dyn_cast(ty); - return intTy && intTy.getWidth() == 8; -} - -static LogicalResult verifyAsyncFlatContiguous1DGMMemRef(Operation *op, - Value value, - StringRef name) { - auto memTy = dyn_cast(value.getType()); - if (!memTy) - return op->emitOpError() << "expects " << name << " to be a memref"; - if (!memTy.hasRank()) - return op->emitOpError() << "expects " << name << " to be a ranked memref"; - if (!isGmAddressSpaceAttr(memTy.getMemorySpace())) - return op->emitOpError() << "expects " << name - << " to be in GM address space"; - - ArrayRef shape = memTy.getShape(); - if (shape.empty()) - return op->emitOpError() << "expects " << name - << " to have rank >= 1"; - for (int64_t dim : shape) { - if (dim == ShapedType::kDynamic) - return op->emitOpError() << "expects " << name - << " to have a static shape"; - } - - SmallVector strides; - int64_t offset = 0; - if (failed(getStridesAndOffset(memTy, strides, offset))) - return op->emitOpError() << "expects " << name - << " to be a strided memref with a known layout"; - - bool hasDynamicLayout = - offset == ShapedType::kDynamic || - llvm::any_of(strides, [](int64_t stride) { - return stride == ShapedType::kDynamic; - }); - if (hasDynamicLayout) - return success(); - - bool packed = !strides.empty() && strides.back() == 1; - for (int i = static_cast(shape.size()) - 2; i >= 0 && packed; --i) - packed &= strides[i] == strides[i + 1] * shape[i + 1]; - if (!packed) - return op->emitOpError() - << "expects " << name - << " to be a static flat contiguous logical 1D GM memref"; - - bool logical1D = true; - for (int i = 0, e = static_cast(shape.size()) - 1; i < e; ++i) - logical1D &= shape[i] == 1; - if (!logical1D) - return op->emitOpError() - << "expects " << name - << " to be a static flat contiguous logical 1D GM memref"; - - return success(); -} - -static LogicalResult verifyAsyncFlatContiguous1DGMViewLike(Operation *op, - Value value, - StringRef name) { - Type ty = value.getType(); - if (isa(ty)) - return verifyAsyncFlatContiguous1DGMMemRef(op, value, name); - - if (!isa(ty)) - return op->emitOpError() << "expects " << name - << " to be a memref/tensor_view/partition_view"; - - SmallVector shape = getShapeVec(ty); - if (shape.empty()) - return op->emitOpError() << "expects " << name << " to have rank >= 1"; - for (int64_t dim : shape) { - if (dim == ShapedType::kDynamic) - return op->emitOpError() << "expects " << name - << " to have a static shape"; - } - - bool logical1D = true; - for (int i = 0, e = static_cast(shape.size()) - 1; i < e; ++i) - logical1D &= shape[i] == 1; - if (!logical1D) - return op->emitOpError() - << "expects " << name - << " to be a static flat contiguous logical 1D GM view"; - - return success(); -} - -static bool isCommGlobalLikeType(Type ty) { - if (auto memTy = dyn_cast(ty)) - return isGmAddressSpaceAttr(memTy.getMemorySpace()); - return isa(ty); -} - -static LogicalResult verifyCommGlobalLike(Operation *op, Value value, - StringRef name) { - Type ty = value.getType(); - if (!isCommGlobalLikeType(ty)) - return op->emitOpError() << "expects " << name - << " to be a GM memref/tensor_view/partition_view"; - - SmallVector shape = getShapeVec(ty); - if (shape.empty()) - return op->emitOpError() << "expects " << name << " to have rank >= 1"; - for (int64_t dim : shape) { - if (dim == ShapedType::kDynamic || dim <= 0) - return op->emitOpError() << "expects " << name - << " to have a positive static shape"; - } - return success(); -} - -static LogicalResult verifyCommSignalLike(Operation *op, Value value, - StringRef name) { - if (failed(verifyCommGlobalLike(op, value, name))) - return failure(); - Type elemTy = getElemTy(value.getType()); - if (!elemTy || !elemTy.isSignlessInteger(32)) - return op->emitOpError() << "expects " << name - << " element type to be i32"; - return success(); -} - -static LogicalResult verifyCommStagingTileLike(Operation *op, Value value, - StringRef name) { - Type ty = value.getType(); - if (!isa(ty)) - return op->emitOpError() << "expects " << name - << " to be a tile_buf or memref tile"; - auto as = getPTOMemorySpaceEnum(ty); - if (!as || *as != pto::AddressSpace::VEC) - return op->emitOpError() << "expects " << name - << " to be in vec address space"; - SmallVector shape = getShapeVec(ty); - if (shape.empty()) - return op->emitOpError() << "expects " << name << " to have rank >= 1"; - for (int64_t dim : shape) { - if (dim == ShapedType::kDynamic || dim <= 0) - return op->emitOpError() << "expects " << name - << " to have a positive static shape"; - } - return success(); -} - -static LogicalResult verifyCommGlobalGroup(Operation *op, ValueRange group, - StringRef name) { - if (group.empty()) - return op->emitOpError() << "expects at least one " << name << " operand"; - Type groupTy = group.front().getType(); - for (auto it : llvm::enumerate(group)) { - if (failed(verifyCommGlobalLike(op, it.value(), - (name + "[" + Twine(it.index()) + "]").str()))) - return failure(); - if (it.value().getType() != groupTy) - return op->emitOpError() << "expects all " << name - << " operands to have identical types"; - } - return success(); -} - -static LogicalResult verifyCommPingPongSameType(Operation *op, Value ping, - Value pong, StringRef pingName, - StringRef pongName) { - if (!pong) - return success(); - if (failed(verifyCommStagingTileLike(op, ping, pingName)) || - failed(verifyCommStagingTileLike(op, pong, pongName))) - return failure(); - if (ping.getType() != pong.getType()) - return op->emitOpError() << "expects " << pingName << " and " << pongName - << " to have identical types"; - return success(); -} - -static std::optional getStaticByteSize(Type ty) { - SmallVector shape = getShapeVec(ty); - if (shape.empty()) - return std::nullopt; - for (int64_t dim : shape) { - if (dim == ShapedType::kDynamic || dim < 0) - return std::nullopt; - } - - Type elemTy = getElemTy(ty); - uint64_t elemBytes = getElemByteSize(elemTy); - if (elemBytes == 0) - return std::nullopt; - - uint64_t total = elemBytes; - for (int64_t dim : shape) { - total *= static_cast(dim); - } - return total; -} - -static std::optional getPTOMemorySpaceEnum(Type ty) { - if (auto tb = dyn_cast(ty)) { - if (auto as = dyn_cast_or_null(tb.getMemorySpace())) - return as.getAddressSpace(); - return std::nullopt; - } - if (auto mr = dyn_cast(ty)) { - if (auto as = dyn_cast_or_null(mr.getMemorySpace())) - return as.getAddressSpace(); - if (!mr.getMemorySpace()) - return pto::AddressSpace::GM; - } - return std::nullopt; -} - -[[maybe_unused]] static bool isRank2TileBuf(Type ty) { - auto tb = dyn_cast(ty); - return tb && tb.getRank() == 2 && tb.getValidShape().size() == 2; -} - -static bool isSupportedVecElemType(Type ty, bool allowBf16, - bool allowInt8) { - if (ty.isF16() || ty.isF32()) - return true; - if (allowBf16 && ty.isBF16()) - return true; - if (auto it = dyn_cast(ty)) { - switch (it.getWidth()) { - case 32: - case 16: - return true; - case 8: - return allowInt8; - default: - return false; - } - } - return false; -} - -static bool isSupportedMGatherMScatterIndexElemType(Type ty) { - auto it = dyn_cast(ty); - if (!it || it.getWidth() != 32) - return false; - return true; -} - -static bool isSupportedMGatherMScatterPayloadElemType(Operation *op, Type ty) { - if (isSupportedVecElemType(ty, /*allowBf16=*/true, /*allowInt8=*/true)) - return true; - if (!isTargetArchA5(op)) - return false; - return ty.isFloat8E4M3() || ty.isFloat8E4M3FN() || ty.isFloat8E4M3FNUZ() || - ty.isFloat8E4M3B11FNUZ() || ty.isFloat8E5M2() || ty.isFloat8E5M2FNUZ(); -} - -static bool isSupportedMScatterAtomicPayloadElemType(Type ty, - pto::ScatterAtomicOp atomic) { - auto intTy = dyn_cast(ty); - switch (atomic) { - case pto::ScatterAtomicOp::None: - return true; - case pto::ScatterAtomicOp::Add: - return ty.isF16() || ty.isF32() || - (intTy && intTy.getWidth() == 32); - case pto::ScatterAtomicOp::Max: - case pto::ScatterAtomicOp::Min: - return ty.isF32() || - (intTy && intTy.getWidth() == 32); - } - llvm_unreachable("Unknown ScatterAtomicOp"); -} - -static LogicalResult verifyMGatherMScatterMemOperand(Operation *op, - Value memValue, - Type dataElemTy, - StringRef dataOperandLabel) { - Type memTy = memValue.getType(); - Type memElem = getElemTy(memTy); - if (!memElem || memElem != dataElemTy) - return op->emitOpError() << "expects mem element type to match " - << dataOperandLabel << " element type"; - - if (isa(memTy)) { - if (auto layout = getLogicalViewLayout(memValue)) { - if (*layout != pto::Layout::ND) - return op->emitOpError( - "expects mem partition view to use ND logical layout when layout " - "can be inferred"); - } - return success(); - } - - if (auto mr = dyn_cast(memTy)) { - auto as = getPTOMemorySpaceEnum(mr); - if (!as || (*as != pto::AddressSpace::GM && - *as != pto::AddressSpace::Zero)) - return op->emitOpError( - "expects mem memref to use GM or zero address space"); - if (mr.getRank() == 5) { - auto shape = mr.getShape(); - bool allStatic = true; - for (int64_t d : shape) - if (d == ShapedType::kDynamic) - allStatic = false; - if (allStatic && (shape[0] != 1 || shape[1] != 1 || shape[2] != 1)) - return op->emitOpError( - "expects rank-5 GM memref leading dimensions to be [1,1,1,...] " - "(GlobalTensor table shape)"); - } - return success(); - } - - return op->emitOpError( - "expects mem to be !pto.partition_tensor_view or a GM/ZERO memref"); -} - -static bool hasCompatibleKnownExtent(int64_t lhs, int64_t rhs); -static bool isKnownUnitExtent(int64_t value); - -static LogicalResult verifyMGatherMScatterTileShape(Operation *op, Type dataTy, - Type idxTy, - StringRef dataName) { - auto dataValid = getValidShapeVec(dataTy); - auto idxValid = getValidShapeVec(idxTy); - if (dataValid.size() != 2 || idxValid.size() != 2) - return op->emitOpError() << "expects " << dataName - << " and idx to have rank-2 valid_shape"; - - auto idxTile = dyn_cast(idxTy); - if (!idxTile) - return op->emitOpError("expects idx to be a tile_buf type"); - - const bool idxRowMajor = - idxTile.getBLayoutValueI32() == - static_cast(pto::BLayout::RowMajor); - const bool idxColMajor = - idxTile.getBLayoutValueI32() == - static_cast(pto::BLayout::ColMajor); - - const bool rowCoalesce1xR = - idxRowMajor && isKnownUnitExtent(idxValid[0]) && - hasCompatibleKnownExtent(idxValid[1], dataValid[0]); - const bool rowCoalesceRx1 = - idxColMajor && hasCompatibleKnownExtent(idxValid[0], dataValid[0]) && - isKnownUnitExtent(idxValid[1]); - const bool elemCoalesce = - hasCompatibleKnownExtent(idxValid[0], dataValid[0]) && - hasCompatibleKnownExtent(idxValid[1], dataValid[1]); - - if (!(rowCoalesce1xR || rowCoalesceRx1 || elemCoalesce)) - return op->emitOpError() - << "expects idx valid_shape to be [1, " << dataName - << ".valid_row], [" << dataName - << ".valid_row, 1], or match " << dataName << " valid_shape"; - - return success(); -} - -static LogicalResult verifyMGatherMScatterIdxTile(Operation *op, Type ty, - StringRef name) { - if (failed(verifyTileBufCommon(op, ty, name))) - return failure(); - auto as = getPTOMemorySpaceEnum(ty); - if (!as || *as != pto::AddressSpace::VEC) - return op->emitOpError() << "expects " << name - << " to be in the vec address space"; - auto tb = dyn_cast(ty); - if (!tb) - return op->emitOpError() << "expects " << name << " to be a tile_buf type"; - int32_t blayout = tb.getBLayoutValueI32(); - if (blayout != static_cast(pto::BLayout::RowMajor) && - blayout != static_cast(pto::BLayout::ColMajor)) - return op->emitOpError() << "expects " << name - << " to use row_major or col_major blayout"; - if (tb.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) - return op->emitOpError() << "expects " << name - << " to use the none_box slayout"; - return success(); -} - -static bool isA5TLoadStoreTransferElemType(Type ty) { - return ty.isInteger(8) || ty.isInteger(16) || ty.isInteger(32) || - ty.isInteger(64) || ty.isF16() || ty.isBF16() || ty.isF32() || - isPTOLowPrecisionType(ty); -} - -static bool isA5AccStorePreQuantDstType(Type srcElem, Type dstElem) { - if (srcElem.isInteger(32)) - return dstElem.isInteger(8) || dstElem.isF16() || dstElem.isBF16(); - if (!srcElem.isF32()) - return false; - return dstElem.isInteger(8) || dstElem.isF16() || dstElem.isBF16() || - dstElem.isF32() || isPTOHiFloat8Type(dstElem) || - dstElem.isFloat8E4M3() || dstElem.isFloat8E4M3FN() || - dstElem.isFloat8E4M3FNUZ() || dstElem.isFloat8E4M3B11FNUZ(); -} - -static bool isA5LowPrecisionTCvtPair(Type srcElem, Type dstElem) { - if (srcElem.isF32()) - return isPTOFloat8Type(dstElem) || isPTOHiFloat8Type(dstElem); - if (srcElem.isF16()) - return isPTOHiFloat8Type(dstElem); - if (srcElem.isBF16()) - return isPTOFloat4PackedType(dstElem); - if (isPTOFloat4PackedType(srcElem)) - return dstElem.isBF16(); - if (isPTOFloat8Type(srcElem) || isPTOHiFloat8Type(srcElem)) - return dstElem.isF32(); - return false; -} - -static bool isA5SupportedTCvtPair(Type srcElem, Type dstElem) { - if (isPTOLowPrecisionType(srcElem) || isPTOLowPrecisionType(dstElem)) - return isA5LowPrecisionTCvtPair(srcElem, dstElem); - return true; -} - -static LogicalResult verifyTileBufCommon(Operation *op, Type ty, StringRef name, - bool allowLowPrecision) { - auto tb = dyn_cast(ty); - if (tb) { - if (tb.getRank() != 2) - return op->emitOpError() << "expects " << name << " to be a rank-2 tile_buf"; - Type elemTy = tb.getElementType(); - if (!allowLowPrecision && isPTOLowPrecisionType(elemTy)) - return op->emitOpError() << name << ": dtype " << elemTy - << " is not supported by this op yet"; - } else if (auto mr = dyn_cast(ty)) { - if (mr.getRank() != 2) - return op->emitOpError() << "expects " << name << " to be a rank-2 memref"; - if (!allowLowPrecision && isPTOLowPrecisionType(mr.getElementType())) - return op->emitOpError() << name << ": dtype " << mr.getElementType() - << " is not supported by this op yet"; - } else { - return op->emitOpError() << "expects " << name << " to be a !pto.tile_buf or rank-2 memref"; - } - - auto validShape = getValidShapeVec(ty); - if (validShape.size() != 2) - return op->emitOpError() << "expects " << name << " to have a rank-2 valid_shape"; - auto shape = getShapeVec(ty); - for (unsigned i = 0; i < 2; ++i) { - if (shape[i] != ShapedType::kDynamic && validShape[i] != ShapedType::kDynamic && - validShape[i] > shape[i]) - return op->emitOpError() << "expects " << name << " to satisfy valid_shape[" << i - << "] <= shape[" << i << "]"; - } - return success(); -} - -static LogicalResult verifyTileBufSameElemType(Operation *op, Type lhs, Type rhs, - StringRef lhsName, - StringRef rhsName) { - if (!isTileLikeType(lhs) || !isTileLikeType(rhs)) - return op->emitOpError() << "expects " << lhsName << " and " << rhsName - << " to be !pto.tile_buf or memref"; - if (getElemTy(lhs) != getElemTy(rhs)) - return op->emitOpError() << "expects " << lhsName << " and " << rhsName - << " to have the same element type"; - return success(); -} - -static LogicalResult verifyTileBufSameValidShape(Operation *op, Type lhs, Type rhs, - StringRef lhsName, StringRef rhsName) { - if (!isTileLikeType(lhs) || !isTileLikeType(rhs)) - return success(); - auto lhsValid = getValidShapeVec(lhs); - auto rhsValid = getValidShapeVec(rhs); - for (size_t i = 0; i < lhsValid.size() && i < rhsValid.size(); ++i) { - if (lhsValid[i] != ShapedType::kDynamic && rhsValid[i] != ShapedType::kDynamic && - lhsValid[i] != rhsValid[i]) - return op->emitOpError() << "expects " << lhsName << " and " << rhsName - << " to have the same valid_shape"; - } - if (lhsValid.size() != rhsValid.size()) - return op->emitOpError() << "expects " << lhsName << " and " << rhsName - << " to have the same valid_shape"; - return success(); -} - -static LogicalResult verifyTileBufSameLogicalExtent(Operation *op, Type lhs, - Type rhs, StringRef lhsName, - StringRef rhsName, - bool compareValidShape) { - if (!isTileLikeType(lhs) || !isTileLikeType(rhs)) - return success(); - - auto lhsExtent = getLogicalTileExtentVec(lhs, compareValidShape); - auto rhsExtent = getLogicalTileExtentVec(rhs, compareValidShape); - auto emitMismatch = [&]() -> LogicalResult { - if (compareValidShape) - return op->emitOpError() << "expects " << lhsName << " and " << rhsName - << " to have the same valid_shape"; - return op->emitOpError() << "expects " << lhsName << " and " << rhsName - << " to have compatible shapes"; - }; - if (lhsExtent.size() != rhsExtent.size()) - return emitMismatch(); - - for (size_t i = 0, e = lhsExtent.size(); i < e; ++i) { - if (lhsExtent[i] != ShapedType::kDynamic && - rhsExtent[i] != ShapedType::kDynamic && lhsExtent[i] != rhsExtent[i]) - return emitMismatch(); - } - return success(); -} - -static LogicalResult verifyScaleTileMatchesOperand(Operation *op, Type scaleTy, - Type operandTy, - StringRef scaleName, - StringRef operandName) { - if (failed(verifyTileBufCommon(op, scaleTy, scaleName))) - return failure(); - auto scaleSpace = getPTOMemorySpaceEnum(scaleTy); - if (!scaleSpace || *scaleSpace != pto::AddressSpace::SCALING) - return op->emitOpError() << "expects " << scaleName - << " to be in the scaling address space"; - - auto scaleShape = getShapeVec(scaleTy); - auto operandShape = getShapeVec(operandTy); - if (scaleShape.size() != operandShape.size()) - return op->emitOpError() << "expects " << scaleName << " and " << operandName - << " to have the same rank"; - for (size_t i = 0; i < scaleShape.size(); ++i) { - if (scaleShape[i] != ShapedType::kDynamic && - operandShape[i] != ShapedType::kDynamic && - scaleShape[i] != operandShape[i]) - return op->emitOpError() << "expects " << scaleName << " and " << operandName - << " to have the same shape"; - } - - auto scaleValid = getValidShapeVec(scaleTy); - auto operandValid = getValidShapeVec(operandTy); - if (scaleValid.size() != operandValid.size()) - return op->emitOpError() << "expects " << scaleName << " and " << operandName - << " to have the same valid_shape"; - for (size_t i = 0; i < scaleValid.size(); ++i) { - if (scaleValid[i] != ShapedType::kDynamic && - operandValid[i] != ShapedType::kDynamic && - scaleValid[i] != operandValid[i]) - return op->emitOpError() << "expects " << scaleName << " and " << operandName - << " to have the same valid_shape"; - } - return success(); -} - -static LogicalResult verifyPartialValidPattern(Operation *op, Type src0Ty, - Type src1Ty, Type dstTy) { - auto src0Valid = getValidShapeVec(src0Ty); - auto src1Valid = getValidShapeVec(src1Ty); - auto dstValid = getValidShapeVec(dstTy); - if (src0Valid.size() != 2 || src1Valid.size() != 2 || dstValid.size() != 2) - return op->emitOpError("expects src0, src1, and dst to have rank-2 valid_shape"); - - auto lessEqualKnown = [](int64_t lhs, int64_t rhs) { - return lhs == ShapedType::kDynamic || rhs == ShapedType::kDynamic || lhs <= rhs; - }; - auto equalsKnown = [](ArrayRef lhs, ArrayRef rhs) { - for (auto [a, b] : llvm::zip(lhs, rhs)) { - if (a != ShapedType::kDynamic && b != ShapedType::kDynamic && a != b) - return false; - } - return true; - }; - - for (unsigned i = 0; i < 2; ++i) { - if (!lessEqualKnown(src0Valid[i], dstValid[i]) || - !lessEqualKnown(src1Valid[i], dstValid[i])) - return op->emitOpError( - "expects src0/src1 valid_shape to be less than or equal to dst valid_shape"); - } - if (!equalsKnown(src0Valid, dstValid) && !equalsKnown(src1Valid, dstValid)) - return op->emitOpError( - "expects at least one of src0/src1 valid_shape to match dst valid_shape"); - return success(); -} - -[[maybe_unused]] static bool hasKnownZeroValidRegion(Type ty) { - auto valid = getValidShapeVec(ty); - if (valid.size() != 2) - return false; - return valid[0] == 0 || valid[1] == 0; -} - -static LogicalResult verifyScalarTileOp(Operation *op, Type srcTy, Type dstTy, - StringRef srcName, StringRef dstName, - bool requireValidRowsEqual, - bool requireValidColsEqual) { - if (failed(verifyTileBufCommon(op, srcTy, srcName)) || - failed(verifyTileBufCommon(op, dstTy, dstName))) - return failure(); - auto srcSpace = getPTOMemorySpaceEnum(srcTy); - auto dstSpace = getPTOMemorySpaceEnum(dstTy); - if (!srcSpace || *srcSpace != pto::AddressSpace::VEC) - return op->emitOpError() << "expects " << srcName - << " to be in the vec address space"; - if (!dstSpace || *dstSpace != pto::AddressSpace::VEC) - return op->emitOpError() << "expects " << dstName - << " to be in the vec address space"; - if (failed(verifyTileBufSameElemType(op, srcTy, dstTy, srcName, dstName))) - return failure(); - - auto srcValid = getValidShapeVec(srcTy); - auto dstValid = getValidShapeVec(dstTy); - if (srcValid.size() != 2 || dstValid.size() != 2) - return op->emitOpError() - << "expects " << srcName << " and " << dstName - << " to have rank-2 valid_shape"; - if (requireValidRowsEqual && - srcValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && - srcValid[0] != dstValid[0]) - return op->emitOpError() - << "expects " << srcName << " and " << dstName - << " to have the same valid_shape[0]"; - if (requireValidColsEqual && - srcValid[1] != ShapedType::kDynamic && dstValid[1] != ShapedType::kDynamic && - srcValid[1] != dstValid[1]) - return op->emitOpError() - << "expects " << srcName << " and " << dstName - << " to have the same valid_shape[1]"; - return success(); -} - -static FailureOr -verifyMatchingRowMajorBinaryTileOpCommon(Operation *op, Type src0Ty, Type src1Ty, - Type dstTy) { - if (failed(verifyTileBufCommon(op, src0Ty, "src0")) || - failed(verifyTileBufCommon(op, src1Ty, "src1")) || - failed(verifyTileBufCommon(op, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameElemType(op, src0Ty, src1Ty, "src0", "src1")) || - failed(verifyTileBufSameElemType(op, src0Ty, dstTy, "src0", "dst")) || - failed(verifyTileBufSameValidShape(op, src0Ty, src1Ty, "src0", "src1")) || - failed(verifyTileBufSameValidShape(op, src0Ty, dstTy, "src0", "dst"))) - return failure(); - if (!isRowMajorTileBuf(src0Ty) || !isRowMajorTileBuf(src1Ty) || - !isRowMajorTileBuf(dstTy)) { - op->emitOpError("expects src0, src1, and dst to use row-major layout"); - return failure(); - } - return getElemTy(src0Ty); -} - -static FailureOr -verifyNumericScalarTileOpCommon(Operation *op, Type srcTy, Type dstTy, - Type scalarTy, bool requireValidRowsEqual) { - if (failed(verifyScalarTileOp(op, srcTy, dstTy, "src", "dst", - requireValidRowsEqual, - /*requireValidColsEqual=*/true))) - return failure(); - if (!mlir::isa(scalarTy)) { - op->emitOpError("scalar must be a scalar type (integer/float)"); - return failure(); - } - return getElemTy(srcTy); -} - -static FailureOr -verifyShiftLikeBinaryTileOpCommon(Operation *op, Type src0Ty, Type src1Ty, - Type dstTy) { - if (failed(verifyTileBufCommon(op, src0Ty, "src0")) || - failed(verifyTileBufCommon(op, src1Ty, "src1")) || - failed(verifyTileBufCommon(op, dstTy, "dst"))) - return failure(); - Type e0 = getElemTy(src0Ty); - Type e1 = getElemTy(src1Ty); - if (!e0 || !e1) { - op->emitOpError("failed to get element type for operands"); - return failure(); - } - if (e0 != e1) { - op->emitOpError("expects src0 and src1 to have the same element type"); - return failure(); - } - if (!isRowMajorTileBuf(src0Ty) || !isRowMajorTileBuf(src1Ty) || - !isRowMajorTileBuf(dstTy)) { - op->emitOpError("expects src0, src1, and dst to use row-major layout"); - return failure(); - } - if (failed(verifyTileBufSameValidShape(op, src0Ty, dstTy, "src0", "dst")) || - failed(verifyTileBufSameValidShape(op, src1Ty, dstTy, "src1", "dst"))) - return failure(); - return e0; -} - -static FailureOr verifyDistinctRowMajorUnaryTileOpCommon( - Operation *op, Value src, Value dst, StringRef srcName = "src", - StringRef dstName = "dst") { - if (src == dst) { - op->emitOpError("expects src and dst to use different storage"); - return failure(); - } - Type srcTy = src.getType(); - Type dstTy = dst.getType(); - if (failed(verifyTileBufCommon(op, srcTy, srcName)) || - failed(verifyTileBufCommon(op, dstTy, dstName))) - return failure(); - - Type srcElem = getElemTy(srcTy); - Type dstElem = getElemTy(dstTy); - if (!srcElem || !dstElem) { - op->emitOpError("failed to get element type for src/dst"); - return failure(); - } - if (srcElem != dstElem) { - op->emitOpError("expects src and dst to have the same element type"); - return failure(); - } - if (!isRowMajorTileBuf(srcTy) || !isRowMajorTileBuf(dstTy)) { - op->emitOpError("expects src and dst to use row-major layout"); - return failure(); - } - if (failed(verifyTileBufSameValidShape(op, srcTy, dstTy, srcName, dstName))) - return failure(); - return srcElem; -} - -static LogicalResult verifyArithmeticElemTypeForArch( - Operation *op, Type elemTy, PTOArch targetArch, bool allowInt8OnA5, - bool allowBf16OnA5, StringRef a2a3Error, StringRef a5Error) { - bool supported = elemTy.isInteger(32) || elemTy.isInteger(16) || - elemTy.isF16() || elemTy.isF32(); - if (targetArch == PTOArch::A5) - supported = supported || (allowInt8OnA5 && elemTy.isInteger(8)) || - (allowBf16OnA5 && elemTy.isBF16()); - if (supported) - return success(); - return op->emitOpError(targetArch == PTOArch::A5 ? a5Error : a2a3Error); -} - -static LogicalResult verifyArithmeticBinaryTileOpWithArchDispatch( - Operation *op, Type src0Ty, Type src1Ty, Type dstTy, bool allowInt8OnA5, - bool allowBf16OnA5, StringRef a2a3Error, StringRef a5Error) { - auto verifyByArch = [&](PTOArch targetArch) -> LogicalResult { - FailureOr elemOr = - verifyMatchingRowMajorBinaryTileOpCommon(op, src0Ty, src1Ty, dstTy); - if (failed(elemOr)) - return failure(); - return verifyArithmeticElemTypeForArch(op, *elemOr, targetArch, - allowInt8OnA5, allowBf16OnA5, - a2a3Error, a5Error); - }; - auto verifyA2A3 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A3); }; - auto verifyA5 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A5); }; - return dispatchVerifierByArch(op, verifyA2A3, verifyA5); -} - -static LogicalResult verifyArithmeticScalarTileOpWithArchDispatch( - Operation *op, Type srcTy, Type dstTy, Type scalarTy, bool allowInt8OnA5, - bool allowBf16OnA5, StringRef a2a3Error, StringRef a5Error, - bool requireValidRowsEqualOnA2A3 = true, - bool requireValidRowsEqualOnA5 = false) { - auto verifyByArch = [&](PTOArch targetArch, - bool requireValidRowsEqual) -> LogicalResult { - FailureOr elemOr = verifyNumericScalarTileOpCommon( - op, srcTy, dstTy, scalarTy, requireValidRowsEqual); - if (failed(elemOr)) - return failure(); - return verifyArithmeticElemTypeForArch(op, *elemOr, targetArch, - allowInt8OnA5, allowBf16OnA5, - a2a3Error, a5Error); - }; - auto verifyA2A3 = [&]() -> LogicalResult { - return verifyByArch(PTOArch::A3, requireValidRowsEqualOnA2A3); - }; - auto verifyA5 = [&]() -> LogicalResult { - return verifyByArch(PTOArch::A5, requireValidRowsEqualOnA5); - }; - return dispatchVerifierByArch(op, verifyA2A3, verifyA5); -} - -static LogicalResult verifyTColReductionElemTypeForArch( - Operation *op, Type elemTy, PTOArch targetArch, bool allowInt8OnA5, - bool allowBf16OnA5, StringRef a2a3Error, StringRef a5Error) { - bool ok = elemTy.isF16() || elemTy.isF32() || elemTy.isInteger(16) || - elemTy.isInteger(32); - if (targetArch == PTOArch::A5) - ok = ok || (allowInt8OnA5 && elemTy.isInteger(8)) || - (allowBf16OnA5 && elemTy.isBF16()); - if (ok) - return success(); - return op->emitOpError(targetArch == PTOArch::A5 ? a5Error : a2a3Error); -} - -static LogicalResult verifyTColReductionOpWithArchDispatch( - Operation *op, Type srcTy, Type dstTy, bool requireNonZeroSrcOnA2A3, - bool requireNonZeroSrcOnA5, bool allowInt8OnA5, bool allowBf16OnA5, - StringRef a2a3Error, StringRef a5Error) { - auto verifyByArch = [&](PTOArch targetArch, - bool requireNonZeroSrc) -> LogicalResult { - if (failed(verifyNDStyleVecTile(op, srcTy, "src")) || - failed(verifyNDStyleVecTile(op, dstTy, "dst"))) - return failure(); - if (getElemTy(srcTy) != getElemTy(dstTy)) - return op->emitOpError("expects src and dst to have the same element type"); - if (failed(verifyColReductionValidRegion(op, srcTy, dstTy, requireNonZeroSrc))) - return failure(); - Type elem = getElemTy(srcTy); - return verifyTColReductionElemTypeForArch(op, elem, targetArch, allowInt8OnA5, - allowBf16OnA5, a2a3Error, a5Error); - }; - auto verifyA2A3 = [&]() -> LogicalResult { - return verifyByArch(PTOArch::A3, requireNonZeroSrcOnA2A3); - }; - auto verifyA5 = [&]() -> LogicalResult { - return verifyByArch(PTOArch::A5, requireNonZeroSrcOnA5); - }; - return dispatchVerifierByArch(op, verifyA2A3, verifyA5); -} - -static LogicalResult verifyTColArgReductionOpCommon(Operation *op, Type srcTy, - Type tmpTy, Type dstTy) { - if (failed(verifyNDStyleVecTile(op, srcTy, "src")) || - failed(verifyVecTileCommon(op, tmpTy, "tmp")) || - failed(verifyColArgReductionDstLayout(op, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameElemType(op, srcTy, tmpTy, "src", "tmp")) || - failed(verifyTileBufSameValidShape(op, srcTy, tmpTy, "src", "tmp"))) - return failure(); - if (failed(verifyColReductionValidRegion(op, srcTy, dstTy, - /*requireNonZeroSrc=*/true))) - return failure(); - Type srcElemTy = getElemTy(srcTy); - unsigned srcElemBits = srcElemTy ? getPTOStorageElemBitWidth(srcElemTy) : 0; - if (!(mlir::isa(srcElemTy) && - (srcElemBits == 8 || srcElemBits == 16 || srcElemBits == 32))) - return op->emitOpError( - "expects src/tmp element type to be 1, 2, or 4 bytes wide"); - auto dstInt = dyn_cast(getElemTy(dstTy)); - if (!dstInt || dstInt.getWidth() != 32) - return op->emitOpError("expects dst element type to be i32 or ui32"); - return success(); -} - -static bool hasCompatibleKnownExtent(int64_t lhs, int64_t rhs) { - return lhs == ShapedType::kDynamic || rhs == ShapedType::kDynamic || lhs == rhs; -} - -static bool isKnownUnitExtent(int64_t value) { - return value == ShapedType::kDynamic || value == 1; -} - -static LogicalResult verifyVecTileStorage(Operation *op, Type ty, StringRef name) { - if (failed(verifyTileBufCommon(op, ty, name))) - return failure(); - auto as = getPTOMemorySpaceEnum(ty); - if (!as || *as != pto::AddressSpace::VEC) - return op->emitOpError() << "expects " << name << " to be in the vec address space"; - return success(); -} - -static LogicalResult verifyVecTileCommonA2A3(Operation *op, Type ty, - StringRef name) { - if (failed(verifyTileBufCommon(op, ty, name))) - return failure(); - auto tb = dyn_cast(ty); - auto as = getPTOMemorySpaceEnum(ty); - if (as && *as != pto::AddressSpace::VEC) - return op->emitOpError() << "expects " << name << " to be in the vec address space"; - if (tb && tb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor)) - return op->emitOpError() << "expects " << name << " to use the row_major blayout"; - return success(); -} - -static LogicalResult verifyVecTileCommonA5(Operation *op, Type ty, - StringRef name) { - return verifyVecTileCommonA2A3(op, ty, name); -} - -static LogicalResult verifyVecTileCommon(Operation *op, Type ty, StringRef name) { - switch (getVerifierTargetArch(op)) { - case VerifierTargetArch::A2A3: - return verifyVecTileCommonA2A3(op, ty, name); - case VerifierTargetArch::A5: - return verifyVecTileCommonA5(op, ty, name); - } - return failure(); -} - -static LogicalResult verifyVecTileUnaryOp(Operation *op, Type srcTy, Type dstTy, - StringRef srcName, - StringRef dstName, - bool allowBf16, - bool allowInt8) { - if (failed(verifyVecTileCommon(op, srcTy, srcName)) || - failed(verifyVecTileCommon(op, dstTy, dstName))) - return failure(); - if (failed(verifyTileBufSameElemType(op, srcTy, dstTy, srcName, dstName))) - return failure(); - if (!isSupportedVecElemType(getElemTy(srcTy), allowBf16, allowInt8)) - return op->emitOpError() << "expects vec tile element types to be supported"; - return success(); -} - -static LogicalResult verifyAccTileCommonA2A3(Operation *op, Type ty, - StringRef name) { - if (failed(verifyTileBufCommon(op, ty, name))) - return failure(); - auto as = getPTOMemorySpaceEnum(ty); - if (!as || *as != pto::AddressSpace::ACC) - return op->emitOpError() << "expects " << name << " to be in the acc address space"; - return success(); -} - -static LogicalResult verifyAccTileCommonA5(Operation *op, Type ty, - StringRef name) { - return verifyAccTileCommonA2A3(op, ty, name); -} - -static LogicalResult verifyAccTileCommon(Operation *op, Type ty, StringRef name) { - switch (getVerifierTargetArch(op)) { - case VerifierTargetArch::A2A3: - return verifyAccTileCommonA2A3(op, ty, name); - case VerifierTargetArch::A5: - return verifyAccTileCommonA5(op, ty, name); - } - return failure(); -} - -static LogicalResult verifyMatTileOperandsA2A3(Operation *op, Type lhsTy, - Type rhsTy, Type dstTy) { - if (failed(verifyTileBufCommon(op, lhsTy, "lhs")) || - failed(verifyTileBufCommon(op, rhsTy, "rhs")) || - failed(verifyAccTileCommon(op, dstTy, "dst"))) - return failure(); - auto lhsSpace = getPTOMemorySpaceEnum(lhsTy); - auto rhsSpace = getPTOMemorySpaceEnum(rhsTy); - auto dstSpace = getPTOMemorySpaceEnum(dstTy); - if (!lhsSpace || !rhsSpace || !dstSpace) - return op->emitOpError("expects lhs, rhs, and dst to have explicit address spaces"); - if (*lhsSpace != pto::AddressSpace::LEFT || *rhsSpace != pto::AddressSpace::RIGHT || - *dstSpace != pto::AddressSpace::ACC) - return op->emitOpError( - "expects lhs, rhs, and dst to use the left, right, and acc address spaces"); - auto lhsShape = getMatmulLogicalShapeVec(lhsTy); - auto rhsShape = getMatmulLogicalShapeVec(rhsTy); - auto dstShape = getMatmulLogicalShapeVec(dstTy); - if ((lhsShape[0] != dstShape[0] || rhsShape[1] != dstShape[1] || lhsShape[1] != rhsShape[0])) - return op->emitOpError( - "expects static matmul tile shapes lhs[M,K], rhs[K,N], and dst[M,N]"); - auto lhsValid = getValidShapeVec(lhsTy); - auto rhsValid = getValidShapeVec(rhsTy); - if (lhsValid.size() == 2 && rhsValid.size() == 2) { - int64_t m = lhsValid[0]; - int64_t k = lhsValid[1]; - int64_t n = rhsValid[1]; - if ((m != ShapedType::kDynamic && (m < 1 || m > 4095)) || - (k != ShapedType::kDynamic && (k < 1 || k > 4095)) || - (n != ShapedType::kDynamic && (n < 1 || n > 4095))) - return op->emitOpError("expects m, k, and n valid sizes to be in [1, 4095]"); - } - return success(); -} - -static LogicalResult verifyMatTileOperandsA5(Operation *op, Type lhsTy, - Type rhsTy, Type dstTy) { - if (failed(verifyMatTileOperandsA2A3(op, lhsTy, rhsTy, dstTy))) - return failure(); - - auto lhsTb = mlir::dyn_cast(lhsTy); - auto rhsTb = mlir::dyn_cast(rhsTy); - auto dstTb = mlir::dyn_cast(dstTy); - if (!lhsTb || !rhsTb || !dstTb) - return success(); - - if (lhsTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor)) - return op->emitOpError("expects lhs to use the col_major blayout on A5"); - if (rhsTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor)) - return op->emitOpError("expects rhs to use the row_major blayout on A5"); - if (dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor)) - return op->emitOpError("expects dst to use the col_major blayout on A5"); - - if (lhsTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor)) - return op->emitOpError("expects lhs to use the row_major slayout on A5"); - if (rhsTb.getSLayoutValueI32() != static_cast(pto::SLayout::ColMajor)) - return op->emitOpError("expects rhs to use the col_major slayout on A5"); - if (dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor)) - return op->emitOpError("expects dst to use the row_major slayout on A5"); - return success(); -} - -static LogicalResult verifyMatTileOperands(Operation *op, Type lhsTy, Type rhsTy, - Type dstTy) { - switch (getVerifierTargetArch(op)) { - case VerifierTargetArch::A2A3: - return verifyMatTileOperandsA2A3(op, lhsTy, rhsTy, dstTy); - case VerifierTargetArch::A5: - return verifyMatTileOperandsA5(op, lhsTy, rhsTy, dstTy); - } - return failure(); -} - -static LogicalResult verifyGemvTileOperandsA2A3(Operation *op, Type lhsTy, - Type rhsTy, Type dstTy) { - if (failed(verifyTileBufCommon(op, lhsTy, "lhs")) || - failed(verifyTileBufCommon(op, rhsTy, "rhs")) || - failed(verifyAccTileCommon(op, dstTy, "dst"))) - return failure(); - - auto lhsSpace = getPTOMemorySpaceEnum(lhsTy); - auto rhsSpace = getPTOMemorySpaceEnum(rhsTy); - if (!lhsSpace || !rhsSpace) - return op->emitOpError("expects lhs and rhs to have explicit address spaces"); - if (*lhsSpace != pto::AddressSpace::LEFT || *rhsSpace != pto::AddressSpace::RIGHT) - return op->emitOpError( - "expects lhs and rhs to use the left and right address spaces"); - - auto lhsValid = getValidShapeVec(lhsTy); - auto rhsValid = getValidShapeVec(rhsTy); - auto dstValid = getValidShapeVec(dstTy); - if (lhsValid[0] != ShapedType::kDynamic && lhsValid[0] != 1) - return op->emitOpError("expects lhs valid_shape[0] to be 1 for tgemv"); - if (isa(dstTy) && dstValid[0] != ShapedType::kDynamic && - dstValid[0] != 1) - return op->emitOpError("expects dst valid_shape[0] to be 1 for tgemv"); - if (lhsValid[1] != ShapedType::kDynamic && rhsValid[0] != ShapedType::kDynamic && - lhsValid[1] != rhsValid[0]) - return op->emitOpError() - << "expects lhs valid_shape[1] to equal rhs valid_shape[0], but got " - << lhsValid[1] << " vs " << rhsValid[0]; - if (rhsValid[1] != ShapedType::kDynamic && dstValid[1] != ShapedType::kDynamic && - rhsValid[1] != dstValid[1]) - return op->emitOpError() - << "expects rhs valid_shape[1] to equal dst valid_shape[1], but got " - << rhsValid[1] << " vs " << dstValid[1]; - return success(); -} - -static LogicalResult verifyGemvTileOperandsA5(Operation *op, Type lhsTy, - Type rhsTy, Type dstTy) { - if (failed(verifyGemvTileOperandsA2A3(op, lhsTy, rhsTy, dstTy))) - return failure(); - return verifyMatTileOperandsA5(op, lhsTy, rhsTy, dstTy); -} - -static LogicalResult verifyGemvTileOperands(Operation *op, Type lhsTy, Type rhsTy, - Type dstTy) { - switch (getVerifierTargetArch(op)) { - case VerifierTargetArch::A2A3: - return verifyGemvTileOperandsA2A3(op, lhsTy, rhsTy, dstTy); - case VerifierTargetArch::A5: - return verifyGemvTileOperandsA5(op, lhsTy, rhsTy, dstTy); - } - return failure(); -} - -static LogicalResult verifyMatBiasTileA2A3(Operation *op, Type biasTy, Type dstTy, - bool requireFloatBias) { - if (failed(verifyTileBufCommon(op, biasTy, "bias"))) - return failure(); - auto biasSpace = getPTOMemorySpaceEnum(biasTy); - if (!biasSpace || *biasSpace != pto::AddressSpace::BIAS) - return op->emitOpError("expects bias to be in the bias address space"); - auto biasShape = getShapeVec(biasTy); - if (biasShape[0] != ShapedType::kDynamic && biasShape[0] != 1) - return op->emitOpError("expects bias to have 1 row"); - if (requireFloatBias) { - if (!getElemTy(biasTy).isF32()) - return op->emitOpError("expects bias to have element type f32"); - } else if (getElemTy(biasTy) != getElemTy(dstTy)) { - return op->emitOpError("expects bias and dst to have the same element type"); - } - return success(); -} - -static LogicalResult verifyMatBiasTileA5(Operation *op, Type biasTy, Type dstTy, - bool requireFloatBias) { - if (failed(verifyMatBiasTileA2A3(op, biasTy, dstTy, requireFloatBias))) - return failure(); - if (auto biasTb = dyn_cast(biasTy)) { - if (biasTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor)) - return op->emitOpError("expects bias to use the row_major blayout on A5"); - } - return success(); -} - -static LogicalResult verifyMatBiasTile(Operation *op, Type biasTy, Type dstTy, - bool requireFloatBias) { - switch (getVerifierTargetArch(op)) { - case VerifierTargetArch::A2A3: - return verifyMatBiasTileA2A3(op, biasTy, dstTy, requireFloatBias); - case VerifierTargetArch::A5: - return verifyMatBiasTileA5(op, biasTy, dstTy, requireFloatBias); - } - return failure(); -} - -static LogicalResult verifyMatmulTypeTriple(Operation *op, Type lhsElemTy, - Type rhsElemTy, Type dstElemTy) { - bool isA5 = getVerifierTargetArch(op) == VerifierTargetArch::A5; - auto isInt8 = [](Type ty) { - return ty.isInteger(8); - }; - if (dstElemTy.isInteger(32) && isInt8(lhsElemTy) && isInt8(rhsElemTy)) - return success(); - - auto isSupportedFpInput = [](Type ty) { - return ty.isF16() || ty.isBF16() || ty.isF32(); - }; - if (dstElemTy.isF32() && lhsElemTy == rhsElemTy && isSupportedFpInput(lhsElemTy)) - return success(); - - if (isA5 && dstElemTy.isF32() && lhsElemTy == rhsElemTy) { - if (auto ft = mlir::dyn_cast(lhsElemTy)) { - unsigned width = ft.getWidth(); - if (width == 8 || width == 16 || width == 32) - return success(); - } - } - - return op->emitOpError() - << "expects (dst, lhs, rhs) element types to match one of " - "(i32, i8, i8), (f32, f16, f16), (f32, bf16, bf16), (f32, f32, f32)" - << (isA5 ? ", or an A5-supported fp8 pair" : ""); -} - -LogicalResult pto::TAddOp::verify() { - return verifyArithmeticBinaryTileOpWithArchDispatch( - getOperation(), getSrc0().getType(), getSrc1().getType(), getDst().getType(), - /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, - "expects A2/A3 tadd element type to be i32/i16/f16/f32", - "expects A5 tadd element type to be i32/i16/i8/f16/bf16/f32"); -} - -LogicalResult pto::TAddCOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - Type t0 = getSrc0().getType(); - Type t1 = getSrc1().getType(); - Type t2 = getSrc2().getType(); - Type td = getDst().getType(); - - if (!isPTOShapedLike(t0) || !isPTOShapedLike(t1) || - !isPTOShapedLike(t2) || !isPTOShapedLike(td)) - return emitOpError("expects src0/src1/src2/dst to be memref/tile_buf types"); - - auto s0 = getShapeVec(t0); - auto s1 = getShapeVec(t1); - auto s2 = getShapeVec(t2); - auto sd = getShapeVec(td); - if (s0 != s1 || s0 != s2 || s0 != sd) - return emitOpError("expects src0/src1/src2/dst to have the same shape"); - return success(); -} -LogicalResult pto::TAddSOp::verify() { - return verifyArithmeticScalarTileOpWithArchDispatch( - getOperation(), getSrc().getType(), getDst().getType(), getScalar().getType(), - /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, - "expects A2/A3 tadds element type to be i32/i16/f16/f32", - "expects A5 tadds element type to be i32/i16/i8/f16/bf16/f32", - /*requireValidRowsEqualOnA2A3=*/true, - /*requireValidRowsEqualOnA5=*/false); -} - -LogicalResult pto::TAxpyOp::verify() { - auto verifyCommon = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyVecTileCommon(*this, srcTy, "src")) || - failed(verifyVecTileCommon(*this, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) - return failure(); - - Type scalarTy = getScalar().getType(); - Type srcElem = getElemTy(srcTy); - if (scalarTy != srcElem) - return emitOpError("expects scalar type to match src element type"); - if (getShapeVec(srcTy) != getShapeVec(dstTy)) - return emitOpError("expects src and dst to have the same shape"); - return success(); - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - if (failed(verifyCommon())) - return failure(); - Type srcElem = getElemTy(getSrc().getType()); - Type dstElem = getElemTy(getDst().getType()); - bool sameType = srcElem == dstElem; - bool widenF16ToF32 = srcElem.isF16() && dstElem.isF32(); - if (!(sameType || widenF16ToF32)) - return emitOpError( - "expects dst/src element types to match, or dst=f32 and src=f16"); - if (!(dstElem.isF16() || dstElem.isF32())) - return emitOpError("expects A2/A3 taxpy dst element type to be f16/f32"); - if (!(srcElem.isF16() || srcElem.isF32())) - return emitOpError("expects A2/A3 taxpy src element type to be f16/f32"); - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - if (failed(verifyCommon())) - return failure(); - Type srcElem = getElemTy(getSrc().getType()); - Type dstElem = getElemTy(getDst().getType()); - bool sameType = srcElem == dstElem; - bool widenF16ToF32 = srcElem.isF16() && dstElem.isF32(); - if (!(sameType || widenF16ToF32)) - return emitOpError( - "expects dst/src element types to match, or dst=f32 and src=f16"); - if (!(dstElem.isF16() || dstElem.isF32() || dstElem.isBF16())) - return emitOpError("expects A5 taxpy dst element type to be f16/bf16/f32"); - if (!(srcElem.isF16() || srcElem.isF32() || srcElem.isBF16())) - return emitOpError("expects A5 taxpy src element type to be f16/bf16/f32"); - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult pto::TAddSCOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - Type ts0 = getSrc0().getType(); - Type ts1 = getSrc1().getType(); - Type td = getDst().getType(); - if (!isPTOShapedLike(ts0) || !isPTOShapedLike(ts1) || !isPTOShapedLike(td)) - return emitOpError("expects src0/src1/dst to be PTO shaped-like types"); - - auto s0 = getShapeVec(ts0); - auto s1 = getShapeVec(ts1); - auto sd = getShapeVec(td); - if (s0 != s1 || s0 != sd) - return emitOpError("expects src0/src1/dst to have the same shape"); - return success(); -} - -LogicalResult pto::TAndOp::verify() { - auto verifyCommon = [&]() -> FailureOr { - return verifyMatchingRowMajorBinaryTileOpCommon( - getOperation(), getSrc0().getType(), getSrc1().getType(), - getDst().getType()); - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - FailureOr elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - auto it = mlir::dyn_cast(*elemOr); - if (!it || (it.getWidth() != 8 && it.getWidth() != 16)) - return emitOpError( - "expects A2/A3 tand src0, src1, and dst element type to be i8/i16"); - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - FailureOr elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - auto it = mlir::dyn_cast(*elemOr); - if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && - it.getWidth() != 32)) - return emitOpError( - "expects A5 tand src0, src1, and dst element type to be i8/i16/i32"); - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TConcatOp::verify() { - auto verifyCommon = [&]() -> FailureOr { - Type t0 = getSrc0().getType(); - Type t1 = getSrc1().getType(); - Type td = getDst().getType(); - if (failed(verifyTileBufCommon(*this, t0, "src0")) || - failed(verifyTileBufCommon(*this, t1, "src1")) || - failed(verifyTileBufCommon(*this, td, "dst"))) - return failure(); - - Type e0 = getElemTy(t0); - Type e1 = getElemTy(t1); - Type ed = getElemTy(td); - if (!e0 || !e1 || !ed) { - emitOpError("failed to get element type for operands"); - return failure(); - } - if (e0 != e1 || e0 != ed) { - emitOpError("expects src0, src1, and dst to have the same element type"); - return failure(); - } - - auto v0 = getValidShapeVec(getSrc0()); - auto v1 = getValidShapeVec(getSrc1()); - auto vd = getValidShapeVec(getDst()); - if (v0.size() != 2 || v1.size() != 2 || vd.size() != 2) - return emitOpError("expects src0, src1, and dst to have rank-2 valid_shape"); - - // validRow must match dst (when known). - if (v0[0] != ShapedType::kDynamic && vd[0] != ShapedType::kDynamic && v0[0] != vd[0]) - return emitOpError("expects src0 valid row to match dst valid row"); - if (v1[0] != ShapedType::kDynamic && vd[0] != ShapedType::kDynamic && v1[0] != vd[0]) - return emitOpError("expects src1 valid row to match dst valid row"); - - // Total valid columns must fit within dst static cols (when known). - auto sd = getShapeVec(td); - if (sd.size() == 2 && sd[1] != ShapedType::kDynamic && - v0[1] != ShapedType::kDynamic && v1[1] != ShapedType::kDynamic) { - if (v0[1] + v1[1] > sd[1]) - return emitOpError("expects src0.valid_col + src1.valid_col <= dst.cols"); - } - - return e0; - }; - - auto verifyElemType = [&](Type elem) -> LogicalResult { - if (elem.isF16() || elem.isF32() || elem.isBF16()) - return success(); - auto it = mlir::dyn_cast(elem); - if (!it || - (it.getWidth() != 8 && it.getWidth() != 16 && it.getWidth() != 32)) - return emitOpError("expects element type to be i8, i16, i32, f16, f32, or bf16"); - return success(); - }; - - auto verifyLocVec = [&](Type ty, StringRef name) -> LogicalResult { - auto as = getPTOMemorySpaceEnum(ty); - if (!as || *as != pto::AddressSpace::VEC) - return emitOpError() << "expects " << name << " to use loc=vec"; - return success(); - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - FailureOr elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - if (failed(verifyLocVec(getSrc0().getType(), "src0")) || - failed(verifyLocVec(getSrc1().getType(), "src1")) || - failed(verifyLocVec(getDst().getType(), "dst"))) - return failure(); - return verifyElemType(*elemOr); - }; - - auto verifyA5 = [&]() -> LogicalResult { - FailureOr elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - if (failed(verifyLocVec(getSrc0().getType(), "src0")) || - failed(verifyLocVec(getSrc1().getType(), "src1")) || - failed(verifyLocVec(getDst().getType(), "dst"))) - return failure(); - if (!isRowMajorTileBuf(getSrc0().getType()) || !isRowMajorTileBuf(getSrc1().getType()) || - !isRowMajorTileBuf(getDst().getType())) - return emitOpError("expects src0, src1, and dst to use row-major layout"); - return verifyElemType(*elemOr); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult pto::TConcatidxOp::verify() { - auto verifyCommon = [&]() -> FailureOr> { - Type t0 = getSrc0().getType(); - Type t1 = getSrc1().getType(); - Type ti0 = getSrc0Idx().getType(); - Type ti1 = getSrc1Idx().getType(); - Type td = getDst().getType(); - if (failed(verifyTileBufCommon(*this, t0, "src0")) || - failed(verifyTileBufCommon(*this, t1, "src1")) || - failed(verifyTileBufCommon(*this, ti0, "src0Idx")) || - failed(verifyTileBufCommon(*this, ti1, "src1Idx")) || - failed(verifyTileBufCommon(*this, td, "dst"))) - return failure(); - - // Check data element type consistency. - Type e0 = getElemTy(t0); - Type e1 = getElemTy(t1); - Type ed = getElemTy(td); - if (!e0 || !e1 || !ed) { - emitOpError("failed to get element type for data operands"); - return failure(); - } - if (e0 != e1 || e0 != ed) { - emitOpError("expects src0, src1, and dst to have the same element type"); - return failure(); - } - - // Check index element type consistency. - Type ei0 = getElemTy(ti0); - Type ei1 = getElemTy(ti1); - if (!ei0 || !ei1) { - emitOpError("failed to get element type for index operands"); - return failure(); - } - if (ei0 != ei1) { - emitOpError("expects src0Idx and src1Idx to have the same element type"); - return failure(); - } - - // All five tiles must be rank-2. - auto v0 = getValidShapeVec(getSrc0()); - auto v1 = getValidShapeVec(getSrc1()); - auto vi0 = getValidShapeVec(getSrc0Idx()); - auto vi1 = getValidShapeVec(getSrc1Idx()); - auto vd = getValidShapeVec(getDst()); - if (v0.size() != 2 || v1.size() != 2 || vi0.size() != 2 || - vi1.size() != 2 || vd.size() != 2) - return emitOpError("expects all operands to have rank-2 valid_shape"); - - // validRow must match dst (when known). - auto checkValidRow = [&](const auto &v, StringRef name) -> LogicalResult { - if (v[0] != ShapedType::kDynamic && vd[0] != ShapedType::kDynamic && - v[0] != vd[0]) - return emitOpError("expects ") << name << " valid row to match dst valid row"; - return success(); - }; - if (failed(checkValidRow(v0, "src0")) || - failed(checkValidRow(v1, "src1")) || - failed(checkValidRow(vi0, "src0Idx")) || - failed(checkValidRow(vi1, "src1Idx"))) - return failure(); - - // Index tile must have cols >= 1 (when known). - if (vi0[1] != ShapedType::kDynamic && vi0[1] < 1) - return emitOpError("expects src0Idx valid_col >= 1"); - if (vi1[1] != ShapedType::kDynamic && vi1[1] < 1) - return emitOpError("expects src1Idx valid_col >= 1"); - - return std::make_pair(e0, ei0); - }; - - auto verifyElementTypes = [&](Type dataElem, Type idxElem) -> LogicalResult { - // Data element type: f16, f32, bf16, i8, i16, i32 (signless). - if (!dataElem.isF16() && !dataElem.isF32() && !dataElem.isBF16()) { - auto it = mlir::dyn_cast(dataElem); - if (!it || !it.isSignless() || - (it.getWidth() != 8 && it.getWidth() != 16 && it.getWidth() != 32)) - return emitOpError() - << "expects data element type to be i8, i16, i32, f16, f32, or bf16"; - } - - // Index element type: i8, i16, i32 (signless). - auto it = mlir::dyn_cast(idxElem); - if (!it || !it.isSignless() || - (it.getWidth() != 8 && it.getWidth() != 16 && it.getWidth() != 32)) - return emitOpError() - << "expects index element type to be i8, i16, or i32"; - return success(); - }; - - auto verifyLocVec = [&](Type ty, StringRef name) -> LogicalResult { - auto as = getPTOMemorySpaceEnum(ty); - if (!as || *as != pto::AddressSpace::VEC) - return emitOpError() << "expects " << name << " to use loc=vec"; - return success(); - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - auto elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - if (failed(verifyLocVec(getSrc0().getType(), "src0")) || - failed(verifyLocVec(getSrc1().getType(), "src1")) || - failed(verifyLocVec(getSrc0Idx().getType(), "src0Idx")) || - failed(verifyLocVec(getSrc1Idx().getType(), "src1Idx")) || - failed(verifyLocVec(getDst().getType(), "dst"))) - return failure(); - return verifyElementTypes(elemOr->first, elemOr->second); - }; - - auto verifyA5 = [&]() -> LogicalResult { - auto elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - if (failed(verifyLocVec(getSrc0().getType(), "src0")) || - failed(verifyLocVec(getSrc1().getType(), "src1")) || - failed(verifyLocVec(getSrc0Idx().getType(), "src0Idx")) || - failed(verifyLocVec(getSrc1Idx().getType(), "src1Idx")) || - failed(verifyLocVec(getDst().getType(), "dst"))) - return failure(); - if (!isRowMajorTileBuf(getSrc0().getType()) || - !isRowMajorTileBuf(getSrc1().getType()) || - !isRowMajorTileBuf(getSrc0Idx().getType()) || - !isRowMajorTileBuf(getSrc1Idx().getType()) || - !isRowMajorTileBuf(getDst().getType())) - return emitOpError( - "expects all operands to use row-major layout"); - return verifyElementTypes(elemOr->first, elemOr->second); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult pto::TAndSOp::verify() { - auto verifyCommon = [&]() -> FailureOr { - return verifyDistinctRowMajorUnaryTileOpCommon(getOperation(), getSrc(), - getDst(), "src", "dst"); - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - FailureOr elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - auto it = mlir::dyn_cast(*elemOr); - if (!it || (it.getWidth() != 8 && it.getWidth() != 16)) - return emitOpError( - "expects A2/A3 tands src, scalar, and dst element type to be i8/i16"); - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - FailureOr elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - auto it = mlir::dyn_cast(*elemOr); - if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && - it.getWidth() != 32)) - return emitOpError( - "expects A5 tands src, scalar, and dst element type to be i8/i16/i32"); - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult pto::TCIOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, dstTy, "dst"))) - return failure(); - - auto elemTy = mlir::dyn_cast(getElemTy(dstTy)); - if (!elemTy) - return emitOpError("expects dst element type to be integer"); - - unsigned bw = elemTy.getWidth(); - if (bw != 16 && bw != 32) - return emitOpError("expects dst element type to be i16/i32"); - - auto sTy = mlir::dyn_cast(getOperand(0).getType()); - if (!sTy) - return emitOpError("expects S to be integer"); - - if (sTy != elemTy) - return emitOpError("expects S and dst element type to be exactly the same type"); - auto shape = getShapeVec(dstTy); - if (shape.size() != 2) - return emitOpError("expects dst to be rank-2"); - if (shape[1] != ShapedType::kDynamic && shape[1] == 1) - return emitOpError("expects dst cols to be different from 1"); - - return success(); -} - -LogicalResult pto::TTriOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - - Type dstTy = getDst().getType(); - if (failed(verifyVecTileCommon(*this, dstTy, "dst"))) - return failure(); - - auto diagonalTy = mlir::dyn_cast(getDiagonal().getType()); - if (!diagonalTy) - return emitOpError("expects diagonal to be an integer operand"); - - int32_t upperOrLower = getUpperOrLower(); - if (upperOrLower != 0 && upperOrLower != 1) - return emitOpError("expects upperOrLower to be 0 (lower) or 1 (upper)"); - - Type elemTy = getElemTy(dstTy); - return dispatchVerifierByArch( - getOperation(), - [&]() -> LogicalResult { - if (!isSupportedVecElemType(elemTy, /*allowBf16=*/false, - /*allowInt8=*/false)) - return emitOpError() - << "expects A2/A3 dst element type to be f16/f32/i16/i32/u16/u32"; - return success(); - }, - [&]() -> LogicalResult { - if (!isSupportedVecElemType(elemTy, /*allowBf16=*/true, - /*allowInt8=*/true)) - return emitOpError() - << "expects A5 dst element type to be f16/f32/bf16/i8/i16/i32/u8/u16/u32"; - return success(); - }); -} - -LogicalResult pto::TCmpOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - Type t0 = getSrc0().getType(); - Type t1 = getSrc1().getType(); - Type td = getDst().getType(); - if (failed(verifyVecTileStorage(*this, t0, "src0")) || - failed(verifyVecTileStorage(*this, t1, "src1")) || - failed(verifyVecTileStorage(*this, td, "dst"))) - return failure(); - - Type e0 = getElemTy(t0); - Type e1 = getElemTy(t1); - Type ed = getElemTy(td); - if (!e0 || !e1 || !ed) - return emitOpError("failed to get element type for src0/src1/dst"); - if (e0 != e1) - return emitOpError("expects src0 and src1 to have the same element type"); - if (!(e0.isInteger(32) || e0.isF16() || e0.isF32())) - return emitOpError("expects A2/A3 tcmp input element type to be i32/f16/f32"); - if (!ed.isInteger(8)) - return emitOpError("expects dst element type to be i8"); - - auto valid0 = getValidShapeVec(t0); - auto valid1 = getValidShapeVec(t1); - auto validd = getValidShapeVec(td); - if (valid0.size() != 2 || valid1.size() != 2 || validd.size() != 2) - return emitOpError("expects src0, src1, and dst to have rank-2 valid_shape"); - if (!hasCompatibleKnownExtent(valid0[0], valid1[0])) - return emitOpError("expects src0 and src1 to have the same valid row"); - if (!hasCompatibleKnownExtent(valid0[1], valid1[1])) - return emitOpError("expects src0 and src1 to have the same valid column"); - if (!hasCompatibleKnownExtent(valid0[0], validd[0])) - return emitOpError("expects src0 valid row to equal dst valid row"); - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - Type t0 = getSrc0().getType(); - Type t1 = getSrc1().getType(); - Type td = getDst().getType(); - if (failed(verifyTileBufCommon(*this, t0, "src0")) || - failed(verifyTileBufCommon(*this, t1, "src1")) || - failed(verifyTileBufCommon(*this, td, "dst"))) - return failure(); - - Type e0 = getElemTy(t0); - Type e1 = getElemTy(t1); - Type ed = getElemTy(td); - if (!e0 || !e1 || !ed) - return emitOpError("failed to get element type for src0/src1/dst"); - if (e0 != e1) - return emitOpError("expects src0 and src1 to have the same element type"); - bool inputOk = e0.isF16() || e0.isF32() || e0.isBF16() || - e0.isInteger(8) || e0.isInteger(16) || e0.isInteger(32); - if (!inputOk) - return emitOpError("expects A5 tcmp input element type to be i8/i16/i32/f16/bf16/f32"); - if (auto it = dyn_cast(ed)) { - if (it.getWidth() != 8) - return emitOpError("expects dst element type to be i8"); - } else { - return emitOpError("expects dst element type to be i8"); - } - - if (getShapeVec(t0) != getShapeVec(t1) || getShapeVec(t0) != getShapeVec(td)) - return emitOpError("expects src0, src1, and dst to have the same shape"); - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -// ---- TCMPS verify ---- -LogicalResult pto::TCmpSOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyVecTileStorage(*this, srcTy, "src")) || - failed(verifyVecTileStorage(*this, dstTy, "dst"))) - return failure(); - - Type elemTy = getElemTy(srcTy); - if (!(elemTy.isInteger(16) || elemTy.isInteger(32) || - elemTy.isF16() || elemTy.isF32())) - return emitOpError("expects A2/A3 tcmps input element type to be i16/i32/f16/f32"); - - auto scalarTy = getScalar().getType(); - if (!(scalarTy.isIntOrIndexOrFloat())) - return emitOpError("expects scalar to be integer, index, or float"); - - auto srcValid = getValidShapeVec(srcTy); - auto dstValid = getValidShapeVec(dstTy); - if (srcValid.size() != 2 || dstValid.size() != 2) - return emitOpError("expects src and dst to have rank-2 valid_shape"); - if (srcValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && - srcValid[0] != dstValid[0]) - return emitOpError("expects src and dst to have the same valid_shape[0]"); - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyVecTileStorage(*this, srcTy, "src")) || - failed(verifyVecTileStorage(*this, dstTy, "dst"))) - return failure(); - - Type elemTy = getElemTy(srcTy); - if (!(elemTy.isInteger(8) || elemTy.isInteger(16) || elemTy.isInteger(32) || - elemTy.isF16() || elemTy.isF32())) - return emitOpError("expects A5 tcmps input element type to be i8/i16/i32/f16/f32"); - - auto scalarTy = getScalar().getType(); - if (!(scalarTy.isIntOrIndexOrFloat())) - return emitOpError("expects scalar to be integer, index, or float"); - - auto srcValid = getValidShapeVec(srcTy); - auto dstValid = getValidShapeVec(dstTy); - if (srcValid.size() != 2 || dstValid.size() != 2) - return emitOpError("expects src and dst to have rank-2 valid_shape"); - if (srcValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && - srcValid[0] != dstValid[0]) - return emitOpError("expects src and dst to have the same valid_shape[0]"); - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} -LogicalResult pto::TColExpandOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyNDStyleVecTile(*this, srcTy, "src")) || - failed(verifyNDStyleVecTile(*this, dstTy, "dst"))) - return failure(); - if (getElemTy(srcTy) != getElemTy(dstTy)) - return emitOpError("expects src and dst to have the same element type"); - if (!isSupportedVecElemType(getElemTy(srcTy), /*allowBf16=*/true, - /*allowInt8=*/true)) - return emitOpError("expects tcolexpand element type to be supported"); - auto srcValid = getValidShapeVec(getSrc()); - auto dstValid = getValidShapeVec(getDst()); - if (srcValid.size() != 2 || dstValid.size() != 2) - return emitOpError("expects src and dst to have rank-2 valid_shape"); - if (srcValid[1] != ShapedType::kDynamic && dstValid[1] != ShapedType::kDynamic && - srcValid[1] != dstValid[1]) - return emitOpError("expects src and dst to have the same valid_shape[1]"); - return success(); -} -static LogicalResult verifyTColExpandBinaryLikeOp(Operation *op, Type t0, Type t1, - Type td, PTOArch targetArch, - StringRef opName, - bool allowIntegerTypes) { - if (!isPTOShapedLike(t0) || !isPTOShapedLike(t1) || !isPTOShapedLike(td)) - return op->emitOpError("expects src0/src1/dst to be PTO shaped-like types"); - - Type e0 = getElemTy(t0); - Type e1 = getElemTy(t1); - Type ed = getElemTy(td); - if (!e0 || !e1 || !ed) - return op->emitOpError("failed to get element type for src0/src1/dst"); - - auto isSupportedElem = [&](Type elemTy) { - if (elemTy.isF16() || elemTy.isF32()) - return true; - if (!allowIntegerTypes) - return false; - if (elemTy.isInteger(16) || elemTy.isInteger(32)) - return true; - return targetArch == PTOArch::A5 && elemTy.isInteger(8); - }; - if (!isSupportedElem(e0) || !isSupportedElem(e1) || !isSupportedElem(ed)) { - if (!allowIntegerTypes) - return op->emitOpError() << "expects " << opName - << " element type to be f16 or f32"; - if (targetArch == PTOArch::A5) - return op->emitOpError() << "expects A5 " << opName - << " element type to be i8/i16/i32/f16/f32"; - return op->emitOpError() << "expects A2/A3 " << opName - << " element type to be i16/i32/f16/f32"; - } - - if (getShapeVec(t0) != getShapeVec(td)) - return op->emitOpError("expects src0/dst to have same shape"); - if (failed(verifyTileBufSameValidShape(op, t0, td, "src0", "dst"))) - return failure(); - - if (auto src0TileTy = dyn_cast(t0)) { - if (src0TileTy.getBLayoutValueI32() != 0) - return op->emitOpError("expects src0 to use row-major layout"); - } - - if (auto src1TileTy = dyn_cast(t1)) { - if (src1TileTy.getBLayoutValueI32() != 0) - return op->emitOpError("expects src1 to use row-major layout"); - } - if (auto dstTileTy = dyn_cast(td)) { - if (dstTileTy.getBLayoutValueI32() != 0) - return op->emitOpError("expects dst to use row-major layout"); - } - - auto src1Valid = getValidShapeVec(t1); - auto dstValid = getValidShapeVec(td); - if (src1Valid.size() == 2 && dstValid.size() == 2 && - src1Valid[1] != ShapedType::kDynamic && dstValid[1] != ShapedType::kDynamic && - src1Valid[1] != dstValid[1]) - return op->emitOpError("expects src1 valid_shape[1] to equal dst valid_shape[1]"); - - return success(); -} -LogicalResult pto::TColExpandMulOp::verify() { - PTOArch arch = getTargetArch(getOperation()); - return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), - getSrc1().getType(), getDst().getType(), - arch, "tcolexpandmul", - /*allowIntegerTypes=*/true); -} -LogicalResult pto::TColExpandAddOp::verify() { - PTOArch arch = getTargetArch(getOperation()); - return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), - getSrc1().getType(), getDst().getType(), - arch, "tcolexpandadd", - /*allowIntegerTypes=*/true); -} -LogicalResult pto::TColExpandDivOp::verify() { - auto verifyByArch = [&](PTOArch targetArch) -> LogicalResult { - bool allowIntegerTypes = (targetArch == PTOArch::A5); - return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), - getSrc1().getType(), getDst().getType(), - targetArch, "tcolexpanddiv", - /*allowIntegerTypes=*/allowIntegerTypes); - }; - auto verifyA2A3 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A3); }; - auto verifyA5 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A5); }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} -LogicalResult pto::TColExpandSubOp::verify() { - PTOArch arch = getTargetArch(getOperation()); - return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), - getSrc1().getType(), getDst().getType(), - arch, "tcolexpandsub", - /*allowIntegerTypes=*/true); -} -LogicalResult pto::TColExpandExpdifOp::verify() { - PTOArch arch = getTargetArch(getOperation()); - return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), - getSrc1().getType(), getDst().getType(), - arch, "tcolexpandexpdif", - /*allowIntegerTypes=*/false); -} -LogicalResult pto::TColExpandMaxOp::verify() { - PTOArch arch = getTargetArch(getOperation()); - return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), - getSrc1().getType(), getDst().getType(), - arch, "tcolexpandmax", - /*allowIntegerTypes=*/true); -} -LogicalResult pto::TColExpandMinOp::verify() { - PTOArch arch = getTargetArch(getOperation()); - return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), - getSrc1().getType(), getDst().getType(), - arch, "tcolexpandmin", - /*allowIntegerTypes=*/true); -} -LogicalResult pto::TColMaxOp::verify() { - return verifyTColReductionOpWithArchDispatch( - getOperation(), getSrc().getType(), getDst().getType(), - /*requireNonZeroSrcOnA2A3=*/false, /*requireNonZeroSrcOnA5=*/true, - /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, - "expects A2/A3 tcolmax element type to be f16/f32/i16/i32", - "expects A5 tcolmax element type to be i8/i16/i32/f16/bf16/f32"); -} - -LogicalResult pto::TColArgMaxOp::verify() { - auto verifyByArch = [&]() -> LogicalResult { - return verifyTColArgReductionOpCommon(*this, getSrc().getType(), - getTmp().getType(), getDst().getType()); - }; - - return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); -} - -LogicalResult pto::TColMinOp::verify() { - return verifyTColReductionOpWithArchDispatch( - getOperation(), getSrc().getType(), getDst().getType(), - /*requireNonZeroSrcOnA2A3=*/false, /*requireNonZeroSrcOnA5=*/true, - /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, - "expects A2/A3 tcolmin element type to be f16/f32/i16/i32", - "expects A5 tcolmin element type to be i8/i16/i32/f16/bf16/f32"); -} - -LogicalResult pto::TColArgMinOp::verify() { - auto verifyByArch = [&]() -> LogicalResult { - return verifyTColArgReductionOpCommon(*this, getSrc().getType(), - getTmp().getType(), getDst().getType()); - }; - - return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); -} - - - -ParseResult mlir::pto::TColSumOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand src; - OpAsmParser::UnresolvedOperand tmp; - OpAsmParser::UnresolvedOperand dst; - Type srcTy, tmpTy, dstTy; - bool hasTmp = false; - - // Parse: ins(%src : type) or ins(%src, %tmp {isBinary = ...}: type, type) - if (parser.parseKeyword("ins") || parser.parseLParen() || parser.parseOperand(src)) - return failure(); - - // Check for optional tmp operand (format 2) - if (succeeded(parser.parseOptionalComma())) { - // Format 2: ins(%src, %tmp {isBinary = ...}: type, type) - if (parser.parseOperand(tmp)) - return failure(); - hasTmp = true; - - // Parse attributes (isBinary) - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - - // Parse types: : type, type - if (parser.parseColonType(srcTy) || parser.parseComma() || parser.parseType(tmpTy)) - return failure(); - } else { - // Format 1: ins(%src : type) - if (parser.parseColonType(srcTy)) - return failure(); - } - - if (parser.parseRParen()) - return failure(); - - // Parse: outs(%dst : type) - if (parser.parseKeyword("outs") || parser.parseLParen() || - parser.parseOperand(dst) || parser.parseColonType(dstTy) || - parser.parseRParen()) - return failure(); - - // Parse any remaining attributes (for format 1) - if (!hasTmp) { - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - } - - // Resolve operands - if (parser.resolveOperand(src, srcTy, result.operands)) - return failure(); - - if (hasTmp) { - if (parser.resolveOperand(tmp, tmpTy, result.operands)) - return failure(); - } - - if (parser.resolveOperand(dst, dstTy, result.operands)) - return failure(); - - return success(); -} - -void mlir::pto::TColSumOp::print(OpAsmPrinter &p) { - if (getTmp()) { - // Format 2: ins(%src, %tmp {isBinary = ...}: type, type) outs(%dst : type) - p << " ins(" << getSrc() << ", " << getTmp(); - // Print isBinary attribute if present - SmallVector elidedAttrs; - if (!getIsBinaryAttr() || getIsBinaryAttr().getValue() == false) { - elidedAttrs.push_back("isBinary"); - } - p.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); - p << " : " << getSrc().getType() << ", " << getTmp().getType() << ")"; - } else { - // Format 1: ins(%src : type) outs(%dst : type) - p << " ins(" << getSrc() << " : " << getSrc().getType() << ")"; - } - - p << " outs(" << getDst() << " : " << getDst().getType() << ")"; - - // Print remaining attributes for format 1 (excluding isBinary) - if (!getTmp()) { - SmallVector elidedAttrs = {"isBinary"}; - p.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); - } -} - -LogicalResult pto::TColSumOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyNDStyleVecTile(*this, srcTy, "src")) || - failed(verifyNDStyleVecTile(*this, dstTy, "dst"))) - return failure(); - bool hasTmp = (bool)getTmp(); - bool hasIsBinary = (bool)getIsBinaryAttr(); - if (hasTmp != hasIsBinary) { - if (hasTmp) - return emitOpError("tmp operand requires isBinary attribute"); - return emitOpError("isBinary attribute requires tmp operand"); - } - if (getTmp()) { - Type tmpTy = getTmp().getType(); - if (failed(verifyNDStyleVecTile(*this, tmpTy, "tmp"))) - return failure(); - if (getElemTy(srcTy) != getElemTy(dstTy) || getElemTy(srcTy) != getElemTy(tmpTy)) - return emitOpError("expects src/tmp/dst element types to match"); - } - if (getElemTy(srcTy) != getElemTy(dstTy)) - return emitOpError("expects src/dst element types to match"); - if (failed(verifyColReductionValidRegion(*this, srcTy, dstTy, - /*requireNonZeroSrc=*/false))) - return failure(); - Type elem = getElemTy(srcTy); - if (!(elem.isF16() || elem.isF32() || elem.isInteger(16) || elem.isInteger(32))) - return emitOpError("expects A2/A3 tcolsum element type to be f16/f32/i16/i32"); - return success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyNDStyleVecTile(*this, srcTy, "src")) || - failed(verifyNDStyleVecTile(*this, dstTy, "dst"))) - return failure(); - bool hasTmp = (bool)getTmp(); - bool hasIsBinary = (bool)getIsBinaryAttr(); - if (hasTmp != hasIsBinary) { - if (hasTmp) - return emitOpError("tmp operand requires isBinary attribute"); - return emitOpError("isBinary attribute requires tmp operand"); - } - if (getTmp()) { - Type tmpTy = getTmp().getType(); - if (failed(verifyNDStyleVecTile(*this, tmpTy, "tmp"))) - return failure(); - if (getElemTy(srcTy) != getElemTy(dstTy) || getElemTy(srcTy) != getElemTy(tmpTy)) - return emitOpError("expects src/tmp/dst element types to match"); - } - if (getElemTy(srcTy) != getElemTy(dstTy)) - return emitOpError("expects src/dst element types to match"); - if (failed(verifyColReductionValidRegion(*this, srcTy, dstTy, - /*requireNonZeroSrc=*/true))) - return failure(); - Type elem = getElemTy(srcTy); - if (!(elem.isF16() || elem.isF32() || elem.isBF16() || elem.isInteger(8) || - elem.isInteger(16) || elem.isInteger(32))) - return emitOpError("expects A5 tcolsum element type to be i8/i16/i32/f16/bf16/f32"); - return success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult pto::TColProdOp::verify() { - return verifyTColReductionOpWithArchDispatch( - getOperation(), getSrc().getType(), getDst().getType(), - /*requireNonZeroSrcOnA2A3=*/false, /*requireNonZeroSrcOnA5=*/false, - /*allowInt8OnA5=*/false, /*allowBf16OnA5=*/true, - "expects A2/A3 tcolprod element type to be f16/f32/i16/i32", - "expects A5 tcolprod element type to be i16/ui16/i32/ui32/f16/bf16/f32"); -} - -llvm::LogicalResult mlir::pto::TCvtOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, srcTy, "src", /*allowLowPrecision=*/true)) || - failed(verifyTileBufCommon(*this, dstTy, "dst", /*allowLowPrecision=*/true))) - return failure(); - if (failed(verifyTileBufSameLogicalExtent(*this, srcTy, dstTy, "src", "dst", - /*compareValidShape=*/false))) - return failure(); - if (failed(verifyTileBufSameLogicalExtent(*this, srcTy, dstTy, "src", "dst", - /*compareValidShape=*/true))) - return failure(); - Type srcElem = getElemTy(srcTy); - Type dstElem = getElemTy(dstTy); - auto verifyA2A3 = [&]() -> LogicalResult { - if (isPTOLowPrecisionType(srcElem) || isPTOLowPrecisionType(dstElem)) - return emitOpError("expects A2/A3 tcvt low-precision element types to be unsupported"); - return success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - if (!isA5SupportedTCvtPair(srcElem, dstElem)) - return emitOpError("expects A5 tcvt low-precision type pairs to match PTO-ISA support"); - return success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -llvm::LogicalResult mlir::pto::TRandomOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - return emitOpError("trandom is only supported for A5 targets"); - }; - auto verifyA5 = [&]() -> LogicalResult { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, dstTy, "dst"))) - return failure(); - if (!isRowMajorTileBuf(dstTy)) - return emitOpError("expects dst to use row-major layout"); - - Type elemTy = getElemTy(dstTy); - if (!elemTy.isInteger(32)) - return emitOpError("expects dst element type to be i32 or ui32"); - - auto checkWord = [&](Value v, StringRef name) -> LogicalResult { - auto ty = dyn_cast(v.getType()); - if (!ty || ty.getWidth() != 32) - return emitOpError() << "expects " << name << " to be i32/ui32"; - return success(); - }; - if (failed(checkWord(getKey0(), "key0")) || - failed(checkWord(getKey1(), "key1")) || - failed(checkWord(getCounter0(), "counter0")) || - failed(checkWord(getCounter1(), "counter1")) || - failed(checkWord(getCounter2(), "counter2")) || - failed(checkWord(getCounter3(), "counter3"))) - return failure(); - - int32_t rounds = getRounds(); - if (rounds != 7 && rounds != 10) - return emitOpError("expects rounds to be 7 or 10"); - - return success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult mlir::pto::TDivOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - FailureOr elemOr = verifyMatchingRowMajorBinaryTileOpCommon( - getOperation(), getSrc0().getType(), getSrc1().getType(), - getDst().getType()); - if (failed(elemOr)) - return failure(); - auto elem0 = *elemOr; - if (!(elem0.isF16() || elem0.isF32())) - return emitOpError("expects A2/A3 tdiv element type to be f16 or f32"); - return success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - FailureOr elemOr = verifyMatchingRowMajorBinaryTileOpCommon( - getOperation(), getSrc0().getType(), getSrc1().getType(), - getDst().getType()); - if (failed(elemOr)) - return failure(); - auto elem0 = *elemOr; - if (!(elem0.isF16() || elem0.isF32() || elem0.isInteger(16) || elem0.isInteger(32))) - return emitOpError("expects A5 tdiv element type to be i32/i16/f16/f32"); - return success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TDivSOp::verify() { - auto isTileLike = [](Type ty) -> bool { - return isa(ty); - }; - auto isScalarLike = [](Type ty) -> bool { - return mlir::isa(ty); - }; - - auto verifyByArch = [&](PTOArch targetArch) -> LogicalResult { - Type srcTy = getSrc().getType(); - Type rhsTy = getScalar().getType(); - Type dstTy = getDst().getType(); - - bool srcTile = isTileLike(srcTy); - bool rhsTile = isTileLike(rhsTy); - bool srcScalar = isScalarLike(srcTy); - bool rhsScalar = isScalarLike(rhsTy); - - if (!(srcTile && rhsScalar) && !(srcScalar && rhsTile)) - return emitOpError("expects one tile-like operand and one scalar operand in ins(...)"); - - Type tileTy = srcTile ? srcTy : rhsTy; - Type scalarTy = srcTile ? rhsTy : srcTy; - - if (failed(verifyScalarTileOp(*this, tileTy, dstTy, "src", "dst", - /*requireValidRowsEqual=*/true, - /*requireValidColsEqual=*/true))) - return failure(); - if (!mlir::isa(scalarTy)) - return emitOpError("scalar must be a scalar type (integer/float)"); - Type elem = getElemTy(tileTy); - if (targetArch == PTOArch::A3 && - !(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || - elem.isF32())) - return emitOpError("expects A2/A3 tdivs element type to be i32/i16/f16/f32"); - if (targetArch == PTOArch::A5 && - !(elem.isInteger(32) || elem.isInteger(16) || elem.isInteger(8) || - elem.isF16() || elem.isF32())) - return emitOpError("expects A5 tdivs element type to be i32/i16/i8/f16/f32"); - return success(); - }; - auto verifyA2A3 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A3); }; - auto verifyA5 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A5); }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TExpOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyVecTileUnaryOp(*this, srcTy, dstTy, "src", "dst", - /*allowBf16=*/false, /*allowInt8=*/false))) - return failure(); - if (failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) - return failure(); - Type srcElem = getElemTy(srcTy); - if (!srcElem.isF16() && !srcElem.isF32()) - return emitOpError("expects element type to be f16 or f32"); - return mlir::success(); - }; - auto verifyA5 = [&]() -> LogicalResult { return verifyA2A3(); }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TExpandsOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, dstTy, "dst"))) - return failure(); - auto dstSpace = getPTOMemorySpaceEnum(dstTy); - if (!dstSpace || (*dstSpace != pto::AddressSpace::VEC && - *dstSpace != pto::AddressSpace::MAT)) - return emitOpError("expects dst to be in the vec or mat address space"); - Type dstElem = getElemTy(dstTy); - Type scalarTy = getScalar().getType(); - if (scalarTy != dstElem) - return emitOpError("expects scalar type == dst element type"); - if (*dstSpace == pto::AddressSpace::VEC && !isRowMajorTileBuf(dstTy)) - return emitOpError("expects vec dst to use row-major layout on A2/A3"); - if (dstElem.isF16() || dstElem.isBF16() || dstElem.isF32()) - return mlir::success(); - if (auto it = mlir::dyn_cast(dstElem)) { - unsigned w = it.getWidth(); - if (w == 16 || w == 32) - return mlir::success(); - } - return emitOpError("expects A2/A3 texpands dst element type to be i16/i32/f16/bf16/f32"); - }; - auto verifyA5 = [&]() -> LogicalResult { - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, dstTy, "dst"))) - return failure(); - auto dstSpace = getPTOMemorySpaceEnum(dstTy); - if (!dstSpace || (*dstSpace != pto::AddressSpace::VEC && - *dstSpace != pto::AddressSpace::MAT)) - return emitOpError("expects dst to be in the vec or mat address space"); - Type dstElem = getElemTy(dstTy); - Type scalarTy = getScalar().getType(); - if (scalarTy != dstElem) - return emitOpError("expects scalar type == dst element type"); - if (dstElem.isF16() || dstElem.isBF16() || dstElem.isF32()) - return mlir::success(); - if (auto it = mlir::dyn_cast(dstElem)) { - unsigned w = it.getWidth(); - if (w == 8 || w == 16 || w == 32) - return mlir::success(); - } - return emitOpError("expects A5 texpands dst element type to be i8/i16/i32/f16/bf16/f32"); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TExtractOp::verify() { - auto hasMatExtractSourceLayoutA2A3 = [&](pto::TileBufType srcTy) -> bool { - int32_t bl = srcTy.getBLayoutValueI32(); - int32_t sl = srcTy.getSLayoutValueI32(); - return bl == static_cast(pto::BLayout::RowMajor) || - (bl != static_cast(pto::BLayout::RowMajor) && - sl == static_cast(pto::SLayout::RowMajor)); - }; - auto hasMatExtractSourceLayoutA5 = [&](pto::TileBufType srcTy, - pto::AddressSpace dstSpace) -> bool { - int32_t bl = srcTy.getBLayoutValueI32(); - int32_t sl = srcTy.getSLayoutValueI32(); - if (dstSpace == pto::AddressSpace::LEFT) { - return (bl == static_cast(pto::BLayout::RowMajor) && - sl == static_cast(pto::SLayout::ColMajor)) || - (bl != static_cast(pto::BLayout::RowMajor) && - sl == static_cast(pto::SLayout::RowMajor)) || - bl == static_cast(pto::BLayout::RowMajor); - } - return (bl == static_cast(pto::BLayout::RowMajor) && - sl == static_cast(pto::SLayout::ColMajor)) || - (bl != static_cast(pto::BLayout::RowMajor) && - sl == static_cast(pto::SLayout::RowMajor)); - }; - auto isA2A3ExtractElemType = [&](Type ty) -> bool { - return ty.isInteger(8) || ty.isF16() || ty.isBF16() || ty.isF32(); - }; - auto isA5ExtractElemType = [&](Type ty) -> bool { - if (auto it = dyn_cast(ty)) - return it.getWidth() == 8; - if (auto ft = dyn_cast(ty)) - return ft.getWidth() == 8 || ft.isF16() || ft.isBF16() || ft.isF32(); - return false; - }; - auto isRowMajorNoneBoxND = [&](pto::TileBufType ty) -> bool { - return ty.getBLayoutValueI32() == static_cast(pto::BLayout::RowMajor) && - ty.getSLayoutValueI32() == static_cast(pto::SLayout::NoneBox); - }; - auto verifyCommon = [&]() -> FailureOr, - std::optional>> { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - auto srcTb = dyn_cast(srcTy); - auto dstTb = dyn_cast(dstTy); - if (!srcTb || !dstTb) - return emitOpError("expects src and dst to be !pto.tile_buf"); - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, dstTy, "dst")) || - failed(verifyNonNegativeIndexRowCol( - *getOperation(), getIndexRow(), getIndexCol(), - /*includeIndexAndIntOpsInConstFold=*/false)) || - failed(verifyExtractStaticBoundsCommon( - *getOperation(), getIndexRow(), getIndexCol(), srcTy, dstTy, - /*includeIndexAndIntOpsInConstFold=*/false))) - return failure(); - Type srcElem = getElemTy(srcTy); - Type dstElem = getElemTy(dstTy); - if (!srcElem || !dstElem || srcElem != dstElem) - return emitOpError("expects src and dst to have the same element type"); - auto srcSpace = getPTOMemorySpaceEnum(srcTy); - auto dstSpace = getPTOMemorySpaceEnum(dstTy); - return std::make_tuple(srcTy, dstTy, srcTb, dstTb, srcElem, dstElem, - srcSpace, dstSpace); - }; - auto verifyA2A3 = [&]() -> LogicalResult { - auto common = verifyCommon(); - if (failed(common)) - return failure(); - auto [srcTy, dstTy, srcTb, dstTb, srcElem, dstElem, srcSpace, dstSpace] = - *common; - (void)srcTy; - (void)dstTy; - (void)srcElem; - if (!isA2A3ExtractElemType(dstElem)) - return emitOpError("expects A2/A3 textract element type to be i8/f16/bf16/f32"); - if (srcSpace && dstSpace && *srcSpace == pto::AddressSpace::VEC && - *dstSpace == pto::AddressSpace::VEC) - return mlir::success(); - if (!srcSpace || *srcSpace != pto::AddressSpace::MAT) - return emitOpError("expects A2/A3 textract src to use loc=mat or vec"); - if (!dstSpace || (*dstSpace != pto::AddressSpace::LEFT && - *dstSpace != pto::AddressSpace::RIGHT)) - return emitOpError("expects A2/A3 textract dst to use loc=left, loc=right, or loc=vec"); - if (!hasMatExtractSourceLayoutA2A3(srcTb)) - return emitOpError("expects A2/A3 textract src to use a supported mat blayout/slayout combination"); - if (*dstSpace == pto::AddressSpace::LEFT) { - if (dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) || - dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor)) - return emitOpError("expects A2/A3 left dst to use row_major blayout and row_major slayout"); - } else { - if (dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) || - dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::ColMajor)) - return emitOpError("expects A2/A3 right dst to use row_major blayout and col_major slayout"); - } - return mlir::success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - auto common = verifyCommon(); - if (failed(common)) - return failure(); - auto [srcTy, dstTy, srcTb, dstTb, srcElem, dstElem, srcSpace, dstSpace] = - *common; - (void)srcTy; - (void)dstTy; - (void)srcElem; - if (!isA5ExtractElemType(dstElem)) - return emitOpError("expects A5 textract element type to be an fp8/f16/bf16/f32 or int8 family type"); - if (!srcSpace || !dstSpace) - return emitOpError("expects src and dst to have explicit loc"); - bool okPair = - (*srcSpace == pto::AddressSpace::MAT && - (*dstSpace == pto::AddressSpace::LEFT || - *dstSpace == pto::AddressSpace::RIGHT || - *dstSpace == pto::AddressSpace::SCALING)) || - (*srcSpace == pto::AddressSpace::VEC && - (*dstSpace == pto::AddressSpace::MAT || - *dstSpace == pto::AddressSpace::VEC)); - if (!okPair) - return emitOpError("expects A5 textract to use a supported src/dst loc pair"); - if (*srcSpace == pto::AddressSpace::MAT) { - if (!hasMatExtractSourceLayoutA5(srcTb, *dstSpace)) - return emitOpError("expects A5 textract src to use a supported mat blayout/slayout combination"); - if (*dstSpace == pto::AddressSpace::LEFT) { - if (dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor) || - dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor)) - return emitOpError("expects A5 left dst to use col_major blayout and row_major slayout"); - } else if (*dstSpace == pto::AddressSpace::RIGHT) { - if (dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) || - dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::ColMajor)) - return emitOpError("expects A5 right dst to use row_major blayout and col_major slayout"); - } - } else if (*srcSpace == pto::AddressSpace::VEC && - *dstSpace == pto::AddressSpace::VEC) { - if (!isRowMajorNoneBoxND(srcTb) || !isRowMajorNoneBoxND(dstTb)) - return emitOpError( - "expects A5 vec->vec textract src/dst to use ND layout " - "(blayout=row_major, slayout=none_box)"); - } - return success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} -mlir::LogicalResult mlir::pto::TInsertOp::verify() { - auto isColMajorRowMajorNZ = [&](pto::TileBufType ty) -> bool { - return ty.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) && - ty.getSLayoutValueI32() == static_cast(pto::SLayout::RowMajor); - }; - auto isRowMajorNoneBoxND = [&](pto::TileBufType ty) -> bool { - return ty.getBLayoutValueI32() == static_cast(pto::BLayout::RowMajor) && - ty.getSLayoutValueI32() == static_cast(pto::SLayout::NoneBox); - }; - auto isA5SupportedVecElemType = [&](Type ty) -> bool { - if (auto it = dyn_cast(ty)) - return it.getWidth() == 8 || it.getWidth() == 32; - if (auto ft = dyn_cast(ty)) - return ft.getWidth() == 8 || ft.isF16() || ft.isBF16() || ft.isF32(); - return false; - }; - auto isA2A3VecInsertElemType = [&](Type ty) -> bool { - return ty.isInteger(8) || ty.isF16() || ty.isBF16() || ty.isF32(); - }; - auto verifyCommon = [&]() -> FailureOr, - std::optional>> { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - auto srcTb = dyn_cast(srcTy); - auto dstTb = dyn_cast(dstTy); - if (!srcTb || !dstTb) - return emitOpError("expects src and dst to be !pto.tile_buf"); - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, dstTy, "dst")) || - failed(verifyNonNegativeIndexRowCol( - *getOperation(), getIndexRow(), getIndexCol(), - /*includeIndexAndIntOpsInConstFold=*/true)) || - failed(verifyInsertStaticBoundsCommon( - *getOperation(), getIndexRow(), getIndexCol(), srcTy, dstTy, - /*includeIndexAndIntOpsInConstFold=*/true))) - return failure(); - Type srcElem = getElemTy(srcTy); - Type dstElem = getElemTy(dstTy); - auto srcSpace = getPTOMemorySpaceEnum(srcTy); - auto dstSpace = getPTOMemorySpaceEnum(dstTy); - return std::make_tuple(srcTy, dstTy, srcTb, dstTb, srcElem, dstElem, - srcSpace, dstSpace); - }; - auto verifyA2A3 = [&]() -> LogicalResult { - auto common = verifyCommon(); - if (failed(common)) - return failure(); - auto [srcTy, dstTy, srcTb, dstTb, srcElem, dstElem, srcSpace, dstSpace] = - *common; - if (srcSpace && dstSpace && *srcSpace == pto::AddressSpace::VEC && - *dstSpace == pto::AddressSpace::VEC) { - if (srcElem != dstElem || !isA2A3VecInsertElemType(srcElem)) - return emitOpError( - "expects A2/A3 vec->vec tinsert src/dst to have same supported dtype " - "(i8/f16/bf16/f32)"); - return success(); - } - if (!srcSpace || !dstSpace || *srcSpace != pto::AddressSpace::ACC || - *dstSpace != pto::AddressSpace::MAT) - return emitOpError("expects A2/A3 tinsert to use acc->mat or vec->vec"); - - if (!isColMajorRowMajorNZ(srcTb)) - return emitOpError("expects A2/A3 tinsert src to use blayout=col_major and slayout=row_major"); - if (!isColMajorRowMajorNZ(dstTb)) - return emitOpError("expects A2/A3 tinsert dst to use blayout=col_major and slayout=row_major"); - if (dstTb.getSFractalSizeI32() != 512) - return emitOpError("expects A2/A3 tinsert dst fractal size to be 512"); - - if (!(srcElem.isF32() && (dstElem.isF16() || dstElem.isBF16()))) - return emitOpError("expects A2/A3 tinsert element types to be src=f32, dst=f16/bf16"); - return success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - auto common = verifyCommon(); - if (failed(common)) - return failure(); - auto [srcTy, dstTy, srcTb, dstTb, srcElem, dstElem, srcSpace, dstSpace] = - *common; - if (!srcSpace || !dstSpace) - return emitOpError("expects A5 tinsert src/dst to have explicit loc"); - - // A5 regular acc->mat path. - if (*srcSpace == pto::AddressSpace::ACC && *dstSpace == pto::AddressSpace::MAT) { - if (!isColMajorRowMajorNZ(srcTb)) - return emitOpError("expects A5 acc->mat tinsert src to use blayout=col_major and slayout=row_major"); - if (!isColMajorRowMajorNZ(dstTb)) - return emitOpError("expects A5 acc->mat tinsert dst to use blayout=col_major and slayout=row_major"); - bool okTypes = (srcElem.isF32() && - (dstElem.isF16() || dstElem.isBF16() || dstElem.isF32())) || - (srcElem.isInteger(32) && dstElem.isInteger(32)); - if (!okTypes) - return emitOpError( - "expects A5 acc->mat tinsert element types to be " - "(src=f32,dst=f16/bf16/f32) or (src=i32,dst=i32)"); - return success(); - } - - // A5 vec->mat path (ND/NZ modes in pto-isa). - if (*srcSpace == pto::AddressSpace::VEC && *dstSpace == pto::AddressSpace::MAT) { - if (!isColMajorRowMajorNZ(dstTb)) - return emitOpError("expects A5 vec->mat tinsert dst to use blayout=col_major and slayout=row_major"); - bool srcIsND = isRowMajorNoneBoxND(srcTb); - bool srcIsNZ = isColMajorRowMajorNZ(srcTb); - if (!srcIsND && !srcIsNZ) - return emitOpError( - "expects A5 vec->mat tinsert src to use ND(row_major/none_box) or NZ(col_major/row_major) layout"); - if (srcElem != dstElem || !isA5SupportedVecElemType(srcElem)) - return emitOpError( - "expects A5 vec->mat tinsert src/dst to have same supported dtype " - "(fp8/f16/bf16/f32/i8/i32)"); - return success(); - } - - // A5 vec->vec path (PR561 ND_VEC). - if (*srcSpace == pto::AddressSpace::VEC && *dstSpace == pto::AddressSpace::VEC) { - if (!isRowMajorNoneBoxND(srcTb) || !isRowMajorNoneBoxND(dstTb)) - return emitOpError( - "expects A5 vec->vec tinsert src/dst to use ND layout " - "(blayout=row_major, slayout=none_box)"); - if (srcElem != dstElem || !isA5SupportedVecElemType(srcElem)) - return emitOpError( - "expects A5 vec->vec tinsert src/dst to have same supported dtype " - "(fp8/f16/bf16/f32/i8/i32)"); - return success(); - } - - return emitOpError( - "expects A5 tinsert to use a supported src/dst loc pair: " - "acc->mat, vec->mat, or vec->vec"); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -static bool isColMajorRowMajorNZTileBuf(pto::TileBufType ty) { - return ty.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) && - ty.getSLayoutValueI32() == static_cast(pto::SLayout::RowMajor); -} - -static bool isA2A3VectorPreQuantTypePair(Type srcElem, Type dstElem) { - if (srcElem.isF32()) - return dstElem.isInteger(8); - if (srcElem.isInteger(32)) - return dstElem.isInteger(8) || dstElem.isF16() || dstElem.isInteger(16); - return false; -} - -static bool isA5Fp8LikeType(Type ty) { - if (auto ft = dyn_cast(ty)) - return ft.getWidth() == 8; - return false; -} - -static bool isA5MxInputType(Type ty) { - return isA5Fp8LikeType(ty); -} - -static LogicalResult verifyA5MxTypeTriple(Operation *op, Type lhsTy, Type rhsTy, - Type dstTy, StringRef lhsName, - StringRef rhsName, StringRef dstName) { - Type lhsElem = getElemTy(lhsTy); - Type rhsElem = getElemTy(rhsTy); - Type dstElem = getElemTy(dstTy); - - if (!isA5MxInputType(lhsElem) || !isA5MxInputType(rhsElem)) - return op->emitOpError() - << "expects A5 mx operands " << lhsName << " and " << rhsName - << " to use fp8 element types"; - - if (!dstElem.isF32()) - return op->emitOpError() - << "expects A5 mx result " << dstName << " to use f32 element type"; - - return success(); -} - -static bool isA5VectorPreQuantTypePair(Type srcElem, Type dstElem) { - if (srcElem.isF32()) - return dstElem.isInteger(8) || isA5Fp8LikeType(dstElem) || dstElem.isF16() || - dstElem.isBF16() || dstElem.isF32(); - if (srcElem.isInteger(32)) - return dstElem.isInteger(8) || dstElem.isF16() || dstElem.isBF16(); - return false; -} - -mlir::LogicalResult mlir::pto::TExtractFPOp::verify() { - auto verifyCommon = [&]() -> FailureOr> { - Type srcTy = getSrc().getType(); - Type fpTy = getFp().getType(); - Type dstTy = getDst().getType(); - auto srcTb = dyn_cast(srcTy); - auto fpTb = dyn_cast(fpTy); - auto dstTb = dyn_cast(dstTy); - if (!srcTb || !fpTb || !dstTb) - return emitOpError("expects src, fp, and dst to be !pto.tile_buf"); - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, fpTy, "fp")) || - failed(verifyTileBufCommon(*this, dstTy, "dst")) || - failed(verifyNonNegativeIndexRowCol( - *getOperation(), getIndexRow(), getIndexCol(), - /*includeIndexAndIntOpsInConstFold=*/true)) || - failed(verifyExtractStaticBoundsCommon( - *getOperation(), getIndexRow(), getIndexCol(), srcTy, dstTy, - /*includeIndexAndIntOpsInConstFold=*/true))) - return failure(); - auto srcSpace = getPTOMemorySpaceEnum(srcTy); - auto fpSpace = getPTOMemorySpaceEnum(fpTy); - auto dstSpace = getPTOMemorySpaceEnum(dstTy); - if (!srcSpace || !fpSpace || !dstSpace) - return emitOpError("expects src, fp, and dst to have explicit loc"); - if (*srcSpace != pto::AddressSpace::ACC) - return emitOpError("expects src to use loc=acc"); - if (*fpSpace != pto::AddressSpace::SCALING) - return emitOpError("expects fp to use loc=scaling"); - if (*dstSpace != pto::AddressSpace::MAT) - return emitOpError("expects dst to use loc=mat"); - if (!isColMajorRowMajorNZTileBuf(srcTb)) - return emitOpError("expects src to use blayout=col_major and slayout=row_major"); - if (!isColMajorRowMajorNZTileBuf(dstTb)) - return emitOpError("expects dst to use blayout=col_major and slayout=row_major"); - return std::make_tuple(srcTy, fpTy, dstTy, srcTb, fpTb, dstTb, *srcSpace, - *fpSpace, *dstSpace); - }; - auto verifyA2A3 = [&]() -> LogicalResult { - auto common = verifyCommon(); - if (failed(common)) - return failure(); - auto [srcTy, fpTy, dstTy, srcTb, fpTb, dstTb, srcSpace, fpSpace, dstSpace] = - *common; - (void)fpTy; - (void)srcSpace; - (void)fpSpace; - (void)dstSpace; - if (dstTb.getSFractalSizeI32() != 512) - return emitOpError("expects dst fractal size to be 512"); - Type srcElem = getElemTy(srcTy); - Type dstElem = getElemTy(dstTy); - if (!isA2A3VectorPreQuantTypePair(srcElem, dstElem)) - return emitOpError( - "expects A2/A3 textract_fp element types to be (src=f32,dst=i8) " - "or (src=i32,dst=i8/f16/i16)"); - return success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - auto common = verifyCommon(); - if (failed(common)) - return failure(); - auto [srcTy, fpTy, dstTy, srcTb, fpTb, dstTb, srcSpace, fpSpace, dstSpace] = - *common; - (void)fpTy; - (void)srcTb; - (void)fpTb; - (void)dstTb; - (void)srcSpace; - (void)fpSpace; - (void)dstSpace; - Type srcElem = getElemTy(srcTy); - Type dstElem = getElemTy(dstTy); - if (!isA5VectorPreQuantTypePair(srcElem, dstElem)) - return emitOpError( - "expects A5 textract_fp element types to be (src=f32,dst=i8/fp8/f16/bf16/f32) " - "or (src=i32,dst=i8/f16/bf16)"); - return success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TInsertFPOp::verify() { - auto verifyCommon = [&]() -> FailureOr> { - Type srcTy = getSrc().getType(); - Type fpTy = getFp().getType(); - Type dstTy = getDst().getType(); - auto srcTb = dyn_cast(srcTy); - auto fpTb = dyn_cast(fpTy); - auto dstTb = dyn_cast(dstTy); - if (!srcTb || !fpTb || !dstTb) - return emitOpError("expects src, fp, and dst to be !pto.tile_buf"); - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, fpTy, "fp")) || - failed(verifyTileBufCommon(*this, dstTy, "dst")) || - failed(verifyNonNegativeIndexRowCol( - *getOperation(), getIndexRow(), getIndexCol(), - /*includeIndexAndIntOpsInConstFold=*/true)) || - failed(verifyInsertStaticBoundsCommon( - *getOperation(), getIndexRow(), getIndexCol(), srcTy, dstTy, - /*includeIndexAndIntOpsInConstFold=*/true))) - return failure(); - auto srcSpace = getPTOMemorySpaceEnum(srcTy); - auto fpSpace = getPTOMemorySpaceEnum(fpTy); - auto dstSpace = getPTOMemorySpaceEnum(dstTy); - if (!srcSpace || !fpSpace || !dstSpace) - return emitOpError("expects src, fp, and dst to have explicit loc"); - if (*srcSpace != pto::AddressSpace::ACC) - return emitOpError("expects src to use loc=acc"); - if (*fpSpace != pto::AddressSpace::SCALING) - return emitOpError("expects fp to use loc=scaling"); - if (*dstSpace != pto::AddressSpace::MAT) - return emitOpError("expects dst to use loc=mat"); - if (!isColMajorRowMajorNZTileBuf(srcTb)) - return emitOpError("expects src to use blayout=col_major and slayout=row_major"); - if (!isColMajorRowMajorNZTileBuf(dstTb)) - return emitOpError("expects dst to use blayout=col_major and slayout=row_major"); - return std::make_tuple(srcTy, fpTy, dstTy, srcTb, fpTb, dstTb, *srcSpace, - *fpSpace, *dstSpace); - }; - auto verifyA2A3 = [&]() -> LogicalResult { - auto common = verifyCommon(); - if (failed(common)) - return failure(); - auto [srcTy, fpTy, dstTy, srcTb, fpTb, dstTb, srcSpace, fpSpace, dstSpace] = - *common; - (void)fpTy; - (void)srcTb; - (void)fpTb; - (void)srcSpace; - (void)fpSpace; - (void)dstSpace; - if (dstTb.getSFractalSizeI32() != 512) - return emitOpError("expects dst fractal size to be 512"); - Type srcElem = getElemTy(srcTy); - Type dstElem = getElemTy(dstTy); - if (!isA2A3VectorPreQuantTypePair(srcElem, dstElem)) - return emitOpError( - "expects A2/A3 tinsert_fp element types to be (src=f32,dst=i8) " - "or (src=i32,dst=i8/f16/i16)"); - return success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - auto common = verifyCommon(); - if (failed(common)) - return failure(); - auto [srcTy, fpTy, dstTy, srcTb, fpTb, dstTb, srcSpace, fpSpace, dstSpace] = - *common; - (void)fpTy; - (void)srcTb; - (void)fpTb; - (void)dstTb; - (void)srcSpace; - (void)fpSpace; - (void)dstSpace; - Type srcElem = getElemTy(srcTy); - Type dstElem = getElemTy(dstTy); - if (!isA5VectorPreQuantTypePair(srcElem, dstElem)) - return emitOpError( - "expects A5 tinsert_fp element types to be (src=f32,dst=i8/fp8/f16/bf16/f32) " - "or (src=i32,dst=i8/f16/bf16)"); - return success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -static mlir::LogicalResult verifyTFillPadLike(Operation *op, Type srcTy, Type dstTy, - bool allowDstExpand, - llvm::StringRef opName) { - if (!isPTOShapedLike(srcTy) || !isPTOShapedLike(dstTy)) - return op->emitError("expects src/dst to be PTO shaped-like types"); - - auto srcShape = getShapeVec(srcTy); - auto dstShape = getShapeVec(dstTy); - if (srcShape.size() != 2 || dstShape.size() != 2) - return op->emitError("expects rank-2 shaped types for src/dst"); - - auto srcElem = getElemTy(srcTy); - auto dstElem = getElemTy(dstTy); - - auto getElemBytes = [](mlir::Type t) -> int64_t { - unsigned elemBytes = getPTOStorageElemByteSize(t); - return elemBytes == 0 ? -1 : static_cast(elemBytes); - }; - - int64_t srcB = getElemBytes(srcElem); - int64_t dstB = getElemBytes(dstElem); - if (srcB < 0 || dstB < 0) - return op->emitError("unsupported element type (expects int/float element types)"); - if (srcB != dstB) - return op->emitError("expects sizeof(src element) == sizeof(dst element)"); - if (!(srcB == 1 || srcB == 2 || srcB == 4)) - return op->emitError("expects element size to be 1, 2, or 4 bytes"); - - // pto.tfillpad lowers to TFILLPAD(dst, src). For loc=mat, pto-isa only - // exposes the homogeneous overload, so src/dst must use the same Tile<...> - // specialization (including valid_shape and pad). - // Note: tfillpad_expand is intentionally not covered here because its - // cross-layer ABI contract for loc=mat heterogeneous shape expansion is not - // finalized yet. - if (opName == "tfillpad") { - auto srcTb = mlir::dyn_cast(srcTy); - auto dstTb = mlir::dyn_cast(dstTy); - auto srcSpace = getPTOMemorySpaceEnum(srcTy); - auto dstSpace = getPTOMemorySpaceEnum(dstTy); - if (srcTb && dstTb && srcSpace && dstSpace && - *srcSpace == mlir::pto::AddressSpace::MAT && - *dstSpace == mlir::pto::AddressSpace::MAT && srcTb != dstTb) { - auto dimToStr = [](int64_t dim) -> std::string { - return dim == ShapedType::kDynamic ? "?" : std::to_string(dim); - }; - SmallVector mismatchFields; - auto srcValid = getValidShapeVec(srcTy); - auto dstValid = getValidShapeVec(dstTy); - if (srcValid.size() == 2 && dstValid.size() == 2) { - if (srcValid[0] != dstValid[0]) - mismatchFields.push_back("v_row (" + dimToStr(srcValid[0]) + " vs " + - dimToStr(dstValid[0]) + ")"); - if (srcValid[1] != dstValid[1]) - mismatchFields.push_back("v_col (" + dimToStr(srcValid[1]) + " vs " + - dimToStr(dstValid[1]) + ")"); - } - if (srcTb.getPadValueI32() != dstTb.getPadValueI32()) - mismatchFields.push_back("pad (" + std::to_string(srcTb.getPadValueI32()) + - " vs " + std::to_string(dstTb.getPadValueI32()) + - ")"); - - auto diag = op->emitError() - << "expects src/dst tile types to be lowerable to TFILLPAD " - "for loc=mat"; - if (!mismatchFields.empty()) - diag << "; mismatching fields: " << llvm::join(mismatchFields, ", "); - diag << "\n src: " << srcTy; - diag << "\n dst: " << dstTy; - diag << "\n note: heterogeneous TFILLPAD overload is only available for loc=vec"; - return failure(); - } - } - - if (auto dstTileTy = mlir::dyn_cast(dstTy)) { - auto padAttr = mlir::dyn_cast(dstTileTy.getPadValueAttr()); - if (!padAttr || padAttr.getValue() == mlir::pto::PadValue::Null) - return op->emitError() << "expects dst PadVal != Null for " << opName; - } - - if (!allowDstExpand) { - if (srcShape != dstShape) - return op->emitError() - << "expects src and dst to have the same static shape for " << opName; - return mlir::success(); - } - - if (srcShape[0] > dstShape[0] || srcShape[1] > dstShape[1]) { - return op->emitError() - << "expects dst static shape to be >= src static shape for " << opName; - } - - return mlir::success(); -} - -mlir::LogicalResult mlir::pto::TFillPadOp::verify() { - return verifyTFillPadLike(getOperation(), getSrc().getType(), getDst().getType(), - /*allowDstExpand=*/false, "tfillpad"); -} - -mlir::LogicalResult mlir::pto::TFillPadExpandOp::verify() { - return verifyTFillPadLike(getOperation(), getSrc().getType(), getDst().getType(), - /*allowDstExpand=*/true, "tfillpad_expand"); -} - -mlir::LogicalResult mlir::pto::TFillPadInplaceOp::verify() { - return verifyTFillPadLike(getOperation(), getSrc().getType(), getDst().getType(), - /*allowDstExpand=*/false, "tfillpad_inplace"); -} - - -llvm::LogicalResult mlir::pto::TGatherOp::verify() { - auto isSupportedGatherElemTypeA5Index = [&](Type ty) -> bool { - if (ty.isF16() || ty.isF32()) - return true; - if (auto it = dyn_cast(ty)) { - unsigned width = it.getWidth(); - return width == 8 || width == 16 || width == 32; - } - return false; - }; - - auto verifyMaskForm = [&](bool allowA5MaskTypes) -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, dstTy, "dst"))) - return failure(); - - Type srcElem = getElemTy(srcTy); - Type dstElem = getElemTy(dstTy); - if (!srcElem || !dstElem) - return emitOpError("failed to get element type for src/dst"); - if (!isRowMajorTileBuf(srcTy) || !isRowMajorTileBuf(dstTy)) - return emitOpError("expects src and dst to use row-major layout"); - auto srcSpace = getPTOMemorySpaceEnum(srcTy); - auto dstSpace = getPTOMemorySpaceEnum(dstTy); - if (!srcSpace || !dstSpace || *srcSpace != pto::AddressSpace::VEC || - *dstSpace != pto::AddressSpace::VEC) - return emitOpError("expects src and dst to be in the vec address space"); - unsigned srcElemBytes = getPTOStorageElemByteSize(srcElem); - unsigned dstElemBytes = getPTOStorageElemByteSize(dstElem); - if (srcElemBytes == 0 || dstElemBytes == 0) - return emitOpError("failed to get element size for src/dst"); - if (srcElemBytes != dstElemBytes) - return emitOpError("expects src and dst element sizes to match"); - - auto dstValid = getValidShapeVec(dstTy); - auto dstShape = getShapeVec(dstTy); - if (dstValid.size() == 2 && dstShape.size() == 2 && - dstValid[1] != ShapedType::kDynamic && dstShape[1] != ShapedType::kDynamic && - dstValid[1] != dstShape[1]) { - return emitOpError("expects dst valid_shape[1] to equal dst cols"); - } - - if (allowA5MaskTypes) { - if (!(srcElemBytes == 1 || srcElemBytes == 2 || srcElemBytes == 4)) - return emitOpError("expects A5 mask-pattern gather element size to be 1, 2, or 4 bytes"); - if (!isSupportedGatherElemTypeA5(srcElem) || !isSupportedGatherElemTypeA5(dstElem)) - return emitOpError( - "expects A5 mask-pattern gather src/dst element type to be i8/i16/i32/f16/bf16/f32/fp8-like"); - } else { - if (!(srcElemBytes == 2 || srcElemBytes == 4)) - return emitOpError("expects A2/A3 mask-pattern gather element size to be 2 or 4 bytes"); - } - return success(); - }; - - auto verifyIndexForm = [&](bool allow16BitIndices, bool allowA5ElemTypes) -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - Type idxTy = getIndices().getType(); - Type tmpTy = getTmp().getType(); - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, dstTy, "dst")) || - failed(verifyTileBufCommon(*this, idxTy, "indices")) || - failed(verifyTileBufCommon(*this, tmpTy, "tmp"))) - return failure(); - - Type srcElem = getElemTy(srcTy); - Type dstElem = getElemTy(dstTy); - if (!srcElem || !dstElem) - return emitOpError("failed to get element type for src/dst"); - if (srcElem != dstElem) - return emitOpError("expects src and dst to have the same element type"); - if (allowA5ElemTypes) { - if (!isSupportedGatherElemTypeA5Index(srcElem) || - !isSupportedGatherElemTypeA5Index(dstElem)) - return emitOpError( - "expects A5 gather src/dst element type to be i8/i16/i32/f16/f32"); - } else if (!isSupportedGatherElemTypeA2A3(srcElem) || - !isSupportedGatherElemTypeA2A3(dstElem)) { - return emitOpError("expects gather src/dst element type to be i16/i32/f16/f32"); - } - - auto idxElem = dyn_cast(getElemTy(idxTy)); - if (!idxElem) - return emitOpError("indices element type must be integer"); - unsigned width = idxElem.getWidth(); - if (!(width == 32 || (allow16BitIndices && width == 16))) { - return emitOpError() << "expects indices element type to be i32" - << (allow16BitIndices ? " or i16" : ""); - } - - auto dstValid = getValidShapeVec(dstTy); - auto dstShape = getShapeVec(dstTy); - if (dstValid.size() == 2 && dstShape.size() == 2 && - dstValid[1] != ShapedType::kDynamic && dstShape[1] != ShapedType::kDynamic && - dstValid[1] != dstShape[1]) { - return emitOpError("expects dst valid_shape[1] to equal dst cols"); - } - - auto idxValid = getValidShapeVec(idxTy); - auto idxShape = getShapeVec(idxTy); - if (idxValid.size() == 2 && idxShape.size() == 2 && - idxValid[1] != ShapedType::kDynamic && idxShape[1] != ShapedType::kDynamic && - idxValid[1] != idxShape[1]) { - return emitOpError("expects indices valid_shape[1] to equal indices cols"); - } - - if (!allowA5ElemTypes) { - Type tmpElem = getElemTy(tmpTy); - if (tmpElem != idxElem) - return emitOpError("expects tmp and indices to have the same element type"); - if (failed(verifyTileBufSameValidShape(*this, idxTy, tmpTy, "indices", "tmp"))) - return failure(); - } - return success(); - }; - - auto verifyCompareForm = [&](bool allowA5SrcTypes) -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - Type cdstTy = getCdst().getType(); - Type tmpTy = getTmp().getType(); - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, dstTy, "dst")) || - failed(verifyTileBufCommon(*this, cdstTy, "cdst")) || - failed(verifyTileBufCommon(*this, tmpTy, "tmp"))) - return failure(); - - Type srcElem = getElemTy(srcTy); - Type dstElem = getElemTy(dstTy); - Type cdstElem = getElemTy(cdstTy); - if (!srcElem || !dstElem || !cdstElem) - return emitOpError("failed to get element type for src/dst/cdst"); - auto dstInt = dyn_cast(dstElem); - if (!dstInt || dstInt.getWidth() != 32) - return emitOpError("expects dst element type to be i32"); - if (cdstElem != dstElem) - return emitOpError("expects cdst to have the same element type as dst"); - if (getKValue().getType() != srcElem) - return emitOpError("expects kValue to have the same type as src element type"); - - auto cmpAttr = getCmpModeAttr(); - auto cmpMode = cmpAttr ? cmpAttr.getValue() : pto::CmpMode::EQ; - if (cmpMode != pto::CmpMode::EQ && cmpMode != pto::CmpMode::GT) - return emitOpError("expects compare-form tgather cmpMode to be eq or gt"); - - if (allowA5SrcTypes) { - if (!(srcElem.isF16() || srcElem.isF32() || srcElem.isInteger(16) || - srcElem.isInteger(32))) { - return emitOpError( - "expects A5 compare-form tgather src element type to be i16/i32/f16/f32"); - } - } else { - if (!(srcElem.isF16() || srcElem.isF32() || - (srcElem.isInteger(32) && cmpMode == pto::CmpMode::EQ))) { - return emitOpError( - "expects A2/A3 compare-form tgather src element type to be f16/f32, or i32 when cmpMode=eq"); - } - } - - if (failed(verifyVecTileCommonA2A3(*this, srcTy, "src")) || - failed(verifyVecTileCommonA2A3(*this, dstTy, "dst")) || - failed(verifyVecTileCommonA2A3(*this, cdstTy, "cdst")) || - failed(verifyVecTileCommonA2A3(*this, tmpTy, "tmp"))) - return failure(); - return success(); - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - if (getMaskPatternAttr()) { - if (getCdst() || getIndices() || getTmp() || getKValue()) - return emitOpError("mask-pattern tgather only allows src and dst operands"); - return verifyMaskForm(/*allowA5MaskTypes=*/false); - } - if (getCdst() || getKValue()) { - if (!getCdst() || !getKValue() || !getTmp()) - return emitOpError("compare-form tgather expects dst, cdst, kValue, and tmp"); - if (getIndices()) - return emitOpError("compare-form tgather does not take indices"); - return verifyCompareForm(/*allowA5SrcTypes=*/false); - } - if (!getIndices() || !getTmp()) - return emitOpError("index-form tgather expects both indices and tmp"); - return verifyIndexForm(/*allow16BitIndices=*/false, /*allowA5ElemTypes=*/false); - }; - - auto verifyA5 = [&]() -> LogicalResult { - if (getMaskPatternAttr()) { - if (getCdst() || getIndices() || getTmp() || getKValue()) - return emitOpError("mask-pattern tgather only allows src and dst operands"); - return verifyMaskForm(/*allowA5MaskTypes=*/true); - } - if (getCdst() || getKValue()) { - if (!getCdst() || !getKValue() || !getTmp()) - return emitOpError("compare-form tgather expects dst, cdst, kValue, and tmp"); - if (getIndices()) - return emitOpError("compare-form tgather does not take indices"); - return verifyCompareForm(/*allowA5SrcTypes=*/true); - } - if (!getIndices() || !getTmp()) - return emitOpError("index-form tgather expects both indices and tmp"); - return verifyIndexForm(/*allow16BitIndices=*/true, /*allowA5ElemTypes=*/true); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} -mlir::LogicalResult mlir::pto::TGatherBOp::verify() { - auto verifyCommon = [&]() -> FailureOr> { - Type srcTy = getSrc().getType(); - Type offTy = getOffsets().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, offTy, "offsets")) || - failed(verifyTileBufCommon(*this, dstTy, "dst"))) - return failure(); - auto srcElemTy = getElemTy(srcTy); - auto dstElemTy = getElemTy(dstTy); - if (!srcElemTy || !dstElemTy) - return emitOpError() << "failed to get element type for src/dst"; - return std::make_pair(srcElemTy, dstElemTy); - }; - - auto getElemBytes = [](Type ty) -> std::optional { - unsigned elemBytes = getPTOStorageElemByteSize(ty); - if (elemBytes == 0) - return std::nullopt; - return elemBytes; - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - FailureOr> elems = verifyCommon(); - if (failed(elems)) - return failure(); - Type dstTy = getDst().getType(); - Type dstElemTy = elems->second; - if (!isRowMajorTileBuf(dstTy)) - return emitOpError() << "expects dst to use row-major layout"; - auto dstBytes = getElemBytes(dstElemTy); - if (!dstBytes || (*dstBytes != 1 && *dstBytes != 2 && *dstBytes != 4)) - return emitOpError() << "expects dst element size to be 1, 2, or 4 bytes"; - return mlir::success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - FailureOr> elems = verifyCommon(); - if (failed(elems)) - return failure(); - Type dstElemTy = elems->second; - auto dstBytes = getElemBytes(dstElemTy); - if (!dstBytes || (*dstBytes != 1 && *dstBytes != 2 && *dstBytes != 4)) - return emitOpError() << "expects dst element size to be 1, 2, or 4 bytes"; - return mlir::success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TLogOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyVecTileUnaryOp(*this, srcTy, dstTy, "src", "dst", - /*allowBf16=*/false, /*allowInt8=*/false))) - return failure(); - if (failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) - return failure(); - auto elemTy = getElemTy(srcTy); - if (!(elemTy.isF16() || elemTy.isF32())) - return emitOpError() << "expects element type to be f16 or f32"; - return mlir::success(); -} - -mlir::LogicalResult mlir::pto::TLReluOp::verify() { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - auto verifyA2A3 = [&]() -> LogicalResult { - if (failed(verifyVecTileStorage(*this, srcTy, "src")) || - failed(verifyVecTileStorage(*this, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst")) || - failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) - return failure(); - auto valid = getValidShapeVec(srcTy); - if (valid.size() != 2) - return emitOpError("expects src to have rank-2 valid_shape"); - if (valid[0] != ShapedType::kDynamic && valid[0] <= 0) - return emitOpError("expects src valid_shape[0] to be positive"); - if (valid[1] != ShapedType::kDynamic && valid[1] <= 0) - return emitOpError("expects src valid_shape[1] to be positive"); - Type elemTy = getElemTy(srcTy); - if (!(elemTy.isF16() || elemTy.isF32())) - return emitOpError() << "expects A2/A3 tlrelu element type to be f16 or f32"; - return success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - if (failed(verifyVecTileStorage(*this, srcTy, "src")) || - failed(verifyVecTileStorage(*this, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst")) || - failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) - return failure(); - Type elemTy = getElemTy(srcTy); - if (!(elemTy.isF16() || elemTy.isF32())) - return emitOpError() << "expects A5 tlrelu element type to be f16 or f32"; - if (!getSlope().getType().isF32()) - return emitOpError() << "expects slope to have type f32"; - return success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TMaxOp::verify() { - return verifyArithmeticBinaryTileOpWithArchDispatch( - getOperation(), getSrc0().getType(), getSrc1().getType(), getDst().getType(), - /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/false, - "expects A2/A3 tmax element type to be i32/i16/f16/f32", - "expects A5 tmax element type to be i32/i16/i8/f16/f32"); -} - -mlir::LogicalResult mlir::pto::TMaxSOp::verify() { - return verifyArithmeticScalarTileOpWithArchDispatch( - getOperation(), getSrc().getType(), getDst().getType(), getScalar().getType(), - /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, - "expects A2/A3 tmaxs element type to be i32/i16/f16/f32", - "expects A5 tmaxs element type to be i32/i16/i8/f16/bf16/f32", - /*requireValidRowsEqualOnA2A3=*/true, - /*requireValidRowsEqualOnA5=*/true); -} - -mlir::LogicalResult mlir::pto::TMinOp::verify() { - return verifyArithmeticBinaryTileOpWithArchDispatch( - getOperation(), getSrc0().getType(), getSrc1().getType(), getDst().getType(), - /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, - "expects A2/A3 tmin element type to be i32/i16/f16/f32", - "expects A5 tmin element type to be i32/i16/i8/f16/bf16/f32"); -} - -mlir::LogicalResult mlir::pto::TMinSOp::verify() { - return verifyArithmeticScalarTileOpWithArchDispatch( - getOperation(), getSrc().getType(), getDst().getType(), getScalar().getType(), - /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, - "expects A2/A3 tmins element type to be i32/i16/f16/f32", - "expects A5 tmins element type to be i32/i16/i8/f16/bf16/f32", - /*requireValidRowsEqualOnA2A3=*/true, - /*requireValidRowsEqualOnA5=*/false); -} - -mlir::LogicalResult mlir::pto::TMovOp::verify() { - auto verifyImpl = [&](bool isA5) -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - Value fp = getFp(); - Value preQuantScalar = getPreQuantScalar(); - auto accToVecModeAttr = getAccToVecModeAttr(); - auto reluMode = getReluPreMode(); - const bool hasFp = static_cast(fp); - const bool hasPreQuantScalar = static_cast(preQuantScalar); - - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, dstTy, "dst"))) - return failure(); - if (hasFp && failed(verifyTileBufCommon(*this, fp.getType(), "fp"))) - return failure(); - if (hasFp && hasPreQuantScalar) - return emitOpError() << "expects fp and preQuantScalar forms to be mutually exclusive"; - - auto srcSpace = getPTOMemorySpaceEnum(srcTy); - auto dstSpace = getPTOMemorySpaceEnum(dstTy); - if (!srcSpace || !dstSpace) - return emitOpError() << "expects src and dst to have explicit address spaces"; - - auto srcShape = getShapeVec(srcTy); - auto dstShape = getShapeVec(dstTy); - if (*srcSpace == pto::AddressSpace::MAT && srcShape != dstShape) - return emitOpError() << "expects mat-source tmov to use matching src/dst shapes"; - if (!isA5 && *srcSpace != pto::AddressSpace::MAT && srcShape != dstShape) - return emitOpError() << "expects A2/A3 non-mat tmov to use matching src/dst shapes"; - - const bool isMatToTile = - *srcSpace == pto::AddressSpace::MAT && - (*dstSpace == pto::AddressSpace::LEFT || - *dstSpace == pto::AddressSpace::RIGHT || - *dstSpace == pto::AddressSpace::BIAS || - *dstSpace == pto::AddressSpace::SCALING); - const bool isVecToVec = - *srcSpace == pto::AddressSpace::VEC && - *dstSpace == pto::AddressSpace::VEC; - const bool isVecToMat = - *srcSpace == pto::AddressSpace::VEC && - *dstSpace == pto::AddressSpace::MAT; - const bool isAccToMat = - *srcSpace == pto::AddressSpace::ACC && - *dstSpace == pto::AddressSpace::MAT; - const bool isAccToVec = - *srcSpace == pto::AddressSpace::ACC && - *dstSpace == pto::AddressSpace::VEC; - - bool okPair = isMatToTile || isVecToVec || isAccToMat || isAccToVec; - if (isA5) - okPair = okPair || isVecToMat; - if (!okPair) - return emitOpError() - << "expects a supported tmov address-space pair for this target"; - - if (accToVecModeAttr && !isAccToVec) - return emitOpError() - << "expects accToVecMode to be used only for acc-to-vec tmov"; - - if (reluMode != pto::ReluPreMode::NoRelu && !(isAccToMat || isAccToVec)) - return emitOpError() - << "expects reluPreMode form to use loc=acc src"; - - if (hasPreQuantScalar && !(isAccToMat || isAccToVec)) - return emitOpError() - << "expects preQuantScalar form to use loc=acc src"; - - if (hasFp) { - auto fpTy = fp.getType(); - auto fpSpace = getPTOMemorySpaceEnum(fpTy); - if (!fpSpace || *fpSpace != pto::AddressSpace::SCALING) - return emitOpError() << "expects fp to be in the scaling address space"; - auto srcElemTy = getElemTy(srcTy); - auto srcIntTy = dyn_cast(srcElemTy); - if (!(srcElemTy.isF32() || (srcIntTy && srcIntTy.getWidth() == 32))) - return emitOpError() - << "expects fp form src to have element type f32, i32"; - if (!(isAccToMat || isAccToVec)) - return emitOpError() << "expects fp form to use loc=acc src"; - } - - if ((hasFp || hasPreQuantScalar) && accToVecModeAttr) { - switch (accToVecModeAttr.getValue()) { - case pto::AccToVecMode::SingleModeVec0: - case pto::AccToVecMode::SingleModeVec1: - break; - case pto::AccToVecMode::DualModeSplitM: - case pto::AccToVecMode::DualModeSplitN: - return emitOpError() - << "expects fp/preQuantScalar acc-to-vec forms to use single-mode accToVecMode"; - } - } - - auto srcTb = dyn_cast(srcTy); - auto dstTb = dyn_cast(dstTy); - if (srcTb && *srcSpace == pto::AddressSpace::ACC && - (hasFp || reluMode != pto::ReluPreMode::NoRelu)) { - if (srcTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor) || - srcTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor)) - return emitOpError() - << "expects acc-source fp/relu tmov src to use blayout=col_major and slayout=row_major"; - } - if (srcTb && dstTb && isAccToMat && !isA5 && - dstTb.getSFractalSizeI32() != 512) - return emitOpError() << "expects A2/A3 acc-to-mat tmov destination fractal to be 512"; - - return success(); - }; - auto verifyA2A3 = [&]() -> LogicalResult { return verifyImpl(/*isA5=*/false); }; - auto verifyA5 = [&]() -> LogicalResult { return verifyImpl(/*isA5=*/true); }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TMovFPOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type fpTy = getFp().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, fpTy, "fp")) || - failed(verifyTileBufCommon(*this, dstTy, "dst"))) - return failure(); - auto srcElemTy = getElemTy(srcTy); - auto srcIntTy = dyn_cast(srcElemTy); - if (!(srcElemTy.isF32() || - (srcIntTy && srcIntTy.getWidth() == 32))) - return emitOpError() - << "expects src to have element type f32, i32"; - auto fpSpace = getPTOMemorySpaceEnum(fpTy); - if (!fpSpace || *fpSpace != mlir::pto::AddressSpace::SCALING) - return emitOpError() << "expects fp to be in the scaling address space"; - auto srcSpace = getPTOMemorySpaceEnum(srcTy); - if (!srcSpace || *srcSpace != mlir::pto::AddressSpace::ACC) - return emitOpError() << "expects src to be in the acc address space"; - auto dstSpace = getPTOMemorySpaceEnum(dstTy); - if (!dstSpace || *dstSpace != mlir::pto::AddressSpace::MAT) - return emitOpError() << "expects dst to be in the mat address space"; - auto srcTb = dyn_cast(srcTy); - auto dstTb = dyn_cast(dstTy); - if (srcTb && - (srcTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor) || - srcTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor))) - return emitOpError() - << "expects src to use blayout=col_major and slayout=row_major"; - if (dstTb && - (dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor) || - dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor))) - return emitOpError() - << "expects dst to use blayout=col_major and slayout=row_major"; - if (dstTb && dstTb.getSFractalSizeI32() != 512) - return emitOpError() << "expects dst to use fractal size 512"; - return mlir::success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type fpTy = getFp().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, fpTy, "fp")) || - failed(verifyTileBufCommon(*this, dstTy, "dst"))) - return failure(); - auto srcElemTy = getElemTy(srcTy); - auto srcIntTy = dyn_cast(srcElemTy); - if (!(srcElemTy.isF32() || - (srcIntTy && srcIntTy.getWidth() == 32))) - return emitOpError() - << "expects src to have element type f32, i32"; - auto fpSpace = getPTOMemorySpaceEnum(fpTy); - if (!fpSpace || *fpSpace != mlir::pto::AddressSpace::SCALING) - return emitOpError() << "expects fp to be in the scaling address space"; - auto srcTb = dyn_cast(srcTy); - if (srcTb && - (srcTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor) || - srcTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor))) - return emitOpError() - << "expects src to use blayout=col_major and slayout=row_major"; - return success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} -// 辅助函数:获取 Rank,支持 ShapedType 和 PTO TileTypes -static int64_t getRankHelper(Type t) { - if (auto s = dyn_cast(t)) return s.getRank(); - if (auto tile = dyn_cast(t)) return tile.getRank(); - if (auto view = dyn_cast(t)) return view.getRank(); - return -1; -} - -static LogicalResult verifyMatmulLike(Operation *op, Type aTy, Type bTy, Type dstTy, bool checkRank = true) { - // 1. 检查类型 (ShapedType 或 Tile 类型) - bool aValid = isa(aTy); - bool bValid = isa(bTy); - bool dValid = isa(dstTy); - - if (!aValid || !bValid || !dValid) - return op->emitOpError("expects inputs/outputs to be shaped types or PTO tile types"); - - if (checkRank) { - int64_t aRank = getRankHelper(aTy); - int64_t bRank = getRankHelper(bTy); - int64_t dRank = getRankHelper(dstTy); - - // 检查 Rank 一致性 - if (aRank != -1 && dRank != -1 && aRank != dRank) - return op->emitOpError("expects a and dst to have the same rank"); - if (bRank != -1 && dRank != -1 && bRank != dRank) - return op->emitOpError("expects b and dst to have the same rank"); - } - - return success(); -} - -// ---- LoadScalarOp ---- -LogicalResult LoadScalarOp::verify() { - Type ptrTy = getPtr().getType(); - Type elemTy; - if (auto pty = dyn_cast(ptrTy)) { - elemTy = pty.getElementType(); - } else if (auto memTy = dyn_cast(ptrTy)) { - elemTy = memTy.getElementType(); - if (!isGmAddressSpaceAttr(memTy.getMemorySpace())) - return emitOpError() << "scalar load only supports GM address space pointers"; - } else { - return emitOpError("expects ptr to be !pto.ptr or memref type"); - } - - if (getValue().getType() != elemTy) - return emitOpError("expects result type to match ptr element type"); - - return success(); -} -// ---- StoreScalarOp ---- -LogicalResult StoreScalarOp::verify() { - Type ptrTy = getPtr().getType(); - Type elemTy; - if (auto pty = dyn_cast(ptrTy)) { - elemTy = pty.getElementType(); - } else if (auto memTy = dyn_cast(ptrTy)) { - elemTy = memTy.getElementType(); - if (!isGmAddressSpaceAttr(memTy.getMemorySpace())) - return emitOpError() << "scalar store only supports GM address space pointers"; - } else { - return emitOpError("expects ptr to be !pto.ptr or memref type"); - } - - if (getValue().getType() != elemTy) - return emitOpError("expects value type to match ptr element type"); - - return success(); -} - -// ---- GetBufOp / RlsBufOp ---- -static LogicalResult verifyBufSyncOp(Operation *op, Attribute opTypeAttr, - IntegerAttr bufIdAttr, IntegerAttr modeAttr) { - if (!opTypeAttr) - return op->emitOpError("expects 'op_type' attribute"); - - auto opTypeOr = parseSyncOpTypeLikeAttr(opTypeAttr); - if (failed(opTypeOr)) { - auto diag = - op->emitOpError("expects 'op_type' to be pipe_event_type/sync_op_type, got "); - diag << opTypeAttr; - return failure(); - } - pto::PIPE pipe = mapSyncOpTypeToPipe(*opTypeOr); - if (!isConcreteSyncPipe(pipe)) - return op->emitOpError("expects 'op_type' to map to a concrete pipe, not PIPE_ALL/PIPE_UNASSIGNED"); - - if (!bufIdAttr) - return op->emitOpError("expects 'buf_id' attribute"); - int64_t bufId = bufIdAttr.getInt(); - if (bufId < 0 || bufId > 31) - return op->emitOpError("expects 'buf_id' in range [0, 31]"); - - if (modeAttr) { - int64_t mode = modeAttr.getInt(); - if (mode < 0) - return op->emitOpError("expects 'mode' to be non-negative"); - } - - return success(); -} - -LogicalResult GetBufOp::verify() { - return verifyBufSyncOp(getOperation(), getOpTypeAttr(), getBufIdAttr(), - getModeAttr()); -} - -LogicalResult RlsBufOp::verify() { - return verifyBufSyncOp(getOperation(), getOpTypeAttr(), getBufIdAttr(), - getModeAttr()); -} -// ---- TOp ---- -LogicalResult TGemvBiasOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - if (failed(verifyGemvTileOperands(*this, getA().getType(), getB().getType(), - getDst().getType())) || - failed(verifyMatBiasTile(*this, getBias().getType(), getDst().getType()))) - return failure(); - if (failed(verifyMatmulTypeTriple(*this, getElemTy(getA().getType()), - getElemTy(getB().getType()), - getElemTy(getDst().getType())))) - return failure(); - return verifyMatmulLike(*this, getA().getType(), getB().getType(), - getDst().getType()); - }; - auto verifyA5 = [&]() -> LogicalResult { return verifyA2A3(); }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult TGemvMxOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - return emitOpError("tgemv.mx is only supported on A5 targets"); - }; - auto verifyA5 = [&]() -> LogicalResult { - if (failed(verifyScaleTileMatchesOperand(*this, getAScale().getType(), - getA().getType(), "a_scale", "a")) || - failed(verifyScaleTileMatchesOperand(*this, getBScale().getType(), - getB().getType(), "b_scale", "b")) || - failed(verifyGemvTileOperands(*this, getA().getType(), getB().getType(), - getDst().getType()))) - return failure(); - if (failed(verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), - getDst().getType(), "a", "b", "dst"))) - return failure(); - return verifyMatmulLike(*this, getA().getType(), getB().getType(), - getDst().getType()); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult TGemvMxAccOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - return emitOpError("tgemv.mx.acc is only supported on A5 targets"); - }; - auto verifyA5 = [&]() -> LogicalResult { - if (failed(verifyAccTileCommon(*this, getCIn().getType(), "c_in")) || - failed(verifyScaleTileMatchesOperand(*this, getAScale().getType(), - getA().getType(), "a_scale", "a")) || - failed(verifyScaleTileMatchesOperand(*this, getBScale().getType(), - getB().getType(), "b_scale", "b")) || - failed(verifyGemvTileOperands(*this, getA().getType(), getB().getType(), - getDst().getType()))) - return failure(); - if (failed(verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), - getDst().getType(), "a", "b", "dst"))) - return failure(); - if (failed(verifyTileBufSameElemType(*this, getCIn().getType(), - getDst().getType(), "c_in", "dst")) || - failed(verifyTileBufSameValidShape(*this, getCIn().getType(), - getDst().getType(), "c_in", "dst"))) - return failure(); - return verifyMatmulLike(*this, getA().getType(), getB().getType(), - getDst().getType()); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult TGemvMxBiasOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - return emitOpError("tgemv.mx.bias is only supported on A5 targets"); - }; - auto verifyA5 = [&]() -> LogicalResult { - if (failed(verifyScaleTileMatchesOperand(*this, getAScale().getType(), - getA().getType(), "a_scale", "a")) || - failed(verifyScaleTileMatchesOperand(*this, getBScale().getType(), - getB().getType(), "b_scale", "b")) || - failed(verifyGemvTileOperands(*this, getA().getType(), getB().getType(), - getDst().getType())) || - failed(verifyMatBiasTile(*this, getBias().getType(), getDst().getType(), - /*requireFloatBias=*/true))) - return failure(); - if (failed(verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), - getDst().getType(), "a", "b", "dst"))) - return failure(); - auto biasShape = getShapeVec(getBias().getType()); - auto dstShape = getShapeVec(getDst().getType()); - if (biasShape.size() != 2 || dstShape.size() != 2) - return emitOpError("expects bias and dst to be rank-2 for tgemv.mx.bias"); - if (biasShape[1] != ShapedType::kDynamic && dstShape[1] != ShapedType::kDynamic && - biasShape[1] != dstShape[1]) - return emitOpError("expects bias and dst to have the same column shape"); - if (failed(verifyTileBufSameValidShape(*this, getBias().getType(), - getDst().getType(), "bias", "dst"))) - return failure(); - return verifyMatmulLike(*this, getA().getType(), getB().getType(), - getDst().getType()); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult TMatmulBiasOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - if (failed(verifyMatTileOperands(*this, getA().getType(), getB().getType(), - getDst().getType())) || - failed(verifyMatBiasTile(*this, getBias().getType(), getDst().getType()))) - return failure(); - if (failed(verifyMatmulTypeTriple(*this, getElemTy(getA().getType()), - getElemTy(getB().getType()), - getElemTy(getDst().getType())))) - return failure(); - return verifyMatmulLike(*this, getA().getType(), getB().getType(), - getDst().getType()); - }; - auto verifyA5 = [&]() -> LogicalResult { return verifyA2A3(); }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult TMatmulMxOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - if (failed(verifyTileBufCommon(*this, getAScale().getType(), "a_scale")) || - failed(verifyTileBufCommon(*this, getBScale().getType(), "b_scale"))) - return failure(); - return verifyMatmulLike(*this, getA().getType(), getB().getType(), - getDst().getType()); - }; - auto verifyA5 = [&]() -> LogicalResult { - if (failed(verifyA2A3())) - return failure(); - return verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), - getDst().getType(), "a", "b", "dst"); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult TMatmulMxAccOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - if (failed(verifyAccTileCommon(*this, getCIn().getType(), "c_in")) || - failed(verifyTileBufCommon(*this, getAScale().getType(), "a_scale")) || - failed(verifyTileBufCommon(*this, getBScale().getType(), "b_scale"))) - return failure(); - return success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - if (failed(verifyA2A3())) - return failure(); - if (failed(verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), - getDst().getType(), "a", "b", "dst"))) - return failure(); - if (failed(verifyTileBufSameElemType(*this, getCIn().getType(), - getDst().getType(), "c_in", "dst")) || - failed(verifyTileBufSameValidShape(*this, getCIn().getType(), - getDst().getType(), "c_in", "dst"))) - return failure(); - return success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} -LogicalResult TMatmulMxBiasOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - if (failed(verifyTileBufCommon(*this, getAScale().getType(), "a_scale")) || - failed(verifyTileBufCommon(*this, getBScale().getType(), "b_scale")) || - failed(verifyMatTileOperands(*this, getA().getType(), getB().getType(), - getDst().getType())) || - failed(verifyMatBiasTile(*this, getBias().getType(), getDst().getType(), - /*requireFloatBias=*/true))) - return failure(); - return verifyMatmulLike(*this, getA().getType(), getB().getType(), - getDst().getType()); - }; - auto verifyA5 = [&]() -> LogicalResult { - if (failed(verifyA2A3())) - return failure(); - return verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), - getDst().getType(), "a", "b", "dst"); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} -// ---- TSetValOp ---- -LogicalResult TSetValOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - // dst can be tile/tensor/tilebuf (PTODpsType). Keep checks minimal. - if (auto shaped = dyn_cast(getDst().getType())) { - if (shaped.getElementType() != getVal().getType()) - return emitOpError("expects val type to match dst element type"); - } - return success(); -} -// ---- TGetValOp ---- -LogicalResult TGetValOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - Type srcTy = getSrc().getType(); - if (!mlir::isa(srcTy)) - return emitOpError("expects src to be tile_buf or memref type"); - - // Memory space must be vec (Ascend does not support getval from MAT etc.). - Attribute memSpace = - isa(srcTy) - ? cast(srcTy).getMemorySpace() - : cast(srcTy).getMemorySpace(); - auto addrSpaceAttr = dyn_cast_or_null(memSpace); - if (!addrSpaceAttr || - addrSpaceAttr.getAddressSpace() != pto::AddressSpace::VEC) { - if (addrSpaceAttr && - addrSpaceAttr.getAddressSpace() == pto::AddressSpace::MAT) - return emitOpError( - "Ascend hardware does not support reading from Mat tile_buf to Scalar unit"); - return emitOpError("expects src memory space to be vec"); - } - - if (getElemTy(srcTy) != getDst().getType()) - return emitOpError("expects dst type to match src element type"); - return success(); -} - -LogicalResult THistogramOp::verify() { - auto isIntegerWidth = [](Type ty, unsigned width) { - auto it = dyn_cast(ty); - return it && it.getWidth() == width; - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - return emitOpError("thistogram is only supported on A5"); - }; - auto verifyA5 = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type idxTy = getIdx().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, idxTy, "idx")) || - failed(verifyTileBufCommon(*this, dstTy, "dst"))) - return failure(); - - auto srcSpace = getPTOMemorySpaceEnum(srcTy); - auto idxSpace = getPTOMemorySpaceEnum(idxTy); - auto dstSpace = getPTOMemorySpaceEnum(dstTy); - if (!srcSpace || *srcSpace != pto::AddressSpace::VEC) - return emitOpError("expects src to be in the vec address space"); - if (!idxSpace || *idxSpace != pto::AddressSpace::VEC) - return emitOpError("expects idx to be in the vec address space"); - if (!dstSpace || *dstSpace != pto::AddressSpace::VEC) - return emitOpError("expects dst to be in the vec address space"); - - auto srcTB = dyn_cast(srcTy); - auto idxTB = dyn_cast(idxTy); - auto dstTB = dyn_cast(dstTy); - if (!srcTB || !idxTB || !dstTB) - return emitOpError("expects src, idx, and dst to be tile_buf types"); - - if (srcTB.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) || - srcTB.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) - return emitOpError("expects src to use row_major + none_box layout"); - if (dstTB.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) || - dstTB.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) - return emitOpError("expects dst to use row_major + none_box layout"); - if (idxTB.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor) || - idxTB.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) - return emitOpError( - "expects idx to use DN layout (col_major + none_box)"); - - if (!isIntegerWidth(getElemTy(srcTy), 16)) - return emitOpError("expects src element type to be ui16"); - if (!isIntegerWidth(getElemTy(idxTy), 8)) - return emitOpError("expects idx element type to be ui8"); - if (!isIntegerWidth(getElemTy(dstTy), 32)) - return emitOpError("expects dst element type to be ui32"); - - auto srcShape = getShapeVec(srcTy); - auto idxShape = getShapeVec(idxTy); - auto dstShape = getShapeVec(dstTy); - auto srcValid = getValidShapeVec(srcTy); - auto idxValid = getValidShapeVec(idxTy); - auto dstValid = getValidShapeVec(dstTy); - if (srcShape.size() != 2 || idxShape.size() != 2 || dstShape.size() != 2 || - srcValid.size() != 2 || idxValid.size() != 2 || dstValid.size() != 2) - return emitOpError( - "expects src, idx, and dst to have rank-2 shape and valid_shape"); - - if (!hasCompatibleKnownExtent(srcShape[0], idxShape[0]) || - !hasCompatibleKnownExtent(srcValid[0], idxValid[0])) - return emitOpError("expects idx rows and valid rows to match src"); - if (!hasCompatibleKnownExtent(srcShape[0], dstShape[0]) || - !hasCompatibleKnownExtent(srcValid[0], dstValid[0])) - return emitOpError("expects dst rows and valid rows to match src"); - - if (!isKnownUnitExtent(idxShape[1]) || !isKnownUnitExtent(idxValid[1])) - return emitOpError("expects idx to have exactly one column"); - if (dstShape[1] != ShapedType::kDynamic && dstShape[1] < 256) - return emitOpError("expects dst shape[1] to be at least 256"); - if (dstValid[1] != ShapedType::kDynamic && dstValid[1] < 256) - return emitOpError("expects dst valid_shape[1] to be at least 256"); - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult TGetScaleAddrOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - return emitOpError("tget_scale_addr is only supported on A5"); - }; - auto verifyA5 = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, srcTy, "src"))) - return failure(); - if (failed(verifyScaleTileMatchesOperand(*this, dstTy, srcTy, "dst", "src"))) - return failure(); - return success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -// ---- MScatterOp ---- -LogicalResult MScatterOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - - if (!isTargetArchA5(getOperation())) - return emitOpError("pto.mscatter is only supported on A5 targets"); - - Type srcTy = getSrc().getType(); - Type idxTy = getIdx().getType(); - Type memTy = getMem().getType(); - - if (getPTOTypeRank(srcTy) == -1 || getPTOTypeRank(idxTy) == -1 || - getPTOTypeRank(memTy) == -1) - return emitOpError("expects src, idx, and mem to use supported PTO shapes"); - - if (failed(verifyNDStyleVecTile(*this, srcTy, "src")) || - failed(verifyMGatherMScatterIdxTile(getOperation(), idxTy, "idx"))) - return failure(); - - Type srcElem = getElemTy(srcTy); - Type idxElem = getElemTy(idxTy); - if (!srcElem || !idxElem) - return emitOpError("failed to resolve element types for src or idx"); - - if (!isSupportedMGatherMScatterPayloadElemType(getOperation(), srcElem)) - return emitOpError( - "expects src element type to be i8/ui8/i16/ui16/i32/ui32/f16/bf16/f32 " - "(and on A5 targets also float8_e4m3/float8_e5m2 family types)"); - - if (!isSupportedMGatherMScatterIndexElemType(idxElem)) - return emitOpError("expects idx element type to be signless i32"); - - if (failed(verifyMGatherMScatterMemOperand(getOperation(), getMem(), srcElem, - "src"))) - return failure(); - - if (getScatterAtomicOp() != pto::ScatterAtomicOp::None || - getScatterOob() != pto::ScatterOOB::Undefined) { - if (!isTargetArchA5(getOperation())) - return emitOpError( - "expects non-default scatterAtomicOp/scatterOob only on A5 targets"); - } - - if (!isSupportedMScatterAtomicPayloadElemType(srcElem, getScatterAtomicOp())) - return emitOpError( - "expects scatterAtomicOp-compatible src element type: add supports " - "i32/ui32/f16/f32, max/min support signless i32/f32"); - - if (failed(verifyMGatherMScatterTileShape(getOperation(), srcTy, idxTy, "src"))) - return failure(); - - return success(); -} - -// ---- MGatherOp ---- -LogicalResult MGatherOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - - if (!isTargetArchA5(getOperation())) - return emitOpError("pto.mgather is only supported on A5 targets"); - - Type memTy = getMem().getType(); - Type idxTy = getIdx().getType(); - Type dstTy = getDst().getType(); - - if (getPTOTypeRank(memTy) == -1 || getPTOTypeRank(idxTy) == -1 || - getPTOTypeRank(dstTy) == -1) - return emitOpError("expects mem, idx, and dst to use supported PTO shapes"); - - if (failed(verifyNDStyleVecTile(*this, dstTy, "dst")) || - failed(verifyMGatherMScatterIdxTile(getOperation(), idxTy, "idx"))) - return failure(); - - Type dstElem = getElemTy(dstTy); - Type idxElem = getElemTy(idxTy); - if (!dstElem || !idxElem) - return emitOpError("failed to resolve element types for dst or idx"); - - if (!isSupportedMGatherMScatterPayloadElemType(getOperation(), dstElem)) - return emitOpError( - "expects dst element type to be i8/ui8/i16/ui16/i32/ui32/f16/bf16/f32 " - "(and on A5 targets also float8_e4m3/float8_e5m2 family types)"); - - if (!isSupportedMGatherMScatterIndexElemType(idxElem)) - return emitOpError("expects idx element type to be signless i32"); - - if (failed(verifyMGatherMScatterMemOperand(getOperation(), getMem(), dstElem, - "dst"))) - return failure(); - - if (getGatherOob() != pto::GatherOOB::Undefined && - !isTargetArchA5(getOperation())) - return emitOpError( - "expects non-default gatherOob only on A5 targets"); - - if (failed(verifyMGatherMScatterTileShape(getOperation(), dstTy, idxTy, "dst"))) - return failure(); - - return success(); -} - -void mlir::pto::TCvtOp::print(OpAsmPrinter &p) { - p << " ins(" << getSrc(); - Builder builder(getContext()); - NamedAttrList attrs; - for (auto attr : (*this)->getAttrs()) { - if (attr.getName() == "sat_mode") { - attrs.set(builder.getStringAttr("satmode"), attr.getValue()); - continue; - } - attrs.set(attr.getName(), attr.getValue()); - } - p.printOptionalAttrDict(attrs.getAttrs()); - p << " : " << getSrc().getType(); - p << ") outs(" << getDst() << " : " << getDst().getType() << ")"; -} - -ParseResult mlir::pto::TCvtOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand src, dst; - Type srcTy, dstTy; - - if (parser.parseKeyword("ins") || parser.parseLParen() || parser.parseOperand(src)) - return failure(); - NamedAttrList attrs; - if (parser.parseOptionalAttrDict(attrs) || parser.parseColonType(srcTy)) - return failure(); - if (auto satmode = attrs.get("satmode")) { - attrs.erase("satmode"); - if (attrs.get("sat_mode")) - return parser.emitError(parser.getCurrentLocation(), - "cannot specify both satmode and sat_mode"); - attrs.set("sat_mode", satmode); - } - result.attributes = attrs; - if (parser.parseRParen() || parser.parseKeyword("outs") || parser.parseLParen() || - parser.parseOperand(dst) || parser.parseColonType(dstTy) || parser.parseRParen()) - return failure(); - - if (parser.resolveOperand(src, srcTy, result.operands) || - parser.resolveOperand(dst, dstTy, result.operands)) - return failure(); - return success(); -} - -void mlir::pto::TMrgSortOp::print(OpAsmPrinter &p) { - if (isFormat1()) { - p << " ins(" << getSrc() << ", " << getBlockLen() << " : " << getSrc().getType() - << ", " << getBlockLen().getType() << ") outs(" << getDst() << " : " - << getDst().getType() << ")"; - } else if (isFormat2()) { - p << " ins("; - llvm::interleaveComma(getSrcs(), p, [&](Value src) { p << src; }); - p << ", " << getTmp(); - p << " {exhausted = " << (getExhausted() ? "true" : "false") << "} : "; - llvm::interleaveComma(getSrcs().getTypes(), p, [&](Type ty) { p << ty; }); - p << ", " << getTmp().getType(); - p << ") outs(" << getDst() << ", " << getExcuted() - << " : " << getDst().getType() << ", " << getExcuted().getType() << ")"; - } else { - llvm::report_fatal_error("TMrgSortOp print expects format1 or format2"); - } - p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"operandSegmentSizes", "exhausted"}); -} - -ParseResult mlir::pto::TMrgSortOp::parse(OpAsmParser &parser, OperationState &result) { - if (parser.parseKeyword("ins") || parser.parseLParen()) - return failure(); - OpAsmParser::UnresolvedOperand first, second; - if (parser.parseOperand(first) || parser.parseComma() || parser.parseOperand(second)) - return failure(); - - if (parser.parseOptionalColon().succeeded()) { - Type srcTy, blockLenTy, dstTy; - if (parser.parseType(srcTy) || parser.parseComma() || parser.parseType(blockLenTy) || - parser.parseRParen() || parser.parseKeyword("outs") || parser.parseLParen()) - return failure(); - OpAsmParser::UnresolvedOperand dstOp; - if (parser.parseOperand(dstOp) || parser.parseColon() || parser.parseType(dstTy) || - parser.parseRParen()) - return failure(); - result.addAttribute("operandSegmentSizes", - parser.getBuilder().getDenseI32ArrayAttr({1, 1, 1, 0, 0})); - if (parser.resolveOperand(first, srcTy, result.operands) || - parser.resolveOperand(second, blockLenTy, result.operands) || - parser.resolveOperand(dstOp, dstTy, result.operands)) - return failure(); - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - if (!result.attributes.get("exhausted")) - result.addAttribute("exhausted", parser.getBuilder().getBoolAttr(false)); - return success(); - } - - SmallVector srcs = {first, second}; - while (parser.parseOptionalComma().succeeded()) { - OpAsmParser::UnresolvedOperand next; - if (parser.parseOperand(next)) - return failure(); - srcs.push_back(next); - } - if (srcs.size() < 3 || srcs.size() > 5) - return parser.emitError(parser.getCurrentLocation(), - "tmrgsort format2 expects 2 to 4 src operands plus one tmp operand"); - OpAsmParser::UnresolvedOperand tmpOp = srcs.pop_back_val(); - bool exhaustedVal = false; - if (parser.parseOptionalLBrace().succeeded()) { - if (parser.parseKeyword("exhausted") || parser.parseEqual()) - return failure(); - StringRef kw; - if (parser.parseKeyword(&kw) || parser.parseRBrace()) - return failure(); - exhaustedVal = (kw == "true"); - } - SmallVector srcTypes; - srcTypes.reserve(srcs.size()); - if (parser.parseColon()) - return failure(); - Type firstSrcTy; - if (parser.parseType(firstSrcTy)) - return failure(); - srcTypes.push_back(firstSrcTy); - while (parser.parseOptionalComma().succeeded()) { - Type nextTy; - if (parser.parseType(nextTy)) - return failure(); - srcTypes.push_back(nextTy); - } - if (srcTypes.size() != srcs.size() + 1 || parser.parseRParen() || - parser.parseKeyword("outs") || parser.parseLParen()) - return failure(); - Type tmpTy = srcTypes.pop_back_val(); - OpAsmParser::UnresolvedOperand dstOp, excutedOp; - Type dstTy, excutedTy; - if (parser.parseOperand(dstOp) || parser.parseComma() || parser.parseOperand(excutedOp) || - parser.parseColon() || parser.parseType(dstTy) || parser.parseComma() || - parser.parseType(excutedTy) || parser.parseRParen()) - return failure(); - result.addAttribute("operandSegmentSizes", - parser.getBuilder().getDenseI32ArrayAttr( - {static_cast(srcs.size()), 0, 1, 1, 1})); - if (parser.resolveOperands(srcs, srcTypes, parser.getCurrentLocation(), result.operands) || - parser.resolveOperand(dstOp, dstTy, result.operands) || - parser.resolveOperand(tmpOp, tmpTy, result.operands) || - parser.resolveOperand(excutedOp, excutedTy, result.operands)) - return failure(); - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - if (!result.attributes.get("exhausted")) - result.addAttribute("exhausted", parser.getBuilder().getBoolAttr(exhaustedVal)); - return success(); -} - -mlir::LogicalResult mlir::pto::TMrgSortOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - if (isFormat1()) { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (!isPTOShapedLike(srcTy) || !isPTOShapedLike(dstTy)) - return emitOpError() << "format1 expects PTO shaped-like types for src/dst"; - if (getElemTy(srcTy) != getElemTy(dstTy)) - return emitOpError() << "expects src/dst to have the same element type"; - if (!getElemTy(srcTy).isF16() && !getElemTy(srcTy).isF32()) - return emitOpError() << "expects element type to be f16 or f32"; - auto ss = getShapeVec(srcTy); - auto ds = getShapeVec(dstTy); - if (ss.size() != 2 || ds.size() != 2) - return emitOpError() << "expects src/dst to be rank-2 tile-shaped"; - if (ss[0] != mlir::ShapedType::kDynamic && ss[0] != 1) - return emitOpError() << "expects src rows == 1"; - if (ds[0] != mlir::ShapedType::kDynamic && ds[0] != 1) - return emitOpError() << "expects dst rows == 1"; - if (ss[1] != mlir::ShapedType::kDynamic && ds[1] != mlir::ShapedType::kDynamic && ss[1] != ds[1]) - return emitOpError() << "expects src/dst cols to match"; - if (getBlockLen()) { - if (auto cstOp = getBlockLen().getDefiningOp()) { - if (auto intAttr = mlir::dyn_cast(cstOp.getValue())) { - int64_t v = intAttr.getValue().getSExtValue(); - if (v <= 0 || (v % 64) != 0) - return emitOpError() << "expects blockLen > 0 and multiple of 64"; - } - } - } - return mlir::success(); - } - if (isFormat2()) { - for (Value v : getSrcs()) - if (!isPTOShapedLike(v.getType())) - return emitOpError() << "format2 expects PTO shaped-like type for each src"; - if (getSrcs().size() < 2u || getSrcs().size() > 4u) - return emitOpError() << "format2 expects 2 to 4 srcs"; - if (getDsts().size() != 1u || !getTmp() || !getExcuted()) - return emitOpError() << "format2 expects ins(srcs..., tmp), outs(dst), and excuted=vector"; - Type dstTy = getDst().getType(); - Type tmpTy = getTmp().getType(); - if (!isPTOShapedLike(dstTy) || !isPTOShapedLike(tmpTy)) - return emitOpError() << "format2 dst/tmp must be PTO shaped-like"; - auto excutedTy = mlir::dyn_cast(getExcuted().getType()); - if (!excutedTy || excutedTy.getRank() != 1 || excutedTy.getNumElements() != 4 || - !excutedTy.getElementType().isInteger(16)) - return emitOpError() << "format2 excuted must be vector<4xi16>"; - Type elemTy = getElemTy(dstTy); - if (elemTy != getElemTy(tmpTy)) - return emitOpError() << "format2 expects dst/tmp element types to match"; - auto dstShape = getShapeVec(dstTy); - auto tmpShape = getShapeVec(tmpTy); - if (dstShape.size() != 2 || tmpShape.size() != 2) - return emitOpError() << "format2 expects dst/tmp to be rank-2 tile-shaped"; - if ((dstShape[0] != mlir::ShapedType::kDynamic && dstShape[0] != 1) || - (tmpShape[0] != mlir::ShapedType::kDynamic && tmpShape[0] != 1)) - return emitOpError() << "format2 expects dst/tmp rows == 1"; - if (dstShape[1] != mlir::ShapedType::kDynamic && - tmpShape[1] != mlir::ShapedType::kDynamic && - tmpShape[1] < dstShape[1]) - return emitOpError() << "format2 expects tmp.cols >= dst.cols"; - for (Value src : getSrcs()) { - Type srcTy = src.getType(); - auto srcShape = getShapeVec(srcTy); - if (srcShape.size() != 2) - return emitOpError() << "format2 expects src to be rank-2 tile-shaped"; - if (srcShape[0] != mlir::ShapedType::kDynamic && srcShape[0] != 1) - return emitOpError() << "format2 expects src rows == 1"; - if (getElemTy(srcTy) != elemTy) - return emitOpError() << "format2 expects src/dst/tmp element types to match"; - } - return mlir::success(); - } - return emitOpError() << "tmrgsort expects format1 (1 src + blockLen + 1 dst) or " - "format2 (2 to 4 srcs + tmp, outs dst, excuted)"; -} - -mlir::LogicalResult mlir::pto::TMulOp::verify() { - return verifyArithmeticBinaryTileOpWithArchDispatch( - getOperation(), getSrc0().getType(), getSrc1().getType(), getDst().getType(), - /*allowInt8OnA5=*/false, /*allowBf16OnA5=*/false, - "expects A2/A3 tmul element type to be i32/i16/f16/f32", - "expects A5 tmul element type to be i32/i16/f16/f32"); -} - -mlir::LogicalResult mlir::pto::TMulSOp::verify() { - return verifyArithmeticScalarTileOpWithArchDispatch( - getOperation(), getSrc0().getType(), getDst().getType(), - getScalar().getType(), /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, - "expects A2/A3 tmuls element type to be i32/i16/f16/f32", - "expects A5 tmuls element type to be i32/i16/i8/f16/bf16/f32", - /*requireValidRowsEqualOnA2A3=*/true, - /*requireValidRowsEqualOnA5=*/false); -} - -mlir::LogicalResult mlir::pto::TShlSOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, dstTy, "dst"))) - return failure(); - - Type srcElem = getElemTy(srcTy); - Type dstElem = getElemTy(dstTy); - if (!srcElem || !dstElem) - return emitOpError() << "failed to get element type for src/dst"; - if (srcElem != dstElem) - return emitOpError() << "expects src and dst to have the same element type"; - if (!mlir::isa(srcElem)) - return emitOpError() << "expects integral element types"; - if (auto scalarValue = getConstantIntegerValue(getScalar()); scalarValue && *scalarValue < 0) - return emitOpError("expects tshls scalar to be non-negative"); - return mlir::success(); -} - -mlir::LogicalResult mlir::pto::TShrSOp::verify() { - auto verifyCommon = [&]() -> FailureOr { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyVecTileCommon(*this, srcTy, "src")) || - failed(verifyVecTileCommon(*this, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) - return failure(); - - Type srcElem = getElemTy(srcTy); - Type dstElem = getElemTy(dstTy); - if (!srcElem || !dstElem) { - emitOpError("failed to get element type for src/dst"); - return failure(); - } - if (srcElem != dstElem) { - emitOpError("expects src and dst to have the same element type"); - return failure(); - } - return srcElem; - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - FailureOr elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - auto it = mlir::dyn_cast(*elemOr); - if (!it || (it.getWidth() != 16 && it.getWidth() != 32)) - return emitOpError( - "expects A2/A3 tshrs src and dst element type to be i16/i32"); - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - FailureOr elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - auto it = mlir::dyn_cast(*elemOr); - if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && - it.getWidth() != 32)) - return emitOpError( - "expects A5 tshrs src and dst element type to be i8/i16/i32"); - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TNegOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyVecTileStorage(*this, srcTy, "src")) || - failed(verifyVecTileStorage(*this, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst")) || - failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) - return failure(); - - Type elemTy = getElemTy(srcTy); - if (!(elemTy.isInteger(16) || elemTy.isInteger(32) || elemTy.isF16() || - elemTy.isF32())) - return emitOpError() - << "expects A2/A3 tneg element type to be i16/i32/f16/f32"; - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyVecTileStorage(*this, srcTy, "src")) || - failed(verifyVecTileStorage(*this, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst"))) - return failure(); - - auto srcValid = getValidShapeVec(srcTy); - auto dstValid = getValidShapeVec(dstTy); - if (srcValid.size() != 2 || dstValid.size() != 2) - return emitOpError() << "expects src and dst to have rank-2 valid_shape"; - if (srcValid[1] != ShapedType::kDynamic && - dstValid[1] != ShapedType::kDynamic && - srcValid[1] != dstValid[1]) - return emitOpError() - << "expects src and dst to have the same valid_shape[1]"; - - Type elemTy = getElemTy(srcTy); - if (!(elemTy.isInteger(8) || elemTy.isInteger(16) || elemTy.isInteger(32) || - elemTy.isF16() || elemTy.isF32() || elemTy.isBF16())) - return emitOpError() - << "expects A5 tneg element type to be i8/i16/i32/f16/f32/bf16"; - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TNotOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyVecTileCommon(*this, srcTy, "src")) || - failed(verifyVecTileCommon(*this, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) - return failure(); - auto elemTy = getElemTy(srcTy); - if (elemTy != getElemTy(dstTy)) - return emitOpError() << "expects src and dst to have the same element type"; - if (!elemTy.isInteger(16)) - return emitOpError() << "expects A2/A3 tnot element type to be i16"; - return success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyVecTileCommon(*this, srcTy, "src")) || - failed(verifyVecTileCommon(*this, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) - return failure(); - auto elemTy = getElemTy(srcTy); - if (elemTy != getElemTy(dstTy)) - return emitOpError() << "expects src and dst to have the same element type"; - if (!(elemTy.isInteger(8) || elemTy.isInteger(16) || elemTy.isInteger(32))) - return emitOpError() << "expects A5 tnot element type to be i8/i16/i32"; - return success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TOrOp::verify() { - auto verifyCommon = [&]() -> FailureOr { - return verifyMatchingRowMajorBinaryTileOpCommon( - getOperation(), getSrc0().getType(), getSrc1().getType(), - getDst().getType()); - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - FailureOr elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - auto it = mlir::dyn_cast(*elemOr); - if (!it || (it.getWidth() != 8 && it.getWidth() != 16)) - return emitOpError( - "expects A2/A3 tor src0, src1, and dst element type to be i8/i16"); - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - FailureOr elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - auto it = mlir::dyn_cast(*elemOr); - if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && - it.getWidth() != 32)) - return emitOpError( - "expects A5 tor src0, src1, and dst element type to be i8/i16/i32"); - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TOrSOp::verify() { - auto verifyCommon = [&]() -> FailureOr { - return verifyDistinctRowMajorUnaryTileOpCommon(getOperation(), getSrc(), - getDst(), "src", "dst"); - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - FailureOr elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - auto it = mlir::dyn_cast(*elemOr); - if (!it || (it.getWidth() != 8 && it.getWidth() != 16)) - return emitOpError( - "expects A2/A3 tors src and dst element type to be i8/i16"); - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - FailureOr elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - auto it = mlir::dyn_cast(*elemOr); - if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && - it.getWidth() != 32)) - return emitOpError( - "expects A5 tors src and dst element type to be i8/i16/i32"); - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -static FailureOr verifyPTOShapedBinarySameElemAndShape(Operation *op, - Type src0Ty, - Type src1Ty, - Type dstTy) { - if (!isPTOShapedLike(src0Ty) || !isPTOShapedLike(src1Ty) || - !isPTOShapedLike(dstTy)) - return op->emitOpError( - "expects src0/src1/dst to be memref/tensor/tile_buf/tile_view types"), - failure(); - Type e0 = getElemTy(src0Ty), e1 = getElemTy(src1Ty), ed = getElemTy(dstTy); - if (!e0 || !e1 || !ed) - return op->emitOpError("failed to get element type for operands"), failure(); - if (e0 != e1 || e0 != ed) - return op->emitOpError("expects src0/src1/dst to have the same element type"), - failure(); - auto s0 = getShapeVec(src0Ty), s1 = getShapeVec(src1Ty), sd = getShapeVec(dstTy); - if (s0 != s1 || s0 != sd) - return op->emitOpError("expects src0/src1/dst to have the same shape"), - failure(); - return e0; -} - -mlir::LogicalResult mlir::pto::TPartAddOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - Type src0Ty = getSrc0().getType(); - Type src1Ty = getSrc1().getType(); - Type dstTy = getDst().getType(); - if (!isPTOShapedLike(src0Ty) || !isPTOShapedLike(src1Ty) || - !isPTOShapedLike(dstTy)) - return emitOpError() << "expects PTO shaped-like src0/src1/dst"; - if (getElemTy(src0Ty) != getElemTy(src1Ty) || - getElemTy(src0Ty) != getElemTy(dstTy)) - return emitOpError() << "expects src0/src1/dst to have the same element type"; - auto s0 = getShapeVec(src0Ty); - auto s1 = getShapeVec(src1Ty); - auto d = getShapeVec(dstTy); - if (s0.size() != 2 || s1.size() != 2 || d.size() != 2) - return emitOpError() << "expects src0/src1/dst to be rank-2 (tile-shaped)"; - if (failed(verifyPartialValidPattern(*this, src0Ty, src1Ty, dstTy))) - return failure(); - Type elem = getElemTy(src0Ty); - if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || elem.isF32())) - return emitOpError("expects A2/A3 tpartadd element type to be i32/i16/f16/f32"); - return mlir::success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - Type src0Ty = getSrc0().getType(); - Type src1Ty = getSrc1().getType(); - Type dstTy = getDst().getType(); - if (!isPTOShapedLike(src0Ty) || !isPTOShapedLike(src1Ty) || - !isPTOShapedLike(dstTy)) - return emitOpError() << "expects PTO shaped-like src0/src1/dst"; - if (getElemTy(src0Ty) != getElemTy(src1Ty) || - getElemTy(src0Ty) != getElemTy(dstTy)) - return emitOpError() << "expects src0/src1/dst to have the same element type"; - Type elem = getElemTy(src0Ty); - if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isInteger(8) || - elem.isF16() || elem.isBF16() || elem.isF32())) - return emitOpError("expects A5 tpartadd element type to be i32/i16/i8/f16/bf16/f32"); - auto s0 = getShapeVec(src0Ty); - auto s1 = getShapeVec(src1Ty); - auto d = getShapeVec(dstTy); - if (s0.size() != 2 || s1.size() != 2 || d.size() != 2) - return emitOpError() << "expects src0/src1/dst to be rank-2 (tile-shaped)"; - return mlir::success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TPartMaxOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - Type t0 = getSrc0().getType(); - Type t1 = getSrc1().getType(); - Type td = getDst().getType(); - FailureOr elemOr = - verifyPTOShapedBinarySameElemAndShape(getOperation(), t0, t1, td); - if (failed(elemOr)) - return failure(); - if (failed(verifyPartialValidPattern(*this, t0, t1, td))) - return failure(); - Type e0 = *elemOr; - if (!(e0.isInteger(32) || e0.isInteger(16) || e0.isF16() || e0.isF32())) - return emitOpError("expects A2/A3 tpartmax element type to be i32/i16/f16/f32"); - return mlir::success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - Type t0 = getSrc0().getType(); - Type t1 = getSrc1().getType(); - Type td = getDst().getType(); - FailureOr elemOr = - verifyPTOShapedBinarySameElemAndShape(getOperation(), t0, t1, td); - if (failed(elemOr)) - return failure(); - Type e0 = *elemOr; - if (!(e0.isInteger(32) || e0.isInteger(16) || e0.isInteger(8) || - e0.isF16() || e0.isBF16() || e0.isF32())) - return emitOpError("expects A5 tpartmax element type to be i32/i16/i8/f16/bf16/f32"); - return mlir::success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TPartMinOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - Type t0 = getSrc0().getType(); - Type t1 = getSrc1().getType(); - Type td = getDst().getType(); - FailureOr elemOr = - verifyPTOShapedBinarySameElemAndShape(getOperation(), t0, t1, td); - if (failed(elemOr)) - return failure(); - if (failed(verifyPartialValidPattern(*this, t0, t1, td))) - return failure(); - Type e0 = *elemOr; - if (!(e0.isInteger(32) || e0.isInteger(16) || e0.isF16() || e0.isF32())) - return emitOpError("expects A2/A3 tpartmin element type to be i32/i16/f16/f32"); - return mlir::success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - Type t0 = getSrc0().getType(); - Type t1 = getSrc1().getType(); - Type td = getDst().getType(); - FailureOr elemOr = - verifyPTOShapedBinarySameElemAndShape(getOperation(), t0, t1, td); - if (failed(elemOr)) - return failure(); - Type e0 = *elemOr; - if (!(e0.isInteger(32) || e0.isInteger(16) || e0.isInteger(8) || - e0.isF16() || e0.isBF16() || e0.isF32())) - return emitOpError("expects A5 tpartmin element type to be i32/i16/i8/f16/bf16/f32"); - return mlir::success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -static LogicalResult verifyTPartArgOpCommon(Operation *op, Type src0Ty, - Type src1Ty, Type src0IdxTy, - Type src1IdxTy, Type dstTy, - Type dstIdxTy, StringRef opName) { - FailureOr dataElemOr = - verifyPTOShapedBinarySameElemAndShape(op, src0Ty, src1Ty, dstTy); - if (failed(dataElemOr)) - return failure(); - if (failed(verifyPartialValidPattern(op, src0Ty, src1Ty, dstTy))) - return failure(); - - if (!isPTOShapedLike(src0IdxTy) || !isPTOShapedLike(src1IdxTy) || - !isPTOShapedLike(dstIdxTy)) - return op->emitOpError("expects PTO shaped-like src0Idx/src1Idx/dstIdx"); - Type idxElem = getElemTy(src0IdxTy); - if (!idxElem || idxElem != getElemTy(src1IdxTy) || - idxElem != getElemTy(dstIdxTy)) - return op->emitOpError( - "expects src0Idx/src1Idx/dstIdx to have the same element type"); - auto idxInt = dyn_cast(idxElem); - if (!idxInt || idxInt.getWidth() != 32) - return op->emitOpError( - "expects src0Idx/src1Idx/dstIdx element type to be i32 or ui32"); - - auto dataShape = getShapeVec(src0Ty); - if (dataShape != getShapeVec(src0IdxTy) || - dataShape != getShapeVec(src1IdxTy) || - dataShape != getShapeVec(dstIdxTy)) - return op->emitOpError( - "expects data and index operands to have the same shape"); - if (getValidShapeVec(src0Ty) != getValidShapeVec(src0IdxTy) || - getValidShapeVec(src1Ty) != getValidShapeVec(src1IdxTy) || - getValidShapeVec(dstTy) != getValidShapeVec(dstIdxTy)) - return op->emitOpError( - "expects each data operand and its index operand to have the same valid_shape"); - - Type elem = *dataElemOr; - PTOArch arch = getTargetArch(op); - if (arch == PTOArch::A5) { - if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isInteger(8) || - elem.isF16() || elem.isBF16() || elem.isF32())) - return op->emitOpError() << "expects A5 " << opName - << " element type to be i32/i16/i8/f16/bf16/f32"; - } else { - if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || - elem.isF32())) - return op->emitOpError() << "expects A2/A3 " << opName - << " element type to be i32/i16/f16/f32"; - } - return success(); -} - -mlir::LogicalResult mlir::pto::TPartArgMaxOp::verify() { - auto verifyByArch = [&]() -> LogicalResult { - return verifyTPartArgOpCommon( - getOperation(), getSrc0().getType(), getSrc1().getType(), - getSrc0Idx().getType(), getSrc1Idx().getType(), getDst().getType(), - getDstIdx().getType(), "tpartargmax"); - }; - return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); -} - -mlir::LogicalResult mlir::pto::TPartArgMinOp::verify() { - auto verifyByArch = [&]() -> LogicalResult { - return verifyTPartArgOpCommon( - getOperation(), getSrc0().getType(), getSrc1().getType(), - getSrc0Idx().getType(), getSrc1Idx().getType(), getDst().getType(), - getDstIdx().getType(), "tpartargmin"); - }; - return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); -} - -mlir::LogicalResult mlir::pto::TPartMulOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - Type src0Ty = getSrc0().getType(); - Type src1Ty = getSrc1().getType(); - Type dstTy = getDst().getType(); - if (!isPTOShapedLike(src0Ty) || !isPTOShapedLike(src1Ty) || - !isPTOShapedLike(dstTy)) - return emitOpError() << "expects PTO shaped-like src0/src1/dst"; - if (getElemTy(src0Ty) != getElemTy(src1Ty) || - getElemTy(src0Ty) != getElemTy(dstTy)) - return emitOpError() - << "expects src0/src1/dst to have the same element type"; - auto s0 = getShapeVec(src0Ty); - auto s1 = getShapeVec(src1Ty); - auto d = getShapeVec(dstTy); - if (s0.size() != 2 || s1.size() != 2 || d.size() != 2) - return emitOpError() - << "expects src0/src1/dst to be rank-2 (tile-shaped)"; - if (failed(verifyPartialValidPattern(*this, src0Ty, src1Ty, dstTy))) - return failure(); - Type elem = getElemTy(src0Ty); - if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || - elem.isF32())) - return emitOpError( - "expects A2/A3 tpartmul element type to be i32/i16/f16/f32"); - return mlir::success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - Type src0Ty = getSrc0().getType(); - Type src1Ty = getSrc1().getType(); - Type dstTy = getDst().getType(); - if (!isPTOShapedLike(src0Ty) || !isPTOShapedLike(src1Ty) || - !isPTOShapedLike(dstTy)) - return emitOpError() << "expects PTO shaped-like src0/src1/dst"; - if (getElemTy(src0Ty) != getElemTy(src1Ty) || - getElemTy(src0Ty) != getElemTy(dstTy)) - return emitOpError() - << "expects src0/src1/dst to have the same element type"; - Type elem = getElemTy(src0Ty); - if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isInteger(8) || - elem.isF16() || elem.isBF16() || elem.isF32())) - return emitOpError( - "expects A5 tpartmul element type to be i32/i16/i8/f16/bf16/f32"); - auto s0 = getShapeVec(src0Ty); - auto s1 = getShapeVec(src1Ty); - auto d = getShapeVec(dstTy); - if (s0.size() != 2 || s1.size() != 2 || d.size() != 2) - return emitOpError() - << "expects src0/src1/dst to be rank-2 (tile-shaped)"; - return mlir::success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TPReluOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - auto verifyCommon = [&]() -> FailureOr> { - Type t0 = getSrc0().getType(); - Type t1 = getSrc1().getType(); - Type tt = getTmp().getType(); - Type td = getDst().getType(); - if (failed(verifyTileBufCommon(*this, t0, "src0")) || - failed(verifyTileBufCommon(*this, t1, "src1")) || - failed(verifyTileBufCommon(*this, tt, "tmp")) || - failed(verifyTileBufCommon(*this, td, "dst"))) - return failure(); - - Type e0 = getElemTy(t0), e1 = getElemTy(t1), et = getElemTy(tt), ed = getElemTy(td); - if (!e0 || !e1 || !et || !ed) { - emitOpError("failed to get element type for operands"); - return failure(); - } - if (e0 != e1 || e0 != ed) { - emitOpError("expects dst/src0/src1 to have the same element type"); - return failure(); - } - if (!(e0.isF16() || e0.isF32())) { - emitOpError("expects dst/src0/src1 element type to be f16 or f32"); - return failure(); - } - if (!isRowMajorTileBuf(t0) || !isRowMajorTileBuf(t1) || !isRowMajorTileBuf(td)) { - emitOpError("expects src0, src1, and dst to use row-major layout"); - return failure(); - } - if (failed(verifyTileBufSameValidShape(*this, t0, td, "src0", "dst")) || - failed(verifyTileBufSameValidShape(*this, t1, td, "src1", "dst"))) - return failure(); - - auto s0 = getShapeVec(t0), s1 = getShapeVec(t1), st = getShapeVec(tt), sd = getShapeVec(td); - if (s0 != s1 || s0 != st || s0 != sd) { - emitOpError("expects src0/src1/tmp/dst to have the same shape"); - return failure(); - } - return std::make_tuple(t0, t1, tt, td); - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - auto tysOr = verifyCommon(); - if (failed(tysOr)) - return failure(); - auto [t0, t1, tt, td] = *tysOr; - Type tmpElem = getElemTy(tt); - auto tmpIntTy = mlir::dyn_cast(tmpElem); - if (!tmpIntTy || tmpIntTy.getWidth() != 8) - return emitOpError("expects A2/A3 tmp element type to be u8"); - if (!isRowMajorTileBuf(tt)) - return emitOpError("expects tmp to use row-major layout"); - if (auto arch = getVerifierArchName(getOperation()); - arch && arch->equals_insensitive("a3")) { - if (getSrc0() == getSrc1() || getSrc0() == getTmp() || getSrc0() == getDst() || - getSrc1() == getTmp() || getSrc1() == getDst() || getTmp() == getDst()) - return emitOpError( - "expects A3 src0, src1, tmp, and dst to use different storage"); - } - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - auto tysOr = verifyCommon(); - if (failed(tysOr)) - return failure(); - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TQuantOp::verify() { - // Structural checks: always run regardless of operand representation - // (applies both before and after PTOViewToMemref lowering). - auto verifyStructural = [&]() -> LogicalResult { - // dst elem type and offset presence must be consistent with quant_type. - Type dstTy = getDst().getType(); - Type dstElemTy = getElemTy(dstTy); - auto dstIntTy = dyn_cast(dstElemTy); - if (getQuantType() == mlir::pto::QuantType::INT8_SYM) { - if (!dstIntTy || dstIntTy.getWidth() != 8) - return emitOpError() - << "expects dst element type i8/ui8 for INT8_SYM quantization"; - if (getOffset()) - return emitOpError() - << "INT8_SYM quantization must not have an offset operand"; - } else { - // INT8_ASYM - if (!dstIntTy || dstIntTy.getWidth() != 8) - return emitOpError() - << "expects dst element type i8/ui8 for INT8_ASYM quantization"; - if (!getOffset()) - return emitOpError() - << "INT8_ASYM quantization requires an offset operand"; - } - return success(); - }; - - if (failed(verifyStructural())) - return failure(); - - // Layout/tile-buffer checks: only meaningful for pre-lowering tile types. - // Skip when operands are already plain MemRefs (post PTOViewToMemref). - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - - auto verifyCommon = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type fpTy = getFp().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, fpTy, "fp")) || - failed(verifyTileBufCommon(*this, dstTy, "dst"))) - return failure(); - // src must be f32 (ISA static_assert) - if (!getElemTy(srcTy).isF32()) - return emitOpError() << "expects src to have element type f32"; - if (getOffset()) { - Type offsetTy = getOffset().getType(); - if (failed(verifyTileBufCommon(*this, offsetTy, "offset"))) - return failure(); - if (!getElemTy(offsetTy).isF32()) - return emitOpError() << "expects offset to have element type f32"; - } - return success(); - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - if (failed(verifyCommon())) - return failure(); - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (!isRowMajorTileBuf(srcTy) || !isRowMajorTileBuf(dstTy)) - return emitOpError() << "expects A2/A3 src and dst to use row-major layout"; - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - return verifyCommon(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TDequantOp::verify() { - // Structural checks: src must be i8 or i16, dst/scale/offset must be f32. - auto verifyStructural = [&]() -> LogicalResult { - Type srcElemTy = getElemTy(getSrc().getType()); - auto srcIntTy = dyn_cast(srcElemTy); - if (!srcIntTy || !(srcIntTy.getWidth() == 8 || srcIntTy.getWidth() == 16)) - return emitOpError() - << "expects src element type i8 or i16"; - if (!getElemTy(getDst().getType()).isF32()) - return emitOpError() << "expects dst element type f32"; - if (!getElemTy(getScale().getType()).isF32()) - return emitOpError() << "expects scale element type f32"; - if (!getElemTy(getOffset().getType()).isF32()) - return emitOpError() << "expects offset element type f32"; - return success(); - }; - - if (failed(verifyStructural())) - return failure(); - - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - - auto verifyCommon = [&]() -> LogicalResult { - if (failed(verifyTileBufCommon(*this, getSrc().getType(), "src")) || - failed(verifyTileBufCommon(*this, getScale().getType(), "scale")) || - failed(verifyTileBufCommon(*this, getOffset().getType(), "offset")) || - failed(verifyTileBufCommon(*this, getDst().getType(), "dst"))) - return failure(); - return success(); - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - if (failed(verifyCommon())) - return failure(); - if (!isRowMajorTileBuf(getSrc().getType()) || - !isRowMajorTileBuf(getDst().getType())) - return emitOpError() - << "expects A2/A3 src and dst to use row-major layout"; - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { return verifyCommon(); }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TRecipOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - Type ts = getSrc().getType(); - Type td = getDst().getType(); - if (failed(verifyVecTileUnaryOp(*this, ts, td, "src", "dst", - /*allowBf16=*/false, /*allowInt8=*/false))) - return failure(); - if (failed(verifyTileBufSameValidShape(*this, ts, td, "src", "dst"))) - return failure(); - Type elemTy = getElemTy(ts); - if (!(elemTy.isF16() || elemTy.isF32())) - return emitOpError() << "expects element type to be f16 or f32"; - if (auto arch = getVerifierArchName(getOperation()); - arch && arch->equals_insensitive("a3") && getSrc() == getDst()) - return emitOpError("expects A3 trecip src and dst to use different storage"); - return mlir::success(); -} - -mlir::LogicalResult mlir::pto::TReluOp::verify() { - auto verifyByArch = [&](StringRef errorMessage) -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyVecTileCommon(*this, srcTy, "src")) || - failed(verifyVecTileCommon(*this, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst")) || - failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) - return failure(); - Type elemTy = getElemTy(srcTy); - if (!(elemTy.isInteger(32) || elemTy.isF16() || elemTy.isF32())) - return emitOpError() << errorMessage; - return success(); - }; - auto verifyA2A3 = [&]() -> LogicalResult { - return verifyByArch("expects A2/A3 trelu element type to be i32/f16/f32"); - }; - auto verifyA5 = [&]() -> LogicalResult { - return verifyByArch("expects A5 trelu element type to be i32/f16/f32"); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - - -mlir::LogicalResult mlir::pto::TRemOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - - Type src0Ty = getSrc0().getType(); - Type src1Ty = getSrc1().getType(); - Type tmpTy = getTmp().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, src0Ty, "src0")) || - failed(verifyTileBufCommon(*this, src1Ty, "src1")) || - failed(verifyTileBufCommon(*this, tmpTy, "tmp")) || - failed(verifyTileBufCommon(*this, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameElemType(*this, src0Ty, src1Ty, "src0", "src1")) || - failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst")) || - failed(verifyTileBufSameValidShape(*this, src0Ty, src1Ty, "src0", "src1")) || - failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) - return failure(); - if (getElemTy(tmpTy) != getElemTy(dstTy)) - return emitOpError("expects tmp and dst to have the same element type"); - if (!isRowMajorTileBuf(src0Ty) || !isRowMajorTileBuf(src1Ty) || - !isRowMajorTileBuf(tmpTy) || !isRowMajorTileBuf(dstTy)) - return emitOpError("expects src0, src1, tmp, and dst to use row-major layout"); - auto dstValid = getValidShapeVec(dstTy); - auto tmpValid = getValidShapeVec(tmpTy); - if (dstValid.size() != 2 || tmpValid.size() != 2) - return emitOpError("expects tmp and dst to be rank-2 tiles"); - if (tmpValid[0] != ShapedType::kDynamic && tmpValid[0] < 1) - return emitOpError("expects tmp to have at least 1 valid row"); - if (dstValid[1] != ShapedType::kDynamic && tmpValid[1] != ShapedType::kDynamic && - tmpValid[1] < dstValid[1]) - return emitOpError("expects tmp valid columns to cover dst valid columns"); - - Type elem = getElemTy(src0Ty); - auto verifyA2A3 = [&]() -> LogicalResult { - if (!(elem.isInteger(32) || elem.isF32())) - return emitOpError("expects A2/A3 trem element type to be i32/f32"); - return success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || elem.isF32())) - return emitOpError("expects A5 trem element type to be i32/i16/f16/f32"); - return success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TFModOp::verify() { - return verifyArithmeticBinaryTileOpWithArchDispatch( - getOperation(), getSrc0().getType(), getSrc1().getType(), getDst().getType(), - /*allowInt8OnA5=*/false, /*allowBf16OnA5=*/false, - "expects A2/A3 tfmod element type to be i32/i16/f16/f32", - "expects A5 tfmod element type to be i32/i16/f16/f32"); -} - -mlir::LogicalResult mlir::pto::TRemSOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - Type ts = getSrc().getType(); - Type tt = getTmp().getType(); - Type td = getDst().getType(); - Type scalarTy = getScalar().getType(); - if (failed(verifyTileBufCommon(*this, ts, "src")) || - failed(verifyTileBufCommon(*this, tt, "tmp")) || - failed(verifyTileBufCommon(*this, td, "dst"))) - return failure(); - if (failed(verifyTileBufSameElemType(*this, ts, td, "src", "dst")) || - failed(verifyTileBufSameValidShape(*this, ts, td, "src", "dst"))) - return failure(); - if (getElemTy(tt) != getElemTy(td)) - return emitOpError("expects tmp and dst to have the same element type"); - if (!isRowMajorTileBuf(ts) || !isRowMajorTileBuf(tt) || !isRowMajorTileBuf(td)) - return emitOpError("expects src, tmp, and dst to use row-major layout"); - Type elem = getElemTy(ts); - if (scalarTy != elem) - return emitOpError("expects scalar type to match the tile element type"); - auto dstValid = getValidShapeVec(td); - auto tmpValid = getValidShapeVec(tt); - if (dstValid.size() != 2 || tmpValid.size() != 2) - return emitOpError("expects tmp and dst to be rank-2 tiles"); - if (tmpValid[0] != ShapedType::kDynamic && tmpValid[0] < 1) - return emitOpError("expects tmp to have at least 1 valid row"); - if (dstValid[1] != ShapedType::kDynamic && tmpValid[1] != ShapedType::kDynamic && - tmpValid[1] < dstValid[1]) - return emitOpError("expects tmp valid columns to cover dst valid columns"); - auto verifyA2A3 = [&]() -> LogicalResult { - if (!(elem.isInteger(32) || elem.isF32())) - return emitOpError("expects A2/A3 trems element type to be i32/f32"); - return success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || elem.isF32())) - return emitOpError("expects A5 trems element type to be i32/i16/f16/f32"); - return success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TFModSOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - Type scalarTy = getScalar().getType(); - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameElemType(*this, srcTy, dstTy, "src", "dst")) || - failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) - return failure(); - if (!isRowMajorTileBuf(srcTy) || !isRowMajorTileBuf(dstTy)) - return emitOpError("expects src and dst to use row-major layout"); - - Type elem = getElemTy(srcTy); - if (scalarTy != elem) - return emitOpError("expects scalar type to match the tile element type"); - - auto verifyA2A3 = [&]() -> LogicalResult { - if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || elem.isF32())) - return emitOpError("expects A2/A3 tfmods element type to be i32/i16/f16/f32"); - return success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - if (!(elem.isInteger(32) || elem.isInteger(16) || elem.isF16() || elem.isF32())) - return emitOpError("expects A5 tfmods element type to be i32/i16/f16/f32"); - return success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - - -static std::optional getStaticNumElements(ArrayRef shape) { - int64_t numel = 1; - for (int64_t d : shape) { - if (d == ShapedType::kDynamic) - return std::nullopt; - if (d < 0) - return std::nullopt; - numel *= d; - } - return numel; -} - -static std::optional getElemBytes(Type elemTy) { - if (!elemTy) - return std::nullopt; - if (auto ft = dyn_cast(elemTy)) { - if (ft.isF16() || ft.isBF16()) - return 2; - if (ft.isF32()) - return 4; - if (ft.isF64()) - return 8; - return std::nullopt; - } - if (auto it = dyn_cast(elemTy)) { - int64_t bits = it.getWidth(); - if (bits <= 0) - return std::nullopt; - return std::max(1, bits / 8); - } - return std::nullopt; -} - -[[maybe_unused]] static bool isTileBufOrMemref(Type ty) { - return mlir::isa(ty); -} - -static constexpr llvm::StringLiteral kLoweredSetValidShapeAttrName = - "__pto.lowered_set_validshape"; - -static bool isLocallyBoundTileSource(Value value) { - if (!value || isa(value)) - return false; - - if (isa( - value.getDefiningOp())) - return true; - - if (auto bitcast = value.getDefiningOp()) - return isLocallyBoundTileSource(bitcast.getSrc()); - if (auto reshape = value.getDefiningOp()) - return isLocallyBoundTileSource(reshape.getSrc()); - - return false; -} - -static std::optional getConstIndexLike(Value v) { - if (auto cOp = v.getDefiningOp()) - return cOp.value(); - if (auto cInt = v.getDefiningOp()) - return cInt.value(); - if (auto cOp = v.getDefiningOp()) { - if (auto ia = dyn_cast(cOp.getValue())) - return ia.getInt(); - } - if (auto castOp = v.getDefiningOp()) - return getConstIndexLike(castOp.getIn()); - if (auto extOp = v.getDefiningOp()) - return getConstIndexLike(extOp.getIn()); - if (auto extOp = v.getDefiningOp()) - return getConstIndexLike(extOp.getIn()); - if (auto truncOp = v.getDefiningOp()) - return getConstIndexLike(truncOp.getIn()); - return std::nullopt; -} - -mlir::LogicalResult mlir::pto::SetValidShapeOp::verify() { - SmallVector shape; - if (auto srcTy = llvm::dyn_cast(getSource().getType())) { - if (srcTy.getRank() != 2) - return emitOpError("expects rank-2 tile_buf source"); - - ArrayRef validShape = srcTy.getValidShape(); - if (validShape.size() != 2) - return emitOpError("expects source validShape to be rank-2"); - if (!srcTy.hasDynamicValid()) - return emitOpError("expects source tile_buf to have dynamic validShape (?, ?)"); - - shape.assign(srcTy.getShape().begin(), srcTy.getShape().end()); - - if (!isLocallyBoundTileSource(getSource())) - return emitOpError( - "requires a locally bound tile source; function arguments/results " - "are unsupported"); - } else if (auto srcTy = llvm::dyn_cast(getSource().getType())) { - if (!(*this)->hasAttr(kLoweredSetValidShapeAttrName)) - return emitOpError( - "expects tile_buf source; memref source is only valid for the internal lowered form"); - if (srcTy.getRank() != 2) - return emitOpError("expects rank-2 memref source after tile lowering"); - shape.assign(srcTy.getShape().begin(), srcTy.getShape().end()); - } else { - return emitOpError("expects tile_buf source (or lowered memref source)"); - } - - auto checkDim = [&](Value operand, unsigned dimIdx, - StringRef dimName) -> LogicalResult { - int64_t maxStatic = shape[dimIdx]; - - auto constVal = getConstIndexLike(operand); - if (!constVal) - return success(); - - if (*constVal < 0) - return emitOpError() << "expects " << dimName << " operand to be non-negative"; - if (maxStatic != ShapedType::kDynamic && *constVal > maxStatic) - return emitOpError() << "expects " << dimName << " operand <= shape dim (" - << maxStatic << ")"; - return success(); - }; - - if (failed(checkDim(getValidRow(), /*dimIdx=*/0, "row"))) - return failure(); - if (failed(checkDim(getValidCol(), /*dimIdx=*/1, "col"))) - return failure(); - - return success(); -} - -mlir::LogicalResult mlir::pto::GetValidShapeOp::verify() { - if (auto srcTy = llvm::dyn_cast(getSource().getType())) { - if (srcTy.getRank() != 2) - return emitOpError("expects rank-2 tile_buf source"); - if (srcTy.getValidShape().size() != 2) - return emitOpError("expects source validShape to be rank-2"); - return success(); - } - if (auto srcTy = llvm::dyn_cast(getSource().getType())) { - if (srcTy.getRank() != 2) - return emitOpError("expects rank-2 memref source after tile lowering"); - return success(); - } - return emitOpError("expects tile_buf source (or lowered memref source)"); -} - - -mlir::LogicalResult mlir::pto::TReshapeOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - Type ts = getSrc().getType(); - Type tr = getResult().getType(); - auto srcTb = dyn_cast(ts); - auto dstTb = dyn_cast(tr); - if (!srcTb || !dstTb) - return emitOpError("expects src/result to be !pto.tile_buf types"); - - if (failed(verifyTileBufCommon(*this, ts, "src")) || - failed(verifyTileBufCommon(*this, tr, "dst"))) - return failure(); - - if (srcTb.getMemorySpace() != dstTb.getMemorySpace()) - return emitOpError("expects src and dst to use the same loc"); - - Type srcElem = srcTb.getElementType(); - Type dstElem = dstTb.getElementType(); - auto srcElemBytes = getElemBytes(srcElem); - auto dstElemBytes = getElemBytes(dstElem); - if (!srcElem || !dstElem || !srcElemBytes.has_value() || !dstElemBytes.has_value()) - return emitOpError("failed to get element byte width for src/dst"); - - auto srcNumel = getStaticNumElements(getShapeVec(ts)); - auto dstNumel = getStaticNumElements(getShapeVec(tr)); - if (!srcNumel.has_value() || !dstNumel.has_value()) - return emitOpError("expects static shapes for treshape"); - - if (srcElemBytes.value() * srcNumel.value() != - dstElemBytes.value() * dstNumel.value()) - return emitOpError("expects src and dst to have the same total byte size"); - - bool srcBoxed = - srcTb.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox); - bool dstBoxed = - dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox); - if (srcBoxed != dstBoxed) - return emitOpError("cannot reshape between boxed and non-boxed tile layouts"); - - return success(); -} - -mlir::LogicalResult mlir::pto::BitcastOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - auto srcTy = llvm::dyn_cast(getSrc().getType()); - auto dstTy = llvm::dyn_cast(getResult().getType()); - if (!srcTy || !dstTy) - return emitOpError("expects tile_buf src and tile_buf result"); - - if (srcTy.getMemorySpace() != dstTy.getMemorySpace()) - return emitOpError("expects src/result to have the same memorySpace"); - - if (srcTy.getElementType() == dstTy.getElementType()) - return emitOpError( - "expects src/result to have different element types; use " - "pto.treshape for shape/config changes"); - - if (srcTy.getShape() != dstTy.getShape()) - return emitOpError("expects src/result to have the same shape; use pto.treshape for shape changes"); - - if (srcTy.getValidShape() != dstTy.getValidShape()) - return emitOpError("expects src/result to have the same validShape"); - - auto srcCfg = srcTy.getConfigAttr(); - auto dstCfg = dstTy.getConfigAttr(); - if (srcCfg != dstCfg) - return emitOpError("expects src/result to have the same tile config"); - - auto numel = getStaticNumElements(srcTy.getShape()); - if (!numel.has_value()) - return emitOpError("expects static shapes for bitcast"); - - auto srcBytes = getElemBytes(srcTy.getElementType()); - auto dstBytes = getElemBytes(dstTy.getElementType()); - if (!srcBytes.has_value() || !dstBytes.has_value()) - return emitOpError("unsupported element type for bitcast"); - - int64_t srcTotalBytes = numel.value() * srcBytes.value(); - int64_t dstTotalBytes = numel.value() * dstBytes.value(); - if (dstTotalBytes > srcTotalBytes) - return emitOpError("bitcast result requires more bytes than source storage"); - - return success(); -} - - -mlir::LogicalResult mlir::pto::TRowExpandOp::verify() { - auto verifyCommon = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyNDStyleVecTile(*this, dstTy, "dst"))) - return failure(); - auto srcSpace = getPTOMemorySpaceEnum(srcTy); - if (!srcSpace || *srcSpace != pto::AddressSpace::VEC) - return emitOpError("expects src to be in the vec address space"); - if (auto srcTb = dyn_cast(srcTy)) { - if (srcTb.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) - return emitOpError("expects src to use the none_box slayout"); - } - if (getElemTy(srcTy) != getElemTy(dstTy)) - return emitOpError("expects src and dst to have the same element type"); - if (!isSupportedVecElemType(getElemTy(srcTy), /*allowBf16=*/true, - /*allowInt8=*/true)) - return emitOpError("expects trowexpand element type to be supported"); - auto srcValid = getValidShapeVec(getSrc()); - auto dstValid = getValidShapeVec(getDst()); - if (srcValid.size() != 2 || dstValid.size() != 2) - return emitOpError("expects src and dst to have rank-2 valid_shape"); - if (srcValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && - srcValid[0] != dstValid[0]) - return emitOpError("expects src and dst to have the same valid_shape[0]"); - if (srcValid[0] != ShapedType::kDynamic && srcValid[0] == 0) - return emitOpError("expects src valid_shape[0] to be non-zero"); - if (srcValid[1] != ShapedType::kDynamic && srcValid[1] == 0) - return emitOpError("expects src valid_shape[1] to be non-zero"); - if (dstValid[0] != ShapedType::kDynamic && dstValid[0] == 0) - return emitOpError("expects dst valid_shape[0] to be non-zero"); - if (dstValid[1] != ShapedType::kDynamic && dstValid[1] == 0) - return emitOpError("expects dst valid_shape[1] to be non-zero"); - return success(); - }; - auto verifyA2A3 = [&]() -> LogicalResult { - return verifyCommon(); - }; - auto verifyA5 = [&]() -> LogicalResult { - return verifyCommon(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - - -ParseResult mlir::pto::TSort32Op::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand src, idx, tmp, dst; - Type srcTy, dstTy, idxTy, tmpTy; - bool hasTmp = false; - - if (parser.parseKeyword("ins") || parser.parseLParen() || parser.parseOperand(src)) - return failure(); - if (succeeded(parser.parseOptionalComma())) { - if (parser.parseOperand(idx)) - return failure(); - if (succeeded(parser.parseOptionalComma())) { - if (parser.parseOperand(tmp)) - return failure(); - hasTmp = true; - } - } else { - return failure(); - } - if (parser.parseColonType(srcTy) || parser.parseComma() || parser.parseType(idxTy)) - return failure(); - if (hasTmp) { - if (parser.parseComma() || parser.parseType(tmpTy)) - return failure(); - } - if (parser.parseRParen()) - return failure(); - - if (parser.parseKeyword("outs") || parser.parseLParen() || - parser.parseOperand(dst) || parser.parseColonType(dstTy) || - parser.parseRParen()) - return failure(); - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - - if (parser.resolveOperand(src, srcTy, result.operands) || - parser.resolveOperand(idx, idxTy, result.operands)) - return failure(); - if (hasTmp) { - if (parser.resolveOperand(tmp, tmpTy, result.operands)) - return failure(); - } - if (parser.resolveOperand(dst, dstTy, result.operands)) - return failure(); - - result.addAttribute( - "operandSegmentSizes", - parser.getBuilder().getDenseI32ArrayAttr({1, 1, hasTmp ? 1 : 0, 1})); - return success(); -} - -void mlir::pto::TSort32Op::print(OpAsmPrinter &p) { - p << " ins(" << getSrc() << ", " << getIdx(); - if (getTmp()) { - p << ", " << getTmp(); - p << " : " << getSrc().getType() << ", " << getIdx().getType() - << ", " << getTmp().getType() << ")"; - } else { - p << " : " << getSrc().getType() << ", " << getIdx().getType() << ")"; - } - p << " outs(" << getDst() << " : " << getDst().getType() << ")"; - p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"operandSegmentSizes"}); -} - -ParseResult mlir::pto::TRsqrtOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand src, tmp, dst; - Type srcTy, tmpTy, dstTy; - bool hasTmp = false; - - if (parser.parseKeyword("ins") || parser.parseLParen() || parser.parseOperand(src)) - return failure(); - if (succeeded(parser.parseOptionalComma())) { - if (parser.parseOperand(tmp)) - return failure(); - hasTmp = true; - } - if (parser.parseColonType(srcTy)) - return failure(); - if (hasTmp) { - if (parser.parseComma() || parser.parseType(tmpTy)) - return failure(); - } - if (parser.parseRParen()) - return failure(); - - if (parser.parseKeyword("outs") || parser.parseLParen() || - parser.parseOperand(dst) || parser.parseColonType(dstTy) || - parser.parseRParen()) - return failure(); - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - - if (parser.resolveOperand(src, srcTy, result.operands) || - parser.resolveOperand(dst, dstTy, result.operands)) - return failure(); - if (hasTmp && parser.resolveOperand(tmp, tmpTy, result.operands)) - return failure(); - - return success(); -} - -void mlir::pto::TRsqrtOp::print(OpAsmPrinter &p) { - p << " ins(" << getSrc(); - if (getTmp()) - p << ", " << getTmp(); - p << " : " << getSrc().getType(); - if (getTmp()) - p << ", " << getTmp().getType(); - p << ")"; - p << " outs(" << getDst() << " : " << getDst().getType() << ")"; - p.printOptionalAttrDict((*this)->getAttrs()); -} - -static ParseResult parseTRowExpandBinaryLikeOp(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::UnresolvedOperand src0, src1, tmp, dst; - Type src0Ty, src1Ty, tmpTy, dstTy; - bool hasTmp = false; - - if (parser.parseKeyword("ins") || parser.parseLParen() || - parser.parseOperand(src0) || parser.parseComma() || parser.parseOperand(src1)) - return failure(); - if (succeeded(parser.parseOptionalComma())) { - if (parser.parseOperand(tmp)) - return failure(); - hasTmp = true; - } - if (parser.parseColon()) - return failure(); - if (parser.parseType(src0Ty) || parser.parseComma() || parser.parseType(src1Ty)) - return failure(); - if (hasTmp) { - if (parser.parseComma() || parser.parseType(tmpTy)) - return failure(); - } - if (parser.parseRParen()) - return failure(); - if (parser.parseKeyword("outs") || parser.parseLParen() || - parser.parseOperand(dst) || parser.parseColonType(dstTy) || - parser.parseRParen()) - return failure(); - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - - if (parser.resolveOperand(src0, src0Ty, result.operands) || - parser.resolveOperand(src1, src1Ty, result.operands)) - return failure(); - if (hasTmp) { - if (parser.resolveOperand(tmp, tmpTy, result.operands)) - return failure(); - } - if (parser.resolveOperand(dst, dstTy, result.operands)) - return failure(); - - result.addAttribute( - "operandSegmentSizes", - parser.getBuilder().getDenseI32ArrayAttr({1, 1, hasTmp ? 1 : 0, 1})); - return success(); -} - -static void printTRowExpandBinaryLikeOp(OpAsmPrinter &p, Operation *op, Value src0, - Value src1, Value tmp, Value dst) { - p << " ins(" << src0 << ", " << src1; - if (tmp) { - p << ", " << tmp; - p << " : " << src0.getType() << ", " << src1.getType() << ", " - << tmp.getType() << ")"; - } else { - p << " : " << src0.getType() << ", " << src1.getType() << ")"; - } - p << " outs(" << dst << " : " << dst.getType() << ")"; - p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"operandSegmentSizes"}); -} - -ParseResult mlir::pto::TRowExpandDivOp::parse(OpAsmParser &parser, OperationState &result) { - return parseTRowExpandBinaryLikeOp(parser, result); -} - -void mlir::pto::TRowExpandDivOp::print(OpAsmPrinter &p) { - printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), - getDst()); -} - -ParseResult mlir::pto::TRowExpandMulOp::parse(OpAsmParser &parser, OperationState &result) { - return parseTRowExpandBinaryLikeOp(parser, result); -} - -void mlir::pto::TRowExpandMulOp::print(OpAsmPrinter &p) { - printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), - getDst()); -} - -ParseResult mlir::pto::TRowExpandSubOp::parse(OpAsmParser &parser, OperationState &result) { - return parseTRowExpandBinaryLikeOp(parser, result); -} - -void mlir::pto::TRowExpandSubOp::print(OpAsmPrinter &p) { - printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), - getDst()); -} - -ParseResult mlir::pto::TRowExpandExpdifOp::parse(OpAsmParser &parser, - OperationState &result) { - return parseTRowExpandBinaryLikeOp(parser, result); -} - -void mlir::pto::TRowExpandExpdifOp::print(OpAsmPrinter &p) { - printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), - getDst()); -} - -ParseResult mlir::pto::TRowExpandMaxOp::parse(OpAsmParser &parser, - OperationState &result) { - return parseTRowExpandBinaryLikeOp(parser, result); -} - -void mlir::pto::TRowExpandMaxOp::print(OpAsmPrinter &p) { - printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), - getDst()); -} - -ParseResult mlir::pto::TRowExpandMinOp::parse(OpAsmParser &parser, - OperationState &result) { - return parseTRowExpandBinaryLikeOp(parser, result); -} - -void mlir::pto::TRowExpandMinOp::print(OpAsmPrinter &p) { - printTRowExpandBinaryLikeOp(p, getOperation(), getSrc0(), getSrc1(), getTmp(), - getDst()); -} - -static FailureOr verifyTRowExpandBinaryCore(Operation *op, Type src0Ty, - Type src1Ty, Type dstTy, - Type tmpTy, bool hasTmp) { - if (failed(verifyTileBufCommon(op, src0Ty, "src0")) || - failed(verifyTileBufCommon(op, src1Ty, "src1")) || - failed(verifyTileBufCommon(op, dstTy, "dst"))) - return failure(); - if (hasTmp && failed(verifyTileBufCommon(op, tmpTy, "tmp"))) - return failure(); - if (failed(verifyTileBufSameElemType(op, src0Ty, dstTy, "src0", "dst"))) - return failure(); - if (getElemTy(src0Ty) != getElemTy(src1Ty)) { - op->emitOpError("expects src0 and src1 to have the same element type"); - return failure(); - } - if (!isRowMajorTileBuf(dstTy)) { - op->emitOpError("expects dst to use row-major layout"); - return failure(); - } - return getElemTy(src0Ty); -} - -mlir::LogicalResult mlir::pto::TRowExpandDivOp::verify() { - auto verifyByArch = [&](PTOArch targetArch) -> LogicalResult { - Type src0Ty = getSrc0().getType(); - Type src1Ty = getSrc1().getType(); - Type dstTy = getDst().getType(); - FailureOr elemOr = verifyTRowExpandBinaryCore( - *this, src0Ty, src1Ty, dstTy, getTmp() ? getTmp().getType() : Type{}, - static_cast(getTmp())); - if (failed(elemOr)) - return failure(); - Type elem = *elemOr; - bool supported = - elem.isF16() || elem.isF32() || - (targetArch == PTOArch::A5 && - (elem.isInteger(8) || elem.isInteger(16) || elem.isInteger(32))); - if (!supported) { - if (targetArch == PTOArch::A5) - return emitOpError( - "expects A5 trowexpanddiv element type to be i8/i16/i32/f16/f32"); - return emitOpError("expects element type to be f16 or f32"); - } - return mlir::success(); - }; - auto verifyA2A3 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A3); }; - auto verifyA5 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A5); }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - - -mlir::LogicalResult mlir::pto::TRowExpandMulOp::verify() { - auto verifyByArch = [&](PTOArch targetArch) -> LogicalResult { - Type src0Ty = getSrc0().getType(); - Type src1Ty = getSrc1().getType(); - Type dstTy = getDst().getType(); - FailureOr elemOr = verifyTRowExpandBinaryCore( - *this, src0Ty, src1Ty, dstTy, getTmp() ? getTmp().getType() : Type{}, - static_cast(getTmp())); - if (failed(elemOr)) - return failure(); - Type elem = *elemOr; - bool supported = elem.isF16() || elem.isF32() || elem.isInteger(16) || - elem.isInteger(32) || - (targetArch == PTOArch::A5 && elem.isInteger(8)); - if (!supported) { - if (targetArch == PTOArch::A5) - return emitOpError( - "expects A5 trowexpandmul element type to be i8/i16/i32/f16/f32"); - return emitOpError( - "expects A2/A3 trowexpandmul element type to be i16/i32/f16/f32"); - } - return mlir::success(); - }; - auto verifyA2A3 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A3); }; - auto verifyA5 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A5); }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - - -mlir::LogicalResult mlir::pto::TRowExpandSubOp::verify() { - auto verifyByArch = [&](PTOArch targetArch) -> LogicalResult { - Type src0Ty = getSrc0().getType(); - Type src1Ty = getSrc1().getType(); - Type dstTy = getDst().getType(); - FailureOr elemOr = verifyTRowExpandBinaryCore( - *this, src0Ty, src1Ty, dstTy, getTmp() ? getTmp().getType() : Type{}, - static_cast(getTmp())); - if (failed(elemOr)) - return failure(); - Type elem = *elemOr; - bool supported = elem.isF16() || elem.isF32() || elem.isInteger(16) || - elem.isInteger(32) || - (targetArch == PTOArch::A5 && elem.isInteger(8)); - if (!supported) { - if (targetArch == PTOArch::A5) - return emitOpError( - "expects A5 trowexpandsub element type to be i8/i16/i32/f16/f32"); - return emitOpError( - "expects A2/A3 trowexpandsub element type to be i16/i32/f16/f32"); - } - return mlir::success(); - }; - auto verifyA2A3 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A3); }; - auto verifyA5 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A5); }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TRowExpandAddOp::verify() { - auto verifyByArch = [&](PTOArch targetArch) -> LogicalResult { - Type src0Ty = getSrc0().getType(); - Type src1Ty = getSrc1().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, src0Ty, "src0")) || - failed(verifyTileBufCommon(*this, src1Ty, "src1")) || - failed(verifyTileBufCommon(*this, dstTy, "dst"))) - return failure(); - if (failed(verifyTileBufSameElemType(*this, src0Ty, dstTy, "src0", "dst"))) - return failure(); - if (failed(verifyTileBufSameValidShape(*this, src0Ty, dstTy, "src0", "dst"))) - return failure(); - if (getElemTy(src0Ty) != getElemTy(src1Ty)) - return emitOpError("expects src0 and src1 to have the same element type"); - if (!isRowMajorTileBuf(src0Ty)) - return emitOpError("expects src0 to use row-major layout"); - if (!isRowMajorTileBuf(dstTy)) - return emitOpError("expects dst to use row-major layout"); - Type elem = getElemTy(src0Ty); - bool supported = elem.isF16() || elem.isF32() || elem.isInteger(16) || - elem.isInteger(32) || - (targetArch == PTOArch::A5 && elem.isInteger(8)); - if (!supported) { - if (targetArch == PTOArch::A5) - return emitOpError( - "expects A5 trowexpandadd element type to be i8/i16/i32/f16/f32"); - return emitOpError( - "expects A2/A3 trowexpandadd element type to be i16/i32/f16/f32"); - } - auto src1Valid = getValidShapeVec(src1Ty); - auto dstValid = getValidShapeVec(dstTy); - if (src1Valid.size() != 2 || dstValid.size() != 2) - return emitOpError("expects src1 and dst to have rank-2 valid_shape"); - if (src1Valid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && - src1Valid[0] != dstValid[0]) - return emitOpError("expects src1 valid_shape[0] to equal dst valid_shape[0]"); - bool src1IsRowMajor = isRowMajorTileBuf(src1Ty); - int64_t expectedCol = elem.isInteger(8) - ? 32 - : ((elem.isF16() || elem.isInteger(16)) ? 16 : 8); - int64_t src1Col = src1Valid[1]; - if (src1IsRowMajor) { - if (src1Col != ShapedType::kDynamic && src1Col != expectedCol) - return emitOpError("expects row-major src1 valid_shape[1] to be 32/sizeof(dtype)"); - } else { - if (src1Col != ShapedType::kDynamic && src1Col != 1) - return emitOpError("expects non-row-major src1 valid_shape[1] to be 1"); - } - return mlir::success(); - }; - auto verifyA2A3 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A3); }; - auto verifyA5 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A5); }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -static LogicalResult verifyTRowExpandReduceLikeOp(Operation *op, Type src0Ty, - Type src1Ty, Type dstTy, - Type tmpTy, bool hasTmp, - PTOArch targetArch, - StringRef opName, - bool allowIntegerTypes) { - if (failed(verifyTileBufCommon(op, src0Ty, "src0")) || - failed(verifyTileBufCommon(op, src1Ty, "src1")) || - failed(verifyTileBufCommon(op, dstTy, "dst"))) - return failure(); - if (hasTmp) { - if (failed(verifyTileBufCommon(op, tmpTy, "tmp"))) - return failure(); - if (getElemTy(tmpTy) != getElemTy(dstTy)) - return op->emitOpError() << "expects tmp and dst to have the same element type"; - } - - Type elem = getElemTy(dstTy); - if (!elem || getElemTy(src0Ty) != elem || getElemTy(src1Ty) != elem) - return op->emitOpError("expects src0, src1, and dst to have the same element type"); - bool supported = elem.isF16() || elem.isF32() || - (allowIntegerTypes && - (elem.isInteger(16) || elem.isInteger(32) || - (targetArch == PTOArch::A5 && elem.isInteger(8)))); - if (!supported) { - if (!allowIntegerTypes) - return op->emitOpError() << "expects " << opName - << " element type to be f16 or f32"; - if (targetArch == PTOArch::A5) - return op->emitOpError() << "expects A5 " << opName - << " element type to be i8/i16/i32/f16/f32"; - return op->emitOpError() << "expects A2/A3 " << opName - << " element type to be i16/i32/f16/f32"; - } - - if (!isRowMajorTileBuf(dstTy)) - return op->emitOpError("expects dst to use row-major layout"); - - auto src0Valid = getValidShapeVec(src0Ty); - auto src1Valid = getValidShapeVec(src1Ty); - auto dstValid = getValidShapeVec(dstTy); - if (src0Valid.size() != 2 || src1Valid.size() != 2 || dstValid.size() != 2) - return op->emitOpError("expects src0, src1, and dst to have rank-2 valid_shape"); - - if (dstValid[0] != ShapedType::kDynamic && dstValid[0] == 0) - return op->emitOpError("expects dst valid_shape[0] to be non-zero"); - if (dstValid[1] != ShapedType::kDynamic && dstValid[1] == 0) - return op->emitOpError("expects dst valid_shape[1] to be non-zero"); - - auto validShapeMatches = [](ArrayRef lhs, - ArrayRef rhs) -> bool { - if (lhs.size() != rhs.size()) - return false; - for (auto [l, r] : llvm::zip(lhs, rhs)) { - if (l != ShapedType::kDynamic && r != ShapedType::kDynamic && l != r) - return false; - } - return true; - }; - - const bool src0MatchesDst = validShapeMatches(src0Valid, dstValid); - const bool src1MatchesDst = validShapeMatches(src1Valid, dstValid); - - auto checkBroadcastOperand = [&](Type operandTy, ArrayRef operandValid, - StringRef operandName, - bool requireNonRowMajor) -> LogicalResult { - if (operandValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && - operandValid[0] != dstValid[0]) { - return op->emitOpError() << "expects " << operandName - << " valid_shape[0] to equal dst valid_shape[0]"; - } - int64_t expectedCol = elem.isInteger(8) ? 32 : ((elem.isF16() || elem.isInteger(16)) ? 16 : 8); - int64_t operandCol = operandValid[1]; - bool operandIsRowMajor = isRowMajorTileBuf(operandTy); - if (requireNonRowMajor && operandIsRowMajor) { - return op->emitOpError() << "expects " << operandName - << " to use a non-row-major layout when tmp is present"; - } - if (operandIsRowMajor) { - if (operandCol != ShapedType::kDynamic && operandCol != expectedCol) { - return op->emitOpError() - << "expects row-major " << operandName - << " valid_shape[1] to be 32/sizeof(dtype)"; - } - return success(); - } - if (operandCol != ShapedType::kDynamic && operandCol != 1) { - return op->emitOpError() << "expects non-row-major " << operandName - << " valid_shape[1] to be 1"; - } - return success(); - }; - - auto checkFullAndBroadcast = [&](Type fullTy, ArrayRef fullValid, - StringRef fullName, Type broadcastTy, - ArrayRef broadcastValid, - StringRef broadcastName) -> LogicalResult { - if (!isRowMajorTileBuf(fullTy)) - return op->emitOpError() << "expects " << fullName - << " to use row-major layout when it matches dst"; - if (fullValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && - fullValid[0] != dstValid[0]) - return op->emitOpError() << "expects " << fullName - << " valid_shape[0] to equal dst valid_shape[0]"; - if (fullValid[1] != ShapedType::kDynamic && dstValid[1] != ShapedType::kDynamic && - fullValid[1] != dstValid[1]) - return op->emitOpError() << "expects " << fullName - << " valid_shape[1] to equal dst valid_shape[1]"; - return checkBroadcastOperand(broadcastTy, broadcastValid, broadcastName, - /*requireNonRowMajor=*/hasTmp && - targetArch == PTOArch::A3); - }; - - if (hasTmp && targetArch == PTOArch::A5) - return op->emitOpError("expects A5 form to omit tmp"); - - if (src0MatchesDst) { - if (succeeded(checkFullAndBroadcast(src0Ty, src0Valid, "src0", src1Ty, - src1Valid, "src1"))) - return success(); - } - if (src1MatchesDst) { - if (succeeded(checkFullAndBroadcast(src1Ty, src1Valid, "src1", src0Ty, - src0Valid, "src0"))) - return success(); - } - - return op->emitOpError() << "expects one of src0/src1 to match dst valid_shape" - << " and the other to be a per-row scalar vector"; -} - -mlir::LogicalResult mlir::pto::TRowExpandExpdifOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), - getSrc1().getType(), getDst().getType(), - getTmp() ? getTmp().getType() : Type{}, - (bool)getTmp(), PTOArch::A3, - "trowexpandexpdif", - /*allowIntegerTypes=*/false); - }; - auto verifyA5 = [&]() -> LogicalResult { - return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), - getSrc1().getType(), getDst().getType(), - getTmp() ? getTmp().getType() : Type{}, - (bool)getTmp(), PTOArch::A5, - "trowexpandexpdif", - /*allowIntegerTypes=*/false); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TRowExpandMaxOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), - getSrc1().getType(), getDst().getType(), - getTmp() ? getTmp().getType() : Type{}, - (bool)getTmp(), PTOArch::A3, - "trowexpandmax", - /*allowIntegerTypes=*/true); - }; - auto verifyA5 = [&]() -> LogicalResult { - return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), - getSrc1().getType(), getDst().getType(), - getTmp() ? getTmp().getType() : Type{}, - (bool)getTmp(), PTOArch::A5, - "trowexpandmax", - /*allowIntegerTypes=*/true); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TRowExpandMinOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), - getSrc1().getType(), getDst().getType(), - getTmp() ? getTmp().getType() : Type{}, - (bool)getTmp(), PTOArch::A3, - "trowexpandmin", - /*allowIntegerTypes=*/true); - }; - auto verifyA5 = [&]() -> LogicalResult { - return verifyTRowExpandReduceLikeOp(getOperation(), getSrc0().getType(), - getSrc1().getType(), getDst().getType(), - getTmp() ? getTmp().getType() : Type{}, - (bool)getTmp(), PTOArch::A5, - "trowexpandmin", - /*allowIntegerTypes=*/true); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - - -mlir::LogicalResult mlir::pto::TRowMaxOp::verify() { - auto verifyByArch = [&]() -> LogicalResult { - return verifyTRowReductionNoTmpCommon(*this, getSrc().getType(), - getDst().getType(), - "expects element type to be i16/i32/f16/f32"); - }; - return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); -} - -mlir::LogicalResult mlir::pto::TRowArgMaxOp::verify() { - auto verifyByArch = [&]() -> LogicalResult { - return verifyTRowArgReductionCommon(*this, getSrc().getType(), - getTmp().getType(), getDst().getType()); - }; - - return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); -} - - -mlir::LogicalResult mlir::pto::TRowMinOp::verify() { - auto verifyByArch = [&]() -> LogicalResult { - return verifyTRowReductionWithTmpCommon( - *this, getSrc().getType(), getTmp().getType(), getDst().getType(), - "expects element type to be i16/i32/f16/f32"); - }; - return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); -} - -mlir::LogicalResult mlir::pto::TRowArgMinOp::verify() { - auto verifyByArch = [&]() -> LogicalResult { - return verifyTRowArgReductionCommon(*this, getSrc().getType(), - getTmp().getType(), getDst().getType()); - }; - - return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); -} - - -mlir::LogicalResult mlir::pto::TRowSumOp::verify() { - auto verifyByArch = [&]() -> LogicalResult { - return verifyTRowReductionNoTmpCommon(*this, getSrc().getType(), - getDst().getType(), - "expects element type to be i16/i32/f16/f32"); - }; - return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); -} - -mlir::LogicalResult mlir::pto::TRowProdOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - return verifyTRowReductionWithTmpCommon( - *this, getSrc().getType(), getTmp().getType(), getDst().getType(), - "expects A2/A3 trowprod element type to be i16/i32/f16/f32"); - }; - auto verifyA5 = [&]() -> LogicalResult { - return verifyTRowReductionWithTmpCommon( - *this, getSrc().getType(), getTmp().getType(), getDst().getType(), - "expects A5 trowprod element type to be i16/i32/f16/f32"); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - - -mlir::LogicalResult mlir::pto::TRsqrtOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - Type ts = getSrc().getType(); - Type td = getDst().getType(); - if (failed(verifyVecTileUnaryOp(*this, ts, td, "src", "dst", - /*allowBf16=*/false, /*allowInt8=*/false))) - return failure(); - if (failed(verifyTileBufSameValidShape(*this, ts, td, "src", "dst"))) - return failure(); - auto ft = mlir::dyn_cast(getElemTy(ts)); - if (!ft || (!ft.isF16() && !ft.isF32())) - return emitOpError("expects element type to be f16 or f32"); - if (auto tmp = getTmp()) { - Type tt = tmp.getType(); - if (failed(verifyVecTileCommon(*this, tt, "tmp"))) - return failure(); - - auto tmpElemTy = getElemTy(tt); - auto tmpElemBytes = getElemBytes(tmpElemTy); - auto tmpNumel = getStaticNumElements(getShapeVec(tt)); - if (!tmpElemBytes.has_value() || !tmpNumel.has_value()) - return emitOpError("expects tmp to have a static, byte-addressable tile type"); - if (tmpElemBytes.value() * tmpNumel.value() < 32) - return emitOpError("expects tmp to be at least 32 bytes when provided"); - } - return mlir::success(); -} - - -mlir::LogicalResult mlir::pto::TScatterOp::verify() { - const bool hasIndexes = static_cast(getIndexes()); - const bool hasMaskPattern = static_cast(getMaskPatternAttr()); - if (hasIndexes == hasMaskPattern) { - return emitOpError( - "expects exactly one of indexes operand or maskPattern attribute"); - } - - auto isAllowedDataElem = [&](mlir::Type t) -> bool { - if (t.isF16() || t.isF32() || t.isBF16()) return true; - if (auto it = mlir::dyn_cast(t)) - return (it.getWidth() == 8 || it.getWidth() == 16 || it.getWidth() == 32); - return false; - }; - auto isAllowedIndexElem = [&](mlir::Type t) -> bool { - if (auto it = mlir::dyn_cast(t)) - return (it.getWidth() == 16 || it.getWidth() == 32); - return false; - }; - auto getMaskScatterTimes = [&](mlir::pto::MaskPatternAttr mp) -> unsigned { - switch (mp.getValue()) { - case mlir::pto::MaskPattern::P1111: - return 1; - case mlir::pto::MaskPattern::P0101: - case mlir::pto::MaskPattern::P1010: - return 2; - default: - return 4; - } - }; - - auto verifyIndexedForm = [&]() -> LogicalResult { - Type ts = getSrc().getType(); - Type ti = getIndexes().getType(); - Type td = getDst().getType(); - if (failed(verifyVecTileStorage(*this, ts, "src")) || - failed(verifyVecTileStorage(*this, ti, "indexes")) || - failed(verifyVecTileStorage(*this, td, "dst"))) - return failure(); - - Type srcElem = getElemTy(ts), dstElem = getElemTy(td), idxElem = getElemTy(ti); - if (!srcElem || !dstElem || !idxElem) - return emitOpError("failed to get element type for operands"); - if (srcElem != dstElem) - return emitOpError("expects src/dst to have the same element type"); - - if (!isAllowedDataElem(srcElem)) - return emitOpError("expects src/dst element type to be i8/i16/i32/f16/bf16/f32"); - if (!isAllowedIndexElem(idxElem)) - return emitOpError("expects indexes element type to be i16/i32"); - - auto bwData = getPTOStorageElemBitWidth(srcElem); - auto bwIdx = getPTOStorageElemBitWidth(idxElem); - if (bwData != 8 && bwData != 16 && bwData != 32) - return emitOpError("unexpected src/dst element bitwidth"); - - unsigned dataBytes = bwData / 8; - unsigned idxBytes = bwIdx / 8; - unsigned expectedIdxBytes = (dataBytes == 1) ? 2 : dataBytes; - if (idxBytes != expectedIdxBytes) - return emitOpError("expects indexes element size to match the documented scatter rule"); - return mlir::success(); - }; - - auto verifyMaskForm = [&]() -> LogicalResult { - Type ts = getSrc().getType(); - Type td = getDst().getType(); - if (failed(verifyVecTileCommon(*this, ts, "src")) || - failed(verifyVecTileCommon(*this, td, "dst"))) - return failure(); - - auto srcTB = dyn_cast(ts); - auto dstTB = dyn_cast(td); - if (!srcTB || !dstTB) - return emitOpError("expects src and dst to be tile_buf types"); - - if (getElemTy(ts) != getElemTy(td)) - return emitOpError("expects src and dst to have the same element type"); - if (!isAllowedDataElem(getElemTy(ts))) - return emitOpError("expects src/dst element type to be i8/i16/i32/f16/bf16/f32"); - - auto srcValid = getValidShapeVec(ts); - auto dstValid = getValidShapeVec(td); - if (srcValid.size() != 2 || dstValid.size() != 2) - return emitOpError("expects src and dst to have rank-2 valid_shape"); - - auto mp = getMaskPatternAttr(); - if (!mp) - return emitOpError("expects mask-pattern tscatter to provide maskPattern"); - const unsigned times = getMaskScatterTimes(mp); - if (srcValid[0] != ShapedType::kDynamic && dstValid[0] != ShapedType::kDynamic && - srcValid[0] != dstValid[0]) - return emitOpError("expects src and dst to have the same valid rows"); - if (srcValid[1] != ShapedType::kDynamic && dstValid[1] != ShapedType::kDynamic && - srcValid[1] != static_cast(dstValid[1] * times)) - return emitOpError("expects src valid cols to equal dst valid cols times the mask expansion factor"); - - if (srcTB.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) || - dstTB.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor)) - return emitOpError("expects mask-pattern tscatter to use row_major blayout"); - return mlir::success(); - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - if (hasMaskPattern) - return verifyMaskForm(); - return verifyIndexedForm(); - }; - auto verifyA5 = [&]() -> LogicalResult { - if (hasMaskPattern) - return emitOpError("mask-pattern tscatter is not supported on A5 yet"); - return verifyIndexedForm(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - - -mlir::LogicalResult mlir::pto::TSelOp::verify() { - auto verifyCommon = [&]() -> FailureOr { - Type t0 = getSrc0().getType(); - Type t1 = getSrc1().getType(); - Type td = getDst().getType(); - if (failed(verifyTileBufCommon(*this, t0, "src0")) || - failed(verifyTileBufCommon(*this, t1, "src1")) || - failed(verifyTileBufCommon(*this, td, "dst"))) - return failure(); - - Type srcElem = getElemTy(t0); - Type src1Elem = getElemTy(t1); - Type dstElem = getElemTy(td); - if (!srcElem || !src1Elem || !dstElem) { - emitOpError("failed to get element type for operands"); - return failure(); - } - if (srcElem != src1Elem || srcElem != dstElem) { - emitOpError("expects src0, src1, and dst to have the same element type"); - return failure(); - } - - if (!isRowMajorTileBuf(t0) || !isRowMajorTileBuf(t1) || - !isRowMajorTileBuf(td)) { - emitOpError( - "expects src0, src1, and dst to use row-major layout"); - return failure(); - } - return srcElem; - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - FailureOr srcElem = verifyCommon(); - if (failed(srcElem)) - return failure(); - Type elem = *srcElem; - bool ok = elem.isF16() || elem.isBF16() || elem.isF32(); - if (auto it = dyn_cast(elem)) - ok = it.getWidth() == 16 || it.getWidth() == 32; - if (!ok) - return emitOpError( - "expects A2/A3 tsel src0, src1, and dst element type to be i16/i32/f16/bf16/f32"); - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - FailureOr srcElem = verifyCommon(); - if (failed(srcElem)) - return failure(); - Type elem = *srcElem; - bool ok = elem.isF16() || elem.isBF16() || elem.isF32(); - if (auto it = dyn_cast(elem)) - ok = it.getWidth() == 8 || it.getWidth() == 16 || it.getWidth() == 32; - if (!ok) - return emitOpError( - "expects A5 tsel src0, src1, and dst element type to be i8/i16/i32/f16/bf16/f32"); - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - - -mlir::LogicalResult mlir::pto::TSelSOp::verify() { - // Constraints & Verification per PTO_IR_manual.md pto.tsels: - // - src and dst same element type; A2A3: i16/i32/f16/f32; A5: i8/i16/i32/f16/f32 - // - src and dst row-major; src and dst same valid region - auto verifyCommon = [&]() -> FailureOr { - Type tMask = getMask().getType(); - Type tSrc = getSrc().getType(); - Type tTmp = getTmp().getType(); - Type tDst = getDst().getType(); - if (failed(verifyTileBufCommon(*this, tMask, "mask")) || - failed(verifyTileBufCommon(*this, tSrc, "src")) || - failed(verifyTileBufCommon(*this, tTmp, "tmp")) || - failed(verifyTileBufCommon(*this, tDst, "dst"))) - return failure(); - Type eMask = getElemTy(tMask), eSrc = getElemTy(tSrc); - Type eTmp = getElemTy(tTmp), eDst = getElemTy(tDst); - if (!eMask || !eSrc || !eTmp || !eDst) { - emitOpError("failed to get element type for operands"); - return failure(); - } - if (eSrc != eDst) - return emitOpError("expects src and dst to have the same element type"); - if (failed(verifyTileBufSameValidShape(*this, tSrc, tDst, "src", "dst"))) - return failure(); - return eDst; - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - FailureOr elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - Type tSrc = getSrc().getType(); - Type tDst = getDst().getType(); - if (!isRowMajorTileBuf(tSrc) || !isRowMajorTileBuf(tDst)) - return emitOpError("expects src and dst to use row-major layout"); - Type elem = *elemOr; - bool ok = elem.isF16() || elem.isF32(); - if (auto it = mlir::dyn_cast(elem)) - ok = (it.getWidth() == 16 || it.getWidth() == 32); - if (!ok) - return emitOpError( - "expects A2/A3 tsels src and dst element type to be i16, i32, f16, or f32"); - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - FailureOr elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - Type tSrc = getSrc().getType(); - Type tDst = getDst().getType(); - if (!isRowMajorTileBuf(tSrc) || !isRowMajorTileBuf(tDst)) - return emitOpError("expects src and dst to use row-major layout"); - Type elem = *elemOr; - bool ok = elem.isF16() || elem.isF32(); - if (auto it = mlir::dyn_cast(elem)) - ok = (it.getWidth() == 8 || it.getWidth() == 16 || it.getWidth() == 32); - if (!ok) - return emitOpError( - "expects A5 tsels src and dst element type to be i8, i16, i32, f16, or f32"); - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - - -mlir::LogicalResult mlir::pto::TShlOp::verify() { - auto verify = [&]() -> LogicalResult { - FailureOr elemOr = verifyShiftLikeBinaryTileOpCommon( - *this, getSrc0().getType(), getSrc1().getType(), getDst().getType()); - if (failed(elemOr)) - return failure(); - auto it = mlir::dyn_cast(*elemOr); - if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && - it.getWidth() != 32)) - return emitOpError( - "expects tshl src0 and src1 element type to be i8/i16/i32"); - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verify, verify); -} - - -mlir::LogicalResult mlir::pto::TShrOp::verify() { - auto verify = [&]() -> LogicalResult { - FailureOr elemOr = verifyShiftLikeBinaryTileOpCommon( - *this, getSrc0().getType(), getSrc1().getType(), getDst().getType()); - if (failed(elemOr)) - return failure(); - auto it = mlir::dyn_cast(*elemOr); - if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && - it.getWidth() != 32)) - return emitOpError( - "expects tshr src0 and src1 element type to be i8/i16/i32"); - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verify, verify); -} - - -mlir::LogicalResult mlir::pto::TSort32Op::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - Type idxTy = getIdx().getType(); - if (failed(verifyVecTileCommon(*this, srcTy, "src")) || - failed(verifyVecTileCommon(*this, dstTy, "dst")) || - failed(verifyVecTileCommon(*this, idxTy, "idx"))) - return failure(); - if (getTmp() && - failed(verifyVecTileCommon(*this, getTmp().getType(), "tmp"))) - return failure(); - - auto srcElem = getElemTy(srcTy); - auto dstElem = getElemTy(dstTy); - if (!srcElem || !dstElem || srcElem != dstElem) - return emitOpError() << "expects src and dst to have the same element type"; - if (!(srcElem.isF16() || srcElem.isF32())) - return emitOpError() << "expects src and dst element type to be f16 or f32"; - - auto idxElem = getElemTy(idxTy); - auto idxInt = dyn_cast(idxElem); - if (!idxInt || idxInt.getWidth() != 32) - return emitOpError() << "expects idx element type to be i32/u32"; - return mlir::success(); -} - - -mlir::LogicalResult mlir::pto::TSqrtOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - Type srcTy = getSrc().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyVecTileUnaryOp(*this, srcTy, dstTy, "src", "dst", - /*allowBf16=*/false, /*allowInt8=*/false))) - return failure(); - if (failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) - return failure(); - - auto srcElem = getElemTy(srcTy); - if (!(mlir::isa(srcElem) || mlir::isa(srcElem))) - return emitOpError() << "expects src and dst element type to be float or half"; - - return mlir::success(); -} - - - -mlir::LogicalResult mlir::pto::TStoreFPOp::verify() { - auto shouldBypassDecoded = [&]() -> bool { - Value src = getSrc(); - Value fp = getFp(); - return isa(src.getType()) || isa(fp.getType()) || - src.getDefiningOp() || - fp.getDefiningOp(); - }; - - auto verifyDstType = [&]() -> LogicalResult { - Type dstTy = getDst().getType(); - if (!isa(dstTy)) - return emitOpError() - << "expects dst to be a memref or !pto.partition_tensor_view"; - if (auto dstPart = dyn_cast(dstTy)) { - for (auto [idx, dim] : llvm::enumerate(dstPart.getShape())) { - if (dim != ShapedType::kDynamic && dim <= 0) - return emitOpError() - << "expects dst shape[" << idx << "] to be positive"; - } - } - return success(); - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type fpTy = getFp().getType(); - if (!isa(srcTy)) - return emitOpError() << "expects src to be a !pto.tile_buf"; - if (!isa(fpTy)) - return emitOpError() << "expects fp to be a !pto.tile_buf"; - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, fpTy, "fp"))) - return failure(); - if (failed(verifyDstType())) - return failure(); - auto srcSpace = getPTOMemorySpaceEnum(srcTy); - if (!srcSpace || *srcSpace != pto::AddressSpace::ACC) - return emitOpError() << "expects src to be in the acc address space"; - auto srcElemTy = getElemTy(srcTy); - auto srcIntTy = dyn_cast(srcElemTy); - if (!(srcElemTy.isF32() || - (srcIntTy && srcIntTy.getWidth() == 32))) - return emitOpError() - << "expects src to have element type f32, i32"; - auto srcShape = getShapeVec(srcTy); - if (srcShape.size() != 2) - return emitOpError() << "expects src to have rank 2"; - if (srcShape[1] != ShapedType::kDynamic && - (srcShape[1] < 1 || srcShape[1] > 4095)) - return emitOpError() << "expects src.cols to be in the range [1, 4095]"; - auto srcValid = getValidShapeVec(srcTy); - if (srcValid.size() != 2) - return emitOpError() << "expects src to have a rank-2 valid_shape"; - if (srcValid[1] != ShapedType::kDynamic && - (srcValid[1] < 1 || srcValid[1] > 4095)) - return emitOpError() - << "expects src.valid_shape[1] to be in the range [1, 4095]"; - return mlir::success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type fpTy = getFp().getType(); - if (!isa(srcTy)) - return emitOpError() << "expects src to be a !pto.tile_buf"; - if (!isa(fpTy)) - return emitOpError() << "expects fp to be a !pto.tile_buf"; - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, fpTy, "fp"))) - return failure(); - if (failed(verifyDstType())) - return failure(); - auto srcSpace = getPTOMemorySpaceEnum(srcTy); - if (!srcSpace || *srcSpace != pto::AddressSpace::ACC) - return emitOpError() << "expects src to be in the acc address space"; - return mlir::success(); - }; - if (shouldBypassDecoded()) - return success(); - switch (getVerifierTargetArch(getOperation())) { - case VerifierTargetArch::A2A3: - return verifyA2A3(); - case VerifierTargetArch::A5: - return verifyA5(); - } - return failure(); -} - - -mlir::LogicalResult mlir::pto::TSubOp::verify() { - return verifyArithmeticBinaryTileOpWithArchDispatch( - getOperation(), getSrc0().getType(), getSrc1().getType(), getDst().getType(), - /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/false, - "expects A2/A3 tsub element type to be i32/i16/f16/f32", - "expects A5 tsub element type to be i32/i16/i8/f16/f32"); -} - - -mlir::LogicalResult mlir::pto::TSubCOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - Type src0Ty = getSrc0().getType(); - Type src1Ty = getSrc1().getType(); - Type src2Ty = getSrc2().getType(); - Type dstTy = getDst().getType(); - if (!isPTOShapedLike(src0Ty) || !isPTOShapedLike(src1Ty) || !isPTOShapedLike(src2Ty) || !isPTOShapedLike(dstTy)) - return emitOpError() << "expects PTO shaped-like src0, src1, src2, and dst"; - - auto d = getShapeVec(dstTy); - if (getShapeVec(src0Ty).size() != d.size() || getShapeVec(src1Ty).size() != d.size() || getShapeVec(src2Ty).size() != d.size()) - return emitOpError() << "expects all tensors to have the same rank"; - return mlir::success(); -} - - -mlir::LogicalResult mlir::pto::TSubSOp::verify() { - return verifyArithmeticScalarTileOpWithArchDispatch( - getOperation(), getSrc().getType(), getDst().getType(), getScalar().getType(), - /*allowInt8OnA5=*/true, /*allowBf16OnA5=*/true, - "expects A2/A3 tsubs element type to be i32/i16/f16/f32", - "expects A5 tsubs element type to be i32/i16/i8/f16/bf16/f32", - /*requireValidRowsEqualOnA2A3=*/true, - /*requireValidRowsEqualOnA5=*/false); -} - - -mlir::LogicalResult mlir::pto::TSubSCOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - Type src0Ty = getSrc0().getType(); - Type src1Ty = getSrc1().getType(); - Type dstTy = getDst().getType(); - if (!isPTOShapedLike(src0Ty) || !isPTOShapedLike(src1Ty) || !isPTOShapedLike(dstTy)) - return emitOpError() << "expects PTO shaped-like src0, src1, and dst"; - - auto d = getShapeVec(dstTy); - if (getShapeVec(src0Ty).size() != d.size() || getShapeVec(src1Ty).size() != d.size()) - return emitOpError() << "expects src0, src1, and dst to have the same rank"; - return mlir::success(); -} -mlir::LogicalResult mlir::pto::TTransOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type tmpTy = getTmp().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, tmpTy, "tmp")) || - failed(verifyTileBufCommon(*this, dstTy, "dst"))) - return failure(); - Type srcElem = getElemTy(srcTy); - Type tmpElem = getElemTy(tmpTy); - Type dstElem = getElemTy(dstTy); - if (!srcElem || !tmpElem || !dstElem || srcElem != dstElem || srcElem != tmpElem) - return emitOpError() << "expects src and dst to have the same element type"; - if (auto srcTb = dyn_cast(srcTy)) { - if (srcTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor)) - return emitOpError() << "expects A2/A3 transpose src to use the row_major blayout"; - } - unsigned elemBytes = getPTOStorageElemByteSize(srcElem); - if (elemBytes == 0) - return emitOpError() << "failed to get transpose element size"; - if (elemBytes != 1 && elemBytes != 2 && elemBytes != 4) - return emitOpError() << "expects transpose element size to be 1, 2, or 4 bytes"; - auto isAllowedWidthType = [&](Type ty) { - if (elemBytes == 4) - return ty.isInteger(32) || ty.isF32(); - if (elemBytes == 2) - return ty.isInteger(16) || ty.isF16() || ty.isBF16(); - return ty.isInteger(8); - }; - if (!isAllowedWidthType(srcElem)) - return emitOpError() << "expects transpose element type to match the supported set for its width"; - return mlir::success(); - }; - auto verifyA5 = [&]() -> LogicalResult { - Type srcTy = getSrc().getType(); - Type tmpTy = getTmp().getType(); - Type dstTy = getDst().getType(); - if (failed(verifyTileBufCommon(*this, srcTy, "src")) || - failed(verifyTileBufCommon(*this, tmpTy, "tmp")) || - failed(verifyTileBufCommon(*this, dstTy, "dst"))) - return failure(); - Type srcElem = getElemTy(srcTy); - Type tmpElem = getElemTy(tmpTy); - Type dstElem = getElemTy(dstTy); - if (!srcElem || !tmpElem || !dstElem || srcElem != dstElem || srcElem != tmpElem) - return emitOpError() << "expects src, tmp, and dst to have the same element type"; - unsigned elemBytes = getPTOStorageElemByteSize(srcElem); - if (elemBytes == 0) - return emitOpError() << "failed to get transpose element size"; - if (elemBytes != 1 && elemBytes != 2 && elemBytes != 4) - return emitOpError() << "expects transpose element size to be 1, 2, or 4 bytes"; - auto isAllowedWidthType = [&](Type ty) { - if (elemBytes == 4) - return ty.isInteger(32) || ty.isF32(); - if (elemBytes == 2) - return ty.isInteger(16) || ty.isF16() || ty.isBF16(); - return ty.isInteger(8); - }; - if (!isAllowedWidthType(srcElem)) - return emitOpError() << "expects transpose element type to match the supported set for its width"; - auto checkAlignedMajor = [&](Type ty, StringRef name) -> LogicalResult { - auto tb = mlir::dyn_cast(ty); - if (!tb) - return success(); - auto shape = getShapeVec(ty); - if (shape.size() != 2) - return success(); - bool rowMajor = tb.getBLayoutValueI32() == static_cast(pto::BLayout::RowMajor); - int64_t major = rowMajor ? shape[1] : shape[0]; - if (major != ShapedType::kDynamic && (major * static_cast(elemBytes)) % 32 != 0) - return emitOpError() << "expects " << name << " major dimension times element size to be 32-byte aligned on A5"; - return success(); - }; - if (failed(checkAlignedMajor(srcTy, "src")) || failed(checkAlignedMajor(dstTy, "dst"))) - return failure(); - return mlir::success(); - }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -mlir::LogicalResult mlir::pto::TXorOp::verify() { - auto verifyBase = [&]() -> FailureOr { - return verifyMatchingRowMajorBinaryTileOpCommon( - getOperation(), getSrc0().getType(), getSrc1().getType(), - getDst().getType()); - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - FailureOr elemOr = verifyBase(); - if (failed(elemOr)) - return failure(); - Type tmpTy = getTmp().getType(); - if (failed(verifyTileBufCommon(*this, tmpTy, "tmp"))) - return failure(); - Type elem = *elemOr; - if (getElemTy(tmpTy) != elem) - return emitOpError("expects tmp to have the same element type as src0, src1, and dst"); - if (!isRowMajorTileBuf(tmpTy)) - return emitOpError("expects tmp to use row-major layout"); - if (failed(verifyTileBufSameValidShape(*this, tmpTy, getDst().getType(), "tmp", "dst"))) - return failure(); - auto it = mlir::dyn_cast(elem); - if (!it || (it.getWidth() != 8 && it.getWidth() != 16)) - return emitOpError( - "expects A2/A3 txor src0, src1, tmp, and dst element type to be i8/i16"); - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - FailureOr elemOr = verifyBase(); - if (failed(elemOr)) - return failure(); - auto it = mlir::dyn_cast(*elemOr); - if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && - it.getWidth() != 32)) - return emitOpError( - "expects A5 txor src0, src1, and dst element type to be i8/i16/i32"); - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - - -mlir::LogicalResult mlir::pto::TXorSOp::verify() { - auto verifyCommon = [&]() -> FailureOr { - return verifyDistinctRowMajorUnaryTileOpCommon(getOperation(), getSrc(), - getDst(), "src", "dst"); - }; - - auto verifyA2A3 = [&]() -> LogicalResult { - FailureOr elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - auto it = mlir::dyn_cast(*elemOr); - if (!it || (it.getWidth() != 8 && it.getWidth() != 16)) - return emitOpError( - "expects A2/A3 txors src and dst element type to be i8/i16"); - return success(); - }; - - auto verifyA5 = [&]() -> LogicalResult { - FailureOr elemOr = verifyCommon(); - if (failed(elemOr)) - return failure(); - auto it = mlir::dyn_cast(*elemOr); - if (!it || (it.getWidth() != 8 && it.getWidth() != 16 && - it.getWidth() != 32)) - return emitOpError( - "expects A5 txors src and dst element type to be i8/i16/i32"); - return success(); - }; - - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} -mlir::LogicalResult mlir::pto::TPrintOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - auto srcType = getSrc().getType(); - if (auto tb = mlir::dyn_cast(srcType)) { - auto elem = tb.getElementType(); - if (!(elem.isF16() || elem.isF32() || - elem.isInteger(8) || elem.isInteger(16) || elem.isInteger(32))) - return emitOpError() << "expects printable tile element type"; - auto space = getPTOMemorySpaceEnum(srcType); - if (!space || *space != pto::AddressSpace::VEC) - return emitOpError() << "expects printable tile_buf to be in vec address space"; - return success(); - } - if (mlir::dyn_cast(srcType) || - mlir::dyn_cast(srcType)) - return mlir::success(); - return emitOpError() << "expects tile_buf, memref, or partition_tensor_view for src"; -} - - - -[[maybe_unused]] static LogicalResult verifyMatmulCommon(Operation *op, Value lhs, Value rhs, - Value biasOpt, Type maybeDstElemTy, - Type maybeResultElemTy) { - // ---- case A: tensor/memref (ShapedType) ---- - if (auto lhsTy = dyn_cast(lhs.getType())) { - auto rhsTy = dyn_cast(rhs.getType()); - if (!rhsTy || !lhsTy.hasRank() || !rhsTy.hasRank()) - return op->emitOpError("expects lhs and rhs to be ranked tensors or memrefs"); - - if (lhsTy.getElementType() != rhsTy.getElementType()) - return op->emitOpError() - << "expects lhs and rhs to have the same element type, but got lhs=" - << lhsTy.getElementType() << " rhs=" << rhsTy.getElementType(); - - if (biasOpt) { - auto biasTy = dyn_cast(biasOpt.getType()); - if (!biasTy || !biasTy.hasRank()) - return op->emitOpError("expects bias to be a ranked tensor or memref"); - if (biasTy.getElementType() != lhsTy.getElementType()) - return op->emitOpError() - << "expects bias to have the same element type as lhs and rhs, but got bias=" - << biasTy.getElementType() << " vs " << lhsTy.getElementType(); - } - - if (maybeDstElemTy && maybeDstElemTy != lhsTy.getElementType()) - return op->emitOpError() - << "expects dst to have the same element type as lhs and rhs, but got dst=" - << maybeDstElemTy << " vs " << lhsTy.getElementType(); - - if (maybeResultElemTy && maybeResultElemTy != lhsTy.getElementType()) - return op->emitOpError() - << "expects result to have the same element type as lhs and rhs, but got result=" - << maybeResultElemTy << " vs " << lhsTy.getElementType(); - - return success(); - } - - // ---- case B: tile ---- - auto lhsTile = dyn_cast(lhs.getType()); - auto rhsTile = dyn_cast(rhs.getType()); - if (!lhsTile || !rhsTile) - return op->emitOpError("expects lhs and rhs to be ranked tensors, memrefs, or !pto.tile"); - - if (lhsTile.getElementType() != rhsTile.getElementType()) - return op->emitOpError() << "expects lhs and rhs tiles to have the same element type, but got lhs=" - << lhsTile.getElementType() << " rhs=" << rhsTile.getElementType(); - - if ((int64_t)lhsTile.getShape().size() != 2 || (int64_t)rhsTile.getShape().size() != 2) - return op->emitOpError("expects lhs and rhs tiles to be 2D"); - - if (lhsTile.getShape()[1] != rhsTile.getShape()[0]) - return op->emitOpError() << "expects lhs dim1 to equal rhs dim0, but got " - << lhsTile.getShape()[1] << " vs " << rhsTile.getShape()[0]; - - if (biasOpt) { - auto biasTile = dyn_cast(biasOpt.getType()); - if (!biasTile) - return op->emitOpError("expects bias to be !pto.tile when lhs and rhs are !pto.tile"); - if (biasTile.getElementType() != lhsTile.getElementType()) - return op->emitOpError("expects bias to have the same element type as lhs and rhs"); - } - - if (maybeDstElemTy && maybeDstElemTy != lhsTile.getElementType()) - return op->emitOpError() << "expects dst to have the same element type as lhs and rhs"; - - if (maybeResultElemTy && maybeResultElemTy != lhsTile.getElementType()) - return op->emitOpError() << "expects result to have the same element type as lhs and rhs"; - - return success(); -} - -LogicalResult mlir::pto::TMatmulOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - if (failed(verifyMatTileOperands(*this, getLhs().getType(), getRhs().getType(), - getDst().getType()))) - return failure(); - if (failed(verifyMatmulTypeTriple(*this, getElemTy(getLhs().getType()), - getElemTy(getRhs().getType()), - getElemTy(getDst().getType())))) - return failure(); - return verifyMatmulLike(*this, getLhs().getType(), getRhs().getType(), - getDst().getType()); - }; - auto verifyA5 = [&]() -> LogicalResult { return verifyA2A3(); }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult mlir::pto::TGemvOp::verify() { - auto verifyA2A3 = [&]() -> LogicalResult { - if (failed(verifyGemvTileOperands(*this, getLhs().getType(), getRhs().getType(), - getDst().getType()))) - return failure(); - if (failed(verifyMatmulTypeTriple(*this, getElemTy(getLhs().getType()), - getElemTy(getRhs().getType()), - getElemTy(getDst().getType())))) - return failure(); - return verifyMatmulLike(*this, getLhs().getType(), getRhs().getType(), - getDst().getType()); - }; - auto verifyA5 = [&]() -> LogicalResult { return verifyA2A3(); }; - return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); -} - -LogicalResult mlir::pto::TMatmulAccOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - if (failed(verifyAccTileCommon(*this, getAccIn().getType(), "acc_in")) || - failed(verifyMatTileOperands(*this, getLhs().getType(), getRhs().getType(), - getDst().getType()))) - return failure(); - return success(); -} - -LogicalResult mlir::pto::TGemvAccOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - if (failed(verifyAccTileCommon(*this, getAccIn().getType(), "acc_in")) || - failed(verifyGemvTileOperands(*this, getLhs().getType(), getRhs().getType(), - getDst().getType()))) - return failure(); - return success(); -} - -//===----------------------------------------------------------------------===// -// inferReturnTypes() for matmul ops (keep your existing code) -//===----------------------------------------------------------------------=== -[[maybe_unused]] static mlir::Type inferMatmulTileResult2DFromAB(MLIRContext *context, ValueRange operands) { - if (operands.size() < 2) - return mlir::Type(); - - auto lhsTile = dyn_cast(operands[0].getType()); - auto rhsTile = dyn_cast(operands[1].getType()); - if (!lhsTile || !rhsTile) - return mlir::Type(); - - Type elemTy = lhsTile.getElementType(); - - if (operands.size() >= 3) { - if (auto biasTile = dyn_cast(operands[2].getType())) { - return mlir::pto::TileType::get(context, biasTile.getShape(), elemTy); - } - } - - auto lhsShape = lhsTile.getShape(); - auto rhsShape = rhsTile.getShape(); - if (lhsShape.size() >= 2 && rhsShape.size() >= 2) { - int64_t M = lhsShape[0]; - int64_t N = rhsShape[1]; - llvm::SmallVector outShape = {M, N}; - return mlir::pto::TileType::get(context, outShape, elemTy); - } - - return mlir::Type(); -} - -[[maybe_unused]] static RankedTensorType inferMatmulResult2DFromAB(ValueRange operands) { - if (operands.size() < 2) - return RankedTensorType(); - - auto lhsTy = dyn_cast(operands[0].getType()); - auto rhsTy = dyn_cast(operands[1].getType()); - if (!lhsTy || !rhsTy || !lhsTy.hasRank() || !rhsTy.hasRank()) - return RankedTensorType(); - - Type elemTy = lhsTy.getElementType(); - - if (operands.size() >= 3) { - if (auto biasRT = dyn_cast(operands[2].getType())) - return RankedTensorType::get(biasRT.getShape(), elemTy); - if (auto biasMR = dyn_cast(operands[2].getType())) { - if (biasMR.hasStaticShape()) - return RankedTensorType::get(biasMR.getShape(), elemTy); - } - } - - if (lhsTy.getRank() >= 2 && rhsTy.getRank() >= 2) { - int64_t M = lhsTy.getDimSize(0); - int64_t N = rhsTy.getDimSize(1); - return RankedTensorType::get({M, N}, elemTy); - } - - return RankedTensorType(); -} - -[[maybe_unused]] static RankedTensorType inferAccReturnFromAccIn(ValueRange operands) { - if (operands.empty()) - return RankedTensorType(); - if (auto accRT = dyn_cast(operands[0].getType())) - return accRT; - return RankedTensorType(); -} - -namespace mlir { -namespace pto { - -static LogicalResult parseShapeAndElem(AsmParser &parser, - SmallVectorImpl &shape, - Type &elementType, - bool allowDynamic) { - if (parser.parseLess()) - return failure(); - - if (parser.parseDimensionList(shape, allowDynamic)) - return failure(); - - if (parser.parseType(elementType)) - return failure(); - - if (parser.parseGreater()) - return failure(); - - return success(); -} - -static void printShapeAndElem(AsmPrinter &printer, - ArrayRef shape, - Type elementType) { - printer << "<"; - for (auto d : shape) { - if (d == ShapedType::kDynamic) - printer << "?"; - else - printer << d; - printer << "x"; - } - printer.printType(elementType); - printer << ">"; -} - -// ============================================================================= -// PartitionTensorViewType Implementation -// ============================================================================= - -Type PartitionTensorViewType::parse(AsmParser &parser) { - SmallVector shape; - Type elemTy; - if (failed(parseShapeAndElem(parser, shape, elemTy, /*allowDynamic=*/true))) - return Type(); - - return PartitionTensorViewType::get(parser.getContext(), shape, elemTy); -} - -void PartitionTensorViewType::print(AsmPrinter &printer) const { - printShapeAndElem(printer, getShape(), getElementType()); -} - -// ---- TileType ---- -Type TileType::parse(AsmParser &parser) { - SmallVector shape; - Type elemTy; - if (failed(parseShapeAndElem(parser, shape, elemTy, /*allowDynamic=*/true))) - return Type(); - return TileType::get(parser.getContext(), shape, elemTy); -} - -void TileType::print(AsmPrinter &printer) const { - printShapeAndElem(printer, getShape(), getElementType()); -} - -// ---- LocalArrayType ---- -// Asm form: !pto.local_array -// Static shape only (no '?'). Element type must be a scalar; this is enforced -// by the type verifier below. -Type LocalArrayType::parse(AsmParser &parser) { - SmallVector shape; - Type elemTy; - if (failed(parseShapeAndElem(parser, shape, elemTy, /*allowDynamic=*/false))) - return Type(); - return LocalArrayType::getChecked( - [&]() { return parser.emitError(parser.getNameLoc()); }, - parser.getContext(), shape, elemTy); -} - -void LocalArrayType::print(AsmPrinter &printer) const { - printShapeAndElem(printer, getShape(), getElementType()); -} - -LogicalResult LocalArrayType::verify( - llvm::function_ref emitError, - llvm::ArrayRef shape, Type elementType) { - if (shape.empty()) - return emitError() << "'!pto.local_array' requires at least one dimension"; - for (auto [i, d] : llvm::enumerate(shape)) { - if (d <= 0) - return emitError() - << "'!pto.local_array' dimension " << i - << " must be a positive static size, got " << d; - } - if (!elementType.isIntOrFloat()) - return emitError() - << "'!pto.local_array' element type must be a scalar integer or " - "float, got " - << elementType; - return success(); -} - -// ============================================================================= -// Decompose Helper (Reverse Engineering AffineMap -> Strides) -// ============================================================================= - -// Helper: 递归地将 Add 表达式拆解为单独的项列表 -static void flattenAddExpr(AffineExpr expr, SmallVectorImpl &terms) { - if (auto add = llvm::dyn_cast(expr)) { - if (add.getKind() == AffineExprKind::Add) { - flattenAddExpr(add.getLHS(), terms); - flattenAddExpr(add.getRHS(), terms); - return; - } - } - terms.push_back(expr); -} - -// Helper: 从 AffineMap 中提取 Strides -static void decomposeStridedLayout(AffineMap map, SmallVectorImpl &strides) { - // 1. 初始化 - strides.assign(map.getNumDims(), 0); - - if (map.getNumResults() != 1) return; - - // 2. 摊平表达式 - SmallVector terms; - flattenAddExpr(map.getResult(0), terms); - - // 3. 分析每一项 - for (auto term : terms) { - // 情况 A: dN * Const 或 Const * dN - if (auto mul = llvm::dyn_cast(term)) { - if (mul.getKind() == AffineExprKind::Mul) { - AffineExpr lhs = mul.getLHS(); - AffineExpr rhs = mul.getRHS(); - - // 尝试匹配 LHS=Dim, RHS=Const - if (auto dim = llvm::dyn_cast(lhs)) { - if (auto cst = llvm::dyn_cast(rhs)) { - strides[dim.getPosition()] = cst.getValue(); - continue; - } - } - - // 尝试匹配 LHS=Const, RHS=Dim (乘法交换律) - if (auto dim = llvm::dyn_cast(rhs)) { - if (auto cst = llvm::dyn_cast(lhs)) { - strides[dim.getPosition()] = cst.getValue(); - continue; - } - } - } - } - // 情况 B: 单独的 dN (隐含 Stride = 1) - else if (auto dim = llvm::dyn_cast(term)) { - strides[dim.getPosition()] = 1; - } - } -} - -// ============================================================================= -// [Critical] Strict Alignment Protocol Helper -// ============================================================================= -// This function is the SINGLE source of truth for building the AffineMap. -// Both the Parser and the Op Inference MUST use this exact function. -// It ensures that the order of AffineExpr addition is: -// 0 + (d0*str0 + d1*str1...) + (s0*str0 + s1*str1...) -// This guarantees bitwise-identical AffineMaps for verification. -static AffineMap buildStrictBitwiseAffineMap(MLIRContext *ctx, - ArrayRef strides, - bool isMultiDimSymbol) { - unsigned rank = strides.size(); - - // Step 1: Initialize with Constant(0) - AffineExpr totalExpr = getAffineConstantExpr(0, ctx); - - // Step 2: Add Dimensions (d0*str0 + d1*str1...) - // Strictly in order: 0, 1, 2... - for (unsigned i = 0; i < rank; ++i) { - auto dim = getAffineDimExpr(i, ctx); - auto str = getAffineConstantExpr(strides[i], ctx); - totalExpr = totalExpr + (dim * str); - } - - // Step 3: Add Symbols (s0*str0 + s1*str1...) - // Strictly in order: 0, 1, 2... - if (isMultiDimSymbol) { - for (unsigned i = 0; i < rank; ++i) { - auto sym = getAffineSymbolExpr(i, ctx); - auto str = getAffineConstantExpr(strides[i], ctx); - totalExpr = totalExpr + (sym * str); - } - } - // (Optional: handle single dynamic offset case if needed, omitted for clarity) - - // numSymbols is rank if multi-dim (for offsets), else 0 - unsigned numSymbols = isMultiDimSymbol ? rank : 0; - return AffineMap::get(rank, numSymbols, totalExpr); -} - - -// ============================================================================= -// Parser Implementation -// ============================================================================= - -// Helper for parsing [64, 1] -static ParseResult parseStrideList(AsmParser &parser, SmallVectorImpl &strides) { - if (parser.parseLSquare()) return failure(); - do { - int64_t stride; - if (parser.parseInteger(stride)) return failure(); - strides.push_back(stride); - } while (succeeded(parser.parseOptionalComma())); - if (parser.parseRSquare()) return failure(); - return success(); -} - -// The custom attribute parser for: strided<[64, 1], offset: [?, ?]> -[[maybe_unused]] static ParseResult parseStridedLayout(AsmParser &parser, Attribute &layout) { - if (parser.parseLess()) return failure(); - - // 1. Parse Strides - SmallVector strides; - if (parseStrideList(parser, strides)) return failure(); - - bool isMultiDim = false; - unsigned numSymbols = 0; - - // 2. Parse Offset - if (succeeded(parser.parseOptionalComma())) { - if (parser.parseKeyword("offset") || parser.parseColon()) return failure(); - - // Check for multi-dim syntax: [?, ?] - if (succeeded(parser.parseOptionalLSquare())) { - isMultiDim = true; - do { - if (parser.parseQuestion()) return failure(); - numSymbols++; - } while (succeeded(parser.parseOptionalComma())); - if (parser.parseRSquare()) return failure(); - } else { - // Fallback for old scalar syntax '?' - if (parser.parseOptionalQuestion()) { /* handle single scalar */ } - } - } - - if (parser.parseGreater()) return failure(); - - // 3. Validation - if (isMultiDim && numSymbols != strides.size()) { - return parser.emitError(parser.getCurrentLocation(), - "Number of offset symbols must match rank"); - } - - // 4. [CALL SHARED BUILDER] - // Delegate to the strict builder - MLIRContext *ctx = parser.getContext(); - AffineMap map = buildStrictBitwiseAffineMap(ctx, strides, isMultiDim); - - layout = AffineMapAttr::get(map); - return success(); -} - -// ============================================================================= -// Printer Implementation -// ============================================================================= - -[[maybe_unused]] static void printLayout(AsmPrinter &printer, Attribute layoutAttr) { - if (!layoutAttr) return; - auto mapAttr = llvm::dyn_cast(layoutAttr); - if (!mapAttr) { printer << ", " << layoutAttr; return; } - - AffineMap map = mapAttr.getValue(); - if (map.isIdentity()) return; - - // 1. [核心修改] 反解 Strides - SmallVector strides; - decomposeStridedLayout(map, strides); - - printer << ", strided<["; - // 2. 打印真实的 strides - llvm::interleaveComma(strides, printer); - printer << "]"; - - // Print Offset: [?, ?] - unsigned numSyms = map.getNumSymbols(); - if (numSyms > 0) { - printer << ", offset: ["; - for (unsigned i = 0; i < numSyms; ++i) { - printer << "?"; - if (i < numSyms - 1) printer << ", "; - } - printer << "]"; - } - printer << ">"; -} - -// ---- TileBuf --- - - -// Tile subview 相关实现 - -// ============================================================================= -// Op Interface Implementation: SubViewOp -// ============================================================================= - -ParseResult mlir::pto::SubViewOp::parse(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::UnresolvedOperand source; - SmallVector offsets; - SmallVector valids; - Type sourceTy; - Type resultTy; - bool hasExplicitResultTy = false; - - if (parser.parseOperand(source) || parser.parseLSquare() || - parser.parseOperandList(offsets) || parser.parseRSquare() || - parser.parseKeyword("sizes")) - return failure(); - - ArrayAttr sizesAttr; - if (parser.parseAttribute(sizesAttr, "sizes", result.attributes)) - return failure(); - - if (succeeded(parser.parseOptionalKeyword("valid"))) { - OpAsmParser::UnresolvedOperand vrow, vcol; - if (parser.parseLSquare() || parser.parseOperand(vrow) || parser.parseComma() || - parser.parseOperand(vcol) || parser.parseRSquare()) - return failure(); - valids.push_back(vrow); - valids.push_back(vcol); - } - - if (parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(sourceTy)) - return failure(); - - if (succeeded(parser.parseOptionalArrow())) { - if (parser.parseType(resultTy)) - return failure(); - hasExplicitResultTy = true; - } - - if (parser.resolveOperand(source, sourceTy, result.operands)) - return failure(); - - Type indexTy = parser.getBuilder().getIndexType(); - if (parser.resolveOperands(offsets, indexTy, result.operands)) - return failure(); - if (!valids.empty() && - parser.resolveOperands(valids, indexTy, result.operands)) - return failure(); - - int32_t hasValid = valids.empty() ? 0 : 1; - result.addAttribute( - "operandSegmentSizes", - parser.getBuilder().getDenseI32ArrayAttr( - {1, static_cast(offsets.size()), hasValid, hasValid})); - - if (hasExplicitResultTy) { - result.addTypes(resultTy); - return success(); - } - - SmallVector inferredReturnTypes; - DictionaryAttr attrs = result.attributes.getDictionary(parser.getContext()); - if (failed(SubViewOp::inferReturnTypes( - parser.getContext(), std::nullopt, result.operands, attrs, nullptr, - RegionRange(), inferredReturnTypes))) { - return parser.emitError(parser.getCurrentLocation(), - "failed to infer pto.subview result type"); - } - result.addTypes(inferredReturnTypes); - return success(); -} - -void mlir::pto::SubViewOp::print(OpAsmPrinter &printer) { - printer << " " << getSource() << "["; - printer.printOperands(getOffsets()); - printer << "] sizes " << getSizes(); - if (getValidRow()) { - printer << " valid [" << getValidRow() << ", " << getValidCol() << "]"; - } - printer.printOptionalAttrDict((*this)->getAttrs(), - /*elidedAttrs=*/{"operandSegmentSizes", - "sizes"}); - printer << " : " << getSource().getType() << " -> " << getResult().getType(); -} - -LogicalResult SubViewOp::inferReturnTypes( - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - - // 1. 获取 Source Type - if (operands.empty()) return failure(); - auto sourceType = llvm::dyn_cast(operands[0].getType()); - if (!sourceType) return failure(); - - // 2. 获取 subview 逻辑窗口(sizes) - ArrayAttr sizeAttr; - if (properties) { - const auto *prop = properties.as(); - if (prop) sizeAttr = prop->sizes; - } - if (!sizeAttr && attributes) { - sizeAttr = attributes.getAs("sizes"); - } - if (!sizeAttr) return failure(); - - SmallVector subviewShape; - for (auto attr : sizeAttr) { - int64_t dim = llvm::cast(attr).getInt(); - subviewShape.push_back(dim); - } - - // Design: subview 的结果 tile 类型显式表达逻辑子窗口 shape(sizes)。 - ArrayRef parentShape = sourceType.getShape(); - if (subviewShape.size() != parentShape.size()) - return failure(); - - // Derive valid shape from explicit valid_row/valid_col when provided. - // Otherwise default to subview shape (no parent valid-shape inheritance). - SmallVector validShape; - constexpr int64_t kDynamicValidDim = -1; - int64_t rank = static_cast(subviewShape.size()); - Value explicitVRow; - Value explicitVCol; - - // Robustly decode optional valid operands using AttrSizedOperandSegments: - // [source, offsets..., valid_row?, valid_col?] - if (attributes) { - if (auto segAttr = - attributes.getAs("operandSegmentSizes")) { - ArrayRef segs = segAttr.asArrayRef(); - if (segs.size() == 4) { - int32_t srcSeg = segs[0]; - int32_t offSeg = segs[1]; - int32_t vRowSeg = segs[2]; - int32_t vColSeg = segs[3]; - if (srcSeg == 1 && offSeg >= 0 && (vRowSeg == 0 || vRowSeg == 1) && - (vColSeg == 0 || vColSeg == 1)) { - size_t idx = static_cast(srcSeg + offSeg); - if (vRowSeg == 1 && idx < operands.size()) - explicitVRow = operands[idx++]; - if (vColSeg == 1 && idx < operands.size()) - explicitVCol = operands[idx]; - } - } - } - } - - // Fallback for legacy callers that may not provide operandSegmentSizes. - if (!explicitVRow && !explicitVCol && rank == 2) { - size_t expectedWithoutValid = static_cast(1 + rank); - if (operands.size() >= expectedWithoutValid + 2) { - explicitVRow = operands[expectedWithoutValid]; - explicitVCol = operands[expectedWithoutValid + 1]; - } - } - - for (size_t i = 0, e = subviewShape.size(); i < e; ++i) { - int64_t vdim = subviewShape[i]; - Value explicitV = (i == 0) ? explicitVRow : (i == 1 ? explicitVCol : Value()); - if (explicitV) { - auto cst = getConstIndexValue(explicitV); - vdim = cst ? std::min(*cst, subviewShape[i]) : kDynamicValidDim; - } - validShape.push_back(vdim); - } - - // 3. 继承 Config (若为空使用默认) - auto cfg = sourceType.getConfigAttr(); - if (!cfg) cfg = TileBufConfigAttr::getDefault(context); - - // 4. 构建 Result Type - auto canonicalValidShape = canonicalizeTileBufValidShape(validShape); - auto resultType = TileBufType::get( - context, subviewShape, sourceType.getElementType(), - sourceType.getMemorySpace(), canonicalValidShape, cfg); - - inferredReturnTypes.push_back(resultType); - return success(); -} - -// ============================================================================= -// SubViewOp verifier -// ============================================================================= -static bool getConstIndex(Value v, int64_t &out) { - if (auto cOp = v.getDefiningOp()) { - out = cOp.value(); - return true; - } - if (auto cInt = v.getDefiningOp()) { - out = cInt.value(); - return true; - } - if (auto cOp = v.getDefiningOp()) { - if (auto ia = dyn_cast(cOp.getValue())) { - out = ia.getInt(); - return true; - } - } - if (auto castOp = v.getDefiningOp()) - return getConstIndex(castOp.getIn(), out); - if (auto extOp = v.getDefiningOp()) - return getConstIndex(extOp.getIn(), out); - if (auto extOp = v.getDefiningOp()) - return getConstIndex(extOp.getIn(), out); - if (auto truncOp = v.getDefiningOp()) - return getConstIndex(truncOp.getIn(), out); - return false; -} - -static LogicalResult computeInnerShape(TileBufConfigAttr cfg, Type elemTy, - int64_t &innerRows, int64_t &innerCols, - bool &boxed, int32_t &bl, int32_t &sl) { - auto readBLayoutI32 = [](Attribute attr, int32_t &out) -> bool { - if (auto a = dyn_cast(attr)) { - out = (int32_t)a.getValue(); - return true; - } - if (auto a = dyn_cast(attr)) { - out = (int32_t)a.getInt(); - return true; - } - return false; - }; - auto readSLayoutI32 = [](Attribute attr, int32_t &out) -> bool { - if (auto a = dyn_cast(attr)) { - out = (int32_t)a.getValue(); - return true; - } - if (auto a = dyn_cast(attr)) { - out = (int32_t)a.getInt(); - return true; - } - return false; - }; - bl = 0; - sl = 0; - int32_t fr = 512; - (void)readBLayoutI32(cfg.getBLayout(), bl); - (void)readSLayoutI32(cfg.getSLayout(), sl); - if (auto attr = dyn_cast(cfg.getSFractalSize())) fr = (int32_t)attr.getInt(); - - boxed = (sl != 0); - if (!boxed) { - innerRows = 1; - innerCols = 1; - return success(); - } - - int64_t elemBytes = static_cast(getElemByteSize(elemTy)); - if (elemBytes <= 0) return failure(); - - if (fr == 1024) { - innerRows = 16; - innerCols = 16; - return success(); - } - if (fr == 32) { - innerRows = 16; - innerCols = 2; - return success(); - } - if (fr == 512) { - if (sl == 1) { - innerRows = 16; - innerCols = 32 / elemBytes; - return success(); - } - if (sl == 2) { - innerRows = 32 / elemBytes; - innerCols = 16; - return success(); - } - } - return failure(); -} - -mlir::LogicalResult mlir::pto::SubViewOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - auto srcTy = llvm::dyn_cast(getSource().getType()); - auto dstTy = llvm::dyn_cast(getResult().getType()); - if (!srcTy || !dstTy) - return emitOpError("expects tile_buf src and tile_buf result"); - if (srcTy.getRank() != 2 || dstTy.getRank() != 2) - return emitOpError("expects rank-2 tilebuf for src/dst"); - - auto sizesAttr = getSizes(); - if (!sizesAttr || sizesAttr.size() != 2) - return emitOpError("subview expects 2D sizes"); - int64_t sizeR = cast(sizesAttr[0]).getInt(); - int64_t sizeC = cast(sizesAttr[1]).getInt(); - if (sizeR <= 0 || sizeC <= 0) - return emitOpError("subview sizes must be positive"); - if (getOffsets().size() != 2) - return emitOpError("subview expects 2D offsets"); - - int64_t offR = 0, offC = 0; - bool offRConst = getConstIndex(getOffsets()[0], offR); - bool offCConst = getConstIndex(getOffsets()[1], offC); - if (offRConst && offR < 0) - return emitOpError("subview offsets must be non-negative"); - if (offCConst && offC < 0) - return emitOpError("subview offsets must be non-negative"); - - bool hasValidRow = static_cast(getValidRow()); - bool hasValidCol = static_cast(getValidCol()); - if (hasValidRow != hasValidCol) - return emitOpError( - "subview expects valid_row and valid_col to be both present or both absent"); - - if (hasValidRow) { - int64_t vRow = 0, vCol = 0; - if (getConstIndex(getValidRow(), vRow)) { - if (vRow <= 0) - return emitOpError("valid_row must be positive when constant"); - if (vRow > sizeR) - return emitOpError("valid_row must be <= subview row size"); - } - if (getConstIndex(getValidCol(), vCol)) { - if (vCol <= 0) - return emitOpError("valid_col must be positive when constant"); - if (vCol > sizeC) - return emitOpError("valid_col must be <= subview col size"); - } - } - - auto dstShape = dstTy.getShape(); - if (dstShape.size() != 2) - return emitOpError("expects result to be rank-2"); - auto srcShape = srcTy.getShape(); - if (srcShape.size() != 2) - return emitOpError("expects source to be rank-2"); - if (dstShape[0] != sizeR || dstShape[1] != sizeC) - return emitOpError("expects result shape to match subview sizes"); - - if (dstTy.getElementType() != srcTy.getElementType()) - return emitOpError("expects result element type to match source"); - if (dstTy.getMemorySpace() != srcTy.getMemorySpace()) - return emitOpError("expects result address space to match source"); - auto srcCfg = srcTy.getConfigAttr(); - if (!srcCfg) srcCfg = TileBufConfigAttr::getDefault(getContext()); - auto dstCfg = dstTy.getConfigAttr(); - if (!dstCfg) dstCfg = TileBufConfigAttr::getDefault(getContext()); - if (dstCfg != srcCfg) - return emitOpError("expects result tile config to match source"); - - // Design choice: when valid[...] is omitted, infer result valid_shape from - // subview sizes directly. We intentionally do not constrain it by source - // valid_shape to allow user-controlled subview semantics. - - auto expectedValidDim = [&](Value explicitValid, int64_t defaultSize) { - if (!explicitValid) - return defaultSize; - int64_t c = 0; - if (getConstIndex(explicitValid, c)) - return std::min(c, defaultSize); - return ShapedType::kDynamic; - }; - int64_t expectedVRow = expectedValidDim(getValidRow(), sizeR); - int64_t expectedVCol = expectedValidDim(getValidCol(), sizeC); - auto dstValid = dstTy.getValidShape(); - if (dstValid.size() != 2) - return emitOpError("expects result to have rank-2 valid_shape"); - if (dstValid[0] != expectedVRow) - return emitOpError("expects result valid_shape[0] to match inferred/explicit valid_row"); - if (dstValid[1] != expectedVCol) - return emitOpError("expects result valid_shape[1] to match inferred/explicit valid_col"); - - auto cfg = srcTy.getConfigAttr(); - if (!cfg) cfg = TileBufConfigAttr::getDefault(getContext()); - - int64_t innerRows = 1, innerCols = 1; - bool boxed = false; - int32_t bl = 0, sl = 0; - if (failed(computeInnerShape(cfg, srcTy.getElementType(), innerRows, innerCols, - boxed, bl, sl))) - return emitOpError("unsupported tile layout for subview"); - - if (!boxed) - return success(); - - // Boxed layout: require static 2D sizes with inner alignment. Offsets may be - // dynamic, but static offsets must be aligned. - if (sizeR % innerRows != 0 || sizeC % innerCols != 0) - return emitOpError("boxed layout subview sizes must be multiples of inner shape"); - - if (offRConst) { - if (offR % innerRows != 0) - return emitOpError("boxed layout subview offsets must be multiples of inner shape"); - } - if (offCConst) { - if (offC % innerCols != 0) - return emitOpError("boxed layout subview offsets must be multiples of inner shape"); - } - - (void)bl; - if (srcShape.size() != 2 || - srcShape[0] == ShapedType::kDynamic || - srcShape[1] == ShapedType::kDynamic) { - return emitOpError("boxed layout subview requires static source shape"); - } - - return success(); -} - -} // namespace pto -} // namespace mlir - -using namespace mlir; -using namespace mlir::pto; - -// ============================================================================= -// Helper Functions -// ============================================================================= - -[[maybe_unused]] static AddressSpace getAddressSpace(Value val) { - auto type = llvm::dyn_cast(val.getType()); - if (!type) return AddressSpace::Zero; // Default - - // 假设你的 AddressSpaceAttr 存储在 MemRef 的 memorySpace 中 - // 需要根据你的 getPTOAddressSpaceAttr 实现来调整 - auto attr = llvm::dyn_cast_or_null(type.getMemorySpace()); - if (attr) return attr.getAddressSpace(); - return AddressSpace::Zero; -} - -// ============================================================================= -// Side Effects Implementation -// ============================================================================= - -// [Fix] 辅助函数:重载以支持 OpOperand* 和 OpResult,避免直接传 Value - -// 针对操作数 (Operand) 的重载 -static void addEffect( - SmallVectorImpl> &effects, - OpOperand *operand, MemoryEffects::Effect *effect) { - if (operand) - effects.emplace_back(effect, operand, SideEffects::DefaultResource::get()); -} - -// 针对结果 (Result) 的重载 -static void addEffect( - SmallVectorImpl> &effects, - OpResult result, MemoryEffects::Effect *effect) { - if (result) - effects.emplace_back(effect, result, SideEffects::DefaultResource::get()); -} - -// === TLoadOp === -// Read: src, Write: dst -// 针对 OpOperand* 的重载 -void TLoadOp::getEffects(SmallVectorImpl> &effects) { - // [Fix] 单个操作数,直接取地址 - addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); -} - -void TPrefetchOp::getEffects( - SmallVectorImpl> &effects) { - addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); -} - -// === TAbsOp === -// Read: src, Write: dst -void TAbsOp::getEffects( - SmallVectorImpl> &effects) { - addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); -} - -// === TStoreOp === -// Read: src, Write: dst (GM) -void TStoreOp::getEffects(SmallVectorImpl> &effects) { - addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); - auto preQuantRange = getPreQuantScalarMutable(); - if (!preQuantRange.empty()) - addEffect(effects, &*preQuantRange.begin(), MemoryEffects::Read::get()); - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); -} - -// === TMovOp === -// Read: src, Write: dst -void TMovOp::getEffects(SmallVectorImpl> &effects) { - addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); - auto fpRange = getFpMutable(); - if (!fpRange.empty()) - addEffect(effects, &*fpRange.begin(), MemoryEffects::Read::get()); - auto preQuantRange = getPreQuantScalarMutable(); - if (!preQuantRange.empty()) - addEffect(effects, &*preQuantRange.begin(), MemoryEffects::Read::get()); - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); -} - -#define PTO_ADD_READ(operand) addEffect(effects, &(operand), MemoryEffects::Read::get()) -#define PTO_ADD_WRITE(operand) addEffect(effects, &(operand), MemoryEffects::Write::get()) - -#define PTO_DEFINE_UNARY_EFFECTS(OpClass, srcOperand, dstOperand) \ - void OpClass::getEffects( \ - SmallVectorImpl> &effects) { \ - PTO_ADD_READ(srcOperand); \ - PTO_ADD_WRITE(dstOperand); \ - } - -#define PTO_DEFINE_BINARY_EFFECTS(OpClass, lhsOperand, rhsOperand, dstOperand) \ - void OpClass::getEffects( \ - SmallVectorImpl> &effects) { \ - PTO_ADD_READ(lhsOperand); \ - PTO_ADD_READ(rhsOperand); \ - PTO_ADD_WRITE(dstOperand); \ - } - -#define PTO_DEFINE_TERNARY_EFFECTS(OpClass, op0, op1, op2, dstOperand) \ - void OpClass::getEffects( \ - SmallVectorImpl> &effects) { \ - PTO_ADD_READ(op0); \ - PTO_ADD_READ(op1); \ - PTO_ADD_READ(op2); \ - PTO_ADD_WRITE(dstOperand); \ - } - -#define PTO_DEFINE_QUATERNARY_EFFECTS(OpClass, op0, op1, op2, op3, dstOperand) \ - void OpClass::getEffects( \ - SmallVectorImpl> &effects) { \ - PTO_ADD_READ(op0); \ - PTO_ADD_READ(op1); \ - PTO_ADD_READ(op2); \ - PTO_ADD_READ(op3); \ - PTO_ADD_WRITE(dstOperand); \ - } - -void LoadScalarOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getPtrMutable()); -} - -void StoreScalarOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_WRITE(getPtrMutable()); -} - -// === Tile/Device ops added for InsertSync === - -// MGATHER: Read(mem, idx) -> Write(dst) -void MGatherOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getMemMutable()); - PTO_ADD_READ(getIdxMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -// MSCATTER: Read(src, idx) -> Write(mem) -void MScatterOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_READ(getIdxMutable()); - PTO_ADD_WRITE(getMemMutable()); -} - -// TGETVAL: Read(src) -> scalar result -void TGetValOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); -} - -void THistogramOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_READ(getIdxMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -void TGetScaleAddrOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -// TSETVAL: Write(dst) (single element update) -void TSetValOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_WRITE(getDstMutable()); -} - -// SET_VALIDSHAPE: update runtime valid row/col metadata on source tile in-place. -void SetValidShapeOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_WRITE(getSourceMutable()); -} - -// GET_VALIDSHAPE: read runtime valid row/col metadata from source tile. -void GetValidShapeOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSourceMutable()); -} - -// Elementwise + reductions: mostly PIPE_V tilebuf ops -PTO_DEFINE_BINARY_EFFECTS(TAddOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_TERNARY_EFFECTS(TAddCOp, getSrc0Mutable(), getSrc1Mutable(), getSrc2Mutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TAddSOp, getSrcMutable(), getDstMutable()) -PTO_DEFINE_BINARY_EFFECTS(TAddSCOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -void TAxpyOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_READ(getScalarMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -PTO_DEFINE_BINARY_EFFECTS(TAndOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_BINARY_EFFECTS(TConcatOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_QUATERNARY_EFFECTS(TConcatidxOp, getSrc0Mutable(), getSrc1Mutable(), getSrc0IdxMutable(), getSrc1IdxMutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TAndSOp, getSrcMutable(), getDstMutable()) - -// TCI: Write(dst) (generates sequence) -void TCIOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_WRITE(getDstMutable()); -} - -// TTRI: Write(dst) (generates triangular mask) -void TTriOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_WRITE(getDstMutable()); -} - -PTO_DEFINE_BINARY_EFFECTS(TCmpOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TCmpSOp, getSrcMutable(), getDstMutable()) - -PTO_DEFINE_UNARY_EFFECTS(TColExpandOp, getSrcMutable(), getDstMutable()) -PTO_DEFINE_BINARY_EFFECTS(TColExpandAddOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_BINARY_EFFECTS(TColExpandMulOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_BINARY_EFFECTS(TColExpandDivOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_BINARY_EFFECTS(TColExpandSubOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_BINARY_EFFECTS(TColExpandExpdifOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_BINARY_EFFECTS(TColExpandMaxOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_BINARY_EFFECTS(TColExpandMinOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TColMaxOp, getSrcMutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TColMinOp, getSrcMutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TColProdOp, getSrcMutable(), getDstMutable()) - -void TColArgMaxOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_WRITE(getTmpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -void TColArgMinOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_WRITE(getTmpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -void TColSumOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - auto tmp = getTmpMutable(); - if (!tmp.empty()) { - PTO_ADD_WRITE(tmp[0]); - } - PTO_ADD_WRITE(getDstMutable()); -} - -void TCvtOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_WRITE(getDstMutable()); -} -void TRandomOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_WRITE(getDstMutable()); -} -PTO_DEFINE_BINARY_EFFECTS(TDivOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) - -// TDIVS has custom assembly format; conservatively treat first 2 operands as reads. -void TDivSOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_READ(getScalarMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -PTO_DEFINE_UNARY_EFFECTS(TExpOp, getSrcMutable(), getDstMutable()) - -// TEXPANDS: Write(dst) (broadcast scalar) -void TExpandsOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_WRITE(getDstMutable()); -} - -// TEXTRACT: Read(src) -> Write(dst) -void TExtractOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -// TINSERT: Read(src) -> Write(dst) -void TInsertOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -// TEXTRACT_FP: Read(src), Read(fp) -> Write(dst) -void TExtractFPOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_READ(getFpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -// TINSERT_FP: Read(src), Read(fp) -> Write(dst) -void TInsertFPOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_READ(getFpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -PTO_DEFINE_UNARY_EFFECTS(TFillPadOp, getSrcMutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TFillPadExpandOp, getSrcMutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TFillPadInplaceOp, getSrcMutable(), getDstMutable()) - -void TGatherOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - if (auto cdst = getCdstMutable(); !cdst.empty()) - PTO_ADD_WRITE(cdst[0]); - if (auto indices = getIndicesMutable(); !indices.empty()) - PTO_ADD_READ(indices[0]); - if (auto tmp = getTmpMutable(); !tmp.empty()) - PTO_ADD_READ(tmp[0]); - PTO_ADD_WRITE(getDstMutable()); -} - -PTO_DEFINE_BINARY_EFFECTS(TGatherBOp, getSrcMutable(), getOffsetsMutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TLogOp, getSrcMutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TLReluOp, getSrcMutable(), getDstMutable()) - -PTO_DEFINE_BINARY_EFFECTS(TMaxOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TMaxSOp, getSrcMutable(), getDstMutable()) -PTO_DEFINE_BINARY_EFFECTS(TMinOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TMinSOp, getSrcMutable(), getDstMutable()) - -PTO_DEFINE_BINARY_EFFECTS(TMovFPOp, getSrcMutable(), getFpMutable(), getDstMutable()) - -void TMrgSortOp::getEffects( - SmallVectorImpl> &effects) { - for (auto &opnd : getSrcsMutable()) { - PTO_ADD_READ(opnd); - } - auto tmp = getTmpMutable(); - if (!tmp.empty()) - PTO_ADD_WRITE(tmp[0]); - for (auto &opnd : getDstsMutable()) { - PTO_ADD_WRITE(opnd); - } - auto executed = getExcutedMutable(); - if (!executed.empty()) { - PTO_ADD_WRITE(executed[0]); - } -} - -PTO_DEFINE_BINARY_EFFECTS(TMulOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TMulSOp, getSrc0Mutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TNegOp, getSrcMutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TNotOp, getSrcMutable(), getDstMutable()) -PTO_DEFINE_BINARY_EFFECTS(TOrOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TOrSOp, getSrcMutable(), getDstMutable()) - -PTO_DEFINE_BINARY_EFFECTS(TPartAddOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_BINARY_EFFECTS(TPartMaxOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_BINARY_EFFECTS(TPartMinOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -void TPartArgMaxOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrc0Mutable()); - PTO_ADD_READ(getSrc1Mutable()); - PTO_ADD_READ(getSrc0IdxMutable()); - PTO_ADD_READ(getSrc1IdxMutable()); - PTO_ADD_WRITE(getDstMutable()); - PTO_ADD_WRITE(getDstIdxMutable()); -} -void TPartArgMinOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrc0Mutable()); - PTO_ADD_READ(getSrc1Mutable()); - PTO_ADD_READ(getSrc0IdxMutable()); - PTO_ADD_READ(getSrc1IdxMutable()); - PTO_ADD_WRITE(getDstMutable()); - PTO_ADD_WRITE(getDstIdxMutable()); -} -PTO_DEFINE_BINARY_EFFECTS(TPartMulOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -// TPRELU: Read(src0, src1) -> Write(tmp, dst) -void TPReluOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrc0Mutable()); - PTO_ADD_READ(getSrc1Mutable()); - // A5 pto-isa TPRELU implementation does not consume tmp; modeling tmp as a - // write-only scratch on A5 incorrectly inflates local-memory planning and - // can trigger false vec-overflow diagnostics. - if (getTargetArch(getOperation()) != PTOArch::A5) - PTO_ADD_WRITE(getTmpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -void TQuantOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_READ(getFpMutable()); - auto offsetRange = getOffsetMutable(); - if (!offsetRange.empty()) - PTO_ADD_READ(offsetRange[0]); - PTO_ADD_WRITE(getDstMutable()); -} -PTO_DEFINE_TERNARY_EFFECTS(TDequantOp, getSrcMutable(), getScaleMutable(), - getOffsetMutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TRecipOp, getSrcMutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TReluOp, getSrcMutable(), getDstMutable()) -PTO_DEFINE_BINARY_EFFECTS(TFModOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TFModSOp, getSrcMutable(), getDstMutable()) -void TRemOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrc0Mutable()); - PTO_ADD_READ(getSrc1Mutable()); - PTO_ADD_WRITE(getTmpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -void TRemSOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_WRITE(getTmpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} -PTO_DEFINE_UNARY_EFFECTS(TRowExpandOp, getSrcMutable(), getDstMutable()) - -void TRowExpandDivOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrc0Mutable()); - PTO_ADD_READ(getSrc1Mutable()); - auto tmp = getTmpMutable(); - if (!tmp.empty()) - PTO_ADD_WRITE(tmp[0]); - PTO_ADD_WRITE(getDstMutable()); -} - -void TRowExpandMulOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrc0Mutable()); - PTO_ADD_READ(getSrc1Mutable()); - auto tmp = getTmpMutable(); - if (!tmp.empty()) - PTO_ADD_WRITE(tmp[0]); - PTO_ADD_WRITE(getDstMutable()); -} - -void TRowExpandSubOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrc0Mutable()); - PTO_ADD_READ(getSrc1Mutable()); - auto tmp = getTmpMutable(); - if (!tmp.empty()) - PTO_ADD_WRITE(tmp[0]); - PTO_ADD_WRITE(getDstMutable()); -} - -PTO_DEFINE_BINARY_EFFECTS(TRowExpandAddOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) - -void TRowExpandExpdifOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrc0Mutable()); - PTO_ADD_READ(getSrc1Mutable()); - auto tmp = getTmpMutable(); - if (!tmp.empty()) - PTO_ADD_WRITE(tmp[0]); - PTO_ADD_WRITE(getDstMutable()); -} - -void TRowExpandMaxOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrc0Mutable()); - PTO_ADD_READ(getSrc1Mutable()); - auto tmp = getTmpMutable(); - if (!tmp.empty()) - PTO_ADD_WRITE(tmp[0]); - PTO_ADD_WRITE(getDstMutable()); -} - -void TRowExpandMinOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrc0Mutable()); - PTO_ADD_READ(getSrc1Mutable()); - auto tmp = getTmpMutable(); - if (!tmp.empty()) - PTO_ADD_WRITE(tmp[0]); - PTO_ADD_WRITE(getDstMutable()); -} - -// Row reductions use tmp scratch tile. -void TRowMaxOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_WRITE(getTmpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -void TRowArgMaxOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - // A5 lowering does not consume tmp for TROWARGMAX; modeling tmp as a - // scratch write inflates local-memory planning and can trigger false - // vec-overflow diagnostics, mirroring the fixed A5 TPRELU issue. - if (getTargetArch(getOperation()) != PTOArch::A5) - PTO_ADD_WRITE(getTmpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -void TRowMinOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_WRITE(getTmpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -void TRowArgMinOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - // A5 lowering does not consume tmp for TROWARGMIN; modeling tmp as a - // scratch write inflates local-memory planning and can trigger false - // vec-overflow diagnostics, mirroring the fixed A5 TPRELU issue. - if (getTargetArch(getOperation()) != PTOArch::A5) - PTO_ADD_WRITE(getTmpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -void TRowSumOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_WRITE(getTmpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -void TRowProdOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_WRITE(getTmpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} -void TRsqrtOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - auto tmp = getTmpMutable(); - if (!tmp.empty()) - PTO_ADD_WRITE(tmp[0]); - PTO_ADD_WRITE(getDstMutable()); -} - -void TScatterOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - if (getIndexes()) { - auto idx = getIndexesMutable(); - if (!idx.empty()) - PTO_ADD_READ(idx[0]); - } - PTO_ADD_WRITE(getDstMutable()); -} - -// Select: Read(mask, src0, src1) -> Write(tmp, dst) -void TSelOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getMaskMutable()); - PTO_ADD_READ(getSrc0Mutable()); - PTO_ADD_READ(getSrc1Mutable()); - PTO_ADD_WRITE(getTmpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -// TSELS: Read(src0, src1) -> Write(tmp, dst) -void TSelSOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getMaskMutable()); - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_WRITE(getTmpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -PTO_DEFINE_BINARY_EFFECTS(TShlOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_BINARY_EFFECTS(TShrOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TShlSOp, getSrcMutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TShrSOp, getSrcMutable(), getDstMutable()) - -// TSORT32: Read(src, idx) -> Write(dst [, tmp]) -void TSort32Op::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_READ(getIdxMutable()); - auto tmp = getTmpMutable(); - if (!tmp.empty()) - PTO_ADD_WRITE(tmp[0]); - PTO_ADD_WRITE(getDstMutable()); -} - -PTO_DEFINE_UNARY_EFFECTS(TSqrtOp, getSrcMutable(), getDstMutable()) -PTO_DEFINE_BINARY_EFFECTS(TSubOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) -PTO_DEFINE_TERNARY_EFFECTS(TSubCOp, getSrc0Mutable(), getSrc1Mutable(), getSrc2Mutable(), getDstMutable()) -PTO_DEFINE_UNARY_EFFECTS(TSubSOp, getSrcMutable(), getDstMutable()) -PTO_DEFINE_BINARY_EFFECTS(TSubSCOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) - -// TXORS: Read(src) -> Write(tmp, dst) -void TXorSOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_WRITE(getTmpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -// TXOR: Read(src0, src1) -> Write(tmp?, dst) -void TXorOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrc0Mutable()); - PTO_ADD_READ(getSrc1Mutable()); - PTO_ADD_WRITE(getTmpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -// TTRANS: Read(src) -> Write(tmp, dst) -void TTransOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_WRITE(getTmpMutable()); - PTO_ADD_WRITE(getDstMutable()); -} - -void TPrintOp::getEffects( - SmallVectorImpl> &effects) { - PTO_ADD_READ(getSrcMutable()); - PTO_ADD_WRITE(getSrcMutable()); -} - -#undef PTO_DEFINE_TERNARY_EFFECTS -#undef PTO_DEFINE_BINARY_EFFECTS -#undef PTO_DEFINE_UNARY_EFFECTS -#undef PTO_ADD_WRITE -#undef PTO_ADD_READ - -// === TMatmulOp === -// Read: lhs, rhs, (bias), Write: dst -void TMatmulOp::getEffects(SmallVectorImpl> &effects) { - // Singleton -> 直接取地址 - addEffect(effects, &getLhsMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getRhsMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); -} - -// === TMatmulAccOp === -// Read: acc_in, lhs, rhs, Write: dst -void TMatmulAccOp::getEffects(SmallVectorImpl> &effects) { - addEffect(effects, &getAccInMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getLhsMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getRhsMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); -} - -// === TMatmulBiasOp === -// Read: a, b, bias, Write: dst -void TMatmulBiasOp::getEffects(SmallVectorImpl> &effects) { - addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); - // 这里的 bias 是必选的 AnyType:$bias,所以是 Singleton - addEffect(effects, &getBiasMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); -} - -// === TGemvOp === -// Read: lhs, rhs, Write: dst -void TGemvOp::getEffects(SmallVectorImpl> &effects) { - addEffect(effects, &getLhsMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getRhsMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); -} - -// === TGemvAccOp === -// Read: acc_in, lhs, rhs, Write: dst -void TGemvAccOp::getEffects(SmallVectorImpl> &effects) { - addEffect(effects, &getAccInMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getLhsMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getRhsMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); -} - -// === TGemvBiasOp === -// Read: a, b, bias, Write: dst -void TGemvBiasOp::getEffects(SmallVectorImpl> &effects) { - addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getBiasMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); -} - -// === TGemvMxOp === -// Read: a, a_scale, b, b_scale, Write: dst -void TGemvMxOp::getEffects(SmallVectorImpl> &effects) { - addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getAScaleMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getBScaleMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); -} - -// === TGemvMxAccOp === -// Read: c_in, a, a_scale, b, b_scale, Write: dst -void TGemvMxAccOp::getEffects(SmallVectorImpl> &effects) { - addEffect(effects, &getCInMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getAScaleMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getBScaleMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); -} - -// === TGemvMxBiasOp === -// Read: a, a_scale, b, b_scale, bias, Write: dst -void TGemvMxBiasOp::getEffects(SmallVectorImpl> &effects) { - addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getAScaleMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getBScaleMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getBiasMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); -} - -// === TMatmulOp === -void TMatmulMxOp::getEffects(SmallVectorImpl> &effects) { - addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getAScaleMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getBScaleMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); -} - -// === TMatmulAccMxOp === -// Read: acc_in, lhs, rhs, Write: dst -void TMatmulMxAccOp::getEffects(SmallVectorImpl> &effects) { - addEffect(effects, &getCInMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getAScaleMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getBScaleMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); -} - -// === TMatmulBiasMxOp === -// Read: a, b, bias, Write: dst -void TMatmulMxBiasOp::getEffects(SmallVectorImpl> &effects) { - addEffect(effects, &getAMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getAScaleMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getBMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getBScaleMutable(), MemoryEffects::Read::get()); - // 这里的 bias 是必选的 AnyType:$bias,所以是 Singleton - addEffect(effects, &getBiasMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); -} - -static bool isInsideSectionCube(Operation *op) { - return op->getParentOfType() != nullptr; -} - -static bool isInsideSectionVector(Operation *op) { - return op->getParentOfType() != nullptr; -} - -static std::optional -getEnclosingFunctionKernelKind(Operation *op) { - auto funcOp = op->getParentOfType(); - if (!funcOp) - return std::nullopt; - - auto kernelKindAttr = - funcOp->getAttrOfType( - FunctionKernelKindAttr::name); - if (!kernelKindAttr) - return std::nullopt; - - return kernelKindAttr.getKernelKind(); -} - -static bool isInsideSectionOrAttributedKernel(Operation *op) { - return isInsideSectionCube(op) || isInsideSectionVector(op) || - getEnclosingFunctionKernelKind(op).has_value(); -} - -static LogicalResult verifySplitAttr(Operation *op, int64_t split) { - if (split < 0 || split > 2) - return op->emitOpError("expects 'split' to be 0, 1, or 2"); - return success(); -} - -static LogicalResult verifyFrontendKernelKind(Operation *op, - FunctionKernelKind expected, - StringRef kernelName) { - auto kernelKind = getEnclosingFunctionKernelKind(op); - if (!kernelKind || *kernelKind != expected) { - return op->emitOpError("must be inside a ") - << kernelName << " kernel function"; - } - return success(); -} - -static ParseResult parseFrontendInitializePipeOp(OpAsmParser &parser, - OperationState &result) { - NamedAttrList attrs; - bool sawId = false; - bool sawDirMask = false; - bool sawSlotSize = false; - bool sawLocalSlotNum = false; - bool sawNoSplit = false; - - if (parser.parseLBrace()) - return failure(); - - while (failed(parser.parseOptionalRBrace())) { - StringRef keyword; - if (parser.parseKeyword(&keyword) || parser.parseEqual()) - return failure(); - - if (keyword == "id") { - if (sawId) - return parser.emitError(parser.getCurrentLocation(), - "duplicate 'id' clause"); - IntegerAttr idAttr; - if (parser.parseAttribute(idAttr, parser.getBuilder().getI32Type(), "id", - attrs)) - return failure(); - sawId = true; - } else if (keyword == "dir_mask") { - if (sawDirMask) - return parser.emitError(parser.getCurrentLocation(), - "duplicate 'dir_mask' clause"); - IntegerAttr dirMaskAttr; - if (parser.parseAttribute(dirMaskAttr, parser.getBuilder().getI8Type(), - "dir_mask", attrs)) - return failure(); - sawDirMask = true; - } else if (keyword == "slot_size") { - if (sawSlotSize) - return parser.emitError(parser.getCurrentLocation(), - "duplicate 'slot_size' clause"); - IntegerAttr slotSizeAttr; - if (parser.parseAttribute(slotSizeAttr, parser.getBuilder().getI32Type(), - "slot_size", attrs)) - return failure(); - sawSlotSize = true; - } else if (keyword == "local_slot_num") { - if (sawLocalSlotNum) - return parser.emitError(parser.getCurrentLocation(), - "duplicate 'local_slot_num' clause"); - IntegerAttr localSlotNumAttr; - if (parser.parseAttribute(localSlotNumAttr, parser.getBuilder().getI32Type(), - "local_slot_num", attrs)) - return failure(); - sawLocalSlotNum = true; - } else if (keyword == "nosplit") { - if (sawNoSplit) - return parser.emitError(parser.getCurrentLocation(), - "duplicate 'nosplit' clause"); - BoolAttr noSplitAttr; - if (parser.parseAttribute(noSplitAttr, "nosplit", attrs)) - return failure(); - sawNoSplit = true; - } else { - return parser.emitError(parser.getCurrentLocation()) - << "unexpected keyword '" << keyword << "'"; - } - - if (succeeded(parser.parseOptionalRBrace())) - break; - if (parser.parseComma()) - return failure(); - } - - if (!sawDirMask) - return parser.emitError(parser.getNameLoc(), "expected 'dir_mask' clause"); - if (!sawSlotSize) - return parser.emitError(parser.getNameLoc(), "expected 'slot_size' clause"); - if (!sawId) - attrs.set("id", parser.getBuilder().getI32IntegerAttr(0)); - - OpAsmParser::UnresolvedOperand gmSlotBuffer; - OpAsmParser::UnresolvedOperand gmSlotTensor; - OpAsmParser::UnresolvedOperand c2vConsumerBuf; - OpAsmParser::UnresolvedOperand v2cConsumerBuf; - Type gmSlotBufferTy; - Type gmSlotTensorTy; - Type c2vConsumerBufTy; - Type v2cConsumerBufTy; - bool hasGmSlotBuffer = false; - bool hasGmSlotTensor = false; - bool hasC2vConsumerBuf = false; - bool hasV2cConsumerBuf = false; - - if (parser.parseLParen()) - return failure(); - while (failed(parser.parseOptionalRParen())) { - StringRef keyword; - if (parser.parseKeyword(&keyword) || parser.parseEqual()) - return failure(); - - if (keyword == "gm_slot_buffer") { - if (hasGmSlotBuffer) - return parser.emitError(parser.getCurrentLocation(), - "duplicate 'gm_slot_buffer' operand"); - if (parser.parseOperand(gmSlotBuffer) || - parser.parseColonType(gmSlotBufferTy)) - return failure(); - hasGmSlotBuffer = true; - } else if (keyword == "gm_slot_tensor") { - if (hasGmSlotTensor) - return parser.emitError(parser.getCurrentLocation(), - "duplicate 'gm_slot_tensor' operand"); - if (parser.parseOperand(gmSlotTensor) || - parser.parseColonType(gmSlotTensorTy)) - return failure(); - hasGmSlotTensor = true; - } else if (keyword == "c2v_consumer_buf") { - if (hasC2vConsumerBuf) - return parser.emitError(parser.getCurrentLocation(), - "duplicate 'c2v_consumer_buf' operand"); - if (parser.parseOperand(c2vConsumerBuf) || - parser.parseColonType(c2vConsumerBufTy)) - return failure(); - hasC2vConsumerBuf = true; - } else if (keyword == "v2c_consumer_buf") { - if (hasV2cConsumerBuf) - return parser.emitError(parser.getCurrentLocation(), - "duplicate 'v2c_consumer_buf' operand"); - if (parser.parseOperand(v2cConsumerBuf) || - parser.parseColonType(v2cConsumerBufTy)) - return failure(); - hasV2cConsumerBuf = true; - } else { - return parser.emitError(parser.getCurrentLocation()) - << "unexpected initialize_pipe operand '" << keyword << "'"; - } - - if (succeeded(parser.parseOptionalRParen())) - break; - if (parser.parseComma()) - return failure(); - } - - if (parser.parseOptionalAttrDict(attrs)) - return failure(); - - result.addAttributes(attrs); - result.addAttribute("operandSegmentSizes", - parser.getBuilder().getDenseI32ArrayAttr( - {hasGmSlotBuffer ? 1 : 0, hasGmSlotTensor ? 1 : 0, - hasC2vConsumerBuf ? 1 : 0, - hasV2cConsumerBuf ? 1 : 0})); - if (hasGmSlotBuffer && - parser.resolveOperand(gmSlotBuffer, gmSlotBufferTy, result.operands)) - return failure(); - if (hasGmSlotTensor && - parser.resolveOperand(gmSlotTensor, gmSlotTensorTy, result.operands)) - return failure(); - if (hasC2vConsumerBuf && - parser.resolveOperand(c2vConsumerBuf, c2vConsumerBufTy, result.operands)) - return failure(); - if (hasV2cConsumerBuf && - parser.resolveOperand(v2cConsumerBuf, v2cConsumerBufTy, result.operands)) - return failure(); - return success(); -} - -template -static void printFrontendInitializePipeOp(InitOpT op, OpAsmPrinter &p) { - p << " {"; - bool needsComma = false; - auto printClause = [&](StringRef keyword, auto value) { - if (needsComma) - p << ", "; - p << keyword << " = " << value; - needsComma = true; - }; - - if (op.getId() != 0) - printClause("id", op.getId()); - printClause("dir_mask", static_cast(op.getDirMask())); - printClause("slot_size", op.getSlotSize()); - if (auto localSlotNumAttr = op.getLocalSlotNumAttr()) - printClause("local_slot_num", localSlotNumAttr.getInt()); - if (auto noSplitAttr = op.getNosplitAttr()) - printClause("nosplit", noSplitAttr.getValue() ? "true" : "false"); - p << "}"; - - p << "("; - bool needsOperandComma = false; - auto printOperandClause = [&](StringRef keyword, Value value) { - if (needsOperandComma) - p << ", "; - p << keyword << " = " << value << " : " << value.getType(); - needsOperandComma = true; - }; - if (op.getGmSlotBuffer()) { - printOperandClause("gm_slot_buffer", op.getGmSlotBuffer()); - } - if (op.getGmSlotTensor()) - printOperandClause("gm_slot_tensor", op.getGmSlotTensor()); - if (op.getC2vConsumerBuf()) - printOperandClause("c2v_consumer_buf", op.getC2vConsumerBuf()); - if (op.getV2cConsumerBuf()) - printOperandClause("v2c_consumer_buf", op.getV2cConsumerBuf()); - p << ")"; - p.printOptionalAttrDict( - op->getAttrs(), - /*elidedAttrs=*/{"id", "dir_mask", "slot_size", "local_slot_num", - "nosplit", "operandSegmentSizes"}); -} - -static std::optional -getStaticElementCount(ArrayRef shape) { - uint64_t count = 1; - for (int64_t dim : shape) { - if (dim == ShapedType::kDynamic || dim < 0) - return std::nullopt; - count *= static_cast(dim); - } - return count; -} - -static bool isSameOrHalfSlotByteSize(uint64_t tensorBytes, uint64_t slotBytes) { - return tensorBytes == slotBytes || tensorBytes * 2 == slotBytes; -} - -static LogicalResult verifyFrontendGlobalSlotTensor(Operation *op, Value tensor, - int8_t dirMask, - int32_t slotSize) { - (void)dirMask; - auto tvTy = dyn_cast(tensor.getType()); - if (!tvTy) - return op->emitOpError("expects 'gm_slot_tensor' to be !pto.tensor_view"); - - ArrayRef shape = tvTy.getShape(); - if (shape.empty()) - return op->emitOpError( - "expects 'gm_slot_tensor' to describe one slot entry tensor"); - - if (auto elemCount = getStaticElementCount(shape)) { - uint64_t elemBytes = getElemByteSize(tvTy.getElementType()); - if (elemBytes != 0) { - uint64_t tensorBytes = *elemCount * elemBytes; - if (!isSameOrHalfSlotByteSize(tensorBytes, - static_cast(slotSize))) { - return op->emitOpError() - << "expects 'slot_size' to equal gm_slot_tensor byte size " - "or twice gm_slot_tensor byte size for split GlobalTensor " - "entries (got slot_size = " - << slotSize << ", gm_slot_tensor byte size = " << tensorBytes - << ")"; - } - } - } - - return success(); -} - -template -static LogicalResult verifyFrontendInitCommon(InitOpT op, - FunctionKernelKind expected, - StringRef kernelName) { - if (failed(verifyFrontendKernelKind(op.getOperation(), expected, kernelName))) - return failure(); - - auto funcOp = op->template getParentOfType(); - if (!funcOp) - return op.emitOpError("must be nested under a func.func"); - - if (op.getId() < 0) - return op.emitOpError("expects 'id' to be non-negative"); - - unsigned sameIdInitCount = 0; - funcOp.walk([&](Operation *candidate) { - if (auto aic = dyn_cast(candidate)) { - if (aic.getId() == op.getId()) - ++sameIdInitCount; - return; - } - if (auto aiv = dyn_cast(candidate)) - if (aiv.getId() == op.getId()) - ++sameIdInitCount; - }); - if (sameIdInitCount > 1) { - return op.emitOpError( - "requires 'id' to be unique across frontend initialize_pipe ops in the function"); - } - - int8_t dirMask = op.getDirMask(); - if (dirMask != 1 && dirMask != 2 && dirMask != 3) - return op.emitOpError("expects 'dir_mask' to be 1, 2, or 3"); - if (op.getSlotSize() <= 0) - return op.emitOpError("expects 'slot_size' to be greater than 0"); - - bool hasGlobalSlotTensor = static_cast(op.getGmSlotTensor()); - bool hasC2vConsumerBuf = static_cast(op.getC2vConsumerBuf()); - bool hasV2cConsumerBuf = static_cast(op.getV2cConsumerBuf()); - if (hasGlobalSlotTensor) { - if (op.getGmSlotBuffer() || hasC2vConsumerBuf || hasV2cConsumerBuf) { - return op.emitOpError( - "globaltensor pipe init expects only 'gm_slot_tensor' and no " - "'gm_slot_buffer', 'c2v_consumer_buf', or 'v2c_consumer_buf'"); - } - if (op.getLocalSlotNumAttr()) - return op.emitOpError( - "globaltensor pipe init does not use 'local_slot_num'"); - if (getTargetArch(op.getOperation()) == PTOArch::A5) { - return op.emitOpError( - "globaltensor pipe entries are supported for a2/a3 l2g2l pipes"); - } - return verifyFrontendGlobalSlotTensor( - op.getOperation(), op.getGmSlotTensor(), dirMask, op.getSlotSize()); - } - - if (hasC2vConsumerBuf != hasV2cConsumerBuf) { - return op.emitOpError( - "expects 'c2v_consumer_buf' and 'v2c_consumer_buf' to be provided together"); - } - if (!hasC2vConsumerBuf) { - return op.emitOpError( - "expects local pipe init to provide 'c2v_consumer_buf' and " - "'v2c_consumer_buf'; use 'gm_slot_tensor' for globaltensor pipe entries"); - } - - if (auto localSlotNumAttr = op.getLocalSlotNumAttr()) { - int32_t localSlotNum = localSlotNumAttr.getInt(); - if (localSlotNum <= 0) - return op.emitOpError("expects 'local_slot_num' to be greater than 0"); - int32_t loweredSlotNum = dirMask == 3 ? 4 : 8; - if (localSlotNum > loweredSlotNum) { - return op.emitOpError() - << "expects 'local_slot_num' to be less than or equal to " - << loweredSlotNum << " for dir_mask = " << static_cast(dirMask); - } - } - - return success(); -} - -ParseResult AicInitializePipeOp::parse(OpAsmParser &parser, - OperationState &result) { - return parseFrontendInitializePipeOp(parser, result); -} - -void AicInitializePipeOp::print(OpAsmPrinter &p) { - printFrontendInitializePipeOp(*this, p); -} - -ParseResult AivInitializePipeOp::parse(OpAsmParser &parser, - OperationState &result) { - return parseFrontendInitializePipeOp(parser, result); -} - -void AivInitializePipeOp::print(OpAsmPrinter &p) { - printFrontendInitializePipeOp(*this, p); -} - -static ReserveBufferOp findReserveBufferByName(func::FuncOp funcOp, - StringRef name) { - ReserveBufferOp found; - funcOp.walk([&](ReserveBufferOp reserveOp) { - if (reserveOp.getName() != name) - return WalkResult::advance(); - found = reserveOp; - return WalkResult::interrupt(); - }); - return found; -} - -LogicalResult ReserveBufferOp::verify() { - auto funcOp = getOperation()->getParentOfType(); - if (!funcOp) - return emitOpError("must be nested under a func.func"); - - if (getSize() <= 0) - return emitOpError("expects 'size' to be greater than 0"); - - auto location = getLocation().getAddressSpace(); - if (location != AddressSpace::VEC && location != AddressSpace::MAT) - return emitOpError("expects 'location' to be #pto.address_space or #pto.address_space"); - - if (!getAutoAlloc() && !getBaseAttr()) - return emitOpError("expects 'base' when 'auto' is false"); - - if (auto baseAttr = getBaseAttr(); baseAttr && baseAttr.getInt() < 0) - return emitOpError("expects 'base' to be non-negative when present"); - - unsigned sameNameCount = 0; - funcOp.walk([&](ReserveBufferOp reserveOp) { - if (reserveOp.getName() == getName()) - ++sameNameCount; - }); - if (sameNameCount > 1) - return emitOpError("requires 'name' to be unique within the function"); - - return success(); -} - -LogicalResult ImportReservedBufferOp::verify() { - auto funcOp = getOperation()->getParentOfType(); - if (!funcOp) - return emitOpError("must be nested under a func.func"); - - auto peerFunc = SymbolTable::lookupNearestSymbolFrom( - getOperation(), getPeerFuncAttr()); - if (!peerFunc) - return emitOpError("expects 'peer_func' to reference an existing func.func"); - - unsigned sameImportCount = 0; - funcOp.walk([&](ImportReservedBufferOp importOp) { - if (importOp.getName() == getName() && - importOp.getPeerFuncAttr() == getPeerFuncAttr()) { - ++sameImportCount; - } - }); - if (sameImportCount > 1) { - return emitOpError( - "requires (name, peer_func) to be unique within the function"); - } - - if (!findReserveBufferByName(peerFunc, getName())) - return emitOpError("expects matching peer reserve_buffer to exist"); - - return success(); -} - -static FailureOr lookupFrontendInitOpById(Operation *op, - func::FuncOp funcOp, - int32_t id) { - Operation *matchedInit = nullptr; - unsigned matchedInitCount = 0; - funcOp.walk([&](Operation *candidate) { - if (auto aic = dyn_cast(candidate)) { - if (aic.getId() == static_cast(id)) { - matchedInit = candidate; - ++matchedInitCount; - } - return WalkResult::advance(); - } - if (auto aiv = dyn_cast(candidate)) { - if (aiv.getId() == static_cast(id)) { - matchedInit = candidate; - ++matchedInitCount; - } - return WalkResult::advance(); - } - return WalkResult::advance(); - }); - - if (matchedInitCount == 0) { - op->emitOpError() << "expects 'id' = " << id - << " to match a frontend initialize_pipe op in the same function"; - return failure(); - } - if (matchedInitCount > 1) { - op->emitOpError() << "expects 'id' = " << id - << " to match exactly one frontend initialize_pipe op in the same function"; - return failure(); - } - return matchedInit; -} - -static LogicalResult verifyFrontendSplitOp(Operation *op, - FunctionKernelKind expected, - StringRef kernelName, - int32_t id, - int64_t split) { - if (failed(verifyFrontendKernelKind(op, expected, kernelName))) - return failure(); - if (id < 0) - return op->emitOpError("expects 'id' to be non-negative"); - return verifySplitAttr(op, split); -} - -static FailureOr lookupFrontendInitDirMaskById(Operation *op, - func::FuncOp funcOp, - int32_t id) { - auto initOr = lookupFrontendInitOpById(op, funcOp, id); - if (failed(initOr)) - return failure(); - if (auto aic = dyn_cast(*initOr)) - return aic.getDirMask(); - return cast(*initOr).getDirMask(); -} - -static LogicalResult verifyFrontendDataOpDirection(Operation *op, int32_t id, - bool expectC2V) { - auto funcOp = op->getParentOfType(); - if (!funcOp) - return op->emitOpError("must be nested under a func.func"); - - auto dirMaskOr = lookupFrontendInitDirMaskById(op, funcOp, id); - if (failed(dirMaskOr)) - return failure(); - - int8_t dirMask = *dirMaskOr; - if (expectC2V && dirMask != 1 && dirMask != 3) { - return op->emitOpError() - << "expects 'id' = " << id - << " to reference initialize_pipe with dir_mask = 1 or 3"; - } - if (!expectC2V && dirMask != 2 && dirMask != 3) { - return op->emitOpError() - << "expects 'id' = " << id - << " to reference initialize_pipe with dir_mask = 2 or 3"; - } - return success(); -} - -static Value getFrontendInitGmSlotTensor(Operation *initOp) { - if (auto aic = dyn_cast(initOp)) - return aic.getGmSlotTensor(); - return cast(initOp).getGmSlotTensor(); -} - -static LogicalResult verifyFrontendTensorEntryMatchesInit(Operation *op, - int32_t id, - Type entryTy) { - auto entryViewTy = dyn_cast(entryTy); - if (!entryViewTy) - return success(); - - auto funcOp = op->getParentOfType(); - if (!funcOp) - return op->emitOpError("must be nested under a func.func"); - - auto initOr = lookupFrontendInitOpById(op, funcOp, id); - if (failed(initOr)) - return failure(); - Value gmSlotTensor = getFrontendInitGmSlotTensor(*initOr); - if (!gmSlotTensor) { - return op->emitOpError() - << "expects 'id' = " << id - << " to reference initialize_pipe with 'gm_slot_tensor' when the " - "pipe entry is !pto.tensor_view"; - } - - auto slotTensorTy = dyn_cast(gmSlotTensor.getType()); - if (!slotTensorTy) - return op->emitOpError("expects 'gm_slot_tensor' to be !pto.tensor_view"); - if (slotTensorTy.getElementType() != entryViewTy.getElementType()) { - return op->emitOpError() - << "expects pipe entry element type to match gm_slot_tensor element type"; - } - if (slotTensorTy.getRank() != entryViewTy.getRank()) { - return op->emitOpError() - << "expects pipe entry rank to match gm_slot_tensor rank"; - } - - ArrayRef slotShape = slotTensorTy.getShape(); - ArrayRef entryShape = entryViewTy.getShape(); - for (auto [idx, entryDim] : llvm::enumerate(entryShape)) { - int64_t slotDim = slotShape[idx]; - if (slotDim == ShapedType::kDynamic || - entryDim == ShapedType::kDynamic || slotDim == entryDim) - continue; - return op->emitOpError() - << "expects pipe entry dimension " << idx - << " to match gm_slot_tensor dimension " << slotDim; - } - return success(); -} - -template -static LogicalResult verifyFrontendPopOp(FrontendPopOpT op, - FunctionKernelKind expected, - StringRef kernelName, - bool expectC2V) { - if (failed(verifyFrontendSplitOp(op.getOperation(), expected, kernelName, - op.getId(), - op.getSplit()))) - return failure(); - if (failed(verifyFrontendDataOpDirection(op.getOperation(), op.getId(), - expectC2V))) - return failure(); - if (failed(verifyFrontendTensorEntryMatchesInit(op.getOperation(), op.getId(), - op.getTile().getType()))) - return failure(); - - bool hasValidRow = static_cast(op.getValidRow()); - bool hasValidCol = static_cast(op.getValidCol()); - if (hasValidRow != hasValidCol) - return op.emitOpError( - "expects valid_row and valid_col operands to be provided together"); - if (!hasValidRow) - return success(); - - if (isa(op.getTile().getType())) - return op.emitOpError( - "does not accept valid_row/valid_col when result is !pto.tensor_view"); - - auto tileTy = dyn_cast(op.getTile().getType()); - if (!tileTy) - return op.emitOpError( - "expects tile result to be !pto.tile_buf when valid_row/valid_col operands are provided"); - if (!tileTy.hasDynamicValid()) - return op.emitOpError( - "expects tile result to have dynamic validShape (?, ?) when valid_row/valid_col operands are provided"); - return success(); -} - -static LogicalResult verifyPipeShape(Operation *op, int8_t dirMask, int32_t slotSize, - int32_t slotNum, - std::optional flagBase) { - constexpr int32_t kMaxHardwareFlagIds = 16; - if (dirMask != 1 && dirMask != 2 && dirMask != 3) - return op->emitOpError("expects 'dir_mask' to be 1, 2, or 3"); - if (slotSize <= 0) - return op->emitOpError("expects 'slot_size' to be greater than 0"); - if (slotNum != 4 && slotNum != 8) - return op->emitOpError("expects 'slot_num' to be 4 or 8"); - if (flagBase && *flagBase < 0) - return op->emitOpError("expects 'flag_base' to be non-negative when present"); - if (flagBase) { - int32_t flagWidth = dirMask == 3 ? 4 : 2; - if (*flagBase + flagWidth > kMaxHardwareFlagIds) { - return op->emitOpError() - << "requires 'flag_base' and dir_mask to fit within " - << kMaxHardwareFlagIds << " hardware flag ids"; - } - } - - return success(); -} - -static LogicalResult verifyPipeHandleProducer(Operation *op, Value pipeHandle) { - if (!isa(pipeHandle.getType())) - return op->emitOpError("expects pipe operand type !pto.pipe"); - if (!pipeHandle.getDefiningOp() && - !pipeHandle.getDefiningOp()) { - return op->emitOpError( - "pipe_handle must be produced by pto.initialize_l2l_pipe or " - "pto.initialize_l2g2l_pipe"); - } - return success(); -} - -static bool getTensorLikeElementAndShape(Type ty, Type &elementType, - ArrayRef &shape) { - if (auto tvTy = dyn_cast(ty)) { - elementType = tvTy.getElementType(); - shape = tvTy.getShape(); - return true; - } - if (auto memrefTy = dyn_cast(ty)) { - elementType = memrefTy.getElementType(); - shape = memrefTy.getShape(); - return true; - } - return false; -} - -static LogicalResult verifyTensorEntryMatchesInternalPipeInit(Operation *op, - Value pipeHandle, - Type entryTy) { - auto entryViewTy = dyn_cast(entryTy); - if (!entryViewTy) - return success(); - - auto initOp = pipeHandle.getDefiningOp(); - if (!initOp) { - return op->emitOpError() - << "expects !pto.tensor_view pipe entry to use a pipe produced by " - "pto.initialize_l2g2l_pipe"; - } - if (initOp.getLocalAddr()) { - return op->emitOpError() - << "expects !pto.tensor_view pipe entry to use global-only " - "pto.initialize_l2g2l_pipe without local_addr"; - } - - Type slotElementType; - ArrayRef slotShape; - if (!getTensorLikeElementAndShape(initOp.getGmAddr().getType(), - slotElementType, slotShape)) { - return op->emitOpError() - << "expects !pto.tensor_view pipe entry to use " - "pto.initialize_l2g2l_pipe gm_addr with tensor/memref slot type"; - } - - if (slotElementType != entryViewTy.getElementType()) { - return op->emitOpError() - << "expects pipe entry element type to match initialize_l2g2l_pipe " - "gm_addr element type"; - } - if (slotShape.size() != static_cast(entryViewTy.getRank())) { - return op->emitOpError() - << "expects pipe entry rank to match initialize_l2g2l_pipe gm_addr " - "rank"; - } - - ArrayRef entryShape = entryViewTy.getShape(); - for (auto [idx, entryDim] : llvm::enumerate(entryShape)) { - int64_t slotDim = slotShape[idx]; - if (slotDim == ShapedType::kDynamic || - entryDim == ShapedType::kDynamic || slotDim == entryDim) - continue; - return op->emitOpError() - << "expects pipe entry dimension " << idx - << " to match initialize_l2g2l_pipe gm_addr dimension " - << slotDim; - } - - if (auto entryElemCount = getStaticElementCount(entryShape)) { - uint64_t elemBytes = getElemByteSize(entryViewTy.getElementType()); - uint64_t entryBytes = *entryElemCount * elemBytes; - if (elemBytes != 0) { - int8_t split = 0; - if (auto alloc = dyn_cast(op)) - split = alloc.getSplit(); - else if (auto push = dyn_cast(op)) - split = push.getSplit(); - else if (auto pop = dyn_cast(op)) - split = pop.getSplit(); - else if (auto free = dyn_cast(op)) - split = free.getSplit(); - - uint64_t slotBytes = static_cast(initOp.getSlotSize()); - bool isSplitEntry = split != 0; - bool byteSizeMatches = - entryBytes == slotBytes || (isSplitEntry && entryBytes * 2 == slotBytes); - if (!byteSizeMatches) { - return op->emitOpError() - << "expects pipe entry byte size to match initialize_l2g2l_pipe " - "slot_size" - << (isSplitEntry ? " or half slot_size for split entries" : "") - << " (got entry byte size = " << entryBytes - << ", slot_size = " << initOp.getSlotSize() << ")"; - } - } - } - - return success(); -} - -LogicalResult BuildAsyncSessionOp::verify() { - Type scratchTy = getScratch().getType(); - if (!isa(scratchTy)) - return emitOpError("expects scratch to be tile_buf or memref type"); - - auto scratchSpace = getPTOMemorySpaceEnum(scratchTy); - if (!scratchSpace || *scratchSpace != pto::AddressSpace::VEC) - return emitOpError("expects scratch to be in vec address space"); - - auto scratchShape = getShapeVec(scratchTy); - if (scratchShape.empty() || scratchShape.size() > 2) - return emitOpError("expects scratch to be rank-1 or rank-2"); - for (int64_t dim : scratchShape) { - if (dim == ShapedType::kDynamic) - return emitOpError("expects scratch to have a static shape"); - } - - auto scratchBytes = getStaticByteSize(scratchTy); - if (!scratchBytes) - return emitOpError("expects scratch byte size to be statically known"); - if (*scratchBytes < sizeof(uint64_t)) - return emitOpError("expects scratch to provide at least 8 bytes"); - - Type workspaceElemTy; - Type workspaceTy = getWorkspace().getType(); - if (auto ptrTy = dyn_cast(workspaceTy)) { - workspaceElemTy = ptrTy.getElementType(); - } else if (auto memTy = dyn_cast(workspaceTy)) { - workspaceElemTy = memTy.getElementType(); - if (!isGmAddressSpaceAttr(memTy.getMemorySpace())) - return emitOpError("expects workspace to be in GM address space"); - } else { - return emitOpError("expects workspace to be !pto.ptr or memref type"); - } - if (!isByteIntegerType(workspaceElemTy)) - return emitOpError("expects workspace element type to be an 8-bit integer"); - - if (auto syncIdAttr = getSyncIdAttr()) { - int64_t syncId = syncIdAttr.getInt(); - if (syncId < 0 || syncId > 7) - return emitOpError("expects sync_id in range [0, 7]"); - } - if (auto blockBytesAttr = getBlockBytesAttr()) { - if (blockBytesAttr.getInt() <= 0) - return emitOpError("expects block_bytes to be greater than 0"); - } - if (auto commBlockOffsetAttr = getCommBlockOffsetAttr()) { - if (commBlockOffsetAttr.getInt() < 0) - return emitOpError("expects comm_block_offset to be non-negative"); - } - if (auto queueNumAttr = getQueueNumAttr()) { - if (queueNumAttr.getInt() <= 0) - return emitOpError("expects queue_num to be greater than 0"); - } - if (auto channelGroupIdxAttr = getChannelGroupIdxAttr()) { - APInt value = channelGroupIdxAttr.getValue(); - if (value.isNegative()) - return emitOpError("expects channel_group_idx to be non-negative"); - if (value.ugt(UINT32_MAX)) - return emitOpError("expects channel_group_idx to fit in uint32"); - } - - return success(); -} - -static LogicalResult verifyAsyncTransferOp(Operation *op, Value dst, Value src) { - Type dstElemTy = getElemTy(dst.getType()); - Type srcElemTy = getElemTy(src.getType()); - if (!dstElemTy || !srcElemTy) - return op->emitOpError("expects src and dst to have element types"); - if (dstElemTy != srcElemTy) - return op->emitOpError("expects src and dst to have the same element type"); - if (failed(verifyAsyncFlatContiguous1DGMViewLike(op, dst, "dst")) || - failed(verifyAsyncFlatContiguous1DGMViewLike(op, src, "src"))) - return failure(); - if (getShapeVec(dst.getType()) != getShapeVec(src.getType())) - return op->emitOpError("expects src and dst to have the same static shape"); - return success(); -} - -LogicalResult TPutAsyncOp::verify() { - return verifyAsyncTransferOp(getOperation(), getDst(), getSrc()); -} - -LogicalResult TGetAsyncOp::verify() { - return verifyAsyncTransferOp(getOperation(), getDst(), getSrc()); -} - -LogicalResult TPutOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - if (failed(verifyCommGlobalLike(*this, getDst(), "dst")) || - failed(verifyCommGlobalLike(*this, getSrc(), "src")) || - failed(verifyCommStagingTileLike(*this, getPing(), "ping")) || - failed(verifyCommPingPongSameType(*this, getPing(), getPong(), "ping", - "pong"))) - return failure(); - if (getElemTy(getDst().getType()) != getElemTy(getSrc().getType())) - return emitOpError("expects src and dst to have the same element type"); - if (getShapeVec(getDst().getType()) != getShapeVec(getSrc().getType())) - return emitOpError("expects src and dst to have the same static shape"); - if (getElemTy(getPing().getType()) != getElemTy(getSrc().getType())) - return emitOpError("expects staging tile element type to match src/dst"); - return success(); -} - -LogicalResult TGetOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - if (failed(verifyCommGlobalLike(*this, getDst(), "dst")) || - failed(verifyCommGlobalLike(*this, getSrc(), "src")) || - failed(verifyCommStagingTileLike(*this, getPing(), "ping")) || - failed(verifyCommPingPongSameType(*this, getPing(), getPong(), "ping", - "pong"))) - return failure(); - if (getElemTy(getDst().getType()) != getElemTy(getSrc().getType())) - return emitOpError("expects src and dst to have the same element type"); - if (getShapeVec(getDst().getType()) != getShapeVec(getSrc().getType())) - return emitOpError("expects src and dst to have the same static shape"); - if (getElemTy(getPing().getType()) != getElemTy(getSrc().getType())) - return emitOpError("expects staging tile element type to match src/dst"); - return success(); -} - -LogicalResult TNotifyOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - if (failed(verifyCommSignalLike(*this, getSignal(), "signal"))) - return failure(); - auto valueTy = dyn_cast(getValue().getType()); - if (!valueTy || valueTy.getWidth() != 32) - return emitOpError("expects value to be i32"); - return success(); -} - -LogicalResult TWaitOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - if (failed(verifyCommSignalLike(*this, getSignal(), "signal"))) - return failure(); - auto cmpTy = dyn_cast(getCmpValue().getType()); - if (!cmpTy || cmpTy.getWidth() != 32) - return emitOpError("expects cmp_value to be i32"); - return success(); -} - -LogicalResult TTestOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - if (failed(verifyCommSignalLike(*this, getSignal(), "signal"))) - return failure(); - auto cmpTy = dyn_cast(getCmpValue().getType()); - if (!cmpTy || cmpTy.getWidth() != 32) - return emitOpError("expects cmp_value to be i32"); - return success(); -} - -static LogicalResult verifySyncAllGmWorkspace(Operation *op, Value workspace, - StringRef name) { - Type ty = workspace.getType(); - if (!isa(ty)) - return op->emitOpError() << "expects " << name - << " to be a GM memref/tensor_view/partition_view"; - - if (auto memTy = dyn_cast(ty)) { - if (!memTy.hasRank()) - return op->emitOpError() << "expects " << name << " to be ranked"; - if (!isGmAddressSpaceAttr(memTy.getMemorySpace())) - return op->emitOpError() << "expects " << name - << " to be in GM address space"; - } - - auto elemTy = dyn_cast(getElemTy(ty)); - if (!elemTy || elemTy.getWidth() != 32) - return op->emitOpError() << "expects " << name - << " element type to be i32"; - - SmallVector shape = getShapeVec(ty); - if (shape.empty()) - return op->emitOpError() << "expects " << name << " to have rank >= 1"; - for (int64_t dim : shape) { - if (dim != ShapedType::kDynamic && dim <= 0) - return op->emitOpError() << "expects " << name - << " shape to be positive"; - } - return success(); -} - -static LogicalResult verifySyncAllTileWorkspace(Operation *op, Value workspace, - StringRef name, - pto::AddressSpace expectedSpace) { - Type ty = workspace.getType(); - if (!isa(ty)) - return op->emitOpError() << "expects " << name - << " to be tile_buf or memref type"; - - if (isa(ty) && failed(verifyTileBufCommon(op, ty, name))) - return failure(); - - auto as = getPTOMemorySpaceEnum(ty); - if (!as || *as != expectedSpace) - return op->emitOpError() << "expects " << name << " to be in " - << (expectedSpace == pto::AddressSpace::VEC - ? "vec" - : "mat") - << " address space"; - - Type elemTy = getElemTy(ty); - auto intTy = dyn_cast_or_null(elemTy); - if (!intTy || intTy.getWidth() != 32) - return op->emitOpError() << "expects " << name - << " element type to be i32"; - - auto shape = getShapeVec(ty); - if (shape.empty() || shape.size() > 2) - return op->emitOpError() << "expects " << name - << " to be rank-1 or rank-2"; - for (int64_t dim : shape) { - if (dim != ShapedType::kDynamic && dim <= 0) - return op->emitOpError() << "expects " << name - << " shape to be positive"; - } - return success(); -} - -LogicalResult SyncAllOp::verify() { - bool hasGm = static_cast(getGmWorkspace()); - bool hasUb = static_cast(getUbWorkspace()); - bool hasL1 = static_cast(getL1Workspace()); - auto mode = getMode().getValue(); - auto coreType = getCoreType().getValue(); - - if (mode == pto::SyncAllMode::Hard) { - if (hasGm || hasUb || hasL1 || getUsedCores()) - return emitOpError( - "expects hard syncall to have no workspace operands or used_cores"); - return success(); - } - - if (!hasGm) - return emitOpError("expects soft syncall to provide gm_workspace"); - if (failed(verifySyncAllGmWorkspace(getOperation(), getGmWorkspace(), - "gm_workspace"))) - return failure(); - - if (auto used = getUsedCores()) { - auto intTy = dyn_cast(used.getType()); - if (!intTy || intTy.getWidth() != 32) - return emitOpError("expects used_cores to be i32"); - } - - switch (coreType) { - case pto::SyncCoreType::AIVOnly: - if (!hasUb || hasL1) - return emitOpError("expects soft AIV-only syncall to use gm_workspace " - "+ ub_workspace only"); - return verifySyncAllTileWorkspace(getOperation(), getUbWorkspace(), - "ub_workspace", - pto::AddressSpace::VEC); - case pto::SyncCoreType::AICOnly: - if (hasUb || !hasL1) - return emitOpError("expects soft AIC-only syncall to use gm_workspace " - "+ l1_workspace only"); - return verifySyncAllTileWorkspace(getOperation(), getL1Workspace(), - "l1_workspace", - pto::AddressSpace::MAT); - case pto::SyncCoreType::Mix: - if (!hasUb || !hasL1) - return emitOpError("expects soft mixed syncall to use gm_workspace + " - "ub_workspace + l1_workspace"); - if (failed(verifySyncAllTileWorkspace(getOperation(), getUbWorkspace(), - "ub_workspace", - pto::AddressSpace::VEC))) - return failure(); - return verifySyncAllTileWorkspace(getOperation(), getL1Workspace(), - "l1_workspace", - pto::AddressSpace::MAT); - } - - llvm_unreachable("unhandled SyncCoreType"); -} - -LogicalResult TBroadcastOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - if (failed(verifyCommGlobalLike(*this, getSrc(), "src")) || - failed(verifyCommStagingTileLike(*this, getPing(), "ping")) || - failed(verifyCommPingPongSameType(*this, getPing(), getPong(), "ping", - "pong")) || - failed(verifyCommGlobalGroup(*this, getGroup(), "group"))) - return failure(); - if (getRoot() >= static_cast(getGroup().size())) - return emitOpError("expects root to index into group operands"); - if (getSrc().getType() != getGroup().front().getType()) - return emitOpError("expects src type to match group member type"); - if (getElemTy(getPing().getType()) != getElemTy(getSrc().getType())) - return emitOpError("expects staging tile element type to match src"); - return success(); -} - -LogicalResult CommTGatherOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - if (failed(verifyCommGlobalLike(*this, getDst(), "dst")) || - failed(verifyCommStagingTileLike(*this, getPing(), "ping")) || - failed(verifyCommPingPongSameType(*this, getPing(), getPong(), "ping", - "pong")) || - failed(verifyCommGlobalGroup(*this, getGroup(), "group"))) - return failure(); - if (getRoot() >= static_cast(getGroup().size())) - return emitOpError("expects root to index into group operands"); - if (getElemTy(getDst().getType()) != getElemTy(getGroup().front().getType())) - return emitOpError("expects dst element type to match group member type"); - if (getElemTy(getPing().getType()) != getElemTy(getDst().getType())) - return emitOpError("expects staging tile element type to match dst"); - return success(); -} - -LogicalResult CommTScatterOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - if (failed(verifyCommGlobalLike(*this, getSrc(), "src")) || - failed(verifyCommStagingTileLike(*this, getPing(), "ping")) || - failed(verifyCommPingPongSameType(*this, getPing(), getPong(), "ping", - "pong")) || - failed(verifyCommGlobalGroup(*this, getGroup(), "group"))) - return failure(); - if (getRoot() >= static_cast(getGroup().size())) - return emitOpError("expects root to index into group operands"); - if (getElemTy(getSrc().getType()) != getElemTy(getGroup().front().getType())) - return emitOpError("expects src element type to match group member type"); - if (getElemTy(getPing().getType()) != getElemTy(getSrc().getType())) - return emitOpError("expects staging tile element type to match src"); - return success(); -} - -LogicalResult TReduceOp::verify() { - if (shouldBypassDecodedMemrefVerifier(getOperation())) - return success(); - if (failed(verifyCommGlobalLike(*this, getDst(), "dst")) || - failed(verifyCommStagingTileLike(*this, getAcc(), "acc")) || - failed(verifyCommStagingTileLike(*this, getRecvPing(), "recv_ping")) || - failed(verifyCommPingPongSameType(*this, getRecvPing(), getRecvPong(), - "recv_ping", "recv_pong")) || - failed(verifyCommGlobalGroup(*this, getGroup(), "group"))) - return failure(); - if (getRoot() >= static_cast(getGroup().size())) - return emitOpError("expects root to index into group operands"); - if (getElemTy(getDst().getType()) != getElemTy(getGroup().front().getType())) - return emitOpError("expects dst element type to match group member type"); - if (getAcc().getType() != getRecvPing().getType()) - return emitOpError("expects acc and recv_ping to have identical types"); - if (getElemTy(getAcc().getType()) != getElemTy(getDst().getType())) - return emitOpError("expects accumulator/receive tiles to match dst element type"); - return success(); -} - -LogicalResult AicInitializePipeOp::verify() { - return verifyFrontendInitCommon(*this, FunctionKernelKind::Cube, "cube"); -} - -LogicalResult AivInitializePipeOp::verify() { - return verifyFrontendInitCommon(*this, FunctionKernelKind::Vector, "vector"); -} - -LogicalResult TAllocToAivOp::verify() { - if (failed(verifyFrontendSplitOp(getOperation(), FunctionKernelKind::Cube, - "cube", getId(), getSplit()))) - return failure(); - if (failed(verifyFrontendDataOpDirection(getOperation(), getId(), - /*expectC2V=*/true))) - return failure(); - return verifyFrontendTensorEntryMatchesInit(getOperation(), getId(), - getEntry().getType()); -} - -LogicalResult TAllocToAicOp::verify() { - if (failed(verifyFrontendSplitOp(getOperation(), FunctionKernelKind::Vector, - "vector", getId(), getSplit()))) - return failure(); - if (failed(verifyFrontendDataOpDirection(getOperation(), getId(), - /*expectC2V=*/false))) - return failure(); - return verifyFrontendTensorEntryMatchesInit(getOperation(), getId(), - getEntry().getType()); -} - -LogicalResult TPushToAivOp::verify() { - if (failed(verifyFrontendSplitOp(getOperation(), FunctionKernelKind::Cube, - "cube", getId(), getSplit()))) - return failure(); - if (failed(verifyFrontendDataOpDirection(getOperation(), getId(), - /*expectC2V=*/true))) - return failure(); - return verifyFrontendTensorEntryMatchesInit(getOperation(), getId(), - getTile().getType()); -} - -LogicalResult TPushToAicOp::verify() { - if (failed(verifyFrontendSplitOp(getOperation(), FunctionKernelKind::Vector, - "vector", getId(), getSplit()))) - return failure(); - if (failed(verifyFrontendDataOpDirection(getOperation(), getId(), - /*expectC2V=*/false))) - return failure(); - return verifyFrontendTensorEntryMatchesInit(getOperation(), getId(), - getTile().getType()); -} - -LogicalResult TPopFromAicOp::verify() { - return verifyFrontendPopOp(*this, FunctionKernelKind::Vector, "vector", - /*expectC2V=*/true); -} - -LogicalResult TPopFromAivOp::verify() { - return verifyFrontendPopOp(*this, FunctionKernelKind::Cube, "cube", - /*expectC2V=*/false); -} - -LogicalResult TFreeFromAicOp::verify() { - if (failed(verifyFrontendSplitOp(getOperation(), FunctionKernelKind::Vector, - "vector", getId(), getSplit()))) - return failure(); - if (failed(verifyFrontendDataOpDirection(getOperation(), getId(), - /*expectC2V=*/true))) - return failure(); - if (getEntry()) - return verifyFrontendTensorEntryMatchesInit(getOperation(), getId(), - getEntry().getType()); - return success(); -} - -LogicalResult TFreeFromAivOp::verify() { - if (failed(verifyFrontendSplitOp(getOperation(), FunctionKernelKind::Cube, - "cube", getId(), getSplit()))) - return failure(); - if (failed(verifyFrontendDataOpDirection(getOperation(), getId(), - /*expectC2V=*/false))) - return failure(); - if (getEntry()) - return verifyFrontendTensorEntryMatchesInit(getOperation(), getId(), - getEntry().getType()); - return success(); -} - -LogicalResult InitializeL2G2LPipeOp::verify() { - if (failed(verifyPipeShape(getOperation(), getDirMask(), getSlotSize(), - getSlotNum(), - getFlagBaseAttr() - ? std::optional(getFlagBaseAttr().getInt()) - : std::nullopt))) - return failure(); - - if (!getLocalAddr()) { - if (getPeerLocalAddr()) - return emitOpError("'peer_local_addr' requires 'local_addr'"); - if (getLocalSlotNumAttr()) - return emitOpError( - "'local_slot_num' is only allowed when 'local_addr' is present"); - return success(); - } - - if (auto localSlotNumAttr = getLocalSlotNumAttr()) { - int32_t localSlotNum = localSlotNumAttr.getInt(); - if (localSlotNum <= 0) - return emitOpError("expects 'local_slot_num' to be greater than 0"); - if (static_cast(localSlotNum) > getSlotNum()) - return emitOpError( - "expects 'local_slot_num' to be less than or equal to slot_num"); - } - - if (getDirMask() == 3 && !getPeerLocalAddr()) - return emitOpError("expects 'peer_local_addr' when dir_mask is 3"); - if (getDirMask() != 3 && getPeerLocalAddr()) - return emitOpError("'peer_local_addr' is only allowed when dir_mask is 3"); - return success(); -} - -LogicalResult InitializeL2LPipeOp::verify() { - if (failed(verifyPipeShape(getOperation(), getDirMask(), getSlotSize(), - getSlotNum(), - getFlagBaseAttr() - ? std::optional(getFlagBaseAttr().getInt()) - : std::nullopt))) - return failure(); - - if (getDirMask() == 3 && !getPeerLocalAddr()) - return emitOpError("expects 'peer_local_addr' when dir_mask is 3"); - if (getDirMask() != 3 && getPeerLocalAddr()) - return emitOpError("'peer_local_addr' is only allowed when dir_mask is 3"); - return success(); -} - -LogicalResult TPushOp::verify() { - if (!isInsideSectionOrAttributedKernel(getOperation())) - return emitOpError("must be inside pto.section.cube/vector or a kernel_kind function"); - if (failed(verifyPipeHandleProducer(getOperation(), getPipeHandle()))) - return failure(); - if (failed(verifySplitAttr(getOperation(), getSplit()))) - return failure(); - if (failed(verifyTensorEntryMatchesInternalPipeInit( - getOperation(), getPipeHandle(), getTile().getType()))) - return failure(); - if (!isa(getTile().getType()) && - getPipe() == pto::PIPE::PIPE_UNASSIGNED) - return emitOpError("tile type must map to a supported producer pipe"); - return success(); -} - -LogicalResult TAllocOp::verify() { - if (!isInsideSectionOrAttributedKernel(getOperation())) - return emitOpError("must be inside pto.section.cube/vector or a kernel_kind function"); - if (failed(verifyPipeHandleProducer(getOperation(), getPipeHandle()))) - return failure(); - if (failed(verifyTensorEntryMatchesInternalPipeInit( - getOperation(), getPipeHandle(), getEntry().getType()))) - return failure(); - return verifySplitAttr(getOperation(), getSplit()); -} - -LogicalResult TPopOp::verify() { - if (!isInsideSectionOrAttributedKernel(getOperation())) - return emitOpError("must be inside pto.section.cube/vector or a kernel_kind function"); - if (failed(verifyPipeHandleProducer(getOperation(), getPipeHandle()))) - return failure(); - if (failed(verifySplitAttr(getOperation(), getSplit()))) - return failure(); - if (failed(verifyTensorEntryMatchesInternalPipeInit( - getOperation(), getPipeHandle(), getTile().getType()))) - return failure(); - if (!isa(getTile().getType()) && - getPipe() == pto::PIPE::PIPE_UNASSIGNED) - return emitOpError( - "tile type and target arch must map to a supported consumer pipe"); - return success(); -} - -LogicalResult TFreeOp::verify() { - if (!isInsideSectionOrAttributedKernel(getOperation())) - return emitOpError("must be inside pto.section.cube/vector or a kernel_kind function"); - if (failed(verifyPipeHandleProducer(getOperation(), getPipeHandle()))) - return failure(); - if (getEntry() && - failed(verifyTensorEntryMatchesInternalPipeInit( - getOperation(), getPipeHandle(), getEntry().getType()))) - return failure(); - return verifySplitAttr(getOperation(), getSplit()); -} - -ParseResult TFreeOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand first; - OpAsmParser::UnresolvedOperand pipe; - Type firstTy; - Type pipeTy; - bool hasEntry = false; - - if (parser.parseLParen() || parser.parseOperand(first)) - return failure(); - - if (succeeded(parser.parseOptionalComma())) { - hasEntry = true; - if (parser.parseOperand(pipe) || parser.parseColonType(firstTy) || - parser.parseComma() || parser.parseType(pipeTy) || parser.parseRParen()) - return failure(); - } else { - if (parser.parseColonType(pipeTy) || parser.parseRParen()) - return failure(); - pipe = first; - } - - NamedAttrList attrs; - if (parser.parseLBrace() || parser.parseKeyword("split") || - parser.parseEqual()) - return failure(); - IntegerAttr splitAttr; - if (parser.parseAttribute(splitAttr, parser.getBuilder().getI8Type(), - "split", attrs) || - parser.parseRBrace() || parser.parseOptionalAttrDict(attrs)) - return failure(); - - result.addAttributes(attrs); - if (hasEntry && - parser.resolveOperand(first, firstTy, result.operands)) - return failure(); - if (parser.resolveOperand(pipe, pipeTy, result.operands)) - return failure(); - return success(); -} - -void TFreeOp::print(OpAsmPrinter &p) { - p << "("; - if (getEntry()) { - p << getEntry() << ", " << getPipeHandle() << " : " - << getEntry().getType() << ", " << getPipeHandle().getType(); - } else { - p << getPipeHandle() << " : " << getPipeHandle().getType(); - } - p << ") {split = " << static_cast(getSplit()) << "}"; - p.printOptionalAttrDict((*this)->getAttrs(), - /*elidedAttrs=*/{"split"}); -} - -void BuildAsyncSessionOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getScratchMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getWorkspaceMutable(), MemoryEffects::Read::get()); - addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); -} - -void TPutAsyncOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); - addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getSessionMutable(), MemoryEffects::Read::get()); - addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); -} - -void TGetAsyncOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); - addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getSessionMutable(), MemoryEffects::Read::get()); - addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); -} - -void TPutOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); - addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getPingMutable(), MemoryEffects::Read::get()); -} - -void TGetOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); - addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getPingMutable(), MemoryEffects::Read::get()); -} - -void TNotifyOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getSignalMutable(), MemoryEffects::Write::get()); - addEffect(effects, &getValueMutable(), MemoryEffects::Read::get()); -} - -void TWaitOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getSignalMutable(), MemoryEffects::Write::get()); - addEffect(effects, &getCmpValueMutable(), MemoryEffects::Read::get()); -} - -void TTestOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getSignalMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getCmpValueMutable(), MemoryEffects::Read::get()); - addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); -} - -void TBroadcastOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getPingMutable(), MemoryEffects::Write::get()); - if (getPong()) { - auto pongRange = getPongMutable(); - if (auto it = pongRange.begin(); it != pongRange.end()) - addEffect(effects, &*it, MemoryEffects::Write::get()); - } -} - -void CommTGatherOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); - addEffect(effects, &getPingMutable(), MemoryEffects::Read::get()); -} - -void CommTScatterOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getSrcMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getPingMutable(), MemoryEffects::Write::get()); - if (getPong()) { - auto pongRange = getPongMutable(); - if (auto it = pongRange.begin(); it != pongRange.end()) - addEffect(effects, &*it, MemoryEffects::Write::get()); - } -} - -void TReduceOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getDstMutable(), MemoryEffects::Write::get()); - addEffect(effects, &getAccMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getAccMutable(), MemoryEffects::Write::get()); - addEffect(effects, &getRecvPingMutable(), MemoryEffects::Write::get()); -} - -void WaitAsyncEventOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getEventMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getSessionMutable(), MemoryEffects::Read::get()); - addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); -} - -void TestAsyncEventOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getEventMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getSessionMutable(), MemoryEffects::Read::get()); - addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); -} - -void InitializeL2G2LPipeOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getGmAddrMutable(), MemoryEffects::Read::get()); - auto localAddr = getLocalAddrMutable(); - if (!localAddr.empty()) - addEffect(effects, &*localAddr.begin(), MemoryEffects::Read::get()); - auto peerLocalAddr = getPeerLocalAddrMutable(); - if (!peerLocalAddr.empty()) - addEffect(effects, &*peerLocalAddr.begin(), MemoryEffects::Read::get()); - addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); -} - -void InitializeL2LPipeOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getLocalAddrMutable(), MemoryEffects::Read::get()); - addEffect(effects, getOperation()->getOpResult(0), MemoryEffects::Write::get()); -} - -void TPushOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getTileMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getPipeHandleMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getPipeHandleMutable(), MemoryEffects::Write::get()); -} - -void TAllocOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getEntryMutable(), MemoryEffects::Write::get()); - addEffect(effects, &getPipeHandleMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getPipeHandleMutable(), MemoryEffects::Write::get()); -} - -void TPopOp::getEffects( - SmallVectorImpl> - &effects) { - addEffect(effects, &getPipeHandleMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getPipeHandleMutable(), MemoryEffects::Write::get()); - addEffect(effects, &getTileMutable(), MemoryEffects::Write::get()); -} - -void TFreeOp::getEffects( - SmallVectorImpl> - &effects) { - auto entry = getEntryMutable(); - if (!entry.empty()) - addEffect(effects, &*entry.begin(), MemoryEffects::Read::get()); - addEffect(effects, &getPipeHandleMutable(), MemoryEffects::Read::get()); - addEffect(effects, &getPipeHandleMutable(), MemoryEffects::Write::get()); -} - -// [Include 必须放在最后] -#include "PTO/IR/PTOInterfaces.cpp.inc" -#define GET_OP_CLASSES -#include "PTO/IR/PTOOps.cpp.inc" diff --git a/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp b/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp index 0e1e75998..23a4032a6 100644 --- a/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp +++ b/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp @@ -6,5 +6,2571 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. +//===--------- SyncSolver.cpp ------- Graph Sync Solver -------------------===// +//===----------------------------------------------------------------------===// -#include "SyncSolver.def" +#include "PTO/Transforms/GraphSyncSolver/SyncSolver.h" +#include "PTO/Transforms/GraphSyncSolver/GraphSolver.h" +#include "PTO/Transforms/GraphSyncSolver/MemInfo.h" +#include "PTO/Transforms/GraphSyncSolver/SyncSolverIR.h" +#include "PTO/Transforms/GraphSyncSolver/Utility.h" + +#include "PTO/IR/PTO.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/LogicalResult.h" +#include +#include +#include +#include +#include +#include +#include + +#define DEBUG_TYPE "PTO-gss-solver" + +using namespace mlir; +using namespace pto::syncsolver; + +// Reset per-pass bookkeeping to start fresh. +void Solver::reset(bool resetEventIdRanOutOpts) { + if (resetEventIdRanOutOpts) { + reusePairs.clear(); + disabledMultiEventIdPairs.clear(); + backwardSyncEventsAfterMerge.clear(); + moveBackwardSyncPairsToOutmostLoop = false; + dontMoveBackwardSyncPairsToOutmostLoop = false; + } + skipOcc.clear(); + syncedPairs.clear(); + processedOccPairs.clear(); + chosenConflictedPairs.clear(); + scopeOccChosenConflicts.clear(); + scopeOccPairChosenConflicts.clear(); + backwardSyncEvents.clear(); + replacedWithReusableSyncedPairs.clear(); + reusedPairs.clear(); + barrierAllPairs.clear(); + insertedBarrierAllBefore.clear(); + eventIdSolver.clear(); + resetUnitFlag(); +} + +void Solver::resetUnitFlag() { + for (auto *rwOp : unitFlagFeaturedOps) { + rwOp->mergedUnitFlagInfo.reset(); + for (auto *occ : opAllOccurrences[rwOp]) { + occ->unitFlagInfo.reset(); + } + } +} + +// Helpers to find first/last iteration occurrences relative to parent +// occurrences. +Occurrence *Solver::getFirstIterOcc(Occurrence *occ, Occurrence *parOcc) { + assert(occ != nullptr && parOcc != nullptr); + if (parOcc->depth + 1 < occ->depth) { + auto *newParOcc = getFirstIterOcc( + occ->getNthParent(occ->depth - parOcc->depth - 1), parOcc); + return getFirstIterOcc(occ, newParOcc); + } + auto *it = + std::find_if(parOcc->childOccs.begin(), parOcc->childOccs.end(), + [occ](Occurrence *curOcc) { return occ->op == curOcc->op; }); + assert(it != parOcc->childOccs.end()); + return *it; +} + +Occurrence *Solver::getLastIterOcc(Occurrence *occ, Occurrence *parOcc) { + assert(occ != nullptr && parOcc != nullptr); + if (parOcc->depth + 1 < occ->depth) { + auto *newParOcc = getLastIterOcc( + occ->getNthParent(occ->depth - parOcc->depth - 1), parOcc); + return getLastIterOcc(occ, newParOcc); + } + auto it = + std::find_if(parOcc->childOccs.rbegin(), parOcc->childOccs.rend(), + [occ](Occurrence *curOcc) { return occ->op == curOcc->op; }); + assert(it != parOcc->childOccs.rend()); + return *it; +} + +bool Solver::checkSkipCrossCorePair(Occurrence *occ1, Occurrence *occ2) { + if (!options.isCrossCoreMode()) { + return false; + } + auto *rwOp1 = llvm::dyn_cast(occ1->op); + auto *rwOp2 = llvm::dyn_cast(occ2->op); + assert(rwOp1 != nullptr && rwOp2 != nullptr); + assert(rwOp1->coreType != pto::TCoreType::CUBE_OR_VECTOR); + assert(rwOp2->coreType != pto::TCoreType::CUBE_OR_VECTOR); + if (rwOp1->coreType == rwOp2->coreType) { + return true; + } + if (rwOp1->coreType == pto::TCoreType::CUBE_AND_VECTOR) { + return true; + } + return false; +} + +bool Solver::checkSkipParallelLoop(Occurrence *occ1, Occurrence *occ2) { + if (!isBackwardSync(occ1, occ2)) { + return false; + } + auto [parOcc1, parOcc2] = Occurrence::getLCAPair(occ1, occ2); + assert(parOcc1 != nullptr && parOcc2 != nullptr); + auto *parentLCALoopOcc = Occurrence::getParentloop(parOcc1); + assert(parentLCALoopOcc != nullptr); + auto *parentLCALoopOp = llvm::cast(parentLCALoopOcc->op); + return parentLCALoopOp->isParallel; +} + +// Check whether occurrences belong to impossible (if-else) pairing. +bool Solver::checkImpossibleOccPair(Occurrence *occ1, Occurrence *occ2) { + assert(occ1 != nullptr && occ2 != nullptr); + if (occ1->op == occ2->op) { + return false; + } + auto [parOcc1, parOcc2] = Occurrence::getLCAPair(occ1, occ2); + assert(parOcc1 != nullptr && parOcc2 != nullptr); + bool isIfElseSituation = + parOcc1->parentOcc != nullptr && + parOcc1->parentOcc == parOcc2->parentOcc && + llvm::isa_and_present(parOcc1->parentOcc->op); + return isIfElseSituation; +} + +// Detect whether occ1 and occ2 have already been covered by an earlier sync. +bool Solver::checkAlreadySynced(Occurrence *occ1, Occurrence *occ2) { + assert(occ1 != nullptr && occ2 != nullptr); + assert(occ1->op != nullptr && occ2->op != nullptr); + + auto [parOcc1, parOcc2] = Occurrence::getLCAPair(occ1, occ2); + assert(parOcc1 != nullptr && parOcc2 != nullptr); + assert(parOcc1->parentOcc != nullptr && parOcc2->parentOcc != nullptr); + + auto [parOp1, parOp2] = OperationBase::getLCAPair(occ1->op, occ2->op); + assert(parOp1 != nullptr && parOp2 != nullptr); + assert(parOp1->parentOp != nullptr && parOp2->parentOp != nullptr); + + auto *parentLoop = OperationBase::getParentloop(parOcc1->op); + auto *curLoop = OperationBase::getParentloop(parOp1); + if (parentLoop == nullptr || parentLoop == curLoop) { + return false; + } + + assert(curLoop != nullptr); + assert(parentLoop->isProperAncestor(curLoop)); + while (curLoop != parentLoop) { + if (!llvm::cast(curLoop)->isParallel) { + return true; + } + curLoop = OperationBase::getParentloop(curLoop); + assert(curLoop != nullptr); + } + return false; +} + +// Unit-flag reuse check between two RWOperations. +bool Solver::checkAlreadySyncedWithUnitFlag(Occurrence *occ1, + Occurrence *occ2) { + assert(occ1 != nullptr && occ2 != nullptr); + if (!options.enableUnitFlagFeature) { + return false; + } + if (!occ1->hasUnitFlagFeat || !occ2->hasUnitFlagFeat) { + return false; + } + llvm::DenseSet visited; + DEBUG_WITH_TYPE("gss-sync-solver-check-unit-flag", { + llvm::dbgs() << "unit-flag-step: " << occ1->syncIrIndex << ' ' + << occ1->op->str(0, false) << "\n"; + }); + Occurrence *curOcc = occ1->unitFlagInfo.linkedElementAsSet; + while (curOcc != nullptr) { + DEBUG_WITH_TYPE("gss-sync-solver-check-unit-flag", { + llvm::dbgs() << "unit-flag-step: " << curOcc->syncIrIndex << ' ' + << curOcc->op->str(0, false) << "\n"; + }); + auto [it, isInserted] = visited.insert(curOcc); + if (!isInserted) { + break; + } + if (curOcc == occ2) { + return true; + } + curOcc = curOcc->unitFlagInfo.linkedElementAsSet; + } + return false; +} + +bool Solver::ignoreMemoryConflict(RWOperation *rwOp1, RWOperation *rwOp2, + const MemInfo &memInfo1, + const MemInfo &memInfo2) { + if (options.isIntraCoreMode()) { + if (memInfo1.isWorkSpace && memInfo2.isWorkSpace) { + if (options.intraCoreIgnoreWorkSpaceFunctionArguments) { + return true; + } + } + } + return false; +} + +bool Solver::checkMemInfoConflict(RWOperation *rwOp1, RWOperation *rwOp2, + const MemInfo &memInfo1, + const MemInfo &memInfo2, + std::optional lcmLen, + std::optional eventIdNum) { + if (ignoreMemoryConflict(rwOp1, rwOp2, memInfo1, memInfo2)) { + return false; + } + return MemInfo::checkConflict(memInfo1, memInfo2, lcmLen, eventIdNum); +} + +bool Solver::checkMemInfoConflict( + RWOperation *rwOp1, RWOperation *rwOp2, + const llvm::SmallVector &memInfoList1, + const llvm::SmallVector &memInfoList2, + std::optional lcmLen, std::optional eventIdNum) { + for (auto &memInfo1 : memInfoList1) { + for (auto &memInfo2 : memInfoList2) { + if (checkMemInfoConflict(rwOp1, rwOp2, memInfo1, memInfo2, lcmLen, + eventIdNum)) { + return true; + } + } + } + return false; +} + +// High-level wrapper computing pipe pairs that represent memory conflicts +// between two RW ops. +llvm::SmallVector> +Solver::checkMemoryConflicts(RWOperation *rwOp1, RWOperation *rwOp2) { + assert(rwOp1 != nullptr && rwOp2 != nullptr); + auto [it, isInserted] = checkMemoryConflictsMem.insert({{rwOp1, rwOp2}, {}}); + if (!isInserted) { + return it->second; + } + auto coreSrc = rwOp1->coreType; + auto coreDst = rwOp2->coreType; + if (options.isCrossCoreMode()) { + if (coreDst == pto::TCoreType::CUBE_AND_VECTOR) { + coreDst = (coreSrc == pto::TCoreType::VECTOR) ? pto::TCoreType::CUBE + : pto::TCoreType::VECTOR; + } + assert(coreSrc == pto::TCoreType::VECTOR || + coreSrc == pto::TCoreType::CUBE); + assert(coreDst == pto::TCoreType::VECTOR || + coreDst == pto::TCoreType::CUBE); + } + llvm::SetVector> collectedConflictsSet; + if (checkMemInfoConflict(rwOp1, rwOp2, rwOp1->readMemInfo, + rwOp2->writeMemInfo)) { + collectedConflictsSet.insert({CorePipeInfo(coreSrc, rwOp1->pipeRead), + CorePipeInfo(coreDst, rwOp2->pipeWrite)}); + } + if (checkMemInfoConflict(rwOp1, rwOp2, rwOp1->writeMemInfo, + rwOp2->readMemInfo)) { + collectedConflictsSet.insert({CorePipeInfo(coreSrc, rwOp1->pipeWrite), + CorePipeInfo(coreDst, rwOp2->pipeRead)}); + } + if (checkMemInfoConflict(rwOp1, rwOp2, rwOp1->writeMemInfo, + rwOp2->writeMemInfo)) { + collectedConflictsSet.insert({CorePipeInfo(coreSrc, rwOp1->pipeWrite), + CorePipeInfo(coreDst, rwOp2->pipeWrite)}); + } + llvm::SmallVector> collectedConflicts( + collectedConflictsSet.begin(), collectedConflictsSet.end()); + return it->second = collectedConflicts; +} + +bool Solver::checkMemoryConflictBetweenOccExclusive( + Occurrence *occ1, Occurrence *occ2, + std::function filter) { + assert(occ1 != nullptr && occ2 != nullptr); + auto *rwOp1 = llvm::dyn_cast_if_present(occ1->op); + auto *rwOp2 = llvm::dyn_cast_if_present(occ2->op); + assert(rwOp1 != nullptr && rwOp2 != nullptr); + for (int i = occ1->syncIrEndIndex; i < occ2->syncIrIndex; i++) { + if (auto *otherOp = llvm::dyn_cast_if_present(syncIr[i]->op)) { + if (!filter(otherOp)) { + continue; + } + if (!checkMemoryConflicts(rwOp1, otherOp).empty()) { + return true; + } + if (!checkMemoryConflicts(rwOp2, otherOp).empty()) { + return true; + } + } + } + return false; +} + +std::optional +Solver::getMultiBufferLoop(RWOperation *rwOp1, RWOperation *rwOp2, + const llvm::SmallVector &memInfoList1, + const llvm::SmallVector &memInfoList2) { + std::optional multibufferLoop; + for (auto &memInfo1 : memInfoList1) { + for (auto &memInfo2 : memInfoList2) { + if (checkMemInfoConflict(rwOp1, rwOp2, memInfo1, memInfo2)) { + if (!memInfo1.pointerLikeInfo.has_value() || + !memInfo2.pointerLikeInfo.has_value()) { + return {}; + } + auto multibufferLoop1 = memInfo1.pointerLikeInfo->parentLoop; + auto multibufferLoop2 = memInfo2.pointerLikeInfo->parentLoop; + if (multibufferLoop1 == nullptr || + multibufferLoop1 != multibufferLoop2) { + return {}; + } + if (multibufferLoop.has_value() && + multibufferLoop.value() != multibufferLoop1) { + return {}; + } + multibufferLoop = multibufferLoop1; + } + } + } + return multibufferLoop; +} + +std::optional +Solver::getMultiBufferLoop(RWOperation *rwOp1, RWOperation *rwOp2) { + std::optional multibufferLoop; + if (checkMemInfoConflict(rwOp1, rwOp2, rwOp1->readMemInfo, + rwOp2->writeMemInfo)) { + auto curMultibufferLoop = getMultiBufferLoop( + rwOp1, rwOp2, rwOp1->readMemInfo, rwOp2->writeMemInfo); + if (multibufferLoop.has_value() && + multibufferLoop.value() != curMultibufferLoop) { + return {}; + } + multibufferLoop = curMultibufferLoop; + } + if (checkMemInfoConflict(rwOp1, rwOp2, rwOp1->writeMemInfo, + rwOp2->readMemInfo)) { + auto curMultibufferLoop = getMultiBufferLoop( + rwOp1, rwOp2, rwOp1->writeMemInfo, rwOp2->readMemInfo); + if (multibufferLoop.has_value() && + multibufferLoop.value() != curMultibufferLoop) { + return {}; + } + multibufferLoop = curMultibufferLoop; + } + if (checkMemInfoConflict(rwOp1, rwOp2, rwOp1->writeMemInfo, + rwOp2->writeMemInfo)) { + auto curMultibufferLoop = getMultiBufferLoop( + rwOp1, rwOp2, rwOp1->writeMemInfo, rwOp2->writeMemInfo); + if (multibufferLoop.has_value() && + multibufferLoop.value() != curMultibufferLoop) { + return {}; + } + multibufferLoop = curMultibufferLoop; + } + return multibufferLoop; +} + +std::optional +Solver::getMultiBufferEventIdInfo(Occurrence *occ1, Occurrence *occ2, + RWOperation *rwOp1, RWOperation *rwOp2) { + assert(rwOp1 != nullptr && rwOp2 != nullptr); + + int64_t lcm = 1; + int64_t minWriteSize = LONG_MAX; + LoopLikeOpInterface multibufferLoop{nullptr}; + + if (options.isTestMode()) { + auto *parLoop1 = occ1->getParentOfType(); + auto *parLoop2 = occ2->getParentOfType(); + if (!parLoop1 || parLoop1 != parLoop2) { + return {}; + } + auto [setOcc, waitOcc] = getSetWaitOcc(occ1, occ2); + if (!parLoop1->isProperAncestor(setOcc) || + !parLoop1->isProperAncestor(waitOcc)) { + return {}; + } + } else { + auto multibufferLoopOpt = getMultiBufferLoop(rwOp1, rwOp2); + if (!multibufferLoopOpt.has_value() || !multibufferLoopOpt.value()) { + return {}; + } + multibufferLoop = multibufferLoopOpt.value(); + assert(multibufferLoop != nullptr); + auto [setOcc, waitOcc] = getSetWaitOcc(occ1, occ2); + if (!setOcc->getParentWithOp(multibufferLoop, + /*assertExists=*/false) || + !waitOcc->getParentWithOp(multibufferLoop, + /*assertExists=*/false)) { + return {}; + } + } + + for (auto &memInfo1 : rwOp1->readMemInfo) { + for (auto &memInfo2 : rwOp2->writeMemInfo) { + if (checkMemInfoConflict(rwOp1, rwOp2, memInfo1, memInfo2)) { + int64_t curLcm = std::lcm(memInfo1.getSz(), memInfo2.getSz()); + lcm = std::lcm(lcm, curLcm); + minWriteSize = std::min(minWriteSize, memInfo2.getSz()); + } + } + } + for (auto &memInfo1 : rwOp1->writeMemInfo) { + for (auto &memInfo2 : rwOp2->readMemInfo) { + if (checkMemInfoConflict(rwOp1, rwOp2, memInfo1, memInfo2)) { + int64_t curLcm = std::lcm(memInfo1.getSz(), memInfo2.getSz()); + lcm = std::lcm(lcm, curLcm); + minWriteSize = std::min(minWriteSize, memInfo1.getSz()); + } + } + } + for (auto &memInfo1 : rwOp1->writeMemInfo) { + for (auto &memInfo2 : rwOp2->writeMemInfo) { + if (checkMemInfoConflict(rwOp1, rwOp2, memInfo1, memInfo2)) { + int64_t curLcm = std::lcm(memInfo1.getSz(), memInfo2.getSz()); + lcm = std::lcm(lcm, curLcm); + minWriteSize = std::min(minWriteSize, memInfo1.getSz()); + minWriteSize = std::min(minWriteSize, memInfo2.getSz()); + } + } + } + + // In case no write sizes were positive. + if (minWriteSize == LONG_MAX) { + minWriteSize = 1; + return {}; + } + + int64_t eventIdNum = minWriteSize; + for (; eventIdNum >= 1; eventIdNum--) { + // llvm::dbgs() << "checking event-id-num: " << eventIdNum << '\n'; + int64_t curLcm = std::lcm(lcm, eventIdNum); + bool okRW = !checkMemInfoConflict(rwOp1, rwOp2, rwOp1->readMemInfo, + rwOp2->writeMemInfo, curLcm, eventIdNum); + bool okWR = !checkMemInfoConflict(rwOp1, rwOp2, rwOp1->writeMemInfo, + rwOp2->readMemInfo, curLcm, eventIdNum); + bool okWW = !checkMemInfoConflict(rwOp1, rwOp2, rwOp1->writeMemInfo, + rwOp2->writeMemInfo, curLcm, eventIdNum); + if (okRW && okWR && okWW) { + break; + } + } + if (eventIdNum <= 1) { + return {}; + } + EventIdInfo eventIdInfo(eventIdNum); + eventIdInfo.multibufferLoop = multibufferLoop; + return eventIdInfo; +} + +std::optional +Solver::checkMultiBufferEventIdInfo(Occurrence *occ1, Occurrence *occ2, + RWOperation *rwOp1, RWOperation *rwOp2) { + assert(rwOp1 != nullptr && rwOp2 != nullptr); + if (!options.isTestMode()) { + if (!checkAllParentLoopsAreForLoops(rwOp1->op) || + !checkAllParentLoopsAreForLoops(rwOp2->op)) { + return {}; + } + } + if (auto eventIdInfo = getMultiBufferEventIdInfo(occ1, occ2, rwOp1, rwOp2)) { + return eventIdInfo; + } + return {}; +} + +std::optional +Solver::checkCVMultiBufferUnrollEventIdInfo(RWOperation *rwOp1, + RWOperation *rwOp2) { + assert(rwOp1 != nullptr && rwOp2 != nullptr); + if (!options.isCrossCoreMode()) { + return {}; + } + auto *parentLoop1 = rwOp1->getParentOfType(); + auto *parentLoop2 = rwOp2->getParentOfType(); + while (parentLoop1 != nullptr && !parentLoop1->multibufferUnrollNum) { + parentLoop1 = parentLoop1->getParentOfType(); + } + while (parentLoop2 != nullptr && !parentLoop2->multibufferUnrollNum) { + parentLoop2 = parentLoop2->getParentOfType(); + } + if (!parentLoop1 || !parentLoop2) { + return {}; + } + if (auto *parCond1 = rwOp1->getParentOfType()) { + if (!parCond1->isProperAncestor(rwOp2)) { + return {}; + } + } + if (auto *parCond2 = rwOp2->getParentOfType()) { + if (!parCond2->isProperAncestor(rwOp1)) { + return {}; + } + } + assert(parentLoop1->multibufferUnrollNum.value() == + parentLoop2->multibufferUnrollNum.value()); + EventIdInfo eventIdInfo; + eventIdInfo.eventIdNum = parentLoop1->multibufferUnrollNum.value(); + eventIdInfo.multibufferUnrollLoop1 = + cast(parentLoop1->op); + eventIdInfo.multibufferUnrollLoop2 = + cast(parentLoop2->op); + return eventIdInfo; +} + +std::optional +Solver::checkCVMultiBufferPreloadEventIdInfo(RWOperation *rwOp1, + RWOperation *rwOp2) { + assert(rwOp1 != nullptr && rwOp2 != nullptr); + if (!options.isCrossCoreMode()) { + return {}; + } + auto *parentScope1 = rwOp1->getParentOfType(); + auto *parentScope2 = rwOp2->getParentOfType(); + while (parentScope1 != nullptr && !parentScope1->maxPreloadNum.has_value()) { + parentScope1 = parentScope1->getParentOfType(); + } + while (parentScope2 != nullptr && !parentScope2->maxPreloadNum.has_value()) { + parentScope2 = parentScope2->getParentOfType(); + } + if (!parentScope1 || !parentScope2) { + return {}; + } + if (auto *parCond1 = rwOp1->getParentOfType()) { + if (!parCond1->isProperAncestor(rwOp2)) { + return {}; + } + } + if (auto *parCond2 = rwOp2->getParentOfType()) { + if (!parCond2->isProperAncestor(rwOp1)) { + return {}; + } + } + + auto *parentLoop1 = parentScope1->getParentOfType(); + auto *parentLoop2 = parentScope2->getParentOfType(); + if (parentLoop1 == nullptr || parentLoop1 != parentLoop2) { + return {}; + } + + assert(parentScope1->preloadNum.has_value()); + assert(parentScope2->preloadNum.has_value()); + assert(parentScope1->maxPreloadNum.value() == + parentScope2->maxPreloadNum.value()); + + auto parentForLoop = llvm::dyn_cast_if_present(parentLoop1->op); + assert(parentForLoop != nullptr); + + EventIdInfo eventIdInfo; + eventIdInfo.eventIdNum = parentScope1->maxPreloadNum.value(); + eventIdInfo.preloadOffset1 = parentScope1->maxPreloadNum.value() - + parentScope1->preloadNum.value() - 1; + eventIdInfo.preloadOffset2 = parentScope2->maxPreloadNum.value() - + parentScope2->preloadNum.value() - 1; + eventIdInfo.multibufferLoop = parentForLoop; + return eventIdInfo; +} + +// Determine required event id count and optional multibuffer loop parent for +// occurrences. +EventIdInfo Solver::getEventIdInfo(Occurrence *occ1, Occurrence *occ2, + RWOperation *rwOp1, RWOperation *rwOp2, + CorePipeInfo corePipeSrc, + CorePipeInfo corePipeDst) { + assert(occ1 != nullptr && occ2 != nullptr); + assert(rwOp1 != nullptr && rwOp2 != nullptr); + EventIdInfo singleEventId(1); + if (!isBackwardSync(occ1, occ2)) { + return singleEventId; + } + if (auto eventIdInfo = checkCVMultiBufferUnrollEventIdInfo(rwOp1, rwOp2)) { + return eventIdInfo.value(); + } + if (auto eventIdInfo = checkCVMultiBufferPreloadEventIdInfo(rwOp1, rwOp2)) { + return eventIdInfo.value(); + } + if (auto eventIdInfo = + checkMultiBufferEventIdInfo(occ1, occ2, rwOp1, rwOp2)) { + return eventIdInfo.value(); + } + return singleEventId; +} + +// Graph-based check to determine if adding a sync between occ1 and occ2 would +// block progress. Uses GraphSolver (Dijkstra) to estimate minimal reachable +// index. +bool Solver::checkGraphConflict( + Occurrence *occ1, Occurrence *occ2, CorePipeInfo corePipeSrc, + CorePipeInfo corePipeDst, EventIdInfo eventIdInfo, + std::optional startIndex, std::optional endIndex, + const llvm::SmallVector &extraConflictPairs, + const llvm::SmallVector &ignoreConflictPairs) { + assert(occ1 != nullptr && occ2 != nullptr); + if (!startIndex.has_value()) { + startIndex = occ1->endIndex; + } + if (!endIndex.has_value()) { + endIndex = occ2->startIndex; + } + GraphSolver graphSolver(options); + llvm::DenseSet visited; + auto handleConflictPair = [&](ConflictPair *conflictPair) { + if (conflictPair->couldNotRun) { + return; + } + if (conflictPair->endIndex < startIndex.value() || + conflictPair->startIndex > endIndex.value()) { + return; + } + if (conflictPair->isInnerBackward) { + if ((eventIdInfo.eventIdNum * eventIdInfo.eventIdRepeatNum) < + (conflictPair->eventIdInfo.eventIdNum * + conflictPair->eventIdInfo.eventIdRepeatNum)) { + return; + } + } + if (llvm::find(ignoreConflictPairs, conflictPair) != + ignoreConflictPairs.end()) { + return; + } + auto [it, isInserted] = visited.insert(conflictPair); + if (!isInserted) { + return; + } + DEBUG_WITH_TYPE("gss-sync-solver-check-graph-conflict", { + llvm::dbgs() << "add-conflict-pair: " << conflictPair->str() << '\n'; + }); + graphSolver.addConflictPair(conflictPair); + }; + + for (auto *parOcc : occ1->getAllParents()) { + if (scopeOccChosenConflicts.contains(parOcc)) { + for (auto *conflictPair : scopeOccChosenConflicts[parOcc]) { + handleConflictPair(conflictPair); + } + } + } + for (auto *parOcc : occ2->getAllParents()) { + if (scopeOccChosenConflicts.contains(parOcc)) { + for (auto *conflictPair : scopeOccChosenConflicts[parOcc]) { + handleConflictPair(conflictPair); + } + } + } + for (auto &[scopeOccPair, chosenConflicts] : scopeOccPairChosenConflicts) { + auto [scopeOcc1, scopeOcc2] = scopeOccPair; + if (scopeOcc1->isProperAncestor(occ1) && + scopeOcc2->isProperAncestor(occ2)) { + for (auto *conflictPair : chosenConflicts) { + handleConflictPair(conflictPair); + } + } + } + for (auto *parOcc : occ1->getAllParents()) { + if (persistentScopeOccChosenConflicts.contains(parOcc)) { + for (auto *conflictPair : persistentScopeOccChosenConflicts[parOcc]) { + handleConflictPair(conflictPair); + } + } + } + for (auto *parOcc : occ2->getAllParents()) { + if (persistentScopeOccChosenConflicts.contains(parOcc)) { + for (auto *conflictPair : persistentScopeOccChosenConflicts[parOcc]) { + handleConflictPair(conflictPair); + } + } + } + for (auto *conflictPair : extraConflictPairs) { + handleConflictPair(conflictPair); + } + std::optional mnDistance; + if (options.enableUnitFlagFeature) { + mnDistance = graphSolver.runDijkstraUnitFlagEnabled( + occ1, occ2, corePipeSrc, corePipeDst, startIndex.value(), + endIndex.value()); + } else { + mnDistance = graphSolver.runDijkstra(corePipeSrc, corePipeDst, + startIndex.value(), endIndex.value()); + } + return !mnDistance.has_value() || mnDistance.value() > endIndex.value(); +} + +bool Solver::checkSyncOpsConflicts(ConflictPair *conflictPair1, + ConflictPair *conflictPair2) { + if (conflictPair1->isBarrier() || conflictPair2->isBarrier()) { + return false; + } + if (conflictPair1->startIndex > conflictPair2->startIndex) { + std::swap(conflictPair1, conflictPair2); + } + if (conflictPair1->startIndex >= conflictPair2->startIndex || + conflictPair1->endIndex >= conflictPair2->endIndex) { + return true; + } + bool result = false; + if (conflictPair1->setCorePipeInfo != conflictPair2->setCorePipeInfo) { + auto corePipeSrc = conflictPair1->setCorePipeInfo; + auto corePipeDst = conflictPair2->setCorePipeInfo; + Occurrence *occ1 = conflictPair1->setOcc; + Occurrence *occ2 = conflictPair2->setOcc; + auto startIndex = conflictPair1->startIndex + 1; + auto endIndex = conflictPair2->startIndex; + conflictPair1->startIndex += 1; + assert(occ1 != nullptr && occ2 != nullptr); + result = result || + checkGraphConflict(occ1, occ2, corePipeSrc, corePipeDst, + conflictPair1->eventIdInfo, startIndex, + endIndex, {conflictPair1}, {conflictPair2}); + conflictPair1->startIndex -= 1; + } + if (conflictPair1->waitCorePipeInfo != conflictPair2->waitCorePipeInfo) { + auto corePipeSrc = conflictPair1->waitCorePipeInfo; + auto corePipeDst = conflictPair2->waitCorePipeInfo; + Occurrence *occ1 = conflictPair1->waitOcc; + Occurrence *occ2 = conflictPair2->waitOcc; + auto startIndex = conflictPair1->endIndex; + auto endIndex = conflictPair2->endIndex - 1; + conflictPair2->endIndex -= 1; + assert(occ1 != nullptr && occ2 != nullptr); + result = result || + checkGraphConflict(occ1, occ2, corePipeSrc, corePipeDst, + conflictPair1->eventIdInfo, startIndex, + endIndex, {conflictPair1}, {conflictPair2}); + conflictPair2->endIndex += 1; + } + DEBUG_WITH_TYPE("gss-check-sync-ops-conflicts", { + if (result) { + llvm::dbgs() << "sync-ops-conflict-found: " << "\n"; + llvm::dbgs() << " " << conflictPair1->str() << '\n'; + llvm::dbgs() << " " << conflictPair2->str() << '\n'; + } + }); + return result; +} + +// Check whether two ConflictPair entries conflict in pipe and time ranges. +bool Solver::checkIntersect(ConflictPair *conflictPair1, + ConflictPair *conflictPair2) { + assert(conflictPair1 != nullptr && conflictPair2 != nullptr); + if (conflictPair1 == conflictPair2) { + return false; + } + if (conflictPair1->isBarrier() || conflictPair2->isBarrier()) { + return false; + } + if (conflictPair1->dontCheckForConflict || + conflictPair2->dontCheckForConflict) { + return false; + } + if (options.isCrossCoreMode()) { + return checkSyncOpsConflicts(conflictPair1, conflictPair2); + } + if (conflictPair1->setCorePipeInfo != conflictPair2->setCorePipeInfo || + conflictPair1->waitCorePipeInfo != conflictPair2->waitCorePipeInfo) { + return false; + } + for (auto [l1, r1] : getRanges(conflictPair1)) { + for (auto [l2, r2] : getRanges(conflictPair2)) { + if (checkRangesIntersect(l1, r1 + 1, l2, r2 + 1)) { + return true; + } + } + } + return false; +} + +// Obtain available event ids while accounting for already chosen conflicts. +std::vector +Solver::getIntersectingConflictPairs(ConflictPair *conflictPair) { + assert(conflictPair != nullptr); + if (conflictPair->isBarrier()) { + return {}; + } + if (conflictPair->dontCheckForConflict) { + return {}; + } + std::vector intersectingConflictPairs; + for (auto &curConflictPair : chosenConflictedPairs) { + if (checkIntersect(conflictPair, curConflictPair.get())) { + intersectingConflictPairs.push_back(curConflictPair.get()); + } + } + for (auto &curConflictPair : persistentChosenConflictedPairs) { + if (checkIntersect(conflictPair, curConflictPair.get())) { + intersectingConflictPairs.push_back(curConflictPair.get()); + } + } + return intersectingConflictPairs; +} + +// Processed-pair tracking helpers. +bool Solver::checkVisited(Occurrence *occ1, Occurrence *occ2) { + auto [it, isInserted] = processedOccPairs.insert(std::make_pair(occ1, occ2)); + return !isInserted; +} + +bool Solver::checkSkippable(bool reverseOrder, Occurrence *occ) { + return skipOcc[reverseOrder].contains(occ); +} + +// Synced-pair memoization helpers. +EventIdNode *Solver::getOldEventIdNodeIfExists(ConflictPair *conflictPair) { + assert(conflictPair != nullptr); + auto oldConflictPairs = getMemorizedSyncedPairs(conflictPair); + if (oldConflictPairs.empty()) { + return {}; + } + ConflictPair *oldConflictPair = *oldConflictPairs.begin(); + assert(oldConflictPair != nullptr && oldConflictPair->eventIdNode != nullptr); + return oldConflictPair->eventIdNode; +} + +llvm::DenseSet +Solver::getMemorizedSyncedPairs(ConflictPair *conflictPair) { + auto key = std::make_tuple( + conflictPair->backwardSyncLoopOp, conflictPair->op1, conflictPair->op2, + conflictPair->setCorePipeInfo, conflictPair->waitCorePipeInfo); + return syncedPairs[key]; +} + +void Solver::memorizeSyncedPair(ConflictPair *conflictPair) { + auto key = std::make_tuple( + conflictPair->backwardSyncLoopOp, conflictPair->op1, conflictPair->op2, + conflictPair->setCorePipeInfo, conflictPair->waitCorePipeInfo); + syncedPairs[key].insert(conflictPair); +#ifndef NDEBUG + for (auto *oldConflictPair : syncedPairs[key]) { + assert(oldConflictPair->eventIdNode == conflictPair->eventIdNode); + } +#endif +} + +void Solver::forgetSyncedPair(ConflictPair *conflictPair) { + assert(conflictPair != nullptr); + auto key = std::make_tuple( + conflictPair->backwardSyncLoopOp, conflictPair->op1, conflictPair->op2, + conflictPair->setCorePipeInfo, conflictPair->waitCorePipeInfo); + syncedPairs[key].erase(conflictPair); +} + +void Solver::memorizeReusedSyncedPair(ConflictPair *conflictPair, + ConflictPair *reusedConflictPair) { + assert(conflictPair != nullptr); + replacedWithReusableSyncedPairs[{ + conflictPair->backwardSyncLoopOp, conflictPair->op1, conflictPair->op2, + conflictPair->setCorePipeInfo, conflictPair->waitCorePipeInfo}] = + reusedConflictPair; +} + +bool Solver::skipMMad1DecomposedLoopOpt(Occurrence *occ1, Occurrence *occ2) { + auto *parentLoopOp1 = OperationBase::getParentloop(occ1->op); + auto *parentLoopOp2 = OperationBase::getParentloop(occ2->op); + if (parentLoopOp1 != nullptr && parentLoopOp2 != nullptr) { + if (parentLoopOp1 != parentLoopOp2) { + if (isa(parentLoopOp1) && + isa(parentLoopOp2)) { + return true; + } + } + } + return false; +} + +std::optional> +Solver::checkAndApplyMmadl0LoopOpt(ConflictPair *conflictPair, Occurrence *occ1, + Occurrence *occ2, Occurrence *parOcc1, + Occurrence *parOcc2) { + if (!options.decomposeMmadl1Op) { + return {}; + } + if (occ1->parentOcc != nullptr && occ1->parentOcc->parentOcc != nullptr && + occ1->parentOcc->parentOcc->parentOcc == parOcc1 && + llvm::isa_and_present( + occ1->op) && + llvm::isa_and_present( + occ1->parentOcc->parentOcc->op)) { + conflictPair->setOnLastIterOnly = true; + return std::make_pair(occ1, parOcc2); + } + if (!conflictPair->isInnerBackward && occ2->parentOcc != nullptr && + occ2->parentOcc->parentOcc != nullptr && + occ2->parentOcc->parentOcc->parentOcc == parOcc2 && + llvm::isa_and_present( + occ2->op) && + llvm::isa_and_present( + occ2->parentOcc->parentOcc->op)) { + conflictPair->waitOnFirstIterOnly = true; + return std::make_pair(parOcc1, occ2); + } + return {}; +} + +std::optional Solver::checkUnitFlagPatterns(Occurrence *occ1, + Occurrence *occ2) { + return {}; +} + +Occurrence *Solver::getBeforePlaceHolderOcc(Occurrence *occ) { + assert(occ != nullptr); + assert(llvm::isa_and_present(occ->op)); + int index = occ->syncIrIndex - 1; + assert(0 <= index && index < static_cast(syncIr.size())); + auto *placeHolderOcc = syncIr[index].get(); +#ifndef NDEBUG + auto *placeHolderOp = llvm::dyn_cast(placeHolderOcc->op); + assert(placeHolderOp != nullptr); + assert(placeHolderOp->beforeOp == occ->op); +#endif + return placeHolderOcc; +} + +Occurrence *Solver::getAfterPlaceHolderOcc(Occurrence *occ) { + assert(occ != nullptr); + assert(llvm::isa_and_present(occ->op)); + int index = occ->syncIrEndIndex; + assert(0 <= index && index < static_cast(syncIr.size())); + auto *placeHolderOcc = syncIr[index].get(); +#ifndef NDEBUG + auto *placeHolderOp = llvm::dyn_cast(placeHolderOcc->op); + assert(placeHolderOp != nullptr); + assert(placeHolderOp->afterOp == occ->op); +#endif + return placeHolderOcc; +} + +Occurrence *Solver::getScopeBeginPlaceHolderOcc(Occurrence *occ) { + assert(occ != nullptr); + assert(llvm::isa_and_present(occ->op)); + int index = occ->syncIrIndex + 1; + assert(0 <= index && index < static_cast(syncIr.size())); + auto *placeHolderOcc = syncIr[index].get(); +#ifndef NDEBUG + auto *placeHolderOp = llvm::dyn_cast(placeHolderOcc->op); + assert(placeHolderOp != nullptr); + assert(placeHolderOp->scopeBegin == occ->op); +#endif + return placeHolderOcc; +} + +Occurrence *Solver::getScopeEndPlaceHolderOcc(Occurrence *occ) { + assert(occ != nullptr); + assert(llvm::isa_and_present(occ->op)); + int index = occ->syncIrEndIndex - 1; + assert(0 <= index && index < static_cast(syncIr.size())); + auto *placeHolderOcc = syncIr[index].get(); +#ifndef NDEBUG + auto *placeHolderOp = llvm::dyn_cast(placeHolderOcc->op); + assert(placeHolderOp != nullptr); + assert(placeHolderOp->scopeEnd == occ->op); +#endif + return placeHolderOcc; +} + +std::pair +Solver::getSetWaitLCAPairOcc(Occurrence *occ1, Occurrence *occ2) { + assert(occ1 != nullptr && occ2 != nullptr); + + auto [grandParOcc1, grandParOcc2] = Occurrence::getLCAPair(occ1, occ2); + assert(grandParOcc1 != nullptr && grandParOcc2 != nullptr); + assert(grandParOcc1->parentOcc != nullptr && + grandParOcc2->parentOcc != nullptr); + + auto [parOp1, parOp2] = OperationBase::getLCAPair(occ1->op, occ2->op); + assert(parOp1 != nullptr && parOp2 != nullptr); + assert(parOp1->parentOp != nullptr && parOp2->parentOp != nullptr); + assert(parOp1->parentOp == parOp2->parentOp); + + auto *parOcc1 = occ1->getParentWithOp(parOp1->parentOp); + auto *parOcc2 = occ2->getParentWithOp(parOp2->parentOp); + assert(parOcc1 != nullptr && parOcc2 != nullptr); + assert(parOcc1 != occ1 && parOcc2 != occ2); + + auto *setOcc = occ1->getNthParent(occ1->depth - parOcc1->depth - 1); + auto *waitOcc = occ2->getNthParent(occ2->depth - parOcc2->depth - 1); + assert(setOcc != nullptr && waitOcc != nullptr); + assert(parOcc1->isProperAncestor(setOcc)); + assert(parOcc2->isProperAncestor(waitOcc)); + + auto *parLoop = Occurrence::getParentloop(setOcc); + while (parLoop != nullptr && grandParOcc1->isProperAncestor(parLoop)) { + setOcc = parLoop; + waitOcc = Occurrence::getParentloop(waitOcc); + parLoop = Occurrence::getParentloop(setOcc); + } + return std::make_pair(setOcc, waitOcc); +} + +std::pair +Solver::getFixedSetWaitOcc(Occurrence *occ1, Occurrence *occ2) { + // - get setOcc waitOcc where: + // setOcc->op->parent = waitOcc->op->parent = lca(occ1, occ2)->op + auto [setOcc, waitOcc] = getSetWaitLCAPairOcc(occ1, occ2); + + // - check if it's the case of while loop: + // while{ + // before{ + // occ1 + // } + // setOcc; + // waitOcc; + // after{ + // occ2 + // } + // } + // - and fix it to be: + // while{ + // before{ + // occ1 + // setOcc; + // ... + // waitOcc; + // placeHolder + // } + // after{ + // occ2 + // } + // } + if (setOcc->op != waitOcc->op) { + if (auto *parLoopOp = + llvm::dyn_cast_if_present(setOcc->parentOcc->op)) { + if (parLoopOp->body.size() > 1 && !isa(waitOcc->op)) { + auto *placeHolderOcc = getScopeEndPlaceHolderOcc(setOcc); + std::tie(setOcc, waitOcc) = getSetWaitLCAPairOcc(occ1, placeHolderOcc); + } + } + } + + // - check if it's the case of: + // loop(iter-1){ + // condition{ + // true-scope{} + // setOcc() + // false-scope{} + // } + // } + // loop(iter-2){ + // condition{ + // true-scope{} + // waitOcc() + // false-scope{} + // } + // } + // - and fix it to be: + // loop(iter-1){ + // condition{ + // true-scope{} + // false-scope{} + // } + // setOcc() + // } + // loop(iter-2){ + // waitOcc() + // condition{ + // true-scope{} + // false-scope{} + // } + // } + if (isBackwardSync(occ1, occ2)) { + if (setOcc->parentOcc != nullptr) { + if (llvm::isa_and_present(setOcc->parentOcc->op)) { + setOcc = setOcc->parentOcc; + } + } + if (waitOcc->parentOcc != nullptr) { + if (llvm::isa_and_present(waitOcc->parentOcc->op)) { + waitOcc = waitOcc->parentOcc; + } + } + } + + // - for the case of cv-pipelining: + // loop(){ + // op1 + // } {unroll=x} + // setOcc + // waitOcc + // loop(){ + // op2 + // } {unroll=x} + // - and fix it to be: + // loop(){ + // op1 + // setOcc + // } {unroll=x} + // loop(){ + // waitOcc + // op2 + // } {unroll=x} + if (options.isCrossCoreMode()) { + assert(setOcc->op != nullptr && waitOcc->op != nullptr); + auto *forOp1 = llvm::dyn_cast_if_present(setOcc->op); + auto *forOp2 = llvm::dyn_cast_if_present(waitOcc->op); + if (forOp1 != nullptr && forOp2 != nullptr) { + if (forOp1->multibufferUnrollNum && forOp2->multibufferUnrollNum) { + assert(forOp1->multibufferUnrollNum == forOp2->multibufferUnrollNum); + setOcc = occ1->getNthParent(occ1->depth - setOcc->depth - 2); + waitOcc = occ2->getNthParent(occ2->depth - waitOcc->depth - 2); + } + } + } + + // - for the case of cv-pipelining: + // scope(){ + // op1 + // } {preload=x} + // setOcc + // waitOcc + // scope(){ + // op2 + // } {preload=x} + // - and fix it to be: + // scope(){ + // op1 + // setOcc + // } {preload=x} + // scope(){ + // waitOcc + // op2 + // } {preload=x} + if (options.isCrossCoreMode()) { + assert(setOcc->op != nullptr && waitOcc->op != nullptr); + auto *scopeOp1 = llvm::dyn_cast_if_present(setOcc->op); + auto *scopeOp2 = llvm::dyn_cast_if_present(waitOcc->op); + if (scopeOp1 != nullptr && scopeOp2 != nullptr) { + if (scopeOp1->maxPreloadNum && scopeOp2->maxPreloadNum) { + assert(scopeOp1->maxPreloadNum == scopeOp2->maxPreloadNum); + setOcc = occ1->getNthParent(occ1->depth - setOcc->depth - 2); + waitOcc = occ2->getNthParent(occ2->depth - waitOcc->depth - 2); + } + } + } + + // - check if it's the case of: + // { + // op1 + // setOcc + // ... + // waitOcc + // loop(){} + // setOcc + // ... + // waitOcc + // op2 + // } + // - and fix it to be: + // { + // op1 + // setOcc + // ... + // waitOcc + // placeHolder + // loop(){} + // placeHolder + // setOcc + // ... + // waitOcc + // op2 + // } + if (llvm::isa_and_present(setOcc->op)) { + setOcc = getAfterPlaceHolderOcc(setOcc); + } + if (llvm::isa_and_present(waitOcc->op)) { + waitOcc = getBeforePlaceHolderOcc(waitOcc); + } + + return std::make_pair(setOcc, waitOcc); +} + +std::optional> +Solver::getFunctionBlockSetWaitOcc(Occurrence *occ1, Occurrence *occ2) { + assert(occ1 != nullptr && occ2 != nullptr); + auto *parFunctionBlock1 = occ1->getParentOfType(); + auto *parFunctionBlock2 = occ2->getParentOfType(); + if (parFunctionBlock1 == parFunctionBlock2) { + return {}; + } + auto *placeHolderOcc = getScopeBeginPlaceHolderOcc(parFunctionBlock2); + return std::make_pair(placeHolderOcc, occ2); +} + +std::optional> +Solver::getUnlikelyCondSetWaitOcc(Occurrence *occ1, Occurrence *occ2) { + assert(occ1 != nullptr && occ2 != nullptr); + if (options.isCrossCoreMode() && isBackwardSync(occ1, occ2)) { + return {}; + } + if (auto *unlikelyParCondOcc1 = + Occurrence::getUnlikelyParentCondition(occ1)) { + if (!unlikelyParCondOcc1->isProperAncestor(occ2)) { + auto *parentLoopOcc = Occurrence::getParentloop(unlikelyParCondOcc1); + if (parentLoopOcc == nullptr || parentLoopOcc->isProperAncestor(occ2)) { + auto *placeHolderOcc = getScopeEndPlaceHolderOcc( + occ1->getNthParent(occ1->depth - unlikelyParCondOcc1->depth - 1)); + return std::make_pair(occ1, placeHolderOcc); + } + } + } + if (auto *unlikelyParCondOcc2 = + Occurrence::getUnlikelyParentCondition(occ2)) { + if (!unlikelyParCondOcc2->isProperAncestor(occ1)) { + auto *parentLoopOcc = Occurrence::getParentloop(unlikelyParCondOcc2); + if (parentLoopOcc == nullptr || parentLoopOcc->isProperAncestor(occ1)) { + auto *placeHolderOcc = getScopeBeginPlaceHolderOcc( + occ2->getNthParent(occ2->depth - unlikelyParCondOcc2->depth - 1)); + return std::make_pair(placeHolderOcc, occ2); + } + } + } + return {}; +} + +std::pair Solver::getSetWaitOcc(Occurrence *occ1, + Occurrence *occ2) { + if (auto functionBlockOpt = getFunctionBlockSetWaitOcc(occ1, occ2)) { + std::tie(occ1, occ2) = functionBlockOpt.value(); + } + if (auto unlikelyOpt = getUnlikelyCondSetWaitOcc(occ1, occ2)) { + std::tie(occ1, occ2) = unlikelyOpt.value(); + } + return getFixedSetWaitOcc(occ1, occ2); +} + +Occurrence *Solver::getBarrierWaitOcc(Occurrence *occ1, Occurrence *occ2) { + auto [setOcc, waitOcc] = getSetWaitOcc(occ1, occ2); + if (!waitOcc->isProperAncestor(occ2)) { + return waitOcc; + } + auto allParents = occ2->getAllParents(); + while (!allParents.empty() && allParents.back()->isProperAncestor(waitOcc)) { + allParents.pop_back(); + } + while (allParents.size() >= 2 && + llvm::isa_and_present(allParents.back()->op)) { + allParents.pop_back(); + assert(llvm::isa_and_present(allParents.back()->op)); + allParents.pop_back(); + } + waitOcc = !allParents.empty() ? allParents.back() : occ2; + return waitOcc; +} + +void Solver::insertBarrierAllBeforeOcc(Occurrence *occ, bool isUseless, + bool isPersistent) { + assert(occ != nullptr); + auto *rwOp = llvm::dyn_cast_if_present(occ->op); + assert(rwOp != nullptr); + auto conflictPair = std::make_unique( + nullptr, nullptr, rwOp, rwOp, occ, occ, + CorePipeInfo(pto::TCoreType::CUBE_OR_VECTOR, pto::PIPE::PIPE_ALL), + CorePipeInfo(pto::TCoreType::CUBE_OR_VECTOR, pto::PIPE::PIPE_ALL), + occ->startIndex, occ->startIndex); + conflictPair->isUseless = isUseless; + auto *normScopeOcc = occ->parentOcc; + assert(normScopeOcc != nullptr); + LLVM_DEBUG(llvm::dbgs() << (isPersistent ? "is-persistent " : "") + << occ->op->str(0, false) << ' ' + << conflictPair->str() << '\n';); + if (isPersistent) { + persistentScopeOccChosenConflicts[normScopeOcc].insert(conflictPair.get()); + persistentChosenConflictedPairs.push_back(std::move(conflictPair)); + } else { + insertedBarrierAllBefore[occ->op].insert({occ, isUseless}); + scopeOccChosenConflicts[normScopeOcc].insert(conflictPair.get()); + chosenConflictedPairs.push_back(std::move(conflictPair)); + } +} + +void Solver::insertBarrierAllBeforeOp(OperationBase *op, bool isUseless, + bool isPersistent) { + assert(op != nullptr); + for (auto *occ : opAllOccurrences[op]) { + insertBarrierAllBeforeOcc(occ, isUseless, isPersistent); + isUseless = true; + } +} + +// When barrier-all markers need to be chosen, insert them before all +// occurrences for the chosen op. +void Solver::pickAndInsertABarrierAll() { + assert(!insertedBarrierAllBefore.empty()); + OperationBase *chosenOp = nullptr; + for (auto &[op, vec] : insertedBarrierAllBefore) { + if (vec.empty()) { + continue; + } + if (chosenOp == nullptr || chosenOp->id > op->id) { + chosenOp = op; + } + } + assert(chosenOp != nullptr); + insertBarrierAllBeforeOp(chosenOp, /*isUseless=*/false, + /*isPersistent=*/true); +} + +bool Solver::isBackwardSync(Occurrence *occ1, Occurrence *occ2) { + if (occ1->op->id >= occ2->op->id) { + return true; + } + assert(occ1 != nullptr && occ2 != nullptr); + assert(occ1->op != nullptr && occ2->op != nullptr); + auto [parOcc1, parOcc2] = Occurrence::getLCAPair(occ1, occ2); + auto [parOp1, parOp2] = OperationBase::getLCAPair(occ1->op, occ2->op); + return parOcc1->parentOcc->op != parOp1->parentOp; +} + +bool Solver::reuseCmp(ConflictPair *conflictPair1, + ConflictPair *conflictPair2) { + assert(conflictPair1 != nullptr && conflictPair2 != nullptr); + assert(conflictPair1->op1 != nullptr && conflictPair1->op2 != nullptr); + assert(conflictPair2->op1 != nullptr && conflictPair2->op2 != nullptr); + if (conflictPair1->startIndex != conflictPair2->startIndex) { + return conflictPair1->startIndex < conflictPair2->startIndex; + } + if (conflictPair1->endIndex != conflictPair2->endIndex) { + return conflictPair1->endIndex > conflictPair2->endIndex; + } + if (conflictPair1->op1 != conflictPair2->op1) { + return conflictPair1->op1->id > conflictPair2->op1->id; + } + if (conflictPair1->op2 != conflictPair2->op2) { + return conflictPair1->op2->id > conflictPair2->op2->id; + } + return false; +} + +ConflictPair *Solver::getReusableConflictPair( + ConflictPair *conflictPair, + const llvm::DenseSet &conflictPairsSet) { + assert(conflictPair != nullptr); + ConflictPair *ret = nullptr; + for (auto *curConflictPair : conflictPairsSet) { + if (curConflictPair->isBarrier() || curConflictPair->dontReuse) { + continue; + } + if (curConflictPair->op1 != conflictPair->op1 || + curConflictPair->op2 != conflictPair->op2 || + curConflictPair->setCorePipeInfo != conflictPair->setCorePipeInfo || + curConflictPair->waitCorePipeInfo != conflictPair->waitCorePipeInfo) { + continue; + } + if (!checkIntersect(conflictPair, curConflictPair)) { + continue; + } + if (curConflictPair->startIndex >= conflictPair->startIndex) { + continue; + } + if (conflictPair->eventIdNode->eventIdNum < + curConflictPair->eventIdNode->eventIdNum) { + continue; + } + assert(conflictPair->eventIdNode != nullptr); + assert(curConflictPair->eventIdNode != nullptr); + if (conflictPair->eventIdNode->eventIdNum > + curConflictPair->eventIdNode->eventIdNum) { + if (conflictPair->eventIdNode->eventIdNum % + curConflictPair->eventIdNode->eventIdNum) { + continue; + } + } + assert(conflictPair->startIndex <= curConflictPair->endIndex); + assert(curConflictPair->endIndex <= conflictPair->endIndex); + if (ret == nullptr || reuseCmp(ret, curConflictPair)) { + ret = curConflictPair; + } + } + return ret; +} + +bool Solver::reuseConflictPair(ConflictPair *conflictPair, + Occurrence *scopeOcc1, Occurrence *scopeOcc2) { + if (conflictPair->isBarrier()) { + return false; + } + if (scopeOcc1->op != scopeOcc2->op) { + return false; + } + if (!barrierAllPairs.empty()) { + return false; + } + + ConflictPair *oldReusedConflictPair = nullptr; + if (conflictPair->isUseless) { + auto it = replacedWithReusableSyncedPairs.find( + {conflictPair->backwardSyncLoopOp, conflictPair->op1, conflictPair->op2, + conflictPair->setCorePipeInfo, conflictPair->waitCorePipeInfo}); + if (it != replacedWithReusableSyncedPairs.end()) { + oldReusedConflictPair = it->second; + } + } + +#ifndef NDEBUG + if (!conflictPair->isUseless) { + auto it = replacedWithReusableSyncedPairs.find( + {conflictPair->backwardSyncLoopOp, conflictPair->op1, conflictPair->op2, + conflictPair->setCorePipeInfo, conflictPair->waitCorePipeInfo}); + assert(it == replacedWithReusableSyncedPairs.end()); + } +#endif + + if (conflictPair->isUseless && oldReusedConflictPair == nullptr) { + return false; + } + + auto corePipeSrc = conflictPair->setCorePipeInfo; + auto corePipeDst = conflictPair->waitCorePipeInfo; + + if (oldReusedConflictPair == nullptr) { + if (!reusePairs.contains({corePipeSrc, corePipeDst}) || + reusePairs[{corePipeSrc, corePipeDst}] <= + reusedPairs[{corePipeSrc, corePipeDst}]) { + return false; + } + } + + assert(reusePairs.contains(std::make_tuple(corePipeSrc, corePipeDst))); + assert(reusePairs[std::make_tuple(corePipeSrc, corePipeDst)] >= + reusedPairs[std::make_tuple(corePipeSrc, corePipeDst)]); + + ConflictPair *opt1 = nullptr; + ConflictPair *opt2 = nullptr; + ConflictPair *opt3 = nullptr; + ConflictPair *opt4 = nullptr; + ConflictPair *opt5 = nullptr; + + auto it1 = scopeOccChosenConflicts.find(scopeOcc1); + auto it2 = scopeOccChosenConflicts.find(scopeOcc2); + auto it3 = scopeOccPairChosenConflicts.find({scopeOcc1, scopeOcc2}); + auto it4 = persistentScopeOccChosenConflicts.find(scopeOcc1); + auto it5 = persistentScopeOccChosenConflicts.find(scopeOcc2); + + if (it1 != scopeOccChosenConflicts.end()) { + opt1 = getReusableConflictPair(conflictPair, it1->second); + } + if (it2 != scopeOccChosenConflicts.end()) { + opt2 = getReusableConflictPair(conflictPair, it2->second); + } + if (it3 != scopeOccPairChosenConflicts.end()) { + opt3 = getReusableConflictPair(conflictPair, it3->second); + } + if (it4 != persistentScopeOccChosenConflicts.end()) { + opt4 = getReusableConflictPair(conflictPair, it4->second); + } + if (it5 != persistentScopeOccChosenConflicts.end()) { + opt5 = getReusableConflictPair(conflictPair, it5->second); + } + + ConflictPair *reusableConflictPair = nullptr; + for (auto *opt : {opt1, opt2, opt3, opt4, opt5}) { + if (opt != nullptr) { + if (reusableConflictPair == nullptr || + reuseCmp(reusableConflictPair, opt)) { + reusableConflictPair = opt; + } + } + } + + if (reusableConflictPair == nullptr) { + return false; + } + + DEBUG_WITH_TYPE("gss-sync-solver-reuse", { + llvm::dbgs() << "reuse: " << conflictPair->str() << '\n'; + llvm::dbgs() << "with: " << reusableConflictPair->str() << '\n'; + }); + + assert(reusableConflictPair->startIndex < conflictPair->startIndex); + assert(reusableConflictPair->endIndex <= conflictPair->endIndex); + reusableConflictPair->setOp = conflictPair->setOp; + reusableConflictPair->setOcc = conflictPair->setOcc; + reusableConflictPair->startIndex = conflictPair->startIndex; + + if (!conflictPair->isUseless) { + memorizeReusedSyncedPair(conflictPair, reusableConflictPair); + } + + DEBUG_WITH_TYPE("gss-sync-solver-reuse", { + if (oldReusedConflictPair != nullptr) { + llvm::dbgs() << "old-reuse: " << oldReusedConflictPair->str() << '\n'; + } + }); + + if (oldReusedConflictPair != nullptr) { + assert(oldReusedConflictPair->op1 == reusableConflictPair->op1); + assert(oldReusedConflictPair->op2 == reusableConflictPair->op2); + assert(oldReusedConflictPair->waitOp == reusableConflictPair->waitOp); + } + + if (!conflictPair->isUseless) { + reusedPairs[{corePipeSrc, corePipeDst}] += 1; + } + + return true; +} + +std::unique_ptr & +Solver::getEventIdSolverRef(pto::PIPE pipeSrc, pto::PIPE pipeDst) { + if (options.isCrossCoreMode()) { + pipeSrc = pto::PIPE::PIPE_UNASSIGNED; + pipeDst = pto::PIPE::PIPE_UNASSIGNED; + } + auto key = std::make_tuple(pipeSrc, pipeDst); + if (!eventIdSolver.contains(key)) { + int64_t eventIdNumMax = + getHWAvailableEventIdNum(options.syncMode, pipeSrc, pipeDst); + if (options.eventIdNumMax.has_value()) { + eventIdNumMax = std::min(eventIdNumMax, options.eventIdNumMax.value()); + eventIdNumMax = std::max(eventIdNumMax, 1); + } + eventIdSolver[key] = std::make_unique(eventIdNumMax); + } + return eventIdSolver[key]; +} + +bool Solver::checkReuseMultiBufferFlagId(ConflictPair *conflictPair) { + if (options.useDifferentMultiBufferFlagIds) { + return false; + } + if (!conflictPair->isInnerBackward || + conflictPair->eventIdInfo.eventIdNum <= 1 || + conflictPair->movedToOuterLoop) { + return false; + } + auto [setOcc, waitOcc] = + std::tie(conflictPair->setOcc, conflictPair->waitOcc); + auto *backwardSyncLoopOcc = conflictPair->backwardSyncLoopOcc; + assert(backwardSyncLoopOcc != nullptr); + if (auto *parCondOcc1 = setOcc->getParentOfType()) { + if (!parCondOcc1->isProperAncestor(backwardSyncLoopOcc)) { + return false; + } + } + if (auto *parCondOcc2 = waitOcc->getParentOfType()) { + if (!parCondOcc2->isProperAncestor(backwardSyncLoopOcc)) { + return false; + } + } + return true; +} + +void Solver::handleSetWaitConflict(Occurrence *occ1, Occurrence *occ2, + CorePipeInfo corePipeSrc, + CorePipeInfo corePipeDst, + EventIdInfo eventIdInfo, bool isUseless) { + assert(occ1 != nullptr && occ2 != nullptr); + auto *rwOp1 = llvm::dyn_cast_if_present(occ1->op); + auto *rwOp2 = llvm::dyn_cast_if_present(occ2->op); + assert(rwOp1 != nullptr && rwOp2 != nullptr); + assert(corePipeSrc != corePipeDst); + + Loop *parentLCALoopOp{nullptr}; + Occurrence *parentLCALoopOcc{nullptr}; + Occurrence *parentLCALoopBeforePHOcc{nullptr}; + Occurrence *parentLCALoopAfterPHOcc{nullptr}; + auto [setOcc, waitOcc] = getSetWaitOcc(occ1, occ2); + + auto [lcaSetOp, lcaWaitOp] = + OperationBase::getLCAPair(setOcc->op, waitOcc->op); + auto *normScopeOcc1 = setOcc->getParentWithOp(lcaSetOp->parentOp); + auto *normScopeOcc2 = waitOcc->getParentWithOp(lcaWaitOp->parentOp); + assert(normScopeOcc1->op == normScopeOcc2->op); + auto *normScopeOp = normScopeOcc1->op; + assert(normScopeOp != nullptr); + assert(normScopeOp->parentOp != nullptr); + + auto conflictPair = std::make_unique( + rwOp1, rwOp2, setOcc->op, waitOcc->op, setOcc, waitOcc, corePipeSrc, + corePipeDst, setOcc->endIndex, waitOcc->startIndex); + assert(conflictPair->startIndex <= conflictPair->endIndex); + + conflictPair->isUseless = isUseless; + conflictPair->isInnerBackward = isBackwardSync(setOcc, waitOcc); + conflictPair->eventIdInfo = eventIdInfo; + + if (conflictPair->isInnerBackward) { + auto [parOcc1, parOcc2] = Occurrence::getLCAPair(occ1, occ2); + assert(parOcc1 != nullptr && parOcc2 != nullptr); + + parentLCALoopOcc = parOcc1->getParentOfType(); + if (moveBackwardSyncPairsToOutmostLoop) { + while (auto *grandParentLoopOcc = + parentLCALoopOcc->getParentOfType()) { + conflictPair->movedToOuterLoop = true; + parentLCALoopOcc = grandParentLoopOcc; + } + } + assert(parentLCALoopOcc != nullptr); + conflictPair->backwardSyncLoopOcc = parentLCALoopOcc; + + parentLCALoopOp = llvm::dyn_cast(parentLCALoopOcc->op); + assert(parentLCALoopOp != nullptr); + conflictPair->backwardSyncLoopOp = parentLCALoopOp; + + parentLCALoopBeforePHOcc = getBeforePlaceHolderOcc(parentLCALoopOcc); + assert(parentLCALoopBeforePHOcc != nullptr); + parentLCALoopAfterPHOcc = getAfterPlaceHolderOcc(parentLCALoopOcc); + assert(parentLCALoopAfterPHOcc != nullptr); + } + + if (auto setWaitOccs = checkAndApplyMmadl0LoopOpt(conflictPair.get(), occ1, + occ2, setOcc, waitOcc)) { + std::tie(setOcc, waitOcc) = setWaitOccs.value(); + conflictPair->updateSetWaitOccs(setOcc, waitOcc); + } + + if (!conflictPair->isInnerBackward || + disabledMultiEventIdPairs.contains({corePipeSrc, corePipeDst})) { + conflictPair->eventIdInfo = EventIdInfo(1); + } + if (checkReuseMultiBufferFlagId(conflictPair.get())) { + conflictPair->eventIdInfo.eventIdRepeatNum = + conflictPair->eventIdInfo.eventIdNum; + conflictPair->eventIdInfo.eventIdNum = 1; + } + + auto &curEventIdSolver = getEventIdSolverRef( + conflictPair->setCorePipeInfo.pipe, conflictPair->waitCorePipeInfo.pipe); + curEventIdSolver->pushActionNone(); + + auto checkColorable = [&]() -> bool { + if (curEventIdSolver->isColorable()) { + return true; + } + LLVM_DEBUG(llvm::dbgs() << "will-be-converted-to-barrier-all " + << conflictPair->str() << '\n';); + insertBarrierAllBeforeOp(occ2->op, conflictPair->isUseless, + /*isPersistent=*/false); + barrierAllPairs.insert({corePipeSrc, corePipeDst}); + curEventIdSolver->undoActions(); + return false; + }; + + if (auto *oldEventIdNode = getOldEventIdNodeIfExists(conflictPair.get())) { + conflictPair->eventIdNode = oldEventIdNode; + curEventIdSolver->insertConflictPair(oldEventIdNode, conflictPair.get()); + } else { + bool reversedPriority = false; + if (conflictPair->isInnerBackward) { + if (OperationBase::getParentloop(occ1->op) == normScopeOp->parentOp && + OperationBase::getParentloop(occ2->op) == normScopeOp->parentOp) { + reversedPriority = true; + } + } + conflictPair->eventIdNode = curEventIdSolver->createNode( + conflictPair.get(), conflictPair->eventIdInfo.eventIdNum, + reversedPriority); + } + + if (options.reuseSyncPairToSaveEventIds) { + if (reuseConflictPair(conflictPair.get(), normScopeOcc1, normScopeOcc2)) { + curEventIdSolver->undoActions(); + return; + } + } + + auto intersectingConflictPairs = + getIntersectingConflictPairs(conflictPair.get()); + curEventIdSolver->addConflicts(conflictPair.get(), intersectingConflictPairs); + if (!checkColorable()) { + return; + } + + LLVM_DEBUG({ + llvm::dbgs() << conflictPair->str() << '\n'; + if (parentLCALoopOcc != nullptr) { + llvm::dbgs() << parentLCALoopOcc->op->str(0, false) << '\n'; + } + }); + + llvm::SmallVector, Occurrence *>> + extraConflictPairs; + + auto insertExtraConflictPair = [&](Occurrence *setOcc, Occurrence *waitOcc, + Occurrence *parentScope, + bool couldNotRun = false) -> bool { + assert(setOcc != nullptr && waitOcc != nullptr && parentScope != nullptr); + auto extraConflictPair = conflictPair->clone(setOcc, waitOcc); + extraConflictPair->isUseless = true; + extraConflictPair->dontReuse = true; + if (couldNotRun || options.moveOutAndMergeBackwardSyncPairs) { + extraConflictPair->couldNotRun = true; + } + LLVM_DEBUG({ + llvm::dbgs() << "extra-conflict-pair: " << extraConflictPair->str() + << "\n"; + }); + curEventIdSolver->insertConflictPair(conflictPair->eventIdNode, + extraConflictPair.get()); + auto intersectingConflictPairs = + getIntersectingConflictPairs(extraConflictPair.get()); + curEventIdSolver->addConflicts(extraConflictPair.get(), + intersectingConflictPairs); + if (!checkColorable()) { + return false; + } + extraConflictPairs.push_back( + std::make_pair(std::move(extraConflictPair), parentScope)); + return true; + }; + + if (conflictPair->isInnerBackward && conflictPair->eventIdNode != nullptr) { + bool insertOuterBwdConflictPair = false; + if ((conflictPair->eventIdInfo.eventIdNum * + conflictPair->eventIdInfo.eventIdRepeatNum) > 1) { + insertOuterBwdConflictPair = true; + } else if (options.isCrossCoreMode()) { + if (setOcc->parentOcc == nullptr || + setOcc->parentOcc->parentOcc == nullptr || + setOcc->parentOcc->parentOcc->op != parentLCALoopOp) { + insertOuterBwdConflictPair = true; + } else if (waitOcc->parentOcc == nullptr || + waitOcc->parentOcc->parentOcc == nullptr || + waitOcc->parentOcc->parentOcc->op != parentLCALoopOp) { + insertOuterBwdConflictPair = true; + } + } + if (insertOuterBwdConflictPair) { + // insert useless conflictPair to cover the whole loop when having + // multi-eventid backward sync to reserve the eventIds. + if (!insertExtraConflictPair(parentLCALoopBeforePHOcc, + parentLCALoopAfterPHOcc, + parentLCALoopOcc->parentOcc)) { + return; + } + } + } + + if (conflictPair->isInnerBackward && conflictPair->eventIdNode != nullptr) { + // insert header/footer useless conflictPairs to reserve the eventIds. + auto *loopOpOcc1 = getFirstIterOcc(waitOcc, normScopeOcc1); + auto *loopOpOcc2 = getLastIterOcc(setOcc, normScopeOcc2); + if (!insertExtraConflictPair(parentLCALoopBeforePHOcc, loopOpOcc1, + parentLCALoopOcc, /*couldNotRun=*/true)) { + return; + } + if (!insertExtraConflictPair(loopOpOcc2, parentLCALoopAfterPHOcc, + parentLCALoopOcc, /*couldNotRun=*/true)) { + return; + } + } + + bool dontInsert = false; + if (conflictPair->isInnerBackward && normScopeOcc1 != normScopeOcc2) { + auto *parCond = OperationBase::getParentCondition(conflictPair->setOp); + if (auto *conditionOp = llvm::dyn_cast_if_present(parCond)) { + if (parentLCALoopOcc->op->isProperAncestor(conditionOp)) { + scopeOccPairChosenConflicts[{normScopeOcc1, normScopeOcc2}].insert( + conflictPair.get()); + dontInsert = true; + } + } + } + if (!dontInsert) { + assert(parentLCALoopOcc != nullptr || normScopeOcc1 == normScopeOcc2); + scopeOccChosenConflicts[normScopeOcc1].insert(conflictPair.get()); + scopeOccChosenConflicts[normScopeOcc2].insert(conflictPair.get()); + } + + memorizeSyncedPair(conflictPair.get()); + chosenConflictedPairs.push_back(std::move(conflictPair)); + + for (auto &[extraConflictPair, parentScope] : extraConflictPairs) { + scopeOccChosenConflicts[parentScope].insert(extraConflictPair.get()); + chosenConflictedPairs.push_back(std::move(extraConflictPair)); + } + + curEventIdSolver->clearActionStack(); +} + +void Solver::handleBarrierConflict(Occurrence *occ1, Occurrence *occ2, + CorePipeInfo corePipeSrc, + CorePipeInfo corePipeDst, bool isUseless) { + assert(occ1 != nullptr && occ2 != nullptr); + auto *rwOp1 = llvm::dyn_cast_if_present(occ1->op); + auto *rwOp2 = llvm::dyn_cast_if_present(occ2->op); + assert(rwOp1 != nullptr && rwOp2 != nullptr); + + assert(corePipeSrc == corePipeDst); + if (corePipeSrc.pipe == pto::PIPE::PIPE_S) { + return; + } + if (options.isRegBasedArch) { + if (corePipeSrc.pipe == pto::PIPE::PIPE_V || + corePipeSrc.pipe == pto::PIPE::PIPE_M) { + return; + } + } + auto *waitOcc = getBarrierWaitOcc(occ1, occ2); + + auto conflictPair = std::make_unique( + rwOp1, rwOp2, waitOcc->op, waitOcc->op, waitOcc, waitOcc, corePipeSrc, + corePipeDst, waitOcc->startIndex, waitOcc->startIndex); + conflictPair->isUseless = isUseless; + assert(conflictPair->startIndex <= conflictPair->endIndex); + + LLVM_DEBUG({ llvm::dbgs() << conflictPair->str() << '\n'; }); + + auto *normScopeOcc = waitOcc->parentOcc; + scopeOccChosenConflicts[normScopeOcc].insert(conflictPair.get()); + chosenConflictedPairs.push_back(std::move(conflictPair)); +} + +void Solver::handleUnitFlagConflict(Occurrence *occ1, Occurrence *occ2, + CorePipeInfo corePipeSrc, + CorePipeInfo corePipeDst, + UnitFlagInfo unitFlagInfo, bool isUseless) { + assert(occ1 != nullptr && occ2 != nullptr); + auto *rwOp1 = llvm::dyn_cast_if_present(occ1->op); + auto *rwOp2 = llvm::dyn_cast_if_present(occ2->op); + assert(rwOp1 != nullptr && rwOp2 != nullptr); + assert(corePipeSrc != corePipeDst); + + auto *setOcc = occ1; + auto *waitOcc = occ2; + auto *normScopeOcc1 = setOcc->parentOcc; + auto *normScopeOcc2 = waitOcc->parentOcc; + + auto conflictPair = std::make_unique( + rwOp1, rwOp2, setOcc->op, waitOcc->op, setOcc, waitOcc, corePipeSrc, + corePipeDst, setOcc->endIndex, waitOcc->startIndex); + assert(conflictPair->startIndex <= conflictPair->endIndex); + + conflictPair->isUseless = true; + conflictPair->dontReuse = true; + conflictPair->replacedWithUnitFlag = true; + conflictPair->dontCheckForConflict = true; + conflictPair->isInnerBackward = isBackwardSync(setOcc, waitOcc); + +#ifndef NDEBUG + Occurrence *parentLCALoopOcc{nullptr}; + if (conflictPair->isInnerBackward) { + auto [parOcc1, parOcc2] = Occurrence::getLCAPair(occ1, occ2); + assert(parOcc1 != nullptr && parOcc2 != nullptr); + parentLCALoopOcc = Occurrence::getParentloop(parOcc1); + assert(parentLCALoopOcc != nullptr); + } + + LLVM_DEBUG({ + llvm::dbgs() << conflictPair->str() << '\n'; + if (parentLCALoopOcc != nullptr) { + llvm::dbgs() << parentLCALoopOcc->op->str(0, false) << '\n'; + } + }); +#endif + + occ1->unitFlagInfo.merge(unitFlagInfo, occ1, occ2, + /*asSet=*/true, /*asWait=*/false); + occ2->unitFlagInfo.merge(unitFlagInfo, occ1, occ2, + /*asSet=*/false, /*asWait=*/true); + if (!isUseless) { + rwOp1->mergedUnitFlagInfo.merge(unitFlagInfo, /*asSet=*/true, + /*asWait=*/false); + rwOp2->mergedUnitFlagInfo.merge(unitFlagInfo, /*asSet=*/false, + /*asWait=*/true); + } + + scopeOccPairChosenConflicts[{normScopeOcc1, normScopeOcc2}].insert( + conflictPair.get()); + chosenConflictedPairs.push_back(std::move(conflictPair)); +} + +void Solver::handleConflict(Occurrence *occ1, Occurrence *occ2, + RWOperation *rwOp1, RWOperation *rwOp2, + CorePipeInfo corePipeSrc, CorePipeInfo corePipeDst, + EventIdInfo eventIdInfo, bool isUseless) { + if (!checkGraphConflict(occ1, occ2, corePipeSrc, corePipeDst, eventIdInfo)) { + return; + } + LLVM_DEBUG({ + llvm::dbgs() << "conflict found: " << "eventIdNum(" + << eventIdInfo.eventIdNum << ")\n"; + llvm::dbgs() << occ1->syncIrIndex << ' ' << occ1->startIndex << ' ' + << occ1->endIndex << ' ' << rwOp1->str(0, false) << '\n'; + llvm::dbgs() << occ2->syncIrIndex << ' ' << occ2->startIndex << ' ' + << occ2->endIndex << ' ' << rwOp2->str(0, false) << '\n'; + }); + if (corePipeSrc == corePipeDst) { + handleBarrierConflict(occ1, occ2, corePipeSrc, corePipeDst, isUseless); + } else if (auto unitFlagInfo = checkUnitFlagPatterns(occ1, occ2)) { + handleUnitFlagConflict(occ1, occ2, corePipeSrc, corePipeDst, + unitFlagInfo.value(), isUseless); + } else { + handleSetWaitConflict(occ1, occ2, corePipeSrc, corePipeDst, eventIdInfo, + isUseless); + } +} + +void Solver::calcAllEventIds() { + for (auto &[pipes, eventIdSolver] : eventIdSolver) { + assert(eventIdSolver != nullptr); + + [[maybe_unused]] auto result = + eventIdSolver->shrinkEventIdMaxToEventIdNum(); + assert(llvm::succeeded(result)); + assert(eventIdSolver->isColorable()); + } +} + +void Solver::collectBackwardSyncEventIds() { + LLVM_DEBUG(llvm::dbgs() << "collectBackwardSyncEventIds\n";); + for (auto &conflictPair : chosenConflictedPairs) { + if (!conflictPair->isUseless && conflictPair->isInnerBackward && + conflictPair->eventIdNode != nullptr) { + LLVM_DEBUG(llvm::dbgs() << " " << conflictPair->str() << "\n";); + for (auto eventId : conflictPair->eventIdNode->getEventIds()) { + auto &e = backwardSyncEvents[conflictPair->backwardSyncLoopOp] + [{conflictPair->setCorePipeInfo, + conflictPair->waitCorePipeInfo}][eventId]; + e = std::max(e, conflictPair->eventIdInfo.eventIdRepeatNum); + } + } + } +} + +void Solver::resetAndBuildSetWaitOpIndex(const SyncMap &syncMapBefore, + const SyncMap &syncMapAfter) { + globalSetWaitIndex = 0; + setWaitStartIndex.clear(); + setWaitEndIndex.clear(); + setWaitStartIndexInclusive.clear(); + setWaitEndIndexInclusive.clear(); + setWaitFlagOpsIndex.clear(); + collectSetWaitOpsIndexes(funcIr.get(), syncMapBefore, syncMapAfter); +} + +std::set> & +Solver::getSetWaitOpsIndexRef(pto::PIPE pipeSrc, pto::PIPE pipeDst, + int64_t eventId) { + auto key = std::make_tuple(pipeSrc, pipeDst, eventId); + return setWaitFlagOpsIndex[key]; +} + +// Collect indices for all Set/Wait ops to facilitate merging decisions. +void Solver::collectSetWaitOpsIndexes(OperationBase *op, + const SyncMap &syncMapBefore, + const SyncMap &syncMapAfter) { + assert(op != nullptr); + setWaitStartIndexInclusive[op] = globalSetWaitIndex++; + if (syncMapBefore.count(op)) { + auto *it = syncMapBefore.find(op); + assert(it != syncMapBefore.end()); + for (auto &syncOp : it->second) { + if (auto *setWaitOp = llvm::dyn_cast(syncOp.get())) { + for (auto eventId : setWaitOp->eventIds) { + auto &index = getSetWaitOpsIndexRef(setWaitOp->pipeSrc, + setWaitOp->pipeDst, eventId); + index.insert({globalSetWaitIndex++, setWaitOp}); + } + } + } + } + setWaitStartIndex[op] = globalSetWaitIndex++; + if (auto *scopeOp = llvm::dyn_cast(op)) { + for (auto &childOp : scopeOp->body) { + collectSetWaitOpsIndexes(childOp.get(), syncMapBefore, syncMapAfter); + } + } + setWaitEndIndex[op] = globalSetWaitIndex++; + if (syncMapAfter.count(op)) { + auto *it = syncMapAfter.find(op); + assert(it != syncMapAfter.end()); + for (auto &syncOp : it->second) { + if (auto *setWaitOp = llvm::dyn_cast(syncOp.get())) { + for (auto eventId : setWaitOp->eventIds) { + auto &index = getSetWaitOpsIndexRef(setWaitOp->pipeSrc, + setWaitOp->pipeDst, eventId); + index.insert({globalSetWaitIndex++, setWaitOp}); + } + } + } + } + setWaitEndIndexInclusive[op] = globalSetWaitIndex++; +} + +bool Solver::checkBackwardSyncEventsContains(OperationBase *op, + CorePipeInfo corePipeSrc, + CorePipeInfo corePipeDst, + int64_t eventId) { + auto *it1 = backwardSyncEvents.find(op); + if (it1 == backwardSyncEvents.end()) { + return false; + } + auto it2 = it1->second.find({corePipeSrc, corePipeDst}); + if (it2 == it1->second.end()) { + return false; + } + return it2->second.contains(eventId); +} + +bool Solver::checkBackwardSyncEventsContainsAfterMerge( + OperationBase *op, CorePipeInfo corePipeSrc, CorePipeInfo corePipeDst) { + auto *it1 = backwardSyncEventsAfterMerge.find(op); + if (it1 == backwardSyncEventsAfterMerge.end()) { + return false; + } + return it1->second.contains({corePipeSrc, corePipeDst}); +} + +// Check whether a backward-sync event id can be merged at scope level. +bool Solver::checkMergeable(Scope *scopeOp, CorePipeInfo corePipeSrc, + CorePipeInfo corePipeDst, int64_t eventId, + bool shouldBeUsedAtleastOnce) { + auto &index = + getSetWaitOpsIndexRef(corePipeSrc.pipe, corePipeDst.pipe, eventId); + if (shouldBeUsedAtleastOnce) { + auto it = index.lower_bound({setWaitStartIndexInclusive[scopeOp], nullptr}); + bool usedAtleastOnce = + it != index.end() && it->first < setWaitEndIndexInclusive[scopeOp]; + if (!usedAtleastOnce) { + return false; + } + } + { + auto it1 = + index.lower_bound({setWaitStartIndexInclusive[scopeOp], nullptr}); + auto it2 = index.lower_bound({setWaitEndIndex[scopeOp], nullptr}); + bool usedBefore = + it1 != index.end() && it1->first < setWaitStartIndex[scopeOp]; + bool usedAfter = + it2 != index.end() && it2->first < setWaitEndIndexInclusive[scopeOp]; + if (usedBefore || usedAfter) { + return false; + } + } + if (auto *conditionOp = llvm::dyn_cast(scopeOp)) { + if (!conditionOp->hasFalseScope()) { + return false; + } + return checkMergeable(conditionOp->getTrueScope(), corePipeSrc, corePipeDst, + eventId, true) && + checkMergeable(conditionOp->getFalseScope(), corePipeSrc, + corePipeDst, eventId, true); + } + if (auto *loopOp = llvm::dyn_cast(scopeOp)) { + for (auto &childOp : loopOp->body) { + if (auto *childScopeOp = llvm::dyn_cast(childOp.get())) { + if (!checkMergeable(childScopeOp, corePipeSrc, corePipeDst, eventId, + false)) { + return false; + } + } + } + for (auto &childOp : loopOp->body) { + if (auto *childScopeOp = llvm::dyn_cast(childOp.get())) { + if (checkMergeable(childScopeOp, corePipeSrc, corePipeDst, eventId, + true)) { + return true; + } + } + } + return false; + } + for (auto &childOp : scopeOp->body) { + auto it1 = + index.lower_bound({setWaitStartIndexInclusive[childOp.get()], nullptr}); + auto it2 = index.lower_bound({setWaitEndIndex[childOp.get()], nullptr}); + bool usedAtleastOnce = it1 != index.end() && + it1->first < setWaitEndIndexInclusive[childOp.get()]; + if (!usedAtleastOnce) { + continue; + } + bool before = + it1 != index.end() && it1->first < setWaitStartIndex[childOp.get()]; + bool after = it2 != index.end() && + it2->first < setWaitEndIndexInclusive[childOp.get()]; + if (before || after) { + return false; + } + if (!checkBackwardSyncEventsContains(childOp.get(), corePipeSrc, + corePipeDst, eventId)) { + return false; + } + if (checkBackwardSyncEventsContainsAfterMerge(childOp.get(), corePipeSrc, + corePipeDst)) { + return false; + } + } + return true; +} + +// Attempt to merge backward sync events across children and prune duplicates. +void Solver::mergeBackwardSyncEventIds(OperationBase *op) { + auto *scopeOp = llvm::dyn_cast_if_present(op); + if (scopeOp == nullptr) { + return; + } + for (auto &op : scopeOp->body) { + mergeBackwardSyncEventIds(op.get()); + } + + if (llvm::isa_and_present(op)) { + return; + } + if (llvm::isa_and_present(op->parentOp)) { + return; + } + + auto *conditionOp = llvm::dyn_cast(op); + if (conditionOp != nullptr) { + if (!conditionOp->hasFalseScope()) { + return; + } + } + + llvm::DenseSet> toBeErased; + + llvm::SmallVector coreTypes; + if (options.isCrossCoreMode()) { + coreTypes = {pto::TCoreType::VECTOR, pto::TCoreType::CUBE}; + } else { + coreTypes = {pto::TCoreType::CUBE_OR_VECTOR}; + } + size_t pipeNumMax = static_cast(pto::PIPE::PIPE_NUM); + const int64_t eventIdMax = getHWAvailableEventIdNum(options.syncMode); + + for (int64_t eventId = 0; eventId < eventIdMax; ++eventId) { + for (auto coreSrc : coreTypes) { + for (auto coreDst : coreTypes) { + for (size_t pipeSrcInt = 0; pipeSrcInt < pipeNumMax; pipeSrcInt++) { + for (size_t pipeDstInt = 0; pipeDstInt < pipeNumMax; pipeDstInt++) { + auto pipeSrc = static_cast(pipeSrcInt); + auto pipeDst = static_cast(pipeDstInt); + auto corePipeSrc = CorePipeInfo(coreSrc, pipeSrc); + auto corePipeDst = CorePipeInfo(coreDst, pipeDst); + if (checkBackwardSyncEventsContains(scopeOp, corePipeSrc, + corePipeDst, eventId)) { + continue; + } + if (checkMergeable(scopeOp, corePipeSrc, corePipeDst, eventId)) { + toBeErased.insert({corePipeSrc, corePipeDst, eventId}); + backwardSyncEvents[scopeOp][{corePipeSrc, corePipeDst}].insert( + {eventId, 1}); + } + } + } + } + } + } + + if (isa(scopeOp)) { + for (auto &op : scopeOp->body) { + if (auto *block = llvm::dyn_cast(op.get())) { + for (auto &childOp : block->body) { + if (auto *childScopeOp = llvm::dyn_cast(childOp.get())) { + for (auto [corePipeSrc, corePipeDst, eventId] : toBeErased) { + if (checkBackwardSyncEventsContains(childScopeOp, corePipeSrc, + corePipeDst, eventId)) { + auto key = std::make_tuple(corePipeSrc, corePipeDst); + backwardSyncEvents[childScopeOp][key].erase(eventId); + if (backwardSyncEvents[childScopeOp][key].empty()) { + backwardSyncEvents[childScopeOp].erase(key); + } + } + } + } + } + } + } + } else { + for (auto &childOp : scopeOp->body) { + if (auto *childScopeOp = llvm::dyn_cast(childOp.get())) { + for (auto [corePipeSrc, corePipeDst, eventId] : toBeErased) { + if (checkBackwardSyncEventsContains(childScopeOp, corePipeSrc, + corePipeDst, eventId)) { + auto key = std::make_tuple(corePipeSrc, corePipeDst); + backwardSyncEvents[childScopeOp][key].erase(eventId); + if (backwardSyncEvents[childScopeOp][key].empty()) { + backwardSyncEvents[childScopeOp].erase(key); + } + } + } + } + } + } +} + +void Solver::mergeBackwardSyncPairs(SyncMap &syncMapBefore, + SyncMap &syncMapAfter) { + if (!options.moveOutAndMergeBackwardSyncPairs) { + return; + } + if (options.isIntraCoreMode()) { + resetAndBuildSetWaitOpIndex(syncMapBefore, syncMapAfter); + auto *scopeOp = llvm::dyn_cast(funcIr.get()); + assert(scopeOp != nullptr && scopeOp->body.front() != nullptr); + mergeBackwardSyncEventIds(scopeOp->body.front().get()); + } +} + +SyncBeforeAfterMap Solver::getBeforeAfterSyncMaps() { + calcAllEventIds(); + SyncMap syncMapBefore, syncMapAfter; + std::vector conflictPairs; + for (auto &conflictPair : chosenConflictedPairs) { + conflictPairs.push_back(conflictPair.get()); + } + for (auto &conflictPair : persistentChosenConflictedPairs) { + conflictPairs.push_back(conflictPair.get()); + } + + for (auto *conflictPair : conflictPairs) { + if (conflictPair->isUseless) { + continue; + } + if (conflictPair->replacedWithUnitFlag) { + continue; + } + assert(conflictPair->setOp != nullptr && conflictPair->waitOp != nullptr); + if (conflictPair->isBarrier()) { + auto barrierOp = std::make_unique( + conflictPair->waitOp->op, conflictPair->waitOp->parentOp, + conflictPair->waitCorePipeInfo.pipe); + LLVM_DEBUG(barrierOp->debugId = conflictPair->id); + syncMapBefore[conflictPair->waitOp].push_back(std::move(barrierOp)); + } else { + assert(conflictPair->eventIdNode != nullptr); + auto setOp = std::make_unique( + conflictPair->setOp->op, conflictPair->setOp->parentOp, + conflictPair->eventIdNode->getEventIds(), + conflictPair->setCorePipeInfo.pipe, + conflictPair->waitCorePipeInfo.pipe); + auto waitOp = std::make_unique( + conflictPair->waitOp->op, conflictPair->waitOp->parentOp, + conflictPair->eventIdNode->getEventIds(), + conflictPair->setCorePipeInfo.pipe, + conflictPair->waitCorePipeInfo.pipe); + if (options.isCrossCoreMode()) { + setOp->coreType = conflictPair->setCorePipeInfo.coreType; + waitOp->coreType = conflictPair->waitCorePipeInfo.coreType; + } + setOp->eventIdInfo = conflictPair->eventIdInfo; + waitOp->eventIdInfo = conflictPair->eventIdInfo; + setOp->checkLastIter = conflictPair->setOnLastIterOnly; + waitOp->checkFirstIter = conflictPair->waitOnFirstIterOnly; + LLVM_DEBUG({ + setOp->debugId = conflictPair->id; + waitOp->debugId = conflictPair->id; + }); + assert(setOp != nullptr && waitOp != nullptr); + syncMapAfter[conflictPair->setOp].push_back(std::move(setOp)); + syncMapBefore[conflictPair->waitOp].push_front(std::move(waitOp)); + } + } + + collectBackwardSyncEventIds(); + mergeBackwardSyncPairs(syncMapBefore, syncMapAfter); + + for (auto &[op, mp] : backwardSyncEvents) { + if (mp.empty()) { + continue; + } + auto *scopeOp = llvm::dyn_cast(op); + assert(scopeOp != nullptr); + for (auto [setWaitCorePipes, eventIdsMp] : mp) { + if (eventIdsMp.empty()) { + continue; + } + llvm::SmallVector eventIds; + for (auto [eventId, repeatNum] : eventIdsMp) { + llvm::SmallVector curEventIds(repeatNum, eventId); + llvm::append_range(eventIds, curEventIds); + } + llvm::sort(eventIds); + auto [corePipeSrc, corePipeDst] = setWaitCorePipes; + auto setOp = + std::make_unique(scopeOp->op, scopeOp->parentOp, eventIds, + corePipeSrc.pipe, corePipeDst.pipe); + auto waitOp = + std::make_unique(scopeOp->op, scopeOp->parentOp, eventIds, + corePipeSrc.pipe, corePipeDst.pipe); + setOp->allAtOnce = true; + waitOp->allAtOnce = true; + if (options.isCrossCoreMode()) { + setOp->coreType = corePipeSrc.coreType; + waitOp->coreType = corePipeDst.coreType; + } + assert(setOp != nullptr && waitOp != nullptr); + syncMapBefore[scopeOp].push_back(std::move(setOp)); + syncMapAfter[scopeOp].push_front(std::move(waitOp)); + } + } + return std::make_pair(std::move(syncMapBefore), std::move(syncMapAfter)); +} + +void Solver::processConflict(Occurrence *occ1, Occurrence *occ2, + RWOperation *rwOp1, RWOperation *rwOp2, + bool isUseless) { + for (auto [corePipeSrc, corePipeDst] : checkMemoryConflicts(rwOp1, rwOp2)) { + if (options.alwaysUsePipeSAsWaitingPipe) { + corePipeDst.pipe = pto::PIPE::PIPE_S; + } + auto eventIdInfo = + getEventIdInfo(occ1, occ2, rwOp1, rwOp2, corePipeSrc, corePipeDst); + handleConflict(occ1, occ2, rwOp1, rwOp2, corePipeSrc, corePipeDst, + eventIdInfo, isUseless); + } +} + +// Main processing loop that iterates processingOrders and attempts to +// discover and record conflicts. +void Solver::processOrders() { + for (auto &[occ1, occ2, rwOp1, rwOp2, isUseless] : processingOrders) { + assert(occ1 != occ2); + assert(occ1->syncIrIndex < occ2->syncIrIndex); + if (checkVisited(occ1, occ2)) { + assert(false && "expected to not check a pair more than once."); + continue; + } + if (checkImpossibleOccPair(occ1, occ2) || checkAlreadySynced(occ1, occ2) || + skipMMad1DecomposedLoopOpt(occ1, occ2) || + checkSkipParallelLoop(occ1, occ2) || + checkSkipCrossCorePair(occ1, occ2)) { + continue; + } + DEBUG_WITH_TYPE("gss-sync-solver-checking", { + llvm::dbgs() << "checking: " << (isUseless ? "is-useless\n" : "\n"); + llvm::dbgs() << occ1->syncIrIndex << ' ' << occ1->startIndex << ' ' + << occ1->endIndex << ' ' << occ1->op->str(0, false) << '\n'; + llvm::dbgs() << occ2->syncIrIndex << ' ' << occ2->startIndex << ' ' + << occ2->endIndex << ' ' << occ2->op->str(0, false) << '\n'; + }); + if (checkAlreadySyncedWithUnitFlag(occ1, occ2)) { + continue; + } + processConflict(occ1, occ2, rwOp1, rwOp2, isUseless); + } +} + +void Solver::insertMergedBackwardSyncPairs() { + for (auto &[scopeOp, st] : backwardSyncEventsAfterMerge) { + for (auto &corePipeInfoPair : st) { + auto [corePipeSrc, corePipeDst] = corePipeInfoPair; + for (auto *scopeOcc : opAllOccurrences[scopeOp]) { + auto *parentScopeOcc = scopeOcc->parentOcc; + assert(parentScopeOcc != nullptr); + Occurrence *setOcc = nullptr; + Occurrence *waitOcc = nullptr; + auto startIndex = scopeOcc->startIndex; + auto endIndex = scopeOcc->endIndex; + if (isa(scopeOp)) { + setOcc = getBeforePlaceHolderOcc(scopeOcc); + waitOcc = getAfterPlaceHolderOcc(scopeOcc); + startIndex = setOcc->endIndex; + endIndex = waitOcc->startIndex; + } + auto conflictPair = std::make_unique( + nullptr, nullptr, nullptr, nullptr, setOcc, waitOcc, corePipeSrc, + corePipeDst, startIndex, endIndex); + assert(conflictPair->startIndex <= conflictPair->endIndex); + conflictPair->isUseless = true; + conflictPair->dontReuse = true; + conflictPair->dontCheckForConflict = true; + conflictPair->couldNotRun = false; // notice this + LLVM_DEBUG({ + llvm::dbgs() << "consider-merged-backward-pair: " + << scopeOp->str(0, false) << ' ' << conflictPair->str() + << "\n"; + }); + scopeOccChosenConflicts[parentScopeOcc].insert(conflictPair.get()); + chosenConflictedPairs.push_back(std::move(conflictPair)); + } + } + } +} + +llvm::LogicalResult Solver::considerOuterBackwardSyncPairs() { + if (!options.considerOuterBackwardSyncPairs) { + return llvm::failure(); + } + bool backwardPairsPositionChanged = false; + for (auto &[scopeOp, st] : backwardSyncEventsAfterMerge) { + SmallVector> toBeErased; + for (auto &corePipeInfoPair : st) { + if (!backwardSyncEvents.contains(scopeOp) || + !backwardSyncEvents[scopeOp].contains(corePipeInfoPair)) { + toBeErased.push_back(corePipeInfoPair); + } + } + if (!toBeErased.empty()) { + backwardPairsPositionChanged = true; + for (auto &corePipeInfoPair : toBeErased) { + st.erase(corePipeInfoPair); + } + } + } + int chosenOpsDepth = -1; + SmallVector chosenOps; + for (auto &[scopeOp, mp] : backwardSyncEvents) { + if (backwardSyncEventsAfterMerge.contains(scopeOp)) { + continue; + } + int scopeOpDepth = scopeOp->getDepth(); + if (chosenOpsDepth == scopeOpDepth) { + chosenOps.push_back(scopeOp); + } else if (chosenOpsDepth == -1 || chosenOpsDepth < scopeOpDepth) { + chosenOps.clear(); + chosenOps.push_back(scopeOp); + chosenOpsDepth = scopeOpDepth; + } + } + if (chosenOps.empty()) { + return llvm::failure(); + } + bool newPairIsInserted = false; + for (auto *chosenOp : chosenOps) { + for (auto &[corePipeInfoPair, eventIdsMp] : backwardSyncEvents[chosenOp]) { + assert(!eventIdsMp.empty()); + if (!eventIdsMp.empty()) { + auto [it, isInserted] = + backwardSyncEventsAfterMerge[chosenOp].insert(corePipeInfoPair); + newPairIsInserted |= isInserted; + } + } + } + return llvm::success(backwardPairsPositionChanged || newPairIsInserted); +} + +llvm::LogicalResult Solver::reuseSyncPairToSaveEventIds() { + if (!options.reuseSyncPairToSaveEventIds || barrierAllPairs.empty()) { + return llvm::failure(); + } + bool limitReached = true; + for (auto [corePipeSrc, corePipeDst] : barrierAllPairs) { + if (reusePairs[{corePipeSrc, corePipeDst}] < maxReuseNum) { + if (reusePairs[{corePipeSrc, corePipeDst}] <= + reusedPairs[{corePipeSrc, corePipeDst}]) { + reusePairs[{corePipeSrc, corePipeDst}] += 1; + limitReached = false; + } + } + } + DEBUG_WITH_TYPE("gss-sync-solver-reuse", { + llvm::dbgs() << "reusePairs: \n"; + for (auto [pipeCorePairs, cnt] : reusePairs) { + llvm::dbgs() << get<0>(pipeCorePairs).pipe << ' ' + << get<1>(pipeCorePairs).pipe << ' ' << cnt << '\n'; + } + }); + return llvm::success(!limitReached); +} + +llvm::LogicalResult Solver::disableMultiEventIdForBarrierAllPairs() { + if (!options.disableMultiEventIdForBarrierAllPairs || + barrierAllPairs.empty()) { + return llvm::failure(); + } + bool newPairIsInserted = false; + for (auto corePipeInfoPair : barrierAllPairs) { + auto [it, isInserted] = disabledMultiEventIdPairs.insert(corePipeInfoPair); + newPairIsInserted |= isInserted; + } + LLVM_DEBUG({ + if (newPairIsInserted) { + llvm::dbgs() << "disabled-multi-event-id-pairs: \n"; + for (auto &[corePipeSrc, corePipeDst] : disabledMultiEventIdPairs) { + llvm::dbgs() << corePipeSrc.coreType << ' ' << corePipeSrc.pipe << ' ' + << corePipeDst.coreType << ' ' << corePipeDst.pipe << '\n'; + } + } + }); + return llvm::success(newPairIsInserted); +} + +llvm::LogicalResult Solver::tryMovingOutBackwardSyncPairsToOuterLoops() { + if (!options.moveOutAndMergeBackwardSyncPairs || !options.isCrossCoreMode() || + dontMoveBackwardSyncPairsToOutmostLoop) { + return llvm::failure(); + } + if (!moveBackwardSyncPairsToOutmostLoop) { + moveBackwardSyncPairsToOutmostLoop = true; + return llvm::success(); + } + if (!barrierAllPairs.empty()) { + moveBackwardSyncPairsToOutmostLoop = false; + dontMoveBackwardSyncPairsToOutmostLoop = true; + return llvm::success(); + } + return llvm::failure(); +} + +// High-level solve orchestration with multiple passes and optional merging +// iterations. +llvm::LogicalResult Solver::runSolver(bool enableOpts1, bool enableOpts2) { + reset(/*resetEventIdRanOutOpts=*/true); + + int64_t runNum = 0; + while (runNum++ < maxRunNum) { + LLVM_DEBUG(llvm::dbgs() << "runNum: " << runNum << '\n'); + + reset(); + insertMergedBackwardSyncPairs(); + processOrders(); + + if (llvm::succeeded(tryMovingOutBackwardSyncPairsToOuterLoops())) { + continue; + } + + if (enableOpts1) { + if (options.considerOuterBackwardSyncPairs) { + getBeforeAfterSyncMaps(); + if (llvm::succeeded(considerOuterBackwardSyncPairs())) { + continue; + } + if (!barrierAllPairs.empty()) { + backwardSyncEventsAfterMerge.clear(); + } + } + } + + if (enableOpts2) { + if (!barrierAllPairs.empty()) { + if (llvm::succeeded(reuseSyncPairToSaveEventIds())) { + continue; + } + if (llvm::succeeded(disableMultiEventIdForBarrierAllPairs())) { + continue; + } + } + } + + if (!barrierAllPairs.empty()) { + pickAndInsertABarrierAll(); + reset(/*resetEventIdRanOutOpts=*/true); + continue; + } + break; + } + + reset(); + insertMergedBackwardSyncPairs(); + processOrders(); + + return llvm::success(runNum < maxRunNum); +} + +void Solver::solve() { + if (llvm::succeeded(runSolver())) { + return; + } + if (!options.isTestMode()) { + if (llvm::succeeded(runSolver(/*enableOpts1=*/false))) { + return; + } + if (llvm::succeeded( + runSolver(/*enableOpts1=*/false, /*enableOpts2=*/false))) { + return; + } + } + llvm_unreachable("GSS: runSolver() failed."); +} diff --git a/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.def b/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.def deleted file mode 100644 index 23a4032a6..000000000 --- a/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.def +++ /dev/null @@ -1,2576 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -//===--------- SyncSolver.cpp ------- Graph Sync Solver -------------------===// -//===----------------------------------------------------------------------===// - -#include "PTO/Transforms/GraphSyncSolver/SyncSolver.h" -#include "PTO/Transforms/GraphSyncSolver/GraphSolver.h" -#include "PTO/Transforms/GraphSyncSolver/MemInfo.h" -#include "PTO/Transforms/GraphSyncSolver/SyncSolverIR.h" -#include "PTO/Transforms/GraphSyncSolver/Utility.h" - -#include "PTO/IR/PTO.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/IR/Value.h" -#include "mlir/Interfaces/LoopLikeInterface.h" -#include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringExtras.h" -#include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/LogicalResult.h" -#include -#include -#include -#include -#include -#include -#include - -#define DEBUG_TYPE "PTO-gss-solver" - -using namespace mlir; -using namespace pto::syncsolver; - -// Reset per-pass bookkeeping to start fresh. -void Solver::reset(bool resetEventIdRanOutOpts) { - if (resetEventIdRanOutOpts) { - reusePairs.clear(); - disabledMultiEventIdPairs.clear(); - backwardSyncEventsAfterMerge.clear(); - moveBackwardSyncPairsToOutmostLoop = false; - dontMoveBackwardSyncPairsToOutmostLoop = false; - } - skipOcc.clear(); - syncedPairs.clear(); - processedOccPairs.clear(); - chosenConflictedPairs.clear(); - scopeOccChosenConflicts.clear(); - scopeOccPairChosenConflicts.clear(); - backwardSyncEvents.clear(); - replacedWithReusableSyncedPairs.clear(); - reusedPairs.clear(); - barrierAllPairs.clear(); - insertedBarrierAllBefore.clear(); - eventIdSolver.clear(); - resetUnitFlag(); -} - -void Solver::resetUnitFlag() { - for (auto *rwOp : unitFlagFeaturedOps) { - rwOp->mergedUnitFlagInfo.reset(); - for (auto *occ : opAllOccurrences[rwOp]) { - occ->unitFlagInfo.reset(); - } - } -} - -// Helpers to find first/last iteration occurrences relative to parent -// occurrences. -Occurrence *Solver::getFirstIterOcc(Occurrence *occ, Occurrence *parOcc) { - assert(occ != nullptr && parOcc != nullptr); - if (parOcc->depth + 1 < occ->depth) { - auto *newParOcc = getFirstIterOcc( - occ->getNthParent(occ->depth - parOcc->depth - 1), parOcc); - return getFirstIterOcc(occ, newParOcc); - } - auto *it = - std::find_if(parOcc->childOccs.begin(), parOcc->childOccs.end(), - [occ](Occurrence *curOcc) { return occ->op == curOcc->op; }); - assert(it != parOcc->childOccs.end()); - return *it; -} - -Occurrence *Solver::getLastIterOcc(Occurrence *occ, Occurrence *parOcc) { - assert(occ != nullptr && parOcc != nullptr); - if (parOcc->depth + 1 < occ->depth) { - auto *newParOcc = getLastIterOcc( - occ->getNthParent(occ->depth - parOcc->depth - 1), parOcc); - return getLastIterOcc(occ, newParOcc); - } - auto it = - std::find_if(parOcc->childOccs.rbegin(), parOcc->childOccs.rend(), - [occ](Occurrence *curOcc) { return occ->op == curOcc->op; }); - assert(it != parOcc->childOccs.rend()); - return *it; -} - -bool Solver::checkSkipCrossCorePair(Occurrence *occ1, Occurrence *occ2) { - if (!options.isCrossCoreMode()) { - return false; - } - auto *rwOp1 = llvm::dyn_cast(occ1->op); - auto *rwOp2 = llvm::dyn_cast(occ2->op); - assert(rwOp1 != nullptr && rwOp2 != nullptr); - assert(rwOp1->coreType != pto::TCoreType::CUBE_OR_VECTOR); - assert(rwOp2->coreType != pto::TCoreType::CUBE_OR_VECTOR); - if (rwOp1->coreType == rwOp2->coreType) { - return true; - } - if (rwOp1->coreType == pto::TCoreType::CUBE_AND_VECTOR) { - return true; - } - return false; -} - -bool Solver::checkSkipParallelLoop(Occurrence *occ1, Occurrence *occ2) { - if (!isBackwardSync(occ1, occ2)) { - return false; - } - auto [parOcc1, parOcc2] = Occurrence::getLCAPair(occ1, occ2); - assert(parOcc1 != nullptr && parOcc2 != nullptr); - auto *parentLCALoopOcc = Occurrence::getParentloop(parOcc1); - assert(parentLCALoopOcc != nullptr); - auto *parentLCALoopOp = llvm::cast(parentLCALoopOcc->op); - return parentLCALoopOp->isParallel; -} - -// Check whether occurrences belong to impossible (if-else) pairing. -bool Solver::checkImpossibleOccPair(Occurrence *occ1, Occurrence *occ2) { - assert(occ1 != nullptr && occ2 != nullptr); - if (occ1->op == occ2->op) { - return false; - } - auto [parOcc1, parOcc2] = Occurrence::getLCAPair(occ1, occ2); - assert(parOcc1 != nullptr && parOcc2 != nullptr); - bool isIfElseSituation = - parOcc1->parentOcc != nullptr && - parOcc1->parentOcc == parOcc2->parentOcc && - llvm::isa_and_present(parOcc1->parentOcc->op); - return isIfElseSituation; -} - -// Detect whether occ1 and occ2 have already been covered by an earlier sync. -bool Solver::checkAlreadySynced(Occurrence *occ1, Occurrence *occ2) { - assert(occ1 != nullptr && occ2 != nullptr); - assert(occ1->op != nullptr && occ2->op != nullptr); - - auto [parOcc1, parOcc2] = Occurrence::getLCAPair(occ1, occ2); - assert(parOcc1 != nullptr && parOcc2 != nullptr); - assert(parOcc1->parentOcc != nullptr && parOcc2->parentOcc != nullptr); - - auto [parOp1, parOp2] = OperationBase::getLCAPair(occ1->op, occ2->op); - assert(parOp1 != nullptr && parOp2 != nullptr); - assert(parOp1->parentOp != nullptr && parOp2->parentOp != nullptr); - - auto *parentLoop = OperationBase::getParentloop(parOcc1->op); - auto *curLoop = OperationBase::getParentloop(parOp1); - if (parentLoop == nullptr || parentLoop == curLoop) { - return false; - } - - assert(curLoop != nullptr); - assert(parentLoop->isProperAncestor(curLoop)); - while (curLoop != parentLoop) { - if (!llvm::cast(curLoop)->isParallel) { - return true; - } - curLoop = OperationBase::getParentloop(curLoop); - assert(curLoop != nullptr); - } - return false; -} - -// Unit-flag reuse check between two RWOperations. -bool Solver::checkAlreadySyncedWithUnitFlag(Occurrence *occ1, - Occurrence *occ2) { - assert(occ1 != nullptr && occ2 != nullptr); - if (!options.enableUnitFlagFeature) { - return false; - } - if (!occ1->hasUnitFlagFeat || !occ2->hasUnitFlagFeat) { - return false; - } - llvm::DenseSet visited; - DEBUG_WITH_TYPE("gss-sync-solver-check-unit-flag", { - llvm::dbgs() << "unit-flag-step: " << occ1->syncIrIndex << ' ' - << occ1->op->str(0, false) << "\n"; - }); - Occurrence *curOcc = occ1->unitFlagInfo.linkedElementAsSet; - while (curOcc != nullptr) { - DEBUG_WITH_TYPE("gss-sync-solver-check-unit-flag", { - llvm::dbgs() << "unit-flag-step: " << curOcc->syncIrIndex << ' ' - << curOcc->op->str(0, false) << "\n"; - }); - auto [it, isInserted] = visited.insert(curOcc); - if (!isInserted) { - break; - } - if (curOcc == occ2) { - return true; - } - curOcc = curOcc->unitFlagInfo.linkedElementAsSet; - } - return false; -} - -bool Solver::ignoreMemoryConflict(RWOperation *rwOp1, RWOperation *rwOp2, - const MemInfo &memInfo1, - const MemInfo &memInfo2) { - if (options.isIntraCoreMode()) { - if (memInfo1.isWorkSpace && memInfo2.isWorkSpace) { - if (options.intraCoreIgnoreWorkSpaceFunctionArguments) { - return true; - } - } - } - return false; -} - -bool Solver::checkMemInfoConflict(RWOperation *rwOp1, RWOperation *rwOp2, - const MemInfo &memInfo1, - const MemInfo &memInfo2, - std::optional lcmLen, - std::optional eventIdNum) { - if (ignoreMemoryConflict(rwOp1, rwOp2, memInfo1, memInfo2)) { - return false; - } - return MemInfo::checkConflict(memInfo1, memInfo2, lcmLen, eventIdNum); -} - -bool Solver::checkMemInfoConflict( - RWOperation *rwOp1, RWOperation *rwOp2, - const llvm::SmallVector &memInfoList1, - const llvm::SmallVector &memInfoList2, - std::optional lcmLen, std::optional eventIdNum) { - for (auto &memInfo1 : memInfoList1) { - for (auto &memInfo2 : memInfoList2) { - if (checkMemInfoConflict(rwOp1, rwOp2, memInfo1, memInfo2, lcmLen, - eventIdNum)) { - return true; - } - } - } - return false; -} - -// High-level wrapper computing pipe pairs that represent memory conflicts -// between two RW ops. -llvm::SmallVector> -Solver::checkMemoryConflicts(RWOperation *rwOp1, RWOperation *rwOp2) { - assert(rwOp1 != nullptr && rwOp2 != nullptr); - auto [it, isInserted] = checkMemoryConflictsMem.insert({{rwOp1, rwOp2}, {}}); - if (!isInserted) { - return it->second; - } - auto coreSrc = rwOp1->coreType; - auto coreDst = rwOp2->coreType; - if (options.isCrossCoreMode()) { - if (coreDst == pto::TCoreType::CUBE_AND_VECTOR) { - coreDst = (coreSrc == pto::TCoreType::VECTOR) ? pto::TCoreType::CUBE - : pto::TCoreType::VECTOR; - } - assert(coreSrc == pto::TCoreType::VECTOR || - coreSrc == pto::TCoreType::CUBE); - assert(coreDst == pto::TCoreType::VECTOR || - coreDst == pto::TCoreType::CUBE); - } - llvm::SetVector> collectedConflictsSet; - if (checkMemInfoConflict(rwOp1, rwOp2, rwOp1->readMemInfo, - rwOp2->writeMemInfo)) { - collectedConflictsSet.insert({CorePipeInfo(coreSrc, rwOp1->pipeRead), - CorePipeInfo(coreDst, rwOp2->pipeWrite)}); - } - if (checkMemInfoConflict(rwOp1, rwOp2, rwOp1->writeMemInfo, - rwOp2->readMemInfo)) { - collectedConflictsSet.insert({CorePipeInfo(coreSrc, rwOp1->pipeWrite), - CorePipeInfo(coreDst, rwOp2->pipeRead)}); - } - if (checkMemInfoConflict(rwOp1, rwOp2, rwOp1->writeMemInfo, - rwOp2->writeMemInfo)) { - collectedConflictsSet.insert({CorePipeInfo(coreSrc, rwOp1->pipeWrite), - CorePipeInfo(coreDst, rwOp2->pipeWrite)}); - } - llvm::SmallVector> collectedConflicts( - collectedConflictsSet.begin(), collectedConflictsSet.end()); - return it->second = collectedConflicts; -} - -bool Solver::checkMemoryConflictBetweenOccExclusive( - Occurrence *occ1, Occurrence *occ2, - std::function filter) { - assert(occ1 != nullptr && occ2 != nullptr); - auto *rwOp1 = llvm::dyn_cast_if_present(occ1->op); - auto *rwOp2 = llvm::dyn_cast_if_present(occ2->op); - assert(rwOp1 != nullptr && rwOp2 != nullptr); - for (int i = occ1->syncIrEndIndex; i < occ2->syncIrIndex; i++) { - if (auto *otherOp = llvm::dyn_cast_if_present(syncIr[i]->op)) { - if (!filter(otherOp)) { - continue; - } - if (!checkMemoryConflicts(rwOp1, otherOp).empty()) { - return true; - } - if (!checkMemoryConflicts(rwOp2, otherOp).empty()) { - return true; - } - } - } - return false; -} - -std::optional -Solver::getMultiBufferLoop(RWOperation *rwOp1, RWOperation *rwOp2, - const llvm::SmallVector &memInfoList1, - const llvm::SmallVector &memInfoList2) { - std::optional multibufferLoop; - for (auto &memInfo1 : memInfoList1) { - for (auto &memInfo2 : memInfoList2) { - if (checkMemInfoConflict(rwOp1, rwOp2, memInfo1, memInfo2)) { - if (!memInfo1.pointerLikeInfo.has_value() || - !memInfo2.pointerLikeInfo.has_value()) { - return {}; - } - auto multibufferLoop1 = memInfo1.pointerLikeInfo->parentLoop; - auto multibufferLoop2 = memInfo2.pointerLikeInfo->parentLoop; - if (multibufferLoop1 == nullptr || - multibufferLoop1 != multibufferLoop2) { - return {}; - } - if (multibufferLoop.has_value() && - multibufferLoop.value() != multibufferLoop1) { - return {}; - } - multibufferLoop = multibufferLoop1; - } - } - } - return multibufferLoop; -} - -std::optional -Solver::getMultiBufferLoop(RWOperation *rwOp1, RWOperation *rwOp2) { - std::optional multibufferLoop; - if (checkMemInfoConflict(rwOp1, rwOp2, rwOp1->readMemInfo, - rwOp2->writeMemInfo)) { - auto curMultibufferLoop = getMultiBufferLoop( - rwOp1, rwOp2, rwOp1->readMemInfo, rwOp2->writeMemInfo); - if (multibufferLoop.has_value() && - multibufferLoop.value() != curMultibufferLoop) { - return {}; - } - multibufferLoop = curMultibufferLoop; - } - if (checkMemInfoConflict(rwOp1, rwOp2, rwOp1->writeMemInfo, - rwOp2->readMemInfo)) { - auto curMultibufferLoop = getMultiBufferLoop( - rwOp1, rwOp2, rwOp1->writeMemInfo, rwOp2->readMemInfo); - if (multibufferLoop.has_value() && - multibufferLoop.value() != curMultibufferLoop) { - return {}; - } - multibufferLoop = curMultibufferLoop; - } - if (checkMemInfoConflict(rwOp1, rwOp2, rwOp1->writeMemInfo, - rwOp2->writeMemInfo)) { - auto curMultibufferLoop = getMultiBufferLoop( - rwOp1, rwOp2, rwOp1->writeMemInfo, rwOp2->writeMemInfo); - if (multibufferLoop.has_value() && - multibufferLoop.value() != curMultibufferLoop) { - return {}; - } - multibufferLoop = curMultibufferLoop; - } - return multibufferLoop; -} - -std::optional -Solver::getMultiBufferEventIdInfo(Occurrence *occ1, Occurrence *occ2, - RWOperation *rwOp1, RWOperation *rwOp2) { - assert(rwOp1 != nullptr && rwOp2 != nullptr); - - int64_t lcm = 1; - int64_t minWriteSize = LONG_MAX; - LoopLikeOpInterface multibufferLoop{nullptr}; - - if (options.isTestMode()) { - auto *parLoop1 = occ1->getParentOfType(); - auto *parLoop2 = occ2->getParentOfType(); - if (!parLoop1 || parLoop1 != parLoop2) { - return {}; - } - auto [setOcc, waitOcc] = getSetWaitOcc(occ1, occ2); - if (!parLoop1->isProperAncestor(setOcc) || - !parLoop1->isProperAncestor(waitOcc)) { - return {}; - } - } else { - auto multibufferLoopOpt = getMultiBufferLoop(rwOp1, rwOp2); - if (!multibufferLoopOpt.has_value() || !multibufferLoopOpt.value()) { - return {}; - } - multibufferLoop = multibufferLoopOpt.value(); - assert(multibufferLoop != nullptr); - auto [setOcc, waitOcc] = getSetWaitOcc(occ1, occ2); - if (!setOcc->getParentWithOp(multibufferLoop, - /*assertExists=*/false) || - !waitOcc->getParentWithOp(multibufferLoop, - /*assertExists=*/false)) { - return {}; - } - } - - for (auto &memInfo1 : rwOp1->readMemInfo) { - for (auto &memInfo2 : rwOp2->writeMemInfo) { - if (checkMemInfoConflict(rwOp1, rwOp2, memInfo1, memInfo2)) { - int64_t curLcm = std::lcm(memInfo1.getSz(), memInfo2.getSz()); - lcm = std::lcm(lcm, curLcm); - minWriteSize = std::min(minWriteSize, memInfo2.getSz()); - } - } - } - for (auto &memInfo1 : rwOp1->writeMemInfo) { - for (auto &memInfo2 : rwOp2->readMemInfo) { - if (checkMemInfoConflict(rwOp1, rwOp2, memInfo1, memInfo2)) { - int64_t curLcm = std::lcm(memInfo1.getSz(), memInfo2.getSz()); - lcm = std::lcm(lcm, curLcm); - minWriteSize = std::min(minWriteSize, memInfo1.getSz()); - } - } - } - for (auto &memInfo1 : rwOp1->writeMemInfo) { - for (auto &memInfo2 : rwOp2->writeMemInfo) { - if (checkMemInfoConflict(rwOp1, rwOp2, memInfo1, memInfo2)) { - int64_t curLcm = std::lcm(memInfo1.getSz(), memInfo2.getSz()); - lcm = std::lcm(lcm, curLcm); - minWriteSize = std::min(minWriteSize, memInfo1.getSz()); - minWriteSize = std::min(minWriteSize, memInfo2.getSz()); - } - } - } - - // In case no write sizes were positive. - if (minWriteSize == LONG_MAX) { - minWriteSize = 1; - return {}; - } - - int64_t eventIdNum = minWriteSize; - for (; eventIdNum >= 1; eventIdNum--) { - // llvm::dbgs() << "checking event-id-num: " << eventIdNum << '\n'; - int64_t curLcm = std::lcm(lcm, eventIdNum); - bool okRW = !checkMemInfoConflict(rwOp1, rwOp2, rwOp1->readMemInfo, - rwOp2->writeMemInfo, curLcm, eventIdNum); - bool okWR = !checkMemInfoConflict(rwOp1, rwOp2, rwOp1->writeMemInfo, - rwOp2->readMemInfo, curLcm, eventIdNum); - bool okWW = !checkMemInfoConflict(rwOp1, rwOp2, rwOp1->writeMemInfo, - rwOp2->writeMemInfo, curLcm, eventIdNum); - if (okRW && okWR && okWW) { - break; - } - } - if (eventIdNum <= 1) { - return {}; - } - EventIdInfo eventIdInfo(eventIdNum); - eventIdInfo.multibufferLoop = multibufferLoop; - return eventIdInfo; -} - -std::optional -Solver::checkMultiBufferEventIdInfo(Occurrence *occ1, Occurrence *occ2, - RWOperation *rwOp1, RWOperation *rwOp2) { - assert(rwOp1 != nullptr && rwOp2 != nullptr); - if (!options.isTestMode()) { - if (!checkAllParentLoopsAreForLoops(rwOp1->op) || - !checkAllParentLoopsAreForLoops(rwOp2->op)) { - return {}; - } - } - if (auto eventIdInfo = getMultiBufferEventIdInfo(occ1, occ2, rwOp1, rwOp2)) { - return eventIdInfo; - } - return {}; -} - -std::optional -Solver::checkCVMultiBufferUnrollEventIdInfo(RWOperation *rwOp1, - RWOperation *rwOp2) { - assert(rwOp1 != nullptr && rwOp2 != nullptr); - if (!options.isCrossCoreMode()) { - return {}; - } - auto *parentLoop1 = rwOp1->getParentOfType(); - auto *parentLoop2 = rwOp2->getParentOfType(); - while (parentLoop1 != nullptr && !parentLoop1->multibufferUnrollNum) { - parentLoop1 = parentLoop1->getParentOfType(); - } - while (parentLoop2 != nullptr && !parentLoop2->multibufferUnrollNum) { - parentLoop2 = parentLoop2->getParentOfType(); - } - if (!parentLoop1 || !parentLoop2) { - return {}; - } - if (auto *parCond1 = rwOp1->getParentOfType()) { - if (!parCond1->isProperAncestor(rwOp2)) { - return {}; - } - } - if (auto *parCond2 = rwOp2->getParentOfType()) { - if (!parCond2->isProperAncestor(rwOp1)) { - return {}; - } - } - assert(parentLoop1->multibufferUnrollNum.value() == - parentLoop2->multibufferUnrollNum.value()); - EventIdInfo eventIdInfo; - eventIdInfo.eventIdNum = parentLoop1->multibufferUnrollNum.value(); - eventIdInfo.multibufferUnrollLoop1 = - cast(parentLoop1->op); - eventIdInfo.multibufferUnrollLoop2 = - cast(parentLoop2->op); - return eventIdInfo; -} - -std::optional -Solver::checkCVMultiBufferPreloadEventIdInfo(RWOperation *rwOp1, - RWOperation *rwOp2) { - assert(rwOp1 != nullptr && rwOp2 != nullptr); - if (!options.isCrossCoreMode()) { - return {}; - } - auto *parentScope1 = rwOp1->getParentOfType(); - auto *parentScope2 = rwOp2->getParentOfType(); - while (parentScope1 != nullptr && !parentScope1->maxPreloadNum.has_value()) { - parentScope1 = parentScope1->getParentOfType(); - } - while (parentScope2 != nullptr && !parentScope2->maxPreloadNum.has_value()) { - parentScope2 = parentScope2->getParentOfType(); - } - if (!parentScope1 || !parentScope2) { - return {}; - } - if (auto *parCond1 = rwOp1->getParentOfType()) { - if (!parCond1->isProperAncestor(rwOp2)) { - return {}; - } - } - if (auto *parCond2 = rwOp2->getParentOfType()) { - if (!parCond2->isProperAncestor(rwOp1)) { - return {}; - } - } - - auto *parentLoop1 = parentScope1->getParentOfType(); - auto *parentLoop2 = parentScope2->getParentOfType(); - if (parentLoop1 == nullptr || parentLoop1 != parentLoop2) { - return {}; - } - - assert(parentScope1->preloadNum.has_value()); - assert(parentScope2->preloadNum.has_value()); - assert(parentScope1->maxPreloadNum.value() == - parentScope2->maxPreloadNum.value()); - - auto parentForLoop = llvm::dyn_cast_if_present(parentLoop1->op); - assert(parentForLoop != nullptr); - - EventIdInfo eventIdInfo; - eventIdInfo.eventIdNum = parentScope1->maxPreloadNum.value(); - eventIdInfo.preloadOffset1 = parentScope1->maxPreloadNum.value() - - parentScope1->preloadNum.value() - 1; - eventIdInfo.preloadOffset2 = parentScope2->maxPreloadNum.value() - - parentScope2->preloadNum.value() - 1; - eventIdInfo.multibufferLoop = parentForLoop; - return eventIdInfo; -} - -// Determine required event id count and optional multibuffer loop parent for -// occurrences. -EventIdInfo Solver::getEventIdInfo(Occurrence *occ1, Occurrence *occ2, - RWOperation *rwOp1, RWOperation *rwOp2, - CorePipeInfo corePipeSrc, - CorePipeInfo corePipeDst) { - assert(occ1 != nullptr && occ2 != nullptr); - assert(rwOp1 != nullptr && rwOp2 != nullptr); - EventIdInfo singleEventId(1); - if (!isBackwardSync(occ1, occ2)) { - return singleEventId; - } - if (auto eventIdInfo = checkCVMultiBufferUnrollEventIdInfo(rwOp1, rwOp2)) { - return eventIdInfo.value(); - } - if (auto eventIdInfo = checkCVMultiBufferPreloadEventIdInfo(rwOp1, rwOp2)) { - return eventIdInfo.value(); - } - if (auto eventIdInfo = - checkMultiBufferEventIdInfo(occ1, occ2, rwOp1, rwOp2)) { - return eventIdInfo.value(); - } - return singleEventId; -} - -// Graph-based check to determine if adding a sync between occ1 and occ2 would -// block progress. Uses GraphSolver (Dijkstra) to estimate minimal reachable -// index. -bool Solver::checkGraphConflict( - Occurrence *occ1, Occurrence *occ2, CorePipeInfo corePipeSrc, - CorePipeInfo corePipeDst, EventIdInfo eventIdInfo, - std::optional startIndex, std::optional endIndex, - const llvm::SmallVector &extraConflictPairs, - const llvm::SmallVector &ignoreConflictPairs) { - assert(occ1 != nullptr && occ2 != nullptr); - if (!startIndex.has_value()) { - startIndex = occ1->endIndex; - } - if (!endIndex.has_value()) { - endIndex = occ2->startIndex; - } - GraphSolver graphSolver(options); - llvm::DenseSet visited; - auto handleConflictPair = [&](ConflictPair *conflictPair) { - if (conflictPair->couldNotRun) { - return; - } - if (conflictPair->endIndex < startIndex.value() || - conflictPair->startIndex > endIndex.value()) { - return; - } - if (conflictPair->isInnerBackward) { - if ((eventIdInfo.eventIdNum * eventIdInfo.eventIdRepeatNum) < - (conflictPair->eventIdInfo.eventIdNum * - conflictPair->eventIdInfo.eventIdRepeatNum)) { - return; - } - } - if (llvm::find(ignoreConflictPairs, conflictPair) != - ignoreConflictPairs.end()) { - return; - } - auto [it, isInserted] = visited.insert(conflictPair); - if (!isInserted) { - return; - } - DEBUG_WITH_TYPE("gss-sync-solver-check-graph-conflict", { - llvm::dbgs() << "add-conflict-pair: " << conflictPair->str() << '\n'; - }); - graphSolver.addConflictPair(conflictPair); - }; - - for (auto *parOcc : occ1->getAllParents()) { - if (scopeOccChosenConflicts.contains(parOcc)) { - for (auto *conflictPair : scopeOccChosenConflicts[parOcc]) { - handleConflictPair(conflictPair); - } - } - } - for (auto *parOcc : occ2->getAllParents()) { - if (scopeOccChosenConflicts.contains(parOcc)) { - for (auto *conflictPair : scopeOccChosenConflicts[parOcc]) { - handleConflictPair(conflictPair); - } - } - } - for (auto &[scopeOccPair, chosenConflicts] : scopeOccPairChosenConflicts) { - auto [scopeOcc1, scopeOcc2] = scopeOccPair; - if (scopeOcc1->isProperAncestor(occ1) && - scopeOcc2->isProperAncestor(occ2)) { - for (auto *conflictPair : chosenConflicts) { - handleConflictPair(conflictPair); - } - } - } - for (auto *parOcc : occ1->getAllParents()) { - if (persistentScopeOccChosenConflicts.contains(parOcc)) { - for (auto *conflictPair : persistentScopeOccChosenConflicts[parOcc]) { - handleConflictPair(conflictPair); - } - } - } - for (auto *parOcc : occ2->getAllParents()) { - if (persistentScopeOccChosenConflicts.contains(parOcc)) { - for (auto *conflictPair : persistentScopeOccChosenConflicts[parOcc]) { - handleConflictPair(conflictPair); - } - } - } - for (auto *conflictPair : extraConflictPairs) { - handleConflictPair(conflictPair); - } - std::optional mnDistance; - if (options.enableUnitFlagFeature) { - mnDistance = graphSolver.runDijkstraUnitFlagEnabled( - occ1, occ2, corePipeSrc, corePipeDst, startIndex.value(), - endIndex.value()); - } else { - mnDistance = graphSolver.runDijkstra(corePipeSrc, corePipeDst, - startIndex.value(), endIndex.value()); - } - return !mnDistance.has_value() || mnDistance.value() > endIndex.value(); -} - -bool Solver::checkSyncOpsConflicts(ConflictPair *conflictPair1, - ConflictPair *conflictPair2) { - if (conflictPair1->isBarrier() || conflictPair2->isBarrier()) { - return false; - } - if (conflictPair1->startIndex > conflictPair2->startIndex) { - std::swap(conflictPair1, conflictPair2); - } - if (conflictPair1->startIndex >= conflictPair2->startIndex || - conflictPair1->endIndex >= conflictPair2->endIndex) { - return true; - } - bool result = false; - if (conflictPair1->setCorePipeInfo != conflictPair2->setCorePipeInfo) { - auto corePipeSrc = conflictPair1->setCorePipeInfo; - auto corePipeDst = conflictPair2->setCorePipeInfo; - Occurrence *occ1 = conflictPair1->setOcc; - Occurrence *occ2 = conflictPair2->setOcc; - auto startIndex = conflictPair1->startIndex + 1; - auto endIndex = conflictPair2->startIndex; - conflictPair1->startIndex += 1; - assert(occ1 != nullptr && occ2 != nullptr); - result = result || - checkGraphConflict(occ1, occ2, corePipeSrc, corePipeDst, - conflictPair1->eventIdInfo, startIndex, - endIndex, {conflictPair1}, {conflictPair2}); - conflictPair1->startIndex -= 1; - } - if (conflictPair1->waitCorePipeInfo != conflictPair2->waitCorePipeInfo) { - auto corePipeSrc = conflictPair1->waitCorePipeInfo; - auto corePipeDst = conflictPair2->waitCorePipeInfo; - Occurrence *occ1 = conflictPair1->waitOcc; - Occurrence *occ2 = conflictPair2->waitOcc; - auto startIndex = conflictPair1->endIndex; - auto endIndex = conflictPair2->endIndex - 1; - conflictPair2->endIndex -= 1; - assert(occ1 != nullptr && occ2 != nullptr); - result = result || - checkGraphConflict(occ1, occ2, corePipeSrc, corePipeDst, - conflictPair1->eventIdInfo, startIndex, - endIndex, {conflictPair1}, {conflictPair2}); - conflictPair2->endIndex += 1; - } - DEBUG_WITH_TYPE("gss-check-sync-ops-conflicts", { - if (result) { - llvm::dbgs() << "sync-ops-conflict-found: " << "\n"; - llvm::dbgs() << " " << conflictPair1->str() << '\n'; - llvm::dbgs() << " " << conflictPair2->str() << '\n'; - } - }); - return result; -} - -// Check whether two ConflictPair entries conflict in pipe and time ranges. -bool Solver::checkIntersect(ConflictPair *conflictPair1, - ConflictPair *conflictPair2) { - assert(conflictPair1 != nullptr && conflictPair2 != nullptr); - if (conflictPair1 == conflictPair2) { - return false; - } - if (conflictPair1->isBarrier() || conflictPair2->isBarrier()) { - return false; - } - if (conflictPair1->dontCheckForConflict || - conflictPair2->dontCheckForConflict) { - return false; - } - if (options.isCrossCoreMode()) { - return checkSyncOpsConflicts(conflictPair1, conflictPair2); - } - if (conflictPair1->setCorePipeInfo != conflictPair2->setCorePipeInfo || - conflictPair1->waitCorePipeInfo != conflictPair2->waitCorePipeInfo) { - return false; - } - for (auto [l1, r1] : getRanges(conflictPair1)) { - for (auto [l2, r2] : getRanges(conflictPair2)) { - if (checkRangesIntersect(l1, r1 + 1, l2, r2 + 1)) { - return true; - } - } - } - return false; -} - -// Obtain available event ids while accounting for already chosen conflicts. -std::vector -Solver::getIntersectingConflictPairs(ConflictPair *conflictPair) { - assert(conflictPair != nullptr); - if (conflictPair->isBarrier()) { - return {}; - } - if (conflictPair->dontCheckForConflict) { - return {}; - } - std::vector intersectingConflictPairs; - for (auto &curConflictPair : chosenConflictedPairs) { - if (checkIntersect(conflictPair, curConflictPair.get())) { - intersectingConflictPairs.push_back(curConflictPair.get()); - } - } - for (auto &curConflictPair : persistentChosenConflictedPairs) { - if (checkIntersect(conflictPair, curConflictPair.get())) { - intersectingConflictPairs.push_back(curConflictPair.get()); - } - } - return intersectingConflictPairs; -} - -// Processed-pair tracking helpers. -bool Solver::checkVisited(Occurrence *occ1, Occurrence *occ2) { - auto [it, isInserted] = processedOccPairs.insert(std::make_pair(occ1, occ2)); - return !isInserted; -} - -bool Solver::checkSkippable(bool reverseOrder, Occurrence *occ) { - return skipOcc[reverseOrder].contains(occ); -} - -// Synced-pair memoization helpers. -EventIdNode *Solver::getOldEventIdNodeIfExists(ConflictPair *conflictPair) { - assert(conflictPair != nullptr); - auto oldConflictPairs = getMemorizedSyncedPairs(conflictPair); - if (oldConflictPairs.empty()) { - return {}; - } - ConflictPair *oldConflictPair = *oldConflictPairs.begin(); - assert(oldConflictPair != nullptr && oldConflictPair->eventIdNode != nullptr); - return oldConflictPair->eventIdNode; -} - -llvm::DenseSet -Solver::getMemorizedSyncedPairs(ConflictPair *conflictPair) { - auto key = std::make_tuple( - conflictPair->backwardSyncLoopOp, conflictPair->op1, conflictPair->op2, - conflictPair->setCorePipeInfo, conflictPair->waitCorePipeInfo); - return syncedPairs[key]; -} - -void Solver::memorizeSyncedPair(ConflictPair *conflictPair) { - auto key = std::make_tuple( - conflictPair->backwardSyncLoopOp, conflictPair->op1, conflictPair->op2, - conflictPair->setCorePipeInfo, conflictPair->waitCorePipeInfo); - syncedPairs[key].insert(conflictPair); -#ifndef NDEBUG - for (auto *oldConflictPair : syncedPairs[key]) { - assert(oldConflictPair->eventIdNode == conflictPair->eventIdNode); - } -#endif -} - -void Solver::forgetSyncedPair(ConflictPair *conflictPair) { - assert(conflictPair != nullptr); - auto key = std::make_tuple( - conflictPair->backwardSyncLoopOp, conflictPair->op1, conflictPair->op2, - conflictPair->setCorePipeInfo, conflictPair->waitCorePipeInfo); - syncedPairs[key].erase(conflictPair); -} - -void Solver::memorizeReusedSyncedPair(ConflictPair *conflictPair, - ConflictPair *reusedConflictPair) { - assert(conflictPair != nullptr); - replacedWithReusableSyncedPairs[{ - conflictPair->backwardSyncLoopOp, conflictPair->op1, conflictPair->op2, - conflictPair->setCorePipeInfo, conflictPair->waitCorePipeInfo}] = - reusedConflictPair; -} - -bool Solver::skipMMad1DecomposedLoopOpt(Occurrence *occ1, Occurrence *occ2) { - auto *parentLoopOp1 = OperationBase::getParentloop(occ1->op); - auto *parentLoopOp2 = OperationBase::getParentloop(occ2->op); - if (parentLoopOp1 != nullptr && parentLoopOp2 != nullptr) { - if (parentLoopOp1 != parentLoopOp2) { - if (isa(parentLoopOp1) && - isa(parentLoopOp2)) { - return true; - } - } - } - return false; -} - -std::optional> -Solver::checkAndApplyMmadl0LoopOpt(ConflictPair *conflictPair, Occurrence *occ1, - Occurrence *occ2, Occurrence *parOcc1, - Occurrence *parOcc2) { - if (!options.decomposeMmadl1Op) { - return {}; - } - if (occ1->parentOcc != nullptr && occ1->parentOcc->parentOcc != nullptr && - occ1->parentOcc->parentOcc->parentOcc == parOcc1 && - llvm::isa_and_present( - occ1->op) && - llvm::isa_and_present( - occ1->parentOcc->parentOcc->op)) { - conflictPair->setOnLastIterOnly = true; - return std::make_pair(occ1, parOcc2); - } - if (!conflictPair->isInnerBackward && occ2->parentOcc != nullptr && - occ2->parentOcc->parentOcc != nullptr && - occ2->parentOcc->parentOcc->parentOcc == parOcc2 && - llvm::isa_and_present( - occ2->op) && - llvm::isa_and_present( - occ2->parentOcc->parentOcc->op)) { - conflictPair->waitOnFirstIterOnly = true; - return std::make_pair(parOcc1, occ2); - } - return {}; -} - -std::optional Solver::checkUnitFlagPatterns(Occurrence *occ1, - Occurrence *occ2) { - return {}; -} - -Occurrence *Solver::getBeforePlaceHolderOcc(Occurrence *occ) { - assert(occ != nullptr); - assert(llvm::isa_and_present(occ->op)); - int index = occ->syncIrIndex - 1; - assert(0 <= index && index < static_cast(syncIr.size())); - auto *placeHolderOcc = syncIr[index].get(); -#ifndef NDEBUG - auto *placeHolderOp = llvm::dyn_cast(placeHolderOcc->op); - assert(placeHolderOp != nullptr); - assert(placeHolderOp->beforeOp == occ->op); -#endif - return placeHolderOcc; -} - -Occurrence *Solver::getAfterPlaceHolderOcc(Occurrence *occ) { - assert(occ != nullptr); - assert(llvm::isa_and_present(occ->op)); - int index = occ->syncIrEndIndex; - assert(0 <= index && index < static_cast(syncIr.size())); - auto *placeHolderOcc = syncIr[index].get(); -#ifndef NDEBUG - auto *placeHolderOp = llvm::dyn_cast(placeHolderOcc->op); - assert(placeHolderOp != nullptr); - assert(placeHolderOp->afterOp == occ->op); -#endif - return placeHolderOcc; -} - -Occurrence *Solver::getScopeBeginPlaceHolderOcc(Occurrence *occ) { - assert(occ != nullptr); - assert(llvm::isa_and_present(occ->op)); - int index = occ->syncIrIndex + 1; - assert(0 <= index && index < static_cast(syncIr.size())); - auto *placeHolderOcc = syncIr[index].get(); -#ifndef NDEBUG - auto *placeHolderOp = llvm::dyn_cast(placeHolderOcc->op); - assert(placeHolderOp != nullptr); - assert(placeHolderOp->scopeBegin == occ->op); -#endif - return placeHolderOcc; -} - -Occurrence *Solver::getScopeEndPlaceHolderOcc(Occurrence *occ) { - assert(occ != nullptr); - assert(llvm::isa_and_present(occ->op)); - int index = occ->syncIrEndIndex - 1; - assert(0 <= index && index < static_cast(syncIr.size())); - auto *placeHolderOcc = syncIr[index].get(); -#ifndef NDEBUG - auto *placeHolderOp = llvm::dyn_cast(placeHolderOcc->op); - assert(placeHolderOp != nullptr); - assert(placeHolderOp->scopeEnd == occ->op); -#endif - return placeHolderOcc; -} - -std::pair -Solver::getSetWaitLCAPairOcc(Occurrence *occ1, Occurrence *occ2) { - assert(occ1 != nullptr && occ2 != nullptr); - - auto [grandParOcc1, grandParOcc2] = Occurrence::getLCAPair(occ1, occ2); - assert(grandParOcc1 != nullptr && grandParOcc2 != nullptr); - assert(grandParOcc1->parentOcc != nullptr && - grandParOcc2->parentOcc != nullptr); - - auto [parOp1, parOp2] = OperationBase::getLCAPair(occ1->op, occ2->op); - assert(parOp1 != nullptr && parOp2 != nullptr); - assert(parOp1->parentOp != nullptr && parOp2->parentOp != nullptr); - assert(parOp1->parentOp == parOp2->parentOp); - - auto *parOcc1 = occ1->getParentWithOp(parOp1->parentOp); - auto *parOcc2 = occ2->getParentWithOp(parOp2->parentOp); - assert(parOcc1 != nullptr && parOcc2 != nullptr); - assert(parOcc1 != occ1 && parOcc2 != occ2); - - auto *setOcc = occ1->getNthParent(occ1->depth - parOcc1->depth - 1); - auto *waitOcc = occ2->getNthParent(occ2->depth - parOcc2->depth - 1); - assert(setOcc != nullptr && waitOcc != nullptr); - assert(parOcc1->isProperAncestor(setOcc)); - assert(parOcc2->isProperAncestor(waitOcc)); - - auto *parLoop = Occurrence::getParentloop(setOcc); - while (parLoop != nullptr && grandParOcc1->isProperAncestor(parLoop)) { - setOcc = parLoop; - waitOcc = Occurrence::getParentloop(waitOcc); - parLoop = Occurrence::getParentloop(setOcc); - } - return std::make_pair(setOcc, waitOcc); -} - -std::pair -Solver::getFixedSetWaitOcc(Occurrence *occ1, Occurrence *occ2) { - // - get setOcc waitOcc where: - // setOcc->op->parent = waitOcc->op->parent = lca(occ1, occ2)->op - auto [setOcc, waitOcc] = getSetWaitLCAPairOcc(occ1, occ2); - - // - check if it's the case of while loop: - // while{ - // before{ - // occ1 - // } - // setOcc; - // waitOcc; - // after{ - // occ2 - // } - // } - // - and fix it to be: - // while{ - // before{ - // occ1 - // setOcc; - // ... - // waitOcc; - // placeHolder - // } - // after{ - // occ2 - // } - // } - if (setOcc->op != waitOcc->op) { - if (auto *parLoopOp = - llvm::dyn_cast_if_present(setOcc->parentOcc->op)) { - if (parLoopOp->body.size() > 1 && !isa(waitOcc->op)) { - auto *placeHolderOcc = getScopeEndPlaceHolderOcc(setOcc); - std::tie(setOcc, waitOcc) = getSetWaitLCAPairOcc(occ1, placeHolderOcc); - } - } - } - - // - check if it's the case of: - // loop(iter-1){ - // condition{ - // true-scope{} - // setOcc() - // false-scope{} - // } - // } - // loop(iter-2){ - // condition{ - // true-scope{} - // waitOcc() - // false-scope{} - // } - // } - // - and fix it to be: - // loop(iter-1){ - // condition{ - // true-scope{} - // false-scope{} - // } - // setOcc() - // } - // loop(iter-2){ - // waitOcc() - // condition{ - // true-scope{} - // false-scope{} - // } - // } - if (isBackwardSync(occ1, occ2)) { - if (setOcc->parentOcc != nullptr) { - if (llvm::isa_and_present(setOcc->parentOcc->op)) { - setOcc = setOcc->parentOcc; - } - } - if (waitOcc->parentOcc != nullptr) { - if (llvm::isa_and_present(waitOcc->parentOcc->op)) { - waitOcc = waitOcc->parentOcc; - } - } - } - - // - for the case of cv-pipelining: - // loop(){ - // op1 - // } {unroll=x} - // setOcc - // waitOcc - // loop(){ - // op2 - // } {unroll=x} - // - and fix it to be: - // loop(){ - // op1 - // setOcc - // } {unroll=x} - // loop(){ - // waitOcc - // op2 - // } {unroll=x} - if (options.isCrossCoreMode()) { - assert(setOcc->op != nullptr && waitOcc->op != nullptr); - auto *forOp1 = llvm::dyn_cast_if_present(setOcc->op); - auto *forOp2 = llvm::dyn_cast_if_present(waitOcc->op); - if (forOp1 != nullptr && forOp2 != nullptr) { - if (forOp1->multibufferUnrollNum && forOp2->multibufferUnrollNum) { - assert(forOp1->multibufferUnrollNum == forOp2->multibufferUnrollNum); - setOcc = occ1->getNthParent(occ1->depth - setOcc->depth - 2); - waitOcc = occ2->getNthParent(occ2->depth - waitOcc->depth - 2); - } - } - } - - // - for the case of cv-pipelining: - // scope(){ - // op1 - // } {preload=x} - // setOcc - // waitOcc - // scope(){ - // op2 - // } {preload=x} - // - and fix it to be: - // scope(){ - // op1 - // setOcc - // } {preload=x} - // scope(){ - // waitOcc - // op2 - // } {preload=x} - if (options.isCrossCoreMode()) { - assert(setOcc->op != nullptr && waitOcc->op != nullptr); - auto *scopeOp1 = llvm::dyn_cast_if_present(setOcc->op); - auto *scopeOp2 = llvm::dyn_cast_if_present(waitOcc->op); - if (scopeOp1 != nullptr && scopeOp2 != nullptr) { - if (scopeOp1->maxPreloadNum && scopeOp2->maxPreloadNum) { - assert(scopeOp1->maxPreloadNum == scopeOp2->maxPreloadNum); - setOcc = occ1->getNthParent(occ1->depth - setOcc->depth - 2); - waitOcc = occ2->getNthParent(occ2->depth - waitOcc->depth - 2); - } - } - } - - // - check if it's the case of: - // { - // op1 - // setOcc - // ... - // waitOcc - // loop(){} - // setOcc - // ... - // waitOcc - // op2 - // } - // - and fix it to be: - // { - // op1 - // setOcc - // ... - // waitOcc - // placeHolder - // loop(){} - // placeHolder - // setOcc - // ... - // waitOcc - // op2 - // } - if (llvm::isa_and_present(setOcc->op)) { - setOcc = getAfterPlaceHolderOcc(setOcc); - } - if (llvm::isa_and_present(waitOcc->op)) { - waitOcc = getBeforePlaceHolderOcc(waitOcc); - } - - return std::make_pair(setOcc, waitOcc); -} - -std::optional> -Solver::getFunctionBlockSetWaitOcc(Occurrence *occ1, Occurrence *occ2) { - assert(occ1 != nullptr && occ2 != nullptr); - auto *parFunctionBlock1 = occ1->getParentOfType(); - auto *parFunctionBlock2 = occ2->getParentOfType(); - if (parFunctionBlock1 == parFunctionBlock2) { - return {}; - } - auto *placeHolderOcc = getScopeBeginPlaceHolderOcc(parFunctionBlock2); - return std::make_pair(placeHolderOcc, occ2); -} - -std::optional> -Solver::getUnlikelyCondSetWaitOcc(Occurrence *occ1, Occurrence *occ2) { - assert(occ1 != nullptr && occ2 != nullptr); - if (options.isCrossCoreMode() && isBackwardSync(occ1, occ2)) { - return {}; - } - if (auto *unlikelyParCondOcc1 = - Occurrence::getUnlikelyParentCondition(occ1)) { - if (!unlikelyParCondOcc1->isProperAncestor(occ2)) { - auto *parentLoopOcc = Occurrence::getParentloop(unlikelyParCondOcc1); - if (parentLoopOcc == nullptr || parentLoopOcc->isProperAncestor(occ2)) { - auto *placeHolderOcc = getScopeEndPlaceHolderOcc( - occ1->getNthParent(occ1->depth - unlikelyParCondOcc1->depth - 1)); - return std::make_pair(occ1, placeHolderOcc); - } - } - } - if (auto *unlikelyParCondOcc2 = - Occurrence::getUnlikelyParentCondition(occ2)) { - if (!unlikelyParCondOcc2->isProperAncestor(occ1)) { - auto *parentLoopOcc = Occurrence::getParentloop(unlikelyParCondOcc2); - if (parentLoopOcc == nullptr || parentLoopOcc->isProperAncestor(occ1)) { - auto *placeHolderOcc = getScopeBeginPlaceHolderOcc( - occ2->getNthParent(occ2->depth - unlikelyParCondOcc2->depth - 1)); - return std::make_pair(placeHolderOcc, occ2); - } - } - } - return {}; -} - -std::pair Solver::getSetWaitOcc(Occurrence *occ1, - Occurrence *occ2) { - if (auto functionBlockOpt = getFunctionBlockSetWaitOcc(occ1, occ2)) { - std::tie(occ1, occ2) = functionBlockOpt.value(); - } - if (auto unlikelyOpt = getUnlikelyCondSetWaitOcc(occ1, occ2)) { - std::tie(occ1, occ2) = unlikelyOpt.value(); - } - return getFixedSetWaitOcc(occ1, occ2); -} - -Occurrence *Solver::getBarrierWaitOcc(Occurrence *occ1, Occurrence *occ2) { - auto [setOcc, waitOcc] = getSetWaitOcc(occ1, occ2); - if (!waitOcc->isProperAncestor(occ2)) { - return waitOcc; - } - auto allParents = occ2->getAllParents(); - while (!allParents.empty() && allParents.back()->isProperAncestor(waitOcc)) { - allParents.pop_back(); - } - while (allParents.size() >= 2 && - llvm::isa_and_present(allParents.back()->op)) { - allParents.pop_back(); - assert(llvm::isa_and_present(allParents.back()->op)); - allParents.pop_back(); - } - waitOcc = !allParents.empty() ? allParents.back() : occ2; - return waitOcc; -} - -void Solver::insertBarrierAllBeforeOcc(Occurrence *occ, bool isUseless, - bool isPersistent) { - assert(occ != nullptr); - auto *rwOp = llvm::dyn_cast_if_present(occ->op); - assert(rwOp != nullptr); - auto conflictPair = std::make_unique( - nullptr, nullptr, rwOp, rwOp, occ, occ, - CorePipeInfo(pto::TCoreType::CUBE_OR_VECTOR, pto::PIPE::PIPE_ALL), - CorePipeInfo(pto::TCoreType::CUBE_OR_VECTOR, pto::PIPE::PIPE_ALL), - occ->startIndex, occ->startIndex); - conflictPair->isUseless = isUseless; - auto *normScopeOcc = occ->parentOcc; - assert(normScopeOcc != nullptr); - LLVM_DEBUG(llvm::dbgs() << (isPersistent ? "is-persistent " : "") - << occ->op->str(0, false) << ' ' - << conflictPair->str() << '\n';); - if (isPersistent) { - persistentScopeOccChosenConflicts[normScopeOcc].insert(conflictPair.get()); - persistentChosenConflictedPairs.push_back(std::move(conflictPair)); - } else { - insertedBarrierAllBefore[occ->op].insert({occ, isUseless}); - scopeOccChosenConflicts[normScopeOcc].insert(conflictPair.get()); - chosenConflictedPairs.push_back(std::move(conflictPair)); - } -} - -void Solver::insertBarrierAllBeforeOp(OperationBase *op, bool isUseless, - bool isPersistent) { - assert(op != nullptr); - for (auto *occ : opAllOccurrences[op]) { - insertBarrierAllBeforeOcc(occ, isUseless, isPersistent); - isUseless = true; - } -} - -// When barrier-all markers need to be chosen, insert them before all -// occurrences for the chosen op. -void Solver::pickAndInsertABarrierAll() { - assert(!insertedBarrierAllBefore.empty()); - OperationBase *chosenOp = nullptr; - for (auto &[op, vec] : insertedBarrierAllBefore) { - if (vec.empty()) { - continue; - } - if (chosenOp == nullptr || chosenOp->id > op->id) { - chosenOp = op; - } - } - assert(chosenOp != nullptr); - insertBarrierAllBeforeOp(chosenOp, /*isUseless=*/false, - /*isPersistent=*/true); -} - -bool Solver::isBackwardSync(Occurrence *occ1, Occurrence *occ2) { - if (occ1->op->id >= occ2->op->id) { - return true; - } - assert(occ1 != nullptr && occ2 != nullptr); - assert(occ1->op != nullptr && occ2->op != nullptr); - auto [parOcc1, parOcc2] = Occurrence::getLCAPair(occ1, occ2); - auto [parOp1, parOp2] = OperationBase::getLCAPair(occ1->op, occ2->op); - return parOcc1->parentOcc->op != parOp1->parentOp; -} - -bool Solver::reuseCmp(ConflictPair *conflictPair1, - ConflictPair *conflictPair2) { - assert(conflictPair1 != nullptr && conflictPair2 != nullptr); - assert(conflictPair1->op1 != nullptr && conflictPair1->op2 != nullptr); - assert(conflictPair2->op1 != nullptr && conflictPair2->op2 != nullptr); - if (conflictPair1->startIndex != conflictPair2->startIndex) { - return conflictPair1->startIndex < conflictPair2->startIndex; - } - if (conflictPair1->endIndex != conflictPair2->endIndex) { - return conflictPair1->endIndex > conflictPair2->endIndex; - } - if (conflictPair1->op1 != conflictPair2->op1) { - return conflictPair1->op1->id > conflictPair2->op1->id; - } - if (conflictPair1->op2 != conflictPair2->op2) { - return conflictPair1->op2->id > conflictPair2->op2->id; - } - return false; -} - -ConflictPair *Solver::getReusableConflictPair( - ConflictPair *conflictPair, - const llvm::DenseSet &conflictPairsSet) { - assert(conflictPair != nullptr); - ConflictPair *ret = nullptr; - for (auto *curConflictPair : conflictPairsSet) { - if (curConflictPair->isBarrier() || curConflictPair->dontReuse) { - continue; - } - if (curConflictPair->op1 != conflictPair->op1 || - curConflictPair->op2 != conflictPair->op2 || - curConflictPair->setCorePipeInfo != conflictPair->setCorePipeInfo || - curConflictPair->waitCorePipeInfo != conflictPair->waitCorePipeInfo) { - continue; - } - if (!checkIntersect(conflictPair, curConflictPair)) { - continue; - } - if (curConflictPair->startIndex >= conflictPair->startIndex) { - continue; - } - if (conflictPair->eventIdNode->eventIdNum < - curConflictPair->eventIdNode->eventIdNum) { - continue; - } - assert(conflictPair->eventIdNode != nullptr); - assert(curConflictPair->eventIdNode != nullptr); - if (conflictPair->eventIdNode->eventIdNum > - curConflictPair->eventIdNode->eventIdNum) { - if (conflictPair->eventIdNode->eventIdNum % - curConflictPair->eventIdNode->eventIdNum) { - continue; - } - } - assert(conflictPair->startIndex <= curConflictPair->endIndex); - assert(curConflictPair->endIndex <= conflictPair->endIndex); - if (ret == nullptr || reuseCmp(ret, curConflictPair)) { - ret = curConflictPair; - } - } - return ret; -} - -bool Solver::reuseConflictPair(ConflictPair *conflictPair, - Occurrence *scopeOcc1, Occurrence *scopeOcc2) { - if (conflictPair->isBarrier()) { - return false; - } - if (scopeOcc1->op != scopeOcc2->op) { - return false; - } - if (!barrierAllPairs.empty()) { - return false; - } - - ConflictPair *oldReusedConflictPair = nullptr; - if (conflictPair->isUseless) { - auto it = replacedWithReusableSyncedPairs.find( - {conflictPair->backwardSyncLoopOp, conflictPair->op1, conflictPair->op2, - conflictPair->setCorePipeInfo, conflictPair->waitCorePipeInfo}); - if (it != replacedWithReusableSyncedPairs.end()) { - oldReusedConflictPair = it->second; - } - } - -#ifndef NDEBUG - if (!conflictPair->isUseless) { - auto it = replacedWithReusableSyncedPairs.find( - {conflictPair->backwardSyncLoopOp, conflictPair->op1, conflictPair->op2, - conflictPair->setCorePipeInfo, conflictPair->waitCorePipeInfo}); - assert(it == replacedWithReusableSyncedPairs.end()); - } -#endif - - if (conflictPair->isUseless && oldReusedConflictPair == nullptr) { - return false; - } - - auto corePipeSrc = conflictPair->setCorePipeInfo; - auto corePipeDst = conflictPair->waitCorePipeInfo; - - if (oldReusedConflictPair == nullptr) { - if (!reusePairs.contains({corePipeSrc, corePipeDst}) || - reusePairs[{corePipeSrc, corePipeDst}] <= - reusedPairs[{corePipeSrc, corePipeDst}]) { - return false; - } - } - - assert(reusePairs.contains(std::make_tuple(corePipeSrc, corePipeDst))); - assert(reusePairs[std::make_tuple(corePipeSrc, corePipeDst)] >= - reusedPairs[std::make_tuple(corePipeSrc, corePipeDst)]); - - ConflictPair *opt1 = nullptr; - ConflictPair *opt2 = nullptr; - ConflictPair *opt3 = nullptr; - ConflictPair *opt4 = nullptr; - ConflictPair *opt5 = nullptr; - - auto it1 = scopeOccChosenConflicts.find(scopeOcc1); - auto it2 = scopeOccChosenConflicts.find(scopeOcc2); - auto it3 = scopeOccPairChosenConflicts.find({scopeOcc1, scopeOcc2}); - auto it4 = persistentScopeOccChosenConflicts.find(scopeOcc1); - auto it5 = persistentScopeOccChosenConflicts.find(scopeOcc2); - - if (it1 != scopeOccChosenConflicts.end()) { - opt1 = getReusableConflictPair(conflictPair, it1->second); - } - if (it2 != scopeOccChosenConflicts.end()) { - opt2 = getReusableConflictPair(conflictPair, it2->second); - } - if (it3 != scopeOccPairChosenConflicts.end()) { - opt3 = getReusableConflictPair(conflictPair, it3->second); - } - if (it4 != persistentScopeOccChosenConflicts.end()) { - opt4 = getReusableConflictPair(conflictPair, it4->second); - } - if (it5 != persistentScopeOccChosenConflicts.end()) { - opt5 = getReusableConflictPair(conflictPair, it5->second); - } - - ConflictPair *reusableConflictPair = nullptr; - for (auto *opt : {opt1, opt2, opt3, opt4, opt5}) { - if (opt != nullptr) { - if (reusableConflictPair == nullptr || - reuseCmp(reusableConflictPair, opt)) { - reusableConflictPair = opt; - } - } - } - - if (reusableConflictPair == nullptr) { - return false; - } - - DEBUG_WITH_TYPE("gss-sync-solver-reuse", { - llvm::dbgs() << "reuse: " << conflictPair->str() << '\n'; - llvm::dbgs() << "with: " << reusableConflictPair->str() << '\n'; - }); - - assert(reusableConflictPair->startIndex < conflictPair->startIndex); - assert(reusableConflictPair->endIndex <= conflictPair->endIndex); - reusableConflictPair->setOp = conflictPair->setOp; - reusableConflictPair->setOcc = conflictPair->setOcc; - reusableConflictPair->startIndex = conflictPair->startIndex; - - if (!conflictPair->isUseless) { - memorizeReusedSyncedPair(conflictPair, reusableConflictPair); - } - - DEBUG_WITH_TYPE("gss-sync-solver-reuse", { - if (oldReusedConflictPair != nullptr) { - llvm::dbgs() << "old-reuse: " << oldReusedConflictPair->str() << '\n'; - } - }); - - if (oldReusedConflictPair != nullptr) { - assert(oldReusedConflictPair->op1 == reusableConflictPair->op1); - assert(oldReusedConflictPair->op2 == reusableConflictPair->op2); - assert(oldReusedConflictPair->waitOp == reusableConflictPair->waitOp); - } - - if (!conflictPair->isUseless) { - reusedPairs[{corePipeSrc, corePipeDst}] += 1; - } - - return true; -} - -std::unique_ptr & -Solver::getEventIdSolverRef(pto::PIPE pipeSrc, pto::PIPE pipeDst) { - if (options.isCrossCoreMode()) { - pipeSrc = pto::PIPE::PIPE_UNASSIGNED; - pipeDst = pto::PIPE::PIPE_UNASSIGNED; - } - auto key = std::make_tuple(pipeSrc, pipeDst); - if (!eventIdSolver.contains(key)) { - int64_t eventIdNumMax = - getHWAvailableEventIdNum(options.syncMode, pipeSrc, pipeDst); - if (options.eventIdNumMax.has_value()) { - eventIdNumMax = std::min(eventIdNumMax, options.eventIdNumMax.value()); - eventIdNumMax = std::max(eventIdNumMax, 1); - } - eventIdSolver[key] = std::make_unique(eventIdNumMax); - } - return eventIdSolver[key]; -} - -bool Solver::checkReuseMultiBufferFlagId(ConflictPair *conflictPair) { - if (options.useDifferentMultiBufferFlagIds) { - return false; - } - if (!conflictPair->isInnerBackward || - conflictPair->eventIdInfo.eventIdNum <= 1 || - conflictPair->movedToOuterLoop) { - return false; - } - auto [setOcc, waitOcc] = - std::tie(conflictPair->setOcc, conflictPair->waitOcc); - auto *backwardSyncLoopOcc = conflictPair->backwardSyncLoopOcc; - assert(backwardSyncLoopOcc != nullptr); - if (auto *parCondOcc1 = setOcc->getParentOfType()) { - if (!parCondOcc1->isProperAncestor(backwardSyncLoopOcc)) { - return false; - } - } - if (auto *parCondOcc2 = waitOcc->getParentOfType()) { - if (!parCondOcc2->isProperAncestor(backwardSyncLoopOcc)) { - return false; - } - } - return true; -} - -void Solver::handleSetWaitConflict(Occurrence *occ1, Occurrence *occ2, - CorePipeInfo corePipeSrc, - CorePipeInfo corePipeDst, - EventIdInfo eventIdInfo, bool isUseless) { - assert(occ1 != nullptr && occ2 != nullptr); - auto *rwOp1 = llvm::dyn_cast_if_present(occ1->op); - auto *rwOp2 = llvm::dyn_cast_if_present(occ2->op); - assert(rwOp1 != nullptr && rwOp2 != nullptr); - assert(corePipeSrc != corePipeDst); - - Loop *parentLCALoopOp{nullptr}; - Occurrence *parentLCALoopOcc{nullptr}; - Occurrence *parentLCALoopBeforePHOcc{nullptr}; - Occurrence *parentLCALoopAfterPHOcc{nullptr}; - auto [setOcc, waitOcc] = getSetWaitOcc(occ1, occ2); - - auto [lcaSetOp, lcaWaitOp] = - OperationBase::getLCAPair(setOcc->op, waitOcc->op); - auto *normScopeOcc1 = setOcc->getParentWithOp(lcaSetOp->parentOp); - auto *normScopeOcc2 = waitOcc->getParentWithOp(lcaWaitOp->parentOp); - assert(normScopeOcc1->op == normScopeOcc2->op); - auto *normScopeOp = normScopeOcc1->op; - assert(normScopeOp != nullptr); - assert(normScopeOp->parentOp != nullptr); - - auto conflictPair = std::make_unique( - rwOp1, rwOp2, setOcc->op, waitOcc->op, setOcc, waitOcc, corePipeSrc, - corePipeDst, setOcc->endIndex, waitOcc->startIndex); - assert(conflictPair->startIndex <= conflictPair->endIndex); - - conflictPair->isUseless = isUseless; - conflictPair->isInnerBackward = isBackwardSync(setOcc, waitOcc); - conflictPair->eventIdInfo = eventIdInfo; - - if (conflictPair->isInnerBackward) { - auto [parOcc1, parOcc2] = Occurrence::getLCAPair(occ1, occ2); - assert(parOcc1 != nullptr && parOcc2 != nullptr); - - parentLCALoopOcc = parOcc1->getParentOfType(); - if (moveBackwardSyncPairsToOutmostLoop) { - while (auto *grandParentLoopOcc = - parentLCALoopOcc->getParentOfType()) { - conflictPair->movedToOuterLoop = true; - parentLCALoopOcc = grandParentLoopOcc; - } - } - assert(parentLCALoopOcc != nullptr); - conflictPair->backwardSyncLoopOcc = parentLCALoopOcc; - - parentLCALoopOp = llvm::dyn_cast(parentLCALoopOcc->op); - assert(parentLCALoopOp != nullptr); - conflictPair->backwardSyncLoopOp = parentLCALoopOp; - - parentLCALoopBeforePHOcc = getBeforePlaceHolderOcc(parentLCALoopOcc); - assert(parentLCALoopBeforePHOcc != nullptr); - parentLCALoopAfterPHOcc = getAfterPlaceHolderOcc(parentLCALoopOcc); - assert(parentLCALoopAfterPHOcc != nullptr); - } - - if (auto setWaitOccs = checkAndApplyMmadl0LoopOpt(conflictPair.get(), occ1, - occ2, setOcc, waitOcc)) { - std::tie(setOcc, waitOcc) = setWaitOccs.value(); - conflictPair->updateSetWaitOccs(setOcc, waitOcc); - } - - if (!conflictPair->isInnerBackward || - disabledMultiEventIdPairs.contains({corePipeSrc, corePipeDst})) { - conflictPair->eventIdInfo = EventIdInfo(1); - } - if (checkReuseMultiBufferFlagId(conflictPair.get())) { - conflictPair->eventIdInfo.eventIdRepeatNum = - conflictPair->eventIdInfo.eventIdNum; - conflictPair->eventIdInfo.eventIdNum = 1; - } - - auto &curEventIdSolver = getEventIdSolverRef( - conflictPair->setCorePipeInfo.pipe, conflictPair->waitCorePipeInfo.pipe); - curEventIdSolver->pushActionNone(); - - auto checkColorable = [&]() -> bool { - if (curEventIdSolver->isColorable()) { - return true; - } - LLVM_DEBUG(llvm::dbgs() << "will-be-converted-to-barrier-all " - << conflictPair->str() << '\n';); - insertBarrierAllBeforeOp(occ2->op, conflictPair->isUseless, - /*isPersistent=*/false); - barrierAllPairs.insert({corePipeSrc, corePipeDst}); - curEventIdSolver->undoActions(); - return false; - }; - - if (auto *oldEventIdNode = getOldEventIdNodeIfExists(conflictPair.get())) { - conflictPair->eventIdNode = oldEventIdNode; - curEventIdSolver->insertConflictPair(oldEventIdNode, conflictPair.get()); - } else { - bool reversedPriority = false; - if (conflictPair->isInnerBackward) { - if (OperationBase::getParentloop(occ1->op) == normScopeOp->parentOp && - OperationBase::getParentloop(occ2->op) == normScopeOp->parentOp) { - reversedPriority = true; - } - } - conflictPair->eventIdNode = curEventIdSolver->createNode( - conflictPair.get(), conflictPair->eventIdInfo.eventIdNum, - reversedPriority); - } - - if (options.reuseSyncPairToSaveEventIds) { - if (reuseConflictPair(conflictPair.get(), normScopeOcc1, normScopeOcc2)) { - curEventIdSolver->undoActions(); - return; - } - } - - auto intersectingConflictPairs = - getIntersectingConflictPairs(conflictPair.get()); - curEventIdSolver->addConflicts(conflictPair.get(), intersectingConflictPairs); - if (!checkColorable()) { - return; - } - - LLVM_DEBUG({ - llvm::dbgs() << conflictPair->str() << '\n'; - if (parentLCALoopOcc != nullptr) { - llvm::dbgs() << parentLCALoopOcc->op->str(0, false) << '\n'; - } - }); - - llvm::SmallVector, Occurrence *>> - extraConflictPairs; - - auto insertExtraConflictPair = [&](Occurrence *setOcc, Occurrence *waitOcc, - Occurrence *parentScope, - bool couldNotRun = false) -> bool { - assert(setOcc != nullptr && waitOcc != nullptr && parentScope != nullptr); - auto extraConflictPair = conflictPair->clone(setOcc, waitOcc); - extraConflictPair->isUseless = true; - extraConflictPair->dontReuse = true; - if (couldNotRun || options.moveOutAndMergeBackwardSyncPairs) { - extraConflictPair->couldNotRun = true; - } - LLVM_DEBUG({ - llvm::dbgs() << "extra-conflict-pair: " << extraConflictPair->str() - << "\n"; - }); - curEventIdSolver->insertConflictPair(conflictPair->eventIdNode, - extraConflictPair.get()); - auto intersectingConflictPairs = - getIntersectingConflictPairs(extraConflictPair.get()); - curEventIdSolver->addConflicts(extraConflictPair.get(), - intersectingConflictPairs); - if (!checkColorable()) { - return false; - } - extraConflictPairs.push_back( - std::make_pair(std::move(extraConflictPair), parentScope)); - return true; - }; - - if (conflictPair->isInnerBackward && conflictPair->eventIdNode != nullptr) { - bool insertOuterBwdConflictPair = false; - if ((conflictPair->eventIdInfo.eventIdNum * - conflictPair->eventIdInfo.eventIdRepeatNum) > 1) { - insertOuterBwdConflictPair = true; - } else if (options.isCrossCoreMode()) { - if (setOcc->parentOcc == nullptr || - setOcc->parentOcc->parentOcc == nullptr || - setOcc->parentOcc->parentOcc->op != parentLCALoopOp) { - insertOuterBwdConflictPair = true; - } else if (waitOcc->parentOcc == nullptr || - waitOcc->parentOcc->parentOcc == nullptr || - waitOcc->parentOcc->parentOcc->op != parentLCALoopOp) { - insertOuterBwdConflictPair = true; - } - } - if (insertOuterBwdConflictPair) { - // insert useless conflictPair to cover the whole loop when having - // multi-eventid backward sync to reserve the eventIds. - if (!insertExtraConflictPair(parentLCALoopBeforePHOcc, - parentLCALoopAfterPHOcc, - parentLCALoopOcc->parentOcc)) { - return; - } - } - } - - if (conflictPair->isInnerBackward && conflictPair->eventIdNode != nullptr) { - // insert header/footer useless conflictPairs to reserve the eventIds. - auto *loopOpOcc1 = getFirstIterOcc(waitOcc, normScopeOcc1); - auto *loopOpOcc2 = getLastIterOcc(setOcc, normScopeOcc2); - if (!insertExtraConflictPair(parentLCALoopBeforePHOcc, loopOpOcc1, - parentLCALoopOcc, /*couldNotRun=*/true)) { - return; - } - if (!insertExtraConflictPair(loopOpOcc2, parentLCALoopAfterPHOcc, - parentLCALoopOcc, /*couldNotRun=*/true)) { - return; - } - } - - bool dontInsert = false; - if (conflictPair->isInnerBackward && normScopeOcc1 != normScopeOcc2) { - auto *parCond = OperationBase::getParentCondition(conflictPair->setOp); - if (auto *conditionOp = llvm::dyn_cast_if_present(parCond)) { - if (parentLCALoopOcc->op->isProperAncestor(conditionOp)) { - scopeOccPairChosenConflicts[{normScopeOcc1, normScopeOcc2}].insert( - conflictPair.get()); - dontInsert = true; - } - } - } - if (!dontInsert) { - assert(parentLCALoopOcc != nullptr || normScopeOcc1 == normScopeOcc2); - scopeOccChosenConflicts[normScopeOcc1].insert(conflictPair.get()); - scopeOccChosenConflicts[normScopeOcc2].insert(conflictPair.get()); - } - - memorizeSyncedPair(conflictPair.get()); - chosenConflictedPairs.push_back(std::move(conflictPair)); - - for (auto &[extraConflictPair, parentScope] : extraConflictPairs) { - scopeOccChosenConflicts[parentScope].insert(extraConflictPair.get()); - chosenConflictedPairs.push_back(std::move(extraConflictPair)); - } - - curEventIdSolver->clearActionStack(); -} - -void Solver::handleBarrierConflict(Occurrence *occ1, Occurrence *occ2, - CorePipeInfo corePipeSrc, - CorePipeInfo corePipeDst, bool isUseless) { - assert(occ1 != nullptr && occ2 != nullptr); - auto *rwOp1 = llvm::dyn_cast_if_present(occ1->op); - auto *rwOp2 = llvm::dyn_cast_if_present(occ2->op); - assert(rwOp1 != nullptr && rwOp2 != nullptr); - - assert(corePipeSrc == corePipeDst); - if (corePipeSrc.pipe == pto::PIPE::PIPE_S) { - return; - } - if (options.isRegBasedArch) { - if (corePipeSrc.pipe == pto::PIPE::PIPE_V || - corePipeSrc.pipe == pto::PIPE::PIPE_M) { - return; - } - } - auto *waitOcc = getBarrierWaitOcc(occ1, occ2); - - auto conflictPair = std::make_unique( - rwOp1, rwOp2, waitOcc->op, waitOcc->op, waitOcc, waitOcc, corePipeSrc, - corePipeDst, waitOcc->startIndex, waitOcc->startIndex); - conflictPair->isUseless = isUseless; - assert(conflictPair->startIndex <= conflictPair->endIndex); - - LLVM_DEBUG({ llvm::dbgs() << conflictPair->str() << '\n'; }); - - auto *normScopeOcc = waitOcc->parentOcc; - scopeOccChosenConflicts[normScopeOcc].insert(conflictPair.get()); - chosenConflictedPairs.push_back(std::move(conflictPair)); -} - -void Solver::handleUnitFlagConflict(Occurrence *occ1, Occurrence *occ2, - CorePipeInfo corePipeSrc, - CorePipeInfo corePipeDst, - UnitFlagInfo unitFlagInfo, bool isUseless) { - assert(occ1 != nullptr && occ2 != nullptr); - auto *rwOp1 = llvm::dyn_cast_if_present(occ1->op); - auto *rwOp2 = llvm::dyn_cast_if_present(occ2->op); - assert(rwOp1 != nullptr && rwOp2 != nullptr); - assert(corePipeSrc != corePipeDst); - - auto *setOcc = occ1; - auto *waitOcc = occ2; - auto *normScopeOcc1 = setOcc->parentOcc; - auto *normScopeOcc2 = waitOcc->parentOcc; - - auto conflictPair = std::make_unique( - rwOp1, rwOp2, setOcc->op, waitOcc->op, setOcc, waitOcc, corePipeSrc, - corePipeDst, setOcc->endIndex, waitOcc->startIndex); - assert(conflictPair->startIndex <= conflictPair->endIndex); - - conflictPair->isUseless = true; - conflictPair->dontReuse = true; - conflictPair->replacedWithUnitFlag = true; - conflictPair->dontCheckForConflict = true; - conflictPair->isInnerBackward = isBackwardSync(setOcc, waitOcc); - -#ifndef NDEBUG - Occurrence *parentLCALoopOcc{nullptr}; - if (conflictPair->isInnerBackward) { - auto [parOcc1, parOcc2] = Occurrence::getLCAPair(occ1, occ2); - assert(parOcc1 != nullptr && parOcc2 != nullptr); - parentLCALoopOcc = Occurrence::getParentloop(parOcc1); - assert(parentLCALoopOcc != nullptr); - } - - LLVM_DEBUG({ - llvm::dbgs() << conflictPair->str() << '\n'; - if (parentLCALoopOcc != nullptr) { - llvm::dbgs() << parentLCALoopOcc->op->str(0, false) << '\n'; - } - }); -#endif - - occ1->unitFlagInfo.merge(unitFlagInfo, occ1, occ2, - /*asSet=*/true, /*asWait=*/false); - occ2->unitFlagInfo.merge(unitFlagInfo, occ1, occ2, - /*asSet=*/false, /*asWait=*/true); - if (!isUseless) { - rwOp1->mergedUnitFlagInfo.merge(unitFlagInfo, /*asSet=*/true, - /*asWait=*/false); - rwOp2->mergedUnitFlagInfo.merge(unitFlagInfo, /*asSet=*/false, - /*asWait=*/true); - } - - scopeOccPairChosenConflicts[{normScopeOcc1, normScopeOcc2}].insert( - conflictPair.get()); - chosenConflictedPairs.push_back(std::move(conflictPair)); -} - -void Solver::handleConflict(Occurrence *occ1, Occurrence *occ2, - RWOperation *rwOp1, RWOperation *rwOp2, - CorePipeInfo corePipeSrc, CorePipeInfo corePipeDst, - EventIdInfo eventIdInfo, bool isUseless) { - if (!checkGraphConflict(occ1, occ2, corePipeSrc, corePipeDst, eventIdInfo)) { - return; - } - LLVM_DEBUG({ - llvm::dbgs() << "conflict found: " << "eventIdNum(" - << eventIdInfo.eventIdNum << ")\n"; - llvm::dbgs() << occ1->syncIrIndex << ' ' << occ1->startIndex << ' ' - << occ1->endIndex << ' ' << rwOp1->str(0, false) << '\n'; - llvm::dbgs() << occ2->syncIrIndex << ' ' << occ2->startIndex << ' ' - << occ2->endIndex << ' ' << rwOp2->str(0, false) << '\n'; - }); - if (corePipeSrc == corePipeDst) { - handleBarrierConflict(occ1, occ2, corePipeSrc, corePipeDst, isUseless); - } else if (auto unitFlagInfo = checkUnitFlagPatterns(occ1, occ2)) { - handleUnitFlagConflict(occ1, occ2, corePipeSrc, corePipeDst, - unitFlagInfo.value(), isUseless); - } else { - handleSetWaitConflict(occ1, occ2, corePipeSrc, corePipeDst, eventIdInfo, - isUseless); - } -} - -void Solver::calcAllEventIds() { - for (auto &[pipes, eventIdSolver] : eventIdSolver) { - assert(eventIdSolver != nullptr); - - [[maybe_unused]] auto result = - eventIdSolver->shrinkEventIdMaxToEventIdNum(); - assert(llvm::succeeded(result)); - assert(eventIdSolver->isColorable()); - } -} - -void Solver::collectBackwardSyncEventIds() { - LLVM_DEBUG(llvm::dbgs() << "collectBackwardSyncEventIds\n";); - for (auto &conflictPair : chosenConflictedPairs) { - if (!conflictPair->isUseless && conflictPair->isInnerBackward && - conflictPair->eventIdNode != nullptr) { - LLVM_DEBUG(llvm::dbgs() << " " << conflictPair->str() << "\n";); - for (auto eventId : conflictPair->eventIdNode->getEventIds()) { - auto &e = backwardSyncEvents[conflictPair->backwardSyncLoopOp] - [{conflictPair->setCorePipeInfo, - conflictPair->waitCorePipeInfo}][eventId]; - e = std::max(e, conflictPair->eventIdInfo.eventIdRepeatNum); - } - } - } -} - -void Solver::resetAndBuildSetWaitOpIndex(const SyncMap &syncMapBefore, - const SyncMap &syncMapAfter) { - globalSetWaitIndex = 0; - setWaitStartIndex.clear(); - setWaitEndIndex.clear(); - setWaitStartIndexInclusive.clear(); - setWaitEndIndexInclusive.clear(); - setWaitFlagOpsIndex.clear(); - collectSetWaitOpsIndexes(funcIr.get(), syncMapBefore, syncMapAfter); -} - -std::set> & -Solver::getSetWaitOpsIndexRef(pto::PIPE pipeSrc, pto::PIPE pipeDst, - int64_t eventId) { - auto key = std::make_tuple(pipeSrc, pipeDst, eventId); - return setWaitFlagOpsIndex[key]; -} - -// Collect indices for all Set/Wait ops to facilitate merging decisions. -void Solver::collectSetWaitOpsIndexes(OperationBase *op, - const SyncMap &syncMapBefore, - const SyncMap &syncMapAfter) { - assert(op != nullptr); - setWaitStartIndexInclusive[op] = globalSetWaitIndex++; - if (syncMapBefore.count(op)) { - auto *it = syncMapBefore.find(op); - assert(it != syncMapBefore.end()); - for (auto &syncOp : it->second) { - if (auto *setWaitOp = llvm::dyn_cast(syncOp.get())) { - for (auto eventId : setWaitOp->eventIds) { - auto &index = getSetWaitOpsIndexRef(setWaitOp->pipeSrc, - setWaitOp->pipeDst, eventId); - index.insert({globalSetWaitIndex++, setWaitOp}); - } - } - } - } - setWaitStartIndex[op] = globalSetWaitIndex++; - if (auto *scopeOp = llvm::dyn_cast(op)) { - for (auto &childOp : scopeOp->body) { - collectSetWaitOpsIndexes(childOp.get(), syncMapBefore, syncMapAfter); - } - } - setWaitEndIndex[op] = globalSetWaitIndex++; - if (syncMapAfter.count(op)) { - auto *it = syncMapAfter.find(op); - assert(it != syncMapAfter.end()); - for (auto &syncOp : it->second) { - if (auto *setWaitOp = llvm::dyn_cast(syncOp.get())) { - for (auto eventId : setWaitOp->eventIds) { - auto &index = getSetWaitOpsIndexRef(setWaitOp->pipeSrc, - setWaitOp->pipeDst, eventId); - index.insert({globalSetWaitIndex++, setWaitOp}); - } - } - } - } - setWaitEndIndexInclusive[op] = globalSetWaitIndex++; -} - -bool Solver::checkBackwardSyncEventsContains(OperationBase *op, - CorePipeInfo corePipeSrc, - CorePipeInfo corePipeDst, - int64_t eventId) { - auto *it1 = backwardSyncEvents.find(op); - if (it1 == backwardSyncEvents.end()) { - return false; - } - auto it2 = it1->second.find({corePipeSrc, corePipeDst}); - if (it2 == it1->second.end()) { - return false; - } - return it2->second.contains(eventId); -} - -bool Solver::checkBackwardSyncEventsContainsAfterMerge( - OperationBase *op, CorePipeInfo corePipeSrc, CorePipeInfo corePipeDst) { - auto *it1 = backwardSyncEventsAfterMerge.find(op); - if (it1 == backwardSyncEventsAfterMerge.end()) { - return false; - } - return it1->second.contains({corePipeSrc, corePipeDst}); -} - -// Check whether a backward-sync event id can be merged at scope level. -bool Solver::checkMergeable(Scope *scopeOp, CorePipeInfo corePipeSrc, - CorePipeInfo corePipeDst, int64_t eventId, - bool shouldBeUsedAtleastOnce) { - auto &index = - getSetWaitOpsIndexRef(corePipeSrc.pipe, corePipeDst.pipe, eventId); - if (shouldBeUsedAtleastOnce) { - auto it = index.lower_bound({setWaitStartIndexInclusive[scopeOp], nullptr}); - bool usedAtleastOnce = - it != index.end() && it->first < setWaitEndIndexInclusive[scopeOp]; - if (!usedAtleastOnce) { - return false; - } - } - { - auto it1 = - index.lower_bound({setWaitStartIndexInclusive[scopeOp], nullptr}); - auto it2 = index.lower_bound({setWaitEndIndex[scopeOp], nullptr}); - bool usedBefore = - it1 != index.end() && it1->first < setWaitStartIndex[scopeOp]; - bool usedAfter = - it2 != index.end() && it2->first < setWaitEndIndexInclusive[scopeOp]; - if (usedBefore || usedAfter) { - return false; - } - } - if (auto *conditionOp = llvm::dyn_cast(scopeOp)) { - if (!conditionOp->hasFalseScope()) { - return false; - } - return checkMergeable(conditionOp->getTrueScope(), corePipeSrc, corePipeDst, - eventId, true) && - checkMergeable(conditionOp->getFalseScope(), corePipeSrc, - corePipeDst, eventId, true); - } - if (auto *loopOp = llvm::dyn_cast(scopeOp)) { - for (auto &childOp : loopOp->body) { - if (auto *childScopeOp = llvm::dyn_cast(childOp.get())) { - if (!checkMergeable(childScopeOp, corePipeSrc, corePipeDst, eventId, - false)) { - return false; - } - } - } - for (auto &childOp : loopOp->body) { - if (auto *childScopeOp = llvm::dyn_cast(childOp.get())) { - if (checkMergeable(childScopeOp, corePipeSrc, corePipeDst, eventId, - true)) { - return true; - } - } - } - return false; - } - for (auto &childOp : scopeOp->body) { - auto it1 = - index.lower_bound({setWaitStartIndexInclusive[childOp.get()], nullptr}); - auto it2 = index.lower_bound({setWaitEndIndex[childOp.get()], nullptr}); - bool usedAtleastOnce = it1 != index.end() && - it1->first < setWaitEndIndexInclusive[childOp.get()]; - if (!usedAtleastOnce) { - continue; - } - bool before = - it1 != index.end() && it1->first < setWaitStartIndex[childOp.get()]; - bool after = it2 != index.end() && - it2->first < setWaitEndIndexInclusive[childOp.get()]; - if (before || after) { - return false; - } - if (!checkBackwardSyncEventsContains(childOp.get(), corePipeSrc, - corePipeDst, eventId)) { - return false; - } - if (checkBackwardSyncEventsContainsAfterMerge(childOp.get(), corePipeSrc, - corePipeDst)) { - return false; - } - } - return true; -} - -// Attempt to merge backward sync events across children and prune duplicates. -void Solver::mergeBackwardSyncEventIds(OperationBase *op) { - auto *scopeOp = llvm::dyn_cast_if_present(op); - if (scopeOp == nullptr) { - return; - } - for (auto &op : scopeOp->body) { - mergeBackwardSyncEventIds(op.get()); - } - - if (llvm::isa_and_present(op)) { - return; - } - if (llvm::isa_and_present(op->parentOp)) { - return; - } - - auto *conditionOp = llvm::dyn_cast(op); - if (conditionOp != nullptr) { - if (!conditionOp->hasFalseScope()) { - return; - } - } - - llvm::DenseSet> toBeErased; - - llvm::SmallVector coreTypes; - if (options.isCrossCoreMode()) { - coreTypes = {pto::TCoreType::VECTOR, pto::TCoreType::CUBE}; - } else { - coreTypes = {pto::TCoreType::CUBE_OR_VECTOR}; - } - size_t pipeNumMax = static_cast(pto::PIPE::PIPE_NUM); - const int64_t eventIdMax = getHWAvailableEventIdNum(options.syncMode); - - for (int64_t eventId = 0; eventId < eventIdMax; ++eventId) { - for (auto coreSrc : coreTypes) { - for (auto coreDst : coreTypes) { - for (size_t pipeSrcInt = 0; pipeSrcInt < pipeNumMax; pipeSrcInt++) { - for (size_t pipeDstInt = 0; pipeDstInt < pipeNumMax; pipeDstInt++) { - auto pipeSrc = static_cast(pipeSrcInt); - auto pipeDst = static_cast(pipeDstInt); - auto corePipeSrc = CorePipeInfo(coreSrc, pipeSrc); - auto corePipeDst = CorePipeInfo(coreDst, pipeDst); - if (checkBackwardSyncEventsContains(scopeOp, corePipeSrc, - corePipeDst, eventId)) { - continue; - } - if (checkMergeable(scopeOp, corePipeSrc, corePipeDst, eventId)) { - toBeErased.insert({corePipeSrc, corePipeDst, eventId}); - backwardSyncEvents[scopeOp][{corePipeSrc, corePipeDst}].insert( - {eventId, 1}); - } - } - } - } - } - } - - if (isa(scopeOp)) { - for (auto &op : scopeOp->body) { - if (auto *block = llvm::dyn_cast(op.get())) { - for (auto &childOp : block->body) { - if (auto *childScopeOp = llvm::dyn_cast(childOp.get())) { - for (auto [corePipeSrc, corePipeDst, eventId] : toBeErased) { - if (checkBackwardSyncEventsContains(childScopeOp, corePipeSrc, - corePipeDst, eventId)) { - auto key = std::make_tuple(corePipeSrc, corePipeDst); - backwardSyncEvents[childScopeOp][key].erase(eventId); - if (backwardSyncEvents[childScopeOp][key].empty()) { - backwardSyncEvents[childScopeOp].erase(key); - } - } - } - } - } - } - } - } else { - for (auto &childOp : scopeOp->body) { - if (auto *childScopeOp = llvm::dyn_cast(childOp.get())) { - for (auto [corePipeSrc, corePipeDst, eventId] : toBeErased) { - if (checkBackwardSyncEventsContains(childScopeOp, corePipeSrc, - corePipeDst, eventId)) { - auto key = std::make_tuple(corePipeSrc, corePipeDst); - backwardSyncEvents[childScopeOp][key].erase(eventId); - if (backwardSyncEvents[childScopeOp][key].empty()) { - backwardSyncEvents[childScopeOp].erase(key); - } - } - } - } - } - } -} - -void Solver::mergeBackwardSyncPairs(SyncMap &syncMapBefore, - SyncMap &syncMapAfter) { - if (!options.moveOutAndMergeBackwardSyncPairs) { - return; - } - if (options.isIntraCoreMode()) { - resetAndBuildSetWaitOpIndex(syncMapBefore, syncMapAfter); - auto *scopeOp = llvm::dyn_cast(funcIr.get()); - assert(scopeOp != nullptr && scopeOp->body.front() != nullptr); - mergeBackwardSyncEventIds(scopeOp->body.front().get()); - } -} - -SyncBeforeAfterMap Solver::getBeforeAfterSyncMaps() { - calcAllEventIds(); - SyncMap syncMapBefore, syncMapAfter; - std::vector conflictPairs; - for (auto &conflictPair : chosenConflictedPairs) { - conflictPairs.push_back(conflictPair.get()); - } - for (auto &conflictPair : persistentChosenConflictedPairs) { - conflictPairs.push_back(conflictPair.get()); - } - - for (auto *conflictPair : conflictPairs) { - if (conflictPair->isUseless) { - continue; - } - if (conflictPair->replacedWithUnitFlag) { - continue; - } - assert(conflictPair->setOp != nullptr && conflictPair->waitOp != nullptr); - if (conflictPair->isBarrier()) { - auto barrierOp = std::make_unique( - conflictPair->waitOp->op, conflictPair->waitOp->parentOp, - conflictPair->waitCorePipeInfo.pipe); - LLVM_DEBUG(barrierOp->debugId = conflictPair->id); - syncMapBefore[conflictPair->waitOp].push_back(std::move(barrierOp)); - } else { - assert(conflictPair->eventIdNode != nullptr); - auto setOp = std::make_unique( - conflictPair->setOp->op, conflictPair->setOp->parentOp, - conflictPair->eventIdNode->getEventIds(), - conflictPair->setCorePipeInfo.pipe, - conflictPair->waitCorePipeInfo.pipe); - auto waitOp = std::make_unique( - conflictPair->waitOp->op, conflictPair->waitOp->parentOp, - conflictPair->eventIdNode->getEventIds(), - conflictPair->setCorePipeInfo.pipe, - conflictPair->waitCorePipeInfo.pipe); - if (options.isCrossCoreMode()) { - setOp->coreType = conflictPair->setCorePipeInfo.coreType; - waitOp->coreType = conflictPair->waitCorePipeInfo.coreType; - } - setOp->eventIdInfo = conflictPair->eventIdInfo; - waitOp->eventIdInfo = conflictPair->eventIdInfo; - setOp->checkLastIter = conflictPair->setOnLastIterOnly; - waitOp->checkFirstIter = conflictPair->waitOnFirstIterOnly; - LLVM_DEBUG({ - setOp->debugId = conflictPair->id; - waitOp->debugId = conflictPair->id; - }); - assert(setOp != nullptr && waitOp != nullptr); - syncMapAfter[conflictPair->setOp].push_back(std::move(setOp)); - syncMapBefore[conflictPair->waitOp].push_front(std::move(waitOp)); - } - } - - collectBackwardSyncEventIds(); - mergeBackwardSyncPairs(syncMapBefore, syncMapAfter); - - for (auto &[op, mp] : backwardSyncEvents) { - if (mp.empty()) { - continue; - } - auto *scopeOp = llvm::dyn_cast(op); - assert(scopeOp != nullptr); - for (auto [setWaitCorePipes, eventIdsMp] : mp) { - if (eventIdsMp.empty()) { - continue; - } - llvm::SmallVector eventIds; - for (auto [eventId, repeatNum] : eventIdsMp) { - llvm::SmallVector curEventIds(repeatNum, eventId); - llvm::append_range(eventIds, curEventIds); - } - llvm::sort(eventIds); - auto [corePipeSrc, corePipeDst] = setWaitCorePipes; - auto setOp = - std::make_unique(scopeOp->op, scopeOp->parentOp, eventIds, - corePipeSrc.pipe, corePipeDst.pipe); - auto waitOp = - std::make_unique(scopeOp->op, scopeOp->parentOp, eventIds, - corePipeSrc.pipe, corePipeDst.pipe); - setOp->allAtOnce = true; - waitOp->allAtOnce = true; - if (options.isCrossCoreMode()) { - setOp->coreType = corePipeSrc.coreType; - waitOp->coreType = corePipeDst.coreType; - } - assert(setOp != nullptr && waitOp != nullptr); - syncMapBefore[scopeOp].push_back(std::move(setOp)); - syncMapAfter[scopeOp].push_front(std::move(waitOp)); - } - } - return std::make_pair(std::move(syncMapBefore), std::move(syncMapAfter)); -} - -void Solver::processConflict(Occurrence *occ1, Occurrence *occ2, - RWOperation *rwOp1, RWOperation *rwOp2, - bool isUseless) { - for (auto [corePipeSrc, corePipeDst] : checkMemoryConflicts(rwOp1, rwOp2)) { - if (options.alwaysUsePipeSAsWaitingPipe) { - corePipeDst.pipe = pto::PIPE::PIPE_S; - } - auto eventIdInfo = - getEventIdInfo(occ1, occ2, rwOp1, rwOp2, corePipeSrc, corePipeDst); - handleConflict(occ1, occ2, rwOp1, rwOp2, corePipeSrc, corePipeDst, - eventIdInfo, isUseless); - } -} - -// Main processing loop that iterates processingOrders and attempts to -// discover and record conflicts. -void Solver::processOrders() { - for (auto &[occ1, occ2, rwOp1, rwOp2, isUseless] : processingOrders) { - assert(occ1 != occ2); - assert(occ1->syncIrIndex < occ2->syncIrIndex); - if (checkVisited(occ1, occ2)) { - assert(false && "expected to not check a pair more than once."); - continue; - } - if (checkImpossibleOccPair(occ1, occ2) || checkAlreadySynced(occ1, occ2) || - skipMMad1DecomposedLoopOpt(occ1, occ2) || - checkSkipParallelLoop(occ1, occ2) || - checkSkipCrossCorePair(occ1, occ2)) { - continue; - } - DEBUG_WITH_TYPE("gss-sync-solver-checking", { - llvm::dbgs() << "checking: " << (isUseless ? "is-useless\n" : "\n"); - llvm::dbgs() << occ1->syncIrIndex << ' ' << occ1->startIndex << ' ' - << occ1->endIndex << ' ' << occ1->op->str(0, false) << '\n'; - llvm::dbgs() << occ2->syncIrIndex << ' ' << occ2->startIndex << ' ' - << occ2->endIndex << ' ' << occ2->op->str(0, false) << '\n'; - }); - if (checkAlreadySyncedWithUnitFlag(occ1, occ2)) { - continue; - } - processConflict(occ1, occ2, rwOp1, rwOp2, isUseless); - } -} - -void Solver::insertMergedBackwardSyncPairs() { - for (auto &[scopeOp, st] : backwardSyncEventsAfterMerge) { - for (auto &corePipeInfoPair : st) { - auto [corePipeSrc, corePipeDst] = corePipeInfoPair; - for (auto *scopeOcc : opAllOccurrences[scopeOp]) { - auto *parentScopeOcc = scopeOcc->parentOcc; - assert(parentScopeOcc != nullptr); - Occurrence *setOcc = nullptr; - Occurrence *waitOcc = nullptr; - auto startIndex = scopeOcc->startIndex; - auto endIndex = scopeOcc->endIndex; - if (isa(scopeOp)) { - setOcc = getBeforePlaceHolderOcc(scopeOcc); - waitOcc = getAfterPlaceHolderOcc(scopeOcc); - startIndex = setOcc->endIndex; - endIndex = waitOcc->startIndex; - } - auto conflictPair = std::make_unique( - nullptr, nullptr, nullptr, nullptr, setOcc, waitOcc, corePipeSrc, - corePipeDst, startIndex, endIndex); - assert(conflictPair->startIndex <= conflictPair->endIndex); - conflictPair->isUseless = true; - conflictPair->dontReuse = true; - conflictPair->dontCheckForConflict = true; - conflictPair->couldNotRun = false; // notice this - LLVM_DEBUG({ - llvm::dbgs() << "consider-merged-backward-pair: " - << scopeOp->str(0, false) << ' ' << conflictPair->str() - << "\n"; - }); - scopeOccChosenConflicts[parentScopeOcc].insert(conflictPair.get()); - chosenConflictedPairs.push_back(std::move(conflictPair)); - } - } - } -} - -llvm::LogicalResult Solver::considerOuterBackwardSyncPairs() { - if (!options.considerOuterBackwardSyncPairs) { - return llvm::failure(); - } - bool backwardPairsPositionChanged = false; - for (auto &[scopeOp, st] : backwardSyncEventsAfterMerge) { - SmallVector> toBeErased; - for (auto &corePipeInfoPair : st) { - if (!backwardSyncEvents.contains(scopeOp) || - !backwardSyncEvents[scopeOp].contains(corePipeInfoPair)) { - toBeErased.push_back(corePipeInfoPair); - } - } - if (!toBeErased.empty()) { - backwardPairsPositionChanged = true; - for (auto &corePipeInfoPair : toBeErased) { - st.erase(corePipeInfoPair); - } - } - } - int chosenOpsDepth = -1; - SmallVector chosenOps; - for (auto &[scopeOp, mp] : backwardSyncEvents) { - if (backwardSyncEventsAfterMerge.contains(scopeOp)) { - continue; - } - int scopeOpDepth = scopeOp->getDepth(); - if (chosenOpsDepth == scopeOpDepth) { - chosenOps.push_back(scopeOp); - } else if (chosenOpsDepth == -1 || chosenOpsDepth < scopeOpDepth) { - chosenOps.clear(); - chosenOps.push_back(scopeOp); - chosenOpsDepth = scopeOpDepth; - } - } - if (chosenOps.empty()) { - return llvm::failure(); - } - bool newPairIsInserted = false; - for (auto *chosenOp : chosenOps) { - for (auto &[corePipeInfoPair, eventIdsMp] : backwardSyncEvents[chosenOp]) { - assert(!eventIdsMp.empty()); - if (!eventIdsMp.empty()) { - auto [it, isInserted] = - backwardSyncEventsAfterMerge[chosenOp].insert(corePipeInfoPair); - newPairIsInserted |= isInserted; - } - } - } - return llvm::success(backwardPairsPositionChanged || newPairIsInserted); -} - -llvm::LogicalResult Solver::reuseSyncPairToSaveEventIds() { - if (!options.reuseSyncPairToSaveEventIds || barrierAllPairs.empty()) { - return llvm::failure(); - } - bool limitReached = true; - for (auto [corePipeSrc, corePipeDst] : barrierAllPairs) { - if (reusePairs[{corePipeSrc, corePipeDst}] < maxReuseNum) { - if (reusePairs[{corePipeSrc, corePipeDst}] <= - reusedPairs[{corePipeSrc, corePipeDst}]) { - reusePairs[{corePipeSrc, corePipeDst}] += 1; - limitReached = false; - } - } - } - DEBUG_WITH_TYPE("gss-sync-solver-reuse", { - llvm::dbgs() << "reusePairs: \n"; - for (auto [pipeCorePairs, cnt] : reusePairs) { - llvm::dbgs() << get<0>(pipeCorePairs).pipe << ' ' - << get<1>(pipeCorePairs).pipe << ' ' << cnt << '\n'; - } - }); - return llvm::success(!limitReached); -} - -llvm::LogicalResult Solver::disableMultiEventIdForBarrierAllPairs() { - if (!options.disableMultiEventIdForBarrierAllPairs || - barrierAllPairs.empty()) { - return llvm::failure(); - } - bool newPairIsInserted = false; - for (auto corePipeInfoPair : barrierAllPairs) { - auto [it, isInserted] = disabledMultiEventIdPairs.insert(corePipeInfoPair); - newPairIsInserted |= isInserted; - } - LLVM_DEBUG({ - if (newPairIsInserted) { - llvm::dbgs() << "disabled-multi-event-id-pairs: \n"; - for (auto &[corePipeSrc, corePipeDst] : disabledMultiEventIdPairs) { - llvm::dbgs() << corePipeSrc.coreType << ' ' << corePipeSrc.pipe << ' ' - << corePipeDst.coreType << ' ' << corePipeDst.pipe << '\n'; - } - } - }); - return llvm::success(newPairIsInserted); -} - -llvm::LogicalResult Solver::tryMovingOutBackwardSyncPairsToOuterLoops() { - if (!options.moveOutAndMergeBackwardSyncPairs || !options.isCrossCoreMode() || - dontMoveBackwardSyncPairsToOutmostLoop) { - return llvm::failure(); - } - if (!moveBackwardSyncPairsToOutmostLoop) { - moveBackwardSyncPairsToOutmostLoop = true; - return llvm::success(); - } - if (!barrierAllPairs.empty()) { - moveBackwardSyncPairsToOutmostLoop = false; - dontMoveBackwardSyncPairsToOutmostLoop = true; - return llvm::success(); - } - return llvm::failure(); -} - -// High-level solve orchestration with multiple passes and optional merging -// iterations. -llvm::LogicalResult Solver::runSolver(bool enableOpts1, bool enableOpts2) { - reset(/*resetEventIdRanOutOpts=*/true); - - int64_t runNum = 0; - while (runNum++ < maxRunNum) { - LLVM_DEBUG(llvm::dbgs() << "runNum: " << runNum << '\n'); - - reset(); - insertMergedBackwardSyncPairs(); - processOrders(); - - if (llvm::succeeded(tryMovingOutBackwardSyncPairsToOuterLoops())) { - continue; - } - - if (enableOpts1) { - if (options.considerOuterBackwardSyncPairs) { - getBeforeAfterSyncMaps(); - if (llvm::succeeded(considerOuterBackwardSyncPairs())) { - continue; - } - if (!barrierAllPairs.empty()) { - backwardSyncEventsAfterMerge.clear(); - } - } - } - - if (enableOpts2) { - if (!barrierAllPairs.empty()) { - if (llvm::succeeded(reuseSyncPairToSaveEventIds())) { - continue; - } - if (llvm::succeeded(disableMultiEventIdForBarrierAllPairs())) { - continue; - } - } - } - - if (!barrierAllPairs.empty()) { - pickAndInsertABarrierAll(); - reset(/*resetEventIdRanOutOpts=*/true); - continue; - } - break; - } - - reset(); - insertMergedBackwardSyncPairs(); - processOrders(); - - return llvm::success(runNum < maxRunNum); -} - -void Solver::solve() { - if (llvm::succeeded(runSolver())) { - return; - } - if (!options.isTestMode()) { - if (llvm::succeeded(runSolver(/*enableOpts1=*/false))) { - return; - } - if (llvm::succeeded( - runSolver(/*enableOpts1=*/false, /*enableOpts2=*/false))) { - return; - } - } - llvm_unreachable("GSS: runSolver() failed."); -} diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index 3a7a2e5a4..ea9466da1 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -6,5 +6,12898 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. +//===- PTOToEmitC.cpp - PTO to EmitC conversion pass ----------------------===// +//===----------------------------------------------------------------------===// -#include "PTOToEmitC.def" +#pragma GCC diagnostic ignored "-Woverloaded-virtual" +// https://discourse.llvm.org/t/matchandrewrite-hiding-virtual-functions/84933/8 + +#include +#include + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" +#include "PTO/IR/PTOSyncUtils.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" +#include "mlir/Analysis/DataFlowFramework.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" + +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeRange.h" + +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Target/Cpp/CppEmitter.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" + +#include +#include +#include +#include + +#define DEBUG_TYPE "pto-emitc" + +namespace mlir { +#define GEN_PASS_DEF_EMITPTOMANUAL +#include "PTO/Transforms/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +static std::string getElemTypeStringForGT(Type elemTy); +static bool getStaticMemrefLayout(MemRefType mrTy, + SmallVectorImpl &strides, + int64_t &offset); +static int64_t multiplyOrDynamic(int64_t lhs, int64_t rhs); +static void buildGlobalTensorShapeAndStride(ArrayRef shape, + ArrayRef strides, + SmallVectorImpl &shape5D, + SmallVectorImpl &stride5D); +static std::string joinIntTemplateParams(ArrayRef values); +static SmallVector buildRowMajorStrides(ArrayRef shape); +static std::string getGlobalTensorTypeStringFromShape(Type elemTy, + ArrayRef shape, + StringRef layoutEnum = + "pto::Layout::ND"); +static std::string getGlobalTensorTypeStringFromShapeAndStrides( + Type elemTy, ArrayRef shape, ArrayRef strides, + StringRef layoutEnum = "pto::Layout::ND"); +static emitc::OpaqueType getGlobalTensorOpaqueTypeFromShape( + MLIRContext *ctx, Type elemTy, ArrayRef shape, + StringRef layoutEnum = "pto::Layout::ND"); + +static const char *addrSpaceQualifier(pto::AddressSpace as) { + switch (as) { + case pto::AddressSpace::Zero: + return "__gm__"; + case pto::AddressSpace::VEC: + return "__ubuf__"; + case pto::AddressSpace::GM: + return "__gm__"; + case pto::AddressSpace::MAT: + return "__cbuf__"; + case pto::AddressSpace::LEFT: + return "__ca__"; + case pto::AddressSpace::RIGHT: + return "__cb__"; + case pto::AddressSpace::ACC: + return "__cc__"; + case pto::AddressSpace::BIAS: + // Bias tiles are special in pto-isa; keep a safe fallback qualifier. + return "__gm__"; + case pto::AddressSpace::SCALING: + // pto-isa TileType::Scaling maps to __fbuf__ (see pto/common/memory.hpp). + return "__fbuf__"; + } + return "__gm__"; +} + +[[maybe_unused]] static constexpr llvm::StringLiteral kLoweredSetValidShapeAttrName = + "__pto.lowered_set_validshape"; +[[maybe_unused]] static constexpr llvm::StringLiteral kLoweredSetValidShapeConfigAttrName = + "__pto.lowered_set_validshape_config"; +static constexpr llvm::StringLiteral kForceDynamicValidShapeAttrName = + "__pto.force_dynamic_valid_shape"; +static constexpr llvm::StringLiteral kGlobalTensorStridesAttrName = + "__pto.globaltensor_strides"; + +static Value peelUnrealized(Value v) { + if (auto castOp = v.getDefiningOp()) + return castOp.getOperand(0); + return v; +} + +static Value buildGlobalTensorFromMemref(ConversionPatternRewriter &rewriter, + Location loc, Value basePtr, + MemRefType mrTy, Operation *anchor); + +static Value maybeWrapGlobalMemrefAsGlobalTensor( + ConversionPatternRewriter &rewriter, Location loc, Value loweredValue, + Type originalType, Operation *anchor); + +static bool hasCompatibleKnownExtentForMGather(int64_t lhs, int64_t rhs) { + return lhs == ShapedType::kDynamic || rhs == ShapedType::kDynamic || + lhs == rhs; +} + +static bool isKnownUnitExtentForMGather(int64_t value) { + return value == ShapedType::kDynamic || value == 1; +} + +struct GatherScatterShapeLayoutInfo { + SmallVector shape; + bool rowMajor = false; + bool colMajor = false; +}; + +static std::optional +getGatherScatterShapeLayoutInfo(Type ty) { + if (auto tileTy = dyn_cast(ty)) { + ArrayRef validShape = tileTy.getValidShape(); + if (validShape.size() != 2) + return std::nullopt; + + GatherScatterShapeLayoutInfo info; + info.shape.assign(validShape.begin(), validShape.end()); + int32_t blayout = tileTy.getBLayoutValueI32(); + info.rowMajor = blayout == static_cast(pto::BLayout::RowMajor); + info.colMajor = blayout == static_cast(pto::BLayout::ColMajor); + return info; + } + + auto memRefTy = dyn_cast(ty); + if (!memRefTy || memRefTy.getRank() != 2) + return std::nullopt; + + SmallVector strides; + int64_t offset = ShapedType::kDynamic; + if (failed(getStridesAndOffset(memRefTy, strides, offset)) || + strides.size() != 2) + return std::nullopt; + + GatherScatterShapeLayoutInfo info; + info.shape.assign(memRefTy.getShape().begin(), memRefTy.getShape().end()); + info.rowMajor = strides[1] == 1; + info.colMajor = strides[0] == 1; + return info; +} + +static bool isRowCoalescedMGatherIndexType(Type dataTy, Type idxTy) { + auto dataInfo = getGatherScatterShapeLayoutInfo(dataTy); + auto idxInfo = getGatherScatterShapeLayoutInfo(idxTy); + if (!dataInfo || !idxInfo) + return false; + + const bool rowCoalesce1xR = + idxInfo->rowMajor && isKnownUnitExtentForMGather(idxInfo->shape[0]) && + hasCompatibleKnownExtentForMGather(idxInfo->shape[1], dataInfo->shape[0]); + const bool rowCoalesceRx1 = + idxInfo->colMajor && + hasCompatibleKnownExtentForMGather(idxInfo->shape[0], dataInfo->shape[0]) && + isKnownUnitExtentForMGather(idxInfo->shape[1]); + return rowCoalesce1xR || rowCoalesceRx1; +} + +static std::optional getLayoutAttrFromOp(Operation *op) { + if (!op) + return std::nullopt; + if (auto attr = op->getAttrOfType("layout")) + return attr.getLayout(); + return std::nullopt; +} + +static std::optional resolveLayoutFromValueChain(Value v) { + v = peelUnrealized(v); + while (Operation *def = v.getDefiningOp()) { + if (auto layout = getLayoutAttrFromOp(def)) + return layout; + if (auto subview = dyn_cast(def)) { + v = peelUnrealized(subview.getSource()); + continue; + } + if (auto reinterpret = dyn_cast(def)) { + v = peelUnrealized(reinterpret.getSource()); + continue; + } + if (auto cast = dyn_cast(def)) { + v = peelUnrealized(cast.getSource()); + continue; + } + if (auto unrealized = dyn_cast(def)) { + if (unrealized->getNumOperands() == 0) + break; + v = peelUnrealized(unrealized.getOperand(0)); + continue; + } + break; + } + return std::nullopt; +} + +static std::optional +resolveLayoutForGlobalTensor(Operation *anchor, Value basePtr) { + if (auto layout = getLayoutAttrFromOp(anchor)) + return layout; + return resolveLayoutFromValueChain(basePtr); +} + +static std::string layoutToEmitCString(mlir::pto::Layout layout) { + switch (layout) { + case mlir::pto::Layout::ND: + return "pto::Layout::ND"; + case mlir::pto::Layout::DN: + return "pto::Layout::DN"; + case mlir::pto::Layout::NZ: + return "pto::Layout::NZ"; + } + return "pto::Layout::ND"; +} + +static bool isEmitCGlobalTensorLikeType(Type ty) { + auto opaqueTy = dyn_cast(ty); + return opaqueTy && opaqueTy.getValue().contains("GlobalTensor<"); +} + +static std::string getEmitCScalarTypeToken(Type elemTy) { + if (pto::isPTOFloat8Type(elemTy) && + (elemTy.isFloat8E4M3() || elemTy.isFloat8E4M3FN() || + elemTy.isFloat8E4M3FNUZ() || elemTy.isFloat8E4M3B11FNUZ())) + return "float8_e4m3_t"; + if (pto::isPTOFloat8Type(elemTy) && + (elemTy.isFloat8E5M2() || elemTy.isFloat8E5M2FNUZ())) + return "float8_e5m2_t"; + if (isa(elemTy)) + return "hifloat8_t"; + if (isa(elemTy)) + return "float4_e1m2x2_t"; + if (isa(elemTy)) + return "float4_e2m1x2_t"; + if (elemTy.isF16()) + return "half"; + if (elemTy.isBF16()) + return "bfloat16_t"; + if (elemTy.isF32()) + return "float"; + if (elemTy.isF64()) + return "double"; + if (elemTy.isInteger(8)) + return (elemTy.isSignlessInteger(8) || elemTy.isSignedInteger(8)) ? "int8_t" + : "uint8_t"; + if (elemTy.isInteger(16)) + return (elemTy.isSignlessInteger(16) || elemTy.isSignedInteger(16)) + ? "int16_t" + : "uint16_t"; + if (elemTy.isInteger(32)) + return (elemTy.isSignlessInteger(32) || elemTy.isSignedInteger(32)) + ? "int32_t" + : "uint32_t"; + if (elemTy.isInteger(64)) + return cast(elemTy).isUnsigned() ? "uint64_t" : "int64_t"; + return "float"; +} + +static emitc::PointerType getEmitCPointerType(MLIRContext *ctx, + StringRef pointeeTypeStr) { + return emitc::PointerType::get(emitc::OpaqueType::get(ctx, pointeeTypeStr)); +} + +static emitc::PointerType getEmitCPointerType(MLIRContext *ctx, + StringRef qualifier, + StringRef elemTypeStr) { + return getEmitCPointerType(ctx, (qualifier + " " + elemTypeStr).str()); +} + +static bool isEmitCPointerLikeType(Type ty) { + if (isa(ty)) + return true; + if (auto opaqueTy = dyn_cast(ty)) + return opaqueTy.getValue().ends_with("*"); + return false; +} + +static int64_t getEmitCScalarByteWidth(Type elemTy) { + if (pto::getPTOStorageElemByteSize(elemTy) == 1) + return 1; + if (elemTy.isF16() || elemTy.isBF16() || elemTy.isInteger(16)) + return 2; + if (elemTy.isF32() || elemTy.isInteger(32)) + return 4; + if (elemTy.isF64() || elemTy.isInteger(64)) + return 8; + return 4; +} + +static std::string tileBufBLayoutToken(pto::TileBufConfigAttr configAttr); +static std::string tileBufSLayoutToken(pto::TileBufConfigAttr configAttr); +static std::string tileBufPadToken(pto::TileBufConfigAttr configAttr); +static pto::BLayout getTileBufBLayoutValue(pto::TileBufConfigAttr configAttr); +static int64_t renderTileTemplateDim(int64_t rawDim, Type elemTy, + pto::BLayout blayout, int dimIdx); + +static const char *tileRoleToken(Attribute memorySpace) { + if (auto asAttr = dyn_cast_or_null(memorySpace)) { + switch (asAttr.getAddressSpace()) { + case pto::AddressSpace::VEC: + return "TileType::Vec"; + case pto::AddressSpace::MAT: + return "TileType::Mat"; + case pto::AddressSpace::LEFT: + return "TileType::Left"; + case pto::AddressSpace::RIGHT: + return "TileType::Right"; + case pto::AddressSpace::ACC: + return "TileType::Acc"; + case pto::AddressSpace::BIAS: + return "TileType::Bias"; + case pto::AddressSpace::SCALING: + return "TileType::Scaling"; + case pto::AddressSpace::GM: + case pto::AddressSpace::Zero: + return "TileType::Vec"; + } + } + return "TileType::Vec"; +} + +static std::string tileBufCompactToken(pto::TileBufConfigAttr configAttr) { + std::string compactTok = "CompactMode::Null"; + if (auto compactAttr = dyn_cast(configAttr.getCompactMode())) { + switch (static_cast(compactAttr.getValue())) { + case 1: + compactTok = "CompactMode::Normal"; + break; + case 2: + compactTok = "CompactMode::RowPlusOne"; + break; + default: + compactTok = "CompactMode::Null"; + break; + } + } + return compactTok; +} + +static std::optional getEmitCTileTypeString(pto::TileBufType type) { + if (type.getRank() != 2) + return std::nullopt; + auto validShape = type.getValidShape(); + if (validShape.size() != 2) + return std::nullopt; + + Type elemTy = type.getElementType(); + auto configAttr = type.getConfigAttr(); + pto::BLayout blayout = getTileBufBLayoutValue(configAttr); + ArrayRef shape = type.getShape(); + int64_t rows = shape[0]; + int64_t cols = shape[1]; + + auto render = [&](int64_t dim, int dimIdx) { + return renderTileTemplateDim(dim, elemTy, blayout, dimIdx); + }; + + std::string vrowTok = + validShape[0] == ShapedType::kDynamic + ? "-1" + : std::to_string(render(validShape[0], 0)); + std::string vcolTok = + validShape[1] == ShapedType::kDynamic + ? "-1" + : std::to_string(render(validShape[1], 1)); + + int32_t fractal = 512; + if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) + fractal = frAttr.getInt(); + + return std::string("Tile<") + tileRoleToken(type.getMemorySpace()) + ", " + + getEmitCScalarTypeToken(elemTy) + ", " + + std::to_string(render(rows, 0)) + ", " + + std::to_string(render(cols, 1)) + ", " + + tileBufBLayoutToken(configAttr) + ", " + vrowTok + ", " + vcolTok + + ", " + tileBufSLayoutToken(configAttr) + ", " + + std::to_string(fractal) + ", " + tileBufPadToken(configAttr) + ", " + + tileBufCompactToken(configAttr) + ">"; +} + +//===----------------------------------------------------------------------===// +// Type Converter +//===----------------------------------------------------------------------===// + +class PTOToEmitCTypeConverter : public TypeConverter { +public: + PTOToEmitCTypeConverter(MLIRContext *Ctx, PTOArch targetArch) { + // --------------------------------------------------------- + // 1. 基本类型 (f32, i32, index) + // --------------------------------------------------------- + addConversion([Ctx](FloatType type) -> Type { + if (type.isFloat8E4M3() || type.isFloat8E4M3FN() || + type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ()) + return emitc::OpaqueType::get(Ctx, "float8_e4m3_t"); + if (type.isFloat8E5M2() || type.isFloat8E5M2FNUZ()) + return emitc::OpaqueType::get(Ctx, "float8_e5m2_t"); + if (type.isF32()) return emitc::OpaqueType::get(Ctx, "float"); + if (type.isF16()) return emitc::OpaqueType::get(Ctx, "half"); + if (type.isBF16()) return emitc::OpaqueType::get(Ctx, "bfloat16_t"); + if (type.isF64()) return emitc::OpaqueType::get(Ctx, "double"); + llvm::errs() << "[Debug] Unsupported FloatType: " << type << "\n"; + return Type{}; + }); + + addConversion([Ctx](pto::HiF8Type) -> Type { + return emitc::OpaqueType::get(Ctx, "hifloat8_t"); + }); + addConversion([Ctx](pto::F4E1M2x2Type) -> Type { + return emitc::OpaqueType::get(Ctx, "float4_e1m2x2_t"); + }); + addConversion([Ctx](pto::F4E2M1x2Type) -> Type { + return emitc::OpaqueType::get(Ctx, "float4_e2m1x2_t"); + }); + + addConversion([Ctx](IntegerType type) -> Type { + if (type.getWidth() == 1) + return type; + + // Prefer fixed-width C types. Preserve signedness if the MLIR integer is + // explicitly signed/unsigned; treat signless as signed by default. + const bool isUnsigned = type.isUnsignedInteger(); + switch (type.getWidth()) { + case 8: + return emitc::OpaqueType::get(Ctx, isUnsigned ? "uint8_t" : "int8_t"); + case 16: + return emitc::OpaqueType::get(Ctx, + isUnsigned ? "uint16_t" : "int16_t"); + case 32: + return emitc::OpaqueType::get(Ctx, + isUnsigned ? "uint32_t" : "int32_t"); + case 64: + return emitc::OpaqueType::get(Ctx, + isUnsigned ? "uint64_t" : "int64_t"); + default: + llvm::errs() << "[Debug] Unsupported IntegerType width: " + << type.getWidth() << "\n"; + return emitc::OpaqueType::get(Ctx, "int32_t"); // Fallback + } + }); + + addConversion([Ctx](IndexType type) -> Type { + return emitc::OpaqueType::get(Ctx, "int32_t"); + }); + + // vector<4xi16> (e.g. TMRGSORT executedNumList) -> pto::MrgSortExecutedNumList + addConversion([Ctx](VectorType type) -> Type { + if (type.getRank() == 1 && type.getNumElements() == 4 && + type.getElementType().isInteger(16)) + return emitc::OpaqueType::get(Ctx, "pto::MrgSortExecutedNumList"); + return Type{}; + }); + + // --------------------------------------------------------- + // 2. PTO 特殊类型 (透传或转换) + // --------------------------------------------------------- + addConversion([](emitc::OpaqueType type) { return type; }); + addConversion([](emitc::PointerType type) { return type; }); + + // --------------------------------------------------------- + // 2.5 PtrType 转换 (指针类型) + // --------------------------------------------------------- + addConversion([this, Ctx](pto::PtrType type) -> std::optional { + Type elemType = type.getElementType(); + Type newElemType = convertType(elemType); + if (!newElemType) + return std::nullopt; + + std::string elemTypeStr; + if (auto opq = dyn_cast(newElemType)) { + elemTypeStr = opq.getValue().str(); + } else { + llvm::errs() << " [Error] PtrType elem type is not OpaqueType: " + << newElemType << "\n"; + return std::nullopt; + } + + std::string qualifier = "__gm__"; + + std::string finalTypeStr = qualifier + " " + elemTypeStr; + return getEmitCPointerType(Ctx, finalTypeStr); + }); + + addConversion([Ctx](pto::PipeType type) -> Type { + (void)type; + return emitc::OpaqueType::get(Ctx, "auto"); + }); + + addConversion([Ctx](pto::EventIdArrayType type) -> Type { + std::string tok = "PTOAS_EventIdArray<" + std::to_string(type.getSize()) + ">"; + return emitc::OpaqueType::get(Ctx, tok); + }); + + // !pto.local_array -> !emitc.array. + // Variables of this type render as `T a[D1][D2]...;` in the emitted C++. + addConversion([this](pto::LocalArrayType type) -> std::optional { + Type convertedElem = convertType(type.getElementType()); + if (!convertedElem) + return std::nullopt; + return emitc::ArrayType::get(type.getShape(), convertedElem); + }); + + addConversion([Ctx](pto::AsyncSessionType type) -> Type { + (void)type; + return emitc::OpaqueType::get(Ctx, "pto::comm::AsyncSession"); + }); + + addConversion([Ctx](pto::AsyncEventType type) -> Type { + (void)type; + return emitc::OpaqueType::get(Ctx, "pto::comm::AsyncEvent"); + }); + + addConversion([Ctx](pto::PrefetchAsyncContextType type) -> Type { + (void)type; + return emitc::OpaqueType::get(Ctx, "pto::PrefetchAsyncContext"); + }); + + addConversion([Ctx](pto::TensorViewType type) -> Type { + return getGlobalTensorOpaqueTypeFromShape( + Ctx, type.getElementType(), type.getShape()); + }); + + addConversion([Ctx](pto::PartitionTensorViewType type) -> Type { + return getGlobalTensorOpaqueTypeFromShape( + Ctx, type.getElementType(), type.getShape()); + }); + + addConversion([Ctx](pto::TileBufType type) -> std::optional { + auto typeString = getEmitCTileTypeString(type); + if (!typeString) + return std::nullopt; + return emitc::OpaqueType::get(Ctx, *typeString); + }); + + // --------------------------------------------------------- + // 3. MemRef 转换 (Debug 重点) + // --------------------------------------------------------- + addConversion([this, Ctx](MemRefType type) -> std::optional { + LLVM_DEBUG(llvm::dbgs() << "Converting MemRef: " << type << "\n"); + + // A. 转换元素类型 + Type elemType = type.getElementType(); + Type newElemType = convertType(elemType); + if (!newElemType) { + llvm::errs() << " [Error] Failed to convert element type: " << elemType << "\n"; + return std::nullopt; + } + + // 获取元素类型的字符串 + std::string elemTypeStr; + if (auto opq = dyn_cast(newElemType)) { + elemTypeStr = opq.getValue().str(); + } else { + llvm::errs() << " [Error] Converted element type is not OpaqueType: " << newElemType << "\n"; + return std::nullopt; + } + + // B. 处理 Memory Space + std::string qualifier = ""; + Attribute memorySpace = type.getMemorySpace(); + + if (!memorySpace) { + qualifier = "__gm__"; + } else if (auto ptoAttr = dyn_cast(memorySpace)) { + qualifier = addrSpaceQualifier(ptoAttr.getAddressSpace()); + } else { + llvm::errs() << " [Warning] Unknown MemorySpace Attribute type: " << memorySpace << "\n"; + qualifier = "__gm__"; // Fallback + } + + std::string finalTypeStr = qualifier + " " + elemTypeStr; + LLVM_DEBUG(llvm::dbgs() << " [Success] -> " << finalTypeStr << "*\n"); + + return getEmitCPointerType(Ctx, finalTypeStr); + }); + + // --------------------------------------------------------- + // 4. Function & Materialization + // --------------------------------------------------------- + addConversion([this](FunctionType type) -> Type { + SmallVector inputs; + if (failed(convertTypes(type.getInputs(), inputs))) return Type{}; + SmallVector results; + if (failed(convertTypes(type.getResults(), results))) return Type{}; + return FunctionType::get(type.getContext(), inputs, results); + }); + + auto materializeCast = [](OpBuilder &Builder, Type ResultType, + ValueRange Inputs, Location Loc) -> Value { + if (Inputs.size() != 1) return Value(); + return Builder.create(Loc, ResultType, Inputs[0]).getResult(0); + }; + + addSourceMaterialization(materializeCast); + addTargetMaterialization(materializeCast); + // Needed for region/block signature conversions (e.g. CFG block args). + addArgumentMaterialization(materializeCast); + } +}; + +static constexpr unsigned kPTOIndexBitWidth = + 32; // keep consistent with IndexType conversion + +// Forward declarations (definitions below). +static inline std::string pipeTokFromPipeAttr(mlir::pto::PipeAttr a); +static emitc::OpaqueType getSignedIntOpaqueType(MLIRContext *ctx, + unsigned bitWidth); +static emitc::OpaqueType getUnsignedIntOpaqueType(MLIRContext *ctx, + unsigned bitWidth); +static emitc::OpaqueType getWiderSignedIntOpaqueType(MLIRContext *ctx, + unsigned bitWidth); +static emitc::OpaqueType getWiderUnsignedIntOpaqueType(MLIRContext *ctx, + unsigned bitWidth); +static Value makeEmitCOpaqueConstant(ConversionPatternRewriter &rewriter, + Location loc, Type type, + llvm::StringRef literal); +static Value makeEmitCIntConstant(ConversionPatternRewriter &rewriter, + Location loc, Type type, int64_t value); +static Value emitCCast(ConversionPatternRewriter &rewriter, Location loc, + Type dstType, Value src); +static FailureOr buildEmitCOpaqueConstantLiteral(Type targetType, + Attribute valueAttr); +static Value castSignlessIntToUnsignedSameWidth(ConversionPatternRewriter &rewriter, + Location loc, Value v, + unsigned bitWidth); +static bool needsA5NoSplitVectorGuard(Operation *op); + +static FailureOr getTileSplitToken(int64_t split) { + switch (split) { + case 0: + return std::string("TileSplitAxis::TILE_NO_SPLIT"); + case 1: + return std::string("TileSplitAxis::TILE_UP_DOWN"); + case 2: + return std::string("TileSplitAxis::TILE_LEFT_RIGHT"); + default: + return failure(); + } +} + +static FailureOr +getTPipeDirectionToken(bool isL2G2L, int8_t dirMask, PTOArch targetArch) { + if (dirMask == 1) { + if (isL2G2L && targetArch == PTOArch::A5) + return std::string("Direction::DIR_C2V_GM"); + return std::string("Direction::DIR_C2V"); + } + if (dirMask == 2) { + if (isL2G2L && targetArch == PTOArch::A5) + return std::string("Direction::DIR_V2C_GM"); + return std::string("Direction::DIR_V2C"); + } + if (dirMask == 3) + return std::string("Direction::DIR_BOTH"); + return failure(); +} + +static std::string buildTPipeToken(int32_t flagBase, llvm::StringRef dirTok, + int32_t slotSize, int32_t slotNum, + int32_t localSlotNum, bool nosplit) { + std::string token = "TPipe<" + std::to_string(flagBase) + ", " + dirTok.str() + + ", " + std::to_string(slotSize) + ", " + + std::to_string(slotNum); + token += ", " + std::to_string(localSlotNum); + token += nosplit ? ", true" : ", false"; + token += ">"; + return token; +} + +static FailureOr buildTPipeTokenFromInitOp(Operation *op, + PTOArch targetArch) { + if (auto initOp = dyn_cast(op)) { + if (!initOp.getFlagBaseAttr()) + return failure(); + auto dirTok = + getTPipeDirectionToken(/*isL2G2L=*/true, initOp.getDirMask(), targetArch); + if (failed(dirTok)) + return failure(); + int32_t localSlotNum = initOp.getLocalSlotNumAttr() + ? initOp.getLocalSlotNumAttr().getInt() + : initOp.getSlotNum(); + return buildTPipeToken(initOp.getFlagBaseAttr().getInt(), *dirTok, + initOp.getSlotSize(), initOp.getSlotNum(), + localSlotNum, + initOp.getNosplitAttr() && + initOp.getNosplitAttr().getValue()); + } + + if (auto initOp = dyn_cast(op)) { + if (!initOp.getFlagBaseAttr()) + return failure(); + auto dirTok = + getTPipeDirectionToken(/*isL2G2L=*/false, initOp.getDirMask(), targetArch); + if (failed(dirTok)) + return failure(); + return buildTPipeToken(initOp.getFlagBaseAttr().getInt(), *dirTok, + initOp.getSlotSize(), initOp.getSlotNum(), 2, + initOp.getNosplitAttr() && + initOp.getNosplitAttr().getValue()); + } + + return failure(); +} + +static FailureOr getTPipeTokenFromValue(Value pipeHandle, + PTOArch targetArch) { + pipeHandle = peelUnrealized(pipeHandle); + Operation *def = pipeHandle.getDefiningOp(); + if (!def) + return failure(); + return buildTPipeTokenFromInitOp(def, targetArch); +} + +static bool isSetFFTsPointerLikeType(Type ty) { + return isEmitCPointerLikeType(ty); +} + +static bool tileDataReturnsIntegralAddress(pto::AddressSpace as) { + return as == pto::AddressSpace::BIAS; +} + +static Type getTileDataResultType(MLIRContext *ctx, pto::AddressSpace as, + StringRef elemTok) { + if (tileDataReturnsIntegralAddress(as)) + return emitc::OpaqueType::get(ctx, "uint64_t"); + return getEmitCPointerType(ctx, addrSpaceQualifier(as), elemTok); +} + +static Value materializeTileDataValue(ConversionPatternRewriter &rewriter, + Location loc, Value tile, + pto::AddressSpace as, + StringRef elemTok) { + auto rawTy = getTileDataResultType(rewriter.getContext(), as, elemTok); + return rewriter + .create(loc, rawTy, "PTOAS__TILE_DATA", + ArrayAttr{}, ArrayAttr{}, + ValueRange{tile}) + .getResult(0); +} + +static Value materializeAddressAsPointer(ConversionPatternRewriter &rewriter, + Location loc, Value addr, + pto::AddressSpace as, + StringRef elemTok) { + auto *ctx = rewriter.getContext(); + std::string ptrTyStr = + std::string(addrSpaceQualifier(as)) + " " + elemTok.str() + "*"; + auto ptrTy = getEmitCPointerType(ctx, addrSpaceQualifier(as), elemTok); + if (isSetFFTsPointerLikeType(addr.getType())) { + if (addr.getType() == ptrTy) + return addr; + return rewriter.create(loc, ptrTy, addr).getResult(); + } + auto castTyAttr = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, ptrTyStr)}); + return rewriter + .create(loc, ptrTy, "reinterpret_cast", + ArrayAttr{}, castTyAttr, + ValueRange{addr}) + .getResult(0); +} + +struct InterCoreSyncCallDesc { + const char *callee = nullptr; + ArrayAttr args; + SmallVector operands; +}; + +static Value castInterCoreEventIdToI32(ConversionPatternRewriter &rewriter, + Location loc, Value eventId) { + auto i32Ty = emitc::OpaqueType::get(rewriter.getContext(), "int32_t"); + if (eventId.getType() == i32Ty) + return eventId; + return emitCCast(rewriter, loc, i32Ty, eventId); +} + +static Attribute getFFTSModeCodegenArg(ConversionPatternRewriter &rewriter, + int64_t fftsMode) { + auto *ctx = rewriter.getContext(); + if (fftsMode == 2) + return emitc::OpaqueAttr::get(ctx, "FFTS_MODE_VAL"); + return emitc::OpaqueAttr::get(ctx, std::to_string(fftsMode)); +} + +static Value createFFTSMsg(ConversionPatternRewriter &rewriter, Location loc, + Value eventI32, int64_t fftsMode) { + auto *ctx = rewriter.getContext(); + auto msgTy = emitc::OpaqueType::get(ctx, "uint16_t"); + auto msgArgs = rewriter.getArrayAttr({ + getFFTSModeCodegenArg(rewriter, fftsMode), + IntegerAttr::get(IndexType::get(ctx), 0), + }); + return rewriter + .create(loc, msgTy, "getFFTSMsg", + /*args=*/msgArgs, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{eventI32}) + .getResult(0); +} + +static InterCoreSyncCallDesc buildInterCoreSyncSetCall( + ConversionPatternRewriter &rewriter, Location loc, PTOArch targetArch, + pto::PipeAttr pipeAttr, IntegerAttr eventIdAttr, int64_t fftsMode) { + auto *ctx = rewriter.getContext(); + std::string pipeTok = pipeTokFromPipeAttr(pipeAttr); + + if (targetArch == PTOArch::A3) { + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value eventVal = + makeEmitCIntConstant(rewriter, loc, i32Ty, eventIdAttr.getInt()); + Value msgVal = createFFTSMsg(rewriter, loc, eventVal, fftsMode); + + InterCoreSyncCallDesc desc; + desc.callee = "ffts_cross_core_sync"; + desc.args = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeTok), + IntegerAttr::get(IndexType::get(ctx), 0), + }); + desc.operands.push_back(msgVal); + return desc; + } + + InterCoreSyncCallDesc desc; + desc.callee = "set_intra_block"; + desc.args = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, pipeTok), eventIdAttr}); + return desc; +} + +static InterCoreSyncCallDesc buildInterCoreSyncSetCallDyn( + ConversionPatternRewriter &rewriter, Location loc, PTOArch targetArch, + pto::PipeAttr pipeAttr, Value eventIdVal, int64_t fftsMode) { + auto *ctx = rewriter.getContext(); + std::string pipeTok = pipeTokFromPipeAttr(pipeAttr); + Value eventI32 = castInterCoreEventIdToI32(rewriter, loc, eventIdVal); + + if (targetArch == PTOArch::A3) { + Value msgVal = createFFTSMsg(rewriter, loc, eventI32, fftsMode); + + InterCoreSyncCallDesc desc; + desc.callee = "ffts_cross_core_sync"; + desc.args = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeTok), + IntegerAttr::get(IndexType::get(ctx), 0), + }); + desc.operands.push_back(msgVal); + return desc; + } + + InterCoreSyncCallDesc desc; + desc.callee = "set_intra_block"; + desc.args = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeTok), + IntegerAttr::get(IndexType::get(ctx), 0), + }); + desc.operands.push_back(eventI32); + return desc; +} + +static InterCoreSyncCallDesc buildInterCoreSyncWaitCall( + ConversionPatternRewriter &rewriter, PTOArch targetArch, + pto::PipeAttr pipeAttr, IntegerAttr eventIdAttr) { + auto *ctx = rewriter.getContext(); + std::string pipeTok = pipeTokFromPipeAttr(pipeAttr); + + InterCoreSyncCallDesc desc; + if (targetArch == PTOArch::A3) { + desc.callee = "wait_flag_dev"; + desc.args = rewriter.getArrayAttr({eventIdAttr}); + return desc; + } + + desc.callee = "wait_intra_block"; + desc.args = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, pipeTok), eventIdAttr}); + return desc; +} + +static InterCoreSyncCallDesc buildInterCoreSyncWaitCallDyn( + ConversionPatternRewriter &rewriter, Location loc, PTOArch targetArch, + pto::PipeAttr pipeAttr, Value eventIdVal) { + auto *ctx = rewriter.getContext(); + std::string pipeTok = pipeTokFromPipeAttr(pipeAttr); + Value eventI32 = castInterCoreEventIdToI32(rewriter, loc, eventIdVal); + + InterCoreSyncCallDesc desc; + if (targetArch == PTOArch::A3) { + desc.callee = "wait_flag_dev"; + desc.args = rewriter.getArrayAttr({IntegerAttr::get(IndexType::get(ctx), 0)}); + desc.operands.push_back(eventI32); + return desc; + } + + desc.callee = "wait_intra_block"; + desc.args = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeTok), + IntegerAttr::get(IndexType::get(ctx), 0), + }); + desc.operands.push_back(eventI32); + return desc; +} + +static bool hasInterCoreSyncOp(func::FuncOp func) { + bool found = false; + func.walk([&](Operation *op) { + if (isa(op)) { + found = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return found; +} + +static bool hasSetFFTsOp(func::FuncOp func) { + bool found = false; + func.walk([&](Operation *op) { + if (isa(op)) { + found = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return found; +} + +//===----------------------------------------------------------------------===// +// Arith -> EmitC (full dialect coverage for scalar ops) +//===----------------------------------------------------------------------===// + +template +struct ArithSimpleBinaryToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type dstTy = this->getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getOperands()); + return success(); + } +}; + +// Integer bitwise ops (andi/ori/xori) on signless integers: perform in unsigned +// to avoid signedness pitfalls, then cast back. +template +struct ArithUnsignedBitwiseBinaryToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = this->getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + if (bitWidth == 1) { + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value resU = rewriter.create(loc, uTy, lhsU, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, resU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithDivUIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::DivUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value divU = rewriter.create(loc, uTy, lhsU, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, divU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithRemUIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::RemUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value remU = rewriter.create(loc, uTy, lhsU, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, remU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithCeilDivUIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::CeilDivUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value one = makeEmitCIntConstant(rewriter, loc, uTy, 1); + Value rhsMinusOne = rewriter.create(loc, uTy, rhsU, one); + Value num = rewriter.create(loc, uTy, lhsU, rhsMinusOne); + Value divU = rewriter.create(loc, uTy, num, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, divU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithCeilDivSIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::CeilDivSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + Value zero = makeEmitCIntConstant(rewriter, loc, dstTy, 0); + Value one = makeEmitCIntConstant(rewriter, loc, dstTy, 1); + + Value q0 = rewriter.create(loc, dstTy, adaptor.getLhs(), + adaptor.getRhs()); + Value r = rewriter.create(loc, dstTy, adaptor.getLhs(), + adaptor.getRhs()); + + Value rNeZero = rewriter.create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::ne, r, + zero); + Value lhsLt0 = + rewriter.create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, adaptor.getLhs(), + zero); + Value rhsLt0 = + rewriter.create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, adaptor.getRhs(), + zero); + Value signsSame = + rewriter.create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::eq, lhsLt0, rhsLt0); + Value adjust = + rewriter.create(loc, rewriter.getI1Type(), + rNeZero, signsSame); + + Value qPlusOne = rewriter.create(loc, dstTy, q0, one); + Value result = rewriter.create(loc, dstTy, adjust, + qPlusOne, q0); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithFloorDivSIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::FloorDivSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + Value zero = makeEmitCIntConstant(rewriter, loc, dstTy, 0); + Value one = makeEmitCIntConstant(rewriter, loc, dstTy, 1); + + Value q0 = rewriter.create(loc, dstTy, adaptor.getLhs(), + adaptor.getRhs()); + Value r = rewriter.create(loc, dstTy, adaptor.getLhs(), + adaptor.getRhs()); + + Value rNeZero = rewriter.create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::ne, r, + zero); + Value lhsLt0 = + rewriter.create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, adaptor.getLhs(), + zero); + Value rhsLt0 = + rewriter.create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, adaptor.getRhs(), + zero); + Value signsDifferent = + rewriter.create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::ne, lhsLt0, rhsLt0); + Value adjust = + rewriter.create(loc, rewriter.getI1Type(), + rNeZero, signsDifferent); + + Value qMinusOne = rewriter.create(loc, dstTy, q0, one); + Value result = rewriter.create(loc, dstTy, adjust, + qMinusOne, q0); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithShiftLeftToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::ShLIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + if (bitWidth == 1) { + // Compute on u8 and truncate to i1. + auto u8Ty = getUnsignedIntOpaqueType(rewriter.getContext(), 8); + Value lhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getLhs()); + Value rhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getRhs()); + Value sh = rewriter.create(loc, u8Ty, lhsU8, + rhsU8); + Value masked = + rewriter.create(loc, u8Ty, sh, + makeEmitCIntConstant(rewriter, loc, + u8Ty, 1)); + rewriter.replaceOp(op, emitCCast(rewriter, loc, dstTy, masked)); + return success(); + } + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value shU = + rewriter.create(loc, uTy, lhsU, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, shU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithShiftRightUIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::ShRUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + if (bitWidth == 1) { + // (x >> y) on i1 is either x (y==0) or 0 (y!=0); approximate in u8. + auto u8Ty = getUnsignedIntOpaqueType(rewriter.getContext(), 8); + Value lhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getLhs()); + Value rhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getRhs()); + Value sh = rewriter.create(loc, u8Ty, lhsU8, + rhsU8); + Value masked = + rewriter.create(loc, u8Ty, sh, + makeEmitCIntConstant(rewriter, loc, + u8Ty, 1)); + rewriter.replaceOp(op, emitCCast(rewriter, loc, dstTy, masked)); + return success(); + } + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value shU = + rewriter.create(loc, uTy, lhsU, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, shU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithShiftRightSIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::ShRSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + if (bitWidth == 1) { + // (x >> y) on i1 is either x (y==0) or 0 (y!=0); approximate in u8. + auto u8Ty = getUnsignedIntOpaqueType(rewriter.getContext(), 8); + Value lhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getLhs()); + Value rhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getRhs()); + Value sh = rewriter.create(loc, u8Ty, lhsU8, + rhsU8); + Value masked = + rewriter.create(loc, u8Ty, sh, + makeEmitCIntConstant(rewriter, loc, + u8Ty, 1)); + rewriter.replaceOp(op, emitCCast(rewriter, loc, dstTy, masked)); + return success(); + } + + // Signed arithmetic shift; cast RHS to unsigned to interpret shift amount. + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value sh = + rewriter.create(loc, dstTy, adaptor.getLhs(), + rhsU); + rewriter.replaceOp(op, sh); + return success(); + } +}; + +struct ArithNegFToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::NegFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getOperand()); + return success(); + } +}; + +struct ArithRemFToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::RemFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + + // Use builtin `fmod` when possible. For f16, compute in float and cast back. + Type callTy = dstTy; + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + + if (auto opFloatTy = dyn_cast(op.getType())) { + if (opFloatTy.isF16()) { + auto f32Ty = emitc::OpaqueType::get(rewriter.getContext(), "float"); + lhs = emitCCast(rewriter, loc, f32Ty, lhs); + rhs = emitCCast(rewriter, loc, f32Ty, rhs); + callTy = f32Ty; + } + } + + // Prefer `__builtin_fmod*` to avoid relying on extra headers. + llvm::StringRef callee = "__builtin_fmod"; + if (auto opFloatTy = dyn_cast(op.getType())) { + if (opFloatTy.isF32() || opFloatTy.isF16()) + callee = "__builtin_fmodf"; + else if (opFloatTy.isF64()) + callee = "__builtin_fmod"; + } + + auto call = rewriter.create( + loc, TypeRange{callTy}, callee, ValueRange{lhs, rhs}, + /*args=*/ArrayAttr{}, /*template_args=*/ArrayAttr{}); + Value result = call.getResult(0); + if (callTy != dstTy) + result = emitCCast(rewriter, loc, dstTy, result); + + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithSelectToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getCondition().getType().isInteger(1)) + return rewriter.notifyMatchFailure( + op, "only scalar i1 conditions supported for arith.select"); + + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + + auto cond = + rewriter.create(op.getLoc(), dstTy, + adaptor.getCondition(), + adaptor.getTrueValue(), + adaptor.getFalseValue()); + rewriter.replaceOp(op, cond.getResult()); + return success(); + } +}; + +struct ArithExtUIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto dstIntTy = dyn_cast(op.getType()); + auto srcIntTy = dyn_cast(op.getIn().getType()); + if (!dstIntTy || !srcIntTy) + return rewriter.notifyMatchFailure(op, "expected scalar integer types"); + + Type dstTy = getTypeConverter()->convertType(dstIntTy); + if (!dstTy) + return failure(); + + // i1 -> iN: bool to integer already behaves as 0/1. + if (srcIntTy.getWidth() == 1) { + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); + return success(); + } + + auto uDstTy = + getUnsignedIntOpaqueType(rewriter.getContext(), dstIntTy.getWidth()); + Value srcU = + castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getIn(), + srcIntTy.getWidth()); + Value extU = emitCCast(rewriter, loc, uDstTy, srcU); + Value result = emitCCast(rewriter, loc, dstTy, extU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithExtSIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto dstIntTy = dyn_cast(op.getType()); + auto srcIntTy = dyn_cast(op.getIn().getType()); + if (!dstIntTy || !srcIntTy) + return rewriter.notifyMatchFailure(op, "expected scalar integer types"); + + Type dstTy = getTypeConverter()->convertType(dstIntTy); + if (!dstTy) + return failure(); + + // i1 sign-extension: 0 -> 0, 1 -> -1. + if (srcIntTy.getWidth() == 1) { + Value zero = makeEmitCIntConstant(rewriter, loc, dstTy, 0); + Value asInt = emitCCast(rewriter, loc, dstTy, adaptor.getIn()); + Value neg = rewriter.create(loc, dstTy, zero, asInt).getResult(); + rewriter.replaceOp(op, neg); + return success(); + } + + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); + return success(); + } +}; + +template +struct ArithCastToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type dstTy = this->getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); + return success(); + } +}; + +struct ArithIndexCastUIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::IndexCastUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + + // MemRef casts are handled elsewhere; for safety, fall back to emitc.cast. + if (isa(op.getIn().getType()) || isa(op.getType())) { + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); + return success(); + } + + auto getBW = [](Type t) -> std::optional { + if (auto i = dyn_cast(t)) + return i.getWidth(); + if (isa(t)) + return kPTOIndexBitWidth; + return std::nullopt; + }; + + auto srcBW = getBW(op.getIn().getType()); + auto dstBW = getBW(op.getType()); + if (!srcBW || !dstBW) + return rewriter.notifyMatchFailure(op, "unsupported index_castui types"); + + if (*dstBW <= *srcBW) { + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); + return success(); + } + + auto uSrcTy = getUnsignedIntOpaqueType(rewriter.getContext(), *srcBW); + auto uDstTy = getUnsignedIntOpaqueType(rewriter.getContext(), *dstBW); + Value srcU = emitCCast(rewriter, loc, uSrcTy, adaptor.getIn()); + Value extU = emitCCast(rewriter, loc, uDstTy, srcU); + Value result = emitCCast(rewriter, loc, dstTy, extU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithUIToFPToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto srcIntTy = dyn_cast(op.getIn().getType()); + if (!srcIntTy) + return rewriter.notifyMatchFailure(op, "expected scalar integer input"); + + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + + // Convert via an unsigned integer type of the same width. + if (srcIntTy.getWidth() == 1) { + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); + return success(); + } + Value srcU = + castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getIn(), + srcIntTy.getWidth()); + Value fp = rewriter.create(loc, dstTy, srcU).getResult(); + rewriter.replaceOp(op, fp); + return success(); + } +}; + +struct ArithFPToUIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::FPToUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto dstIntTy = dyn_cast(op.getType()); + if (!dstIntTy) + return rewriter.notifyMatchFailure(op, "expected scalar integer result"); + + Type dstTy = getTypeConverter()->convertType(dstIntTy); + if (!dstTy) + return failure(); + + auto uDstTy = + getUnsignedIntOpaqueType(rewriter.getContext(), dstIntTy.getWidth()); + Value asU = rewriter.create(loc, uDstTy, adaptor.getIn()).getResult(); + Value result = emitCCast(rewriter, loc, dstTy, asU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithBitcastToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + + // For pointer-like types, a regular cast is fine. + if (isa(dstTy)) { + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); + return success(); + } + + // Only support scalar int/float/index bitcasts here. + auto srcTy = op.getIn().getType(); + auto dstOrigTy = op.getType(); + + auto getBitWidth = [](Type t) -> std::optional { + if (auto it = dyn_cast(t)) + return it.getWidth(); + if (auto ft = dyn_cast(t)) + return ft.getWidth(); + if (isa(t)) + return kPTOIndexBitWidth; + return std::nullopt; + }; + auto srcBW = getBitWidth(srcTy); + auto dstBW = getBitWidth(dstOrigTy); + if (!srcBW || !dstBW || *srcBW != *dstBW) + return rewriter.notifyMatchFailure(op, "bitcast requires equal bitwidth"); + + // Determine the template argument from the destination type string. + auto dstOpaque = dyn_cast(dstTy); + if (!dstOpaque) + return rewriter.notifyMatchFailure(op, "expected emitc opaque dest type"); + + auto templateArgs = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), + dstOpaque.getValue())}); + auto call = rewriter.create( + loc, TypeRange{dstTy}, "ptoas_bitcast", /*operands=*/ValueRange{adaptor.getIn()}, + /*args=*/ArrayAttr{}, /*template_args=*/templateArgs); + rewriter.replaceOp(op, call.getResult(0)); + return success(); + } +}; + +// arith.cmpf lowering with ordered/unordered semantics. +struct ArithCmpFToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + struct CmpFConfig { + bool unordered = false; + emitc::CmpPredicate predicate = emitc::CmpPredicate::eq; + }; + + static Value isNaN(ConversionPatternRewriter &rewriter, Location loc, + Value v) { + return rewriter + .create(loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, + v, v) + .getResult(); + } + + static Value isNotNaN(ConversionPatternRewriter &rewriter, Location loc, + Value v) { + return rewriter + .create(loc, rewriter.getI1Type(), emitc::CmpPredicate::eq, + v, v) + .getResult(); + } + + static std::optional buildSpecialCmpFResult( + arith::CmpFPredicate predicate, ConversionPatternRewriter &rewriter, + Location loc, Type i1Ty, Value lhs, Value rhs) { + switch (predicate) { + case arith::CmpFPredicate::AlwaysFalse: + return makeEmitCOpaqueConstant(rewriter, loc, i1Ty, "false"); + case arith::CmpFPredicate::AlwaysTrue: + return makeEmitCOpaqueConstant(rewriter, loc, i1Ty, "true"); + case arith::CmpFPredicate::ORD: + return rewriter.create( + loc, i1Ty, isNotNaN(rewriter, loc, lhs), + isNotNaN(rewriter, loc, rhs)) + .getResult(); + case arith::CmpFPredicate::UNO: + return rewriter.create( + loc, i1Ty, isNaN(rewriter, loc, lhs), + isNaN(rewriter, loc, rhs)) + .getResult(); + default: + return std::nullopt; + } + } + + static std::optional + getCmpFConfig(arith::CmpFPredicate predicate) { + switch (predicate) { + case arith::CmpFPredicate::OEQ: + return CmpFConfig{false, emitc::CmpPredicate::eq}; + case arith::CmpFPredicate::OGT: + return CmpFConfig{false, emitc::CmpPredicate::gt}; + case arith::CmpFPredicate::OGE: + return CmpFConfig{false, emitc::CmpPredicate::ge}; + case arith::CmpFPredicate::OLT: + return CmpFConfig{false, emitc::CmpPredicate::lt}; + case arith::CmpFPredicate::OLE: + return CmpFConfig{false, emitc::CmpPredicate::le}; + case arith::CmpFPredicate::ONE: + return CmpFConfig{false, emitc::CmpPredicate::ne}; + case arith::CmpFPredicate::UEQ: + return CmpFConfig{true, emitc::CmpPredicate::eq}; + case arith::CmpFPredicate::UGT: + return CmpFConfig{true, emitc::CmpPredicate::gt}; + case arith::CmpFPredicate::UGE: + return CmpFConfig{true, emitc::CmpPredicate::ge}; + case arith::CmpFPredicate::ULT: + return CmpFConfig{true, emitc::CmpPredicate::lt}; + case arith::CmpFPredicate::ULE: + return CmpFConfig{true, emitc::CmpPredicate::le}; + case arith::CmpFPredicate::UNE: + return CmpFConfig{true, emitc::CmpPredicate::ne}; + default: + return std::nullopt; + } + } + + static Value buildCmpFResult(const CmpFConfig &config, + ConversionPatternRewriter &rewriter, + Location loc, Type i1Ty, Value lhs, Value rhs) { + Value cmp = rewriter + .create(loc, i1Ty, config.predicate, lhs, rhs) + .getResult(); + Value unord = rewriter.create( + loc, i1Ty, isNaN(rewriter, loc, lhs), isNaN(rewriter, loc, rhs)); + if (config.unordered) + return rewriter + .create(loc, i1Ty, unord, cmp) + .getResult(); + Value ord = rewriter.create( + loc, i1Ty, isNotNaN(rewriter, loc, lhs), isNotNaN(rewriter, loc, rhs)); + return rewriter + .create(loc, i1Ty, ord, cmp) + .getResult(); + } + + LogicalResult matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isa(op.getLhs().getType())) + return rewriter.notifyMatchFailure(op, "cmpf only supported on scalar floats"); + + auto loc = op.getLoc(); + auto i1Ty = rewriter.getI1Type(); + if (auto special = buildSpecialCmpFResult(op.getPredicate(), rewriter, loc, + i1Ty, adaptor.getLhs(), + adaptor.getRhs())) { + rewriter.replaceOp(op, *special); + return success(); + } + + auto config = getCmpFConfig(op.getPredicate()); + if (!config) + return rewriter.notifyMatchFailure(op, "unsupported cmpf predicate"); + rewriter.replaceOp(op, buildCmpFResult(*config, rewriter, loc, i1Ty, + adaptor.getLhs(), adaptor.getRhs())); + return success(); + } +}; + +struct ArithAddUIExtendedToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getSum().getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, + "expected scalar integer or index operands"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + SmallVector newResultTypes; + if (failed(getTypeConverter()->convertTypes(op->getResultTypes(), + newResultTypes))) + return failure(); + if (newResultTypes.size() != 2) + return failure(); + + Type sumDstTy = newResultTypes[0]; + Type overflowDstTy = newResultTypes[1]; + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + auto wideTy = getWiderUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value lhsWide = emitCCast(rewriter, loc, wideTy, lhsU); + Value rhsWide = emitCCast(rewriter, loc, wideTy, rhsU); + Value sumWide = + rewriter.create(loc, wideTy, lhsWide, rhsWide).getResult(); + + Value sumN = emitCCast(rewriter, loc, uTy, sumWide); + Value sum = emitCCast(rewriter, loc, sumDstTy, sumN); + + Value shiftAmt = makeEmitCIntConstant(rewriter, loc, wideTy, bitWidth); + Value high = rewriter + .create(loc, wideTy, sumWide, + shiftAmt) + .getResult(); + Value zeroWide = makeEmitCIntConstant(rewriter, loc, wideTy, 0); + Value overflow = + rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::ne, high, zeroWide) + .getResult(); + overflow = emitCCast(rewriter, loc, overflowDstTy, overflow); + + rewriter.replaceOp(op, {sum, overflow}); + return success(); + } +}; + +template +struct ArithMulExtendedToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getResult(0).getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, + "expected scalar integer or index operands"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + SmallVector newResultTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), + newResultTypes))) + return failure(); + if (newResultTypes.size() != 2) + return failure(); + + Type lowDstTy = newResultTypes[0]; + Type highDstTy = newResultTypes[1]; + + Type wideTy = isUnsigned ? (Type)getWiderUnsignedIntOpaqueType(rewriter.getContext(), + bitWidth) + : (Type)getWiderSignedIntOpaqueType(rewriter.getContext(), + bitWidth); + + Value lhsWide; + Value rhsWide; + if constexpr (isUnsigned) { + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + lhsWide = emitCCast(rewriter, loc, wideTy, lhsU); + rhsWide = emitCCast(rewriter, loc, wideTy, rhsU); + } else { + lhsWide = emitCCast(rewriter, loc, wideTy, adaptor.getLhs()); + rhsWide = emitCCast(rewriter, loc, wideTy, adaptor.getRhs()); + } + + Value prodWide = + rewriter.create(loc, wideTy, lhsWide, rhsWide).getResult(); + Value low = emitCCast(rewriter, loc, lowDstTy, prodWide); + + Value shiftAmt = makeEmitCIntConstant(rewriter, loc, wideTy, bitWidth); + Value highWide = rewriter + .create(loc, wideTy, prodWide, + shiftAmt) + .getResult(); + Value high = emitCCast(rewriter, loc, highDstTy, highWide); + + rewriter.replaceOp(op, {low, high}); + return success(); + } +}; + +using ArithMulSIExtendedToEmitC = + ArithMulExtendedToEmitC; +using ArithMulUIExtendedToEmitC = + ArithMulExtendedToEmitC; + +struct ArithMinMaxIToEmitCBase { + static Value makeSelect(ConversionPatternRewriter &rewriter, Location loc, + Type dstTy, Value cond, Value trueV, Value falseV) { + return rewriter + .create(loc, dstTy, cond, trueV, falseV) + .getResult(); + } +}; + +struct ArithMaxSIToEmitC : public OpConversionPattern, + ArithMinMaxIToEmitCBase { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::MaxSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + Value cond = rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, + adaptor.getLhs(), adaptor.getRhs()) + .getResult(); + Value res = makeSelect(rewriter, loc, dstTy, cond, adaptor.getRhs(), + adaptor.getLhs()); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ArithMinSIToEmitC : public OpConversionPattern, + ArithMinMaxIToEmitCBase { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::MinSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + Value cond = rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, + adaptor.getLhs(), adaptor.getRhs()) + .getResult(); + Value res = makeSelect(rewriter, loc, dstTy, cond, adaptor.getLhs(), + adaptor.getRhs()); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ArithMaxUIToEmitC : public OpConversionPattern, + ArithMinMaxIToEmitCBase { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::MaxUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + Value lhsU = + castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = + castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value cond = rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, lhsU, rhsU) + .getResult(); + Value res = makeSelect(rewriter, loc, dstTy, cond, adaptor.getRhs(), + adaptor.getLhs()); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ArithMinUIToEmitC : public OpConversionPattern, + ArithMinMaxIToEmitCBase { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::MinUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + Value lhsU = + castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = + castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value cond = rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, lhsU, rhsU) + .getResult(); + Value res = makeSelect(rewriter, loc, dstTy, cond, adaptor.getLhs(), + adaptor.getRhs()); + rewriter.replaceOp(op, res); + return success(); + } +}; + +// Floating-point max/min variants. +struct ArithFloatMinMaxToEmitCBase { + static Value isNaN(ConversionPatternRewriter &rewriter, Location loc, + Value v) { + return rewriter + .create(loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, + v, v) + .getResult(); + } + + static Value makeFZero(ConversionPatternRewriter &rewriter, Location loc, + Type ty) { + return makeEmitCOpaqueConstant(rewriter, loc, ty, "0.0f"); + } +}; + +struct ArithMaxNumFToEmitC : public OpConversionPattern, + ArithFloatMinMaxToEmitCBase { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::MaxNumFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + + Value lhsNaN = isNaN(rewriter, loc, adaptor.getLhs()); + Value rhsNaN = isNaN(rewriter, loc, adaptor.getRhs()); + + Value cmpLt = rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, + adaptor.getLhs(), adaptor.getRhs()) + .getResult(); + Value maxNoNaN = + rewriter + .create(loc, dstTy, cmpLt, adaptor.getRhs(), + adaptor.getLhs()) + .getResult(); + + Value rhsOrMax = + rewriter + .create(loc, dstTy, rhsNaN, adaptor.getLhs(), + maxNoNaN) + .getResult(); + Value res = + rewriter + .create(loc, dstTy, lhsNaN, adaptor.getRhs(), + rhsOrMax) + .getResult(); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ArithMinNumFToEmitC : public OpConversionPattern, + ArithFloatMinMaxToEmitCBase { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::MinNumFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + + Value lhsNaN = isNaN(rewriter, loc, adaptor.getLhs()); + Value rhsNaN = isNaN(rewriter, loc, adaptor.getRhs()); + + Value cmpLt = rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, + adaptor.getLhs(), adaptor.getRhs()) + .getResult(); + Value minNoNaN = + rewriter + .create(loc, dstTy, cmpLt, adaptor.getLhs(), + adaptor.getRhs()) + .getResult(); + + Value rhsOrMin = + rewriter + .create(loc, dstTy, rhsNaN, adaptor.getLhs(), + minNoNaN) + .getResult(); + Value res = + rewriter + .create(loc, dstTy, lhsNaN, adaptor.getRhs(), + rhsOrMin) + .getResult(); + rewriter.replaceOp(op, res); + return success(); + } +}; + +template +struct ArithMinMaxFPropagateNaNToEmitC : public OpConversionPattern, + ArithFloatMinMaxToEmitCBase { + using OpConversionPattern::OpConversionPattern; + + static Value buildPrimaryCandidate(ConversionPatternRewriter &rewriter, + Location loc, Type dstTy, Value lhs, + Value rhs) { + Value cmpLt = + rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, lhs, rhs) + .getResult(); + return rewriter + .create( + loc, dstTy, cmpLt, isMaximum ? rhs : lhs, isMaximum ? lhs : rhs) + .getResult(); + } + + static Value buildSignBitValue(ConversionPatternRewriter &rewriter, + Location loc, Value lhs, FloatType floatTy) { + auto bitsTy = + getUnsignedIntOpaqueType(rewriter.getContext(), floatTy.getWidth()); + auto templateArgs = rewriter.getArrayAttr({emitc::OpaqueAttr::get( + rewriter.getContext(), cast(bitsTy).getValue())}); + Value lhsBits = + rewriter + .create(loc, TypeRange{bitsTy}, "ptoas_bitcast", + ValueRange{lhs}, ArrayAttr{}, + templateArgs) + .getResult(0); + Value oneBits = makeEmitCIntConstant(rewriter, loc, bitsTy, 1); + Value shiftAmount = + makeEmitCIntConstant(rewriter, loc, bitsTy, floatTy.getWidth() - 1); + Value signMask = rewriter + .create(loc, bitsTy, oneBits, + shiftAmount) + .getResult(); + return rewriter + .create(loc, bitsTy, lhsBits, signMask) + .getResult(); + } + + static Value buildSignedZeroCandidate(ConversionPatternRewriter &rewriter, + Location loc, Type dstTy, Value lhs, + Value rhs, FloatType floatTy) { + Value zero = makeFZero(rewriter, loc, dstTy); + Value equal = rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::eq, lhs, rhs) + .getResult(); + Value lhsZero = rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::eq, lhs, + zero) + .getResult(); + Value bothZero = rewriter + .create(loc, rewriter.getI1Type(), + equal, lhsZero) + .getResult(); + auto bitsTy = + getUnsignedIntOpaqueType(rewriter.getContext(), floatTy.getWidth()); + Value zeroBits = makeEmitCIntConstant(rewriter, loc, bitsTy, 0); + Value lhsIsNegZero = + rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::ne, + buildSignBitValue(rewriter, loc, lhs, floatTy), + zeroBits) + .getResult(); + Value tie = rewriter + .create( + loc, dstTy, lhsIsNegZero, isMaximum ? rhs : lhs, + isMaximum ? lhs : rhs) + .getResult(); + return rewriter + .create(loc, dstTy, bothZero, tie, + buildPrimaryCandidate(rewriter, loc, dstTy, + lhs, rhs)) + .getResult(); + } + + static Value buildNaNPropagatingResult(ConversionPatternRewriter &rewriter, + Location loc, Type dstTy, Value lhs, + Value rhs, FloatType floatTy) { + Value lhsNaN = isNaN(rewriter, loc, lhs); + Value rhsNaN = isNaN(rewriter, loc, rhs); + Value noNaN = + buildSignedZeroCandidate(rewriter, loc, dstTy, lhs, rhs, floatTy); + Value rhsOrNoNaN = rewriter + .create(loc, dstTy, rhsNaN, rhs, + noNaN) + .getResult(); + return rewriter + .create(loc, dstTy, lhsNaN, lhs, rhsOrNoNaN) + .getResult(); + } + + LogicalResult + matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isa(op.getType())) + return rewriter.notifyMatchFailure(op, "expected scalar float type"); + + auto loc = op.getLoc(); + Type dstTy = this->getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + + auto floatTy = cast(op.getType()); + rewriter.replaceOp(op, buildNaNPropagatingResult( + rewriter, loc, dstTy, adaptor.getLhs(), + adaptor.getRhs(), floatTy)); + return success(); + } +}; + +using ArithMaximumFToEmitC = + ArithMinMaxFPropagateNaNToEmitC; +using ArithMinimumFToEmitC = + ArithMinMaxFPropagateNaNToEmitC; + +//===----------------------------------------------------------------------===// +// Arith -> EmitC helpers +//===----------------------------------------------------------------------===// + +static emitc::OpaqueType getSignedIntOpaqueType(MLIRContext *ctx, + unsigned bitWidth) { + switch (bitWidth) { + case 1: + return emitc::OpaqueType::get(ctx, "int8_t"); + case 8: + return emitc::OpaqueType::get(ctx, "int8_t"); + case 16: + return emitc::OpaqueType::get(ctx, "int16_t"); + case 32: + return emitc::OpaqueType::get(ctx, "int32_t"); + case 64: + return emitc::OpaqueType::get(ctx, "int64_t"); + case 128: + return emitc::OpaqueType::get(ctx, "__int128"); + default: + llvm::errs() << "[Debug] Unsupported signed integer bitwidth: " << bitWidth + << "\n"; + return emitc::OpaqueType::get(ctx, "int64_t"); + } +} + +static emitc::OpaqueType getUnsignedIntOpaqueType(MLIRContext *ctx, + unsigned bitWidth) { + switch (bitWidth) { + case 1: + return emitc::OpaqueType::get(ctx, "uint8_t"); + case 8: + return emitc::OpaqueType::get(ctx, "uint8_t"); + case 16: + return emitc::OpaqueType::get(ctx, "uint16_t"); + case 32: + return emitc::OpaqueType::get(ctx, "uint32_t"); + case 64: + return emitc::OpaqueType::get(ctx, "uint64_t"); + case 128: + return emitc::OpaqueType::get(ctx, "unsigned __int128"); + default: + llvm::errs() << "[Debug] Unsupported unsigned integer bitwidth: " + << bitWidth << "\n"; + return emitc::OpaqueType::get(ctx, "uint64_t"); + } +} + +static emitc::OpaqueType getWiderSignedIntOpaqueType(MLIRContext *ctx, + unsigned bitWidth) { + switch (bitWidth) { + case 1: + case 8: + return getSignedIntOpaqueType(ctx, 16); + case 16: + return getSignedIntOpaqueType(ctx, 32); + case 32: + return getSignedIntOpaqueType(ctx, 64); + case 64: + return getSignedIntOpaqueType(ctx, 128); + default: + return getSignedIntOpaqueType(ctx, 128); + } +} + +static emitc::OpaqueType getWiderUnsignedIntOpaqueType(MLIRContext *ctx, + unsigned bitWidth) { + switch (bitWidth) { + case 1: + case 8: + return getUnsignedIntOpaqueType(ctx, 16); + case 16: + return getUnsignedIntOpaqueType(ctx, 32); + case 32: + return getUnsignedIntOpaqueType(ctx, 64); + case 64: + return getUnsignedIntOpaqueType(ctx, 128); + default: + return getUnsignedIntOpaqueType(ctx, 128); + } +} + +static Value makeEmitCOpaqueConstant(ConversionPatternRewriter &rewriter, + Location loc, Type type, + llvm::StringRef literal) { + auto attr = emitc::OpaqueAttr::get(rewriter.getContext(), literal); + return rewriter.create(loc, type, attr); +} + +static Value makeEmitCIntConstant(ConversionPatternRewriter &rewriter, + Location loc, Type type, int64_t value) { + return makeEmitCOpaqueConstant(rewriter, loc, type, std::to_string(value)); +} + +static FailureOr buildEmitCOpaqueConstantLiteral(Type targetType, + Attribute valueAttr) { + auto opaqueTy = dyn_cast(targetType); + if (!opaqueTy) + return failure(); + + if (opaqueTy.getValue() == "pto::MrgSortExecutedNumList") { + auto dense = dyn_cast_or_null(valueAttr); + if (!dense) + return failure(); + + auto vecTy = dyn_cast(dense.getType()); + if (!vecTy || vecTy.getRank() != 1 || vecTy.getNumElements() != 4 || + !vecTy.getElementType().isInteger(16)) + return failure(); + + std::string literal; + llvm::raw_string_ostream os(literal); + os << "pto::MrgSortExecutedNumList{"; + bool first = true; + for (APInt elem : dense.getValues()) { + if (!first) + os << ", "; + first = false; + os << elem.getZExtValue(); + } + os << "}"; + os.flush(); + return literal; + } + + return failure(); +} + +static Value emitCCast(ConversionPatternRewriter &rewriter, Location loc, + Type dstType, Value src) { + if (src.getType() == dstType) + return src; + return rewriter.createOrFold(loc, dstType, src); +} + +// For signless iN integers lowered to signed C++ types, this creates a value +// representing the same N-bit pattern in an unsigned C++ type of the same +// width. This avoids incorrect sign-extension when later widening to a larger +// unsigned type. +static Value castSignlessIntToUnsignedSameWidth(ConversionPatternRewriter &rewriter, + Location loc, Value v, + unsigned bitWidth) { + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + return emitCCast(rewriter, loc, uTy, v); +} + +struct ArithMulIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(arith::MulIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + // i1 mul is equivalent to bitwise AND (mod 2 arithmetic). + if (bitWidth == 1) { + rewriter.replaceOpWithNewOp(op, opTy, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value mulU = rewriter.create(loc, uTy, lhsU, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, mulU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithAddIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + // i1 add is equivalent to XOR (mod 2 arithmetic). + if (bitWidth == 1) { + rewriter.replaceOpWithNewOp(op, opTy, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value addU = rewriter.create(loc, uTy, lhsU, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, addU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithCastOPToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type newTy = getTypeConverter()->convertType(op.getType()); + if (!newTy) + return failure(); + if (adaptor.getIn().getType() == newTy) { + rewriter.replaceOp(op, adaptor.getIn()); + return success(); + } + rewriter.replaceOpWithNewOp(op, newTy, adaptor.getIn()); + return success(); + } +}; + +struct ArithSubIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(arith::SubIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + // i1 sub is equivalent to XOR (mod 2 arithmetic). + if (bitWidth == 1) { + rewriter.replaceOpWithNewOp(op, opTy, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value subU = rewriter.create(loc, uTy, lhsU, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, subU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithDivSIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(arith::DivSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type newTy = getTypeConverter()->convertType(op.getType()); + if (!newTy) + return failure(); + rewriter.replaceOpWithNewOp(op, newTy, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } +}; + +struct ArithRemSIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type newTy = getTypeConverter()->convertType(op.getType()); + if (!newTy) + return failure(); + rewriter.replaceOpWithNewOp(op, newTy, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } +}; + +struct ArithTruncIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + auto dstIntTy = dyn_cast(op.getType()); + auto srcIntTy = dyn_cast(op.getIn().getType()); + if (!dstIntTy || !srcIntTy) + return rewriter.notifyMatchFailure(op, "expected scalar integer types"); + + Type dstTy = getTypeConverter()->convertType(dstIntTy); + if (!dstTy) + return failure(); + + // to-i1 conversions: Arith wants truncation to the low bit, while C/C++ + // casts to bool are equivalent to `v != 0`. Implement as `(bool)(v & 1)`. + if (dstIntTy.getWidth() == 1) { + if (srcIntTy.getWidth() == 1) { + rewriter.replaceOp(op, adaptor.getIn()); + return success(); + } + + auto uSrcTy = + getUnsignedIntOpaqueType(rewriter.getContext(), srcIntTy.getWidth()); + Value inU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getIn(), + srcIntTy.getWidth()); + Value one = makeEmitCIntConstant(rewriter, loc, uSrcTy, 1); + Value masked = + rewriter.create(loc, uSrcTy, inU, one); + Value asBool = emitCCast(rewriter, loc, dstTy, masked); + rewriter.replaceOp(op, asBool); + return success(); + } + + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); + return success(); + } +}; + +struct ArithConstantToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type newType = getTypeConverter()->convertType(op.getType()); + if (!newType) + return failure(); + + // `adaptor.getValue()` may be null if attribute conversion isn't defined. + // Use the original attribute as fallback and always cast null-safely. + Attribute valueAttr = adaptor.getValue(); + if (!valueAttr) + valueAttr = op.getValue(); + + if (auto opaqueLiteral = buildEmitCOpaqueConstantLiteral(newType, valueAttr); + succeeded(opaqueLiteral)) { + auto constAttr = emitc::OpaqueAttr::get(rewriter.getContext(), *opaqueLiteral); + rewriter.replaceOpWithNewOp(op, newType, constAttr); + return success(); + } + + if (auto floatAttr = dyn_cast_or_null(valueAttr)) { + SmallString<32> valStr; + floatAttr.getValue().toString(valStr); + llvm::StringRef s(valStr); + // Ensure the literal parses as a floating-point constant in C/C++. + // `APFloat::toString` may emit "1" for integral values; make it "1.0". + const bool hasFloatMarker = + s.contains('.') || s.contains('e') || s.contains('E') || + s.contains('p') || s.contains('P') || s.starts_with("0x") || + s.starts_with("0X") || s.starts_with("nan") || + s.starts_with("-nan") || s.starts_with("inf") || + s.starts_with("-inf"); + if (!hasFloatMarker) + valStr.append(".0"); + // Suffix: keep `f` for f16/f32; omit for f64. + if (!floatAttr.getType().isF64()) + valStr.append("f"); + auto constAttr = emitc::OpaqueAttr::get(rewriter.getContext(), valStr); + rewriter.replaceOpWithNewOp(op, newType, constAttr); + return success(); + } + + if (auto intAttr = dyn_cast_or_null(valueAttr)) { + std::string valStr = std::to_string(intAttr.getValue().getSExtValue()); + auto constAttr = emitc::OpaqueAttr::get(rewriter.getContext(), valStr); + rewriter.replaceOpWithNewOp(op, newType, constAttr); + return success(); + } + + return failure(); + } +}; +//===----------------------------------------------------------------------===// +// pto.mgather lowering -> MGATHER(dst, src, indexes) (pto-isa) +//===----------------------------------------------------------------------===// + +struct PTOMGatherToMGATHER : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::MGatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *ctx = rewriter.getContext(); + Value mem = peelUnrealized(adaptor.getMem()); + Value idx = peelUnrealized(adaptor.getIdx()); + Value dst = peelUnrealized(adaptor.getDst()); + + Value memArg = maybeWrapGlobalMemrefAsGlobalTensor( + rewriter, op.getLoc(), mem, op.getMem().getType(), op.getOperation()); + + auto gatherOobTok = [&](pto::GatherOOB mode) -> StringRef { + switch (mode) { + case pto::GatherOOB::Undefined: + return "pto::GatherOOB::Undefined"; + case pto::GatherOOB::Clamp: + return "pto::GatherOOB::Clamp"; + case pto::GatherOOB::Wrap: + return "pto::GatherOOB::Wrap"; + case pto::GatherOOB::Zero: + return "pto::GatherOOB::Zero"; + } + llvm_unreachable("unknown GatherOOB"); + }; + + SmallVector templateArgVec; + const bool rowCoalesce = + isRowCoalescedMGatherIndexType(op.getDst().getType(), op.getIdx().getType()); + templateArgVec.push_back(emitc::OpaqueAttr::get( + ctx, rowCoalesce ? "pto::Coalesce::Row" : "pto::Coalesce::Elem")); + if (op.getGatherOob() != pto::GatherOOB::Undefined) { + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, gatherOobTok(op.getGatherOob()))); + } + ArrayAttr templateArgs = rewriter.getArrayAttr(templateArgVec); + + rewriter.create( + op.getLoc(), TypeRange{}, "MGATHER", + ArrayAttr{}, templateArgs, + ValueRange{dst, memArg, idx}); + + if (op->getNumResults() == 0) { + rewriter.eraseOp(op); + } else { + rewriter.replaceOp(op, dst); + } + return success(); + } +}; + +struct AffineApplyMulConstToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(affine::AffineApplyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto map = op.getAffineMap(); + + if (map.getNumDims() != 0 || map.getNumSymbols() != 1) + return failure(); + + auto expr = map.getResult(0); + auto bin = dyn_cast(expr); + if (!bin || bin.getKind() != AffineExprKind::Mul) + return failure(); + + auto lhs = bin.getLHS(); + auto rhs = bin.getRHS(); + + auto symExpr = dyn_cast(lhs); + auto constExpr = dyn_cast(rhs); + if (!symExpr || !constExpr) + return failure(); + + Value inputVal = adaptor.getMapOperands()[0]; + + std::string valStr = std::to_string(constExpr.getValue()); + auto cstAttr = emitc::OpaqueAttr::get(rewriter.getContext(), valStr); + auto cstOp = rewriter.create( + op.getLoc(), inputVal.getType(), cstAttr); + + rewriter.replaceOpWithNewOp( + op, inputVal.getType(), inputVal, cstOp); + + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Kernel inference helpers +//===----------------------------------------------------------------------===// + +enum class KernelKind { VecAdd, Matmul, Unknown }; + +[[maybe_unused]] static KernelKind inferKernelKind(func::FuncOp f) { + bool hasAdd = false; + bool hasMM = false; + f.walk([&](Operation *op) { + if (isa(op)) hasAdd = true; + if (isa(op)) hasMM = true; + if (isa(op)) hasMM = true; + }); + if (hasMM) return KernelKind::Matmul; + if (hasAdd) return KernelKind::VecAdd; + return KernelKind::Unknown; +} + +[[maybe_unused]] static void inferTileMNK(func::FuncOp f, int &M, int &N, int &K) { + M = 32; N = 32; K = 32; + SmallVector subs; + f.walk([&](memref::SubViewOp sv) { subs.push_back(sv); }); + + auto readShape2D = [&](memref::SubViewOp sv, int &d0, int &d1) { + auto resTy = mlir::cast(sv.getResult().getType()); + if (resTy.getRank() == 2 && resTy.hasStaticShape()) { + d0 = (int)resTy.getDimSize(0); + d1 = (int)resTy.getDimSize(1); + } + }; + + if (subs.empty()) return; + + int a0=32, a1=32; + readShape2D(subs[0], a0, a1); + M = a0; N = a1; + + if (subs.size() >= 2) { + int b0=32, b1=32; + readShape2D(subs[0], a0, a1); + readShape2D(subs[1], b0, b1); + M = a0; K = a1; N = b1; + } +} + +static std::optional getKernelKindMacro(func::FuncOp funcOp) { + auto kernelKindAttr = + funcOp->getAttrOfType(FunctionKernelKindAttr::name); + if (!kernelKindAttr) + return std::nullopt; + + switch (kernelKindAttr.getKernelKind()) { + case FunctionKernelKind::Cube: + return StringRef("__DAV_CUBE__"); + case FunctionKernelKind::Vector: + return StringRef("__DAV_VEC__"); + } + + llvm_unreachable("unexpected kernel kind"); +} + +struct FuncToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(func::FuncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Convert the function signature with the type converter. + Type convertedTy = getTypeConverter()->convertType(op.getFunctionType()); + auto funcType = dyn_cast_or_null(convertedTy); + if (!funcType) + return rewriter.notifyMatchFailure(op, "failed to convert function type"); + if (funcType.getNumResults() > 1) + return rewriter.notifyMatchFailure( + op, "EmitC cannot return multiple values"); + + // Create the EmitC function with the converted signature. + auto emitcFunc = + rewriter.create(op.getLoc(), op.getName(), funcType); + + for (const auto &namedAttr : op->getAttrs()) { + StringRef name = namedAttr.getName().strref(); + if (name == op.getFunctionTypeAttrName() || + name == SymbolTable::getSymbolAttrName() || + name == pto::kPTOEntryAttrName || + name == pto::kLegacyHACCEntryAttrName || + name == "pto.internal.entry") + continue; + emitcFunc->setAttr(namedAttr.getName(), namedAttr.getValue()); + } + + if (op.isDeclaration()) { + emitcFunc.setSpecifiersAttr(rewriter.getStrArrayAttr({"extern"})); + rewriter.eraseOp(op); + return success(); + } + + if (pto::isPTOEntryFunction(op)) { + emitcFunc.setSpecifiersAttr( + rewriter.getStrArrayAttr({"__global__ AICORE"})); + } else if (op.isPrivate()) { + emitcFunc.setSpecifiersAttr( + rewriter.getStrArrayAttr({"static", "AICORE"})); + } else { + emitcFunc.setSpecifiersAttr(rewriter.getStrArrayAttr({"AICORE"})); + } + + std::optional kernelKindMacro = getKernelKindMacro(op); + bool needsNoSplitGuard = needsA5NoSplitVectorGuard(op.getOperation()); + + // Inline the original body, then convert region/block argument types to + // match the converted signature (also covers CFG blocks introduced by + // pre-lowering, e.g. scf.while -> cf.br/cf.cond_br). + rewriter.inlineRegionBefore(op.getBody(), emitcFunc.getBody(), + emitcFunc.end()); + + TypeConverter::SignatureConversion entryConv(op.getNumArguments()); + for (unsigned i = 0; i < op.getNumArguments(); ++i) + entryConv.addInputs(i, funcType.getInput(i)); + + if (failed(rewriter.convertRegionTypes(&emitcFunc.getBody(), + *getTypeConverter(), &entryConv))) + return failure(); + + // Preserve the existing function prologue shape. `kernel_kind` functions are + // emitted with the same macro guard/reset sequence that used to come from + // early pto.section wrapping, but only after SCF pre-lowering has finished. + { + Block &entryBlock = emitcFunc.getBody().front(); + rewriter.setInsertionPointToStart(&entryBlock); + rewriter.create(op.getLoc(), "using T = float;"); + if (kernelKindMacro) { + std::string startMacro = "\n#if defined(" + kernelKindMacro->str() + ")"; + rewriter.create(op.getLoc(), startMacro); + if (*kernelKindMacro == "__DAV_VEC__") { + rewriter.create(op.getLoc(), "set_mask_norm();"); + rewriter.create(op.getLoc(), + "set_vector_mask(-1, -1);"); + if (needsNoSplitGuard) + rewriter.create( + op.getLoc(), "if (get_subblockid() == 0) {"); + } + } + } + + if (kernelKindMacro) { + Block &lastBlock = emitcFunc.getBody().back(); + rewriter.setInsertionPoint(lastBlock.getTerminator()); + if (*kernelKindMacro == "__DAV_VEC__" && needsNoSplitGuard) + rewriter.create(op.getLoc(), "}"); + std::string endMacro = "#endif // " + kernelKindMacro->str() + "\n"; + rewriter.create(op.getLoc(), endMacro); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// SubView lowering to GlobalTensor (keep your existing code) +//===----------------------------------------------------------------------=== + +enum class Role { A, B, C, Unknown }; + +template +static std::optional inferMatmulLikeSubviewRole(MatmulLikeOp op, + Value buffer) { + if (op.getLhs() == buffer) + return Role::A; + if (op.getRhs() == buffer) + return Role::B; + return std::nullopt; +} + +static std::optional inferSubviewRoleFromLoadUser(mlir::pto::TLoadOp load) { + Value buffer = load.getDst(); + if (!buffer) + return std::nullopt; + for (Operation *user : buffer.getUsers()) { + if (auto matmul = dyn_cast(user)) { + if (auto role = inferMatmulLikeSubviewRole(matmul, buffer)) + return role; + continue; + } + if (auto matmulAcc = dyn_cast(user)) { + if (auto role = inferMatmulLikeSubviewRole(matmulAcc, buffer)) + return role; + } + } + return std::nullopt; +} + +static std::optional inferSubviewRoleFromUser(Operation *user, Value result) { + if (auto load = dyn_cast(user)) + return inferSubviewRoleFromLoadUser(load); + if (auto store = dyn_cast(user)) { + if (store.getDst() == result) + return Role::C; + } + return std::nullopt; +} + +[[maybe_unused]] static Role inferSubviewRole(memref::SubViewOp sv) { + Value result = sv.getResult(); + for (Operation *user : result.getUsers()) { + if (auto role = inferSubviewRoleFromUser(user, result)) + return *role; + } + return Role::Unknown; +} + +// ============================================================================= +// 4. MemRef SubView -> Explicit Shape/Stride Construction (Full Implementation) +// ============================================================================= +struct SubviewToEmitCPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + // 辅助函数:尝试从 OpFoldResult 中提取静态整数值 + std::optional extractStaticInt(OpFoldResult ofr) const { + if (auto attr = ofr.dyn_cast()) { + if (auto intAttr = dyn_cast(attr)) + return intAttr.getInt(); + } else { + Value v = ofr.get(); + if (auto cOp = v.getDefiningOp()) { + if (auto iAttr = dyn_cast(cOp.getValue())) + return iAttr.getInt(); + } else if (auto idxOp = v.getDefiningOp()) { + return idxOp.value(); + } + } + return std::nullopt; + } + + LogicalResult matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + // 获取源 MemRef 类型信息 + auto srcType = mlir::cast(op.getSource().getType()); + int64_t rank = srcType.getRank(); + + auto elemTypeToString = [&](Type elemTy) -> std::string { + if (elemTy.isF16()) + return "half"; + if (elemTy.isBF16()) + return "bfloat16_t"; + if (elemTy.isF32()) + return "float"; + if (elemTy.isF64()) + return "double"; + if (elemTy.isInteger(8)) { + if (elemTy.isSignlessInteger(8) || elemTy.isSignedInteger(8)) + return "int8_t"; + return "uint8_t"; + } + if (elemTy.isInteger(16)) { + if (elemTy.isSignlessInteger(16) || elemTy.isSignedInteger(16)) + return "int16_t"; + return "uint16_t"; + } + if (elemTy.isInteger(32)) { + if (elemTy.isSignlessInteger(32) || elemTy.isSignedInteger(32)) + return "int32_t"; + return "uint32_t"; + } + if (elemTy.isInteger(64)) { + return cast(elemTy).isUnsigned() ? "uint64_t" : "int64_t"; + } + return "float"; + }; + + // ------------------------------------------------------------------------- + // Part 1: 指针偏移计算 (Runtime Pointer Arithmetic) + // ------------------------------------------------------------------------- + + // 准备类型: unsigned + Type u32Ty = emitc::OpaqueType::get(ctx, "unsigned"); + + // Helper: 创建 unsigned 常量 + auto mkU32 = [&](int64_t v) -> Value { + return rewriter.create( + loc, u32Ty, emitc::OpaqueAttr::get(ctx, std::to_string(v))); + }; + + // Helper: 将 OpFoldResult 转为 EmitC Value (用于计算) + auto ofrToEmitCValue = [&](OpFoldResult ofr) -> Value { + if (auto v = ofr.dyn_cast()) { + Value rv = rewriter.getRemappedValue(v); + // 如果类型不匹配,插入 Cast + if (rv.getType() != u32Ty) + return rewriter.create(loc, u32Ty, rv).getResult(); + return rv; + } + if (auto attr = ofr.dyn_cast()) { + if (auto ia = dyn_cast(attr)) + return mkU32(ia.getValue().getSExtValue()); + } + return mkU32(0); + }; + + // 1. 获取 Source 的 Strides (支持动态 Stride 收集) + SmallVector sourceStrides; + + if (auto rc = op.getSource().getDefiningOp()) { + sourceStrides = rc.getMixedStrides(); + } else { + SmallVector strideInts; + int64_t offset = ShapedType::kDynamic; + bool useTypeStrides = succeeded(getStridesAndOffset(srcType, strideInts, offset)); + (void)offset; + if (useTypeStrides) { + for (int64_t s : strideInts) { + if (s == ShapedType::kDynamic) + useTypeStrides = false; + } + } + if (useTypeStrides) { + for (int64_t s : strideInts) { + sourceStrides.push_back(rewriter.getIndexAttr(s)); + } + } else { + // Fallback: Compact Layout + auto shape = srcType.getShape(); + int64_t current = 1; + sourceStrides.resize(rank); + for (int i = rank - 1; i >= 0; --i) { + sourceStrides[i] = rewriter.getIndexAttr(current); + if (shape[i] != ShapedType::kDynamic) current *= shape[i]; + } + } + } + + // 2. 计算运行时 Offset + auto staticOffsets = op.getStaticOffsets(); + auto dynamicOffsets = adaptor.getOffsets(); + int dynOffIdx = 0; + Value totalOffset = mkU32(0); + + for (int i = 0; i < rank; ++i) { + // A. 获取 Offset + Value offVal; + if (staticOffsets[i] == ShapedType::kDynamic) { + Value rawDyn = dynamicOffsets[dynOffIdx++]; + offVal = rewriter.create(loc, u32Ty, rawDyn); + } else { + offVal = mkU32(staticOffsets[i]); + } + + // B. 获取 Stride (用于指针计算) + Value strideVal = mkU32(1); + if (i < (int)sourceStrides.size()) { + strideVal = ofrToEmitCValue(sourceStrides[i]); + } + + // C. 累加 + Value term = rewriter.create(loc, u32Ty, offVal, strideVal); + totalOffset = rewriter.create(loc, u32Ty, totalOffset, term); + } + + // 3. 生成新指针 + // + // NOTE: Some toolchains may materialize kernel pointer params as `void*` even + // when the underlying element type is i16. Pointer arithmetic on `void*` + // is ill-formed in C++, so we explicitly cast to a typed pointer for i16. + Value sourcePtr = adaptor.getSource(); + Value tileCandidate = sourcePtr; + if (auto castOp = sourcePtr.getDefiningOp()) { + tileCandidate = castOp.getOperand(); + } else if (auto uc = + sourcePtr.getDefiningOp()) { + tileCandidate = uc.getOperand(0); + } + if (auto ot = dyn_cast(tileCandidate.getType())) { + auto tyStr = ot.getValue(); + if (tyStr.find("Tile<") != std::string::npos || + tyStr.find("ConvTile<") != std::string::npos) { + std::string elemTok = elemTypeToString(srcType.getElementType()); + pto::AddressSpace as = pto::AddressSpace::GM; + if (auto asAttr = + dyn_cast_or_null(srcType.getMemorySpace())) + as = asAttr.getAddressSpace(); + sourcePtr = + materializeTileDataValue(rewriter, loc, tileCandidate, as, elemTok); + if (tileDataReturnsIntegralAddress(as)) + sourcePtr = + materializeAddressAsPointer(rewriter, loc, sourcePtr, as, elemTok); + } + } + Value newPtr; + { + auto resTy = mlir::cast(op.getResult().getType()); + Type elemTy = resTy.getElementType(); + if (elemTy.isInteger(16)) { + std::string castElemTypeStr = "int16_t"; + if (cast(elemTy).isUnsigned()) + castElemTypeStr = "uint16_t"; + + std::string qualifier = "__gm__"; + if (Attribute ms = srcType.getMemorySpace()) { + if (auto ptoAttr = dyn_cast(ms)) { + qualifier = addrSpaceQualifier(ptoAttr.getAddressSpace()); + } + } + + auto typedPtrTy = emitc::PointerType::get( + emitc::OpaqueType::get(ctx, qualifier + " " + castElemTypeStr)); + Value typedSourcePtr = rewriter.create(loc, typedPtrTy, sourcePtr); + newPtr = rewriter.create(loc, typedPtrTy, typedSourcePtr, totalOffset); + } else { + newPtr = rewriter.create(loc, sourcePtr.getType(), sourcePtr, totalOffset); + } + } + + + // ------------------------------------------------------------------------- + // Part 2: For non-GM memrefs, keep pointer (no GlobalTensor). + // ------------------------------------------------------------------------- + bool isGlobal = true; + if (auto asAttr = dyn_cast_or_null(srcType.getMemorySpace())) { + auto as = asAttr.getAddressSpace(); + isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); + } + if (!isGlobal) { + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + if (newPtr.getType() != dstTy) + newPtr = rewriter.create(loc, dstTy, newPtr); + rewriter.replaceOp(op, newPtr); + return success(); + } + + // ------------------------------------------------------------------------- + // Part 3: 生成 GlobalTensor 类型 (Shape/Stride Template Generation) + // ------------------------------------------------------------------------- + + // When emitting C++ with `declareVariablesAtTop`, value declarations are + // hoisted before body statements. Avoid introducing local `using` aliases + // for templated types (Shape/Stride/GlobalTensor) because those aliases + // would appear after the hoisted declarations and break compilation + // (`unknown type name`). + // + // Instead, use the fully spelled template types as EmitC opaque types. + + auto resTy = mlir::cast(op.getResult().getType()); + + // 1. 解析具体元素类型 + std::string elemTypeStr = getElemTypeStringForGT(resTy.getElementType()); + + // 2. 生成 Shape 模板参数,之后会右对齐有效维度并补齐到 5 维(高维填 1) + SmallVector shapeParamsVec; + SmallVector sizeValues; // 每个维度对应的运行时 size(统一为 unsigned) + auto resShape = resTy.getShape(); + auto mixedSizes = op.getMixedSizes(); + sizeValues.reserve(rank); + for (int i = 0; i < resTy.getRank(); ++i) { + if (resShape[i] == ShapedType::kDynamic) { + shapeParamsVec.push_back(-1); + } else { + shapeParamsVec.push_back(resShape[i]); + } + // size 值:优先从 op.getMixedSizes() 取(可动态/静态),否则退化为类型里的静态 shape。 + if (i < (int)mixedSizes.size()) + sizeValues.push_back(ofrToEmitCValue(mixedSizes[i])); + else + sizeValues.push_back( + mkU32(resShape[i] == ShapedType::kDynamic ? 1 : resShape[i])); + } + + // 3. 生成 Stride 模板参数 + 运行时 stride 值(考虑 subview step) + SmallVector strideTemplateVec; + SmallVector strideValues; // 每个维度对应的运行时 stride(统一为 unsigned) + strideTemplateVec.reserve(rank); + strideValues.reserve(rank); + auto subViewSteps = op.getMixedStrides(); + for (int i = 0; i < rank; ++i) { + OpFoldResult srcStrideOfr = + (i < (int)sourceStrides.size()) ? sourceStrides[i] + : rewriter.getIndexAttr(1); + OpFoldResult stepOfr = (i < (int)subViewSteps.size()) + ? subViewSteps[i] + : rewriter.getIndexAttr(1); + + auto srcStatic = extractStaticInt(srcStrideOfr); + auto stepStatic = extractStaticInt(stepOfr); + if (srcStatic && stepStatic) { + int64_t finalStride = (*srcStatic) * (*stepStatic); + strideTemplateVec.push_back(finalStride); + strideValues.push_back(mkU32(finalStride)); + continue; + } + + strideTemplateVec.push_back(-1); + Value srcV = ofrToEmitCValue(srcStrideOfr); + Value stepV = ofrToEmitCValue(stepOfr); + // 尽量避免乘以 1 生成冗余指令 + if (stepStatic && *stepStatic == 1) + strideValues.push_back(srcV); + else if (srcStatic && *srcStatic == 1) + strideValues.push_back(stepV); + else + strideValues.push_back( + rewriter.create(loc, u32Ty, srcV, stepV)); + } + + // 3.1 右对齐到 5 维:shape 补 1;已有维度继承原 stride; + // 被补出来的高维按“紧密升维”规则连续推导:stride[i] = shape[i+1] * stride[i+1] + SmallVector finalShape; + SmallVector finalStride; + buildGlobalTensorShapeAndStride(shapeParamsVec, strideTemplateVec, + finalShape, finalStride); + Value oneU32 = mkU32(1); + SmallVector finalShapeValues(5, oneU32); + SmallVector finalStrideValues(5, oneU32); + int shift = 5 - rank; + + // 先放入原始 shape/stride(保持用户提供的值) + for (int i = 0; i < rank && i < 5; ++i) { + finalShapeValues[shift + i] = sizeValues[i]; + finalStrideValues[shift + i] = strideValues[i]; + } + + // 从低维到高维倒推补齐 stride(仅对补出来的前置维度生效) + for (int i = 3; i >= 0; --i) { + // 如果该维已由原始 rank 覆盖,则保持原值 + if (i >= shift) + continue; + if (finalStride[i] != -1) { + finalStrideValues[i] = mkU32(finalStride[i]); + continue; + } + // 动态推导:stride[i] = shape[i+1] * stride[i+1] + if (finalShape[i + 1] == 1) { + finalStrideValues[i] = finalStrideValues[i + 1]; + } else { + finalStrideValues[i] = rewriter.create( + loc, u32Ty, finalShapeValues[i + 1], finalStrideValues[i + 1]); + } + } + + std::string shapeParams = joinIntTemplateParams(finalShape); + std::string strideParams = joinIntTemplateParams(finalStride); + + // Spelled-out C++ types. + std::string shapeCppType = "pto::Shape<" + shapeParams + ">"; + std::string strideCppType = "pto::Stride<" + strideParams + ">"; + + // 3.0 Layout: prefer the attribute from InferPTOLayout; only fall back to + // local inference when the pass is disabled. + std::string layoutEnum = "pto::Layout::ND"; + if (auto layout = resolveLayoutForGlobalTensor(op, op.getSource())) { + layoutEnum = layoutToEmitCString(*layout); + } else { + bool allStatic = + llvm::all_of(finalShape, [](int64_t value) { return value != -1; }) && + llvm::all_of(finalStride, [](int64_t value) { return value != -1; }); + + int layoutTag = 0; // ND + auto elemBytes = 4; // default float + if (elemTypeStr.find("half") != std::string::npos || + elemTypeStr.find("f16") != std::string::npos || + elemTypeStr.find("bf16") != std::string::npos) + elemBytes = 2; + else if (elemTypeStr.find("double") != std::string::npos || + elemTypeStr.find("f64") != std::string::npos) + elemBytes = 8; + + if (allStatic) { + if (finalShape[2] == 16 && + finalShape[2] * finalShape[3] * elemBytes == 512 && + finalStride[4] == 1 && finalStride[3] == finalShape[4]) { + layoutTag = 2; // NZ + } else { + bool isRow = finalStride[4] == 1; + for (int i = 3; i >= 0; --i) + isRow &= (finalStride[i] == + multiplyOrDynamic(finalStride[i + 1], finalShape[i + 1])); + bool isCol = finalStride[0] == 1; + for (int i = 0; i < 4; ++i) + isCol &= (finalStride[i + 1] == + multiplyOrDynamic(finalStride[i], finalShape[i])); + if (isCol) + layoutTag = 1; // DN + else + layoutTag = isRow ? 0 : 0; // fallback ND + } + } + + if (layoutTag == 1) + layoutEnum = "pto::Layout::DN"; + else if (layoutTag == 2) + layoutEnum = "pto::Layout::NZ"; + } + // GlobalTensor takes a Layout non-type template parameter; directly use the + // enum constant. + + + // ------------------------------------------------------------------------- + // Part 3: 显式对象实例化 (Explicit Object Instantiation) + // ------------------------------------------------------------------------- + + // A. Instantiate Shape object. + auto shapeTypeOpaque = emitc::OpaqueType::get(ctx, shapeCppType); + SmallVector shapeArgs; + // 从 adaptor.getSizes() 获取 subview 的所有 dynamic sizes + for (Value dynSize : adaptor.getSizes()) { + shapeArgs.push_back(dynSize); + } + + auto shapeInstOp = rewriter.create( + loc, + shapeTypeOpaque, // 返回类型 + shapeCppType, // 调用的“函数名”即类名构造函数 + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange(shapeArgs) + ); + + // B. Instantiate Stride object. + auto strideTypeOpaque = emitc::OpaqueType::get(ctx, strideCppType); + // 仅传入动态 stride 维度对应的值,匹配 pto::Stride 的 N-parameter ctor(并满足其 static_assert)。 + SmallVector strideCtorArgs; + strideCtorArgs.reserve(5); + for (int i = 0; i < 5; ++i) { + if (finalStride[i] == -1) + strideCtorArgs.push_back(finalStrideValues[i]); + } + auto strideInstOp = rewriter.create( + loc, strideTypeOpaque, strideCppType, + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange(strideCtorArgs)); + + // C. Instantiate GlobalTensor object (ptr + shape + stride). + std::string gtCppType = "GlobalTensor<" + elemTypeStr + ", " + shapeCppType + + ", " + strideCppType + ", " + layoutEnum + ">"; + auto gtType = emitc::OpaqueType::get(ctx, gtCppType); + + // 准备构造参数: [ptr, shape_instance, stride_instance] + SmallVector gtConstructorArgs; + gtConstructorArgs.push_back(newPtr); + gtConstructorArgs.push_back(shapeInstOp.getResult(0)); // 拿到 shape_inst 的 SSA Value + gtConstructorArgs.push_back(strideInstOp.getResult(0)); // 拿到 stride_inst 的 SSA Value + + rewriter.replaceOpWithNewOp( + op, + gtType, + gtCppType, + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange(gtConstructorArgs) + ); + + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Helper: build GlobalTensor from a static MemRef (for TLOAD/TSTORE) +//===----------------------------------------------------------------------===// + +static std::string getElemTypeStringForGT(Type elemTy) { + return getEmitCScalarTypeToken(elemTy); +} + +static bool hasStaticShape(MemRefType mrTy) { + return llvm::none_of(mrTy.getShape(), [](int64_t dim) { + return dim == ShapedType::kDynamic; + }); +} + +static bool getStaticMemrefLayout(MemRefType mrTy, SmallVectorImpl &strides, + int64_t &offset) { + if (failed(getStridesAndOffset(mrTy, strides, offset))) { + strides.clear(); + int64_t stride = 1; + ArrayRef shape = mrTy.getShape(); + for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { + strides.push_back(stride); + stride *= shape[i]; + } + std::reverse(strides.begin(), strides.end()); + offset = 0; + } + return offset != ShapedType::kDynamic && + llvm::none_of(strides, [](int64_t strideValue) { + return strideValue == ShapedType::kDynamic; + }); +} + +static Value applyStaticMemrefOffset(ConversionPatternRewriter &rewriter, + Location loc, Value basePtr, + int64_t offset) { + if (offset == 0) + return basePtr; + auto *ctx = rewriter.getContext(); + Type u32Ty = emitc::OpaqueType::get(ctx, "unsigned"); + auto offVal = rewriter.create( + loc, u32Ty, emitc::OpaqueAttr::get(ctx, std::to_string(offset))); + return rewriter.create(loc, basePtr.getType(), basePtr, offVal); +} + +static int getGlobalTensorElementBytes(Type elemTy) { + return static_cast(getPTOStorageElemByteSize(elemTy)); +} + +static int64_t multiplyOrDynamic(int64_t lhs, int64_t rhs) { + if (lhs < 0 || rhs < 0) + return -1; + return lhs * rhs; +} + +static void buildGlobalTensorShapeAndStride(ArrayRef shape, + ArrayRef strides, + SmallVectorImpl &shape5D, + SmallVectorImpl &stride5D) { + shape5D.assign(5, 1); + stride5D.assign(5, 1); + int rank = static_cast(shape.size()); + int shift = 5 - rank; + for (int i = 0; i < rank && i < 5; ++i) { + shape5D[shift + i] = shape[i]; + stride5D[shift + i] = strides[i]; + } + for (int i = 3; i >= 0; --i) { + if (i >= shift) + continue; + stride5D[i] = multiplyOrDynamic(shape5D[i + 1], stride5D[i + 1]); + } +} + +static std::string joinIntTemplateParams(ArrayRef values) { + std::string result; + for (size_t i = 0; i < values.size(); ++i) { + if (i != 0) + result += ", "; + result += std::to_string(values[i]); + } + return result; +} + +static SmallVector buildRowMajorStrides(ArrayRef shape) { + SmallVector strides(shape.size(), 1); + int64_t running = 1; + for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { + strides[i] = running; + running = multiplyOrDynamic(running, shape[i]); + } + return strides; +} + +static std::string getGlobalTensorTypeStringFromShape(Type elemTy, + ArrayRef shape, + StringRef layoutEnum) { + SmallVector strides = buildRowMajorStrides(shape); + return getGlobalTensorTypeStringFromShapeAndStrides(elemTy, shape, strides, + layoutEnum); +} + +static std::string getGlobalTensorTypeStringFromShapeAndStrides( + Type elemTy, ArrayRef shape, ArrayRef strides, + StringRef layoutEnum) { + SmallVector shape5D; + SmallVector stride5D; + buildGlobalTensorShapeAndStride(shape, strides, shape5D, stride5D); + + std::string elemTypeStr = getElemTypeStringForGT(elemTy); + std::string shapeType = "pto::Shape<" + joinIntTemplateParams(shape5D) + ">"; + std::string strideType = + "pto::Stride<" + joinIntTemplateParams(stride5D) + ">"; + return "GlobalTensor<" + elemTypeStr + ", " + shapeType + ", " + + strideType + ", " + layoutEnum.str() + ">"; +} + +static emitc::OpaqueType getGlobalTensorOpaqueTypeFromShape( + MLIRContext *ctx, Type elemTy, ArrayRef shape, + StringRef layoutEnum) { + return emitc::OpaqueType::get( + ctx, getGlobalTensorTypeStringFromShape(elemTy, shape, layoutEnum)); +} + +static std::string inferFallbackGlobalTensorLayout(ArrayRef shape5D, + ArrayRef stride5D, + Type elemTy) { + int elemBytes = getGlobalTensorElementBytes(elemTy); + if (elemBytes == 0) + return "pto::Layout::ND"; + if (shape5D[2] == 16 && multiplyOrDynamic(shape5D[2], shape5D[3]) * elemBytes == 512 && + stride5D[4] == 1 && stride5D[3] == shape5D[4]) { + return "pto::Layout::NZ"; + } + + bool isRowMajor = stride5D[4] == 1; + for (int i = 3; i >= 0 && isRowMajor; --i) + isRowMajor = stride5D[i] == multiplyOrDynamic(stride5D[i + 1], shape5D[i + 1]); + + bool isColMajor = stride5D[0] == 1; + for (int i = 0; i < 4 && isColMajor; ++i) + isColMajor = stride5D[i + 1] == multiplyOrDynamic(stride5D[i], shape5D[i]); + + if (isColMajor) + return "pto::Layout::DN"; + return isRowMajor ? "pto::Layout::ND" : "pto::Layout::ND"; +} + +static std::string resolveGlobalTensorLayout(Operation *anchor, Value basePtr, + ArrayRef shape5D, + ArrayRef stride5D, + Type elemTy) { + if (auto layout = resolveLayoutForGlobalTensor(anchor, basePtr)) + return layoutToEmitCString(*layout); + return inferFallbackGlobalTensorLayout(shape5D, stride5D, elemTy); +} + +struct GlobalTensorTypeNames { + std::string shapeTypeName; + std::string strideTypeName; + std::string tensorTypeName; + std::string layoutConstName; +}; + +static GlobalTensorTypeNames getGlobalTensorTypeNames(Operation *anchor) { + std::string suffix = "_" + std::to_string(reinterpret_cast(anchor)); + return { + "GTShape" + suffix, + "GTStride" + suffix, + "GT" + suffix, + "GT" + suffix + "_layout", + }; +} +static Value buildGlobalTensorFromMemref(ConversionPatternRewriter &rewriter, + Location loc, Value basePtr, + MemRefType mrTy, + Operation *anchor) { + auto *ctx = rewriter.getContext(); + + ArrayRef shape = mrTy.getShape(); + if (!hasStaticShape(mrTy)) + return Value(); + + SmallVector strides; + int64_t offset = 0; + if (!getStaticMemrefLayout(mrTy, strides, offset)) + return Value(); + + Value ptr = applyStaticMemrefOffset(rewriter, loc, basePtr, offset); + GlobalTensorTypeNames names = getGlobalTensorTypeNames(anchor); + std::string elemTypeStr = getElemTypeStringForGT(mrTy.getElementType()); + SmallVector shape5D; + SmallVector stride5D; + buildGlobalTensorShapeAndStride(shape, strides, shape5D, stride5D); + + rewriter.create( + loc, "using " + names.shapeTypeName + " = pto::Shape<" + + joinIntTemplateParams(shape5D) + ">;"); + rewriter.create( + loc, "using " + names.strideTypeName + " = pto::Stride<" + + joinIntTemplateParams(stride5D) + ">;"); + + std::string layoutEnum = resolveGlobalTensorLayout( + anchor, basePtr, shape5D, stride5D, mrTy.getElementType()); + rewriter.create(loc, "constexpr pto::Layout " + + names.layoutConstName + " = " + + layoutEnum + ";"); + + auto shapeTypeOpaque = emitc::OpaqueType::get(ctx, names.shapeTypeName); + auto strideTypeOpaque = emitc::OpaqueType::get(ctx, names.strideTypeName); + auto shapeInstOp = rewriter.create( + loc, shapeTypeOpaque, names.shapeTypeName, ArrayAttr{}, ArrayAttr{}, + ValueRange{}); + auto strideInstOp = rewriter.create( + loc, strideTypeOpaque, names.strideTypeName, ArrayAttr{}, ArrayAttr{}, + ValueRange{}); + + rewriter.create( + loc, "using " + names.tensorTypeName + " = GlobalTensor<" + elemTypeStr + + ", " + names.shapeTypeName + ", " + names.strideTypeName + + ", " + names.layoutConstName + ">;"); + auto gtType = emitc::OpaqueType::get(ctx, names.tensorTypeName); + + SmallVector gtArgs; + gtArgs.push_back(ptr); + gtArgs.push_back(shapeInstOp.getResult(0)); + gtArgs.push_back(strideInstOp.getResult(0)); + + auto gtInst = rewriter.create( + loc, gtType, names.tensorTypeName, ArrayAttr{}, ArrayAttr{}, + ValueRange(gtArgs)); + + return gtInst.getResult(0); +} + +static Value maybeWrapGlobalMemrefAsGlobalTensor( + ConversionPatternRewriter &rewriter, Location loc, Value loweredValue, + Type originalType, Operation *anchor) { + auto mrTy = dyn_cast(originalType); + if (!mrTy) + return loweredValue; + + bool isGlobal = true; + if (auto asAttr = + dyn_cast_or_null(mrTy.getMemorySpace())) { + auto as = asAttr.getAddressSpace(); + isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); + } + if (!isGlobal) + return loweredValue; + + if (Value gt = + buildGlobalTensorFromMemref(rewriter, loc, loweredValue, mrTy, anchor)) + return gt; + return loweredValue; +} + +static Value castToGMBytePointer(ConversionPatternRewriter &rewriter, + Location loc, Value value) { + auto *ctx = rewriter.getContext(); + auto targetTy = + emitc::PointerType::get(emitc::OpaqueType::get(ctx, "__gm__ uint8_t")); + if (value.getType() == targetTy) + return value; + + auto castTyAttr = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "__gm__ uint8_t*")}); + if (isSetFFTsPointerLikeType(value.getType())) { + return rewriter + .create(loc, targetTy, "reinterpret_cast", + ArrayAttr{}, castTyAttr, + ValueRange{value}) + .getResult(0); + } + return rewriter.create(loc, targetTy, value).getResult(); +} + +static Value materializeTensorViewDataPointer( + ConversionPatternRewriter &rewriter, Location loc, Value value, + Type sourceType) { + auto tvTy = dyn_cast(sourceType); + if (!tvTy) + return value; + + auto *ctx = rewriter.getContext(); + std::string elemTypeStr = getElemTypeStringForGT(tvTy.getElementType()); + auto ptrTy = emitc::PointerType::get( + emitc::OpaqueType::get(ctx, "__gm__ " + elemTypeStr)); + return rewriter + .create(loc, ptrTy, "PTOAS__GLOBAL_TENSOR_DATA", + ArrayAttr{}, ArrayAttr{}, ValueRange{value}) + .getResult(0); +} + +static std::string tileBufBLayoutToken(pto::TileBufConfigAttr configAttr) { + std::string blTok = "BLayout::RowMajor"; + if (auto blAttr = dyn_cast(configAttr.getBLayout())) { + if (static_cast(blAttr.getValue()) == 1) + blTok = "BLayout::ColMajor"; + } + return blTok; +} + +static std::string tileBufSLayoutToken(pto::TileBufConfigAttr configAttr) { + std::string slTok = "SLayout::NoneBox"; + if (auto slAttr = dyn_cast(configAttr.getSLayout())) { + int32_t slVal = static_cast(slAttr.getValue()); + slTok = (slVal == 1) ? "SLayout::RowMajor" + : (slVal == 2) ? "SLayout::ColMajor" + : "SLayout::NoneBox"; + } + return slTok; +} + +static std::string tileBufPadToken(pto::TileBufConfigAttr configAttr) { + std::string padTok = "PadValue::Null"; + if (auto padAttr = dyn_cast(configAttr.getPad())) { + switch (static_cast(padAttr.getValue())) { + case 1: + padTok = "PadValue::Zero"; + break; + case 2: + padTok = "PadValue::Max"; + break; + case 3: + padTok = "PadValue::Min"; + break; + default: + padTok = "PadValue::Null"; + break; + } + } + return padTok; +} + +static pto::BLayout getTileBufBLayoutValue(pto::TileBufConfigAttr configAttr) { + if (auto blAttr = dyn_cast(configAttr.getBLayout())) + return blAttr.getValue(); + return pto::BLayout::RowMajor; +} + +static int64_t renderTileTemplateDim(int64_t rawDim, Type elemTy, + pto::BLayout blayout, int dimIdx) { + assert(dimIdx >= 0 && dimIdx < 2 && + "renderTileTemplateDim expects a rank-2 rows/cols dimension index"); + if (rawDim == ShapedType::kDynamic) + return rawDim; + if (!pto::isPTOFloat4PackedType(elemTy)) + return rawDim; + int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; + return dimIdx == packedDim ? rawDim * 2 : rawDim; +} + +static FailureOr buildAsyncScratchTileValue( + ConversionPatternRewriter &rewriter, Location loc, Value originalScratch, + Value emittedScratch) { + Value scratch = peelUnrealized(emittedScratch); + if (auto opaqueTy = dyn_cast(scratch.getType())) { + StringRef typeStr = opaqueTy.getValue(); + if (typeStr.contains("Tile<") || typeStr.contains("ConvTile<")) + return scratch; + } + + auto memTy = dyn_cast(originalScratch.getType()); + if (!memTy) + return failure(); + + ArrayRef shape = memTy.getShape(); + if (!memTy.hasStaticShape() || shape.empty() || shape.size() > 2) + return failure(); + + int64_t rows = shape.size() == 1 ? 1 : shape[0]; + int64_t cols = shape.size() == 1 ? shape[0] : shape[1]; + + auto *ctx = rewriter.getContext(); + pto::TileBufConfigAttr configAttr = pto::TileBufConfigAttr::getDefault(ctx); + if (auto bind = originalScratch.getDefiningOp()) { + configAttr = bind.getConfig(); + } else if (auto cast = originalScratch.getDefiningOp()) { + if (auto config = cast.getConfig()) + configAttr = *config; + } + + int32_t fractal = 512; + if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) + fractal = frAttr.getInt(); + + Type elemTy = memTy.getElementType(); + pto::BLayout blayout = getTileBufBLayoutValue(configAttr); + int64_t templateRows = renderTileTemplateDim(rows, elemTy, blayout, 0); + int64_t templateCols = renderTileTemplateDim(cols, elemTy, blayout, 1); + std::string elemTypeStr = getEmitCScalarTypeToken(elemTy); + std::string tileTypeStr = + "Tile"; + + Value tile = rewriter + .create( + loc, emitc::OpaqueType::get(ctx, tileTypeStr), + emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + auto addr = rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + Value scratchAddr = + rewriter + .create(loc, emitc::OpaqueType::get(ctx, "uint64_t"), + "reinterpret_cast", ArrayAttr{}, addr, + ValueRange{scratch}) + .getResult(0); + rewriter.create(loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{tile, scratchAddr}); + return tile; +} + +static FailureOr buildSyncAllWorkspaceTileValue( + ConversionPatternRewriter &rewriter, Location loc, Value originalWorkspace, + Value emittedWorkspace) { + Value workspace = peelUnrealized(emittedWorkspace); + if (auto opaqueTy = dyn_cast(workspace.getType())) { + StringRef typeStr = opaqueTy.getValue(); + if (typeStr.contains("Tile<") || typeStr.contains("ConvTile<")) + return workspace; + } + + auto memTy = dyn_cast(originalWorkspace.getType()); + if (!memTy) + return failure(); + if (!memTy.hasStaticShape()) + return failure(); + + ArrayRef rawShape = memTy.getShape(); + if (rawShape.empty() || rawShape.size() > 2) + return failure(); + + int64_t rows = rawShape.size() == 1 ? 1 : rawShape[0]; + int64_t cols = rawShape.size() == 1 ? rawShape[0] : rawShape[1]; + SmallVector shape{rows, cols}; + SmallVector validShape{rows, cols}; + + auto *ctx = rewriter.getContext(); + pto::TileBufConfigAttr configAttr = pto::TileBufConfigAttr::getDefault(ctx); + if (auto bind = originalWorkspace.getDefiningOp()) { + configAttr = bind.getConfig(); + } else if (auto cast = originalWorkspace.getDefiningOp()) { + if (auto config = cast.getConfig()) + configAttr = *config; + } + + Attribute memorySpace = memTy.getMemorySpace(); + if (!memorySpace) + return failure(); + + auto tileTy = pto::TileBufType::get(ctx, shape, memTy.getElementType(), + memorySpace, validShape, configAttr); + auto tileTypeString = getEmitCTileTypeString(tileTy); + if (!tileTypeString) + return failure(); + + auto tileEmitTy = emitc::OpaqueType::get(ctx, *tileTypeString); + Value tile = rewriter + .create(loc, tileEmitTy, + emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + + Value rawPtr = workspace; + auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); + if (isSetFFTsPointerLikeType(rawPtr.getType())) { + auto rcU64 = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + rawPtr = rewriter + .create(loc, u64Ty, "reinterpret_cast", + ArrayAttr{}, rcU64, + ValueRange{rawPtr}) + .getResult(0); + } else if (rawPtr.getType() != u64Ty) { + rawPtr = rewriter.create(loc, u64Ty, rawPtr).getResult(); + } + + rewriter.create(loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{tile, rawPtr}); + return tile; +} + +//===----------------------------------------------------------------------===// +// pto.pointer_cast lowering +//===----------------------------------------------------------------------=== +struct PointerCastConversion : public OpConversionPattern { + static bool getIndexConst(Value v, int64_t &out) { + if (auto cst = v.getDefiningOp()) { + if (auto ia = dyn_cast(cst.getValue())) { + out = ia.getValue().getSExtValue(); + return true; + } + } + return false; + } + + using OpConversionPattern::OpConversionPattern; + + enum class TileRole { Vec, Mat, Left, Right, Acc, Bias, Scaling }; + + static void collectUserOpsThroughCasts(Value v, SmallVectorImpl &out) { + for (Operation *u : v.getUsers()) { + if (auto castOp = dyn_cast(u)) { + for (Value r : castOp.getResults()) + collectUserOpsThroughCasts(r, out); + continue; + } + out.push_back(u); + } + } + + static Value peelUnrealized(Value v) { + while (auto castOp = v.getDefiningOp()) { + v = castOp.getOperand(0); + } + return v; + } + + static TileRole inferRole(pto::PointerCastOp op) { + // 1. 优先检查 AddressSpace + if (auto memRefTy = dyn_cast(op.getType())) { + Attribute memorySpace = memRefTy.getMemorySpace(); + if (auto ptoAttr = dyn_cast_or_null(memorySpace)) { + switch (ptoAttr.getAddressSpace()) { + case pto::AddressSpace::LEFT: return TileRole::Left; + case pto::AddressSpace::RIGHT: return TileRole::Right; + case pto::AddressSpace::ACC: return TileRole::Acc; + case pto::AddressSpace::BIAS: return TileRole::Bias; + case pto::AddressSpace::MAT: return TileRole::Mat; + case pto::AddressSpace::SCALING: return TileRole::Scaling; + default: break; + } + } + } + + // 2. 通过 Usage 推导 (Fallback) + SmallVector users; + collectUserOpsThroughCasts(op.getResult(), users); + + for (Operation *user : users) { + if (auto mm = dyn_cast(user)) { + if (mm.getDst() && peelUnrealized(mm.getDst()) == op.getResult()) return TileRole::Acc; + if (peelUnrealized(mm.getLhs()) == op.getResult()) return TileRole::Left; + if (peelUnrealized(mm.getRhs()) == op.getResult()) return TileRole::Right; + } + if (auto mmacc = dyn_cast(user)) { + if (mmacc.getDst() && peelUnrealized(mmacc.getDst()) == op.getResult()) return TileRole::Acc; + if (peelUnrealized(mmacc.getAccIn()) == op.getResult()) return TileRole::Acc; + if (peelUnrealized(mmacc.getLhs()) == op.getResult()) return TileRole::Left; + if (peelUnrealized(mmacc.getRhs()) == op.getResult()) return TileRole::Right; + } + } + + return TileRole::Vec; + } + + // [新增] 辅助函数:判断 Value 是否源自 arith.constant + static bool isConstant(Value v, int64_t &outVal) { + if (!v) return false; + if (auto cst = v.getDefiningOp()) { + if (auto attr = dyn_cast(cst.getValue())) { + outVal = attr.getInt(); + return true; + } + } + return false; + } + + LogicalResult matchAndRewrite(pto::PointerCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto selfType = mlir::cast(op.getType()); + ArrayRef shape = selfType.getShape(); + Type elemType = selfType.getElementType(); + + // 1. 推导 Tile Role + TileRole role = inferRole(op); + + // 2. 类型字符串生成 (elemTypeStr, dimStr) + std::string elemTypeStr = getEmitCScalarTypeToken(elemType); + + std::string dimStr; + pto::BLayout blayout = pto::BLayout::RowMajor; + auto dimToString = [&](int64_t dim, const char *symbol, + int dimIdx) -> std::string { + if (dim == ShapedType::kDynamic) + return std::string(symbol); + return std::to_string(renderTileTemplateDim(dim, elemType, blayout, + dimIdx)); + }; + + // 3. Role Token + const char *roleTok = "TileType::Vec"; + switch (role) { + case TileRole::Left: roleTok = "TileType::Left"; break; + case TileRole::Right: roleTok = "TileType::Right"; break; + case TileRole::Acc: roleTok = "TileType::Acc"; break; + case TileRole::Bias: roleTok = "TileType::Bias"; break; + case TileRole::Mat: roleTok = "TileType::Mat"; break; + case TileRole::Vec: roleTok = "TileType::Vec"; break; + case TileRole::Scaling: roleTok = "TileType::Scaling"; break; + } + + // 4. Config & Layout (support BLayoutAttr/SLayoutAttr/PadValueAttr after namespace change) + std::string layoutParams = "BLayout::RowMajor"; + std::string extraParams = ""; + if (auto configOpt = op.getConfig()) { + auto config = *configOpt; + int32_t blVal = 0; + if (auto attr = dyn_cast(config.getBLayout())) + blVal = static_cast(attr.getValue()); + + if (blVal == 1) layoutParams = "BLayout::ColMajor"; + blayout = blVal == 1 ? pto::BLayout::ColMajor : pto::BLayout::RowMajor; + + int32_t slVal = 0; + if (auto attr = dyn_cast(config.getSLayout())) + slVal = static_cast(attr.getValue()); + + std::string slStr = (slVal == 1) ? "SLayout::RowMajor" : (slVal == 2) ? "SLayout::ColMajor" : "SLayout::NoneBox"; + + int32_t frVal = 0; + if (auto attr = dyn_cast(config.getSFractalSize())) frVal = attr.getInt(); + + int32_t padVal = 0; + if (auto attr = dyn_cast(config.getPad())) + padVal = static_cast(attr.getValue()); + + std::string padStr = "PadValue::Null"; + switch (padVal) { + case 1: padStr = "PadValue::Zero"; break; + case 2: padStr = "PadValue::Max"; break; + case 3: padStr = "PadValue::Min"; break; + } + + int32_t compactVal = 0; + if (auto attr = dyn_cast(config.getCompactMode())) + compactVal = static_cast(attr.getValue()); + + std::string compactStr = "CompactMode::Null"; + switch (compactVal) { + case 1: compactStr = "CompactMode::Normal"; break; + case 2: compactStr = "CompactMode::RowPlusOne"; break; + } + + if (!slStr.empty()) { + extraParams += ", " + slStr + ", " + std::to_string(frVal) + ", " + + padStr + ", " + compactStr; + } + } else { + extraParams = ", SLayout::NoneBox, 512, PadValue::Null, CompactMode::Null"; + } + + if (role == TileRole::Left) + dimStr = dimToString(shape[0], "M", 0) + ", " + + dimToString(shape[1], "K", 1); + else if (role == TileRole::Right) + dimStr = dimToString(shape[0], "K", 0) + ", " + + dimToString(shape[1], "N", 1); + else if (role == TileRole::Bias) + dimStr = "1, " + dimToString(shape[1], "N", 1); + else + dimStr = dimToString(shape[0], "M", 0) + ", " + + dimToString(shape[1], "N", 1); + + // [核心修改] Valid Dims 处理逻辑 (支持混合静态/动态) + std::string vrowTok, vcolTok; + bool useConstructor = false; + + bool rowIsDynamic = false; + bool colIsDynamic = false; + + SmallVector constructorArgs; + + Value vRow = op.getValidRow(); + Value vCol = op.getValidCol(); + Value vRowEmitC = adaptor.getValidRow(); + Value vColEmitC = adaptor.getValidCol(); + bool forceDynamicValid = op->hasAttr(kForceDynamicValidShapeAttrName); + + int64_t cRow = 0, cCol = 0; + bool rowIsConst = vRow && isConstant(vRow, cRow); + bool colIsConst = vCol && isConstant(vCol, cCol); + + auto makeCtorDimValue = [&](Value emitted, int64_t fallback) -> Value { + if (emitted) + return emitted; + return makeEmitCIntConstant( + rewriter, loc, emitc::OpaqueType::get(ctx, "int32_t"), fallback); + }; + auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { + if (!emitted || !pto::isPTOFloat4PackedType(elemType)) + return emitted; + int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; + if (dimIdx != packedDim) + return emitted; + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); + return rewriter.create(loc, i32Ty, emitted, two).getResult(); + }; + + if (forceDynamicValid) { + vrowTok = "-1"; + vcolTok = "-1"; + useConstructor = true; + constructorArgs.push_back( + makeCtorDimValue(maybeScaleDynamicValid(vRowEmitC, 0), + renderTileTemplateDim(rowIsConst ? cRow : shape[0], + elemType, blayout, 0))); + constructorArgs.push_back( + makeCtorDimValue(maybeScaleDynamicValid(vColEmitC, 1), + renderTileTemplateDim(colIsConst ? cCol : shape[1], + elemType, blayout, 1))); + } else { + if (rowIsConst) { + vrowTok = std::to_string( + renderTileTemplateDim(cRow, elemType, blayout, 0)); + } else if (vRow) { + vrowTok = "-1"; + rowIsDynamic = true; + useConstructor = true; + } else { + vrowTok = std::to_string( + renderTileTemplateDim(shape[0], elemType, blayout, 0)); + } + + if (colIsConst) { + vcolTok = std::to_string( + renderTileTemplateDim(cCol, elemType, blayout, 1)); + } else if (vCol) { + vcolTok = "-1"; + colIsDynamic = true; + useConstructor = true; + } else { + vcolTok = std::to_string( + renderTileTemplateDim(shape[1], elemType, blayout, 1)); + } + + if (useConstructor) { + if (rowIsDynamic && vRowEmitC) + constructorArgs.push_back(maybeScaleDynamicValid(vRowEmitC, 0)); + if (colIsDynamic && vColEmitC) + constructorArgs.push_back(maybeScaleDynamicValid(vColEmitC, 1)); + } + } + + // 5. 生成 Tile 类型字符串 + std::string tileTypeStr = + std::string("Tile<") + roleTok + ", " + elemTypeStr + ", " + dimStr + ", " + + layoutParams + ", " + vrowTok + ", " + vcolTok + extraParams + ">"; + + auto tileType = emitc::OpaqueType::get(ctx, tileTypeStr); + Value resultValue; + + if (useConstructor) { + // 使用 CallOpaqueOp 生成构造函数调用 (Tile v = Tile(...)) + auto ctorOp = rewriter.create( + loc, + tileType, // Result Type + tileTypeStr, // Callee Name (类名) + ArrayAttr{}, // args + ArrayAttr{}, // template_args + ValueRange(constructorArgs) // operands + ); + resultValue = ctorOp.getResult(0); + } else { + // 静态情况 (Tile v;) + auto varOp = rewriter.create( + loc, + tileType, + emitc::OpaqueAttr::get(ctx, "") + ); + resultValue = varOp.getResult(); + } + + // TASSIGN: pto-isa expects an integral address. + Value addr = adaptor.getAddrs()[0]; + if (isa(addr.getType()) || + (isa(addr.getType()) && + cast(addr.getType()).getValue().ends_with("*"))) { + auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); + auto rcU64 = rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + addr = rewriter.create( + loc, u64Ty, "reinterpret_cast", + /*args=*/ArrayAttr{}, /*templateArgs=*/rcU64, + /*operands=*/ValueRange{addr}) + .getResult(0); + } + + rewriter.create( + loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{resultValue, addr}); + + rewriter.replaceOp(op, resultValue); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// pto.load_dps / pto.store_dps lowering (FIX: keep optional result) +//===----------------------------------------------------------------------=== + +struct PTOTLoadToTLOAD : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TLoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getDst()) + return rewriter.notifyMatchFailure(op, "expected outs(dst) on pto.tload"); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value srcArg = src; + if (auto srcMrTy = dyn_cast(op.getSrc().getType())) { + bool isGlobal = true; + if (auto asAttr = dyn_cast_or_null(srcMrTy.getMemorySpace())) { + auto as = asAttr.getAddressSpace(); + isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); + } + if (isGlobal) { + if (Value gt = buildGlobalTensorFromMemref(rewriter, op.getLoc(), src, srcMrTy, + op.getOperation())) + srcArg = gt; + } + } + + rewriter.create( + op.getLoc(), TypeRange{}, "TLOAD", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, srcArg}); + + if (op->getNumResults() == 1) { + rewriter.replaceOp(op, dst); + } else { + rewriter.eraseOp(op); + } + return success(); + } +}; + +struct PTOTPrefetchToTPREFETCH : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPrefetchOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getDst()) + return rewriter.notifyMatchFailure(op, "expected outs(dst) on pto.tprefetch"); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value srcArg = src; + if (auto srcMrTy = dyn_cast(op.getSrc().getType())) { + bool isGlobal = true; + if (auto asAttr = dyn_cast_or_null(srcMrTy.getMemorySpace())) { + auto as = asAttr.getAddressSpace(); + isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); + } + if (isGlobal) { + if (Value gt = buildGlobalTensorFromMemref(rewriter, op.getLoc(), src, srcMrTy, + op.getOperation())) + srcArg = gt; + } + } + + rewriter.create( + op.getLoc(), TypeRange{}, "TPREFETCH", + ArrayAttr{}, ArrayAttr{}, ValueRange{dst, srcArg}); + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOTPrefetchAsyncToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPrefetchAsyncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = peelUnrealized(adaptor.getSrc()); + Value srcArg = src; + if (!isEmitCGlobalTensorLikeType(srcArg.getType())) { + auto srcMrTy = dyn_cast(op.getSrc().getType()); + if (!srcMrTy) + return rewriter.notifyMatchFailure( + op, "expected src to lower to GlobalTensor or memref"); + srcArg = buildGlobalTensorFromMemref(rewriter, op.getLoc(), src, srcMrTy, + op.getSrc().getDefiningOp() + ? op.getSrc().getDefiningOp() + : op.getOperation()); + } + if (!srcArg) + return rewriter.notifyMatchFailure(op, + "failed to build GlobalTensor src"); + + Value prefetchCtx = peelUnrealized(adaptor.getCtx()); + + Type eventTy = getTypeConverter()->convertType(op.getEvent().getType()); + if (!eventTy) + return rewriter.notifyMatchFailure( + op, "failed to convert tprefetch_async result type"); + + Value event = rewriter + .create( + op.getLoc(), TypeRange{eventTy}, "TPREFETCH_ASYNC", + ArrayAttr{}, ArrayAttr{}, + ValueRange{srcArg, prefetchCtx}) + .getResult(0); + + rewriter.replaceOp(op, ValueRange{event}); + return success(); + } +}; + +struct PTOMakePrefetchAsyncContextToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::MakePrefetchAsyncContextOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type ctxTy = getTypeConverter()->convertType(op.getCtx().getType()); + if (!ctxTy) + return rewriter.notifyMatchFailure( + op, "failed to convert make_prefetch_async_context result type"); + + Value workspace = peelUnrealized(adaptor.getWorkspace()); + workspace = castToGMBytePointer(rewriter, op.getLoc(), workspace); + + Value ctx = rewriter + .create( + op.getLoc(), TypeRange{ctxTy}, "pto::PrefetchAsyncContext", + ArrayAttr{}, ArrayAttr{}, ValueRange{workspace}) + .getResult(0); + + rewriter.replaceOp(op, ValueRange{ctx}); + return success(); + } +}; + +struct PTOGetPrefetchAsyncSessionToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::GetPrefetchAsyncSessionOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type sessionTy = getTypeConverter()->convertType(op.getSession().getType()); + if (!sessionTy) + return rewriter.notifyMatchFailure( + op, "failed to convert get_prefetch_async_session result type"); + + Value ctx = peelUnrealized(adaptor.getCtx()); + Value session = rewriter + .create( + op.getLoc(), TypeRange{sessionTy}, + "PTOAS__PREFETCH_CTX_SESSION", ArrayAttr{}, + ArrayAttr{}, ValueRange{ctx}) + .getResult(0); + + rewriter.replaceOp(op, ValueRange{session}); + return success(); + } +}; + +struct PTOTStoreToTSTORE : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + static std::string stPhaseTok(pto::STPhase phase) { + switch (phase) { + case pto::STPhase::Unspecified: return "STPhase::Unspecified"; + case pto::STPhase::Partial: return "STPhase::Partial"; + case pto::STPhase::Final: return "STPhase::Final"; + } + return "STPhase::Unspecified"; + } + + static std::string atomicTypeTok(pto::AtomicType atomicType) { + switch (atomicType) { + case pto::AtomicType::AtomicNone: return "AtomicType::AtomicNone"; + case pto::AtomicType::AtomicAdd: return "AtomicType::AtomicAdd"; + } + return "AtomicType::AtomicNone"; + } + + static std::string reluPreModeTok(pto::ReluPreMode reluPreMode) { + switch (reluPreMode) { + case pto::ReluPreMode::NoRelu: return "ReluPreMode::NoRelu"; + case pto::ReluPreMode::NormalRelu: return "ReluPreMode::NormalRelu"; + } + return "ReluPreMode::NoRelu"; + } + + LogicalResult matchAndRewrite(pto::TStoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getDst()) + return rewriter.notifyMatchFailure(op, "expected outs(dst) on pto.tstore"); + + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value preQuantScalar; + if (op.getPreQuantScalar()) + preQuantScalar = peelUnrealized(adaptor.getPreQuantScalar()); + Value dstArg = dst; + if (auto dstMrTy = dyn_cast(op.getDst().getType())) { + bool isGlobal = true; + if (auto asAttr = dyn_cast_or_null(dstMrTy.getMemorySpace())) { + auto as = asAttr.getAddressSpace(); + isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); + } + if (isGlobal) { + if (Value gt = buildGlobalTensorFromMemref(rewriter, op.getLoc(), dst, dstMrTy, + op.getOperation())) + dstArg = gt; + } + } + + const auto phase = op.getStPhase(); + const auto atomicType = op.getAtomicType(); + const auto reluPreMode = op.getReluPreMode(); + const bool hasPreQuantScalar = static_cast(preQuantScalar); + const bool phaseNonDefault = phase != pto::STPhase::Unspecified; + const bool atomicNonDefault = atomicType != pto::AtomicType::AtomicNone; + const bool reluNonDefault = reluPreMode != pto::ReluPreMode::NoRelu; + + auto getOpaqueTok = [&](Value v, StringRef name) -> FailureOr { + if (auto ot = mlir::dyn_cast(v.getType())) + return ot.getValue().str(); + return rewriter.notifyMatchFailure(op, (name + " must be emitc::OpaqueType").str()); + }; + + ArrayAttr targs; + // Map op attributes/operands to the exact TSTORE overload family: + // 1) TSTORE(dst, src) + // 2) TSTORE(dst, src) + // 3) TSTORE(dst, src) + // 4) TSTORE(dst, src) + // 5) TSTORE(dst, src) + // 6) TSTORE(dst, src) + // 7) TSTORE(dst, src, preQuant) + // 8) TSTORE(dst, src, preQuant) + if (!hasPreQuantScalar && !reluNonDefault && !atomicNonDefault) { + if (phaseNonDefault) { + targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, stPhaseTok(phase)), + }); + } else { + targs = ArrayAttr{}; + } + } else { + auto srcTokOr = getOpaqueTok(src, "src"); + auto dstTokOr = getOpaqueTok(dstArg, "dst"); + if (failed(srcTokOr) || failed(dstTokOr)) + return failure(); + + // If there is no preQuant and relu stays default, emit the atomic-only + // overloads (#3/#4) without ReluPreMode template argument. + if (!hasPreQuantScalar && !reluNonDefault) { + if (phaseNonDefault) { + targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, stPhaseTok(phase)), + emitc::OpaqueAttr::get(ctx, *srcTokOr), + emitc::OpaqueAttr::get(ctx, *dstTokOr), + emitc::OpaqueAttr::get(ctx, atomicTypeTok(atomicType)), + }); + } else { + targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, *srcTokOr), + emitc::OpaqueAttr::get(ctx, *dstTokOr), + emitc::OpaqueAttr::get(ctx, atomicTypeTok(atomicType)), + }); + } + } else { + // Relu/preQuant families (#5/#6/#7/#8): keep AtomicType + ReluPreMode. + if (phaseNonDefault) { + targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, stPhaseTok(phase)), + emitc::OpaqueAttr::get(ctx, *srcTokOr), + emitc::OpaqueAttr::get(ctx, *dstTokOr), + emitc::OpaqueAttr::get(ctx, atomicTypeTok(atomicType)), + emitc::OpaqueAttr::get(ctx, reluPreModeTok(reluPreMode)), + }); + } else { + targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, *srcTokOr), + emitc::OpaqueAttr::get(ctx, *dstTokOr), + emitc::OpaqueAttr::get(ctx, atomicTypeTok(atomicType)), + emitc::OpaqueAttr::get(ctx, reluPreModeTok(reluPreMode)), + }); + } + } + } + + SmallVector operands{dstArg, src}; + if (hasPreQuantScalar) + operands.push_back(preQuantScalar); + + rewriter.create( + loc, TypeRange{}, "TSTORE", + /*args=*/ArrayAttr{}, /*templateArgs=*/targs, + /*operands=*/operands); + + if (op->getNumResults() == 1) { + rewriter.replaceOp(op, dst); + } else { + rewriter.eraseOp(op); + } + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// pto.matmul_dps lowering (Simplified: No internal copy/sync) +//===----------------------------------------------------------------------===// +// +// Render `pto.tmatmul` as one of three forms depending on the optional +// `acc_phase` attribute: +// * absent / Unspecified -> `TMATMUL(dst, lhs, rhs)` +// * Partial -> `TMATMUL(dst, lhs, rhs)` +// * Final -> `TMATMUL(dst, lhs, rhs)` +// The Unspecified default keeps backward compatibility with all upstream IR +// that does not yet emit an explicit phase attribute. +static ArrayAttr buildAccPhaseTemplateArgs(ConversionPatternRewriter &rewriter, + pto::AccPhase phase) { + StringRef tmpl; + switch (phase) { + case pto::AccPhase::Unspecified: + return ArrayAttr{}; + case pto::AccPhase::Partial: + tmpl = "AccPhase::Partial"; + break; + case pto::AccPhase::Final: + tmpl = "AccPhase::Final"; + break; + } + if (tmpl.empty()) + return ArrayAttr{}; + return rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(rewriter.getContext(), tmpl)}); +} + +struct PTOTMatmulToTMATMUL : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMatmulOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // 1. 获取操作数 (剥离 Cast) + Value lhs = peelUnrealized(adaptor.getLhs()); // A (Left) + Value rhs = peelUnrealized(adaptor.getRhs()); // B (Right) + Value dst = peelUnrealized(adaptor.getDst()); // C (Acc) + + // 2. 根据 acc_phase 属性决定是否生成 TMATMUL(...) + ArrayAttr templateArgs = + buildAccPhaseTemplateArgs(rewriter, op.getAccPhase()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TMATMUL", + /*args=*/ArrayAttr{}, /*template_args=*/templateArgs, + ValueRange{dst, lhs, rhs}); + + // 3. 处理 Op 替换/删除 + if (op->getNumResults() == 1) { + rewriter.replaceOp(op, dst); + } else { + rewriter.eraseOp(op); + } + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// pto.tgemv lowering +//===----------------------------------------------------------------------===// +struct PTOTGemvToTGEMV : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGemvOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // 1. 获取操作数 (剥离 Cast) + Value lhs = peelUnrealized(adaptor.getLhs()); // A (Matrix) + Value rhs = peelUnrealized(adaptor.getRhs()); // B (Vector) + Value dst = peelUnrealized(adaptor.getDst()); // C (Result) + + // 2. 直接生成函数调用 TGEMV(dst, lhs, rhs) + rewriter.create( + op.getLoc(), TypeRange{}, "TGEMV", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, lhs, rhs}); + + // 3. 处理 Op 替换/删除 + if (op->getNumResults() == 1) { + rewriter.replaceOp(op, dst); + } else { + rewriter.eraseOp(op); + } + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// pto.tgemv.acc lowering +//===----------------------------------------------------------------------===// +struct PTOTGemvAccToTGEMVACC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGemvAccOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getDst()) + return rewriter.notifyMatchFailure(op, "expected outs(dst) for pto.tgemv.acc"); + + // 1. 获取操作数 + Value accIn = peelUnrealized(adaptor.getAccIn()); // AccOld + Value lhs = peelUnrealized(adaptor.getLhs()); // A (Matrix) + Value rhs = peelUnrealized(adaptor.getRhs()); // B (Vector) + Value dst = peelUnrealized(adaptor.getDst()); // AccNew + + // 2. 直接生成函数调用 TGEMV_ACC(dst, accIn, lhs, rhs) + rewriter.create( + op.getLoc(), TypeRange{}, "TGEMV_ACC", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, accIn, lhs, rhs}); + + // 3. 处理 Op 替换/删除 + if (op->getNumResults() == 1) { + rewriter.replaceOp(op, dst); + } else { + rewriter.eraseOp(op); + } + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// pto.matmul_acc_dps lowering (Simplified: No internal copy/sync) +//===----------------------------------------------------------------------===// +struct PTOTMatmulAccToTMATMULACC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMatmulAccOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getDst()) + return rewriter.notifyMatchFailure(op, "expected outs(dst) for pto.tmatmul.acc"); + + // 1. 获取操作数 + Value accIn = peelUnrealized(adaptor.getAccIn()); // AccOld + Value lhs = peelUnrealized(adaptor.getLhs()); // A (Left) + Value rhs = peelUnrealized(adaptor.getRhs()); // B (Right) + Value dst = peelUnrealized(adaptor.getDst()); // AccNew + + // 2. 根据 acc_phase 属性决定是否生成 TMATMUL_ACC(...) + ArrayAttr templateArgs = + buildAccPhaseTemplateArgs(rewriter, op.getAccPhase()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TMATMUL_ACC", + /*args=*/ArrayAttr{}, /*template_args=*/templateArgs, + ValueRange{dst, accIn, lhs, rhs}); + + // 3. 处理 Op 替换/删除 + if (op->getNumResults() == 1) { + rewriter.replaceOp(op, dst); + } else { + rewriter.eraseOp(op); + } + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Return lowering +//===----------------------------------------------------------------------=== + +static constexpr llvm::StringLiteral kAutoSyncTailPendingModeAttr = + "__pto.auto_sync_tail_mode"; + +struct ReturnToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (auto emitcFunc = op->getParentOfType()) { + if (auto modeAttr = + emitcFunc->getAttrOfType(kAutoSyncTailPendingModeAttr)) { + auto *ctx = rewriter.getContext(); + rewriter.setInsertionPoint(op); + auto args = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, modeAttr.getValue())}); + rewriter.create( + op.getLoc(), TypeRange{}, "ptoas_auto_sync_tail", + args, ArrayAttr{}, ValueRange{}); + } + } + + auto vals = adaptor.getOperands(); + if (vals.empty()) { + rewriter.replaceOpWithNewOp(op, Value{}); + return success(); + } + if (vals.size() == 1) { + rewriter.replaceOpWithNewOp(op, vals[0]); + return success(); + } + return rewriter.notifyMatchFailure(op, "EmitC cannot return multiple values"); + } +}; + +struct CallToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(func::CallOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op.getNumResults() > 1) + return rewriter.notifyMatchFailure( + op, "EmitC cannot lower calls with multiple results"); + + SmallVector resultTypes; + if (failed( + getTypeConverter()->convertTypes(op.getResultTypes(), resultTypes))) + return rewriter.notifyMatchFailure(op, + "failed to convert call result types"); + + rewriter.replaceOpWithNewOp(op, op.getCalleeAttr(), + resultTypes, + adaptor.getOperands()); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Sync lowering +//===----------------------------------------------------------------------=== + +static constexpr llvm::StringLiteral kAutoSyncTailBarrierAttr = + "pto.auto_sync_tail_barrier"; +static constexpr llvm::StringLiteral kAutoSyncTailHintAttr = + "pto.auto_sync_tail_hint"; +static constexpr llvm::StringLiteral kAutoSyncTailPolicyBarrierAll = + "barrier_all"; +static constexpr llvm::StringLiteral kAutoSyncTailPolicyMte3ToSEvent0 = + "setwait_mte3_to_s_event0"; +static constexpr llvm::StringLiteral kAutoSyncTailModeBarrierAllToken = + "PTOAutoSyncTailMode::kBarrierAll"; +static constexpr llvm::StringLiteral kAutoSyncTailModeMte3ToSEvent0Token = + "PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0"; + +static std::string getAutoSyncTailModeToken(Operation *op) { + if (op) { + if (auto hintAttr = op->getAttrOfType(kAutoSyncTailHintAttr)) { + if (hintAttr.getValue() == kAutoSyncTailPolicyBarrierAll) + return kAutoSyncTailModeBarrierAllToken.str(); + if (hintAttr.getValue() == kAutoSyncTailPolicyMte3ToSEvent0) + return kAutoSyncTailModeMte3ToSEvent0Token.str(); + } + } + + auto func = op ? op->getParentOfType() : func::FuncOp(); + if (!func) + return kAutoSyncTailModeBarrierAllToken.str(); + + auto hintAttr = func->getAttrOfType(kAutoSyncTailHintAttr); + if (!hintAttr) + return kAutoSyncTailModeBarrierAllToken.str(); + + if (hintAttr.getValue() == kAutoSyncTailPolicyBarrierAll) + return kAutoSyncTailModeBarrierAllToken.str(); + if (hintAttr.getValue() == kAutoSyncTailPolicyMte3ToSEvent0) + return kAutoSyncTailModeMte3ToSEvent0Token.str(); + + // Fallback to the conservative behavior when seeing unknown policies. + return kAutoSyncTailModeBarrierAllToken.str(); +} + +[[maybe_unused]] static std::string getPipeName(pto::PIPE pipe) { + switch (pipe) { + case pto::PIPE::PIPE_S: return "PIPE_S"; + case pto::PIPE::PIPE_V: return "PIPE_V"; + case pto::PIPE::PIPE_M: return "PIPE_M"; + case pto::PIPE::PIPE_MTE1: return "PIPE_MTE1"; + case pto::PIPE::PIPE_MTE2: return "PIPE_MTE2"; + case pto::PIPE::PIPE_MTE3: return "PIPE_MTE3"; + case pto::PIPE::PIPE_ALL: return "PIPE_ALL"; + case pto::PIPE::PIPE_MTE4: return "PIPE_MTE4"; + case pto::PIPE::PIPE_MTE5: return "PIPE_MTE5"; + case pto::PIPE::PIPE_V2: return "PIPE_V2"; + case pto::PIPE::PIPE_FIX: return "PIPE_FIX"; + case pto::PIPE::VIRTUAL_PIPE_MTE2_L1A: return "VIRTUAL_PIPE_MTE2_L1A"; + case pto::PIPE::VIRTUAL_PIPE_MTE2_L1B: return "VIRTUAL_PIPE_MTE2_L1B"; + // 默认回退 + default: return "PIPE_ALL"; + } +} + +//===----------------------------------------------------------------------===// +// pto.barrier lowering -> pipe_barrier(...) +//===----------------------------------------------------------------------===// +struct PTOBarrierToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::BarrierOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op->hasAttr(kAutoSyncTailBarrierAttr)) { + auto modeAttr = rewriter.getStringAttr(getAutoSyncTailModeToken(op)); + if (auto emitcFunc = op->getParentOfType()) { + emitcFunc->setAttr(kAutoSyncTailPendingModeAttr, modeAttr); + } else if (auto funcOp = op->getParentOfType()) { + funcOp->setAttr(kAutoSyncTailPendingModeAttr, modeAttr); + } + rewriter.eraseOp(op); + return success(); + } + + // [FIX] op.getPipe() returns PipeAttr. + // We must call .getPipe() on the attribute to get the actual Enum value. + pto::PIPE pipeEnum = op.getPipe().getPipe(); + + // Convert Enum to String (e.g., PIPE_ALL -> "PIPE_ALL") + std::string pipeStr = pto::stringifyPIPE(pipeEnum).str(); + auto *ctx = rewriter.getContext(); + + auto args = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeStr) + }); + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, // void return + "pipe_barrier", // function name + args, // arguments + ArrayAttr{}, // template args + ValueRange{} // operands + ); + + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Sync lowering (robust for bracket form pto.set_flag[...] / pto.wait_flag[...]) +// Replace your PTOSyncToRuntimeCall with the code below. +//===----------------------------------------------------------------------===// + +static bool tryConvertPipeAttrToToken(Attribute attr, std::string &token) { + if (!attr) + return false; + if (auto pipe = dyn_cast(attr)) { + token = mlir::pto::stringifyPIPE(pipe.getPipe()).str(); + return true; + } + if (auto stringAttr = dyn_cast(attr)) { + token = stringAttr.getValue().str(); + return true; + } + return false; +} + +static bool tryConvertEventAttrToToken(Attribute attr, std::string &token) { + if (!attr) + return false; + if (auto event = dyn_cast(attr)) { + token = mlir::pto::stringifyEVENT(event.getEvent()).str(); + return true; + } + if (auto stringAttr = dyn_cast(attr)) { + token = stringAttr.getValue().str(); + return true; + } + return false; +} + +static bool tryAssignSyncTokens(Attribute srcAttr, Attribute dstAttr, + Attribute evtAttr, std::string &srcTok, + std::string &dstTok, std::string &evtTok) { + std::string localSrc; + std::string localDst; + std::string localEvt; + if (!tryConvertPipeAttrToToken(srcAttr, localSrc) || + !tryConvertPipeAttrToToken(dstAttr, localDst) || + !tryConvertEventAttrToToken(evtAttr, localEvt)) { + return false; + } + srcTok = std::move(localSrc); + dstTok = std::move(localDst); + evtTok = std::move(localEvt); + return true; +} + +static bool tryExtractSyncTokensFromNamedAttrs(Operation *op, + StringRef srcName, + StringRef dstName, + StringRef evtName, + std::string &srcTok, + std::string &dstTok, + std::string &evtTok) { + return tryAssignSyncTokens(op->getAttr(srcName), op->getAttr(dstName), + op->getAttr(evtName), srcTok, dstTok, evtTok); +} + +static bool tryExtractSyncTokensFromArrayAttr(Operation *op, StringRef attrName, + std::string &srcTok, + std::string &dstTok, + std::string &evtTok) { + auto arrayAttr = op->getAttrOfType(attrName); + if (!arrayAttr || arrayAttr.size() < 3) + return false; + return tryAssignSyncTokens(arrayAttr[0], arrayAttr[1], arrayAttr[2], srcTok, + dstTok, evtTok); +} + +static bool tryExtractFallbackSyncTokens(Operation *op, std::string &srcTok, + std::string &dstTok, + std::string &evtTok) { + SmallVector pipes; + std::string event; + for (NamedAttribute namedAttr : op->getAttrs()) { + std::string token; + if (tryConvertPipeAttrToToken(namedAttr.getValue(), token)) { + pipes.push_back(std::move(token)); + continue; + } + if (event.empty() && + tryConvertEventAttrToToken(namedAttr.getValue(), token)) { + event = std::move(token); + } + } + if (pipes.size() < 2 || event.empty()) + return false; + srcTok = pipes[0]; + dstTok = pipes[1]; + evtTok = event; + return true; +} + +static LogicalResult extractSyncTripletTokens(Operation *op, + std::string &srcTok, + std::string &dstTok, + std::string &evtTok, + ConversionPatternRewriter &rewriter) { + if (tryExtractSyncTokensFromNamedAttrs(op, "src_pipe", "dst_pipe", "event_id", + srcTok, dstTok, evtTok) || + tryExtractSyncTokensFromNamedAttrs(op, "srcPipe", "dstPipe", "eventId", + srcTok, dstTok, evtTok) || + tryExtractSyncTokensFromNamedAttrs(op, "src", "dst", "event", srcTok, + dstTok, evtTok)) { + return success(); + } + + for (StringRef attrName : {"args", "pipes", "sync", "triplet", "attrs"}) { + if (tryExtractSyncTokensFromArrayAttr(op, attrName, srcTok, dstTok, + evtTok)) { + return success(); + } + } + + if (tryExtractFallbackSyncTokens(op, srcTok, dstTok, evtTok)) + return success(); + return rewriter.notifyMatchFailure( + op, "cannot extract PIPE/PIPE/EVENT tokens from pto.{set,wait}_flag"); +} +static inline std::string pipeTokFromPipeEnum(mlir::pto::PIPE p) { + return mlir::pto::stringifyPIPE(p).str(); +} +[[maybe_unused]] static inline std::string evtTokFromEventEnum(mlir::pto::EVENT e) { + return mlir::pto::stringifyEVENT(e).str(); +} +static inline std::string pipeTokFromPipeAttr(mlir::pto::PipeAttr a) { + return mlir::pto::stringifyPIPE(a.getPipe()).str(); +} +static inline std::string evtTokFromEventAttr(mlir::pto::EventAttr a) { + return mlir::pto::stringifyEVENT(a.getEvent()).str(); +} + +template +struct HasGetSrcPipe : std::false_type {}; +template +struct HasGetSrcPipe().getSrcPipe())>> : std::true_type {}; + +template +struct HasGetDstPipe : std::false_type {}; +template +struct HasGetDstPipe().getDstPipe())>> : std::true_type {}; + +template +struct HasGetEventId : std::false_type {}; +template +struct HasGetEventId().getEventId())>> : std::true_type {}; + +template +struct HasGetSrcPipeAttr : std::false_type {}; +template +struct HasGetSrcPipeAttr().getSrcPipeAttr())>> : std::true_type {}; + +template +struct HasGetDstPipeAttr : std::false_type {}; +template +struct HasGetDstPipeAttr().getDstPipeAttr())>> : std::true_type {}; + +template +struct HasGetEventIdAttr : std::false_type {}; +template +struct HasGetEventIdAttr().getEventIdAttr())>> : std::true_type {}; + +template +static LogicalResult extractSyncTokens(SyncOpT op, + std::string &srcTok, + std::string &dstTok, + std::string &evtTok, + ConversionPatternRewriter &rewriter) { + if constexpr (HasGetSrcPipe::value && + HasGetDstPipe::value && + HasGetEventId::value) { + auto s = op.getSrcPipe(); + auto d = op.getDstPipe(); + auto e = op.getEventId(); + + if constexpr (std::is_same::value) srcTok = pipeTokFromPipeEnum(s); + else srcTok = pipeTokFromPipeAttr(s); + + if constexpr (std::is_same::value) dstTok = pipeTokFromPipeEnum(d); + else dstTok = pipeTokFromPipeAttr(d); + + if constexpr (std::is_same::value) evtTok = evtTokFromEventEnum(e); + else evtTok = evtTokFromEventAttr(e); + + return success(); + } + + if constexpr (HasGetSrcPipeAttr::value && + HasGetDstPipeAttr::value && + HasGetEventIdAttr::value) { + auto s = op.getSrcPipeAttr(); + auto d = op.getDstPipeAttr(); + auto e = op.getEventIdAttr(); + srcTok = pipeTokFromPipeAttr(s); + dstTok = pipeTokFromPipeAttr(d); + evtTok = evtTokFromEventAttr(e); + return success(); + } + + return extractSyncTripletTokens(op.getOperation(), srcTok, dstTok, evtTok, rewriter); +} +struct PTOSetFlagToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::SetFlagOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + auto *ctx = rewriter.getContext(); + + std::string srcTok, dstTok, evtTok; + if (failed(extractSyncTokens(op, srcTok, dstTok, evtTok, rewriter))) + return failure(); + + auto argsAttr = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, srcTok), + emitc::OpaqueAttr::get(ctx, dstTok), + emitc::OpaqueAttr::get(ctx, evtTok), + }); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, "set_flag", + /*args=*/argsAttr, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{}); + return success(); + } +}; + +struct PTOWaitFlagToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::WaitFlagOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + auto *ctx = rewriter.getContext(); + + std::string srcTok, dstTok, evtTok; + if (failed(extractSyncTokens(op, srcTok, dstTok, evtTok, rewriter))) + return failure(); + + auto argsAttr = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, srcTok), + emitc::OpaqueAttr::get(ctx, dstTok), + emitc::OpaqueAttr::get(ctx, evtTok), + }); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, "wait_flag", + /*args=*/argsAttr, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{}); + return success(); + } +}; + +struct PTOSyncToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::TSyncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector operands; + operands.reserve(adaptor.getEvents().size()); + for (Value event : adaptor.getEvents()) + operands.push_back(peelUnrealized(event)); + + rewriter.create( + op.getLoc(), TypeRange{}, "TSYNC", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange(operands)); + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOSyncAllToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + static StringRef coreTypeTok(pto::SyncCoreType coreType) { + switch (coreType) { + case pto::SyncCoreType::AIVOnly: + return "SyncCoreType::AIVOnly"; + case pto::SyncCoreType::AICOnly: + return "SyncCoreType::AICOnly"; + case pto::SyncCoreType::Mix: + return "SyncCoreType::Mix"; + } + llvm_unreachable("unhandled SyncCoreType"); + } + + LogicalResult matchAndRewrite(mlir::pto::SyncAllOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto mode = op.getMode().getValue(); + auto coreType = op.getCoreType().getValue(); + + auto buildGmWorkspace = [&]() -> FailureOr { + Value gm = peelUnrealized(adaptor.getGmWorkspace()); + if (isEmitCGlobalTensorLikeType(gm.getType())) + return gm; + + auto memTy = dyn_cast(op.getGmWorkspace().getType()); + if (!memTy) + return failure(); + + Value gt = buildGlobalTensorFromMemref(rewriter, op.getLoc(), gm, memTy, + op.getGmWorkspace().getDefiningOp() + ? op.getGmWorkspace().getDefiningOp() + : op.getOperation()); + if (!gt) + return failure(); + return gt; + }; + + if (mode == pto::SyncAllMode::Hard) { + std::string callee = "SYNCALL<" + coreTypeTok(coreType).str() + ">"; + rewriter.create(op.getLoc(), TypeRange{}, callee, + ArrayAttr{}, ArrayAttr{}, + ValueRange{}); + rewriter.eraseOp(op); + return success(); + } + + FailureOr gmWorkspace = buildGmWorkspace(); + if (failed(gmWorkspace)) + return rewriter.notifyMatchFailure(op, + "failed to build gm_workspace GlobalTensor"); + + auto i32Ty = emitc::OpaqueType::get(rewriter.getContext(), "int32_t"); + Value usedCores = adaptor.getUsedCores() + ? peelUnrealized(adaptor.getUsedCores()) + : makeEmitCIntConstant(rewriter, op.getLoc(), i32Ty, 0); + if (usedCores.getType() != i32Ty) + usedCores = rewriter.create(op.getLoc(), i32Ty, usedCores) + .getResult(); + + std::string callee = + "SYNCALL"; + + SmallVector operands{*gmWorkspace}; + switch (coreType) { + case pto::SyncCoreType::AIVOnly: { + FailureOr ubWorkspace = + buildSyncAllWorkspaceTileValue(rewriter, op.getLoc(), + op.getUbWorkspace(), + adaptor.getUbWorkspace()); + if (failed(ubWorkspace)) + return rewriter.notifyMatchFailure( + op, "failed to materialize ub_workspace tile"); + operands.push_back(*ubWorkspace); + break; + } + case pto::SyncCoreType::AICOnly: { + FailureOr l1Workspace = + buildSyncAllWorkspaceTileValue(rewriter, op.getLoc(), + op.getL1Workspace(), + adaptor.getL1Workspace()); + if (failed(l1Workspace)) + return rewriter.notifyMatchFailure( + op, "failed to materialize l1_workspace tile"); + operands.push_back(*l1Workspace); + break; + } + case pto::SyncCoreType::Mix: { + FailureOr ubWorkspace = + buildSyncAllWorkspaceTileValue(rewriter, op.getLoc(), + op.getUbWorkspace(), + adaptor.getUbWorkspace()); + FailureOr l1Workspace = + buildSyncAllWorkspaceTileValue(rewriter, op.getLoc(), + op.getL1Workspace(), + adaptor.getL1Workspace()); + if (failed(ubWorkspace) || failed(l1Workspace)) + return rewriter.notifyMatchFailure( + op, "failed to materialize mixed syncall workspace tiles"); + operands.push_back(*ubWorkspace); + operands.push_back(*l1Workspace); + break; + } + } + + operands.push_back(usedCores); + rewriter.create(op.getLoc(), TypeRange{}, callee, + ArrayAttr{}, ArrayAttr{}, + ValueRange(operands)); + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOSyncFlagDynToEmitC : public ConversionPattern { + PTOSyncFlagDynToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + StringRef opName, StringRef callee) + : ConversionPattern(typeConverter, opName, /*benefit=*/1, ctx), + callee(callee.str()) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (operands.size() != 1) + return rewriter.notifyMatchFailure(op, "expected exactly one dynamic event-id operand"); + + auto srcAttr = op->getAttrOfType("src_pipe"); + auto dstAttr = op->getAttrOfType("dst_pipe"); + if (!srcAttr || !dstAttr) + return rewriter.notifyMatchFailure(op, "missing PipeAttr src_pipe/dst_pipe attrs"); + + auto *ctx = rewriter.getContext(); + std::string srcTok = pipeTokFromPipeAttr(srcAttr); + std::string dstTok = pipeTokFromPipeAttr(dstAttr); + + Value eventVal = operands.front(); + eventVal = + emitCCast(rewriter, op->getLoc(), emitc::OpaqueType::get(ctx, "event_t"), eventVal); + + auto argsAttr = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, srcTok), + emitc::OpaqueAttr::get(ctx, dstTok), + IntegerAttr::get(IndexType::get(ctx), 0), + }); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, callee, + /*args=*/argsAttr, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{eventVal}); + return success(); + } + +private: + std::string callee; +}; + +struct PTOGetBufToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::GetBufOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + auto *ctx = rewriter.getContext(); + + auto opTypeOr = parseSyncOpTypeLikeAttr(op.getOpTypeAttr()); + if (failed(opTypeOr)) + return rewriter.notifyMatchFailure(op, "get_buf expects pipe_event_type/sync_op_type attr"); + auto pipe = mapSyncOpTypeToPipe(*opTypeOr); + if (!isConcreteSyncPipe(pipe)) + return rewriter.notifyMatchFailure(op, "get_buf op_type cannot map to a concrete pipe"); + std::string pipeTok = pipeTokFromPipeEnum(pipe); + auto argsAttr = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeTok), + op.getBufIdAttr(), + op.getModeAttr(), + }); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, "get_buf", + /*args=*/argsAttr, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{}); + return success(); + } +}; + +struct PTORlsBufToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::RlsBufOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + auto *ctx = rewriter.getContext(); + + auto opTypeOr = parseSyncOpTypeLikeAttr(op.getOpTypeAttr()); + if (failed(opTypeOr)) + return rewriter.notifyMatchFailure(op, "rls_buf expects pipe_event_type/sync_op_type attr"); + auto pipe = mapSyncOpTypeToPipe(*opTypeOr); + if (!isConcreteSyncPipe(pipe)) + return rewriter.notifyMatchFailure(op, "rls_buf op_type cannot map to a concrete pipe"); + std::string pipeTok = pipeTokFromPipeEnum(pipe); + auto argsAttr = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeTok), + op.getBufIdAttr(), + op.getModeAttr(), + }); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, "rls_buf", + /*args=*/argsAttr, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{}); + return success(); + } +}; + +struct PTOSetFFTsToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::SetFFTsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *ctx = rewriter.getContext(); + auto loc = op.getLoc(); + + Value fftsAddr = peelUnrealized(adaptor.getFfts()); + auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); + + if (isSetFFTsPointerLikeType(fftsAddr.getType())) { + auto castTyAttr = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + fftsAddr = + rewriter + .create(loc, u64Ty, "reinterpret_cast", + /*args=*/ArrayAttr{}, + /*templateArgs=*/castTyAttr, + /*operands=*/ValueRange{fftsAddr}) + .getResult(0); + } else if (fftsAddr.getType() != u64Ty) { + fftsAddr = + rewriter.create(loc, u64Ty, fftsAddr).getResult(); + } + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, "set_ffts_base_addr", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{fftsAddr}); + return success(); + } +}; + +struct PTOSyncSetToEmitC : public OpConversionPattern { + PTOSyncSetToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + PTOArch targetArch) + : OpConversionPattern(typeConverter, ctx), + targetArch(targetArch) {} + + LogicalResult + matchAndRewrite(mlir::pto::SyncSetOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto *ctx = rewriter.getContext(); + IntegerAttr eventIdAttr = op.getEventIdAttr(); + Value eventIdDyn = adaptor.getEventIdDyn(); + int64_t fftsMode = 2; + if (IntegerAttr fftsModeAttr = op.getFftsModeAttr()) + fftsMode = fftsModeAttr.getInt(); + + if ((eventIdAttr != nullptr) == static_cast(eventIdDyn)) + return rewriter.notifyMatchFailure( + op, "expects exactly one of static event_id attr or dynamic event_id operand"); + + // A5 inter-core sync mirrors +16 only for cube-side producer (PIPE_FIX). + // Vec-side producer (PIPE_MTE3) emits a single set; hardware handles the + // subblock mapping in PTO-ISA custom flow. + if (targetArch == PTOArch::A5) { + pto::PIPE pipe = op.getPipe().getPipe(); + bool needsMirrorPlus16 = (pipe == pto::PIPE::PIPE_FIX); + std::string pipeTok = pipeTokFromPipeAttr(op.getPipe()); + auto emitSet = [&](Value eventOperand, IntegerAttr eventLiteral, + bool isDynamic) { + if (isDynamic) { + auto args = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeTok), + IntegerAttr::get(IndexType::get(ctx), 0), + }); + rewriter.create(loc, TypeRange{}, "set_intra_block", + /*args=*/args, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{eventOperand}); + return; + } + auto args = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeTok), + eventLiteral, + }); + rewriter.create(loc, TypeRange{}, "set_intra_block", + /*args=*/args, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{}); + }; + + if (eventIdAttr) { + emitSet(Value{}, eventIdAttr, /*isDynamic=*/false); + if (needsMirrorPlus16) { + auto plus16 = IntegerAttr::get(eventIdAttr.getType(), + eventIdAttr.getInt() + 16); + emitSet(Value{}, plus16, /*isDynamic=*/false); + } + } else { + Value eventI32 = castInterCoreEventIdToI32(rewriter, loc, eventIdDyn); + emitSet(eventI32, IntegerAttr{}, /*isDynamic=*/true); + if (needsMirrorPlus16) { + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value c16 = makeEmitCIntConstant(rewriter, loc, i32Ty, 16); + Value eventI32Plus16 = + rewriter.create(loc, i32Ty, eventI32, c16).getResult(); + emitSet(eventI32Plus16, IntegerAttr{}, /*isDynamic=*/true); + } + } + + rewriter.eraseOp(op); + return success(); + } + + InterCoreSyncCallDesc desc; + if (eventIdAttr) { + desc = buildInterCoreSyncSetCall(rewriter, loc, targetArch, op.getPipe(), + eventIdAttr, fftsMode); + } else { + desc = buildInterCoreSyncSetCallDyn(rewriter, loc, targetArch, op.getPipe(), + eventIdDyn, fftsMode); + } + rewriter.create(loc, TypeRange{}, desc.callee, + /*args=*/desc.args, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/desc.operands); + + rewriter.eraseOp(op); + return success(); + } + + PTOArch targetArch; +}; + +struct PTOSyncWaitToEmitC : public OpConversionPattern { + PTOSyncWaitToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + PTOArch targetArch) + : OpConversionPattern(typeConverter, ctx), + targetArch(targetArch) {} + + LogicalResult + matchAndRewrite(mlir::pto::SyncWaitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + IntegerAttr eventIdAttr = op.getEventIdAttr(); + Value eventIdDyn = adaptor.getEventIdDyn(); + + if ((eventIdAttr != nullptr) == static_cast(eventIdDyn)) + return rewriter.notifyMatchFailure( + op, "expects exactly one of static event_id attr or dynamic event_id operand"); + + InterCoreSyncCallDesc desc; + if (eventIdAttr) { + desc = buildInterCoreSyncWaitCall(rewriter, targetArch, op.getPipe(), + eventIdAttr); + } else { + desc = buildInterCoreSyncWaitCallDyn(rewriter, loc, targetArch, op.getPipe(), + eventIdDyn); + } + rewriter.create(loc, TypeRange{}, desc.callee, + desc.args, ArrayAttr{}, desc.operands); + + rewriter.eraseOp(op); + return success(); + } + + PTOArch targetArch; +}; + +// GetBlockIdxOp Lowering (pto.get_block_idx -> get_block_idx()) +struct PTOGetBlockIdxToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mlir::pto::GetBlockIdxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + rewriter.replaceOpWithNewOp( + op, op.getType(), "get_block_idx", ValueRange{}, ArrayAttr{}, + ArrayAttr{}); + + return success(); + } +}; + +// GetBlockNumOp Lowering (pto.get_block_num -> get_block_num()) +struct PTOGetBlockNumToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mlir::pto::GetBlockNumOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + rewriter.replaceOpWithNewOp( + op, op.getType(), "get_block_num", ValueRange{}, ArrayAttr{}, + ArrayAttr{}); + + return success(); + } +}; + +// GetSubBlockIdxOp Lowering (pto.get_block_idx -> get_subblockid()) +struct PTOGetSubBlockIdxToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mlir::pto::GetSubBlockIdxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + rewriter.replaceOpWithNewOp( + op, op.getType(), "get_subblockid", ValueRange{}, ArrayAttr{}, + ArrayAttr{}); + + return success(); + } +}; + +// GetSubBlockNumOp Lowering. +struct PTOGetSubBlockNumToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mlir::pto::GetSubBlockNumOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + rewriter.replaceOpWithNewOp( + op, op.getType(), "get_subblockdim", ValueRange{}, ArrayAttr{}, + ArrayAttr{}); + + return success(); + } +}; + + +struct PTOMScatterToMSCATTER : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::MScatterOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *ctx = rewriter.getContext(); + Value src = peelUnrealized(adaptor.getSrc()); + Value idx = peelUnrealized(adaptor.getIdx()); + Value mem = peelUnrealized(adaptor.getMem()); + + Value memArg = maybeWrapGlobalMemrefAsGlobalTensor( + rewriter, op.getLoc(), mem, op.getMem().getType(), op.getOperation()); + + auto scatterAtomicTok = [&](pto::ScatterAtomicOp atomic) -> StringRef { + switch (atomic) { + case pto::ScatterAtomicOp::None: + return "pto::ScatterAtomicOp::None"; + case pto::ScatterAtomicOp::Add: + return "pto::ScatterAtomicOp::Add"; + case pto::ScatterAtomicOp::Max: + return "pto::ScatterAtomicOp::Max"; + case pto::ScatterAtomicOp::Min: + return "pto::ScatterAtomicOp::Min"; + } + llvm_unreachable("unknown ScatterAtomicOp"); + }; + auto scatterOobTok = [&](pto::ScatterOOB mode) -> StringRef { + switch (mode) { + case pto::ScatterOOB::Undefined: + return "pto::ScatterOOB::Undefined"; + case pto::ScatterOOB::Skip: + return "pto::ScatterOOB::Skip"; + case pto::ScatterOOB::Clamp: + return "pto::ScatterOOB::Clamp"; + case pto::ScatterOOB::Wrap: + return "pto::ScatterOOB::Wrap"; + } + llvm_unreachable("unknown ScatterOOB"); + }; + + SmallVector templateArgVec; + const bool rowCoalesce = + isRowCoalescedMGatherIndexType(op.getSrc().getType(), op.getIdx().getType()); + templateArgVec.push_back(emitc::OpaqueAttr::get( + ctx, rowCoalesce ? "pto::Coalesce::Row" : "pto::Coalesce::Elem")); + if (op.getScatterAtomicOp() != pto::ScatterAtomicOp::None || + op.getScatterOob() != pto::ScatterOOB::Undefined) { + templateArgVec.push_back(emitc::OpaqueAttr::get( + ctx, scatterAtomicTok(op.getScatterAtomicOp()))); + if (op.getScatterOob() != pto::ScatterOOB::Undefined) + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, scatterOobTok(op.getScatterOob()))); + } + ArrayAttr templateArgs = rewriter.getArrayAttr(templateArgVec); + + rewriter.create( + op.getLoc(), TypeRange{}, "MSCATTER", + ArrayAttr{}, templateArgs, + ValueRange{memArg, src, idx}); + + rewriter.eraseOp(op); + return success(); + } +}; +struct PTOSetValToSETVAL : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSetValOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value dst = peelUnrealized(adaptor.getDst()); + Value val = peelUnrealized(adaptor.getVal()); + + // ---- offset: SSA index operand ---- + Value offset = peelUnrealized(adaptor.getOffset()); + + // Emit a marker call and let the ptoas post-processing step lower it to + // the corresponding tile setter. + rewriter.create( + op.getLoc(), TypeRange{}, "PTOAS__TILE_SET_VALUE", + ArrayAttr{}, ArrayAttr{}, ValueRange{dst, offset, val}); + + rewriter.eraseOp(op); + return success(); + } +}; +struct PTOGetValToGETVAL : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGetValOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = peelUnrealized(adaptor.getSrc()); + + // ---- offset: SSA index operand ---- + Value offset = peelUnrealized(adaptor.getOffset()); + + // Emit a marker call and let the ptoas post-processing step lower it to + // the corresponding tile getter. + Type dstTy = getTypeConverter()->convertType(op.getDst().getType()); + if (!dstTy) + return failure(); + auto call = rewriter.create( + op.getLoc(), + TypeRange{dstTy}, + "PTOAS__TILE_GET_VALUE", + ArrayAttr{}, ArrayAttr{}, + ValueRange{src, offset}); + + rewriter.replaceOp(op, call.getResults()); + return success(); + } +}; + +struct PTOTAxpyToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAxpyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value scalar = peelUnrealized(adaptor.getScalar()); + + rewriter.create( + loc, TypeRange{}, "TAXPY", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, scalar}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOHistogramToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::THistogramOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value idx = peelUnrealized(adaptor.getIdx()); + Value dst = peelUnrealized(adaptor.getDst()); + + auto templateArgs = rewriter.getArrayAttr({emitc::OpaqueAttr::get( + ctx, op.getIsMSB() ? "HistByte::BYTE_1" : "HistByte::BYTE_0")}); + rewriter.create( + loc, TypeRange{}, "THISTOGRAM", + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, + /*operands=*/ValueRange{dst, src, idx}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOGetScaleAddrToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGetScaleAddrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TGET_SCALE_ADDR", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOSetValidShapeToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::SetValidShapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto peelAllCasts = [](Value v) { + while (auto castOp = v.getDefiningOp()) + v = castOp.getOperand(0); + if (auto castOp = v.getDefiningOp()) + v = castOp.getOperand(); + return v; + }; + auto isTileLike = [](Value v) -> bool { + auto ot = dyn_cast(v.getType()); + if (!ot) + return false; + StringRef s = ot.getValue(); + return s.contains("Tile<") || s.contains("ConvTile<"); + }; + + Value src = peelAllCasts(peelUnrealized(adaptor.getSource())); + Value row = peelUnrealized(adaptor.getValidRow()); + Value col = peelUnrealized(adaptor.getValidCol()); + + if (!isTileLike(src)) + return rewriter.notifyMatchFailure( + op, "set_validshape source must lower to a tile-like value"); + + rewriter.create( + op.getLoc(), TypeRange{}, "PTOAS__TILE_SET_VALIDSHAPE", ArrayAttr{}, + ArrayAttr{}, ValueRange{src, row, col}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOGetValidShapeToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::GetValidShapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto peelAllCasts = [](Value v) { + while (auto castOp = v.getDefiningOp()) + v = castOp.getOperand(0); + if (auto castOp = v.getDefiningOp()) + v = castOp.getOperand(); + return v; + }; + auto isTileLike = [](Value v) -> bool { + auto ot = dyn_cast(v.getType()); + if (!ot) + return false; + StringRef s = ot.getValue(); + return s.contains("Tile<") || s.contains("ConvTile<"); + }; + + Value src = peelAllCasts(peelUnrealized(adaptor.getSource())); + if (!isTileLike(src)) + return rewriter.notifyMatchFailure( + op, "get_validshape source must lower to a tile-like value"); + + auto resultTy = getTypeConverter()->convertType(rewriter.getIndexType()); + if (!resultTy) + return failure(); + + Value row = rewriter + .create( + op.getLoc(), resultTy, + "PTOAS__TILE_GET_VALID_ROW", ArrayAttr{}, + ArrayAttr{}, ValueRange{src}) + .getResult(0); + Value col = rewriter + .create( + op.getLoc(), resultTy, + "PTOAS__TILE_GET_VALID_COL", ArrayAttr{}, + ArrayAttr{}, ValueRange{src}) + .getResult(0); + rewriter.replaceOp(op, ValueRange{row, col}); + return success(); + } +}; + +struct PTOTAssignToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAssignOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto peelAllCasts = [](Value v) { + while (auto castOp = v.getDefiningOp()) + v = castOp.getOperand(0); + if (auto castOp = v.getDefiningOp()) + v = castOp.getOperand(); + return v; + }; + auto isTileLike = [](Value v) -> bool { + auto ot = dyn_cast(v.getType()); + if (!ot) + return false; + StringRef s = ot.getValue(); + return s.contains("Tile<") || s.contains("ConvTile<"); + }; + + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value tile = peelAllCasts(peelUnrealized(adaptor.getTile())); + if (!isTileLike(tile)) + return rewriter.notifyMatchFailure( + op, "tassign tile must lower to a tile-like value"); + + Value addr = peelUnrealized(adaptor.getAddr()); + auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); + if (isa(addr.getType()) || + (isa(addr.getType()) && + cast(addr.getType()).getValue().ends_with("*"))) { + auto rcU64 = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + addr = rewriter + .create(loc, u64Ty, "reinterpret_cast", + ArrayAttr{}, rcU64, + ValueRange{addr}) + .getResult(0); + } else if (addr.getType() != u64Ty) { + addr = rewriter.create(loc, u64Ty, addr).getResult(); + } + + rewriter.create(loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{tile, addr}); + rewriter.replaceOp(op, tile); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// pto.load_scalar / pto.store_scalar lowering -> ptr[offset] +//===----------------------------------------------------------------------===// + +static Type getPointerLikeElementType(Type type) { + if (auto ptrTy = dyn_cast(type)) + return ptrTy.getElementType(); + if (auto memTy = dyn_cast(type)) + return memTy.getElementType(); + return Type(); +} + +struct PTOPtrToIntToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::PtrToIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value ptr = peelUnrealized(adaptor.getPtr()); + Type dstTy = getTypeConverter()->convertType(op.getResult().getType()); + if (!dstTy) + return failure(); + + auto dstOpaque = dyn_cast(dstTy); + if (!dstOpaque) + return failure(); + + auto templateArgs = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), + dstOpaque.getValue())}); + auto cast = rewriter.create( + op.getLoc(), dstTy, "reinterpret_cast", ArrayAttr{}, templateArgs, + ValueRange{ptr}); + rewriter.replaceOp(op, cast.getResult(0)); + return success(); + } +}; + +struct PTOIntToPtrToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::IntToPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value addr = peelUnrealized(adaptor.getAddr()); + Type dstTy = getTypeConverter()->convertType(op.getResult().getType()); + if (!dstTy) + return failure(); + + Type dstElemTy = getPointerLikeElementType(op.getResult().getType()); + if (!dstElemTy) + return failure(); + + std::string castType = + std::string("__gm__ ") + getEmitCScalarTypeToken(dstElemTy) + "*"; + auto templateArgs = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), + castType)}); + auto cast = rewriter.create( + op.getLoc(), dstTy, "reinterpret_cast", ArrayAttr{}, templateArgs, + ValueRange{addr}); + rewriter.replaceOp(op, cast.getResult(0)); + return success(); + } +}; + +struct PTOLoadScalarToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::LoadScalarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value ptr = peelUnrealized(adaptor.getPtr()); + Value offset = peelUnrealized(adaptor.getOffset()); + + Type dstTy = getTypeConverter()->convertType(op.getValue().getType()); + if (!dstTy) + return failure(); + + auto call = rewriter.create( + op.getLoc(), TypeRange{dstTy}, "PTOAS__PTR_LOAD", + ArrayAttr{}, ArrayAttr{}, ValueRange{ptr, offset}); + + rewriter.replaceOp(op, call.getResults()); + return success(); + } +}; + +struct PTOStoreScalarToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::StoreScalarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value ptr = peelUnrealized(adaptor.getPtr()); + Value offset = peelUnrealized(adaptor.getOffset()); + Value val = peelUnrealized(adaptor.getValue()); + + rewriter.create( + op.getLoc(), TypeRange{}, "PTOAS__PTR_STORE", + ArrayAttr{}, ArrayAttr{}, ValueRange{ptr, offset, val}); + rewriter.create( + op.getLoc(), TypeRange{}, "PTOAS__SCALAR_GM_STORE_FLUSH", + ArrayAttr{}, ArrayAttr{}, ValueRange{ptr}); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// pto.tabs lowering -> TABS(dst, src) +//===----------------------------------------------------------------------===// + +struct PTOTAbsToTABS : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAbsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + // intrinsic: TABS(dst, src) + rewriter.create( + op.getLoc(), TypeRange{}, "TABS", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tadd lowering -> TADD(dst, src0, src1) +//===----------------------------------------------------------------------===// + +struct PTOTAddToTADD : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAddOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TADD", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOInitializeL2G2LPipeToEmitC + : public OpConversionPattern { + PTOInitializeL2G2LPipeToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + PTOArch targetArch) + : OpConversionPattern(typeConverter, ctx), + targetArch(targetArch) {} + + LogicalResult matchAndRewrite(mlir::pto::InitializeL2G2LPipeOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto tpipeTok = buildTPipeTokenFromInitOp(op.getOperation(), targetArch); + if (failed(tpipeTok)) + return rewriter.notifyMatchFailure(op, "failed to build TPipe token"); + + auto *ctx = rewriter.getContext(); + auto emitPipeTy = + cast(getTypeConverter()->convertType(op.getPipe().getType())); + + Value gmAddr = peelUnrealized(adaptor.getGmAddr()); + gmAddr = materializeTensorViewDataPointer( + rewriter, op.getLoc(), gmAddr, op.getGmAddr().getType()); + Value localAddr = + op.getLocalAddr() ? peelUnrealized(adaptor.getLocalAddr()) : Value(); + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value zero = makeEmitCIntConstant(rewriter, op.getLoc(), i32Ty, 0); + + Value c2vBuf = zero; + Value v2cBuf = zero; + if (op.getDirMask() == 1) + c2vBuf = localAddr ? localAddr : zero; + else if (op.getDirMask() == 2) + v2cBuf = localAddr ? localAddr : zero; + else if (op.getDirMask() == 3) { + if (localAddr) { + if (!op.getPeerLocalAddr()) + return rewriter.notifyMatchFailure( + op, "bidirectional l2g2l pipe requires peer local buffer"); + c2vBuf = localAddr; + v2cBuf = peelUnrealized(adaptor.getPeerLocalAddr()); + } + } else + return rewriter.notifyMatchFailure(op, "unsupported dir_mask"); + + rewriter.replaceOpWithNewOp( + op, TypeRange{emitPipeTy}, *tpipeTok, ArrayAttr{}, ArrayAttr{}, + ValueRange{gmAddr, c2vBuf, v2cBuf}); + return success(); + } + + PTOArch targetArch; +}; + +struct PTOInitializeL2LPipeToEmitC + : public OpConversionPattern { + PTOInitializeL2LPipeToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + PTOArch targetArch) + : OpConversionPattern(typeConverter, ctx), + targetArch(targetArch) {} + + LogicalResult matchAndRewrite(mlir::pto::InitializeL2LPipeOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto tpipeTok = buildTPipeTokenFromInitOp(op.getOperation(), targetArch); + if (failed(tpipeTok)) + return rewriter.notifyMatchFailure(op, "failed to build TPipe token"); + + auto *ctx = rewriter.getContext(); + auto emitPipeTy = + cast(getTypeConverter()->convertType(op.getPipe().getType())); + + auto gmPtrTy = + emitc::PointerType::get(emitc::OpaqueType::get(ctx, "__gm__ void")); + Value nullGm = + makeEmitCOpaqueConstant(rewriter, op.getLoc(), gmPtrTy, "nullptr"); + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value zero = makeEmitCIntConstant(rewriter, op.getLoc(), i32Ty, 0); + Value localAddr = peelUnrealized(adaptor.getLocalAddr()); + + Value c2vBuf = zero; + Value v2cBuf = zero; + if (op.getDirMask() == 1) + c2vBuf = localAddr; + else if (op.getDirMask() == 2) + v2cBuf = localAddr; + else if (op.getDirMask() == 3) { + c2vBuf = localAddr; + v2cBuf = peelUnrealized(adaptor.getPeerLocalAddr()); + } else + return rewriter.notifyMatchFailure(op, "unsupported dir_mask"); + + rewriter.replaceOpWithNewOp( + op, TypeRange{emitPipeTy}, *tpipeTok, ArrayAttr{}, ArrayAttr{}, + ValueRange{nullGm, c2vBuf, v2cBuf}); + return success(); + } + + PTOArch targetArch; +}; + +struct PTOBuildAsyncSessionToEmitC + : public OpConversionPattern { + PTOBuildAsyncSessionToEmitC(TypeConverter &typeConverter, MLIRContext *ctx) + : OpConversionPattern(typeConverter, ctx) {} + + LogicalResult matchAndRewrite(mlir::pto::BuildAsyncSessionOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *ctx = rewriter.getContext(); + Location loc = op.getLoc(); + + auto sessionTy = + dyn_cast(getTypeConverter()->convertType(op.getSession().getType())); + if (!sessionTy) + return rewriter.notifyMatchFailure(op, "failed to convert async session type"); + + FailureOr scratchTile = + buildAsyncScratchTileValue(rewriter, loc, op.getScratch(), + adaptor.getScratch()); + if (failed(scratchTile)) + return rewriter.notifyMatchFailure(op, "failed to materialize async scratch tile"); + + Value workspace = + castToGMBytePointer(rewriter, loc, peelUnrealized(adaptor.getWorkspace())); + + Value session = rewriter + .create( + loc, sessionTy, emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + + auto u32Ty = emitc::OpaqueType::get(ctx, "uint32_t"); + + auto makeU32Const = [&](uint64_t value) -> Value { + return makeEmitCOpaqueConstant(rewriter, loc, u32Ty, + std::to_string(value) + "u"); + }; + uint64_t syncId = op.getSyncIdAttr() ? op.getSyncIdAttr().getInt() : 0; + uint64_t blockBytes = + op.getBlockBytesAttr() ? op.getBlockBytesAttr().getInt() : 32 * 1024; + uint64_t commBlockOffset = + op.getCommBlockOffsetAttr() ? op.getCommBlockOffsetAttr().getInt() : 0; + uint64_t queueNum = op.getQueueNumAttr() ? op.getQueueNumAttr().getInt() : 1; + uint64_t channelGroupIdx = op.getChannelGroupIdxAttr() + ? op.getChannelGroupIdxAttr().getInt() + : UINT32_MAX; + + Value syncIdVal = makeU32Const(syncId); + Value channelGroupIdxVal = + channelGroupIdx == UINT32_MAX + ? makeEmitCOpaqueConstant(rewriter, loc, u32Ty, "UINT32_MAX") + : makeU32Const(channelGroupIdx); + + auto baseConfigTy = + emitc::OpaqueType::get(ctx, "pto::comm::sdma::SdmaBaseConfig"); + Value baseConfig = + rewriter + .create( + loc, baseConfigTy, + emitc::OpaqueAttr::get( + ctx, "{" + std::to_string(blockBytes) + "ULL, " + + std::to_string(commBlockOffset) + "ULL, " + + std::to_string(queueNum) + "u}")) + .getResult(); + + rewriter.create( + loc, TypeRange{}, "pto::comm::BuildAsyncSession", + ArrayAttr{}, ArrayAttr{}, + ValueRange{*scratchTile, workspace, session, syncIdVal, baseConfig, + channelGroupIdxVal}); + + rewriter.replaceOp(op, session); + return success(); + } +}; + +template +struct PTOAsyncTransferToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + explicit PTOAsyncTransferToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + StringRef callee) + : OpConversionPattern(typeConverter, ctx), callee(callee.str()) {} + + LogicalResult matchAndRewrite(AsyncOp op, typename AsyncOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value dst = peelUnrealized(adaptor.getDst()); + Value src = peelUnrealized(adaptor.getSrc()); + Value dstGT = dst; + Value srcGT = src; + if (!isEmitCGlobalTensorLikeType(dstGT.getType())) { + auto dstMrTy = dyn_cast(op.getDst().getType()); + if (!dstMrTy) + return rewriter.notifyMatchFailure(op, "expected dst to lower to GlobalTensor or memref"); + dstGT = buildGlobalTensorFromMemref(rewriter, op.getLoc(), dst, dstMrTy, + op.getDst().getDefiningOp() + ? op.getDst().getDefiningOp() + : op.getOperation()); + } + if (!isEmitCGlobalTensorLikeType(srcGT.getType())) { + auto srcMrTy = dyn_cast(op.getSrc().getType()); + if (!srcMrTy) + return rewriter.notifyMatchFailure(op, "expected src to lower to GlobalTensor or memref"); + srcGT = buildGlobalTensorFromMemref(rewriter, op.getLoc(), src, srcMrTy, + op.getSrc().getDefiningOp() + ? op.getSrc().getDefiningOp() + : op.getOperation()); + } + if (!dstGT || !srcGT) + return rewriter.notifyMatchFailure(op, "failed to build GlobalTensor operands"); + + Type eventTy = this->getTypeConverter()->convertType(op.getEvent().getType()); + if (!eventTy) + return rewriter.notifyMatchFailure(op, "failed to convert async event type"); + + rewriter.replaceOpWithNewOp( + op, TypeRange{eventTy}, callee, ArrayAttr{}, ArrayAttr{}, + ValueRange{dstGT, srcGT, peelUnrealized(adaptor.getSession())}); + return success(); + } + + std::string callee; +}; + +template +struct PTOAsyncEventToEmitC : public OpConversionPattern { + explicit PTOAsyncEventToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + StringRef callee) + : OpConversionPattern(typeConverter, ctx), + callee(callee.str()) {} + + LogicalResult matchAndRewrite(AsyncEventOp op, + typename AsyncEventOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultTy = + this->getTypeConverter()->convertType(op.getCompleted().getType()); + if (!resultTy) + return rewriter.notifyMatchFailure(op, "failed to convert async event result type"); + + rewriter.replaceOpWithNewOp( + op, TypeRange{resultTy}, callee, ArrayAttr{}, ArrayAttr{}, + ValueRange{peelUnrealized(adaptor.getEvent()), + peelUnrealized(adaptor.getSession())}); + return success(); + } + + std::string callee; +}; + +static FailureOr buildCommGlobalTensorValue( + ConversionPatternRewriter &rewriter, Location loc, Value originalValue, + Value emittedValue, Operation *anchor) { + Value value = peelUnrealized(emittedValue); + if (isEmitCGlobalTensorLikeType(value.getType())) + return value; + + auto memTy = dyn_cast(originalValue.getType()); + if (!memTy) + return failure(); + + Value gt = buildGlobalTensorFromMemref(rewriter, loc, value, memTy, anchor); + if (!gt) + return failure(); + return gt; +} + +static FailureOr buildCommTileValue(ConversionPatternRewriter &rewriter, + Location loc, Value originalValue, + Value emittedValue) { + Value value = peelUnrealized(emittedValue); + if (auto opaqueTy = dyn_cast(value.getType())) { + StringRef typeStr = opaqueTy.getValue(); + if (typeStr.contains("Tile<") || typeStr.contains("ConvTile<")) + return value; + } + return buildAsyncScratchTileValue(rewriter, loc, originalValue, emittedValue); +} + +static FailureOr buildCollectiveParallelGroup( + ConversionPatternRewriter &rewriter, Location loc, + ArrayRef groupGTs, int64_t root) { + if (groupGTs.empty()) + return failure(); + + auto firstTy = dyn_cast(groupGTs.front().getType()); + if (!firstTy) + return failure(); + + auto *ctx = rewriter.getContext(); + auto arrayTy = emitc::ArrayType::get({static_cast(groupGTs.size())}, + firstTy); + auto groupArray = cast>( + rewriter + .create(loc, arrayTy, + emitc::OpaqueAttr::get(ctx, "{}")) + .getResult()); + + auto indexTy = emitc::OpaqueType::get(ctx, "int"); + for (auto [idx, groupVal] : llvm::enumerate(groupGTs)) { + Value idxVal = + makeEmitCIntConstant(rewriter, loc, indexTy, static_cast(idx)); + Value slot = + rewriter.create(loc, groupArray, ValueRange{idxVal}) + .getResult(); + rewriter.create(loc, slot, groupVal); + } + + std::string pgTypeStr = + (Twine("pto::comm::ParallelGroup<") + firstTy.getValue() + ">").str(); + auto pgTy = emitc::OpaqueType::get(ctx, pgTypeStr); + Value sizeVal = makeEmitCIntConstant(rewriter, loc, indexTy, + static_cast(groupGTs.size())); + Value rootVal = makeEmitCIntConstant(rewriter, loc, indexTy, root); + return rewriter + .create( + loc, TypeRange{pgTy}, (Twine(pgTypeStr) + "::Create").str(), + ArrayAttr{}, ArrayAttr{}, ValueRange{groupArray, sizeVal, rootVal}) + .getResult(0); +} + +static std::string notifyOpTok(pto::NotifyOp op) { + switch (op) { + case pto::NotifyOp::AtomicAdd: + return "pto::comm::NotifyOp::AtomicAdd"; + case pto::NotifyOp::Set: + return "pto::comm::NotifyOp::Set"; + } + return "pto::comm::NotifyOp::Set"; +} + +static std::string waitCmpTok(pto::WaitCmp cmp) { + switch (cmp) { + case pto::WaitCmp::EQ: + return "pto::comm::WaitCmp::EQ"; + case pto::WaitCmp::NE: + return "pto::comm::WaitCmp::NE"; + case pto::WaitCmp::GT: + return "pto::comm::WaitCmp::GT"; + case pto::WaitCmp::GE: + return "pto::comm::WaitCmp::GE"; + case pto::WaitCmp::LT: + return "pto::comm::WaitCmp::LT"; + case pto::WaitCmp::LE: + return "pto::comm::WaitCmp::LE"; + } + return "pto::comm::WaitCmp::EQ"; +} + +static std::string reduceOpTok(pto::ReduceOp op) { + switch (op) { + case pto::ReduceOp::Sum: + return "pto::comm::ReduceOp::Sum"; + case pto::ReduceOp::Max: + return "pto::comm::ReduceOp::Max"; + case pto::ReduceOp::Min: + return "pto::comm::ReduceOp::Min"; + } + return "pto::comm::ReduceOp::Sum"; +} + +template +static FailureOr> buildCommGroupGlobalTensors( + ConversionPatternRewriter &rewriter, Location loc, OpTy op, + ValueRange originalGroup, ValueRange emittedGroup) { + SmallVector groupGTs; + groupGTs.reserve(originalGroup.size()); + for (auto [orig, emitted] : llvm::zip(originalGroup, emittedGroup)) { + FailureOr gt = + buildCommGlobalTensorValue(rewriter, loc, orig, emitted, op.getOperation()); + if (failed(gt)) + return failure(); + groupGTs.push_back(*gt); + } + return groupGTs; +} + +template +struct PTOCommCollectiveToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + explicit PTOCommCollectiveToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + StringRef apiName) + : OpConversionPattern(typeConverter, ctx), + apiName(apiName.str()) {} + + LogicalResult matchAndRewrite(CollectiveOp op, typename CollectiveOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + auto buildPong = [&](Value original, Value emitted, StringRef name) -> FailureOr { + if (!original) + return failure(); + return buildCommTileValue(rewriter, loc, original, emitted); + }; + + if constexpr (std::is_same_v) { + FailureOr srcGT = + buildCommGlobalTensorValue(rewriter, loc, op.getSrc(), adaptor.getSrc(), + op.getOperation()); + FailureOr pingTile = + buildCommTileValue(rewriter, loc, op.getPing(), adaptor.getPing()); + auto groupGTs = + buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); + if (failed(srcGT) || failed(pingTile) || failed(groupGTs)) + return rewriter.notifyMatchFailure(op, "failed to materialize broadcast operands"); + FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); + if (failed(pg)) + return rewriter.notifyMatchFailure(op, "failed to materialize broadcast group"); + if (op.getPong()) { + FailureOr pongTile = + buildPong(op.getPong(), adaptor.getPong(), "__pong"); + if (failed(pongTile)) + return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); + rewriter.create( + loc, TypeRange{}, "pto::comm::TBROADCAST", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *srcGT, *pingTile, *pongTile}); + } else { + rewriter.create( + loc, TypeRange{}, "pto::comm::TBROADCAST", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *srcGT, *pingTile}); + } + } else if constexpr (std::is_same_v) { + FailureOr dstGT = + buildCommGlobalTensorValue(rewriter, loc, op.getDst(), adaptor.getDst(), + op.getOperation()); + FailureOr pingTile = + buildCommTileValue(rewriter, loc, op.getPing(), adaptor.getPing()); + auto groupGTs = + buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); + if (failed(dstGT) || failed(pingTile) || failed(groupGTs)) + return rewriter.notifyMatchFailure(op, "failed to materialize gather operands"); + FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); + if (failed(pg)) + return rewriter.notifyMatchFailure(op, "failed to materialize gather group"); + if (op.getPong()) { + FailureOr pongTile = + buildPong(op.getPong(), adaptor.getPong(), "__pong"); + if (failed(pongTile)) + return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); + rewriter.create( + loc, TypeRange{}, "pto::comm::TGATHER", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *dstGT, *pingTile, *pongTile}); + } else { + rewriter.create( + loc, TypeRange{}, "pto::comm::TGATHER", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *dstGT, *pingTile}); + } + } else if constexpr (std::is_same_v) { + FailureOr srcGT = + buildCommGlobalTensorValue(rewriter, loc, op.getSrc(), adaptor.getSrc(), + op.getOperation()); + FailureOr pingTile = + buildCommTileValue(rewriter, loc, op.getPing(), adaptor.getPing()); + auto groupGTs = + buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); + if (failed(srcGT) || failed(pingTile) || failed(groupGTs)) + return rewriter.notifyMatchFailure(op, "failed to materialize scatter operands"); + FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); + if (failed(pg)) + return rewriter.notifyMatchFailure(op, "failed to materialize scatter group"); + if (op.getPong()) { + FailureOr pongTile = + buildPong(op.getPong(), adaptor.getPong(), "__pong"); + if (failed(pongTile)) + return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); + rewriter.create( + loc, TypeRange{}, "pto::comm::TSCATTER", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *srcGT, *pingTile, *pongTile}); + } else { + rewriter.create( + loc, TypeRange{}, "pto::comm::TSCATTER", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *srcGT, *pingTile}); + } + } else { + FailureOr dstGT = + buildCommGlobalTensorValue(rewriter, loc, op.getDst(), adaptor.getDst(), + op.getOperation()); + FailureOr accTile = + buildCommTileValue(rewriter, loc, op.getAcc(), adaptor.getAcc()); + FailureOr recvPing = + buildCommTileValue(rewriter, loc, op.getRecvPing(), adaptor.getRecvPing()); + auto groupGTs = + buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); + if (failed(dstGT) || failed(accTile) || failed(recvPing) || failed(groupGTs)) + return rewriter.notifyMatchFailure(op, "failed to materialize reduce operands"); + FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); + if (failed(pg)) + return rewriter.notifyMatchFailure(op, "failed to materialize reduce group"); + if (op.getRecvPong()) { + FailureOr recvPong = + buildPong(op.getRecvPong(), adaptor.getRecvPong(), "__recv_pong"); + if (failed(recvPong)) + return rewriter.notifyMatchFailure(op, "failed to materialize recv_pong"); + auto reduceTy = + emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::ReduceOp"); + Value reduceOp = makeEmitCOpaqueConstant(rewriter, loc, reduceTy, + reduceOpTok(op.getReduceOp())); + rewriter.create( + loc, TypeRange{}, "pto::comm::TREDUCE", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *dstGT, *accTile, *recvPing, *recvPong, reduceOp}); + } else { + auto reduceTy = + emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::ReduceOp"); + Value reduceOp = makeEmitCOpaqueConstant(rewriter, loc, reduceTy, + reduceOpTok(op.getReduceOp())); + rewriter.create( + loc, TypeRange{}, "pto::comm::TREDUCE", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *dstGT, *accTile, *recvPing, reduceOp}); + } + } + rewriter.eraseOp(op); + return success(); + } + + std::string apiName; +}; + +template +struct PTOP2PCommToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + explicit PTOP2PCommToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + StringRef callee) + : OpConversionPattern(typeConverter, ctx), callee(callee.str()) {} + + LogicalResult matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr dstGT = + buildCommGlobalTensorValue(rewriter, op.getLoc(), op.getDst(), adaptor.getDst(), + op.getOperation()); + FailureOr srcGT = + buildCommGlobalTensorValue(rewriter, op.getLoc(), op.getSrc(), adaptor.getSrc(), + op.getOperation()); + FailureOr pingTile = + buildCommTileValue(rewriter, op.getLoc(), op.getPing(), adaptor.getPing()); + if (failed(dstGT) || failed(srcGT) || failed(pingTile)) + return rewriter.notifyMatchFailure(op, "failed to materialize p2p operands"); + + SmallVector operands{*dstGT, *srcGT, *pingTile}; + std::string actualCallee = callee; + if constexpr (std::is_same_v) { + if (op.getAtomicType() == pto::AtomicType::AtomicAdd) + actualCallee = "pto::comm::TPUT"; + } + if (op.getPong()) { + FailureOr pongTile = + buildCommTileValue(rewriter, op.getLoc(), op.getPong(), adaptor.getPong()); + if (failed(pongTile)) + return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); + operands.push_back(*pongTile); + } + + rewriter.create(op.getLoc(), TypeRange{}, actualCallee, + ArrayAttr{}, ArrayAttr{}, operands); + rewriter.eraseOp(op); + return success(); + } + + std::string callee; +}; + +template +struct PTOSignalCommToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + explicit PTOSignalCommToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + StringRef callee) + : OpConversionPattern(typeConverter, ctx), + callee(callee.str()) {} + + LogicalResult matchAndRewrite(SignalOp op, typename SignalOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr signalGT = buildCommGlobalTensorValue( + rewriter, op.getLoc(), op.getSignal(), adaptor.getSignal(), op.getOperation()); + if (failed(signalGT)) + return rewriter.notifyMatchFailure(op, "failed to materialize signal operand"); + + if constexpr (std::is_same_v) { + auto notifyTy = + emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::NotifyOp"); + Value notifyOp = makeEmitCOpaqueConstant( + rewriter, op.getLoc(), notifyTy, notifyOpTok(op.getNotifyOp())); + SmallVector operands{*signalGT, peelUnrealized(adaptor.getValue()), + notifyOp}; + rewriter.create(op.getLoc(), TypeRange{}, callee, + ArrayAttr{}, ArrayAttr{}, operands); + rewriter.eraseOp(op); + } else { + auto waitCmpTy = + emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::WaitCmp"); + Value waitCmp = makeEmitCOpaqueConstant( + rewriter, op.getLoc(), waitCmpTy, waitCmpTok(op.getCmp())); + SmallVector operands{*signalGT, peelUnrealized(adaptor.getCmpValue()), + waitCmp}; + if constexpr (std::is_same_v) { + Type resultTy = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultTy) + return rewriter.notifyMatchFailure(op, "failed to convert ttest result type"); + rewriter.replaceOpWithNewOp( + op, TypeRange{resultTy}, callee, ArrayAttr{}, ArrayAttr{}, operands); + } else { + rewriter.create(op.getLoc(), TypeRange{}, callee, + ArrayAttr{}, ArrayAttr{}, operands); + rewriter.eraseOp(op); + } + } + return success(); + } + + std::string callee; +}; + +struct PTODeclareTileMemRefToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::DeclareTileMemRefOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::DeclareTileMemRefOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + Type convertedType = getTypeConverter()->convertType(op.getResult().getType()); + if (!convertedType) + return rewriter.notifyMatchFailure( + op, "failed to convert declare_tile_memref result type"); + rewriter.replaceOp(op, makeEmitCOpaqueConstant(rewriter, op.getLoc(), + convertedType, "nullptr")); + return success(); + } +}; + +struct PTODeclareGlobalToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::DeclareGlobalOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::DeclareGlobalOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + Type convertedType = getTypeConverter()->convertType(op.getEntry().getType()); + if (!convertedType) + return rewriter.notifyMatchFailure( + op, "failed to convert declare_global result type"); + if (auto tvTy = dyn_cast(op.getEntry().getType())) { + if (auto stridesAttr = + op->getAttrOfType(kGlobalTensorStridesAttrName)) { + auto strides = stridesAttr.asArrayRef(); + if (strides.size() == static_cast(tvTy.getRank())) { + convertedType = emitc::OpaqueType::get( + rewriter.getContext(), + getGlobalTensorTypeStringFromShapeAndStrides( + tvTy.getElementType(), tvTy.getShape(), strides)); + } + } + } + auto var = rewriter.create( + op.getLoc(), convertedType, + emitc::OpaqueAttr::get(rewriter.getContext(), "")); + rewriter.replaceOp(op, var.getResult()); + return success(); + } +}; + +struct PTODeclareEventIdArrayToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::DeclareEventIdArrayOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::DeclareEventIdArrayOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + Type arrayTy = getTypeConverter()->convertType(op.getArray().getType()); + if (!arrayTy) + return rewriter.notifyMatchFailure(op, + "failed to map declared eventid_array type"); + + auto array = rewriter + .create( + op.getLoc(), arrayTy, + emitc::OpaqueAttr::get(rewriter.getContext(), "")) + .getResult(); + rewriter.replaceOp(op, array); + return success(); + } +}; + +struct PTOEventIdArrayGetToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::EventIdArrayGetOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::EventIdArrayGetOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value array = peelUnrealized(adaptor.getArray()); + Value index = peelUnrealized(adaptor.getIndex()); + + Type resultTy = getTypeConverter()->convertType(op.getResult().getType()); + if (!resultTy) + return rewriter.notifyMatchFailure(op, + "failed to map eventid_array get result type"); + + auto load = + rewriter.create(op.getLoc(), resultTy, array, index); + rewriter.replaceOp(op, load.getResult()); + return success(); + } +}; + +struct PTOEventIdArraySetToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::EventIdArraySetOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::EventIdArraySetOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value array = peelUnrealized(adaptor.getArray()); + Value index = peelUnrealized(adaptor.getIndex()); + Value value = peelUnrealized(adaptor.getValue()); + + rewriter.create( + op.getLoc(), TypeRange{}, "PTOAS__EVENTID_ARRAY_STORE", + ArrayAttr{}, ArrayAttr{}, ValueRange{array, index, value}); + rewriter.eraseOp(op); + return success(); + } +}; + +// pto.declare_local_array -> emitc.variable of !emitc.array<...>. +// Renders as `T a[D1][D2]...;` in the emitted C++. +struct PTODeclareLocalArrayToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::DeclareLocalArrayOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::DeclareLocalArrayOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + Type arrayTy = getTypeConverter()->convertType(op.getArray().getType()); + if (!arrayTy) + return rewriter.notifyMatchFailure(op, + "failed to map !pto.local_array type"); + + auto var = rewriter + .create( + op.getLoc(), arrayTy, + emitc::OpaqueAttr::get(rewriter.getContext(), "")) + .getResult(); + rewriter.replaceOp(op, var); + return success(); + } +}; + +// pto.local_array_get %a[%i0, %i1, ...] -> rvalue. +// Lowers to a single emitc.subscript with the full index pack; the C++ emitter +// prints it as `a[i0][i1]...`. The adaptor already exposes target-typed values +// (the type converter has remapped !pto.local_array -> !emitc.array and +// index/integer indices), so they're forwarded directly to the builder. +struct PTOLocalArrayGetToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::LocalArrayGetOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::LocalArrayGetOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultTy = + getTypeConverter()->convertType(op.getResult().getType()); + if (!resultTy) + return rewriter.notifyMatchFailure( + op, "failed to map local_array element type"); + + auto sub = rewriter.create( + op.getLoc(), resultTy, adaptor.getArray(), adaptor.getIndices()); + rewriter.replaceOp(op, sub.getResult()); + return success(); + } +}; + +// pto.local_array_set %a[%i0, %i1, ...], %v -> emitc.assign to subscript slot. +// The C++ emitter prints this as `a[i0][i1]... = v;`. As above, adaptor values +// are already target-typed; pass them through directly. +struct PTOLocalArraySetToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::LocalArraySetOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::LocalArraySetOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value value = adaptor.getValue(); + Type elemTy = value.getType(); + + Value slot = rewriter + .create(op.getLoc(), elemTy, + adaptor.getArray(), + adaptor.getIndices()) + .getResult(); + rewriter.create(op.getLoc(), slot, value); + rewriter.eraseOp(op); + return success(); + } +}; + +static std::optional getStaticIndexLikeValue(Value value) { + if (!value) + return std::nullopt; + if (auto cst = value.getDefiningOp()) + return cst.value(); + if (auto cst = value.getDefiningOp()) + return cst.value(); + if (auto cst = value.getDefiningOp()) { + if (auto intAttr = dyn_cast(cst.getValue())) + return intAttr.getInt(); + } + return std::nullopt; +} + +static FailureOr buildGlobalTensorViewFromPointer( + ConversionPatternRewriter &rewriter, Location loc, Value ptr, Type elemTy, + ArrayRef shape, ArrayRef strides = {}, + StringRef layoutEnum = "pto::Layout::ND") { + if (llvm::any_of(shape, [](int64_t dim) { + return dim == ShapedType::kDynamic; + })) + return failure(); + + auto *ctx = rewriter.getContext(); + SmallVector rowMajorStrides; + ArrayRef effectiveStrides = strides; + if (effectiveStrides.empty()) { + rowMajorStrides = buildRowMajorStrides(shape); + effectiveStrides = rowMajorStrides; + } + SmallVector shape5D; + SmallVector stride5D; + buildGlobalTensorShapeAndStride(shape, effectiveStrides, shape5D, stride5D); + + std::string shapeType = "pto::Shape<" + joinIntTemplateParams(shape5D) + ">"; + std::string strideType = + "pto::Stride<" + joinIntTemplateParams(stride5D) + ">"; + auto shapeVal = rewriter + .create( + loc, emitc::OpaqueType::get(ctx, shapeType), + shapeType, ArrayAttr{}, ArrayAttr{}, ValueRange{}) + .getResult(0); + auto strideVal = rewriter + .create( + loc, emitc::OpaqueType::get(ctx, strideType), + strideType, ArrayAttr{}, ArrayAttr{}, ValueRange{}) + .getResult(0); + + std::string gtTypeStr = + getGlobalTensorTypeStringFromShapeAndStrides(elemTy, shape, + effectiveStrides, + layoutEnum); + auto gtType = emitc::OpaqueType::get(ctx, gtTypeStr); + auto gt = rewriter.create( + loc, gtType, gtTypeStr, ArrayAttr{}, ArrayAttr{}, + ValueRange{ptr, shapeVal, strideVal}); + return gt.getResult(0); +} + +static bool parseIntegerTemplateList(StringRef token, StringRef marker, + SmallVectorImpl &values) { + size_t pos = token.find(marker); + if (pos == StringRef::npos) + return false; + pos += marker.size(); + size_t end = token.find('>', pos); + if (end == StringRef::npos) + return false; + + SmallVector parts; + token.slice(pos, end).split(parts, ','); + values.clear(); + for (StringRef part : parts) { + int64_t value = 0; + if (part.trim().getAsInteger(10, value)) + return false; + values.push_back(value); + } + return true; +} + +static LogicalResult getStaticTensorViewStrides( + Value source, Value convertedSource, pto::TensorViewType sourceType, + SmallVectorImpl &strides) { + int64_t rank = sourceType.getRank(); + strides.clear(); + + if (auto makeView = source.getDefiningOp()) { + if ((int64_t)makeView.getStrides().size() != rank) + return failure(); + for (Value strideValue : makeView.getStrides()) { + auto cst = getStaticIndexLikeValue(strideValue); + if (!cst) + return failure(); + strides.push_back(*cst); + } + return success(); + } + + Value src = peelUnrealized(convertedSource); + if (auto opaqueTy = dyn_cast(src.getType())) { + SmallVector stride5D; + StringRef token = opaqueTy.getValue(); + if ((parseIntegerTemplateList(token, "pto::Stride<", stride5D) || + parseIntegerTemplateList(token, "Stride<", stride5D)) && + (int64_t)stride5D.size() >= rank) { + strides.append(stride5D.end() - rank, stride5D.end()); + return success(); + } + } + + auto fallback = buildRowMajorStrides(sourceType.getShape()); + strides.append(fallback.begin(), fallback.end()); + return success(); +} + +struct PTOPartitionViewToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::PartitionViewOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::PartitionViewOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcTy = dyn_cast(op.getSource().getType()); + auto resTy = dyn_cast(op.getResult().getType()); + if (!srcTy || !resTy) + return rewriter.notifyMatchFailure( + op, "expected tensor_view source and partition_tensor_view result"); + + if (op.getOffsets().size() != static_cast(srcTy.getRank()) || + op.getSizes().size() != static_cast(srcTy.getRank())) + return rewriter.notifyMatchFailure(op, "rank mismatch"); + + for (auto [idx, value] : llvm::enumerate(op.getSizes())) { + auto cst = getStaticIndexLikeValue(value); + if (!cst) + return rewriter.notifyMatchFailure( + op, "globaltensor partition_view requires static sizes"); + int64_t resultDim = resTy.getShape()[idx]; + if (resultDim != ShapedType::kDynamic && resultDim != *cst) + return rewriter.notifyMatchFailure( + op, "partition_view static size does not match result type"); + } + + SmallVector srcStrides; + if (failed(getStaticTensorViewStrides(op.getSource(), adaptor.getSource(), + srcTy, srcStrides))) + return rewriter.notifyMatchFailure( + op, "partition_view requires static source strides"); + int64_t staticLinearOffset = 0; + SmallVector> dynamicOffsetTerms; + for (auto [idx, values] : + llvm::enumerate(llvm::zip(op.getOffsets(), adaptor.getOffsets()))) { + Value originalOffset = std::get<0>(values); + Value convertedOffset = std::get<1>(values); + int64_t stride = srcStrides[idx]; + if (stride == ShapedType::kDynamic) + return rewriter.notifyMatchFailure( + op, "dynamic source stride is not supported"); + + if (auto cst = getStaticIndexLikeValue(originalOffset)) { + if (*cst != 0) + staticLinearOffset += (*cst) * stride; + continue; + } + dynamicOffsetTerms.push_back({convertedOffset, stride}); + } + + auto *ctx = rewriter.getContext(); + std::string elemTypeStr = getElemTypeStringForGT(srcTy.getElementType()); + auto ptrTy = emitc::PointerType::get( + emitc::OpaqueType::get(ctx, "__gm__ " + elemTypeStr)); + Value src = peelUnrealized(adaptor.getSource()); + auto data = rewriter + .create( + op.getLoc(), ptrTy, "PTOAS__GLOBAL_TENSOR_DATA", + ArrayAttr{}, ArrayAttr{}, ValueRange{src}) + .getResult(0); + Value ptr = data; + if (!dynamicOffsetTerms.empty()) { + Type u32Ty = emitc::OpaqueType::get(ctx, "unsigned"); + auto makeU32 = [&](int64_t value) { + return makeEmitCIntConstant(rewriter, op.getLoc(), u32Ty, value); + }; + auto asU32 = [&](Value value) -> Value { + if (value.getType() == u32Ty) + return value; + return rewriter.create(op.getLoc(), u32Ty, value) + .getResult(); + }; + + Value totalOffset = makeU32(staticLinearOffset); + for (auto [offsetValue, stride] : dynamicOffsetTerms) { + Value term = asU32(offsetValue); + if (stride != 1) { + Value strideValue = makeU32(stride); + term = rewriter + .create(op.getLoc(), u32Ty, term, + strideValue) + .getResult(); + } + totalOffset = rewriter + .create(op.getLoc(), u32Ty, + totalOffset, term) + .getResult(); + } + ptr = rewriter + .create(op.getLoc(), data.getType(), data, + totalOffset) + .getResult(); + } else { + ptr = applyStaticMemrefOffset(rewriter, op.getLoc(), data, + staticLinearOffset); + } + + auto resultOr = buildGlobalTensorViewFromPointer( + rewriter, op.getLoc(), ptr, resTy.getElementType(), resTy.getShape(), + srcStrides); + if (failed(resultOr)) + return rewriter.notifyMatchFailure( + op, "failed to materialize partition GlobalTensor"); + + rewriter.replaceOp(op, *resultOr); + return success(); + } +}; + +static FailureOr getPipeDataTypeToken(Value value) { + auto opaqueTy = dyn_cast(value.getType()); + if (!opaqueTy) + return failure(); + StringRef token = opaqueTy.getValue(); + if (!token.contains("Tile<") && !token.contains("GlobalTensor<")) + return failure(); + return token.str(); +} + +struct PTOTAllocToEmitC : public OpConversionPattern { + PTOTAllocToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + PTOArch targetArch) + : OpConversionPattern(typeConverter, ctx), + targetArch(targetArch) {} + + LogicalResult matchAndRewrite(mlir::pto::TAllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); + if (failed(pipeTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); + Value entry = peelUnrealized(adaptor.getEntry()); + auto entryTok = getPipeDataTypeToken(entry); + if (failed(entryTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve entry token"); + auto splitTok = getTileSplitToken(op.getSplit()); + if (failed(splitTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve split token"); + + std::string callee = + "TALLOC<" + *pipeTok + ", " + *entryTok + ", " + *splitTok + ">"; + rewriter.replaceOpWithNewOp( + op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, + ValueRange{peelUnrealized(adaptor.getPipeHandle()), entry}); + return success(); + } + + PTOArch targetArch; +}; + +struct PTOTPushToEmitC : public OpConversionPattern { + PTOTPushToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + PTOArch targetArch) + : OpConversionPattern(typeConverter, ctx), + targetArch(targetArch) {} + + LogicalResult matchAndRewrite(mlir::pto::TPushOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); + if (failed(pipeTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); + // Read the tile type token from the already-converted OpaqueType, which + // preserves the exact layout produced by BindTileOp / PointerCastOp EmitC. + Value convertedTile = peelUnrealized(adaptor.getTile()); + auto tileTok = getPipeDataTypeToken(convertedTile); + if (failed(tileTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve tile token"); + auto splitTok = getTileSplitToken(op.getSplit()); + if (failed(splitTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve split token"); + + std::string callee = + "TPUSH<" + *pipeTok + ", " + *tileTok + ", " + *splitTok + ">"; + rewriter.replaceOpWithNewOp( + op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, + ValueRange{peelUnrealized(adaptor.getPipeHandle()), convertedTile}); + return success(); + } + + PTOArch targetArch; +}; + +struct PTOTPopToEmitC : public OpConversionPattern { + PTOTPopToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + PTOArch targetArch) + : OpConversionPattern(typeConverter, ctx), + targetArch(targetArch) {} + + LogicalResult matchAndRewrite(mlir::pto::TPopOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); + if (failed(pipeTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); + Value convertedTile = peelUnrealized(adaptor.getTile()); + auto tileTok = getPipeDataTypeToken(convertedTile); + if (failed(tileTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve tile token"); + auto splitTok = getTileSplitToken(op.getSplit()); + if (failed(splitTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve split token"); + + std::string callee = + "TPOP<" + *pipeTok + ", " + *tileTok + ", " + *splitTok + ">"; + rewriter.replaceOpWithNewOp( + op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, + ValueRange{peelUnrealized(adaptor.getPipeHandle()), convertedTile}); + return success(); + } + + PTOArch targetArch; +}; + +struct PTOTFreeToEmitC : public OpConversionPattern { + PTOTFreeToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + PTOArch targetArch) + : OpConversionPattern(typeConverter, ctx), + targetArch(targetArch) {} + + LogicalResult matchAndRewrite(mlir::pto::TFreeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); + if (failed(pipeTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); + auto splitTok = getTileSplitToken(op.getSplit()); + if (failed(splitTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve split token"); + + SmallVector operands{peelUnrealized(adaptor.getPipeHandle())}; + std::string callee; + if (op.getEntry()) { + Value entry = peelUnrealized(adaptor.getEntry()); + auto entryTok = getPipeDataTypeToken(entry); + if (failed(entryTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve entry token"); + callee = "TFREE<" + *pipeTok + ", " + *entryTok + ", " + *splitTok + ">"; + operands.push_back(entry); + } else { + callee = "TFREE<" + *pipeTok + ", " + *splitTok + ">"; + } + rewriter.replaceOpWithNewOp( + op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, operands); + return success(); + } + + PTOArch targetArch; +}; + +//===----------------------------------------------------------------------===// +// populate patterns +//===----------------------------------------------------------------------=== +struct ReinterpretCastToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + auto resMrTy = dyn_cast(op.getType()); + if (!resMrTy) + return failure(); + + auto asAttr = dyn_cast_or_null(resMrTy.getMemorySpace()); + const bool isGm = (!asAttr || asAttr.getAddressSpace() == pto::AddressSpace::GM); + + bool emitAddPtrTrace = op->hasAttr("pto.addptr_trace"); + Value source = peelUnrealized(adaptor.getSource()); + auto offsets = adaptor.getOffsets(); + Value offsetVal = offsets.empty() ? Value() : offsets[0]; + + // GM: keep pointer arithmetic. + if (isGm) { + if (!offsetVal) { + rewriter.replaceOp(op, source); + return success(); + } + + Type resultType = getTypeConverter()->convertType(op.getType()); + if (!resultType) + return failure(); + + auto addOp = rewriter.create(loc, resultType, source, offsetVal); + if (emitAddPtrTrace) { + rewriter.setInsertionPointAfter(addOp); + rewriter.create( + loc, TypeRange{}, "PTOAS__ADDPTR_TRACE", + ArrayAttr{}, ArrayAttr{}, + ValueRange{addOp.getResult(), source, offsetVal}); + } + rewriter.replaceOp(op, addOp.getResult()); + return success(); + } + + // UB/L1/L0 tiles: materialize a new Tile view by assigning an adjusted + // underlying pointer (in elements). + pto::AddressSpace as = asAttr.getAddressSpace(); + + // Element type token. + Type elemTy = resMrTy.getElementType(); + std::string elemTok = getEmitCScalarTypeToken(elemTy); + int64_t elemBytes = getEmitCScalarByteWidth(elemTy); + + // Tile role. + const char *roleTok = "TileType::Vec"; + switch (as) { + case pto::AddressSpace::VEC: + roleTok = "TileType::Vec"; + break; + case pto::AddressSpace::MAT: + roleTok = "TileType::Mat"; + break; + case pto::AddressSpace::LEFT: + roleTok = "TileType::Left"; + break; + case pto::AddressSpace::RIGHT: + roleTok = "TileType::Right"; + break; + case pto::AddressSpace::ACC: + roleTok = "TileType::Acc"; + break; + case pto::AddressSpace::BIAS: + roleTok = "TileType::Bias"; + break; + case pto::AddressSpace::GM: + roleTok = "TileType::Vec"; + break; + case pto::AddressSpace::Zero: + roleTok = "TileType::Vec"; + break; + case pto::AddressSpace::SCALING: + roleTok = "TileType::Scaling"; + break; + } + + // Shape (fallback to 32x32). + int64_t rows = 32, cols = 32; + if (resMrTy.getRank() >= 2 && resMrTy.hasStaticShape()) { + rows = resMrTy.getDimSize(0); + cols = resMrTy.getDimSize(1); + } + int64_t templateRows = + renderTileTemplateDim(rows, elemTy, pto::BLayout::RowMajor, 0); + int64_t templateCols = + renderTileTemplateDim(cols, elemTy, pto::BLayout::RowMajor, 1); + + // Keep a conservative default config for now. + std::string tileTypeStr = + std::string("Tile<") + roleTok + ", " + elemTok + ", " + + std::to_string(templateRows) + ", " + std::to_string(templateCols) + + ", BLayout::RowMajor, " + std::to_string(templateRows) + ", " + + std::to_string(templateCols) + + ", SLayout::NoneBox, 512, PadValue::Null, CompactMode::Null>"; + + auto tileType = emitc::OpaqueType::get(ctx, tileTypeStr); + Value tile = rewriter + .create(loc, tileType, + emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + + // Compute an integer address and assign it to the new tile. + // NOTE: pto-isa TASSIGN requires an integral address (not a pointer). + auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); + auto rcU64 = rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + + // Non-GM reinterpret_cast operands come from UB/L1/L0 tiles. + // We need the underlying address, but `__cce_get_tile_ptr()` is only valid + // inside `__tf__` functions. Use `tile.data()` (via a post-processed marker) + // and compute the adjusted address in bytes. + Value rawPtr = source; + if (auto ot = dyn_cast(source.getType())) { + // Only Tiles have a `.data()` member. For plain address-space pointers + // (e.g. `__ubuf__ float*`), use the pointer value directly. + if (ot.getValue().starts_with("Tile<")) { + rawPtr = materializeTileDataValue(rewriter, loc, source, as, elemTok); + } + } + + Value baseAddr = rawPtr; + if (isSetFFTsPointerLikeType(rawPtr.getType())) { + baseAddr = rewriter + .create(loc, u64Ty, "reinterpret_cast", + /*args=*/ArrayAttr{}, + /*templateArgs=*/rcU64, + /*operands=*/ValueRange{rawPtr}) + .getResult(0); + } else if (rawPtr.getType() != u64Ty) { + baseAddr = rewriter.create(loc, u64Ty, rawPtr).getResult(); + } + + Value addr = baseAddr; + if (offsetVal) { + Value offU64 = offsetVal; + if (offU64.getType() != u64Ty) + offU64 = rewriter.create(loc, u64Ty, offU64).getResult(); + + auto bytesAttr = emitc::OpaqueAttr::get(ctx, std::to_string(elemBytes)); + Value bytesVal = rewriter.create(loc, u64Ty, bytesAttr); + Value byteOff = rewriter.create(loc, u64Ty, offU64, bytesVal); + addr = rewriter.create(loc, u64Ty, baseAddr, byteOff); + } + + rewriter.create(loc, TypeRange{}, "TASSIGN", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{tile, addr}); + + rewriter.replaceOp(op, tile); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.taddc lowering -> TADDC(dst, src0, src1, src2) +//===----------------------------------------------------------------------===// + +struct PTOTAddCToTADDC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAddCOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value src2 = peelUnrealized(adaptor.getSrc2()); + Value dst = peelUnrealized(adaptor.getDst()); + + // pto-isa does not provide NPU implementation for TADDC yet. + // Decompose: dst = src0 + src1 + src2 + rewriter.create( + loc, TypeRange{}, "TADD", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TADD", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, dst, src2}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tadds lowering -> TADDS(dst, src, scalar) +//===----------------------------------------------------------------------===// + +struct PTOAddSToTADDS : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAddSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value scalar = peelUnrealized(adaptor.getScalar()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TADDS", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src, scalar}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.taddsc lowering -> TADDSC(dst, src0, scalar, src1) +//===----------------------------------------------------------------------===// + +struct PTOAddSCToTADDSC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAddSCOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + // pto-isa does not provide NPU implementation for TADDSC yet. + // Decompose: dst = src0 + scalar + src1 + rewriter.create( + loc, TypeRange{}, "TADDS", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src0, scalar}); + rewriter.create( + loc, TypeRange{}, "TADD", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, dst, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; +struct PTOTAndToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAndOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value a = peelUnrealized(adaptor.getSrc0()); + Value b = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TAND", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, a, b}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOConcatToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TConcatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TCONCAT", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; +struct PTOConcatidxToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TConcatidxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value src0Idx = peelUnrealized(adaptor.getSrc0Idx()); + Value src1Idx = peelUnrealized(adaptor.getSrc1Idx()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TCONCAT", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src0, src1, src0Idx, src1Idx}); + + rewriter.eraseOp(op); + return success(); + } +}; +struct PTOAndSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAndSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TANDS", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src, scalar}); + + rewriter.eraseOp(op); + return success(); + } +}; + + +struct PTOTCIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TCIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value S = peelUnrealized(adaptor.getOperands()[0]); + + // The TCI scalar template parameter should follow the original PTO IR + // scalar type, not the converted EmitC value type. + std::string scalarTok = "int32_t"; + if (auto it = dyn_cast(op->getOperand(0).getType())) { + bool isUnsigned = it.isUnsigned(); + if (it.getWidth() == 16) + scalarTok = isUnsigned ? "uint16_t" : "int16_t"; + else + scalarTok = isUnsigned ? "uint32_t" : "int32_t"; + } + + // descending -> "0"/"1" + std::string descTok = op.getDescending() ? "1" : "0"; + + ArrayAttr targs; + if (auto ot = mlir::dyn_cast(dst.getType())) { + std::string tileTok = ot.getValue().str(); // "Tile<...>" + targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, tileTok), + emitc::OpaqueAttr::get(ctx, scalarTok), + emitc::OpaqueAttr::get(ctx, descTok), + }); + } else { + targs = rewriter.getArrayAttr({}); + } + + rewriter.create( + loc, TypeRange{}, "TCI", + /*args=*/ArrayAttr{}, + /*templateArgs=*/targs, + /*operands=*/ValueRange{dst, S}); + + rewriter.eraseOp(op); + return success(); + } +}; +static std::string cmpModeTok(pto::CmpModeAttr a) { + // 生成 "CmpMode::GT" 这种 token + auto m = a.getValue(); // 取 enum + switch (m) { + case pto::CmpMode::EQ: return "CmpMode::EQ"; + case pto::CmpMode::NE: return "CmpMode::NE"; + case pto::CmpMode::LT: return "CmpMode::LT"; + case pto::CmpMode::LE: return "CmpMode::LE"; + case pto::CmpMode::GT: return "CmpMode::GT"; + case pto::CmpMode::GE: return "CmpMode::GE"; + } + return "CmpMode::EQ"; +} +struct PTOColExpandToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value src = peelUnrealized(adaptor.getSrc()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPAND", + /*args=*/ArrayAttr(), + /*templateArgs=*/ArrayAttr(), + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColExpandMulToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandMulOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPANDMUL", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColExpandAddToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandAddOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPANDADD", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColExpandDivToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandDivOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPANDDIV", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColExpandExpdifToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandExpdifOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPANDEXPDIF", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColExpandSubToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandSubOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPANDSUB", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColExpandMaxToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPANDMAX", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColExpandMinToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPANDMIN", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOTTriToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TTriOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value diagonal = peelUnrealized(adaptor.getDiagonal()); + + ArrayAttr templateArgs; + if (auto dstOT = mlir::dyn_cast(dst.getType())) { + templateArgs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), + emitc::OpaqueAttr::get(ctx, std::to_string(op.getUpperOrLower())), + }); + } else { + templateArgs = ArrayAttr{}; + } + + SmallVector operands{dst, diagonal}; + rewriter.create( + loc, TypeRange{}, "TTRI", + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOCmpToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TCmpOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + + std::string tok = "CmpMode::EQ"; + if (auto a = op.getCmpModeAttr()) + tok = cmpModeTok(a); + + auto modeTy = emitc::OpaqueType::get(ctx, "CmpMode"); + Value modeVal = rewriter.create( + loc, modeTy, emitc::OpaqueAttr::get(ctx, tok)); + + rewriter.create( + loc, + TypeRange{}, + "TCMP", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1, modeVal}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOCmpSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TCmpSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + + // cmpMode -> token + auto cmpAttr = op.getCmpModeAttr(); // PTO_CmpModeAttr + std::string tok = cmpModeTok(cmpAttr); + + auto modeTy = emitc::OpaqueType::get(ctx, "CmpMode"); + Value modeVal = rewriter.create( + loc, modeTy, emitc::OpaqueAttr::get(ctx, tok)); + + rewriter.create( + loc, + TypeRange{}, + "TCMPS", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, scalar, modeVal}); + + rewriter.eraseOp(op); + return success(); + } +}; + + +struct PTOColMaxToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + // intrinsic: TCOLMAX(dst, src) + rewriter.create( + loc, TypeRange{}, "TCOLMAX", + /*args=*/ArrayAttr{}, // default: print all operands + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColArgMaxToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColArgMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLARGMAX", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, tmp}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColMinToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + // intrinsic: TCOLMIN(dst, src) + rewriter.create( + loc, TypeRange{}, "TCOLMIN", + /*args=*/ArrayAttr{}, // default: print all operands + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColArgMinToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColArgMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLARGMIN", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, tmp}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColSumToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColSumOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + // Check if tmp exists before accessing it + if (op.getTmp()) { + // Format 2: with tmp and isBinary + Value tmp = peelUnrealized(adaptor.getTmp()); + bool isBinary = false; + if (auto a = op.getIsBinaryAttr()) + isBinary = a.getValue(); + + auto boolTy = emitc::OpaqueType::get(ctx, "bool"); + auto tok = isBinary ? "true" : "false"; + Value isBinaryVal = rewriter.create( + loc, boolTy, emitc::OpaqueAttr::get(ctx, tok)); + + rewriter.create( + loc, TypeRange{}, "TCOLSUM", + /*args=*/ArrayAttr(), + /*templateArgs=*/ArrayAttr(), + /*operands=*/ValueRange{dst, src, tmp, isBinaryVal}); + } else { + // Format 1: without tmp and isBinary + rewriter.create( + loc, TypeRange{}, "TCOLSUM", + /*args=*/ArrayAttr(), + /*templateArgs=*/ArrayAttr(), + /*operands=*/ValueRange{dst, src}); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColProdToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColProdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLPROD", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; +static std::string roundModeTok(mlir::pto::RoundModeAttr attr) { + using RM = mlir::pto::RoundMode; + switch (attr.getValue()) { + case RM::NONE: return "RoundMode::CAST_NONE"; + case RM::RINT: return "RoundMode::CAST_RINT"; + case RM::ROUND: return "RoundMode::CAST_ROUND"; + case RM::FLOOR: return "RoundMode::CAST_FLOOR"; + case RM::CEIL: return "RoundMode::CAST_CEIL"; + case RM::TRUNC: return "RoundMode::CAST_TRUNC"; + case RM::ODD: return "RoundMode::CAST_ODD"; + case RM::CAST_RINT: return "RoundMode::CAST_RINT"; + } + return "RoundMode::CAST_RINT"; +} +static std::string saturationModeTok(mlir::pto::SaturationModeAttr attr) { + using SM = mlir::pto::SaturationMode; + switch (attr.getValue()) { + case SM::ON: return "SaturationMode::ON"; + case SM::OFF: return "SaturationMode::OFF"; + } + return "SaturationMode::OFF"; +} +struct PTOCvtToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TCvtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + pto::RoundModeAttr rmAttr = op.getRmodeAttr(); + std::string rmTok = rmAttr ? roundModeTok(rmAttr) + : std::string("RoundMode::CAST_RINT"); + auto rmodeTy = emitc::OpaqueType::get(ctx, "RoundMode"); + Value rmodeVal = rewriter.create( + loc, rmodeTy, emitc::OpaqueAttr::get(ctx, rmTok)); + + auto satModeTy = emitc::OpaqueType::get(ctx, "SaturationMode"); + auto satAttr = op.getSatModeAttr(); + std::string satTok = satAttr ? saturationModeTok(satAttr) + : std::string("SaturationMode::OFF"); + Value satModeVal = rewriter.create( + loc, satModeTy, emitc::OpaqueAttr::get(ctx, satTok)); + + SmallVector operands{dst, src, rmodeVal, satModeVal}; + + rewriter.create( + loc, TypeRange{}, "TCVT", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +struct PTORandomToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRandomOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + SmallVector operands{ + dst, + peelUnrealized(adaptor.getKey0()), + peelUnrealized(adaptor.getKey1()), + peelUnrealized(adaptor.getCounter0()), + peelUnrealized(adaptor.getCounter1()), + peelUnrealized(adaptor.getCounter2()), + peelUnrealized(adaptor.getCounter3()), + }; + ArrayAttr templateArgs = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, std::to_string(op.getRounds()))}); + + rewriter.create( + loc, TypeRange{}, "PTOAS__TRANDOM", + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, operands); + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tdiv lowering -> TDIV(dst, src0, src1) +//===----------------------------------------------------------------------===// + +struct PTODivToTDIV : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TDivOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TDIV", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tdivs lowering -> TDIVS(dst, src, scalar) or TDIVS(dst, scalar, src) +// Order is determined by operand types: if src is tile_buf, order is (tile, scalar) +// Otherwise, order is (scalar, tile) +//===----------------------------------------------------------------------===// + +struct PTODivSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TDivSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value dst = peelUnrealized(adaptor.getDst()); + // Preserve source order from textual parse: + // ins(tile, scalar) -> TDIVS(dst, tile, scalar) + // ins(scalar, tile) -> TDIVS(dst, scalar, tile) + rewriter.create( + loc, TypeRange{}, "TDIVS", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src, scalar}); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// pto.tdivs (TDivSOp) lowering -> TDIVS(dst, src, scalar) or TDIVS(dst, scalar, src) +// Order is determined by operand types: if src is tile_buf, order is (tile, scalar) +// Otherwise, order is (scalar, tile) +//===----------------------------------------------------------------------===// + +struct PTOTDivSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TDivSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value dst = peelUnrealized(adaptor.getDst()); + rewriter.create( + loc, TypeRange{}, "TDIVS", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src, scalar}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.texp lowering -> TEXP(dst, src) +//===----------------------------------------------------------------------===// + +struct PTOExpToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TExpOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TEXP", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.texpands lowering -> TEXPANDS(dst, scalar) +//===----------------------------------------------------------------------===// + +struct PTOExpandsToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TExpandsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value scalar = peelUnrealized(adaptor.getScalar()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TEXPANDS", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, scalar}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.textract lowering -> TEXTRACT(dst, src, indexRow, indexCol) +//===----------------------------------------------------------------------===// + +struct PTOExtractToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TExtractOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value r0 = peelUnrealized(adaptor.getIndexRow()); + Value c0 = peelUnrealized(adaptor.getIndexCol()); + + rewriter.create( + loc, TypeRange{}, "TEXTRACT", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, r0, c0}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.textract_fp lowering -> TEXTRACT_FP(dst, src, fp, indexRow, indexCol) +//===----------------------------------------------------------------------===// + +struct PTOExtractFPToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TExtractFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value fp = peelUnrealized(adaptor.getFp()); + Value dst = peelUnrealized(adaptor.getDst()); + Value r0 = peelUnrealized(adaptor.getIndexRow()); + Value c0 = peelUnrealized(adaptor.getIndexCol()); + + rewriter.create( + loc, TypeRange{}, "TEXTRACT_FP", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, fp, r0, c0}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tinsert lowering -> TINSERT(dst, src, indexRow, indexCol) +// Keep lowering arch-agnostic and let PTO-ISA infer proper A5 path. +//===----------------------------------------------------------------------===// + +struct PTOInsertToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TInsertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value r0 = peelUnrealized(adaptor.getIndexRow()); + Value c0 = peelUnrealized(adaptor.getIndexCol()); + + rewriter.create( + loc, TypeRange{}, "TINSERT", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, r0, c0}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tinsert_fp lowering -> TINSERT_FP(dst, src, fp, indexRow, indexCol) +//===----------------------------------------------------------------------===// + +struct PTOInsertFPToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TInsertFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value fp = peelUnrealized(adaptor.getFp()); + Value dst = peelUnrealized(adaptor.getDst()); + Value r0 = peelUnrealized(adaptor.getIndexRow()); + Value c0 = peelUnrealized(adaptor.getIndexCol()); + + rewriter.create( + loc, TypeRange{}, "TINSERT_FP", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, fp, r0, c0}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tfillpad lowering -> TFILLPAD(dst, src) +//===----------------------------------------------------------------------===// + +struct PTOFillPadToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TFillPadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TFILLPAD", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tfillpad_inplace lowering -> TFILLPAD_INPLACE(dst, src) +//===----------------------------------------------------------------------===// + +struct PTOFillPadInplaceToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TFillPadInplaceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TFILLPAD_INPLACE", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tfillpad_expand lowering -> TFILLPAD_EXPAND(dst, src) +//===----------------------------------------------------------------------===// + +struct PTOFillPadExpandToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TFillPadExpandOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TFILLPAD_EXPAND", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tgather lowering +// - Index form : TGATHER(dst, src0, indices, tmp) +// - Compare form: TGATHER(dst, src0, kValue, cdst, tmp) +// - Mask form : TGATHER(dst, src0) +//===----------------------------------------------------------------------===// + +[[maybe_unused]] static std::string maskPatternTok(mlir::pto::MaskPatternAttr a) { + + auto v = a.getValue(); // enum + return (std::string("pto::MaskPattern::") + mlir::pto::stringifyMaskPattern(v).str()); +} + +struct PTOGatherToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value src0 = peelUnrealized(adaptor.getSrc()); + + auto getOpaqueTok = [&](Value v, StringRef name) -> FailureOr { + if (auto ot = mlir::dyn_cast(v.getType())) + return ot.getValue().str(); + return rewriter.notifyMatchFailure(op, (name + " must be emitc::OpaqueType (tile)").str()); + }; + + // Case 1: index-based TGATHER(dst, src0, indices, tmp) + if (Value idx = adaptor.getIndices()) { + idx = peelUnrealized(idx); + Value tmp = peelUnrealized(adaptor.getTmp()); + + rewriter.create( + loc, TypeRange{}, "TGATHER", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, idx, tmp}); + + rewriter.eraseOp(op); + return success(); + } + + // Case 2: compare-based TGATHER( + // dst, src0, kValue, tmp, cdst, offset) + if (Value cdst = adaptor.getCdst()) { + cdst = peelUnrealized(cdst); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value kValue = peelUnrealized(adaptor.getKValue()); + + auto dstTokOr = getOpaqueTok(dst, "dst"); + auto srcTokOr = getOpaqueTok(src0, "src0"); + auto cdstTokOr = getOpaqueTok(cdst, "cdst"); + auto tmpTokOr = getOpaqueTok(tmp, "tmp"); + if (failed(dstTokOr) || failed(srcTokOr) || failed(cdstTokOr) || failed(tmpTokOr)) + return failure(); + + auto cmpAttr = op.getCmpModeAttr(); + std::string cmpTok = cmpAttr ? cmpModeTok(cmpAttr) : "CmpMode::EQ"; + int64_t offset = 0; + if (auto offsetAttr = op.getOffsetAttr()) + offset = offsetAttr.getInt(); + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value offsetVal = makeEmitCIntConstant(rewriter, loc, i32Ty, offset); + + auto targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, *dstTokOr), + emitc::OpaqueAttr::get(ctx, *srcTokOr), + emitc::OpaqueAttr::get(ctx, *tmpTokOr), + emitc::OpaqueAttr::get(ctx, *cdstTokOr), + emitc::OpaqueAttr::get(ctx, cmpTok), + }); + + rewriter.create( + loc, TypeRange{}, "TGATHER", + /*args=*/ArrayAttr{}, /*templateArgs=*/targs, + /*operands=*/ValueRange{dst, src0, kValue, tmp, cdst, offsetVal}); + + rewriter.eraseOp(op); + return success(); + } + + // Case 3: mask-pattern TGATHER(dst, src0) + auto mp = op.getMaskPatternAttr(); + if (!mp) + return rewriter.notifyMatchFailure(op, "expected maskPattern, indices, or cdst on tgather"); + + auto dstTokOr = getOpaqueTok(dst, "dst"); + auto srcTokOr = getOpaqueTok(src0, "src0"); + if (failed(dstTokOr) || failed(srcTokOr)) + return failure(); + + // mp is an EnumAttr; stringify name is "P0101" etc. + // We emit MaskPattern::P0101 (because generated C++ has `using namespace pto;`) + std::string mpTok = std::string("MaskPattern::") + + mlir::pto::stringifyMaskPattern(mp.getValue()).str(); + + auto targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, *dstTokOr), + emitc::OpaqueAttr::get(ctx, *srcTokOr), + emitc::OpaqueAttr::get(ctx, mpTok), + }); + + rewriter.create( + loc, TypeRange{}, "TGATHER", + /*args=*/ArrayAttr{}, + /*templateArgs=*/targs, + /*operands=*/ValueRange{dst, src0}); + + rewriter.eraseOp(op); + return success(); + } +}; + + +struct PTOGatherbToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGatherBOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value offsets = peelUnrealized(adaptor.getOffsets()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TGATHERB", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, offsets}); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// TLOG lowering to EmitC (PTOConvert.cpp) +//===----------------------------------------------------------------------===// + +struct PTOLogToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TLogOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src}; + rewriter.create( + loc, TypeRange{}, "TLOG", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + + + +//===----------------------------------------------------------------------===// +// TLRELU lowering to EmitC (PTOConvert.cpp) +//===----------------------------------------------------------------------===// + + struct PTOLReluToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TLReluOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value slope = peelUnrealized(adaptor.getSlope()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, slope}; + + rewriter.create( + loc, TypeRange{}, "TLRELU", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// TMAX lowering to EmitC (PTOConvert.cpp) +//===----------------------------------------------------------------------===// + +struct PTOMaxToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TMAX", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// TMAXS lowering to EmitC (PTOConvert.cpp) +//===----------------------------------------------------------------------===// + + struct PTOMaxSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMaxSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, scalar}; + rewriter.create( + loc, TypeRange{}, "TMAXS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + + +//===----------------------------------------------------------------------===// +// TMIN lowering to EmitC (PTOConvert.cpp) +//===----------------------------------------------------------------------===// + +struct PTOMinToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TMIN", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// TMINS lowering to EmitC (PTOConvert.cpp) +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// TMINS lowering to EmitC (fix APFloat -> FloatAttr) (PTOConvert.cpp) +//===----------------------------------------------------------------------===// + +struct PTOMinsToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMinSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value scalar = peelUnrealized(adaptor.getScalar()); + + SmallVector operands{dst, src, scalar}; + rewriter.create( + loc, TypeRange{}, "TMINS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering for TMOV op -> EmitC) +//===----------------------------------------------------------------------===// + +struct PTOMovToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMovOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value fp; + if (op.getFp()) + fp = peelUnrealized(adaptor.getFp()); + Value preQuantScalar; + if (op.getPreQuantScalar()) + preQuantScalar = peelUnrealized(adaptor.getPreQuantScalar()); + + auto dstOT = mlir::dyn_cast(dst.getType()); + auto srcOT = mlir::dyn_cast(src.getType()); + if (!dstOT || !srcOT) + return rewriter.notifyMatchFailure( + op, "tmov lowering expects opaque dst/src types"); + + auto modeTok = [&](pto::AccToVecMode mode) -> StringRef { + switch (mode) { + case pto::AccToVecMode::SingleModeVec0: + return "pto::AccToVecMode::SingleModeVec0"; + case pto::AccToVecMode::SingleModeVec1: + return "pto::AccToVecMode::SingleModeVec1"; + case pto::AccToVecMode::DualModeSplitM: + return "pto::AccToVecMode::DualModeSplitM"; + case pto::AccToVecMode::DualModeSplitN: + return "pto::AccToVecMode::DualModeSplitN"; + } + llvm_unreachable("unknown AccToVecMode"); + }; + + auto modeAttr = op.getAccToVecModeAttr(); + auto reluTok = [&](pto::ReluPreMode mode) -> StringRef { + switch (mode) { + case pto::ReluPreMode::NoRelu: + return "ReluPreMode::NoRelu"; + case pto::ReluPreMode::NormalRelu: + return "ReluPreMode::NormalRelu"; + } + llvm_unreachable("unknown ReluPreMode"); + }; + + const bool hasFp = static_cast(fp); + const bool hasPreQuantScalar = static_cast(preQuantScalar); + const bool hasMode = static_cast(modeAttr); + const bool reluNonDefault = op.getReluPreMode() != pto::ReluPreMode::NoRelu; + + SmallVector operands{dst, src}; + SmallVector templateArgVec{ + emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), + emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), + }; + StringRef callee = "TMOV"; + + if (hasFp) { + auto fpOT = mlir::dyn_cast(fp.getType()); + if (!fpOT) + return rewriter.notifyMatchFailure( + op, "tmov fp lowering expects opaque fp type"); + operands.push_back(fp); + templateArgVec.push_back(emitc::OpaqueAttr::get(ctx, fpOT.getValue().str())); + if (hasMode) + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, modeTok(modeAttr.getValue()))); + if (hasMode || reluNonDefault) + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); + callee = hasMode ? "TMOV" : "TMOV_FP"; + } else if (hasPreQuantScalar) { + operands.push_back(preQuantScalar); + if (hasMode) + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, modeTok(modeAttr.getValue()))); + if (hasMode || reluNonDefault) + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); + } else if (hasMode) { + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, modeTok(modeAttr.getValue()))); + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); + } else if (reluNonDefault) { + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); + } + + ArrayAttr templateArgs = + templateArgVec.size() == 2 && !hasFp && !hasPreQuantScalar && + !hasMode && !reluNonDefault + ? ArrayAttr{} + : rewriter.getArrayAttr(templateArgVec); + + rewriter.create( + loc, TypeRange{}, callee, + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TMOV_FP DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOMovFPToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMovFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value src = peelUnrealized(adaptor.getSrc()); + Value fp = peelUnrealized(adaptor.getFp()); + + // TMOV_FP(dstTileData, cTile, fbTile) + ArrayAttr templateArgs; + auto dstOT = mlir::dyn_cast(dst.getType()); + auto srcOT = mlir::dyn_cast(src.getType()); + auto fpOT = mlir::dyn_cast(fp.getType()); + if (dstOT && srcOT && fpOT) { + templateArgs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), + emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), + emitc::OpaqueAttr::get(ctx, fpOT.getValue().str()), + }); + } else { + templateArgs = ArrayAttr{}; + } + + SmallVector operands{dst, src, fp}; + rewriter.create( + loc, TypeRange{}, "TMOV_FP", + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOQuantToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TQuantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value src = peelUnrealized(adaptor.getSrc()); + Value fp = peelUnrealized(adaptor.getFp()); + + // Optional offset (INT8_ASYM only): passed as pointer (&offset) + Value offsetPtr; + if (op.getOffset()) { + Value offset = peelUnrealized(adaptor.getOffset()); + auto offsetOT = mlir::dyn_cast(offset.getType()); + if (offsetOT) { + offsetPtr = rewriter + .create( + loc, emitc::PointerType::get(offsetOT), "&", offset) + .getResult(); + } + } + + // TQUANT(dst, src, fp[, &offset]) + std::string quantTypeStr = + op.getQuantType() == pto::QuantType::INT8_SYM + ? "pto::QuantType::INT8_SYM" + : "pto::QuantType::INT8_ASYM"; + ArrayAttr templateArgs; + auto dstOT = mlir::dyn_cast(dst.getType()); + auto srcOT = mlir::dyn_cast(src.getType()); + auto fpOT = mlir::dyn_cast(fp.getType()); + if (dstOT && srcOT && fpOT) { + templateArgs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, quantTypeStr), + emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), + emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), + emitc::OpaqueAttr::get(ctx, fpOT.getValue().str()), + }); + } else { + templateArgs = ArrayAttr{}; + } + + SmallVector operands{dst, src, fp}; + if (offsetPtr) + operands.push_back(offsetPtr); + + rewriter.create( + loc, TypeRange{}, "TQUANT", + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +struct PTODequantToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TDequantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value src = peelUnrealized(adaptor.getSrc()); + Value scale = peelUnrealized(adaptor.getScale()); + Value offset = peelUnrealized(adaptor.getOffset()); + + // TDEQUANT(dst, src, scale, offset) + ArrayAttr templateArgs; + auto dstOT = mlir::dyn_cast(dst.getType()); + auto srcOT = mlir::dyn_cast(src.getType()); + auto scaleOT = mlir::dyn_cast(scale.getType()); + if (dstOT && srcOT && scaleOT) { + templateArgs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), + emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), + emitc::OpaqueAttr::get(ctx, scaleOT.getValue().str()), + }); + } else { + templateArgs = ArrayAttr{}; + } + + rewriter.create( + loc, TypeRange{}, "TDEQUANT", + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, + /*operands=*/SmallVector{dst, src, scale, offset}); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TMRGSORT DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOMrgSortToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMrgSortOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + if (op.isFormat1()) { + Value src = peelUnrealized(adaptor.getSrcs().front()); + Value dst = peelUnrealized(adaptor.getDsts().front()); + Value blockLen = peelUnrealized(adaptor.getBlockLen()); + + SmallVector operands{dst, src, blockLen}; + rewriter.create( + loc, TypeRange{}, "TMRGSORT", + ArrayAttr{}, ArrayAttr{}, operands); + } else if (op.isFormat2()) { + // pto-isa API: + // TMRGSORT( + // dst, executedNumList, tmp, src0, src1[, src2[, src3]]); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDsts()[0]); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value excuted = peelUnrealized(adaptor.getExcuted()); + + SmallVector srcs; + srcs.reserve(adaptor.getSrcs().size()); + for (Value v : adaptor.getSrcs()) + srcs.push_back(peelUnrealized(v)); + + auto dstOT = mlir::dyn_cast(dst.getType()); + auto tmpOT = mlir::dyn_cast(tmp.getType()); + if (!dstOT || !tmpOT || srcs.size() < 2 || srcs.size() > 4) + return op.emitOpError("format2 expects dst/tmp tilebufs and 2 to 4 srcs"); + + SmallVector targs; + targs.reserve(2 + srcs.size() + 1); + targs.push_back(emitc::OpaqueAttr::get(ctx, dstOT.getValue().str())); + targs.push_back(emitc::OpaqueAttr::get(ctx, tmpOT.getValue().str())); + for (Value v : srcs) { + auto ot = mlir::dyn_cast(v.getType()); + if (!ot) + return op.emitOpError("format2 expects tilebuf srcs"); + targs.push_back(emitc::OpaqueAttr::get(ctx, ot.getValue().str())); + } + targs.push_back(emitc::OpaqueAttr::get(ctx, op.getExhausted() ? "true" : "false")); + ArrayAttr templateArgs = rewriter.getArrayAttr(targs); + + SmallVector operands{dst, excuted, tmp}; + operands.append(srcs.begin(), srcs.end()); + + rewriter.create( + loc, TypeRange{}, "TMRGSORT", + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, operands); + } else { + return op.emitOpError("unsupported mrgsort_dps format"); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TMUL DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOMulToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMulOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TMUL", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TMULS DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOMulsToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMulSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc0()); + Value dst = peelUnrealized(adaptor.getDst()); + Value scalar = peelUnrealized(adaptor.getScalar()); + + SmallVector operands{dst, src, scalar}; + rewriter.create( + loc, TypeRange{}, "TMULS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TNEG DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTONegToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TNegOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src}; + rewriter.create( + loc, TypeRange{}, "TNEG", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TNOT DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTONotToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TNotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src}; + rewriter.create( + loc, TypeRange{}, "TNOT", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TOR DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOOrToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TOrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TOR", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TORS DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOOrsToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TOrSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + // NOTE: The conversion type system may materialize integers as emitc.opaque + // (e.g. "int32_t"). For EmitC call emission we can pass the scalar through + // directly without arith casts here. + Value s = adaptor.getScalar(); + + SmallVector operands{dst, src0, s}; + rewriter.create( + loc, TypeRange{}, "TORS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TPARTADD DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOPartAddToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPartAddOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TPARTADD", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TPARTMAX DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOPartMaxToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPartMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TPARTMAX", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TPARTMIN DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOPartMinToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPartMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TPARTMIN", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOPartArgMaxToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPartArgMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value src0Idx = peelUnrealized(adaptor.getSrc0Idx()); + Value src1Idx = peelUnrealized(adaptor.getSrc1Idx()); + Value dst = peelUnrealized(adaptor.getDst()); + Value dstIdx = peelUnrealized(adaptor.getDstIdx()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TPARTARGMAX", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1, dstIdx, src0Idx, src1Idx}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOPartArgMinToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPartArgMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value src0Idx = peelUnrealized(adaptor.getSrc0Idx()); + Value src1Idx = peelUnrealized(adaptor.getSrc1Idx()); + Value dst = peelUnrealized(adaptor.getDst()); + Value dstIdx = peelUnrealized(adaptor.getDstIdx()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TPARTARGMIN", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1, dstIdx, src0Idx, src1Idx}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TPARTMUL DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOPartMulToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPartMulOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TPARTMUL", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TPRELU DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOPreluToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPReluOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + // C++ interface: TPRELU(dst, src0, src1, tmp) — last parameter is tmp. + SmallVector operands{dst, src0, src1, tmp}; + rewriter.create( + loc, TypeRange{}, "TPRELU", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TRECIP DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORecipToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRecipOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src}; + rewriter.create( + loc, TypeRange{}, "TRECIP", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TRELU DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOReluToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TReluOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src}; + rewriter.create( + loc, TypeRange{}, "TRELU", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TREM DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORemToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRemOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + SmallVector operands{dst, src0, src1, tmp}; + rewriter.create( + loc, TypeRange{}, "TREM", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOFModToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TFModOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TFMOD", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TREMS DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORemSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRemSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + Value scalar = peelUnrealized(adaptor.getScalar()); + SmallVector operands{dst, src, scalar, tmp}; + rewriter.create( + loc, TypeRange{}, "TREMS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOFModSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TFModSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value scalar = peelUnrealized(adaptor.getScalar()); + + SmallVector operands{dst, src, scalar}; + rewriter.create( + loc, TypeRange{}, "TFMODS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TROWEXPAND DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORowExpandToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src}; + rewriter.create( + loc, TypeRange{}, "TROWEXPAND", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTORowExpandAddToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandAddOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TROWEXPANDADD", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTORowExpandExpdifToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandExpdifOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); + + SmallVector operands; + if (tmp) + operands.assign({dst, src0, src1, tmp}); + else + operands.assign({dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TROWEXPANDEXPDIF", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TROWEXPANDDIV DPS/memref op) +//===----------------------------------------------------------------------===// +// Helper: replace or erase based on whether op has results. +static void replaceOrEraseWithOpaqueCall(Operation *op, + StringRef callee, + ArrayRef args, + ConversionPatternRewriter &rewriter) { + TypeRange resultTypes = op->getResultTypes(); + auto call = rewriter.create( + op->getLoc(), resultTypes, callee, ArrayAttr{}, ArrayAttr{}, ValueRange(args)); + if (resultTypes.empty()) + rewriter.eraseOp(op); + else + rewriter.replaceOp(op, call.getResults()); +} + +static void replaceOrEraseWithOpaqueCallAndReturnDst(Operation *op, Value dst, + StringRef callee, + ArrayRef args, + ConversionPatternRewriter &rewriter) { + rewriter.create( + op->getLoc(), TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, ValueRange(args)); + if (op->getNumResults() == 1) + rewriter.replaceOp(op, dst); + else + rewriter.eraseOp(op); +} + +// ---------- TOp ---------- +struct PTOTGemvBiasToTGEMV_BIAS + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGemvBiasOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value a = peelUnrealized(adaptor.getA()); + Value b = peelUnrealized(adaptor.getB()); + Value bias = peelUnrealized(adaptor.getBias()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCall(op.getOperation(), "TGEMV_BIAS", + {dst, a, b, bias}, rewriter); + return success(); + } +}; + +struct PTOTGemvMXToTGEMV_MX + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGemvMxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value a = peelUnrealized(adaptor.getA()); + Value aScale = peelUnrealized(adaptor.getAScale()); + Value b = peelUnrealized(adaptor.getB()); + Value bScale = peelUnrealized(adaptor.getBScale()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", + {dst, a, aScale, b, bScale}, rewriter); + return success(); + } +}; + +struct PTOTGemvMXAccToTGEMV_MX + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGemvMxAccOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value cIn = peelUnrealized(adaptor.getCIn()); + Value a = peelUnrealized(adaptor.getA()); + Value aScale = peelUnrealized(adaptor.getAScale()); + Value b = peelUnrealized(adaptor.getB()); + Value bScale = peelUnrealized(adaptor.getBScale()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", + {dst, cIn, a, aScale, b, bScale}, rewriter); + return success(); + } +}; + +struct PTOTGemvMXBiasToTGEMV_MX + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGemvMxBiasOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value a = peelUnrealized(adaptor.getA()); + Value aScale = peelUnrealized(adaptor.getAScale()); + Value b = peelUnrealized(adaptor.getB()); + Value bScale = peelUnrealized(adaptor.getBScale()); + Value bias = peelUnrealized(adaptor.getBias()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", + {dst, a, aScale, b, bScale, bias}, rewriter); + return success(); + } +}; + +struct PTOTMatmulBiasToTMATMUL_BIAS + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMatmulBiasOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value a = peelUnrealized(adaptor.getA()); + Value b = peelUnrealized(adaptor.getB()); + Value bias = peelUnrealized(adaptor.getBias()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_BIAS", + {dst, a, b, bias}, rewriter); + return success(); + } +}; + +struct PTOTMatmulMXToTMATMUL_MX + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMatmulMxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value a = peelUnrealized(adaptor.getA()); + Value aScale = peelUnrealized(adaptor.getAScale()); + Value b = peelUnrealized(adaptor.getB()); + Value bScale = peelUnrealized(adaptor.getBScale()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_MX", + {dst, a, aScale, b, bScale}, rewriter); + return success(); + } +}; + +struct PTOTMatmulMXAccToTMATMUL_MX_ACC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMatmulMxAccOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value cIn = peelUnrealized(adaptor.getCIn()); + Value a = peelUnrealized(adaptor.getA()); + Value aScale = peelUnrealized(adaptor.getAScale()); + Value b = peelUnrealized(adaptor.getB()); + Value bScale = peelUnrealized(adaptor.getBScale()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_MX", + {dst, cIn, a, aScale, b, bScale}, rewriter); + return success(); + } +}; + +struct PTOTMatmulMXBiasToTMATMUL_MX_BIAS + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMatmulMxBiasOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value a = peelUnrealized(adaptor.getA()); + Value aScale = peelUnrealized(adaptor.getAScale()); + Value b = peelUnrealized(adaptor.getB()); + Value bScale = peelUnrealized(adaptor.getBScale()); + Value bias = peelUnrealized(adaptor.getBias()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_MX", + {dst, a, aScale, b, bScale, bias}, rewriter); + return success(); + } +}; + +struct PTORowExpandDivToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandDivOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); + + SmallVector operands; + if (tmp) + operands.assign({dst, src0, src1, tmp}); + else + operands.assign({dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TROWEXPANDDIV", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TROWEXPANDMUL DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORowExpandMulToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandMulOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); + + SmallVector operands; + if (tmp) + operands.assign({dst, src0, src1, tmp}); + else + operands.assign({dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TROWEXPANDMUL", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TROWEXPANDSUB DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORowExpandSubToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandSubOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); + + SmallVector operands; + if (tmp) + operands.assign({dst, src0, src1, tmp}); + else + operands.assign({dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TROWEXPANDSUB", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTORowExpandMaxToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); + + SmallVector operands; + if (tmp) + operands.assign({dst, src0, src1, tmp}); + else + operands.assign({dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TROWEXPANDMAX", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTORowExpandMinToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); + + SmallVector operands; + if (tmp) + operands.assign({dst, src0, src1, tmp}); + else + operands.assign({dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TROWEXPANDMIN", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TROWMAX DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORowMaxToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, tmp}; + rewriter.create( + loc, TypeRange{}, "TROWMAX", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTORowArgMaxToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowArgMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TROWARGMAX", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, tmp}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TROWMIN DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORowMinToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, tmp}; + rewriter.create( + loc, TypeRange{}, "TROWMIN", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTORowArgMinToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowArgMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TROWARGMIN", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, tmp}); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TROWSUM DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORowSumToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowSumOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, tmp}; + rewriter.create( + loc, TypeRange{}, "TROWSUM", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTORowProdToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowProdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, tmp}; + rewriter.create( + loc, TypeRange{}, "TROWPROD", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TRSQRT DPS/memref op) +// - no-tmp form : TRSQRT(dst, src) +// - tmp form : TRSQRT(dst, src, tmp) +//===----------------------------------------------------------------------===// + +struct PTORsqrtToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRsqrtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + SmallVector operands{dst, src}; + if (Value tmp = adaptor.getTmp()) + operands.push_back(peelUnrealized(tmp)); + rewriter.create( + loc, TypeRange{}, "TRSQRT", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSCATTER DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOScatterToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TScatterOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + const bool hasMaskPattern = static_cast(op.getMaskPatternAttr()); + const bool hasIndexes = static_cast(op.getIndexes()); + if (hasMaskPattern == hasIndexes) { + return rewriter.notifyMatchFailure( + op, "expected exactly one of indexes operand or maskPattern attribute"); + } + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + if (auto mp = op.getMaskPatternAttr()) { + auto *ctx = rewriter.getContext(); + auto targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, maskPatternTok(mp)), + }); + rewriter.create( + loc, TypeRange{}, "TSCATTER", + /*args=*/ArrayAttr{}, /*templateArgs=*/targs, + /*operands=*/ValueRange{dst, src}); + } else { + Value idx = peelUnrealized(adaptor.getIndexes()); + rewriter.create( + loc, TypeRange{}, "TSCATTER", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, idx}); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSEL DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOSelToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSelOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value mask = peelUnrealized(adaptor.getMask()); + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, mask, src0, src1, tmp}; + rewriter.create( + loc, TypeRange{}, "TSEL", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSELS DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOSelSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSelSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value mask = peelUnrealized(adaptor.getMask()); + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, mask, src, tmp, scalar}; + rewriter.create( + loc, TypeRange{}, "TSELS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSHL DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOShlSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TShlOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TSHL", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSHR DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOShrSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TShrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TSHR", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering for TSHLS/TSHRS DPS: shift by scalar) +//===----------------------------------------------------------------------===// + +struct PTOShlSConstToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TShlSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value dst = peelUnrealized(adaptor.getDst()); + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + SmallVector operands{dst, src, scalar}; + rewriter.create( + loc, TypeRange{}, "TSHLS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOShrSConstToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TShrSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value dst = peelUnrealized(adaptor.getDst()); + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + SmallVector operands{dst, src, scalar}; + rewriter.create( + loc, TypeRange{}, "TSHRS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (TSORT32 DPS/memref op: ins(src, idx[, tmp]) outs(dst)) +//===----------------------------------------------------------------------===// + +struct PTOSORT32SToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSort32Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value idx = peelUnrealized(adaptor.getIdx()); + Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); + + SmallVector operands; + if (tmp) + operands.assign({dst, src, idx, tmp}); + else + operands.assign({dst, src, idx}); + rewriter.create( + loc, TypeRange{}, "TSORT32", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSQRT DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOSqrtSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSqrtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src}; + rewriter.create( + loc, TypeRange{}, "TSQRT", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSTORE_FP DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOStoreFPSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TStoreFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value fp = peelUnrealized(adaptor.getFp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, fp}; + rewriter.create( + loc, TypeRange{}, "TSTORE_FP", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSUB DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOSubSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSubOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TSUB", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSUBC DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOSubCSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSubCOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value src2 = peelUnrealized(adaptor.getSrc2()); + Value dst = peelUnrealized(adaptor.getDst()); + + // pto-isa does not provide NPU implementation for TSUBC yet. + // Decompose: dst = src0 - src1 + src2 + rewriter.create( + loc, TypeRange{}, "TSUB", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TADD", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, dst, src2}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSUBS DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOSubSSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSubSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, scalar}; + rewriter.create( + loc, TypeRange{}, "TSUBS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSUBSC DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOSubSCToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSubSCOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + // pto-isa does not provide NPU implementation for TSUBSC yet. + // Decompose: dst = src0 - scalar + src1 + rewriter.create( + loc, TypeRange{}, "TSUBS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, scalar}); + rewriter.create( + loc, TypeRange{}, "TADD", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, dst, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TXOR DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOXORToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TXorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + Value tmp = peelUnrealized(adaptor.getTmp()); + SmallVector operands{dst, src0, src1, tmp}; + rewriter.create( + loc, TypeRange{}, "TXOR", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +struct PTOTTransToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TTransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, tmp}; + rewriter.create( + loc, TypeRange{}, "TTRANS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TXORS DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOXORSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TXorSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, scalar, tmp}; + rewriter.create( + loc, TypeRange{}, "TXORS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + struct PTOPrintToTPRINT : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + + SmallVector operands{src}; + rewriter.create( + loc, TypeRange{}, "TPRINT", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +// pto.print "format", %scalar -> PRINTF("format", scalar) +struct PTOPrintOpToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::PrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + std::string fmt = op.getFormat().str(); + if (fmt.empty()) + fmt = "%f"; + std::string quoted = "\""; + for (char c : fmt) { + if (c == '"' || c == '\\') + quoted += '\\'; + else if (c == '\n') + quoted += "\\n"; + else if (c == '\t') + quoted += "\\t"; + else + quoted += c; + } + quoted += "\""; + + Value scalar = peelUnrealized(adaptor.getScalar()); + auto argsAttr = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, quoted), + IntegerAttr::get(IndexType::get(ctx), 0)}); + rewriter.create( + loc, TypeRange{}, "cce::printf", + /*args=*/argsAttr, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{scalar}); + + rewriter.eraseOp(op); + return success(); + } +}; + +// pto.trap -> TRAP() +struct PTOTrapOpToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TrapOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + rewriter.create( + loc, TypeRange{}, "trap", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{}); + + rewriter.eraseOp(op); + return success(); + } +}; + +// ============================================================================= +// 2. BindTileOp Lowering (FIX: Trace back to physical address) +// ============================================================================= +struct PTOBindTileToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + struct TileBuildSpec { + std::string tileTypeStr; + bool useConstructor = false; + SmallVector constructorArgs; + }; + + static bool getIndexConst(Value v, int64_t &out) { + if (!v) + return false; + if (auto cst = v.getDefiningOp()) { + if (auto ia = dyn_cast(cst.getValue())) { + out = ia.getValue().getSExtValue(); + return true; + } + } + return false; + } + + static bool getTilePointerStrides(pto::TileBufConfigAttr configAttr, + Type elemTy, int64_t rows, int64_t cols, + int64_t &rowStride, + int64_t &colStride) { + if (rows == ShapedType::kDynamic || cols == ShapedType::kDynamic) + return false; + + int32_t blVal = 0; + if (auto blAttr = dyn_cast(configAttr.getBLayout())) + blVal = static_cast(blAttr.getValue()); + else if (auto intAttr = dyn_cast(configAttr.getBLayout())) + blVal = static_cast(intAttr.getInt()); + + int32_t slVal = 0; + if (auto slAttr = dyn_cast(configAttr.getSLayout())) + slVal = static_cast(slAttr.getValue()); + else if (auto intAttr = dyn_cast(configAttr.getSLayout())) + slVal = static_cast(intAttr.getInt()); + + bool boxed = slVal != 0; + int64_t innerRows = 1; + int64_t innerCols = 1; + if (boxed) { + int32_t fractal = 512; + if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) + fractal = static_cast(frAttr.getInt()); + + unsigned elemBytes = pto::getPTOStorageElemByteSize(elemTy); + if (elemBytes == 0) + return false; + + switch (fractal) { + case 1024: + innerRows = 16; + innerCols = 16; + break; + case 32: + innerRows = 16; + innerCols = 2; + break; + case 512: + if (slVal == 1) { + innerRows = 16; + innerCols = 32 / elemBytes; + } else if (slVal == 2) { + innerRows = 32 / elemBytes; + innerCols = 16; + } else { + return false; + } + break; + default: + return false; + } + if (innerRows <= 0 || innerCols <= 0) + return false; + } + + if (!boxed) { + if (blVal == 1) { + rowStride = 1; + colStride = rows; + } else { + rowStride = cols; + colStride = 1; + } + return true; + } + + if (blVal == 1) { + if (slVal != 1) + return false; + rowStride = innerCols; + colStride = rows; + return true; + } + + rowStride = cols; + colStride = innerRows; + return true; + } + + LogicalResult matchAndRewrite(pto::BindTileOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto configAttr = op.getConfigAttr(); + auto viewSemantics = op->getAttrOfType("pto.view_semantics"); + bool isSubView = viewSemantics && viewSemantics.getValue() == "subview"; + + auto peelAllCasts = [](Value v) { + while (auto castOp = v.getDefiningOp()) + v = castOp.getOperand(0); + if (auto castOp = v.getDefiningOp()) + v = castOp.getOperand(); + return v; + }; + auto isTileLike = [](Value v) -> bool { + auto ot = dyn_cast(v.getType()); + if (!ot) + return false; + StringRef s = ot.getValue(); + return s.contains("Tile<") || s.contains("ConvTile<"); + }; + auto buildTileSpec = [&]() -> FailureOr { + auto resMrTy = dyn_cast(op.getType()); + if (!resMrTy) + return failure(); + + const char *roleTok = "TileType::Vec"; + if (auto asAttr = + dyn_cast_or_null(resMrTy.getMemorySpace())) { + switch (asAttr.getAddressSpace()) { + case pto::AddressSpace::VEC: + roleTok = "TileType::Vec"; + break; + case pto::AddressSpace::MAT: + roleTok = "TileType::Mat"; + break; + case pto::AddressSpace::LEFT: + roleTok = "TileType::Left"; + break; + case pto::AddressSpace::RIGHT: + roleTok = "TileType::Right"; + break; + case pto::AddressSpace::ACC: + roleTok = "TileType::Acc"; + break; + case pto::AddressSpace::BIAS: + roleTok = "TileType::Bias"; + break; + case pto::AddressSpace::SCALING: + roleTok = "TileType::Scaling"; + break; + case pto::AddressSpace::GM: + case pto::AddressSpace::Zero: + roleTok = "TileType::Vec"; + break; + } + } + + Type elemTy = resMrTy.getElementType(); + Type emitElemTy = getTypeConverter()->convertType(elemTy); + if (!emitElemTy) + return failure(); + auto emitElemOpaque = dyn_cast(emitElemTy); + if (!emitElemOpaque) + return failure(); + std::string elemTypeStr = emitElemOpaque.getValue().str(); + + if (resMrTy.getRank() < 2) + return failure(); + int64_t rows = resMrTy.getDimSize(0); + int64_t cols = resMrTy.getDimSize(1); + if (rows == ShapedType::kDynamic || cols == ShapedType::kDynamic) + return failure(); + + std::string blTok = "BLayout::RowMajor"; + if (auto blAttr = dyn_cast(configAttr.getBLayout())) { + if (static_cast(blAttr.getValue()) == 1) + blTok = "BLayout::ColMajor"; + } + pto::BLayout blayout = getTileBufBLayoutValue(configAttr); + + if (isSubView) { + auto subMrTy = dyn_cast(op.getSource().getType()); + auto subViewOp = op.getSource().getDefiningOp(); + if (subMrTy && subMrTy.getRank() >= 2 && subViewOp) { + int64_t subRows = subMrTy.getDimSize(0); + int64_t subCols = subMrTy.getDimSize(1); + SmallVector inheritedStrides; + int64_t inheritedOffset = ShapedType::kDynamic; + + if (!pto::isPTOFloat4PackedType(elemTy) && + subRows != ShapedType::kDynamic && + subCols != ShapedType::kDynamic && + succeeded(getStridesAndOffset(subMrTy, inheritedStrides, + inheritedOffset)) && + inheritedStrides.size() >= 2) { + int64_t childRowStride = 0; + int64_t childColStride = 0; + bool sameStrides = getTilePointerStrides( + configAttr, elemTy, subRows, subCols, childRowStride, + childColStride); + sameStrides = sameStrides && + inheritedStrides[0] == childRowStride && + inheritedStrides[1] == childColStride; + if (sameStrides) { + rows = subRows; + cols = subCols; + } + } + } + } + + std::string slTok = "SLayout::NoneBox"; + if (auto slAttr = dyn_cast(configAttr.getSLayout())) { + int32_t slVal = static_cast(slAttr.getValue()); + slTok = (slVal == 1) ? "SLayout::RowMajor" + : (slVal == 2) ? "SLayout::ColMajor" + : "SLayout::NoneBox"; + } + + int32_t fractal = 512; + if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) + fractal = frAttr.getInt(); + + std::string padTok = "PadValue::Null"; + if (auto padAttr = dyn_cast(configAttr.getPad())) { + switch (static_cast(padAttr.getValue())) { + case 1: + padTok = "PadValue::Zero"; + break; + case 2: + padTok = "PadValue::Max"; + break; + case 3: + padTok = "PadValue::Min"; + break; + default: + padTok = "PadValue::Null"; + break; + } + } + + std::string compactTok = "CompactMode::Null"; + if (auto compactAttr = dyn_cast(configAttr.getCompactMode())) { + switch (static_cast(compactAttr.getValue())) { + case 1: + compactTok = "CompactMode::Normal"; + break; + case 2: + compactTok = "CompactMode::RowPlusOne"; + break; + default: + compactTok = "CompactMode::Null"; + break; + } + } + + std::string vrowTok, vcolTok; + bool useConstructor = false; + bool rowIsDynamic = false; + bool colIsDynamic = false; + SmallVector constructorArgs; + + Value vRow = op.getValidRow(); + Value vCol = op.getValidCol(); + Value vRowEmitC = adaptor.getValidRow(); + Value vColEmitC = adaptor.getValidCol(); + bool forceDynamicValid = op->hasAttr(kForceDynamicValidShapeAttrName); + int64_t cRow = 0, cCol = 0; + bool rowIsConst = vRow && getIndexConst(vRow, cRow); + bool colIsConst = vCol && getIndexConst(vCol, cCol); + + auto makeCtorDimValue = [&](Value emitted, int64_t fallback) -> Value { + if (emitted) + return emitted; + return makeEmitCIntConstant( + rewriter, loc, emitc::OpaqueType::get(ctx, "int32_t"), fallback); + }; + auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { + if (!emitted || !pto::isPTOFloat4PackedType(elemTy)) + return emitted; + int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; + if (dimIdx != packedDim) + return emitted; + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); + return rewriter.create(loc, i32Ty, emitted, two).getResult(); + }; + + if (forceDynamicValid) { + vrowTok = "-1"; + vcolTok = "-1"; + useConstructor = true; + constructorArgs.push_back( + makeCtorDimValue(maybeScaleDynamicValid(vRowEmitC, 0), + renderTileTemplateDim(rowIsConst ? cRow : rows, + elemTy, blayout, 0))); + constructorArgs.push_back( + makeCtorDimValue(maybeScaleDynamicValid(vColEmitC, 1), + renderTileTemplateDim(colIsConst ? cCol : cols, + elemTy, blayout, 1))); + } else { + if (rowIsConst) { + vrowTok = std::to_string( + renderTileTemplateDim(cRow, elemTy, blayout, 0)); + } else if (vRow) { + vrowTok = "-1"; + rowIsDynamic = true; + useConstructor = true; + } else { + vrowTok = std::to_string( + renderTileTemplateDim(rows, elemTy, blayout, 0)); + } + + if (colIsConst) { + vcolTok = std::to_string( + renderTileTemplateDim(cCol, elemTy, blayout, 1)); + } else if (vCol) { + vcolTok = "-1"; + colIsDynamic = true; + useConstructor = true; + } else { + vcolTok = std::to_string( + renderTileTemplateDim(cols, elemTy, blayout, 1)); + } + + if (useConstructor) { + if (rowIsDynamic && vRowEmitC) + constructorArgs.push_back(maybeScaleDynamicValid(vRowEmitC, 0)); + if (colIsDynamic && vColEmitC) + constructorArgs.push_back(maybeScaleDynamicValid(vColEmitC, 1)); + } + } + + std::string tileTypeStr = std::string("Tile<") + roleTok + ", " + + elemTypeStr + ", " + + std::to_string(renderTileTemplateDim( + rows, elemTy, blayout, 0)) + + ", " + + std::to_string(renderTileTemplateDim( + cols, elemTy, blayout, 1)) + + ", " + blTok + + ", " + vrowTok + ", " + vcolTok + ", " + slTok + + ", " + std::to_string(fractal) + ", " + padTok + + ", " + compactTok + + ">"; + return TileBuildSpec{tileTypeStr, useConstructor, constructorArgs}; + }; + + auto buildTileValue = [&](const TileBuildSpec &spec, + bool forceDeclaration = false) -> Value { + auto tileType = emitc::OpaqueType::get(ctx, spec.tileTypeStr); + if (spec.useConstructor && !forceDeclaration) { + return rewriter + .create(loc, tileType, spec.tileTypeStr, + ArrayAttr{}, ArrayAttr{}, + ValueRange(spec.constructorArgs)) + .getResult(0); + } + + return rewriter + .create(loc, tileType, emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + }; + + auto emitElemTypeToString = [&](Type elemTy) -> std::string { + return getEmitCScalarTypeToken(elemTy); + }; + + auto buildIntegralAddress = [&](Value sourceValue) -> FailureOr { + auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); + auto rcU64 = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + + Value rawPtr = sourceValue; + if (auto ot = dyn_cast(sourceValue.getType())) { + StringRef tyStr = ot.getValue(); + if (tyStr.contains("Tile<") || tyStr.contains("ConvTile<")) { + auto srcMrTy = dyn_cast(op.getSource().getType()); + if (!srcMrTy) + return failure(); + std::string elemTok = emitElemTypeToString(srcMrTy.getElementType()); + pto::AddressSpace as = pto::AddressSpace::GM; + if (auto asAttr = + dyn_cast_or_null(srcMrTy.getMemorySpace())) + as = asAttr.getAddressSpace(); + rawPtr = materializeTileDataValue(rewriter, loc, sourceValue, as, + elemTok); + } + } + + if (isSetFFTsPointerLikeType(rawPtr.getType())) { + return rewriter + .create(loc, u64Ty, "reinterpret_cast", + ArrayAttr{}, rcU64, ValueRange{rawPtr}) + .getResult(0); + } + + if (rawPtr.getType() == u64Ty) + return rawPtr; + return rewriter.create(loc, u64Ty, rawPtr).getResult(); + }; + + if (op.getSource().getDefiningOp()) { + FailureOr tileSpec = buildTileSpec(); + if (failed(tileSpec)) + return failure(); + rewriter.replaceOp(op, buildTileValue(*tileSpec)); + return success(); + } + + Value tileCandidate = peelAllCasts(adaptor.getSource()); + if (viewSemantics && viewSemantics.getValue() == "bitcast" && + isTileLike(tileCandidate)) { + FailureOr tileSpec = buildTileSpec(); + if (failed(tileSpec)) + return failure(); + Value dstTile = buildTileValue(*tileSpec); + FailureOr addr = buildIntegralAddress(tileCandidate); + if (failed(addr)) + return failure(); + + rewriter.create(loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dstTile, *addr}); + rewriter.replaceOp(op, dstTile); + return success(); + } + + if (viewSemantics && viewSemantics.getValue() == "treshape" && + isTileLike(tileCandidate)) { + FailureOr tileSpec = buildTileSpec(); + if (failed(tileSpec)) + return failure(); + Value dstTile = buildTileValue(*tileSpec, /*forceDeclaration=*/true); + + rewriter.create(loc, TypeRange{}, "TRESHAPE", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dstTile, tileCandidate}); + rewriter.replaceOp(op, dstTile); + return success(); + } + + // Subview origins are kept distinct from generic tile rebinding: + // even when source/destination C++ tile types match, subview may carry + // shifted base address semantics and should materialize a fresh handle. + if (isSubView) { + FailureOr tileSpec = buildTileSpec(); + if (failed(tileSpec)) + return failure(); + Value dstTile = buildTileValue(*tileSpec); + FailureOr addr = buildIntegralAddress(tileCandidate); + if (failed(addr)) + return failure(); + + rewriter.create(loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dstTile, *addr}); + rewriter.replaceOp(op, dstTile); + return success(); + } + + // Generic tile-to-tile rebind path: preserve the same backing storage and + // rebuild a sibling tile with updated metadata/valid dims. + if (isTileLike(tileCandidate)) { + FailureOr tileSpec = buildTileSpec(); + if (failed(tileSpec)) + return failure(); + + if (!tileSpec->useConstructor) { + if (auto srcTy = dyn_cast(tileCandidate.getType())) { + if (srcTy.getValue() == tileSpec->tileTypeStr) { + rewriter.replaceOp(op, tileCandidate); + return success(); + } + } + } + + Value dstTile = buildTileValue(*tileSpec); + FailureOr addr = buildIntegralAddress(tileCandidate); + if (failed(addr)) + return failure(); + + rewriter.create(loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dstTile, *addr}); + rewriter.replaceOp(op, dstTile); + return success(); + } + + SmallVector physAddrs; + Value source = op.getSource(); + + while (auto castOp = source.getDefiningOp()) + source = castOp.getOperand(0); + + if (auto upstreamCast = source.getDefiningOp()) { + auto upstreamOperands = upstreamCast.getAddrs(); + physAddrs.append(upstreamOperands.begin(), upstreamOperands.end()); + } else { + physAddrs.push_back(adaptor.getSource()); + } + + Value vRow = op.getValidRow(); + Value vCol = op.getValidCol(); + + auto newCast = rewriter.create( + loc, op.getType(), physAddrs, vRow ? vRow : Value(), + vCol ? vCol : Value(), configAttr); + if (viewSemantics) + newCast->setAttr("pto.view_semantics", viewSemantics); + if (op->hasAttr(kForceDynamicValidShapeAttrName)) + newCast->setAttr(kForceDynamicValidShapeAttrName, + op->getAttr(kForceDynamicValidShapeAttrName)); + rewriter.replaceOp(op, newCast.getResult()); + + return success(); + } +}; + +struct PTOAllocTileToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::AllocTileOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + auto tileTy = cast(op.getResult().getType()); + auto tileTypeString = getEmitCTileTypeString(tileTy); + if (!tileTypeString) + return rewriter.notifyMatchFailure( + op, "only rank-2 alloc_tile handles can be converted to EmitC"); + + Type convertedTy = getTypeConverter()->convertType(tileTy); + if (!convertedTy) + convertedTy = emitc::OpaqueType::get(ctx, *tileTypeString); + + auto validShape = tileTy.getValidShape(); + bool hasDynamicValidDim = + llvm::any_of(validShape, [](int64_t dim) { return dim < 0; }); + bool useConstructor = hasDynamicValidDim; + + SmallVector constructorArgs; + if (useConstructor) { + Type elemTy = tileTy.getElementType(); + pto::BLayout blayout = getTileBufBLayoutValue(tileTy.getConfigAttr()); + auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { + if (!emitted || !pto::isPTOFloat4PackedType(elemTy)) + return emitted; + int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; + if (dimIdx != packedDim) + return emitted; + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); + return rewriter.create(loc, i32Ty, emitted, two) + .getResult(); + }; + + if (validShape.size() > 0 && validShape[0] < 0) { + Value validRow = adaptor.getValidRow(); + if (!validRow) + return rewriter.notifyMatchFailure( + op, "dynamic alloc_tile valid row must have an operand"); + if (validRow) + validRow = peelUnrealized(validRow); + constructorArgs.push_back(maybeScaleDynamicValid(validRow, 0)); + } + if (validShape.size() > 1 && validShape[1] < 0) { + Value validCol = adaptor.getValidCol(); + if (!validCol) + return rewriter.notifyMatchFailure( + op, "dynamic alloc_tile valid col must have an operand"); + if (validCol) + validCol = peelUnrealized(validCol); + constructorArgs.push_back(maybeScaleDynamicValid(validCol, 1)); + } + } + + Value tile; + if (useConstructor) { + tile = rewriter + .create( + loc, convertedTy, *tileTypeString, ArrayAttr{}, + ArrayAttr{}, ValueRange(constructorArgs)) + .getResult(0); + } else { + tile = + rewriter + .create( + loc, convertedTy, emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + } + + Value addr = adaptor.getAddr(); + if (addr) { + addr = peelUnrealized(addr); + auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); + if (isa(addr.getType()) || + (isa(addr.getType()) && + cast(addr.getType()).getValue().ends_with("*"))) { + auto rcU64 = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + addr = rewriter + .create(loc, u64Ty, "reinterpret_cast", + ArrayAttr{}, rcU64, + ValueRange{addr}) + .getResult(0); + } else if (addr.getType() != u64Ty) { + addr = rewriter.create(loc, u64Ty, addr).getResult(); + } + + rewriter.create(loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{tile, addr}); + } + + rewriter.replaceOp(op, tile); + return success(); + } +}; + +static FailureOr +createEmitCTileVariable(ConversionPatternRewriter &rewriter, Location loc, + const TypeConverter *typeConverter, + pto::TileBufType tileTy) { + auto tileTypeString = getEmitCTileTypeString(tileTy); + if (!tileTypeString) + return failure(); + + Type convertedTy = typeConverter->convertType(tileTy); + if (!convertedTy) + convertedTy = emitc::OpaqueType::get(rewriter.getContext(), *tileTypeString); + + return rewriter + .create( + loc, convertedTy, emitc::OpaqueAttr::get(rewriter.getContext(), "")) + .getResult(); +} + +struct PTOTReshapeToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TReshapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto tileTy = dyn_cast(op.getResult().getType()); + if (!tileTy) + return failure(); + + FailureOr dst = + createEmitCTileVariable(rewriter, op.getLoc(), getTypeConverter(), tileTy); + if (failed(dst)) + return failure(); + + Value src = peelUnrealized(adaptor.getSrc()); + if (auto castOp = src.getDefiningOp()) + src = castOp.getOperand(); + + rewriter.create(op.getLoc(), TypeRange{}, "TRESHAPE", + ArrayAttr{}, ArrayAttr{}, + ValueRange{*dst, src}); + rewriter.replaceOp(op, *dst); + return success(); + } +}; + +struct PTOBitcastToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::BitcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto dstTy = dyn_cast(op.getResult().getType()); + auto srcTy = dyn_cast(op.getSrc().getType()); + if (!dstTy || !srcTy) + return failure(); + + FailureOr dst = + createEmitCTileVariable(rewriter, op.getLoc(), getTypeConverter(), dstTy); + if (failed(dst)) + return failure(); + + Value src = peelUnrealized(adaptor.getSrc()); + if (auto castOp = src.getDefiningOp()) + src = castOp.getOperand(); + + pto::AddressSpace as = pto::AddressSpace::GM; + if (auto asAttr = + dyn_cast_or_null(srcTy.getMemorySpace())) + as = asAttr.getAddressSpace(); + std::string elemTok = getEmitCScalarTypeToken(srcTy.getElementType()); + + Value rawPtr = materializeTileDataValue(rewriter, op.getLoc(), src, as, elemTok); + auto u64Ty = emitc::OpaqueType::get(rewriter.getContext(), "uint64_t"); + Value addr = rawPtr; + if (isSetFFTsPointerLikeType(rawPtr.getType())) { + auto rcU64 = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), + "uint64_t")}); + addr = rewriter + .create(op.getLoc(), u64Ty, + "reinterpret_cast", ArrayAttr{}, + rcU64, ValueRange{rawPtr}) + .getResult(0); + } else if (addr.getType() != u64Ty) { + addr = rewriter.create(op.getLoc(), u64Ty, addr).getResult(); + } + + rewriter.create(op.getLoc(), TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{*dst, addr}); + rewriter.replaceOp(op, *dst); + return success(); + } +}; + +struct PTOMaterializeTileToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + static bool isTileLike(Value v) { + auto ot = dyn_cast(v.getType()); + if (!ot) + return false; + StringRef s = ot.getValue(); + return s.contains("Tile<") || s.contains("ConvTile<"); + } + + LogicalResult matchAndRewrite(pto::MaterializeTileOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + auto tileTy = cast(op.getResult().getType()); + auto tileTypeString = getEmitCTileTypeString(tileTy); + if (!tileTypeString) + return rewriter.notifyMatchFailure( + op, "only rank-2 tile_buf handles can be materialized to EmitC"); + + Type convertedTy = getTypeConverter()->convertType(tileTy); + if (!convertedTy) + convertedTy = emitc::OpaqueType::get(ctx, *tileTypeString); + + Value source = peelUnrealized(adaptor.getSource()); + if (auto castOp = source.getDefiningOp()) + source = castOp.getOperand(); + + auto viewSemantics = op->getAttrOfType("pto.view_semantics"); + bool forceDynamicValid = op->hasAttr(kForceDynamicValidShapeAttrName); + bool isReshape = viewSemantics && viewSemantics.getValue() == "treshape"; + bool isSubview = viewSemantics && viewSemantics.getValue() == "subview"; + bool sourceIsDeclaredTile = + op.getSource().getDefiningOp(); + + auto createTileValue = [&]() -> Value { + SmallVector constructorArgs; + bool useConstructor = false; + pto::BLayout blayout = getTileBufBLayoutValue(tileTy.getConfigAttr()); + Type elemTy = tileTy.getElementType(); + auto shape = tileTy.getShape(); + auto validShape = tileTy.getValidShape(); + + auto makeCtorDimValue = [&](Value emitted, int64_t fallback) -> Value { + if (emitted) + return emitted; + return makeEmitCIntConstant( + rewriter, loc, emitc::OpaqueType::get(ctx, "int32_t"), fallback); + }; + auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { + if (!emitted || !pto::isPTOFloat4PackedType(elemTy)) + return emitted; + int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; + if (dimIdx != packedDim) + return emitted; + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); + return rewriter.create(loc, i32Ty, emitted, two).getResult(); + }; + auto fallbackDim = [&](int dimIdx) { + return renderTileTemplateDim(shape[dimIdx], elemTy, blayout, dimIdx); + }; + + if (forceDynamicValid) { + useConstructor = true; + constructorArgs.push_back(makeCtorDimValue( + maybeScaleDynamicValid(adaptor.getValidRow(), 0), fallbackDim(0))); + constructorArgs.push_back(makeCtorDimValue( + maybeScaleDynamicValid(adaptor.getValidCol(), 1), fallbackDim(1))); + } else { + if (validShape[0] == ShapedType::kDynamic) { + useConstructor = true; + constructorArgs.push_back(makeCtorDimValue( + maybeScaleDynamicValid(adaptor.getValidRow(), 0), fallbackDim(0))); + } + if (validShape[1] == ShapedType::kDynamic) { + useConstructor = true; + constructorArgs.push_back(makeCtorDimValue( + maybeScaleDynamicValid(adaptor.getValidCol(), 1), fallbackDim(1))); + } + } + + if (useConstructor) { + return rewriter + .create(loc, convertedTy, *tileTypeString, + ArrayAttr{}, ArrayAttr{}, + ValueRange(constructorArgs)) + .getResult(0); + } + + return rewriter + .create(loc, convertedTy, + emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + }; + + if (!isSubview && !forceDynamicValid && isTileLike(source)) { + if (auto srcTy = dyn_cast(source.getType())) { + if (srcTy.getValue() == *tileTypeString) { + rewriter.replaceOp(op, source); + return success(); + } + } + } + + Value tile = createTileValue(); + if (sourceIsDeclaredTile) { + rewriter.replaceOp(op, tile); + return success(); + } + + if (isReshape && isTileLike(source)) { + rewriter.create(loc, TypeRange{}, "TRESHAPE", + ArrayAttr{}, ArrayAttr{}, + ValueRange{tile, source}); + rewriter.replaceOp(op, tile); + return success(); + } + + pto::AddressSpace as = pto::AddressSpace::GM; + if (auto asAttr = + dyn_cast_or_null(tileTy.getMemorySpace())) + as = asAttr.getAddressSpace(); + std::string elemTok = getEmitCScalarTypeToken(tileTy.getElementType()); + + Value rawPtr = source; + if (isTileLike(rawPtr)) + rawPtr = materializeTileDataValue(rewriter, loc, rawPtr, as, elemTok); + + auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); + Value addr = rawPtr; + if (isSetFFTsPointerLikeType(rawPtr.getType())) { + auto rcU64 = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + addr = rewriter + .create(loc, u64Ty, "reinterpret_cast", + ArrayAttr{}, rcU64, + ValueRange{rawPtr}) + .getResult(0); + } else if (rawPtr.getType() != u64Ty) { + addr = rewriter.create(loc, u64Ty, rawPtr).getResult(); + } + + rewriter.create(loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{tile, addr}); + rewriter.replaceOp(op, tile); + return success(); + } +}; + +// ============================================================================= +// Arith CmpI -> EmitC Cmp +// ============================================================================= +class ArithCmpIToEmitC : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + // 将 arith.cmpi 转换为 emitc.cmp + // 映射 Predicate: eq -> equal, slt -> less, etc. + emitc::CmpPredicate emitcPred = emitc::CmpPredicate::eq; + const bool isUnsignedPred = + op.getPredicate() == arith::CmpIPredicate::ult || + op.getPredicate() == arith::CmpIPredicate::ule || + op.getPredicate() == arith::CmpIPredicate::ugt || + op.getPredicate() == arith::CmpIPredicate::uge; + switch (op.getPredicate()) { + case arith::CmpIPredicate::eq: emitcPred = emitc::CmpPredicate::eq; break; + case arith::CmpIPredicate::ne: emitcPred = emitc::CmpPredicate::ne; break; + case arith::CmpIPredicate::slt: emitcPred = emitc::CmpPredicate::lt; break; + case arith::CmpIPredicate::sle: emitcPred = emitc::CmpPredicate::le; break; + case arith::CmpIPredicate::sgt: emitcPred = emitc::CmpPredicate::gt; break; + case arith::CmpIPredicate::sge: emitcPred = emitc::CmpPredicate::ge; break; + // ... 处理无符号比较 (ult, ule 等) ... + case arith::CmpIPredicate::ult: emitcPred = emitc::CmpPredicate::lt; break; + case arith::CmpIPredicate::ule: emitcPred = emitc::CmpPredicate::le; break; + case arith::CmpIPredicate::ugt: emitcPred = emitc::CmpPredicate::gt; break; + case arith::CmpIPredicate::uge: emitcPred = emitc::CmpPredicate::ge; break; + } + + Type resTy = getTypeConverter()->convertType(op.getType()); + if (!resTy) + return failure(); + + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + if (isUnsignedPred) { + Type opTy = op.getLhs().getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure( + op, "expected scalar integer or index operands"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + if (bitWidth != 1) { + lhs = castSignlessIntToUnsignedSameWidth(rewriter, loc, lhs, bitWidth); + rhs = castSignlessIntToUnsignedSameWidth(rewriter, loc, rhs, bitWidth); + } + } + + rewriter.replaceOpWithNewOp( + op, + /*resultType=*/resTy, // i1 -> bool/i1 + emitcPred, + lhs, + rhs + ); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Section Op Lowering +//===----------------------------------------------------------------------===// +static bool isA5NoSplitPipeOp(Operation *op) { + if (auto talloc = dyn_cast(op)) + return talloc.getSplit() == 0; + if (auto tpush = dyn_cast(op)) + return tpush.getSplit() == 0; + if (auto tpop = dyn_cast(op)) + return tpop.getSplit() == 0; + if (auto tfree = dyn_cast(op)) + return tfree.getSplit() == 0; + if (auto tpush = dyn_cast(op)) + return tpush.getSplit() == 0; + if (auto tpush = dyn_cast(op)) + return tpush.getSplit() == 0; + if (auto talloc = dyn_cast(op)) + return talloc.getSplit() == 0; + if (auto talloc = dyn_cast(op)) + return talloc.getSplit() == 0; + if (auto tpop = dyn_cast(op)) + return tpop.getSplit() == 0; + if (auto tpop = dyn_cast(op)) + return tpop.getSplit() == 0; + if (auto tfree = dyn_cast(op)) + return tfree.getSplit() == 0; + if (auto tfree = dyn_cast(op)) + return tfree.getSplit() == 0; + return false; +} + +static bool hasExplicitSubblockControl(Operation *op) { + bool hasControl = false; + op->walk([&](Operation *nested) { + if (isa(nested)) { + hasControl = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return hasControl; +} + +static bool needsA5NoSplitVectorGuard(Operation *op) { + auto arch = getTargetArch(op); + if (arch != PTOArch::A5) + return false; + bool isVectorScope = isa(op); + if (auto func = dyn_cast(op)) { + if (auto kernelKindAttr = + func->getAttrOfType( + FunctionKernelKindAttr::name)) { + isVectorScope = + kernelKindAttr.getKernelKind() == FunctionKernelKind::Vector; + } + } + if (!isVectorScope) + return false; + if (hasExplicitSubblockControl(op)) + return false; + + bool hasNoSplitPipe = false; + op->walk([&](Operation *nested) { + if (!isA5NoSplitPipeOp(nested)) + return WalkResult::advance(); + hasNoSplitPipe = true; + return WalkResult::interrupt(); + }); + return hasNoSplitPipe; +} + +template +struct SectionToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + std::string getMacroName() const { + if (std::is_same::value) + return "__DAV_CUBE__"; + if (std::is_same::value) + return "__DAV_VEC__"; + return "UNKNOWN_MACRO"; + } + + LogicalResult + matchAndRewrite(SectionOpTy op, typename SectionOpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + bool needsNoSplitGuard = needsA5NoSplitVectorGuard(op.getOperation()); + + std::string startMacro = "\n#if defined(" + getMacroName() + ")"; + rewriter.create(loc, startMacro); + + if constexpr (std::is_same_v) { + // Vector mask is a global HW state and may be modified by previous kernels + // (or earlier sections). Reset it to a well-defined state for deterministic + // execution of VEC ops. + rewriter.create(loc, "set_mask_norm();"); + rewriter.create(loc, "set_vector_mask(-1, -1);"); + } + + if (needsNoSplitGuard) { + rewriter.create( + loc, "if (get_subblockid() == 0) {"); + } + + Block &innerBlock = op.getBody().front(); + if (!innerBlock.empty()) { + rewriter.inlineBlockBefore(&innerBlock, op.getOperation(), ValueRange{}); + } + + if (needsNoSplitGuard) + rewriter.create(loc, "}"); + + std::string endMacro = "#endif // " + getMacroName() + "\n"; + rewriter.create(loc, endMacro); + + rewriter.eraseOp(op); + + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// SCF Control-Flow Pre-Lowering +// +// EmitC translation supports `emitc.for`/`emitc.if` plus CFG-style +// `cf.br`/`cf.cond_br`. Upstream SCFToEmitC patterns only cover `scf.for` and +// `scf.if`, so we pre-lower some SCF ops into those supported forms. +//===----------------------------------------------------------------------===// + +namespace { + +static bool isTriviallyInlineableExecuteRegion(scf::ExecuteRegionOp op) { + Region &r = op.getRegion(); + if (!r.hasOneBlock()) + return false; + Block &b = r.front(); + return isa_and_nonnull(b.getTerminator()); +} + +static bool needsWholeFunctionSCFToCF(func::FuncOp func) { + bool needs = false; + func.walk([&](Operation *op) { + if (!isa(op)) + return WalkResult::advance(); + Operation *parentOp = op->getParentOp(); + + // `scf.execute_region` can legally appear in single-block parents. Only + // require whole-function SCFToCF if we need to lower it into CFG blocks + // (multi-block region / non-trivial terminators). + if (auto exec = dyn_cast(op)) { + if (parentOp && parentOp->hasTrait() && + !isTriviallyInlineableExecuteRegion(exec)) { + needs = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + } + + if (parentOp && parentOp->hasTrait()) { + needs = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return needs; +} + +// scf.execute_region is semantically just an inlined region producing results +// via scf.yield. Inline it to the parent block to avoid extra lowering needs. +struct SCFExecuteRegionInline + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ExecuteRegionOp op, + PatternRewriter &rewriter) const override { + if (op.getRegion().empty()) + return rewriter.notifyMatchFailure(op, "expected non-empty region"); + + Block &innerBlock = op.getRegion().front(); + auto yield = dyn_cast(innerBlock.getTerminator()); + if (!yield) + return rewriter.notifyMatchFailure(op, "expected scf.yield terminator"); + + // Move the body operations before the execute_region op. + rewriter.inlineBlockBefore(&innerBlock, op.getOperation(), ValueRange{}); + + // Replace execute_region results with yielded values, then erase the yield. + rewriter.replaceOp(op, yield.getOperands()); + rewriter.eraseOp(yield); + return success(); + } +}; + +// Lower scf.execute_region into CFG blocks with cf.br/cf.cond_br by inlining the +// region blocks into the parent region and rewriting scf.yield to branch into a +// continuation block carrying results. +// +// Note: This requires the parent region to allow multiple blocks (e.g. the +// function body CFG region). For execute_region nested in single-block regions +// (scf.for/scf.if), run SCFToCF first to eliminate the single-block constraint. +struct SCFExecuteRegionToCF : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ExecuteRegionOp op, + PatternRewriter &rewriter) const override { + if (isTriviallyInlineableExecuteRegion(op)) + return rewriter.notifyMatchFailure(op, "trivially inlineable"); + + Operation *parentOp = op->getParentOp(); + if (parentOp && parentOp->hasTrait()) { + return rewriter.notifyMatchFailure( + op, "cannot lower scf.execute_region inside a single-block parent region"); + } + + if (op.getRegion().empty()) + return rewriter.notifyMatchFailure(op, "expected non-empty region"); + + Location loc = op.getLoc(); + Block *curBlock = op->getBlock(); + Region *parentRegion = curBlock->getParent(); + + // Split the parent block so we can branch to a continuation block with phi + // arguments for the execute_region results. + auto execIt = Block::iterator(op.getOperation()); + Block *continueBlock = rewriter.splitBlock(curBlock, std::next(execIt)); + + SmallVector contArgs; + contArgs.reserve(op.getNumResults()); + for (Type t : op.getResultTypes()) + contArgs.push_back(continueBlock->addArgument(t, loc)); + + for (auto it : llvm::enumerate(op.getResults())) + it.value().replaceAllUsesWith(contArgs[it.index()]); + + // Capture blocks before moving the region. + SmallVector movedBlocks; + movedBlocks.reserve(op.getRegion().getBlocks().size()); + for (Block &b : op.getRegion()) + movedBlocks.push_back(&b); + Block *entryBlock = &op.getRegion().front(); + + // Inline the execute_region blocks into the parent region right before the + // continuation block. + rewriter.inlineRegionBefore(op.getRegion(), *parentRegion, + continueBlock->getIterator()); + + // Replace all scf.yield terminators with a branch to the continuation. + for (Block *b : movedBlocks) { + auto yield = dyn_cast(b->getTerminator()); + if (!yield) + continue; + rewriter.setInsertionPoint(yield); + rewriter.create(loc, continueBlock, yield.getOperands()); + rewriter.eraseOp(yield); + } + + // Replace execute_region itself with a branch to the inlined entry block. + rewriter.setInsertionPoint(op); + rewriter.create(loc, entryBlock, ValueRange{}); + rewriter.eraseOp(op); + return success(); + } +}; + +// Lower scf.index_switch into CFG blocks with cf.cond_br/cf.br so that we can +// avoid `scf.if` result materialization quirks (and avoid relying on cf.switch, +// which is not supported by EmitC C++ translation). +struct SCFIndexSwitchToCF : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + static LogicalResult cloneYieldingBlockAndBranchTo( + PatternRewriter &rewriter, Location loc, Block &srcBlock, Block *destBlock, + Block *continueBlock) { + rewriter.setInsertionPointToEnd(destBlock); + + IRMapping mapping; + for (Operation &inner : srcBlock.without_terminator()) + rewriter.clone(inner, mapping); + + auto yield = dyn_cast(srcBlock.getTerminator()); + if (!yield) + return failure(); + + SmallVector yieldOperands; + yieldOperands.reserve(yield.getNumOperands()); + for (Value v : yield.getOperands()) + yieldOperands.push_back(mapping.lookupOrDefault(v)); + + rewriter.create(loc, continueBlock, yieldOperands); + return success(); + } + + static Block *splitBlockForContinuation(PatternRewriter &rewriter, + scf::IndexSwitchOp op) { + auto switchIt = Block::iterator(op.getOperation()); + return rewriter.splitBlock(op->getBlock(), std::next(switchIt)); + } + + static void addContinuationArguments(PatternRewriter &rewriter, + scf::IndexSwitchOp op, Location loc, + Block *continueBlock) { + SmallVector contArgs; + contArgs.reserve(op.getNumResults()); + for (Type type : op.getResultTypes()) + contArgs.push_back(continueBlock->addArgument(type, loc)); + for (auto result : llvm::enumerate(op.getResults())) + result.value().replaceAllUsesWith(contArgs[result.index()]); + } + + static void createIndexSwitchBlocks(PatternRewriter &rewriter, + Region *parentRegion, + Region::iterator insertPt, + unsigned numCases, + SmallVectorImpl &checkBlocks, + Block *&defaultBlock, + SmallVectorImpl &caseBlocks) { + checkBlocks.reserve(numCases); + caseBlocks.reserve(numCases); + for (unsigned i = 0; i < numCases; ++i) + checkBlocks.push_back(rewriter.createBlock(parentRegion, insertPt)); + defaultBlock = rewriter.createBlock(parentRegion, insertPt); + for (unsigned i = 0; i < numCases; ++i) + caseBlocks.push_back(rewriter.createBlock(parentRegion, insertPt)); + } + + static void populateIndexSwitchCheckBlocks( + PatternRewriter &rewriter, Location loc, Value selector, + ArrayRef cases, ArrayRef checkBlocks, + ArrayRef caseBlocks, Block *defaultBlock) { + for (unsigned i = 0; i < checkBlocks.size(); ++i) { + rewriter.setInsertionPointToEnd(checkBlocks[i]); + Value caseVal = rewriter.create(loc, cases[i]); + Value cond = rewriter.create( + loc, arith::CmpIPredicate::eq, selector, caseVal); + Block *falseDest = + (i + 1 < checkBlocks.size()) ? checkBlocks[i + 1] : defaultBlock; + rewriter.create(loc, cond, caseBlocks[i], ValueRange{}, + falseDest, ValueRange{}); + } + } + + LogicalResult matchAndRewrite(scf::IndexSwitchOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Operation *parentOp = op->getParentOp(); + if (parentOp && parentOp->hasTrait()) { + return rewriter.notifyMatchFailure( + op, "cannot lower scf.index_switch inside a single-block parent region"); + } + + Block *curBlock = op->getBlock(); + Region *parentRegion = curBlock->getParent(); + Block *continueBlock = splitBlockForContinuation(rewriter, op); + addContinuationArguments(rewriter, op, loc, continueBlock); + + unsigned numCases = op.getCases().size(); + auto insertPt = continueBlock->getIterator(); + + SmallVector checkBlocks; + SmallVector caseBlocks; + Block *defaultBlock = nullptr; + createIndexSwitchBlocks(rewriter, parentRegion, insertPt, numCases, + checkBlocks, defaultBlock, caseBlocks); + + Value selector = op.getArg(); + auto cases = op.getCases(); + populateIndexSwitchCheckBlocks(rewriter, loc, selector, cases, checkBlocks, + caseBlocks, defaultBlock); + + // Fill case blocks and default block with cloned bodies + branch to cont. + for (unsigned i = 0; i < numCases; ++i) { + if (failed(cloneYieldingBlockAndBranchTo( + rewriter, loc, op.getCaseBlock(i), caseBlocks[i], continueBlock))) + return rewriter.notifyMatchFailure(op, "expected scf.yield terminator"); + } + if (failed(cloneYieldingBlockAndBranchTo(rewriter, loc, op.getDefaultBlock(), + defaultBlock, continueBlock))) + return rewriter.notifyMatchFailure(op, "expected scf.yield terminator"); + + // Replace the original switch op with a branch into the check chain. + Block *entryDest = numCases ? checkBlocks[0] : defaultBlock; + rewriter.setInsertionPointAfter(op); + rewriter.create(loc, entryDest, ValueRange{}); + rewriter.eraseOp(op); + return success(); + } +}; + +// Lower scf.while into CFG blocks with cf.br/cf.cond_br. +// +// Note: This requires the parent region to allow multiple blocks. In +// particular, scf.if/scf.for regions are single-block and cannot contain this +// lowering. +struct SCFWhileToCF : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + static LogicalResult validateWhileResultUses(scf::WhileOp op) { + Block *parentBlock = op->getBlock(); + for (Value result : op.getResults()) { + for (OpOperand &use : result.getUses()) { + if (use.getOwner()->getBlock() != parentBlock) + return failure(); + } + } + return success(); + } + + static Block *splitAfterWhileBlock(PatternRewriter &rewriter, + scf::WhileOp op) { + auto whileIt = Block::iterator(op.getOperation()); + return rewriter.splitBlock(op->getBlock(), std::next(whileIt)); + } + + static void addWhileExitArguments(PatternRewriter &rewriter, scf::WhileOp op, + Location loc, Block *afterWhileBlock) { + SmallVector exitArgs; + exitArgs.reserve(op.getNumResults()); + for (Type type : op.getResultTypes()) + exitArgs.push_back(afterWhileBlock->addArgument(type, loc)); + for (auto result : llvm::enumerate(op.getResults())) + result.value().replaceAllUsesWith(exitArgs[result.index()]); + } + + static Block *createWhileHeaderBlock(PatternRewriter &rewriter, + scf::WhileOp op, Location loc, + Block *afterWhileBlock) { + SmallVector headerArgTypes; + for (Value init : op.getInits()) + headerArgTypes.push_back(init.getType()); + SmallVector headerArgLocs(headerArgTypes.size(), loc); + return rewriter.createBlock(afterWhileBlock->getParent(), + afterWhileBlock->getIterator(), headerArgTypes, + headerArgLocs); + } + + static Block *createWhileBodyBlock(PatternRewriter &rewriter, scf::WhileOp op, + Location loc, Block *afterWhileBlock) { + Block &afterRegionBlock = op.getAfter().front(); + SmallVector bodyArgTypes(afterRegionBlock.getArgumentTypes().begin(), + afterRegionBlock.getArgumentTypes().end()); + SmallVector bodyArgLocs(bodyArgTypes.size(), loc); + return rewriter.createBlock(afterWhileBlock->getParent(), + afterWhileBlock->getIterator(), bodyArgTypes, + bodyArgLocs); + } + + static void rewriteWhileTerminators(PatternRewriter &rewriter, Location loc, + Block *headerBlock, Block *bodyBlock, + Block *afterWhileBlock) { + auto condOp = cast(headerBlock->getTerminator()); + rewriter.setInsertionPoint(condOp); + rewriter.create(loc, condOp.getCondition(), + /*trueDest=*/bodyBlock, + /*trueOperands=*/condOp.getArgs(), + /*falseDest=*/afterWhileBlock, + /*falseOperands=*/condOp.getArgs()); + rewriter.eraseOp(condOp); + + auto yieldOp = cast(bodyBlock->getTerminator()); + rewriter.setInsertionPoint(yieldOp); + rewriter.create(loc, headerBlock, yieldOp.getOperands()); + rewriter.eraseOp(yieldOp); + } + + LogicalResult matchAndRewrite(scf::WhileOp op, + PatternRewriter &rewriter) const override { + Operation *parentOp = op->getParentOp(); + if (parentOp && parentOp->hasTrait()) { + return rewriter.notifyMatchFailure( + op, "cannot lower scf.while inside a single-block parent region"); + } + + if (failed(validateWhileResultUses(op))) + return rewriter.notifyMatchFailure( + op, "unsupported: while results used outside the parent block"); + + auto loc = op.getLoc(); + Block *afterWhileBlock = splitAfterWhileBlock(rewriter, op); + addWhileExitArguments(rewriter, op, loc, afterWhileBlock); + Block *headerBlock = createWhileHeaderBlock(rewriter, op, loc, + afterWhileBlock); + Block *bodyBlock = createWhileBodyBlock(rewriter, op, loc, afterWhileBlock); + + // Move the before/after region bodies into the new CFG blocks. + Block &afterRegionBlock = op.getAfter().front(); + rewriter.mergeBlocks(&op.getBefore().front(), headerBlock, + headerBlock->getArguments()); + rewriter.mergeBlocks(&afterRegionBlock, bodyBlock, bodyBlock->getArguments()); + rewriteWhileTerminators(rewriter, loc, headerBlock, bodyBlock, + afterWhileBlock); + + // Replace scf.while itself with a branch to the header. + rewriter.setInsertionPoint(op); + rewriter.create(loc, headerBlock, op.getInits()); + rewriter.eraseOp(op); + return success(); + } +}; + +// Lower cf.switch into chained comparisons and cf.cond_br/cf.br. +// +// EmitC C++ translation currently supports cf.br/cf.cond_br, but not cf.switch. +struct CFSwitchToCondBr : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + static SmallVector> + collectSwitchCaseOperands(cf::SwitchOp op) { + SmallVector> caseOperands; + caseOperands.reserve(op.getCaseDestinations().size()); + for (auto range : op.getCaseOperands()) + caseOperands.emplace_back(range.begin(), range.end()); + return caseOperands; + } + + static SmallVector getSwitchCaseValues(cf::SwitchOp op) { + SmallVector caseValues; + if (auto caseValuesAttr = op.getCaseValues()) { + for (APInt value : caseValuesAttr->getValues()) + caseValues.push_back(value); + } + return caseValues; + } + + static SmallVector createSwitchCheckBlocks(PatternRewriter &rewriter, + Region *parentRegion, + Block *curBlock, + size_t numCases) { + auto insertPt = std::next(curBlock->getIterator()); + SmallVector checkBlocks; + checkBlocks.reserve(numCases); + for (size_t i = 0; i < numCases; ++i) + checkBlocks.push_back(rewriter.createBlock(parentRegion, insertPt)); + return checkBlocks; + } + + static LogicalResult populateSwitchCheckBlocks( + PatternRewriter &rewriter, Location loc, Value flag, IntegerType flagTy, + ArrayRef caseValues, ArrayRef caseDests, + ArrayRef> caseOperands, Block *defaultDest, + ValueRange defaultOperands, ArrayRef checkBlocks, + cf::SwitchOp op) { + for (size_t i = 0; i < caseDests.size(); ++i) { + rewriter.setInsertionPointToEnd(checkBlocks[i]); + APInt caseVal = caseValues[i]; + if (caseVal.getBitWidth() != flagTy.getWidth()) { + return rewriter.notifyMatchFailure( + op, "case value bitwidth doesn't match flag type"); + } + + Value caseConst = rewriter.create( + loc, flagTy, rewriter.getIntegerAttr(flagTy, caseVal)); + Value cond = rewriter.create( + loc, arith::CmpIPredicate::eq, flag, caseConst); + Block *falseDest = + (i + 1 < checkBlocks.size()) ? checkBlocks[i + 1] : defaultDest; + ValueRange falseOperands = + (i + 1 < checkBlocks.size()) ? ValueRange{} : defaultOperands; + rewriter.create(loc, cond, caseDests[i], + caseOperands[i], falseDest, + falseOperands); + } + return success(); + } + + LogicalResult matchAndRewrite(cf::SwitchOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Operation *parentOp = op->getParentOp(); + if (parentOp && parentOp->hasTrait()) { + return rewriter.notifyMatchFailure( + op, "cannot lower cf.switch inside a single-block parent region"); + } + + Block *curBlock = op->getBlock(); + Region *parentRegion = curBlock->getParent(); + + Value flag = op.getFlag(); + auto flagTy = dyn_cast(flag.getType()); + if (!flagTy) + return rewriter.notifyMatchFailure(op, "expected integer switch flag"); + + SmallVector defaultOperands(op.getDefaultOperands().begin(), + op.getDefaultOperands().end()); + Block *defaultDest = op.getDefaultDestination(); + + SmallVector caseDests(op.getCaseDestinations().begin(), + op.getCaseDestinations().end()); + SmallVector> caseOperands = collectSwitchCaseOperands(op); + + if (caseDests.empty()) { + rewriter.replaceOpWithNewOp(op, defaultDest, defaultOperands); + return success(); + } + + if (!op.getCaseValues()) + return rewriter.notifyMatchFailure(op, "missing case_values"); + SmallVector caseValues = getSwitchCaseValues(op); + + if (caseValues.size() != caseDests.size()) + return rewriter.notifyMatchFailure(op, "case_values/destinations mismatch"); + if (caseOperands.size() != caseDests.size()) + return rewriter.notifyMatchFailure(op, "case_operands/destinations mismatch"); + + SmallVector checkBlocks = + createSwitchCheckBlocks(rewriter, parentRegion, curBlock, + caseDests.size()); + if (failed(populateSwitchCheckBlocks(rewriter, loc, flag, flagTy, + caseValues, caseDests, caseOperands, + defaultDest, defaultOperands, + checkBlocks, op))) { + return failure(); + } + + // Replace the switch terminator with a branch into the first check block. + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp(op, checkBlocks.front(), + ValueRange{}); + return success(); + } +}; + +} // namespace + +static void populatePTOToEmitCPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx, + DataFlowSolver &solver, + PTOArch targetArch) { + (void)solver; + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx, "pto.set_flag_dyn", + "set_flag"); + patterns.add(typeConverter, ctx, "pto.wait_flag_dyn", + "wait_flag"); + // Backward-compatible aliases used in some downstream branches. + patterns.add(typeConverter, ctx, "pto.set_flag_d", + "set_flag"); + patterns.add(typeConverter, ctx, "pto.wait_flag_d", + "wait_flag"); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add( + typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add( + typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add>( + typeConverter, ctx); + patterns.add>( + typeConverter, ctx); + patterns.add>( + typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add>(typeConverter, + ctx); + patterns.add>(typeConverter, + ctx); + patterns.add>(typeConverter, + ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add>( + typeConverter, ctx, + "pto::comm::TPUT_ASYNC"); + patterns.add>( + typeConverter, ctx, + "pto::comm::TGET_ASYNC"); + patterns.add>(typeConverter, ctx, + "pto::comm::TPUT"); + patterns.add>(typeConverter, ctx, + "pto::comm::TGET"); + patterns.add>(typeConverter, ctx, + "pto::comm::TNOTIFY"); + patterns.add>(typeConverter, ctx, + "pto::comm::TWAIT"); + patterns.add>(typeConverter, ctx, + "pto::comm::TTEST"); + patterns.add>(typeConverter, ctx, + "TBROADCAST"); + patterns.add>(typeConverter, ctx, + "TGATHER"); + patterns.add>(typeConverter, ctx, + "TSCATTER"); + patterns.add>(typeConverter, ctx, + "TREDUCE"); + patterns.add>( + typeConverter, ctx, "PTOAS__ASYNC_EVENT_WAIT"); + patterns.add>( + typeConverter, ctx, "PTOAS__ASYNC_EVENT_TEST"); + patterns.add(typeConverter, ctx, targetArch); + patterns.add(typeConverter, ctx, targetArch); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx, targetArch); + patterns.add(typeConverter, ctx, targetArch); + patterns.add(typeConverter, ctx, targetArch); + patterns.add(typeConverter, ctx, targetArch); + patterns.add(typeConverter, ctx, targetArch); + patterns.add(typeConverter, ctx, targetArch); + patterns.add>(typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add< + PTOTMatmulBiasToTMATMUL_BIAS, + PTOTMatmulMXToTMATMUL_MX, + PTOTMatmulMXAccToTMATMUL_MX_ACC, + PTOTMatmulMXBiasToTMATMUL_MX_BIAS, + PTOTMatmulBiasToTMATMUL_BIAS, + PTOTMatmulMXToTMATMUL_MX, + PTOTMatmulMXAccToTMATMUL_MX_ACC, + PTOTMatmulMXBiasToTMATMUL_MX_BIAS, + PTOTGemvBiasToTGEMV_BIAS, + PTOTGemvMXToTGEMV_MX, + PTOTGemvMXAccToTGEMV_MX, + PTOTGemvMXBiasToTGEMV_MX, + PTOBarrierToEmitC + >(typeConverter, ctx); + + patterns.add(typeConverter, ctx); + + populateSCFToEmitCConversionPatterns(patterns); + // Keep CFG-style branches type-consistent when block argument types are + // converted (e.g. after lowering scf.while to cf.br/cf.cond_br). + populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); +} + +//===----------------------------------------------------------------------===// +// Pass +//===----------------------------------------------------------------------===// + +namespace { +struct EmitPTOManualPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(EmitPTOManualPass) + + PTOArch targetArch; + + EmitPTOManualPass() : targetArch(PTOArch::A3) {} + + explicit EmitPTOManualPass(PTOArch arch) : targetArch(arch) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + LLVM_DEBUG(llvm::dbgs() << "DEBUG: Start PTOToEmitC Pass\n"); + MLIRContext *ctx = &getContext(); + ModuleOp mop = getOperation(); + + if (failed(pto::validatePTOEntryFunctions(mop))) + return signalPassFailure(); + pto::annotatePTOEntryFunctions(mop); + + // A3 requires explicit FFTS base setup for inter-core sync ops. + if (targetArch == PTOArch::A3) { + bool hasMissingSetFFTs = false; + for (auto func : mop.getOps()) { + if (!hasInterCoreSyncOp(func)) + continue; + if (hasSetFFTsOp(func)) + continue; + hasMissingSetFFTs = true; + func.emitError() + << "A3 inter-core sync requires explicit `pto.set_ffts` in the " + "same function when using `pto.sync.set`/`pto.sync.wait`"; + } + if (hasMissingSetFFTs) + return signalPassFailure(); + } + + bool needsEventIdArrayHelper = false; + bool needsTRandomHelper = false; + bool needsGlobalTensorDataHelper = false; + bool needsCommInclude = false; + mop.walk([&](Operation *op) { + if (isa(op)) + needsEventIdArrayHelper = true; + if (isa(op)) + needsTRandomHelper = true; + if (isa(op)) + needsGlobalTensorDataHelper = true; + if (isa(op)) + needsCommInclude = true; + }); + + // 1. 插入头文件 + auto loc = mop->getLoc(); + OpBuilder builder(ctx); + builder.setInsertionPointToStart(mop.getBody()); + builder.create( + loc, "pto/pto-inst.hpp", /*is_standard_include=*/false); + if (needsCommInclude) { + builder.create( + loc, builder.getStringAttr(R"cpp( +#ifndef PIPE_FIX +#define PIPE_FIX PIPE_M +#endif +)cpp")); + builder.create( + loc, "pto/comm/pto_comm_inst.hpp", /*is_standard_include=*/false); + } + builder.create( + loc, builder.getStringAttr("using namespace pto;")); + if (needsGlobalTensorDataHelper) { + builder.create( + loc, builder.getStringAttr(R"cpp( +template +static AICORE inline auto PTOAS__GLOBAL_TENSOR_DATA(Tensor &tensor) + -> decltype(tensor.data()) { + return tensor.data(); +} +)cpp")); + } + if (needsEventIdArrayHelper) { + builder.create( + loc, builder.getStringAttr(R"cpp( +template +struct PTOAS_EventIdArray { + static_assert(N > 0, "PTOAS_EventIdArray requires a positive static size"); + int32_t data[N] = {}; + + AICORE inline int32_t &operator[](int32_t idx) { return data[idx]; } + AICORE inline const int32_t &operator[](int32_t idx) const { return data[idx]; } +}; +)cpp")); + } + if (needsTRandomHelper) { + builder.create( + loc, builder.getStringAttr(R"cpp( +template +static AICORE inline void PTOAS__TRANDOM( + DstTile &dst, uint32_t key0, uint32_t key1, uint32_t counter0, + uint32_t counter1, uint32_t counter2, uint32_t counter3) { + TRandomKey key = {key0, key1}; + TRandomCounter counter = {counter0, counter1, counter2, counter3}; + TRANDOM(dst, key, counter); +} +)cpp")); + } + builder.create( + loc, builder.getStringAttr(R"cpp( +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static AICORE inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} +)cpp")); + // Only inject the bitcast helper when we actually lower ops that need it + // (e.g. arith.bitcast or arith.maximumf/minimumf tie-breaking on zeros). + bool needsBitcastHelper = false; + mop.walk([&](Operation *op) { + if (isa(op)) { + needsBitcastHelper = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (needsBitcastHelper) { + builder.create( + loc, builder.getStringAttr(R"cpp( + template + static inline To ptoas_bitcast(From from) { + static_assert(sizeof(To) == sizeof(From), "ptoas_bitcast: size mismatch"); + To to; + __builtin_memcpy(&to, &from, sizeof(To)); + return to; + } + )cpp")); + } + + // 1.5 Pre-lower SCF constructs not handled by SCFToEmitC. + { + // scf.while / scf.index_switch are lowered via CFG blocks. This is not + // possible inside ops that require single-block regions (e.g. scf.for / + // scf.if). If we see such nesting, lower the entire function to the + // ControlFlow dialect first. + bool needsAnySCFToCF = false; + for (auto func : mop.getOps()) { + if (needsWholeFunctionSCFToCF(func)) { + needsAnySCFToCF = true; + break; + } + } + if (needsAnySCFToCF) { + RewritePatternSet scfToCfPatterns(ctx); + populateSCFToControlFlowConversionPatterns(scfToCfPatterns); + FrozenRewritePatternSet frozenSCFToCF(std::move(scfToCfPatterns)); + + ConversionTarget scfToCfTarget(*ctx); + // Only eliminate the single-block SCF constructs; we'll pre-lower + // scf.while/index_switch/execute_region ourselves afterwards. + scfToCfTarget.addIllegalOp(); + scfToCfTarget.markUnknownOpDynamicallyLegal( + [](Operation *) { return true; }); + + for (auto func : mop.getOps()) { + if (!needsWholeFunctionSCFToCF(func)) + continue; + if (failed(applyPartialConversion(func, scfToCfTarget, + frozenSCFToCF))) { + func.emitError() + << "failed to lower nested SCF to ControlFlow (SCFToCF)"; + return signalPassFailure(); + } + } + } + + RewritePatternSet scfLoweringPatterns(ctx); + scfLoweringPatterns.add(ctx); + (void)applyPatternsAndFoldGreedily(mop, std::move(scfLoweringPatterns)); + + bool hasUnsupportedSCF = false; + mop.walk([&](Operation *op) { + if (isa(op)) { + hasUnsupportedSCF = true; + op->emitError() << "Unsupported SCF op remained after pre-lowering"; + return WalkResult::interrupt(); + } + if (isa(op)) { + hasUnsupportedSCF = true; + op->emitError() + << "Unsupported CF op remained after pre-lowering: cf.switch"; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (hasUnsupportedSCF) + return signalPassFailure(); + } + + PTOToEmitCTypeConverter typeConverter(ctx, targetArch); + + // 2. Pre-convert SCF structural op types (e.g. scf.if/scf.for results) + // using the same type converter. This avoids creating emitc.variable with + // unsupported types such as memref. + { + RewritePatternSet scfTypePatterns(ctx); + ConversionTarget scfTypeTarget(*ctx); + scf::populateSCFStructuralTypeConversionsAndLegality( + typeConverter, scfTypePatterns, scfTypeTarget); + scfTypeTarget.markUnknownOpDynamicallyLegal( + [](Operation *) { return true; }); + + if (failed(applyPartialConversion(mop, scfTypeTarget, + std::move(scfTypePatterns)))) { + mop.emitError("failed to reconcile SCF structural types"); + return signalPassFailure(); + } + } + + // 3. 配置转换目标 + ConversionTarget target(*ctx); + + target.addIllegalDialect(); + target.addIllegalDialect(); + target.addIllegalDialect(); + target.addIllegalDialect(); + + // If we introduced CFG branches (e.g. from scf.while), make sure they are + // updated to use legalized operand types. + target.addDynamicallyLegalOp( + [&](Operation *op) { + return isLegalForBranchOpInterfaceTypeConversionPattern(op, + typeConverter); + }); + + // [关键] 允许 Cast 存在,最后统一清理 + target.addLegalOp(); + + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + + target.addLegalDialect(); + target.addLegalOp(); + + auto solver = std::make_unique(); + solver->load(); + solver->load(); + if (failed(solver->initializeAndRun(getOperation()))) + return signalPassFailure(); + + RewritePatternSet patterns(ctx); + populatePTOToEmitCPatterns(patterns, typeConverter, ctx, *solver, targetArch); + + // 4. 执行转换 + if (failed(applyPartialConversion(mop, target, std::move(patterns)))) { + llvm::errs() << "Conversion FAILED! Rolling back executed.\n"; + return signalPassFailure(); + } + + // ========================================================================= + // 5. [终极清理] + // 顺序至关重要: + // Step A: 先移除所有 Cast,让 Loop 的 Operand 类型变成底层类型 (如 int32) + // Step B: 再根据新的 Operand 类型,修复 Loop IV 的类型 + // ========================================================================= + + // --- Step A: 清理 UnrealizedConversionCastOp --- + // Prefer dropping redundant/unused casts; otherwise lower to emitc.cast + // so the C++ emitter can print it. + auto isEmitCTileLikeType = [](Type ty) { + auto opaqueTy = dyn_cast(ty); + if (!opaqueTy) + return false; + StringRef value = opaqueTy.getValue(); + return value.contains("Tile<") || value.contains("ConvTile<"); + }; + + llvm::SmallVector castsToErase; + bool castCleanupFailed = false; + mop.walk([&](UnrealizedConversionCastOp cast) { + if (castCleanupFailed) + return; + + if (cast->getNumOperands() != 1 || cast->getNumResults() != 1) { + cast.emitError() << "unsupported unrealized_conversion_cast shape"; + castCleanupFailed = true; + return; + } + + Value input = cast.getOperand(0); + Value output = cast.getResult(0); + Type inTy = input.getType(); + Type outTy = output.getType(); + + if (output.use_empty()) { + castsToErase.push_back(cast); + return; + } + + if (inTy == outTy) { + output.replaceAllUsesWith(input); + castsToErase.push_back(cast); + return; + } + + // SCF/CFG type conversion can transiently materialize pointer->memref + // bridge casts. At this stage, the producing value is already in the + // lowered EmitC pointer form; keep it and drop the bridge cast. + if (isEmitCPointerLikeType(inTy) && isa(outTy)) { + output.replaceAllUsesWith(input); + castsToErase.push_back(cast); + return; + } + + // SCF structural type conversion may leave a bridge from the converted + // EmitC tile value back to the original pto.tile_buf type for PTO op + // users. After PTO ops are lowered, the EmitC tile value is the value we + // want to keep. + if (isEmitCTileLikeType(inTy) && isa(outTy)) { + output.replaceAllUsesWith(input); + castsToErase.push_back(cast); + return; + } + + if (emitc::isSupportedEmitCType(inTy) && emitc::isSupportedEmitCType(outTy)) { + OpBuilder builder(cast); + auto c = builder.create(cast.getLoc(), outTy, input); + output.replaceAllUsesWith(c.getResult()); + castsToErase.push_back(cast); + return; + } + + cast.emitError() << "cannot lower unrealized_conversion_cast(" << inTy + << " -> " << outTy << ") to emitc.cast"; + castCleanupFailed = true; + }); + + for (auto cast : castsToErase) + cast.erase(); + + if (castCleanupFailed) + return signalPassFailure(); + + // --- Step A2: Sink casts of emitc.variable "reads" to their use sites --- + // + // SCFToEmitC lowers scf.if/scf.for results via mutable `emitc.variable` and + // `emitc.assign`. During type conversion, casts from the variable handle to + // the converted type may be materialized right after the variable + // declaration, effectively snapshotting the value *before* assignments. That + // produces wrong C++ (use-before-init / stale reads). + // + // Fix by re-materializing the cast at each use site so it reads the variable + // at the point of use. + { + SmallVector castOpsToSink; + mop.walk([&](emitc::CastOp castOp) { + if (castOp.getSource().getDefiningOp()) + castOpsToSink.push_back(castOp); + }); + + for (emitc::CastOp castOp : castOpsToSink) { + Value src = castOp.getSource(); + Type dstTy = castOp.getResult().getType(); + Value oldRes = castOp.getResult(); + + // Replace each use with a freshly inserted cast right before the user. + for (OpOperand &use : llvm::make_early_inc_range(oldRes.getUses())) { + Operation *user = use.getOwner(); + OpBuilder b(user); + b.setInsertionPoint(user); + auto newCast = b.create(castOp.getLoc(), dstTy, src); + use.set(newCast.getResult()); + } + + castOp.erase(); + } + } + + // --- Step B: 修复 Loop 归纳变量 (IV) --- + // 此时 emitc.for 的 operand 已经是 int32 了,我们检查 IV 是否匹配,不匹配则修正 + mop.walk([&](emitc::ForOp forOp) { + Type boundTy = forOp.getLowerBound().getType(); + BlockArgument iv = forOp.getBody()->getArgument(0); + + if (iv.getType() != boundTy) { + iv.setType(boundTy); // 强制将 IV 类型 (index) 修改为与边界一致 (int32) + } + }); + + // --- Step C: 消除冗余 Tile 变量 (Dead Code Elimination) [新增] --- + // 逻辑:如果一个 emitc.variable 没有被读取(use_empty), + // 那么它自己,以及给它赋值的 TASSIGN 都可以删除。 + // 注意:TASSIGN(v15, v9) 会把 v15 作为 Operand 0 使用,所以 v15 不是严格的 use_empty。 + // 我们需要检查:v15 是否除了 TASSIGN 之外没有其他 User。 + + llvm::SmallVector deadVars; + mop.walk([&](emitc::VariableOp varOp) { + // 检查该变量的所有 User + bool isRead = false; + for (Operation* user : varOp.getResult().getUsers()) { + // 如果 User 是 TASSIGN 且变量是第0个参数(dst),不算"读取" + if (auto call = dyn_cast(user)) { + if (call.getCallee() == "TASSIGN" && call.getOperand(0) == varOp.getResult()) { + continue; // 这是一个赋值操作,不算有效使用 + } + } + // 如果还有其他用途(如 TLOAD, TMOV, TMATMUL),则该变量有用 + isRead = true; + break; + } + + if (!isRead) { + deadVars.push_back(varOp); + } + }); + + for (auto varOp : deadVars) { + // 1. 先删除所有使用该变量的 TASSIGN + llvm::SmallVector usersToErase; + for (Operation* user : varOp.getResult().getUsers()) { + // 我们上面已经确认过,剩下的 user 只能是 TASSIGN + usersToErase.push_back(user); + } + for (auto u : usersToErase) u->erase(); + + // 2. 删除变量定义本身 + varOp.erase(); + } + + llvm::SmallVector deadConsts; + mop.walk([&](emitc::ConstantOp constOp) { + if (constOp.getResult().use_empty()) + deadConsts.push_back(constOp); + }); + for (auto constOp : deadConsts) + constOp.erase(); + + // ========================================================================= + } + }; +} // namespace + +std::unique_ptr mlir::pto::createEmitPTOManualPass() { + return std::make_unique(); +} + +std::unique_ptr mlir::pto::createEmitPTOManualPass(PTOArch arch) { + return std::make_unique(arch); +} diff --git a/lib/PTO/Transforms/PTOToEmitC.def b/lib/PTO/Transforms/PTOToEmitC.def deleted file mode 100644 index ea9466da1..000000000 --- a/lib/PTO/Transforms/PTOToEmitC.def +++ /dev/null @@ -1,12903 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -//===- PTOToEmitC.cpp - PTO to EmitC conversion pass ----------------------===// -//===----------------------------------------------------------------------===// - -#pragma GCC diagnostic ignored "-Woverloaded-virtual" -// https://discourse.llvm.org/t/matchandrewrite-hiding-virtual-functions/84933/8 - -#include -#include - -#include "PTO/IR/PTO.h" -#include "PTO/IR/PTOTypeUtils.h" -#include "PTO/IR/PTOSyncUtils.h" -#include "PTO/Transforms/Passes.h" - -#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" -#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" -#include "mlir/Analysis/DataFlowFramework.h" - -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/EmitC/IR/EmitC.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" - -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/TypeRange.h" - -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Target/Cpp/CppEmitter.h" -#include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/Func/Transforms/FuncConversions.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/SCF/Transforms/Patterns.h" -#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h" -#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" - -#include -#include -#include -#include - -#define DEBUG_TYPE "pto-emitc" - -namespace mlir { -#define GEN_PASS_DEF_EMITPTOMANUAL -#include "PTO/Transforms/Passes.h.inc" -} // namespace mlir - -using namespace mlir; -using namespace mlir::pto; - -static std::string getElemTypeStringForGT(Type elemTy); -static bool getStaticMemrefLayout(MemRefType mrTy, - SmallVectorImpl &strides, - int64_t &offset); -static int64_t multiplyOrDynamic(int64_t lhs, int64_t rhs); -static void buildGlobalTensorShapeAndStride(ArrayRef shape, - ArrayRef strides, - SmallVectorImpl &shape5D, - SmallVectorImpl &stride5D); -static std::string joinIntTemplateParams(ArrayRef values); -static SmallVector buildRowMajorStrides(ArrayRef shape); -static std::string getGlobalTensorTypeStringFromShape(Type elemTy, - ArrayRef shape, - StringRef layoutEnum = - "pto::Layout::ND"); -static std::string getGlobalTensorTypeStringFromShapeAndStrides( - Type elemTy, ArrayRef shape, ArrayRef strides, - StringRef layoutEnum = "pto::Layout::ND"); -static emitc::OpaqueType getGlobalTensorOpaqueTypeFromShape( - MLIRContext *ctx, Type elemTy, ArrayRef shape, - StringRef layoutEnum = "pto::Layout::ND"); - -static const char *addrSpaceQualifier(pto::AddressSpace as) { - switch (as) { - case pto::AddressSpace::Zero: - return "__gm__"; - case pto::AddressSpace::VEC: - return "__ubuf__"; - case pto::AddressSpace::GM: - return "__gm__"; - case pto::AddressSpace::MAT: - return "__cbuf__"; - case pto::AddressSpace::LEFT: - return "__ca__"; - case pto::AddressSpace::RIGHT: - return "__cb__"; - case pto::AddressSpace::ACC: - return "__cc__"; - case pto::AddressSpace::BIAS: - // Bias tiles are special in pto-isa; keep a safe fallback qualifier. - return "__gm__"; - case pto::AddressSpace::SCALING: - // pto-isa TileType::Scaling maps to __fbuf__ (see pto/common/memory.hpp). - return "__fbuf__"; - } - return "__gm__"; -} - -[[maybe_unused]] static constexpr llvm::StringLiteral kLoweredSetValidShapeAttrName = - "__pto.lowered_set_validshape"; -[[maybe_unused]] static constexpr llvm::StringLiteral kLoweredSetValidShapeConfigAttrName = - "__pto.lowered_set_validshape_config"; -static constexpr llvm::StringLiteral kForceDynamicValidShapeAttrName = - "__pto.force_dynamic_valid_shape"; -static constexpr llvm::StringLiteral kGlobalTensorStridesAttrName = - "__pto.globaltensor_strides"; - -static Value peelUnrealized(Value v) { - if (auto castOp = v.getDefiningOp()) - return castOp.getOperand(0); - return v; -} - -static Value buildGlobalTensorFromMemref(ConversionPatternRewriter &rewriter, - Location loc, Value basePtr, - MemRefType mrTy, Operation *anchor); - -static Value maybeWrapGlobalMemrefAsGlobalTensor( - ConversionPatternRewriter &rewriter, Location loc, Value loweredValue, - Type originalType, Operation *anchor); - -static bool hasCompatibleKnownExtentForMGather(int64_t lhs, int64_t rhs) { - return lhs == ShapedType::kDynamic || rhs == ShapedType::kDynamic || - lhs == rhs; -} - -static bool isKnownUnitExtentForMGather(int64_t value) { - return value == ShapedType::kDynamic || value == 1; -} - -struct GatherScatterShapeLayoutInfo { - SmallVector shape; - bool rowMajor = false; - bool colMajor = false; -}; - -static std::optional -getGatherScatterShapeLayoutInfo(Type ty) { - if (auto tileTy = dyn_cast(ty)) { - ArrayRef validShape = tileTy.getValidShape(); - if (validShape.size() != 2) - return std::nullopt; - - GatherScatterShapeLayoutInfo info; - info.shape.assign(validShape.begin(), validShape.end()); - int32_t blayout = tileTy.getBLayoutValueI32(); - info.rowMajor = blayout == static_cast(pto::BLayout::RowMajor); - info.colMajor = blayout == static_cast(pto::BLayout::ColMajor); - return info; - } - - auto memRefTy = dyn_cast(ty); - if (!memRefTy || memRefTy.getRank() != 2) - return std::nullopt; - - SmallVector strides; - int64_t offset = ShapedType::kDynamic; - if (failed(getStridesAndOffset(memRefTy, strides, offset)) || - strides.size() != 2) - return std::nullopt; - - GatherScatterShapeLayoutInfo info; - info.shape.assign(memRefTy.getShape().begin(), memRefTy.getShape().end()); - info.rowMajor = strides[1] == 1; - info.colMajor = strides[0] == 1; - return info; -} - -static bool isRowCoalescedMGatherIndexType(Type dataTy, Type idxTy) { - auto dataInfo = getGatherScatterShapeLayoutInfo(dataTy); - auto idxInfo = getGatherScatterShapeLayoutInfo(idxTy); - if (!dataInfo || !idxInfo) - return false; - - const bool rowCoalesce1xR = - idxInfo->rowMajor && isKnownUnitExtentForMGather(idxInfo->shape[0]) && - hasCompatibleKnownExtentForMGather(idxInfo->shape[1], dataInfo->shape[0]); - const bool rowCoalesceRx1 = - idxInfo->colMajor && - hasCompatibleKnownExtentForMGather(idxInfo->shape[0], dataInfo->shape[0]) && - isKnownUnitExtentForMGather(idxInfo->shape[1]); - return rowCoalesce1xR || rowCoalesceRx1; -} - -static std::optional getLayoutAttrFromOp(Operation *op) { - if (!op) - return std::nullopt; - if (auto attr = op->getAttrOfType("layout")) - return attr.getLayout(); - return std::nullopt; -} - -static std::optional resolveLayoutFromValueChain(Value v) { - v = peelUnrealized(v); - while (Operation *def = v.getDefiningOp()) { - if (auto layout = getLayoutAttrFromOp(def)) - return layout; - if (auto subview = dyn_cast(def)) { - v = peelUnrealized(subview.getSource()); - continue; - } - if (auto reinterpret = dyn_cast(def)) { - v = peelUnrealized(reinterpret.getSource()); - continue; - } - if (auto cast = dyn_cast(def)) { - v = peelUnrealized(cast.getSource()); - continue; - } - if (auto unrealized = dyn_cast(def)) { - if (unrealized->getNumOperands() == 0) - break; - v = peelUnrealized(unrealized.getOperand(0)); - continue; - } - break; - } - return std::nullopt; -} - -static std::optional -resolveLayoutForGlobalTensor(Operation *anchor, Value basePtr) { - if (auto layout = getLayoutAttrFromOp(anchor)) - return layout; - return resolveLayoutFromValueChain(basePtr); -} - -static std::string layoutToEmitCString(mlir::pto::Layout layout) { - switch (layout) { - case mlir::pto::Layout::ND: - return "pto::Layout::ND"; - case mlir::pto::Layout::DN: - return "pto::Layout::DN"; - case mlir::pto::Layout::NZ: - return "pto::Layout::NZ"; - } - return "pto::Layout::ND"; -} - -static bool isEmitCGlobalTensorLikeType(Type ty) { - auto opaqueTy = dyn_cast(ty); - return opaqueTy && opaqueTy.getValue().contains("GlobalTensor<"); -} - -static std::string getEmitCScalarTypeToken(Type elemTy) { - if (pto::isPTOFloat8Type(elemTy) && - (elemTy.isFloat8E4M3() || elemTy.isFloat8E4M3FN() || - elemTy.isFloat8E4M3FNUZ() || elemTy.isFloat8E4M3B11FNUZ())) - return "float8_e4m3_t"; - if (pto::isPTOFloat8Type(elemTy) && - (elemTy.isFloat8E5M2() || elemTy.isFloat8E5M2FNUZ())) - return "float8_e5m2_t"; - if (isa(elemTy)) - return "hifloat8_t"; - if (isa(elemTy)) - return "float4_e1m2x2_t"; - if (isa(elemTy)) - return "float4_e2m1x2_t"; - if (elemTy.isF16()) - return "half"; - if (elemTy.isBF16()) - return "bfloat16_t"; - if (elemTy.isF32()) - return "float"; - if (elemTy.isF64()) - return "double"; - if (elemTy.isInteger(8)) - return (elemTy.isSignlessInteger(8) || elemTy.isSignedInteger(8)) ? "int8_t" - : "uint8_t"; - if (elemTy.isInteger(16)) - return (elemTy.isSignlessInteger(16) || elemTy.isSignedInteger(16)) - ? "int16_t" - : "uint16_t"; - if (elemTy.isInteger(32)) - return (elemTy.isSignlessInteger(32) || elemTy.isSignedInteger(32)) - ? "int32_t" - : "uint32_t"; - if (elemTy.isInteger(64)) - return cast(elemTy).isUnsigned() ? "uint64_t" : "int64_t"; - return "float"; -} - -static emitc::PointerType getEmitCPointerType(MLIRContext *ctx, - StringRef pointeeTypeStr) { - return emitc::PointerType::get(emitc::OpaqueType::get(ctx, pointeeTypeStr)); -} - -static emitc::PointerType getEmitCPointerType(MLIRContext *ctx, - StringRef qualifier, - StringRef elemTypeStr) { - return getEmitCPointerType(ctx, (qualifier + " " + elemTypeStr).str()); -} - -static bool isEmitCPointerLikeType(Type ty) { - if (isa(ty)) - return true; - if (auto opaqueTy = dyn_cast(ty)) - return opaqueTy.getValue().ends_with("*"); - return false; -} - -static int64_t getEmitCScalarByteWidth(Type elemTy) { - if (pto::getPTOStorageElemByteSize(elemTy) == 1) - return 1; - if (elemTy.isF16() || elemTy.isBF16() || elemTy.isInteger(16)) - return 2; - if (elemTy.isF32() || elemTy.isInteger(32)) - return 4; - if (elemTy.isF64() || elemTy.isInteger(64)) - return 8; - return 4; -} - -static std::string tileBufBLayoutToken(pto::TileBufConfigAttr configAttr); -static std::string tileBufSLayoutToken(pto::TileBufConfigAttr configAttr); -static std::string tileBufPadToken(pto::TileBufConfigAttr configAttr); -static pto::BLayout getTileBufBLayoutValue(pto::TileBufConfigAttr configAttr); -static int64_t renderTileTemplateDim(int64_t rawDim, Type elemTy, - pto::BLayout blayout, int dimIdx); - -static const char *tileRoleToken(Attribute memorySpace) { - if (auto asAttr = dyn_cast_or_null(memorySpace)) { - switch (asAttr.getAddressSpace()) { - case pto::AddressSpace::VEC: - return "TileType::Vec"; - case pto::AddressSpace::MAT: - return "TileType::Mat"; - case pto::AddressSpace::LEFT: - return "TileType::Left"; - case pto::AddressSpace::RIGHT: - return "TileType::Right"; - case pto::AddressSpace::ACC: - return "TileType::Acc"; - case pto::AddressSpace::BIAS: - return "TileType::Bias"; - case pto::AddressSpace::SCALING: - return "TileType::Scaling"; - case pto::AddressSpace::GM: - case pto::AddressSpace::Zero: - return "TileType::Vec"; - } - } - return "TileType::Vec"; -} - -static std::string tileBufCompactToken(pto::TileBufConfigAttr configAttr) { - std::string compactTok = "CompactMode::Null"; - if (auto compactAttr = dyn_cast(configAttr.getCompactMode())) { - switch (static_cast(compactAttr.getValue())) { - case 1: - compactTok = "CompactMode::Normal"; - break; - case 2: - compactTok = "CompactMode::RowPlusOne"; - break; - default: - compactTok = "CompactMode::Null"; - break; - } - } - return compactTok; -} - -static std::optional getEmitCTileTypeString(pto::TileBufType type) { - if (type.getRank() != 2) - return std::nullopt; - auto validShape = type.getValidShape(); - if (validShape.size() != 2) - return std::nullopt; - - Type elemTy = type.getElementType(); - auto configAttr = type.getConfigAttr(); - pto::BLayout blayout = getTileBufBLayoutValue(configAttr); - ArrayRef shape = type.getShape(); - int64_t rows = shape[0]; - int64_t cols = shape[1]; - - auto render = [&](int64_t dim, int dimIdx) { - return renderTileTemplateDim(dim, elemTy, blayout, dimIdx); - }; - - std::string vrowTok = - validShape[0] == ShapedType::kDynamic - ? "-1" - : std::to_string(render(validShape[0], 0)); - std::string vcolTok = - validShape[1] == ShapedType::kDynamic - ? "-1" - : std::to_string(render(validShape[1], 1)); - - int32_t fractal = 512; - if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) - fractal = frAttr.getInt(); - - return std::string("Tile<") + tileRoleToken(type.getMemorySpace()) + ", " + - getEmitCScalarTypeToken(elemTy) + ", " + - std::to_string(render(rows, 0)) + ", " + - std::to_string(render(cols, 1)) + ", " + - tileBufBLayoutToken(configAttr) + ", " + vrowTok + ", " + vcolTok + - ", " + tileBufSLayoutToken(configAttr) + ", " + - std::to_string(fractal) + ", " + tileBufPadToken(configAttr) + ", " + - tileBufCompactToken(configAttr) + ">"; -} - -//===----------------------------------------------------------------------===// -// Type Converter -//===----------------------------------------------------------------------===// - -class PTOToEmitCTypeConverter : public TypeConverter { -public: - PTOToEmitCTypeConverter(MLIRContext *Ctx, PTOArch targetArch) { - // --------------------------------------------------------- - // 1. 基本类型 (f32, i32, index) - // --------------------------------------------------------- - addConversion([Ctx](FloatType type) -> Type { - if (type.isFloat8E4M3() || type.isFloat8E4M3FN() || - type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ()) - return emitc::OpaqueType::get(Ctx, "float8_e4m3_t"); - if (type.isFloat8E5M2() || type.isFloat8E5M2FNUZ()) - return emitc::OpaqueType::get(Ctx, "float8_e5m2_t"); - if (type.isF32()) return emitc::OpaqueType::get(Ctx, "float"); - if (type.isF16()) return emitc::OpaqueType::get(Ctx, "half"); - if (type.isBF16()) return emitc::OpaqueType::get(Ctx, "bfloat16_t"); - if (type.isF64()) return emitc::OpaqueType::get(Ctx, "double"); - llvm::errs() << "[Debug] Unsupported FloatType: " << type << "\n"; - return Type{}; - }); - - addConversion([Ctx](pto::HiF8Type) -> Type { - return emitc::OpaqueType::get(Ctx, "hifloat8_t"); - }); - addConversion([Ctx](pto::F4E1M2x2Type) -> Type { - return emitc::OpaqueType::get(Ctx, "float4_e1m2x2_t"); - }); - addConversion([Ctx](pto::F4E2M1x2Type) -> Type { - return emitc::OpaqueType::get(Ctx, "float4_e2m1x2_t"); - }); - - addConversion([Ctx](IntegerType type) -> Type { - if (type.getWidth() == 1) - return type; - - // Prefer fixed-width C types. Preserve signedness if the MLIR integer is - // explicitly signed/unsigned; treat signless as signed by default. - const bool isUnsigned = type.isUnsignedInteger(); - switch (type.getWidth()) { - case 8: - return emitc::OpaqueType::get(Ctx, isUnsigned ? "uint8_t" : "int8_t"); - case 16: - return emitc::OpaqueType::get(Ctx, - isUnsigned ? "uint16_t" : "int16_t"); - case 32: - return emitc::OpaqueType::get(Ctx, - isUnsigned ? "uint32_t" : "int32_t"); - case 64: - return emitc::OpaqueType::get(Ctx, - isUnsigned ? "uint64_t" : "int64_t"); - default: - llvm::errs() << "[Debug] Unsupported IntegerType width: " - << type.getWidth() << "\n"; - return emitc::OpaqueType::get(Ctx, "int32_t"); // Fallback - } - }); - - addConversion([Ctx](IndexType type) -> Type { - return emitc::OpaqueType::get(Ctx, "int32_t"); - }); - - // vector<4xi16> (e.g. TMRGSORT executedNumList) -> pto::MrgSortExecutedNumList - addConversion([Ctx](VectorType type) -> Type { - if (type.getRank() == 1 && type.getNumElements() == 4 && - type.getElementType().isInteger(16)) - return emitc::OpaqueType::get(Ctx, "pto::MrgSortExecutedNumList"); - return Type{}; - }); - - // --------------------------------------------------------- - // 2. PTO 特殊类型 (透传或转换) - // --------------------------------------------------------- - addConversion([](emitc::OpaqueType type) { return type; }); - addConversion([](emitc::PointerType type) { return type; }); - - // --------------------------------------------------------- - // 2.5 PtrType 转换 (指针类型) - // --------------------------------------------------------- - addConversion([this, Ctx](pto::PtrType type) -> std::optional { - Type elemType = type.getElementType(); - Type newElemType = convertType(elemType); - if (!newElemType) - return std::nullopt; - - std::string elemTypeStr; - if (auto opq = dyn_cast(newElemType)) { - elemTypeStr = opq.getValue().str(); - } else { - llvm::errs() << " [Error] PtrType elem type is not OpaqueType: " - << newElemType << "\n"; - return std::nullopt; - } - - std::string qualifier = "__gm__"; - - std::string finalTypeStr = qualifier + " " + elemTypeStr; - return getEmitCPointerType(Ctx, finalTypeStr); - }); - - addConversion([Ctx](pto::PipeType type) -> Type { - (void)type; - return emitc::OpaqueType::get(Ctx, "auto"); - }); - - addConversion([Ctx](pto::EventIdArrayType type) -> Type { - std::string tok = "PTOAS_EventIdArray<" + std::to_string(type.getSize()) + ">"; - return emitc::OpaqueType::get(Ctx, tok); - }); - - // !pto.local_array -> !emitc.array. - // Variables of this type render as `T a[D1][D2]...;` in the emitted C++. - addConversion([this](pto::LocalArrayType type) -> std::optional { - Type convertedElem = convertType(type.getElementType()); - if (!convertedElem) - return std::nullopt; - return emitc::ArrayType::get(type.getShape(), convertedElem); - }); - - addConversion([Ctx](pto::AsyncSessionType type) -> Type { - (void)type; - return emitc::OpaqueType::get(Ctx, "pto::comm::AsyncSession"); - }); - - addConversion([Ctx](pto::AsyncEventType type) -> Type { - (void)type; - return emitc::OpaqueType::get(Ctx, "pto::comm::AsyncEvent"); - }); - - addConversion([Ctx](pto::PrefetchAsyncContextType type) -> Type { - (void)type; - return emitc::OpaqueType::get(Ctx, "pto::PrefetchAsyncContext"); - }); - - addConversion([Ctx](pto::TensorViewType type) -> Type { - return getGlobalTensorOpaqueTypeFromShape( - Ctx, type.getElementType(), type.getShape()); - }); - - addConversion([Ctx](pto::PartitionTensorViewType type) -> Type { - return getGlobalTensorOpaqueTypeFromShape( - Ctx, type.getElementType(), type.getShape()); - }); - - addConversion([Ctx](pto::TileBufType type) -> std::optional { - auto typeString = getEmitCTileTypeString(type); - if (!typeString) - return std::nullopt; - return emitc::OpaqueType::get(Ctx, *typeString); - }); - - // --------------------------------------------------------- - // 3. MemRef 转换 (Debug 重点) - // --------------------------------------------------------- - addConversion([this, Ctx](MemRefType type) -> std::optional { - LLVM_DEBUG(llvm::dbgs() << "Converting MemRef: " << type << "\n"); - - // A. 转换元素类型 - Type elemType = type.getElementType(); - Type newElemType = convertType(elemType); - if (!newElemType) { - llvm::errs() << " [Error] Failed to convert element type: " << elemType << "\n"; - return std::nullopt; - } - - // 获取元素类型的字符串 - std::string elemTypeStr; - if (auto opq = dyn_cast(newElemType)) { - elemTypeStr = opq.getValue().str(); - } else { - llvm::errs() << " [Error] Converted element type is not OpaqueType: " << newElemType << "\n"; - return std::nullopt; - } - - // B. 处理 Memory Space - std::string qualifier = ""; - Attribute memorySpace = type.getMemorySpace(); - - if (!memorySpace) { - qualifier = "__gm__"; - } else if (auto ptoAttr = dyn_cast(memorySpace)) { - qualifier = addrSpaceQualifier(ptoAttr.getAddressSpace()); - } else { - llvm::errs() << " [Warning] Unknown MemorySpace Attribute type: " << memorySpace << "\n"; - qualifier = "__gm__"; // Fallback - } - - std::string finalTypeStr = qualifier + " " + elemTypeStr; - LLVM_DEBUG(llvm::dbgs() << " [Success] -> " << finalTypeStr << "*\n"); - - return getEmitCPointerType(Ctx, finalTypeStr); - }); - - // --------------------------------------------------------- - // 4. Function & Materialization - // --------------------------------------------------------- - addConversion([this](FunctionType type) -> Type { - SmallVector inputs; - if (failed(convertTypes(type.getInputs(), inputs))) return Type{}; - SmallVector results; - if (failed(convertTypes(type.getResults(), results))) return Type{}; - return FunctionType::get(type.getContext(), inputs, results); - }); - - auto materializeCast = [](OpBuilder &Builder, Type ResultType, - ValueRange Inputs, Location Loc) -> Value { - if (Inputs.size() != 1) return Value(); - return Builder.create(Loc, ResultType, Inputs[0]).getResult(0); - }; - - addSourceMaterialization(materializeCast); - addTargetMaterialization(materializeCast); - // Needed for region/block signature conversions (e.g. CFG block args). - addArgumentMaterialization(materializeCast); - } -}; - -static constexpr unsigned kPTOIndexBitWidth = - 32; // keep consistent with IndexType conversion - -// Forward declarations (definitions below). -static inline std::string pipeTokFromPipeAttr(mlir::pto::PipeAttr a); -static emitc::OpaqueType getSignedIntOpaqueType(MLIRContext *ctx, - unsigned bitWidth); -static emitc::OpaqueType getUnsignedIntOpaqueType(MLIRContext *ctx, - unsigned bitWidth); -static emitc::OpaqueType getWiderSignedIntOpaqueType(MLIRContext *ctx, - unsigned bitWidth); -static emitc::OpaqueType getWiderUnsignedIntOpaqueType(MLIRContext *ctx, - unsigned bitWidth); -static Value makeEmitCOpaqueConstant(ConversionPatternRewriter &rewriter, - Location loc, Type type, - llvm::StringRef literal); -static Value makeEmitCIntConstant(ConversionPatternRewriter &rewriter, - Location loc, Type type, int64_t value); -static Value emitCCast(ConversionPatternRewriter &rewriter, Location loc, - Type dstType, Value src); -static FailureOr buildEmitCOpaqueConstantLiteral(Type targetType, - Attribute valueAttr); -static Value castSignlessIntToUnsignedSameWidth(ConversionPatternRewriter &rewriter, - Location loc, Value v, - unsigned bitWidth); -static bool needsA5NoSplitVectorGuard(Operation *op); - -static FailureOr getTileSplitToken(int64_t split) { - switch (split) { - case 0: - return std::string("TileSplitAxis::TILE_NO_SPLIT"); - case 1: - return std::string("TileSplitAxis::TILE_UP_DOWN"); - case 2: - return std::string("TileSplitAxis::TILE_LEFT_RIGHT"); - default: - return failure(); - } -} - -static FailureOr -getTPipeDirectionToken(bool isL2G2L, int8_t dirMask, PTOArch targetArch) { - if (dirMask == 1) { - if (isL2G2L && targetArch == PTOArch::A5) - return std::string("Direction::DIR_C2V_GM"); - return std::string("Direction::DIR_C2V"); - } - if (dirMask == 2) { - if (isL2G2L && targetArch == PTOArch::A5) - return std::string("Direction::DIR_V2C_GM"); - return std::string("Direction::DIR_V2C"); - } - if (dirMask == 3) - return std::string("Direction::DIR_BOTH"); - return failure(); -} - -static std::string buildTPipeToken(int32_t flagBase, llvm::StringRef dirTok, - int32_t slotSize, int32_t slotNum, - int32_t localSlotNum, bool nosplit) { - std::string token = "TPipe<" + std::to_string(flagBase) + ", " + dirTok.str() + - ", " + std::to_string(slotSize) + ", " + - std::to_string(slotNum); - token += ", " + std::to_string(localSlotNum); - token += nosplit ? ", true" : ", false"; - token += ">"; - return token; -} - -static FailureOr buildTPipeTokenFromInitOp(Operation *op, - PTOArch targetArch) { - if (auto initOp = dyn_cast(op)) { - if (!initOp.getFlagBaseAttr()) - return failure(); - auto dirTok = - getTPipeDirectionToken(/*isL2G2L=*/true, initOp.getDirMask(), targetArch); - if (failed(dirTok)) - return failure(); - int32_t localSlotNum = initOp.getLocalSlotNumAttr() - ? initOp.getLocalSlotNumAttr().getInt() - : initOp.getSlotNum(); - return buildTPipeToken(initOp.getFlagBaseAttr().getInt(), *dirTok, - initOp.getSlotSize(), initOp.getSlotNum(), - localSlotNum, - initOp.getNosplitAttr() && - initOp.getNosplitAttr().getValue()); - } - - if (auto initOp = dyn_cast(op)) { - if (!initOp.getFlagBaseAttr()) - return failure(); - auto dirTok = - getTPipeDirectionToken(/*isL2G2L=*/false, initOp.getDirMask(), targetArch); - if (failed(dirTok)) - return failure(); - return buildTPipeToken(initOp.getFlagBaseAttr().getInt(), *dirTok, - initOp.getSlotSize(), initOp.getSlotNum(), 2, - initOp.getNosplitAttr() && - initOp.getNosplitAttr().getValue()); - } - - return failure(); -} - -static FailureOr getTPipeTokenFromValue(Value pipeHandle, - PTOArch targetArch) { - pipeHandle = peelUnrealized(pipeHandle); - Operation *def = pipeHandle.getDefiningOp(); - if (!def) - return failure(); - return buildTPipeTokenFromInitOp(def, targetArch); -} - -static bool isSetFFTsPointerLikeType(Type ty) { - return isEmitCPointerLikeType(ty); -} - -static bool tileDataReturnsIntegralAddress(pto::AddressSpace as) { - return as == pto::AddressSpace::BIAS; -} - -static Type getTileDataResultType(MLIRContext *ctx, pto::AddressSpace as, - StringRef elemTok) { - if (tileDataReturnsIntegralAddress(as)) - return emitc::OpaqueType::get(ctx, "uint64_t"); - return getEmitCPointerType(ctx, addrSpaceQualifier(as), elemTok); -} - -static Value materializeTileDataValue(ConversionPatternRewriter &rewriter, - Location loc, Value tile, - pto::AddressSpace as, - StringRef elemTok) { - auto rawTy = getTileDataResultType(rewriter.getContext(), as, elemTok); - return rewriter - .create(loc, rawTy, "PTOAS__TILE_DATA", - ArrayAttr{}, ArrayAttr{}, - ValueRange{tile}) - .getResult(0); -} - -static Value materializeAddressAsPointer(ConversionPatternRewriter &rewriter, - Location loc, Value addr, - pto::AddressSpace as, - StringRef elemTok) { - auto *ctx = rewriter.getContext(); - std::string ptrTyStr = - std::string(addrSpaceQualifier(as)) + " " + elemTok.str() + "*"; - auto ptrTy = getEmitCPointerType(ctx, addrSpaceQualifier(as), elemTok); - if (isSetFFTsPointerLikeType(addr.getType())) { - if (addr.getType() == ptrTy) - return addr; - return rewriter.create(loc, ptrTy, addr).getResult(); - } - auto castTyAttr = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, ptrTyStr)}); - return rewriter - .create(loc, ptrTy, "reinterpret_cast", - ArrayAttr{}, castTyAttr, - ValueRange{addr}) - .getResult(0); -} - -struct InterCoreSyncCallDesc { - const char *callee = nullptr; - ArrayAttr args; - SmallVector operands; -}; - -static Value castInterCoreEventIdToI32(ConversionPatternRewriter &rewriter, - Location loc, Value eventId) { - auto i32Ty = emitc::OpaqueType::get(rewriter.getContext(), "int32_t"); - if (eventId.getType() == i32Ty) - return eventId; - return emitCCast(rewriter, loc, i32Ty, eventId); -} - -static Attribute getFFTSModeCodegenArg(ConversionPatternRewriter &rewriter, - int64_t fftsMode) { - auto *ctx = rewriter.getContext(); - if (fftsMode == 2) - return emitc::OpaqueAttr::get(ctx, "FFTS_MODE_VAL"); - return emitc::OpaqueAttr::get(ctx, std::to_string(fftsMode)); -} - -static Value createFFTSMsg(ConversionPatternRewriter &rewriter, Location loc, - Value eventI32, int64_t fftsMode) { - auto *ctx = rewriter.getContext(); - auto msgTy = emitc::OpaqueType::get(ctx, "uint16_t"); - auto msgArgs = rewriter.getArrayAttr({ - getFFTSModeCodegenArg(rewriter, fftsMode), - IntegerAttr::get(IndexType::get(ctx), 0), - }); - return rewriter - .create(loc, msgTy, "getFFTSMsg", - /*args=*/msgArgs, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{eventI32}) - .getResult(0); -} - -static InterCoreSyncCallDesc buildInterCoreSyncSetCall( - ConversionPatternRewriter &rewriter, Location loc, PTOArch targetArch, - pto::PipeAttr pipeAttr, IntegerAttr eventIdAttr, int64_t fftsMode) { - auto *ctx = rewriter.getContext(); - std::string pipeTok = pipeTokFromPipeAttr(pipeAttr); - - if (targetArch == PTOArch::A3) { - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value eventVal = - makeEmitCIntConstant(rewriter, loc, i32Ty, eventIdAttr.getInt()); - Value msgVal = createFFTSMsg(rewriter, loc, eventVal, fftsMode); - - InterCoreSyncCallDesc desc; - desc.callee = "ffts_cross_core_sync"; - desc.args = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeTok), - IntegerAttr::get(IndexType::get(ctx), 0), - }); - desc.operands.push_back(msgVal); - return desc; - } - - InterCoreSyncCallDesc desc; - desc.callee = "set_intra_block"; - desc.args = rewriter.getArrayAttr( - {emitc::OpaqueAttr::get(ctx, pipeTok), eventIdAttr}); - return desc; -} - -static InterCoreSyncCallDesc buildInterCoreSyncSetCallDyn( - ConversionPatternRewriter &rewriter, Location loc, PTOArch targetArch, - pto::PipeAttr pipeAttr, Value eventIdVal, int64_t fftsMode) { - auto *ctx = rewriter.getContext(); - std::string pipeTok = pipeTokFromPipeAttr(pipeAttr); - Value eventI32 = castInterCoreEventIdToI32(rewriter, loc, eventIdVal); - - if (targetArch == PTOArch::A3) { - Value msgVal = createFFTSMsg(rewriter, loc, eventI32, fftsMode); - - InterCoreSyncCallDesc desc; - desc.callee = "ffts_cross_core_sync"; - desc.args = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeTok), - IntegerAttr::get(IndexType::get(ctx), 0), - }); - desc.operands.push_back(msgVal); - return desc; - } - - InterCoreSyncCallDesc desc; - desc.callee = "set_intra_block"; - desc.args = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeTok), - IntegerAttr::get(IndexType::get(ctx), 0), - }); - desc.operands.push_back(eventI32); - return desc; -} - -static InterCoreSyncCallDesc buildInterCoreSyncWaitCall( - ConversionPatternRewriter &rewriter, PTOArch targetArch, - pto::PipeAttr pipeAttr, IntegerAttr eventIdAttr) { - auto *ctx = rewriter.getContext(); - std::string pipeTok = pipeTokFromPipeAttr(pipeAttr); - - InterCoreSyncCallDesc desc; - if (targetArch == PTOArch::A3) { - desc.callee = "wait_flag_dev"; - desc.args = rewriter.getArrayAttr({eventIdAttr}); - return desc; - } - - desc.callee = "wait_intra_block"; - desc.args = rewriter.getArrayAttr( - {emitc::OpaqueAttr::get(ctx, pipeTok), eventIdAttr}); - return desc; -} - -static InterCoreSyncCallDesc buildInterCoreSyncWaitCallDyn( - ConversionPatternRewriter &rewriter, Location loc, PTOArch targetArch, - pto::PipeAttr pipeAttr, Value eventIdVal) { - auto *ctx = rewriter.getContext(); - std::string pipeTok = pipeTokFromPipeAttr(pipeAttr); - Value eventI32 = castInterCoreEventIdToI32(rewriter, loc, eventIdVal); - - InterCoreSyncCallDesc desc; - if (targetArch == PTOArch::A3) { - desc.callee = "wait_flag_dev"; - desc.args = rewriter.getArrayAttr({IntegerAttr::get(IndexType::get(ctx), 0)}); - desc.operands.push_back(eventI32); - return desc; - } - - desc.callee = "wait_intra_block"; - desc.args = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeTok), - IntegerAttr::get(IndexType::get(ctx), 0), - }); - desc.operands.push_back(eventI32); - return desc; -} - -static bool hasInterCoreSyncOp(func::FuncOp func) { - bool found = false; - func.walk([&](Operation *op) { - if (isa(op)) { - found = true; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - return found; -} - -static bool hasSetFFTsOp(func::FuncOp func) { - bool found = false; - func.walk([&](Operation *op) { - if (isa(op)) { - found = true; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - return found; -} - -//===----------------------------------------------------------------------===// -// Arith -> EmitC (full dialect coverage for scalar ops) -//===----------------------------------------------------------------------===// - -template -struct ArithSimpleBinaryToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type dstTy = this->getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getOperands()); - return success(); - } -}; - -// Integer bitwise ops (andi/ori/xori) on signless integers: perform in unsigned -// to avoid signedness pitfalls, then cast back. -template -struct ArithUnsignedBitwiseBinaryToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = this->getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - if (bitWidth == 1) { - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getLhs(), - adaptor.getRhs()); - return success(); - } - - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value resU = rewriter.create(loc, uTy, lhsU, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, resU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithDivUIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::DivUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value divU = rewriter.create(loc, uTy, lhsU, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, divU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithRemUIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::RemUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value remU = rewriter.create(loc, uTy, lhsU, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, remU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithCeilDivUIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::CeilDivUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value one = makeEmitCIntConstant(rewriter, loc, uTy, 1); - Value rhsMinusOne = rewriter.create(loc, uTy, rhsU, one); - Value num = rewriter.create(loc, uTy, lhsU, rhsMinusOne); - Value divU = rewriter.create(loc, uTy, num, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, divU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithCeilDivSIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::CeilDivSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - Value zero = makeEmitCIntConstant(rewriter, loc, dstTy, 0); - Value one = makeEmitCIntConstant(rewriter, loc, dstTy, 1); - - Value q0 = rewriter.create(loc, dstTy, adaptor.getLhs(), - adaptor.getRhs()); - Value r = rewriter.create(loc, dstTy, adaptor.getLhs(), - adaptor.getRhs()); - - Value rNeZero = rewriter.create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::ne, r, - zero); - Value lhsLt0 = - rewriter.create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, adaptor.getLhs(), - zero); - Value rhsLt0 = - rewriter.create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, adaptor.getRhs(), - zero); - Value signsSame = - rewriter.create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::eq, lhsLt0, rhsLt0); - Value adjust = - rewriter.create(loc, rewriter.getI1Type(), - rNeZero, signsSame); - - Value qPlusOne = rewriter.create(loc, dstTy, q0, one); - Value result = rewriter.create(loc, dstTy, adjust, - qPlusOne, q0); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithFloorDivSIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::FloorDivSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - Value zero = makeEmitCIntConstant(rewriter, loc, dstTy, 0); - Value one = makeEmitCIntConstant(rewriter, loc, dstTy, 1); - - Value q0 = rewriter.create(loc, dstTy, adaptor.getLhs(), - adaptor.getRhs()); - Value r = rewriter.create(loc, dstTy, adaptor.getLhs(), - adaptor.getRhs()); - - Value rNeZero = rewriter.create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::ne, r, - zero); - Value lhsLt0 = - rewriter.create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, adaptor.getLhs(), - zero); - Value rhsLt0 = - rewriter.create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, adaptor.getRhs(), - zero); - Value signsDifferent = - rewriter.create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::ne, lhsLt0, rhsLt0); - Value adjust = - rewriter.create(loc, rewriter.getI1Type(), - rNeZero, signsDifferent); - - Value qMinusOne = rewriter.create(loc, dstTy, q0, one); - Value result = rewriter.create(loc, dstTy, adjust, - qMinusOne, q0); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithShiftLeftToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::ShLIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - if (bitWidth == 1) { - // Compute on u8 and truncate to i1. - auto u8Ty = getUnsignedIntOpaqueType(rewriter.getContext(), 8); - Value lhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getLhs()); - Value rhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getRhs()); - Value sh = rewriter.create(loc, u8Ty, lhsU8, - rhsU8); - Value masked = - rewriter.create(loc, u8Ty, sh, - makeEmitCIntConstant(rewriter, loc, - u8Ty, 1)); - rewriter.replaceOp(op, emitCCast(rewriter, loc, dstTy, masked)); - return success(); - } - - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value shU = - rewriter.create(loc, uTy, lhsU, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, shU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithShiftRightUIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::ShRUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - if (bitWidth == 1) { - // (x >> y) on i1 is either x (y==0) or 0 (y!=0); approximate in u8. - auto u8Ty = getUnsignedIntOpaqueType(rewriter.getContext(), 8); - Value lhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getLhs()); - Value rhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getRhs()); - Value sh = rewriter.create(loc, u8Ty, lhsU8, - rhsU8); - Value masked = - rewriter.create(loc, u8Ty, sh, - makeEmitCIntConstant(rewriter, loc, - u8Ty, 1)); - rewriter.replaceOp(op, emitCCast(rewriter, loc, dstTy, masked)); - return success(); - } - - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value shU = - rewriter.create(loc, uTy, lhsU, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, shU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithShiftRightSIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::ShRSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - if (bitWidth == 1) { - // (x >> y) on i1 is either x (y==0) or 0 (y!=0); approximate in u8. - auto u8Ty = getUnsignedIntOpaqueType(rewriter.getContext(), 8); - Value lhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getLhs()); - Value rhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getRhs()); - Value sh = rewriter.create(loc, u8Ty, lhsU8, - rhsU8); - Value masked = - rewriter.create(loc, u8Ty, sh, - makeEmitCIntConstant(rewriter, loc, - u8Ty, 1)); - rewriter.replaceOp(op, emitCCast(rewriter, loc, dstTy, masked)); - return success(); - } - - // Signed arithmetic shift; cast RHS to unsigned to interpret shift amount. - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value sh = - rewriter.create(loc, dstTy, adaptor.getLhs(), - rhsU); - rewriter.replaceOp(op, sh); - return success(); - } -}; - -struct ArithNegFToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::NegFOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getOperand()); - return success(); - } -}; - -struct ArithRemFToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::RemFOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - - // Use builtin `fmod` when possible. For f16, compute in float and cast back. - Type callTy = dstTy; - Value lhs = adaptor.getLhs(); - Value rhs = adaptor.getRhs(); - - if (auto opFloatTy = dyn_cast(op.getType())) { - if (opFloatTy.isF16()) { - auto f32Ty = emitc::OpaqueType::get(rewriter.getContext(), "float"); - lhs = emitCCast(rewriter, loc, f32Ty, lhs); - rhs = emitCCast(rewriter, loc, f32Ty, rhs); - callTy = f32Ty; - } - } - - // Prefer `__builtin_fmod*` to avoid relying on extra headers. - llvm::StringRef callee = "__builtin_fmod"; - if (auto opFloatTy = dyn_cast(op.getType())) { - if (opFloatTy.isF32() || opFloatTy.isF16()) - callee = "__builtin_fmodf"; - else if (opFloatTy.isF64()) - callee = "__builtin_fmod"; - } - - auto call = rewriter.create( - loc, TypeRange{callTy}, callee, ValueRange{lhs, rhs}, - /*args=*/ArrayAttr{}, /*template_args=*/ArrayAttr{}); - Value result = call.getResult(0); - if (callTy != dstTy) - result = emitCCast(rewriter, loc, dstTy, result); - - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithSelectToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!op.getCondition().getType().isInteger(1)) - return rewriter.notifyMatchFailure( - op, "only scalar i1 conditions supported for arith.select"); - - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - - auto cond = - rewriter.create(op.getLoc(), dstTy, - adaptor.getCondition(), - adaptor.getTrueValue(), - adaptor.getFalseValue()); - rewriter.replaceOp(op, cond.getResult()); - return success(); - } -}; - -struct ArithExtUIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto dstIntTy = dyn_cast(op.getType()); - auto srcIntTy = dyn_cast(op.getIn().getType()); - if (!dstIntTy || !srcIntTy) - return rewriter.notifyMatchFailure(op, "expected scalar integer types"); - - Type dstTy = getTypeConverter()->convertType(dstIntTy); - if (!dstTy) - return failure(); - - // i1 -> iN: bool to integer already behaves as 0/1. - if (srcIntTy.getWidth() == 1) { - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); - return success(); - } - - auto uDstTy = - getUnsignedIntOpaqueType(rewriter.getContext(), dstIntTy.getWidth()); - Value srcU = - castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getIn(), - srcIntTy.getWidth()); - Value extU = emitCCast(rewriter, loc, uDstTy, srcU); - Value result = emitCCast(rewriter, loc, dstTy, extU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithExtSIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto dstIntTy = dyn_cast(op.getType()); - auto srcIntTy = dyn_cast(op.getIn().getType()); - if (!dstIntTy || !srcIntTy) - return rewriter.notifyMatchFailure(op, "expected scalar integer types"); - - Type dstTy = getTypeConverter()->convertType(dstIntTy); - if (!dstTy) - return failure(); - - // i1 sign-extension: 0 -> 0, 1 -> -1. - if (srcIntTy.getWidth() == 1) { - Value zero = makeEmitCIntConstant(rewriter, loc, dstTy, 0); - Value asInt = emitCCast(rewriter, loc, dstTy, adaptor.getIn()); - Value neg = rewriter.create(loc, dstTy, zero, asInt).getResult(); - rewriter.replaceOp(op, neg); - return success(); - } - - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); - return success(); - } -}; - -template -struct ArithCastToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type dstTy = this->getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); - return success(); - } -}; - -struct ArithIndexCastUIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::IndexCastUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - - // MemRef casts are handled elsewhere; for safety, fall back to emitc.cast. - if (isa(op.getIn().getType()) || isa(op.getType())) { - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); - return success(); - } - - auto getBW = [](Type t) -> std::optional { - if (auto i = dyn_cast(t)) - return i.getWidth(); - if (isa(t)) - return kPTOIndexBitWidth; - return std::nullopt; - }; - - auto srcBW = getBW(op.getIn().getType()); - auto dstBW = getBW(op.getType()); - if (!srcBW || !dstBW) - return rewriter.notifyMatchFailure(op, "unsupported index_castui types"); - - if (*dstBW <= *srcBW) { - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); - return success(); - } - - auto uSrcTy = getUnsignedIntOpaqueType(rewriter.getContext(), *srcBW); - auto uDstTy = getUnsignedIntOpaqueType(rewriter.getContext(), *dstBW); - Value srcU = emitCCast(rewriter, loc, uSrcTy, adaptor.getIn()); - Value extU = emitCCast(rewriter, loc, uDstTy, srcU); - Value result = emitCCast(rewriter, loc, dstTy, extU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithUIToFPToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto srcIntTy = dyn_cast(op.getIn().getType()); - if (!srcIntTy) - return rewriter.notifyMatchFailure(op, "expected scalar integer input"); - - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - - // Convert via an unsigned integer type of the same width. - if (srcIntTy.getWidth() == 1) { - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); - return success(); - } - Value srcU = - castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getIn(), - srcIntTy.getWidth()); - Value fp = rewriter.create(loc, dstTy, srcU).getResult(); - rewriter.replaceOp(op, fp); - return success(); - } -}; - -struct ArithFPToUIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::FPToUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto dstIntTy = dyn_cast(op.getType()); - if (!dstIntTy) - return rewriter.notifyMatchFailure(op, "expected scalar integer result"); - - Type dstTy = getTypeConverter()->convertType(dstIntTy); - if (!dstTy) - return failure(); - - auto uDstTy = - getUnsignedIntOpaqueType(rewriter.getContext(), dstIntTy.getWidth()); - Value asU = rewriter.create(loc, uDstTy, adaptor.getIn()).getResult(); - Value result = emitCCast(rewriter, loc, dstTy, asU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithBitcastToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - - // For pointer-like types, a regular cast is fine. - if (isa(dstTy)) { - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); - return success(); - } - - // Only support scalar int/float/index bitcasts here. - auto srcTy = op.getIn().getType(); - auto dstOrigTy = op.getType(); - - auto getBitWidth = [](Type t) -> std::optional { - if (auto it = dyn_cast(t)) - return it.getWidth(); - if (auto ft = dyn_cast(t)) - return ft.getWidth(); - if (isa(t)) - return kPTOIndexBitWidth; - return std::nullopt; - }; - auto srcBW = getBitWidth(srcTy); - auto dstBW = getBitWidth(dstOrigTy); - if (!srcBW || !dstBW || *srcBW != *dstBW) - return rewriter.notifyMatchFailure(op, "bitcast requires equal bitwidth"); - - // Determine the template argument from the destination type string. - auto dstOpaque = dyn_cast(dstTy); - if (!dstOpaque) - return rewriter.notifyMatchFailure(op, "expected emitc opaque dest type"); - - auto templateArgs = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), - dstOpaque.getValue())}); - auto call = rewriter.create( - loc, TypeRange{dstTy}, "ptoas_bitcast", /*operands=*/ValueRange{adaptor.getIn()}, - /*args=*/ArrayAttr{}, /*template_args=*/templateArgs); - rewriter.replaceOp(op, call.getResult(0)); - return success(); - } -}; - -// arith.cmpf lowering with ordered/unordered semantics. -struct ArithCmpFToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - struct CmpFConfig { - bool unordered = false; - emitc::CmpPredicate predicate = emitc::CmpPredicate::eq; - }; - - static Value isNaN(ConversionPatternRewriter &rewriter, Location loc, - Value v) { - return rewriter - .create(loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, - v, v) - .getResult(); - } - - static Value isNotNaN(ConversionPatternRewriter &rewriter, Location loc, - Value v) { - return rewriter - .create(loc, rewriter.getI1Type(), emitc::CmpPredicate::eq, - v, v) - .getResult(); - } - - static std::optional buildSpecialCmpFResult( - arith::CmpFPredicate predicate, ConversionPatternRewriter &rewriter, - Location loc, Type i1Ty, Value lhs, Value rhs) { - switch (predicate) { - case arith::CmpFPredicate::AlwaysFalse: - return makeEmitCOpaqueConstant(rewriter, loc, i1Ty, "false"); - case arith::CmpFPredicate::AlwaysTrue: - return makeEmitCOpaqueConstant(rewriter, loc, i1Ty, "true"); - case arith::CmpFPredicate::ORD: - return rewriter.create( - loc, i1Ty, isNotNaN(rewriter, loc, lhs), - isNotNaN(rewriter, loc, rhs)) - .getResult(); - case arith::CmpFPredicate::UNO: - return rewriter.create( - loc, i1Ty, isNaN(rewriter, loc, lhs), - isNaN(rewriter, loc, rhs)) - .getResult(); - default: - return std::nullopt; - } - } - - static std::optional - getCmpFConfig(arith::CmpFPredicate predicate) { - switch (predicate) { - case arith::CmpFPredicate::OEQ: - return CmpFConfig{false, emitc::CmpPredicate::eq}; - case arith::CmpFPredicate::OGT: - return CmpFConfig{false, emitc::CmpPredicate::gt}; - case arith::CmpFPredicate::OGE: - return CmpFConfig{false, emitc::CmpPredicate::ge}; - case arith::CmpFPredicate::OLT: - return CmpFConfig{false, emitc::CmpPredicate::lt}; - case arith::CmpFPredicate::OLE: - return CmpFConfig{false, emitc::CmpPredicate::le}; - case arith::CmpFPredicate::ONE: - return CmpFConfig{false, emitc::CmpPredicate::ne}; - case arith::CmpFPredicate::UEQ: - return CmpFConfig{true, emitc::CmpPredicate::eq}; - case arith::CmpFPredicate::UGT: - return CmpFConfig{true, emitc::CmpPredicate::gt}; - case arith::CmpFPredicate::UGE: - return CmpFConfig{true, emitc::CmpPredicate::ge}; - case arith::CmpFPredicate::ULT: - return CmpFConfig{true, emitc::CmpPredicate::lt}; - case arith::CmpFPredicate::ULE: - return CmpFConfig{true, emitc::CmpPredicate::le}; - case arith::CmpFPredicate::UNE: - return CmpFConfig{true, emitc::CmpPredicate::ne}; - default: - return std::nullopt; - } - } - - static Value buildCmpFResult(const CmpFConfig &config, - ConversionPatternRewriter &rewriter, - Location loc, Type i1Ty, Value lhs, Value rhs) { - Value cmp = rewriter - .create(loc, i1Ty, config.predicate, lhs, rhs) - .getResult(); - Value unord = rewriter.create( - loc, i1Ty, isNaN(rewriter, loc, lhs), isNaN(rewriter, loc, rhs)); - if (config.unordered) - return rewriter - .create(loc, i1Ty, unord, cmp) - .getResult(); - Value ord = rewriter.create( - loc, i1Ty, isNotNaN(rewriter, loc, lhs), isNotNaN(rewriter, loc, rhs)); - return rewriter - .create(loc, i1Ty, ord, cmp) - .getResult(); - } - - LogicalResult matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!isa(op.getLhs().getType())) - return rewriter.notifyMatchFailure(op, "cmpf only supported on scalar floats"); - - auto loc = op.getLoc(); - auto i1Ty = rewriter.getI1Type(); - if (auto special = buildSpecialCmpFResult(op.getPredicate(), rewriter, loc, - i1Ty, adaptor.getLhs(), - adaptor.getRhs())) { - rewriter.replaceOp(op, *special); - return success(); - } - - auto config = getCmpFConfig(op.getPredicate()); - if (!config) - return rewriter.notifyMatchFailure(op, "unsupported cmpf predicate"); - rewriter.replaceOp(op, buildCmpFResult(*config, rewriter, loc, i1Ty, - adaptor.getLhs(), adaptor.getRhs())); - return success(); - } -}; - -struct ArithAddUIExtendedToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getSum().getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, - "expected scalar integer or index operands"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - SmallVector newResultTypes; - if (failed(getTypeConverter()->convertTypes(op->getResultTypes(), - newResultTypes))) - return failure(); - if (newResultTypes.size() != 2) - return failure(); - - Type sumDstTy = newResultTypes[0]; - Type overflowDstTy = newResultTypes[1]; - - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - auto wideTy = getWiderUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value lhsWide = emitCCast(rewriter, loc, wideTy, lhsU); - Value rhsWide = emitCCast(rewriter, loc, wideTy, rhsU); - Value sumWide = - rewriter.create(loc, wideTy, lhsWide, rhsWide).getResult(); - - Value sumN = emitCCast(rewriter, loc, uTy, sumWide); - Value sum = emitCCast(rewriter, loc, sumDstTy, sumN); - - Value shiftAmt = makeEmitCIntConstant(rewriter, loc, wideTy, bitWidth); - Value high = rewriter - .create(loc, wideTy, sumWide, - shiftAmt) - .getResult(); - Value zeroWide = makeEmitCIntConstant(rewriter, loc, wideTy, 0); - Value overflow = - rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::ne, high, zeroWide) - .getResult(); - overflow = emitCCast(rewriter, loc, overflowDstTy, overflow); - - rewriter.replaceOp(op, {sum, overflow}); - return success(); - } -}; - -template -struct ArithMulExtendedToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getResult(0).getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, - "expected scalar integer or index operands"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - SmallVector newResultTypes; - if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), - newResultTypes))) - return failure(); - if (newResultTypes.size() != 2) - return failure(); - - Type lowDstTy = newResultTypes[0]; - Type highDstTy = newResultTypes[1]; - - Type wideTy = isUnsigned ? (Type)getWiderUnsignedIntOpaqueType(rewriter.getContext(), - bitWidth) - : (Type)getWiderSignedIntOpaqueType(rewriter.getContext(), - bitWidth); - - Value lhsWide; - Value rhsWide; - if constexpr (isUnsigned) { - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - lhsWide = emitCCast(rewriter, loc, wideTy, lhsU); - rhsWide = emitCCast(rewriter, loc, wideTy, rhsU); - } else { - lhsWide = emitCCast(rewriter, loc, wideTy, adaptor.getLhs()); - rhsWide = emitCCast(rewriter, loc, wideTy, adaptor.getRhs()); - } - - Value prodWide = - rewriter.create(loc, wideTy, lhsWide, rhsWide).getResult(); - Value low = emitCCast(rewriter, loc, lowDstTy, prodWide); - - Value shiftAmt = makeEmitCIntConstant(rewriter, loc, wideTy, bitWidth); - Value highWide = rewriter - .create(loc, wideTy, prodWide, - shiftAmt) - .getResult(); - Value high = emitCCast(rewriter, loc, highDstTy, highWide); - - rewriter.replaceOp(op, {low, high}); - return success(); - } -}; - -using ArithMulSIExtendedToEmitC = - ArithMulExtendedToEmitC; -using ArithMulUIExtendedToEmitC = - ArithMulExtendedToEmitC; - -struct ArithMinMaxIToEmitCBase { - static Value makeSelect(ConversionPatternRewriter &rewriter, Location loc, - Type dstTy, Value cond, Value trueV, Value falseV) { - return rewriter - .create(loc, dstTy, cond, trueV, falseV) - .getResult(); - } -}; - -struct ArithMaxSIToEmitC : public OpConversionPattern, - ArithMinMaxIToEmitCBase { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::MaxSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - Value cond = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, - adaptor.getLhs(), adaptor.getRhs()) - .getResult(); - Value res = makeSelect(rewriter, loc, dstTy, cond, adaptor.getRhs(), - adaptor.getLhs()); - rewriter.replaceOp(op, res); - return success(); - } -}; - -struct ArithMinSIToEmitC : public OpConversionPattern, - ArithMinMaxIToEmitCBase { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::MinSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - Value cond = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, - adaptor.getLhs(), adaptor.getRhs()) - .getResult(); - Value res = makeSelect(rewriter, loc, dstTy, cond, adaptor.getLhs(), - adaptor.getRhs()); - rewriter.replaceOp(op, res); - return success(); - } -}; - -struct ArithMaxUIToEmitC : public OpConversionPattern, - ArithMinMaxIToEmitCBase { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::MaxUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - Value lhsU = - castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = - castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value cond = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, lhsU, rhsU) - .getResult(); - Value res = makeSelect(rewriter, loc, dstTy, cond, adaptor.getRhs(), - adaptor.getLhs()); - rewriter.replaceOp(op, res); - return success(); - } -}; - -struct ArithMinUIToEmitC : public OpConversionPattern, - ArithMinMaxIToEmitCBase { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::MinUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - Value lhsU = - castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = - castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value cond = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, lhsU, rhsU) - .getResult(); - Value res = makeSelect(rewriter, loc, dstTy, cond, adaptor.getLhs(), - adaptor.getRhs()); - rewriter.replaceOp(op, res); - return success(); - } -}; - -// Floating-point max/min variants. -struct ArithFloatMinMaxToEmitCBase { - static Value isNaN(ConversionPatternRewriter &rewriter, Location loc, - Value v) { - return rewriter - .create(loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, - v, v) - .getResult(); - } - - static Value makeFZero(ConversionPatternRewriter &rewriter, Location loc, - Type ty) { - return makeEmitCOpaqueConstant(rewriter, loc, ty, "0.0f"); - } -}; - -struct ArithMaxNumFToEmitC : public OpConversionPattern, - ArithFloatMinMaxToEmitCBase { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::MaxNumFOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - - Value lhsNaN = isNaN(rewriter, loc, adaptor.getLhs()); - Value rhsNaN = isNaN(rewriter, loc, adaptor.getRhs()); - - Value cmpLt = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, - adaptor.getLhs(), adaptor.getRhs()) - .getResult(); - Value maxNoNaN = - rewriter - .create(loc, dstTy, cmpLt, adaptor.getRhs(), - adaptor.getLhs()) - .getResult(); - - Value rhsOrMax = - rewriter - .create(loc, dstTy, rhsNaN, adaptor.getLhs(), - maxNoNaN) - .getResult(); - Value res = - rewriter - .create(loc, dstTy, lhsNaN, adaptor.getRhs(), - rhsOrMax) - .getResult(); - rewriter.replaceOp(op, res); - return success(); - } -}; - -struct ArithMinNumFToEmitC : public OpConversionPattern, - ArithFloatMinMaxToEmitCBase { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::MinNumFOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - - Value lhsNaN = isNaN(rewriter, loc, adaptor.getLhs()); - Value rhsNaN = isNaN(rewriter, loc, adaptor.getRhs()); - - Value cmpLt = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, - adaptor.getLhs(), adaptor.getRhs()) - .getResult(); - Value minNoNaN = - rewriter - .create(loc, dstTy, cmpLt, adaptor.getLhs(), - adaptor.getRhs()) - .getResult(); - - Value rhsOrMin = - rewriter - .create(loc, dstTy, rhsNaN, adaptor.getLhs(), - minNoNaN) - .getResult(); - Value res = - rewriter - .create(loc, dstTy, lhsNaN, adaptor.getRhs(), - rhsOrMin) - .getResult(); - rewriter.replaceOp(op, res); - return success(); - } -}; - -template -struct ArithMinMaxFPropagateNaNToEmitC : public OpConversionPattern, - ArithFloatMinMaxToEmitCBase { - using OpConversionPattern::OpConversionPattern; - - static Value buildPrimaryCandidate(ConversionPatternRewriter &rewriter, - Location loc, Type dstTy, Value lhs, - Value rhs) { - Value cmpLt = - rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, lhs, rhs) - .getResult(); - return rewriter - .create( - loc, dstTy, cmpLt, isMaximum ? rhs : lhs, isMaximum ? lhs : rhs) - .getResult(); - } - - static Value buildSignBitValue(ConversionPatternRewriter &rewriter, - Location loc, Value lhs, FloatType floatTy) { - auto bitsTy = - getUnsignedIntOpaqueType(rewriter.getContext(), floatTy.getWidth()); - auto templateArgs = rewriter.getArrayAttr({emitc::OpaqueAttr::get( - rewriter.getContext(), cast(bitsTy).getValue())}); - Value lhsBits = - rewriter - .create(loc, TypeRange{bitsTy}, "ptoas_bitcast", - ValueRange{lhs}, ArrayAttr{}, - templateArgs) - .getResult(0); - Value oneBits = makeEmitCIntConstant(rewriter, loc, bitsTy, 1); - Value shiftAmount = - makeEmitCIntConstant(rewriter, loc, bitsTy, floatTy.getWidth() - 1); - Value signMask = rewriter - .create(loc, bitsTy, oneBits, - shiftAmount) - .getResult(); - return rewriter - .create(loc, bitsTy, lhsBits, signMask) - .getResult(); - } - - static Value buildSignedZeroCandidate(ConversionPatternRewriter &rewriter, - Location loc, Type dstTy, Value lhs, - Value rhs, FloatType floatTy) { - Value zero = makeFZero(rewriter, loc, dstTy); - Value equal = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::eq, lhs, rhs) - .getResult(); - Value lhsZero = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::eq, lhs, - zero) - .getResult(); - Value bothZero = rewriter - .create(loc, rewriter.getI1Type(), - equal, lhsZero) - .getResult(); - auto bitsTy = - getUnsignedIntOpaqueType(rewriter.getContext(), floatTy.getWidth()); - Value zeroBits = makeEmitCIntConstant(rewriter, loc, bitsTy, 0); - Value lhsIsNegZero = - rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::ne, - buildSignBitValue(rewriter, loc, lhs, floatTy), - zeroBits) - .getResult(); - Value tie = rewriter - .create( - loc, dstTy, lhsIsNegZero, isMaximum ? rhs : lhs, - isMaximum ? lhs : rhs) - .getResult(); - return rewriter - .create(loc, dstTy, bothZero, tie, - buildPrimaryCandidate(rewriter, loc, dstTy, - lhs, rhs)) - .getResult(); - } - - static Value buildNaNPropagatingResult(ConversionPatternRewriter &rewriter, - Location loc, Type dstTy, Value lhs, - Value rhs, FloatType floatTy) { - Value lhsNaN = isNaN(rewriter, loc, lhs); - Value rhsNaN = isNaN(rewriter, loc, rhs); - Value noNaN = - buildSignedZeroCandidate(rewriter, loc, dstTy, lhs, rhs, floatTy); - Value rhsOrNoNaN = rewriter - .create(loc, dstTy, rhsNaN, rhs, - noNaN) - .getResult(); - return rewriter - .create(loc, dstTy, lhsNaN, lhs, rhsOrNoNaN) - .getResult(); - } - - LogicalResult - matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!isa(op.getType())) - return rewriter.notifyMatchFailure(op, "expected scalar float type"); - - auto loc = op.getLoc(); - Type dstTy = this->getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - - auto floatTy = cast(op.getType()); - rewriter.replaceOp(op, buildNaNPropagatingResult( - rewriter, loc, dstTy, adaptor.getLhs(), - adaptor.getRhs(), floatTy)); - return success(); - } -}; - -using ArithMaximumFToEmitC = - ArithMinMaxFPropagateNaNToEmitC; -using ArithMinimumFToEmitC = - ArithMinMaxFPropagateNaNToEmitC; - -//===----------------------------------------------------------------------===// -// Arith -> EmitC helpers -//===----------------------------------------------------------------------===// - -static emitc::OpaqueType getSignedIntOpaqueType(MLIRContext *ctx, - unsigned bitWidth) { - switch (bitWidth) { - case 1: - return emitc::OpaqueType::get(ctx, "int8_t"); - case 8: - return emitc::OpaqueType::get(ctx, "int8_t"); - case 16: - return emitc::OpaqueType::get(ctx, "int16_t"); - case 32: - return emitc::OpaqueType::get(ctx, "int32_t"); - case 64: - return emitc::OpaqueType::get(ctx, "int64_t"); - case 128: - return emitc::OpaqueType::get(ctx, "__int128"); - default: - llvm::errs() << "[Debug] Unsupported signed integer bitwidth: " << bitWidth - << "\n"; - return emitc::OpaqueType::get(ctx, "int64_t"); - } -} - -static emitc::OpaqueType getUnsignedIntOpaqueType(MLIRContext *ctx, - unsigned bitWidth) { - switch (bitWidth) { - case 1: - return emitc::OpaqueType::get(ctx, "uint8_t"); - case 8: - return emitc::OpaqueType::get(ctx, "uint8_t"); - case 16: - return emitc::OpaqueType::get(ctx, "uint16_t"); - case 32: - return emitc::OpaqueType::get(ctx, "uint32_t"); - case 64: - return emitc::OpaqueType::get(ctx, "uint64_t"); - case 128: - return emitc::OpaqueType::get(ctx, "unsigned __int128"); - default: - llvm::errs() << "[Debug] Unsupported unsigned integer bitwidth: " - << bitWidth << "\n"; - return emitc::OpaqueType::get(ctx, "uint64_t"); - } -} - -static emitc::OpaqueType getWiderSignedIntOpaqueType(MLIRContext *ctx, - unsigned bitWidth) { - switch (bitWidth) { - case 1: - case 8: - return getSignedIntOpaqueType(ctx, 16); - case 16: - return getSignedIntOpaqueType(ctx, 32); - case 32: - return getSignedIntOpaqueType(ctx, 64); - case 64: - return getSignedIntOpaqueType(ctx, 128); - default: - return getSignedIntOpaqueType(ctx, 128); - } -} - -static emitc::OpaqueType getWiderUnsignedIntOpaqueType(MLIRContext *ctx, - unsigned bitWidth) { - switch (bitWidth) { - case 1: - case 8: - return getUnsignedIntOpaqueType(ctx, 16); - case 16: - return getUnsignedIntOpaqueType(ctx, 32); - case 32: - return getUnsignedIntOpaqueType(ctx, 64); - case 64: - return getUnsignedIntOpaqueType(ctx, 128); - default: - return getUnsignedIntOpaqueType(ctx, 128); - } -} - -static Value makeEmitCOpaqueConstant(ConversionPatternRewriter &rewriter, - Location loc, Type type, - llvm::StringRef literal) { - auto attr = emitc::OpaqueAttr::get(rewriter.getContext(), literal); - return rewriter.create(loc, type, attr); -} - -static Value makeEmitCIntConstant(ConversionPatternRewriter &rewriter, - Location loc, Type type, int64_t value) { - return makeEmitCOpaqueConstant(rewriter, loc, type, std::to_string(value)); -} - -static FailureOr buildEmitCOpaqueConstantLiteral(Type targetType, - Attribute valueAttr) { - auto opaqueTy = dyn_cast(targetType); - if (!opaqueTy) - return failure(); - - if (opaqueTy.getValue() == "pto::MrgSortExecutedNumList") { - auto dense = dyn_cast_or_null(valueAttr); - if (!dense) - return failure(); - - auto vecTy = dyn_cast(dense.getType()); - if (!vecTy || vecTy.getRank() != 1 || vecTy.getNumElements() != 4 || - !vecTy.getElementType().isInteger(16)) - return failure(); - - std::string literal; - llvm::raw_string_ostream os(literal); - os << "pto::MrgSortExecutedNumList{"; - bool first = true; - for (APInt elem : dense.getValues()) { - if (!first) - os << ", "; - first = false; - os << elem.getZExtValue(); - } - os << "}"; - os.flush(); - return literal; - } - - return failure(); -} - -static Value emitCCast(ConversionPatternRewriter &rewriter, Location loc, - Type dstType, Value src) { - if (src.getType() == dstType) - return src; - return rewriter.createOrFold(loc, dstType, src); -} - -// For signless iN integers lowered to signed C++ types, this creates a value -// representing the same N-bit pattern in an unsigned C++ type of the same -// width. This avoids incorrect sign-extension when later widening to a larger -// unsigned type. -static Value castSignlessIntToUnsignedSameWidth(ConversionPatternRewriter &rewriter, - Location loc, Value v, - unsigned bitWidth) { - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - return emitCCast(rewriter, loc, uTy, v); -} - -struct ArithMulIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(arith::MulIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - // i1 mul is equivalent to bitwise AND (mod 2 arithmetic). - if (bitWidth == 1) { - rewriter.replaceOpWithNewOp(op, opTy, adaptor.getLhs(), - adaptor.getRhs()); - return success(); - } - - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value mulU = rewriter.create(loc, uTy, lhsU, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, mulU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithAddIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - // i1 add is equivalent to XOR (mod 2 arithmetic). - if (bitWidth == 1) { - rewriter.replaceOpWithNewOp(op, opTy, adaptor.getLhs(), - adaptor.getRhs()); - return success(); - } - - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value addU = rewriter.create(loc, uTy, lhsU, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, addU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithCastOPToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type newTy = getTypeConverter()->convertType(op.getType()); - if (!newTy) - return failure(); - if (adaptor.getIn().getType() == newTy) { - rewriter.replaceOp(op, adaptor.getIn()); - return success(); - } - rewriter.replaceOpWithNewOp(op, newTy, adaptor.getIn()); - return success(); - } -}; - -struct ArithSubIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(arith::SubIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - // i1 sub is equivalent to XOR (mod 2 arithmetic). - if (bitWidth == 1) { - rewriter.replaceOpWithNewOp(op, opTy, adaptor.getLhs(), - adaptor.getRhs()); - return success(); - } - - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value subU = rewriter.create(loc, uTy, lhsU, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, subU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithDivSIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(arith::DivSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type newTy = getTypeConverter()->convertType(op.getType()); - if (!newTy) - return failure(); - rewriter.replaceOpWithNewOp(op, newTy, adaptor.getLhs(), - adaptor.getRhs()); - return success(); - } -}; - -struct ArithRemSIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type newTy = getTypeConverter()->convertType(op.getType()); - if (!newTy) - return failure(); - rewriter.replaceOpWithNewOp(op, newTy, adaptor.getLhs(), - adaptor.getRhs()); - return success(); - } -}; - -struct ArithTruncIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - auto dstIntTy = dyn_cast(op.getType()); - auto srcIntTy = dyn_cast(op.getIn().getType()); - if (!dstIntTy || !srcIntTy) - return rewriter.notifyMatchFailure(op, "expected scalar integer types"); - - Type dstTy = getTypeConverter()->convertType(dstIntTy); - if (!dstTy) - return failure(); - - // to-i1 conversions: Arith wants truncation to the low bit, while C/C++ - // casts to bool are equivalent to `v != 0`. Implement as `(bool)(v & 1)`. - if (dstIntTy.getWidth() == 1) { - if (srcIntTy.getWidth() == 1) { - rewriter.replaceOp(op, adaptor.getIn()); - return success(); - } - - auto uSrcTy = - getUnsignedIntOpaqueType(rewriter.getContext(), srcIntTy.getWidth()); - Value inU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getIn(), - srcIntTy.getWidth()); - Value one = makeEmitCIntConstant(rewriter, loc, uSrcTy, 1); - Value masked = - rewriter.create(loc, uSrcTy, inU, one); - Value asBool = emitCCast(rewriter, loc, dstTy, masked); - rewriter.replaceOp(op, asBool); - return success(); - } - - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); - return success(); - } -}; - -struct ArithConstantToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type newType = getTypeConverter()->convertType(op.getType()); - if (!newType) - return failure(); - - // `adaptor.getValue()` may be null if attribute conversion isn't defined. - // Use the original attribute as fallback and always cast null-safely. - Attribute valueAttr = adaptor.getValue(); - if (!valueAttr) - valueAttr = op.getValue(); - - if (auto opaqueLiteral = buildEmitCOpaqueConstantLiteral(newType, valueAttr); - succeeded(opaqueLiteral)) { - auto constAttr = emitc::OpaqueAttr::get(rewriter.getContext(), *opaqueLiteral); - rewriter.replaceOpWithNewOp(op, newType, constAttr); - return success(); - } - - if (auto floatAttr = dyn_cast_or_null(valueAttr)) { - SmallString<32> valStr; - floatAttr.getValue().toString(valStr); - llvm::StringRef s(valStr); - // Ensure the literal parses as a floating-point constant in C/C++. - // `APFloat::toString` may emit "1" for integral values; make it "1.0". - const bool hasFloatMarker = - s.contains('.') || s.contains('e') || s.contains('E') || - s.contains('p') || s.contains('P') || s.starts_with("0x") || - s.starts_with("0X") || s.starts_with("nan") || - s.starts_with("-nan") || s.starts_with("inf") || - s.starts_with("-inf"); - if (!hasFloatMarker) - valStr.append(".0"); - // Suffix: keep `f` for f16/f32; omit for f64. - if (!floatAttr.getType().isF64()) - valStr.append("f"); - auto constAttr = emitc::OpaqueAttr::get(rewriter.getContext(), valStr); - rewriter.replaceOpWithNewOp(op, newType, constAttr); - return success(); - } - - if (auto intAttr = dyn_cast_or_null(valueAttr)) { - std::string valStr = std::to_string(intAttr.getValue().getSExtValue()); - auto constAttr = emitc::OpaqueAttr::get(rewriter.getContext(), valStr); - rewriter.replaceOpWithNewOp(op, newType, constAttr); - return success(); - } - - return failure(); - } -}; -//===----------------------------------------------------------------------===// -// pto.mgather lowering -> MGATHER(dst, src, indexes) (pto-isa) -//===----------------------------------------------------------------------===// - -struct PTOMGatherToMGATHER : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::MGatherOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto *ctx = rewriter.getContext(); - Value mem = peelUnrealized(adaptor.getMem()); - Value idx = peelUnrealized(adaptor.getIdx()); - Value dst = peelUnrealized(adaptor.getDst()); - - Value memArg = maybeWrapGlobalMemrefAsGlobalTensor( - rewriter, op.getLoc(), mem, op.getMem().getType(), op.getOperation()); - - auto gatherOobTok = [&](pto::GatherOOB mode) -> StringRef { - switch (mode) { - case pto::GatherOOB::Undefined: - return "pto::GatherOOB::Undefined"; - case pto::GatherOOB::Clamp: - return "pto::GatherOOB::Clamp"; - case pto::GatherOOB::Wrap: - return "pto::GatherOOB::Wrap"; - case pto::GatherOOB::Zero: - return "pto::GatherOOB::Zero"; - } - llvm_unreachable("unknown GatherOOB"); - }; - - SmallVector templateArgVec; - const bool rowCoalesce = - isRowCoalescedMGatherIndexType(op.getDst().getType(), op.getIdx().getType()); - templateArgVec.push_back(emitc::OpaqueAttr::get( - ctx, rowCoalesce ? "pto::Coalesce::Row" : "pto::Coalesce::Elem")); - if (op.getGatherOob() != pto::GatherOOB::Undefined) { - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, gatherOobTok(op.getGatherOob()))); - } - ArrayAttr templateArgs = rewriter.getArrayAttr(templateArgVec); - - rewriter.create( - op.getLoc(), TypeRange{}, "MGATHER", - ArrayAttr{}, templateArgs, - ValueRange{dst, memArg, idx}); - - if (op->getNumResults() == 0) { - rewriter.eraseOp(op); - } else { - rewriter.replaceOp(op, dst); - } - return success(); - } -}; - -struct AffineApplyMulConstToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(affine::AffineApplyOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto map = op.getAffineMap(); - - if (map.getNumDims() != 0 || map.getNumSymbols() != 1) - return failure(); - - auto expr = map.getResult(0); - auto bin = dyn_cast(expr); - if (!bin || bin.getKind() != AffineExprKind::Mul) - return failure(); - - auto lhs = bin.getLHS(); - auto rhs = bin.getRHS(); - - auto symExpr = dyn_cast(lhs); - auto constExpr = dyn_cast(rhs); - if (!symExpr || !constExpr) - return failure(); - - Value inputVal = adaptor.getMapOperands()[0]; - - std::string valStr = std::to_string(constExpr.getValue()); - auto cstAttr = emitc::OpaqueAttr::get(rewriter.getContext(), valStr); - auto cstOp = rewriter.create( - op.getLoc(), inputVal.getType(), cstAttr); - - rewriter.replaceOpWithNewOp( - op, inputVal.getType(), inputVal, cstOp); - - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// Kernel inference helpers -//===----------------------------------------------------------------------===// - -enum class KernelKind { VecAdd, Matmul, Unknown }; - -[[maybe_unused]] static KernelKind inferKernelKind(func::FuncOp f) { - bool hasAdd = false; - bool hasMM = false; - f.walk([&](Operation *op) { - if (isa(op)) hasAdd = true; - if (isa(op)) hasMM = true; - if (isa(op)) hasMM = true; - }); - if (hasMM) return KernelKind::Matmul; - if (hasAdd) return KernelKind::VecAdd; - return KernelKind::Unknown; -} - -[[maybe_unused]] static void inferTileMNK(func::FuncOp f, int &M, int &N, int &K) { - M = 32; N = 32; K = 32; - SmallVector subs; - f.walk([&](memref::SubViewOp sv) { subs.push_back(sv); }); - - auto readShape2D = [&](memref::SubViewOp sv, int &d0, int &d1) { - auto resTy = mlir::cast(sv.getResult().getType()); - if (resTy.getRank() == 2 && resTy.hasStaticShape()) { - d0 = (int)resTy.getDimSize(0); - d1 = (int)resTy.getDimSize(1); - } - }; - - if (subs.empty()) return; - - int a0=32, a1=32; - readShape2D(subs[0], a0, a1); - M = a0; N = a1; - - if (subs.size() >= 2) { - int b0=32, b1=32; - readShape2D(subs[0], a0, a1); - readShape2D(subs[1], b0, b1); - M = a0; K = a1; N = b1; - } -} - -static std::optional getKernelKindMacro(func::FuncOp funcOp) { - auto kernelKindAttr = - funcOp->getAttrOfType(FunctionKernelKindAttr::name); - if (!kernelKindAttr) - return std::nullopt; - - switch (kernelKindAttr.getKernelKind()) { - case FunctionKernelKind::Cube: - return StringRef("__DAV_CUBE__"); - case FunctionKernelKind::Vector: - return StringRef("__DAV_VEC__"); - } - - llvm_unreachable("unexpected kernel kind"); -} - -struct FuncToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(func::FuncOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Convert the function signature with the type converter. - Type convertedTy = getTypeConverter()->convertType(op.getFunctionType()); - auto funcType = dyn_cast_or_null(convertedTy); - if (!funcType) - return rewriter.notifyMatchFailure(op, "failed to convert function type"); - if (funcType.getNumResults() > 1) - return rewriter.notifyMatchFailure( - op, "EmitC cannot return multiple values"); - - // Create the EmitC function with the converted signature. - auto emitcFunc = - rewriter.create(op.getLoc(), op.getName(), funcType); - - for (const auto &namedAttr : op->getAttrs()) { - StringRef name = namedAttr.getName().strref(); - if (name == op.getFunctionTypeAttrName() || - name == SymbolTable::getSymbolAttrName() || - name == pto::kPTOEntryAttrName || - name == pto::kLegacyHACCEntryAttrName || - name == "pto.internal.entry") - continue; - emitcFunc->setAttr(namedAttr.getName(), namedAttr.getValue()); - } - - if (op.isDeclaration()) { - emitcFunc.setSpecifiersAttr(rewriter.getStrArrayAttr({"extern"})); - rewriter.eraseOp(op); - return success(); - } - - if (pto::isPTOEntryFunction(op)) { - emitcFunc.setSpecifiersAttr( - rewriter.getStrArrayAttr({"__global__ AICORE"})); - } else if (op.isPrivate()) { - emitcFunc.setSpecifiersAttr( - rewriter.getStrArrayAttr({"static", "AICORE"})); - } else { - emitcFunc.setSpecifiersAttr(rewriter.getStrArrayAttr({"AICORE"})); - } - - std::optional kernelKindMacro = getKernelKindMacro(op); - bool needsNoSplitGuard = needsA5NoSplitVectorGuard(op.getOperation()); - - // Inline the original body, then convert region/block argument types to - // match the converted signature (also covers CFG blocks introduced by - // pre-lowering, e.g. scf.while -> cf.br/cf.cond_br). - rewriter.inlineRegionBefore(op.getBody(), emitcFunc.getBody(), - emitcFunc.end()); - - TypeConverter::SignatureConversion entryConv(op.getNumArguments()); - for (unsigned i = 0; i < op.getNumArguments(); ++i) - entryConv.addInputs(i, funcType.getInput(i)); - - if (failed(rewriter.convertRegionTypes(&emitcFunc.getBody(), - *getTypeConverter(), &entryConv))) - return failure(); - - // Preserve the existing function prologue shape. `kernel_kind` functions are - // emitted with the same macro guard/reset sequence that used to come from - // early pto.section wrapping, but only after SCF pre-lowering has finished. - { - Block &entryBlock = emitcFunc.getBody().front(); - rewriter.setInsertionPointToStart(&entryBlock); - rewriter.create(op.getLoc(), "using T = float;"); - if (kernelKindMacro) { - std::string startMacro = "\n#if defined(" + kernelKindMacro->str() + ")"; - rewriter.create(op.getLoc(), startMacro); - if (*kernelKindMacro == "__DAV_VEC__") { - rewriter.create(op.getLoc(), "set_mask_norm();"); - rewriter.create(op.getLoc(), - "set_vector_mask(-1, -1);"); - if (needsNoSplitGuard) - rewriter.create( - op.getLoc(), "if (get_subblockid() == 0) {"); - } - } - } - - if (kernelKindMacro) { - Block &lastBlock = emitcFunc.getBody().back(); - rewriter.setInsertionPoint(lastBlock.getTerminator()); - if (*kernelKindMacro == "__DAV_VEC__" && needsNoSplitGuard) - rewriter.create(op.getLoc(), "}"); - std::string endMacro = "#endif // " + kernelKindMacro->str() + "\n"; - rewriter.create(op.getLoc(), endMacro); - } - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// SubView lowering to GlobalTensor (keep your existing code) -//===----------------------------------------------------------------------=== - -enum class Role { A, B, C, Unknown }; - -template -static std::optional inferMatmulLikeSubviewRole(MatmulLikeOp op, - Value buffer) { - if (op.getLhs() == buffer) - return Role::A; - if (op.getRhs() == buffer) - return Role::B; - return std::nullopt; -} - -static std::optional inferSubviewRoleFromLoadUser(mlir::pto::TLoadOp load) { - Value buffer = load.getDst(); - if (!buffer) - return std::nullopt; - for (Operation *user : buffer.getUsers()) { - if (auto matmul = dyn_cast(user)) { - if (auto role = inferMatmulLikeSubviewRole(matmul, buffer)) - return role; - continue; - } - if (auto matmulAcc = dyn_cast(user)) { - if (auto role = inferMatmulLikeSubviewRole(matmulAcc, buffer)) - return role; - } - } - return std::nullopt; -} - -static std::optional inferSubviewRoleFromUser(Operation *user, Value result) { - if (auto load = dyn_cast(user)) - return inferSubviewRoleFromLoadUser(load); - if (auto store = dyn_cast(user)) { - if (store.getDst() == result) - return Role::C; - } - return std::nullopt; -} - -[[maybe_unused]] static Role inferSubviewRole(memref::SubViewOp sv) { - Value result = sv.getResult(); - for (Operation *user : result.getUsers()) { - if (auto role = inferSubviewRoleFromUser(user, result)) - return *role; - } - return Role::Unknown; -} - -// ============================================================================= -// 4. MemRef SubView -> Explicit Shape/Stride Construction (Full Implementation) -// ============================================================================= -struct SubviewToEmitCPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - // 辅助函数:尝试从 OpFoldResult 中提取静态整数值 - std::optional extractStaticInt(OpFoldResult ofr) const { - if (auto attr = ofr.dyn_cast()) { - if (auto intAttr = dyn_cast(attr)) - return intAttr.getInt(); - } else { - Value v = ofr.get(); - if (auto cOp = v.getDefiningOp()) { - if (auto iAttr = dyn_cast(cOp.getValue())) - return iAttr.getInt(); - } else if (auto idxOp = v.getDefiningOp()) { - return idxOp.value(); - } - } - return std::nullopt; - } - - LogicalResult matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - // 获取源 MemRef 类型信息 - auto srcType = mlir::cast(op.getSource().getType()); - int64_t rank = srcType.getRank(); - - auto elemTypeToString = [&](Type elemTy) -> std::string { - if (elemTy.isF16()) - return "half"; - if (elemTy.isBF16()) - return "bfloat16_t"; - if (elemTy.isF32()) - return "float"; - if (elemTy.isF64()) - return "double"; - if (elemTy.isInteger(8)) { - if (elemTy.isSignlessInteger(8) || elemTy.isSignedInteger(8)) - return "int8_t"; - return "uint8_t"; - } - if (elemTy.isInteger(16)) { - if (elemTy.isSignlessInteger(16) || elemTy.isSignedInteger(16)) - return "int16_t"; - return "uint16_t"; - } - if (elemTy.isInteger(32)) { - if (elemTy.isSignlessInteger(32) || elemTy.isSignedInteger(32)) - return "int32_t"; - return "uint32_t"; - } - if (elemTy.isInteger(64)) { - return cast(elemTy).isUnsigned() ? "uint64_t" : "int64_t"; - } - return "float"; - }; - - // ------------------------------------------------------------------------- - // Part 1: 指针偏移计算 (Runtime Pointer Arithmetic) - // ------------------------------------------------------------------------- - - // 准备类型: unsigned - Type u32Ty = emitc::OpaqueType::get(ctx, "unsigned"); - - // Helper: 创建 unsigned 常量 - auto mkU32 = [&](int64_t v) -> Value { - return rewriter.create( - loc, u32Ty, emitc::OpaqueAttr::get(ctx, std::to_string(v))); - }; - - // Helper: 将 OpFoldResult 转为 EmitC Value (用于计算) - auto ofrToEmitCValue = [&](OpFoldResult ofr) -> Value { - if (auto v = ofr.dyn_cast()) { - Value rv = rewriter.getRemappedValue(v); - // 如果类型不匹配,插入 Cast - if (rv.getType() != u32Ty) - return rewriter.create(loc, u32Ty, rv).getResult(); - return rv; - } - if (auto attr = ofr.dyn_cast()) { - if (auto ia = dyn_cast(attr)) - return mkU32(ia.getValue().getSExtValue()); - } - return mkU32(0); - }; - - // 1. 获取 Source 的 Strides (支持动态 Stride 收集) - SmallVector sourceStrides; - - if (auto rc = op.getSource().getDefiningOp()) { - sourceStrides = rc.getMixedStrides(); - } else { - SmallVector strideInts; - int64_t offset = ShapedType::kDynamic; - bool useTypeStrides = succeeded(getStridesAndOffset(srcType, strideInts, offset)); - (void)offset; - if (useTypeStrides) { - for (int64_t s : strideInts) { - if (s == ShapedType::kDynamic) - useTypeStrides = false; - } - } - if (useTypeStrides) { - for (int64_t s : strideInts) { - sourceStrides.push_back(rewriter.getIndexAttr(s)); - } - } else { - // Fallback: Compact Layout - auto shape = srcType.getShape(); - int64_t current = 1; - sourceStrides.resize(rank); - for (int i = rank - 1; i >= 0; --i) { - sourceStrides[i] = rewriter.getIndexAttr(current); - if (shape[i] != ShapedType::kDynamic) current *= shape[i]; - } - } - } - - // 2. 计算运行时 Offset - auto staticOffsets = op.getStaticOffsets(); - auto dynamicOffsets = adaptor.getOffsets(); - int dynOffIdx = 0; - Value totalOffset = mkU32(0); - - for (int i = 0; i < rank; ++i) { - // A. 获取 Offset - Value offVal; - if (staticOffsets[i] == ShapedType::kDynamic) { - Value rawDyn = dynamicOffsets[dynOffIdx++]; - offVal = rewriter.create(loc, u32Ty, rawDyn); - } else { - offVal = mkU32(staticOffsets[i]); - } - - // B. 获取 Stride (用于指针计算) - Value strideVal = mkU32(1); - if (i < (int)sourceStrides.size()) { - strideVal = ofrToEmitCValue(sourceStrides[i]); - } - - // C. 累加 - Value term = rewriter.create(loc, u32Ty, offVal, strideVal); - totalOffset = rewriter.create(loc, u32Ty, totalOffset, term); - } - - // 3. 生成新指针 - // - // NOTE: Some toolchains may materialize kernel pointer params as `void*` even - // when the underlying element type is i16. Pointer arithmetic on `void*` - // is ill-formed in C++, so we explicitly cast to a typed pointer for i16. - Value sourcePtr = adaptor.getSource(); - Value tileCandidate = sourcePtr; - if (auto castOp = sourcePtr.getDefiningOp()) { - tileCandidate = castOp.getOperand(); - } else if (auto uc = - sourcePtr.getDefiningOp()) { - tileCandidate = uc.getOperand(0); - } - if (auto ot = dyn_cast(tileCandidate.getType())) { - auto tyStr = ot.getValue(); - if (tyStr.find("Tile<") != std::string::npos || - tyStr.find("ConvTile<") != std::string::npos) { - std::string elemTok = elemTypeToString(srcType.getElementType()); - pto::AddressSpace as = pto::AddressSpace::GM; - if (auto asAttr = - dyn_cast_or_null(srcType.getMemorySpace())) - as = asAttr.getAddressSpace(); - sourcePtr = - materializeTileDataValue(rewriter, loc, tileCandidate, as, elemTok); - if (tileDataReturnsIntegralAddress(as)) - sourcePtr = - materializeAddressAsPointer(rewriter, loc, sourcePtr, as, elemTok); - } - } - Value newPtr; - { - auto resTy = mlir::cast(op.getResult().getType()); - Type elemTy = resTy.getElementType(); - if (elemTy.isInteger(16)) { - std::string castElemTypeStr = "int16_t"; - if (cast(elemTy).isUnsigned()) - castElemTypeStr = "uint16_t"; - - std::string qualifier = "__gm__"; - if (Attribute ms = srcType.getMemorySpace()) { - if (auto ptoAttr = dyn_cast(ms)) { - qualifier = addrSpaceQualifier(ptoAttr.getAddressSpace()); - } - } - - auto typedPtrTy = emitc::PointerType::get( - emitc::OpaqueType::get(ctx, qualifier + " " + castElemTypeStr)); - Value typedSourcePtr = rewriter.create(loc, typedPtrTy, sourcePtr); - newPtr = rewriter.create(loc, typedPtrTy, typedSourcePtr, totalOffset); - } else { - newPtr = rewriter.create(loc, sourcePtr.getType(), sourcePtr, totalOffset); - } - } - - - // ------------------------------------------------------------------------- - // Part 2: For non-GM memrefs, keep pointer (no GlobalTensor). - // ------------------------------------------------------------------------- - bool isGlobal = true; - if (auto asAttr = dyn_cast_or_null(srcType.getMemorySpace())) { - auto as = asAttr.getAddressSpace(); - isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); - } - if (!isGlobal) { - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - if (newPtr.getType() != dstTy) - newPtr = rewriter.create(loc, dstTy, newPtr); - rewriter.replaceOp(op, newPtr); - return success(); - } - - // ------------------------------------------------------------------------- - // Part 3: 生成 GlobalTensor 类型 (Shape/Stride Template Generation) - // ------------------------------------------------------------------------- - - // When emitting C++ with `declareVariablesAtTop`, value declarations are - // hoisted before body statements. Avoid introducing local `using` aliases - // for templated types (Shape/Stride/GlobalTensor) because those aliases - // would appear after the hoisted declarations and break compilation - // (`unknown type name`). - // - // Instead, use the fully spelled template types as EmitC opaque types. - - auto resTy = mlir::cast(op.getResult().getType()); - - // 1. 解析具体元素类型 - std::string elemTypeStr = getElemTypeStringForGT(resTy.getElementType()); - - // 2. 生成 Shape 模板参数,之后会右对齐有效维度并补齐到 5 维(高维填 1) - SmallVector shapeParamsVec; - SmallVector sizeValues; // 每个维度对应的运行时 size(统一为 unsigned) - auto resShape = resTy.getShape(); - auto mixedSizes = op.getMixedSizes(); - sizeValues.reserve(rank); - for (int i = 0; i < resTy.getRank(); ++i) { - if (resShape[i] == ShapedType::kDynamic) { - shapeParamsVec.push_back(-1); - } else { - shapeParamsVec.push_back(resShape[i]); - } - // size 值:优先从 op.getMixedSizes() 取(可动态/静态),否则退化为类型里的静态 shape。 - if (i < (int)mixedSizes.size()) - sizeValues.push_back(ofrToEmitCValue(mixedSizes[i])); - else - sizeValues.push_back( - mkU32(resShape[i] == ShapedType::kDynamic ? 1 : resShape[i])); - } - - // 3. 生成 Stride 模板参数 + 运行时 stride 值(考虑 subview step) - SmallVector strideTemplateVec; - SmallVector strideValues; // 每个维度对应的运行时 stride(统一为 unsigned) - strideTemplateVec.reserve(rank); - strideValues.reserve(rank); - auto subViewSteps = op.getMixedStrides(); - for (int i = 0; i < rank; ++i) { - OpFoldResult srcStrideOfr = - (i < (int)sourceStrides.size()) ? sourceStrides[i] - : rewriter.getIndexAttr(1); - OpFoldResult stepOfr = (i < (int)subViewSteps.size()) - ? subViewSteps[i] - : rewriter.getIndexAttr(1); - - auto srcStatic = extractStaticInt(srcStrideOfr); - auto stepStatic = extractStaticInt(stepOfr); - if (srcStatic && stepStatic) { - int64_t finalStride = (*srcStatic) * (*stepStatic); - strideTemplateVec.push_back(finalStride); - strideValues.push_back(mkU32(finalStride)); - continue; - } - - strideTemplateVec.push_back(-1); - Value srcV = ofrToEmitCValue(srcStrideOfr); - Value stepV = ofrToEmitCValue(stepOfr); - // 尽量避免乘以 1 生成冗余指令 - if (stepStatic && *stepStatic == 1) - strideValues.push_back(srcV); - else if (srcStatic && *srcStatic == 1) - strideValues.push_back(stepV); - else - strideValues.push_back( - rewriter.create(loc, u32Ty, srcV, stepV)); - } - - // 3.1 右对齐到 5 维:shape 补 1;已有维度继承原 stride; - // 被补出来的高维按“紧密升维”规则连续推导:stride[i] = shape[i+1] * stride[i+1] - SmallVector finalShape; - SmallVector finalStride; - buildGlobalTensorShapeAndStride(shapeParamsVec, strideTemplateVec, - finalShape, finalStride); - Value oneU32 = mkU32(1); - SmallVector finalShapeValues(5, oneU32); - SmallVector finalStrideValues(5, oneU32); - int shift = 5 - rank; - - // 先放入原始 shape/stride(保持用户提供的值) - for (int i = 0; i < rank && i < 5; ++i) { - finalShapeValues[shift + i] = sizeValues[i]; - finalStrideValues[shift + i] = strideValues[i]; - } - - // 从低维到高维倒推补齐 stride(仅对补出来的前置维度生效) - for (int i = 3; i >= 0; --i) { - // 如果该维已由原始 rank 覆盖,则保持原值 - if (i >= shift) - continue; - if (finalStride[i] != -1) { - finalStrideValues[i] = mkU32(finalStride[i]); - continue; - } - // 动态推导:stride[i] = shape[i+1] * stride[i+1] - if (finalShape[i + 1] == 1) { - finalStrideValues[i] = finalStrideValues[i + 1]; - } else { - finalStrideValues[i] = rewriter.create( - loc, u32Ty, finalShapeValues[i + 1], finalStrideValues[i + 1]); - } - } - - std::string shapeParams = joinIntTemplateParams(finalShape); - std::string strideParams = joinIntTemplateParams(finalStride); - - // Spelled-out C++ types. - std::string shapeCppType = "pto::Shape<" + shapeParams + ">"; - std::string strideCppType = "pto::Stride<" + strideParams + ">"; - - // 3.0 Layout: prefer the attribute from InferPTOLayout; only fall back to - // local inference when the pass is disabled. - std::string layoutEnum = "pto::Layout::ND"; - if (auto layout = resolveLayoutForGlobalTensor(op, op.getSource())) { - layoutEnum = layoutToEmitCString(*layout); - } else { - bool allStatic = - llvm::all_of(finalShape, [](int64_t value) { return value != -1; }) && - llvm::all_of(finalStride, [](int64_t value) { return value != -1; }); - - int layoutTag = 0; // ND - auto elemBytes = 4; // default float - if (elemTypeStr.find("half") != std::string::npos || - elemTypeStr.find("f16") != std::string::npos || - elemTypeStr.find("bf16") != std::string::npos) - elemBytes = 2; - else if (elemTypeStr.find("double") != std::string::npos || - elemTypeStr.find("f64") != std::string::npos) - elemBytes = 8; - - if (allStatic) { - if (finalShape[2] == 16 && - finalShape[2] * finalShape[3] * elemBytes == 512 && - finalStride[4] == 1 && finalStride[3] == finalShape[4]) { - layoutTag = 2; // NZ - } else { - bool isRow = finalStride[4] == 1; - for (int i = 3; i >= 0; --i) - isRow &= (finalStride[i] == - multiplyOrDynamic(finalStride[i + 1], finalShape[i + 1])); - bool isCol = finalStride[0] == 1; - for (int i = 0; i < 4; ++i) - isCol &= (finalStride[i + 1] == - multiplyOrDynamic(finalStride[i], finalShape[i])); - if (isCol) - layoutTag = 1; // DN - else - layoutTag = isRow ? 0 : 0; // fallback ND - } - } - - if (layoutTag == 1) - layoutEnum = "pto::Layout::DN"; - else if (layoutTag == 2) - layoutEnum = "pto::Layout::NZ"; - } - // GlobalTensor takes a Layout non-type template parameter; directly use the - // enum constant. - - - // ------------------------------------------------------------------------- - // Part 3: 显式对象实例化 (Explicit Object Instantiation) - // ------------------------------------------------------------------------- - - // A. Instantiate Shape object. - auto shapeTypeOpaque = emitc::OpaqueType::get(ctx, shapeCppType); - SmallVector shapeArgs; - // 从 adaptor.getSizes() 获取 subview 的所有 dynamic sizes - for (Value dynSize : adaptor.getSizes()) { - shapeArgs.push_back(dynSize); - } - - auto shapeInstOp = rewriter.create( - loc, - shapeTypeOpaque, // 返回类型 - shapeCppType, // 调用的“函数名”即类名构造函数 - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange(shapeArgs) - ); - - // B. Instantiate Stride object. - auto strideTypeOpaque = emitc::OpaqueType::get(ctx, strideCppType); - // 仅传入动态 stride 维度对应的值,匹配 pto::Stride 的 N-parameter ctor(并满足其 static_assert)。 - SmallVector strideCtorArgs; - strideCtorArgs.reserve(5); - for (int i = 0; i < 5; ++i) { - if (finalStride[i] == -1) - strideCtorArgs.push_back(finalStrideValues[i]); - } - auto strideInstOp = rewriter.create( - loc, strideTypeOpaque, strideCppType, - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange(strideCtorArgs)); - - // C. Instantiate GlobalTensor object (ptr + shape + stride). - std::string gtCppType = "GlobalTensor<" + elemTypeStr + ", " + shapeCppType + - ", " + strideCppType + ", " + layoutEnum + ">"; - auto gtType = emitc::OpaqueType::get(ctx, gtCppType); - - // 准备构造参数: [ptr, shape_instance, stride_instance] - SmallVector gtConstructorArgs; - gtConstructorArgs.push_back(newPtr); - gtConstructorArgs.push_back(shapeInstOp.getResult(0)); // 拿到 shape_inst 的 SSA Value - gtConstructorArgs.push_back(strideInstOp.getResult(0)); // 拿到 stride_inst 的 SSA Value - - rewriter.replaceOpWithNewOp( - op, - gtType, - gtCppType, - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange(gtConstructorArgs) - ); - - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// Helper: build GlobalTensor from a static MemRef (for TLOAD/TSTORE) -//===----------------------------------------------------------------------===// - -static std::string getElemTypeStringForGT(Type elemTy) { - return getEmitCScalarTypeToken(elemTy); -} - -static bool hasStaticShape(MemRefType mrTy) { - return llvm::none_of(mrTy.getShape(), [](int64_t dim) { - return dim == ShapedType::kDynamic; - }); -} - -static bool getStaticMemrefLayout(MemRefType mrTy, SmallVectorImpl &strides, - int64_t &offset) { - if (failed(getStridesAndOffset(mrTy, strides, offset))) { - strides.clear(); - int64_t stride = 1; - ArrayRef shape = mrTy.getShape(); - for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { - strides.push_back(stride); - stride *= shape[i]; - } - std::reverse(strides.begin(), strides.end()); - offset = 0; - } - return offset != ShapedType::kDynamic && - llvm::none_of(strides, [](int64_t strideValue) { - return strideValue == ShapedType::kDynamic; - }); -} - -static Value applyStaticMemrefOffset(ConversionPatternRewriter &rewriter, - Location loc, Value basePtr, - int64_t offset) { - if (offset == 0) - return basePtr; - auto *ctx = rewriter.getContext(); - Type u32Ty = emitc::OpaqueType::get(ctx, "unsigned"); - auto offVal = rewriter.create( - loc, u32Ty, emitc::OpaqueAttr::get(ctx, std::to_string(offset))); - return rewriter.create(loc, basePtr.getType(), basePtr, offVal); -} - -static int getGlobalTensorElementBytes(Type elemTy) { - return static_cast(getPTOStorageElemByteSize(elemTy)); -} - -static int64_t multiplyOrDynamic(int64_t lhs, int64_t rhs) { - if (lhs < 0 || rhs < 0) - return -1; - return lhs * rhs; -} - -static void buildGlobalTensorShapeAndStride(ArrayRef shape, - ArrayRef strides, - SmallVectorImpl &shape5D, - SmallVectorImpl &stride5D) { - shape5D.assign(5, 1); - stride5D.assign(5, 1); - int rank = static_cast(shape.size()); - int shift = 5 - rank; - for (int i = 0; i < rank && i < 5; ++i) { - shape5D[shift + i] = shape[i]; - stride5D[shift + i] = strides[i]; - } - for (int i = 3; i >= 0; --i) { - if (i >= shift) - continue; - stride5D[i] = multiplyOrDynamic(shape5D[i + 1], stride5D[i + 1]); - } -} - -static std::string joinIntTemplateParams(ArrayRef values) { - std::string result; - for (size_t i = 0; i < values.size(); ++i) { - if (i != 0) - result += ", "; - result += std::to_string(values[i]); - } - return result; -} - -static SmallVector buildRowMajorStrides(ArrayRef shape) { - SmallVector strides(shape.size(), 1); - int64_t running = 1; - for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { - strides[i] = running; - running = multiplyOrDynamic(running, shape[i]); - } - return strides; -} - -static std::string getGlobalTensorTypeStringFromShape(Type elemTy, - ArrayRef shape, - StringRef layoutEnum) { - SmallVector strides = buildRowMajorStrides(shape); - return getGlobalTensorTypeStringFromShapeAndStrides(elemTy, shape, strides, - layoutEnum); -} - -static std::string getGlobalTensorTypeStringFromShapeAndStrides( - Type elemTy, ArrayRef shape, ArrayRef strides, - StringRef layoutEnum) { - SmallVector shape5D; - SmallVector stride5D; - buildGlobalTensorShapeAndStride(shape, strides, shape5D, stride5D); - - std::string elemTypeStr = getElemTypeStringForGT(elemTy); - std::string shapeType = "pto::Shape<" + joinIntTemplateParams(shape5D) + ">"; - std::string strideType = - "pto::Stride<" + joinIntTemplateParams(stride5D) + ">"; - return "GlobalTensor<" + elemTypeStr + ", " + shapeType + ", " + - strideType + ", " + layoutEnum.str() + ">"; -} - -static emitc::OpaqueType getGlobalTensorOpaqueTypeFromShape( - MLIRContext *ctx, Type elemTy, ArrayRef shape, - StringRef layoutEnum) { - return emitc::OpaqueType::get( - ctx, getGlobalTensorTypeStringFromShape(elemTy, shape, layoutEnum)); -} - -static std::string inferFallbackGlobalTensorLayout(ArrayRef shape5D, - ArrayRef stride5D, - Type elemTy) { - int elemBytes = getGlobalTensorElementBytes(elemTy); - if (elemBytes == 0) - return "pto::Layout::ND"; - if (shape5D[2] == 16 && multiplyOrDynamic(shape5D[2], shape5D[3]) * elemBytes == 512 && - stride5D[4] == 1 && stride5D[3] == shape5D[4]) { - return "pto::Layout::NZ"; - } - - bool isRowMajor = stride5D[4] == 1; - for (int i = 3; i >= 0 && isRowMajor; --i) - isRowMajor = stride5D[i] == multiplyOrDynamic(stride5D[i + 1], shape5D[i + 1]); - - bool isColMajor = stride5D[0] == 1; - for (int i = 0; i < 4 && isColMajor; ++i) - isColMajor = stride5D[i + 1] == multiplyOrDynamic(stride5D[i], shape5D[i]); - - if (isColMajor) - return "pto::Layout::DN"; - return isRowMajor ? "pto::Layout::ND" : "pto::Layout::ND"; -} - -static std::string resolveGlobalTensorLayout(Operation *anchor, Value basePtr, - ArrayRef shape5D, - ArrayRef stride5D, - Type elemTy) { - if (auto layout = resolveLayoutForGlobalTensor(anchor, basePtr)) - return layoutToEmitCString(*layout); - return inferFallbackGlobalTensorLayout(shape5D, stride5D, elemTy); -} - -struct GlobalTensorTypeNames { - std::string shapeTypeName; - std::string strideTypeName; - std::string tensorTypeName; - std::string layoutConstName; -}; - -static GlobalTensorTypeNames getGlobalTensorTypeNames(Operation *anchor) { - std::string suffix = "_" + std::to_string(reinterpret_cast(anchor)); - return { - "GTShape" + suffix, - "GTStride" + suffix, - "GT" + suffix, - "GT" + suffix + "_layout", - }; -} -static Value buildGlobalTensorFromMemref(ConversionPatternRewriter &rewriter, - Location loc, Value basePtr, - MemRefType mrTy, - Operation *anchor) { - auto *ctx = rewriter.getContext(); - - ArrayRef shape = mrTy.getShape(); - if (!hasStaticShape(mrTy)) - return Value(); - - SmallVector strides; - int64_t offset = 0; - if (!getStaticMemrefLayout(mrTy, strides, offset)) - return Value(); - - Value ptr = applyStaticMemrefOffset(rewriter, loc, basePtr, offset); - GlobalTensorTypeNames names = getGlobalTensorTypeNames(anchor); - std::string elemTypeStr = getElemTypeStringForGT(mrTy.getElementType()); - SmallVector shape5D; - SmallVector stride5D; - buildGlobalTensorShapeAndStride(shape, strides, shape5D, stride5D); - - rewriter.create( - loc, "using " + names.shapeTypeName + " = pto::Shape<" + - joinIntTemplateParams(shape5D) + ">;"); - rewriter.create( - loc, "using " + names.strideTypeName + " = pto::Stride<" + - joinIntTemplateParams(stride5D) + ">;"); - - std::string layoutEnum = resolveGlobalTensorLayout( - anchor, basePtr, shape5D, stride5D, mrTy.getElementType()); - rewriter.create(loc, "constexpr pto::Layout " + - names.layoutConstName + " = " + - layoutEnum + ";"); - - auto shapeTypeOpaque = emitc::OpaqueType::get(ctx, names.shapeTypeName); - auto strideTypeOpaque = emitc::OpaqueType::get(ctx, names.strideTypeName); - auto shapeInstOp = rewriter.create( - loc, shapeTypeOpaque, names.shapeTypeName, ArrayAttr{}, ArrayAttr{}, - ValueRange{}); - auto strideInstOp = rewriter.create( - loc, strideTypeOpaque, names.strideTypeName, ArrayAttr{}, ArrayAttr{}, - ValueRange{}); - - rewriter.create( - loc, "using " + names.tensorTypeName + " = GlobalTensor<" + elemTypeStr + - ", " + names.shapeTypeName + ", " + names.strideTypeName + - ", " + names.layoutConstName + ">;"); - auto gtType = emitc::OpaqueType::get(ctx, names.tensorTypeName); - - SmallVector gtArgs; - gtArgs.push_back(ptr); - gtArgs.push_back(shapeInstOp.getResult(0)); - gtArgs.push_back(strideInstOp.getResult(0)); - - auto gtInst = rewriter.create( - loc, gtType, names.tensorTypeName, ArrayAttr{}, ArrayAttr{}, - ValueRange(gtArgs)); - - return gtInst.getResult(0); -} - -static Value maybeWrapGlobalMemrefAsGlobalTensor( - ConversionPatternRewriter &rewriter, Location loc, Value loweredValue, - Type originalType, Operation *anchor) { - auto mrTy = dyn_cast(originalType); - if (!mrTy) - return loweredValue; - - bool isGlobal = true; - if (auto asAttr = - dyn_cast_or_null(mrTy.getMemorySpace())) { - auto as = asAttr.getAddressSpace(); - isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); - } - if (!isGlobal) - return loweredValue; - - if (Value gt = - buildGlobalTensorFromMemref(rewriter, loc, loweredValue, mrTy, anchor)) - return gt; - return loweredValue; -} - -static Value castToGMBytePointer(ConversionPatternRewriter &rewriter, - Location loc, Value value) { - auto *ctx = rewriter.getContext(); - auto targetTy = - emitc::PointerType::get(emitc::OpaqueType::get(ctx, "__gm__ uint8_t")); - if (value.getType() == targetTy) - return value; - - auto castTyAttr = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "__gm__ uint8_t*")}); - if (isSetFFTsPointerLikeType(value.getType())) { - return rewriter - .create(loc, targetTy, "reinterpret_cast", - ArrayAttr{}, castTyAttr, - ValueRange{value}) - .getResult(0); - } - return rewriter.create(loc, targetTy, value).getResult(); -} - -static Value materializeTensorViewDataPointer( - ConversionPatternRewriter &rewriter, Location loc, Value value, - Type sourceType) { - auto tvTy = dyn_cast(sourceType); - if (!tvTy) - return value; - - auto *ctx = rewriter.getContext(); - std::string elemTypeStr = getElemTypeStringForGT(tvTy.getElementType()); - auto ptrTy = emitc::PointerType::get( - emitc::OpaqueType::get(ctx, "__gm__ " + elemTypeStr)); - return rewriter - .create(loc, ptrTy, "PTOAS__GLOBAL_TENSOR_DATA", - ArrayAttr{}, ArrayAttr{}, ValueRange{value}) - .getResult(0); -} - -static std::string tileBufBLayoutToken(pto::TileBufConfigAttr configAttr) { - std::string blTok = "BLayout::RowMajor"; - if (auto blAttr = dyn_cast(configAttr.getBLayout())) { - if (static_cast(blAttr.getValue()) == 1) - blTok = "BLayout::ColMajor"; - } - return blTok; -} - -static std::string tileBufSLayoutToken(pto::TileBufConfigAttr configAttr) { - std::string slTok = "SLayout::NoneBox"; - if (auto slAttr = dyn_cast(configAttr.getSLayout())) { - int32_t slVal = static_cast(slAttr.getValue()); - slTok = (slVal == 1) ? "SLayout::RowMajor" - : (slVal == 2) ? "SLayout::ColMajor" - : "SLayout::NoneBox"; - } - return slTok; -} - -static std::string tileBufPadToken(pto::TileBufConfigAttr configAttr) { - std::string padTok = "PadValue::Null"; - if (auto padAttr = dyn_cast(configAttr.getPad())) { - switch (static_cast(padAttr.getValue())) { - case 1: - padTok = "PadValue::Zero"; - break; - case 2: - padTok = "PadValue::Max"; - break; - case 3: - padTok = "PadValue::Min"; - break; - default: - padTok = "PadValue::Null"; - break; - } - } - return padTok; -} - -static pto::BLayout getTileBufBLayoutValue(pto::TileBufConfigAttr configAttr) { - if (auto blAttr = dyn_cast(configAttr.getBLayout())) - return blAttr.getValue(); - return pto::BLayout::RowMajor; -} - -static int64_t renderTileTemplateDim(int64_t rawDim, Type elemTy, - pto::BLayout blayout, int dimIdx) { - assert(dimIdx >= 0 && dimIdx < 2 && - "renderTileTemplateDim expects a rank-2 rows/cols dimension index"); - if (rawDim == ShapedType::kDynamic) - return rawDim; - if (!pto::isPTOFloat4PackedType(elemTy)) - return rawDim; - int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; - return dimIdx == packedDim ? rawDim * 2 : rawDim; -} - -static FailureOr buildAsyncScratchTileValue( - ConversionPatternRewriter &rewriter, Location loc, Value originalScratch, - Value emittedScratch) { - Value scratch = peelUnrealized(emittedScratch); - if (auto opaqueTy = dyn_cast(scratch.getType())) { - StringRef typeStr = opaqueTy.getValue(); - if (typeStr.contains("Tile<") || typeStr.contains("ConvTile<")) - return scratch; - } - - auto memTy = dyn_cast(originalScratch.getType()); - if (!memTy) - return failure(); - - ArrayRef shape = memTy.getShape(); - if (!memTy.hasStaticShape() || shape.empty() || shape.size() > 2) - return failure(); - - int64_t rows = shape.size() == 1 ? 1 : shape[0]; - int64_t cols = shape.size() == 1 ? shape[0] : shape[1]; - - auto *ctx = rewriter.getContext(); - pto::TileBufConfigAttr configAttr = pto::TileBufConfigAttr::getDefault(ctx); - if (auto bind = originalScratch.getDefiningOp()) { - configAttr = bind.getConfig(); - } else if (auto cast = originalScratch.getDefiningOp()) { - if (auto config = cast.getConfig()) - configAttr = *config; - } - - int32_t fractal = 512; - if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) - fractal = frAttr.getInt(); - - Type elemTy = memTy.getElementType(); - pto::BLayout blayout = getTileBufBLayoutValue(configAttr); - int64_t templateRows = renderTileTemplateDim(rows, elemTy, blayout, 0); - int64_t templateCols = renderTileTemplateDim(cols, elemTy, blayout, 1); - std::string elemTypeStr = getEmitCScalarTypeToken(elemTy); - std::string tileTypeStr = - "Tile"; - - Value tile = rewriter - .create( - loc, emitc::OpaqueType::get(ctx, tileTypeStr), - emitc::OpaqueAttr::get(ctx, "")) - .getResult(); - auto addr = rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - Value scratchAddr = - rewriter - .create(loc, emitc::OpaqueType::get(ctx, "uint64_t"), - "reinterpret_cast", ArrayAttr{}, addr, - ValueRange{scratch}) - .getResult(0); - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{tile, scratchAddr}); - return tile; -} - -static FailureOr buildSyncAllWorkspaceTileValue( - ConversionPatternRewriter &rewriter, Location loc, Value originalWorkspace, - Value emittedWorkspace) { - Value workspace = peelUnrealized(emittedWorkspace); - if (auto opaqueTy = dyn_cast(workspace.getType())) { - StringRef typeStr = opaqueTy.getValue(); - if (typeStr.contains("Tile<") || typeStr.contains("ConvTile<")) - return workspace; - } - - auto memTy = dyn_cast(originalWorkspace.getType()); - if (!memTy) - return failure(); - if (!memTy.hasStaticShape()) - return failure(); - - ArrayRef rawShape = memTy.getShape(); - if (rawShape.empty() || rawShape.size() > 2) - return failure(); - - int64_t rows = rawShape.size() == 1 ? 1 : rawShape[0]; - int64_t cols = rawShape.size() == 1 ? rawShape[0] : rawShape[1]; - SmallVector shape{rows, cols}; - SmallVector validShape{rows, cols}; - - auto *ctx = rewriter.getContext(); - pto::TileBufConfigAttr configAttr = pto::TileBufConfigAttr::getDefault(ctx); - if (auto bind = originalWorkspace.getDefiningOp()) { - configAttr = bind.getConfig(); - } else if (auto cast = originalWorkspace.getDefiningOp()) { - if (auto config = cast.getConfig()) - configAttr = *config; - } - - Attribute memorySpace = memTy.getMemorySpace(); - if (!memorySpace) - return failure(); - - auto tileTy = pto::TileBufType::get(ctx, shape, memTy.getElementType(), - memorySpace, validShape, configAttr); - auto tileTypeString = getEmitCTileTypeString(tileTy); - if (!tileTypeString) - return failure(); - - auto tileEmitTy = emitc::OpaqueType::get(ctx, *tileTypeString); - Value tile = rewriter - .create(loc, tileEmitTy, - emitc::OpaqueAttr::get(ctx, "")) - .getResult(); - - Value rawPtr = workspace; - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - if (isSetFFTsPointerLikeType(rawPtr.getType())) { - auto rcU64 = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - rawPtr = rewriter - .create(loc, u64Ty, "reinterpret_cast", - ArrayAttr{}, rcU64, - ValueRange{rawPtr}) - .getResult(0); - } else if (rawPtr.getType() != u64Ty) { - rawPtr = rewriter.create(loc, u64Ty, rawPtr).getResult(); - } - - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{tile, rawPtr}); - return tile; -} - -//===----------------------------------------------------------------------===// -// pto.pointer_cast lowering -//===----------------------------------------------------------------------=== -struct PointerCastConversion : public OpConversionPattern { - static bool getIndexConst(Value v, int64_t &out) { - if (auto cst = v.getDefiningOp()) { - if (auto ia = dyn_cast(cst.getValue())) { - out = ia.getValue().getSExtValue(); - return true; - } - } - return false; - } - - using OpConversionPattern::OpConversionPattern; - - enum class TileRole { Vec, Mat, Left, Right, Acc, Bias, Scaling }; - - static void collectUserOpsThroughCasts(Value v, SmallVectorImpl &out) { - for (Operation *u : v.getUsers()) { - if (auto castOp = dyn_cast(u)) { - for (Value r : castOp.getResults()) - collectUserOpsThroughCasts(r, out); - continue; - } - out.push_back(u); - } - } - - static Value peelUnrealized(Value v) { - while (auto castOp = v.getDefiningOp()) { - v = castOp.getOperand(0); - } - return v; - } - - static TileRole inferRole(pto::PointerCastOp op) { - // 1. 优先检查 AddressSpace - if (auto memRefTy = dyn_cast(op.getType())) { - Attribute memorySpace = memRefTy.getMemorySpace(); - if (auto ptoAttr = dyn_cast_or_null(memorySpace)) { - switch (ptoAttr.getAddressSpace()) { - case pto::AddressSpace::LEFT: return TileRole::Left; - case pto::AddressSpace::RIGHT: return TileRole::Right; - case pto::AddressSpace::ACC: return TileRole::Acc; - case pto::AddressSpace::BIAS: return TileRole::Bias; - case pto::AddressSpace::MAT: return TileRole::Mat; - case pto::AddressSpace::SCALING: return TileRole::Scaling; - default: break; - } - } - } - - // 2. 通过 Usage 推导 (Fallback) - SmallVector users; - collectUserOpsThroughCasts(op.getResult(), users); - - for (Operation *user : users) { - if (auto mm = dyn_cast(user)) { - if (mm.getDst() && peelUnrealized(mm.getDst()) == op.getResult()) return TileRole::Acc; - if (peelUnrealized(mm.getLhs()) == op.getResult()) return TileRole::Left; - if (peelUnrealized(mm.getRhs()) == op.getResult()) return TileRole::Right; - } - if (auto mmacc = dyn_cast(user)) { - if (mmacc.getDst() && peelUnrealized(mmacc.getDst()) == op.getResult()) return TileRole::Acc; - if (peelUnrealized(mmacc.getAccIn()) == op.getResult()) return TileRole::Acc; - if (peelUnrealized(mmacc.getLhs()) == op.getResult()) return TileRole::Left; - if (peelUnrealized(mmacc.getRhs()) == op.getResult()) return TileRole::Right; - } - } - - return TileRole::Vec; - } - - // [新增] 辅助函数:判断 Value 是否源自 arith.constant - static bool isConstant(Value v, int64_t &outVal) { - if (!v) return false; - if (auto cst = v.getDefiningOp()) { - if (auto attr = dyn_cast(cst.getValue())) { - outVal = attr.getInt(); - return true; - } - } - return false; - } - - LogicalResult matchAndRewrite(pto::PointerCastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - auto selfType = mlir::cast(op.getType()); - ArrayRef shape = selfType.getShape(); - Type elemType = selfType.getElementType(); - - // 1. 推导 Tile Role - TileRole role = inferRole(op); - - // 2. 类型字符串生成 (elemTypeStr, dimStr) - std::string elemTypeStr = getEmitCScalarTypeToken(elemType); - - std::string dimStr; - pto::BLayout blayout = pto::BLayout::RowMajor; - auto dimToString = [&](int64_t dim, const char *symbol, - int dimIdx) -> std::string { - if (dim == ShapedType::kDynamic) - return std::string(symbol); - return std::to_string(renderTileTemplateDim(dim, elemType, blayout, - dimIdx)); - }; - - // 3. Role Token - const char *roleTok = "TileType::Vec"; - switch (role) { - case TileRole::Left: roleTok = "TileType::Left"; break; - case TileRole::Right: roleTok = "TileType::Right"; break; - case TileRole::Acc: roleTok = "TileType::Acc"; break; - case TileRole::Bias: roleTok = "TileType::Bias"; break; - case TileRole::Mat: roleTok = "TileType::Mat"; break; - case TileRole::Vec: roleTok = "TileType::Vec"; break; - case TileRole::Scaling: roleTok = "TileType::Scaling"; break; - } - - // 4. Config & Layout (support BLayoutAttr/SLayoutAttr/PadValueAttr after namespace change) - std::string layoutParams = "BLayout::RowMajor"; - std::string extraParams = ""; - if (auto configOpt = op.getConfig()) { - auto config = *configOpt; - int32_t blVal = 0; - if (auto attr = dyn_cast(config.getBLayout())) - blVal = static_cast(attr.getValue()); - - if (blVal == 1) layoutParams = "BLayout::ColMajor"; - blayout = blVal == 1 ? pto::BLayout::ColMajor : pto::BLayout::RowMajor; - - int32_t slVal = 0; - if (auto attr = dyn_cast(config.getSLayout())) - slVal = static_cast(attr.getValue()); - - std::string slStr = (slVal == 1) ? "SLayout::RowMajor" : (slVal == 2) ? "SLayout::ColMajor" : "SLayout::NoneBox"; - - int32_t frVal = 0; - if (auto attr = dyn_cast(config.getSFractalSize())) frVal = attr.getInt(); - - int32_t padVal = 0; - if (auto attr = dyn_cast(config.getPad())) - padVal = static_cast(attr.getValue()); - - std::string padStr = "PadValue::Null"; - switch (padVal) { - case 1: padStr = "PadValue::Zero"; break; - case 2: padStr = "PadValue::Max"; break; - case 3: padStr = "PadValue::Min"; break; - } - - int32_t compactVal = 0; - if (auto attr = dyn_cast(config.getCompactMode())) - compactVal = static_cast(attr.getValue()); - - std::string compactStr = "CompactMode::Null"; - switch (compactVal) { - case 1: compactStr = "CompactMode::Normal"; break; - case 2: compactStr = "CompactMode::RowPlusOne"; break; - } - - if (!slStr.empty()) { - extraParams += ", " + slStr + ", " + std::to_string(frVal) + ", " + - padStr + ", " + compactStr; - } - } else { - extraParams = ", SLayout::NoneBox, 512, PadValue::Null, CompactMode::Null"; - } - - if (role == TileRole::Left) - dimStr = dimToString(shape[0], "M", 0) + ", " + - dimToString(shape[1], "K", 1); - else if (role == TileRole::Right) - dimStr = dimToString(shape[0], "K", 0) + ", " + - dimToString(shape[1], "N", 1); - else if (role == TileRole::Bias) - dimStr = "1, " + dimToString(shape[1], "N", 1); - else - dimStr = dimToString(shape[0], "M", 0) + ", " + - dimToString(shape[1], "N", 1); - - // [核心修改] Valid Dims 处理逻辑 (支持混合静态/动态) - std::string vrowTok, vcolTok; - bool useConstructor = false; - - bool rowIsDynamic = false; - bool colIsDynamic = false; - - SmallVector constructorArgs; - - Value vRow = op.getValidRow(); - Value vCol = op.getValidCol(); - Value vRowEmitC = adaptor.getValidRow(); - Value vColEmitC = adaptor.getValidCol(); - bool forceDynamicValid = op->hasAttr(kForceDynamicValidShapeAttrName); - - int64_t cRow = 0, cCol = 0; - bool rowIsConst = vRow && isConstant(vRow, cRow); - bool colIsConst = vCol && isConstant(vCol, cCol); - - auto makeCtorDimValue = [&](Value emitted, int64_t fallback) -> Value { - if (emitted) - return emitted; - return makeEmitCIntConstant( - rewriter, loc, emitc::OpaqueType::get(ctx, "int32_t"), fallback); - }; - auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { - if (!emitted || !pto::isPTOFloat4PackedType(elemType)) - return emitted; - int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; - if (dimIdx != packedDim) - return emitted; - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); - return rewriter.create(loc, i32Ty, emitted, two).getResult(); - }; - - if (forceDynamicValid) { - vrowTok = "-1"; - vcolTok = "-1"; - useConstructor = true; - constructorArgs.push_back( - makeCtorDimValue(maybeScaleDynamicValid(vRowEmitC, 0), - renderTileTemplateDim(rowIsConst ? cRow : shape[0], - elemType, blayout, 0))); - constructorArgs.push_back( - makeCtorDimValue(maybeScaleDynamicValid(vColEmitC, 1), - renderTileTemplateDim(colIsConst ? cCol : shape[1], - elemType, blayout, 1))); - } else { - if (rowIsConst) { - vrowTok = std::to_string( - renderTileTemplateDim(cRow, elemType, blayout, 0)); - } else if (vRow) { - vrowTok = "-1"; - rowIsDynamic = true; - useConstructor = true; - } else { - vrowTok = std::to_string( - renderTileTemplateDim(shape[0], elemType, blayout, 0)); - } - - if (colIsConst) { - vcolTok = std::to_string( - renderTileTemplateDim(cCol, elemType, blayout, 1)); - } else if (vCol) { - vcolTok = "-1"; - colIsDynamic = true; - useConstructor = true; - } else { - vcolTok = std::to_string( - renderTileTemplateDim(shape[1], elemType, blayout, 1)); - } - - if (useConstructor) { - if (rowIsDynamic && vRowEmitC) - constructorArgs.push_back(maybeScaleDynamicValid(vRowEmitC, 0)); - if (colIsDynamic && vColEmitC) - constructorArgs.push_back(maybeScaleDynamicValid(vColEmitC, 1)); - } - } - - // 5. 生成 Tile 类型字符串 - std::string tileTypeStr = - std::string("Tile<") + roleTok + ", " + elemTypeStr + ", " + dimStr + ", " + - layoutParams + ", " + vrowTok + ", " + vcolTok + extraParams + ">"; - - auto tileType = emitc::OpaqueType::get(ctx, tileTypeStr); - Value resultValue; - - if (useConstructor) { - // 使用 CallOpaqueOp 生成构造函数调用 (Tile v = Tile(...)) - auto ctorOp = rewriter.create( - loc, - tileType, // Result Type - tileTypeStr, // Callee Name (类名) - ArrayAttr{}, // args - ArrayAttr{}, // template_args - ValueRange(constructorArgs) // operands - ); - resultValue = ctorOp.getResult(0); - } else { - // 静态情况 (Tile v;) - auto varOp = rewriter.create( - loc, - tileType, - emitc::OpaqueAttr::get(ctx, "") - ); - resultValue = varOp.getResult(); - } - - // TASSIGN: pto-isa expects an integral address. - Value addr = adaptor.getAddrs()[0]; - if (isa(addr.getType()) || - (isa(addr.getType()) && - cast(addr.getType()).getValue().ends_with("*"))) { - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - auto rcU64 = rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - addr = rewriter.create( - loc, u64Ty, "reinterpret_cast", - /*args=*/ArrayAttr{}, /*templateArgs=*/rcU64, - /*operands=*/ValueRange{addr}) - .getResult(0); - } - - rewriter.create( - loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{resultValue, addr}); - - rewriter.replaceOp(op, resultValue); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.load_dps / pto.store_dps lowering (FIX: keep optional result) -//===----------------------------------------------------------------------=== - -struct PTOTLoadToTLOAD : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TLoadOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!op.getDst()) - return rewriter.notifyMatchFailure(op, "expected outs(dst) on pto.tload"); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value srcArg = src; - if (auto srcMrTy = dyn_cast(op.getSrc().getType())) { - bool isGlobal = true; - if (auto asAttr = dyn_cast_or_null(srcMrTy.getMemorySpace())) { - auto as = asAttr.getAddressSpace(); - isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); - } - if (isGlobal) { - if (Value gt = buildGlobalTensorFromMemref(rewriter, op.getLoc(), src, srcMrTy, - op.getOperation())) - srcArg = gt; - } - } - - rewriter.create( - op.getLoc(), TypeRange{}, "TLOAD", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, srcArg}); - - if (op->getNumResults() == 1) { - rewriter.replaceOp(op, dst); - } else { - rewriter.eraseOp(op); - } - return success(); - } -}; - -struct PTOTPrefetchToTPREFETCH : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPrefetchOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!op.getDst()) - return rewriter.notifyMatchFailure(op, "expected outs(dst) on pto.tprefetch"); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value srcArg = src; - if (auto srcMrTy = dyn_cast(op.getSrc().getType())) { - bool isGlobal = true; - if (auto asAttr = dyn_cast_or_null(srcMrTy.getMemorySpace())) { - auto as = asAttr.getAddressSpace(); - isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); - } - if (isGlobal) { - if (Value gt = buildGlobalTensorFromMemref(rewriter, op.getLoc(), src, srcMrTy, - op.getOperation())) - srcArg = gt; - } - } - - rewriter.create( - op.getLoc(), TypeRange{}, "TPREFETCH", - ArrayAttr{}, ArrayAttr{}, ValueRange{dst, srcArg}); - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOTPrefetchAsyncToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPrefetchAsyncOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src = peelUnrealized(adaptor.getSrc()); - Value srcArg = src; - if (!isEmitCGlobalTensorLikeType(srcArg.getType())) { - auto srcMrTy = dyn_cast(op.getSrc().getType()); - if (!srcMrTy) - return rewriter.notifyMatchFailure( - op, "expected src to lower to GlobalTensor or memref"); - srcArg = buildGlobalTensorFromMemref(rewriter, op.getLoc(), src, srcMrTy, - op.getSrc().getDefiningOp() - ? op.getSrc().getDefiningOp() - : op.getOperation()); - } - if (!srcArg) - return rewriter.notifyMatchFailure(op, - "failed to build GlobalTensor src"); - - Value prefetchCtx = peelUnrealized(adaptor.getCtx()); - - Type eventTy = getTypeConverter()->convertType(op.getEvent().getType()); - if (!eventTy) - return rewriter.notifyMatchFailure( - op, "failed to convert tprefetch_async result type"); - - Value event = rewriter - .create( - op.getLoc(), TypeRange{eventTy}, "TPREFETCH_ASYNC", - ArrayAttr{}, ArrayAttr{}, - ValueRange{srcArg, prefetchCtx}) - .getResult(0); - - rewriter.replaceOp(op, ValueRange{event}); - return success(); - } -}; - -struct PTOMakePrefetchAsyncContextToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::MakePrefetchAsyncContextOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type ctxTy = getTypeConverter()->convertType(op.getCtx().getType()); - if (!ctxTy) - return rewriter.notifyMatchFailure( - op, "failed to convert make_prefetch_async_context result type"); - - Value workspace = peelUnrealized(adaptor.getWorkspace()); - workspace = castToGMBytePointer(rewriter, op.getLoc(), workspace); - - Value ctx = rewriter - .create( - op.getLoc(), TypeRange{ctxTy}, "pto::PrefetchAsyncContext", - ArrayAttr{}, ArrayAttr{}, ValueRange{workspace}) - .getResult(0); - - rewriter.replaceOp(op, ValueRange{ctx}); - return success(); - } -}; - -struct PTOGetPrefetchAsyncSessionToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::GetPrefetchAsyncSessionOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type sessionTy = getTypeConverter()->convertType(op.getSession().getType()); - if (!sessionTy) - return rewriter.notifyMatchFailure( - op, "failed to convert get_prefetch_async_session result type"); - - Value ctx = peelUnrealized(adaptor.getCtx()); - Value session = rewriter - .create( - op.getLoc(), TypeRange{sessionTy}, - "PTOAS__PREFETCH_CTX_SESSION", ArrayAttr{}, - ArrayAttr{}, ValueRange{ctx}) - .getResult(0); - - rewriter.replaceOp(op, ValueRange{session}); - return success(); - } -}; - -struct PTOTStoreToTSTORE : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - static std::string stPhaseTok(pto::STPhase phase) { - switch (phase) { - case pto::STPhase::Unspecified: return "STPhase::Unspecified"; - case pto::STPhase::Partial: return "STPhase::Partial"; - case pto::STPhase::Final: return "STPhase::Final"; - } - return "STPhase::Unspecified"; - } - - static std::string atomicTypeTok(pto::AtomicType atomicType) { - switch (atomicType) { - case pto::AtomicType::AtomicNone: return "AtomicType::AtomicNone"; - case pto::AtomicType::AtomicAdd: return "AtomicType::AtomicAdd"; - } - return "AtomicType::AtomicNone"; - } - - static std::string reluPreModeTok(pto::ReluPreMode reluPreMode) { - switch (reluPreMode) { - case pto::ReluPreMode::NoRelu: return "ReluPreMode::NoRelu"; - case pto::ReluPreMode::NormalRelu: return "ReluPreMode::NormalRelu"; - } - return "ReluPreMode::NoRelu"; - } - - LogicalResult matchAndRewrite(pto::TStoreOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!op.getDst()) - return rewriter.notifyMatchFailure(op, "expected outs(dst) on pto.tstore"); - - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value preQuantScalar; - if (op.getPreQuantScalar()) - preQuantScalar = peelUnrealized(adaptor.getPreQuantScalar()); - Value dstArg = dst; - if (auto dstMrTy = dyn_cast(op.getDst().getType())) { - bool isGlobal = true; - if (auto asAttr = dyn_cast_or_null(dstMrTy.getMemorySpace())) { - auto as = asAttr.getAddressSpace(); - isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); - } - if (isGlobal) { - if (Value gt = buildGlobalTensorFromMemref(rewriter, op.getLoc(), dst, dstMrTy, - op.getOperation())) - dstArg = gt; - } - } - - const auto phase = op.getStPhase(); - const auto atomicType = op.getAtomicType(); - const auto reluPreMode = op.getReluPreMode(); - const bool hasPreQuantScalar = static_cast(preQuantScalar); - const bool phaseNonDefault = phase != pto::STPhase::Unspecified; - const bool atomicNonDefault = atomicType != pto::AtomicType::AtomicNone; - const bool reluNonDefault = reluPreMode != pto::ReluPreMode::NoRelu; - - auto getOpaqueTok = [&](Value v, StringRef name) -> FailureOr { - if (auto ot = mlir::dyn_cast(v.getType())) - return ot.getValue().str(); - return rewriter.notifyMatchFailure(op, (name + " must be emitc::OpaqueType").str()); - }; - - ArrayAttr targs; - // Map op attributes/operands to the exact TSTORE overload family: - // 1) TSTORE(dst, src) - // 2) TSTORE(dst, src) - // 3) TSTORE(dst, src) - // 4) TSTORE(dst, src) - // 5) TSTORE(dst, src) - // 6) TSTORE(dst, src) - // 7) TSTORE(dst, src, preQuant) - // 8) TSTORE(dst, src, preQuant) - if (!hasPreQuantScalar && !reluNonDefault && !atomicNonDefault) { - if (phaseNonDefault) { - targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, stPhaseTok(phase)), - }); - } else { - targs = ArrayAttr{}; - } - } else { - auto srcTokOr = getOpaqueTok(src, "src"); - auto dstTokOr = getOpaqueTok(dstArg, "dst"); - if (failed(srcTokOr) || failed(dstTokOr)) - return failure(); - - // If there is no preQuant and relu stays default, emit the atomic-only - // overloads (#3/#4) without ReluPreMode template argument. - if (!hasPreQuantScalar && !reluNonDefault) { - if (phaseNonDefault) { - targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, stPhaseTok(phase)), - emitc::OpaqueAttr::get(ctx, *srcTokOr), - emitc::OpaqueAttr::get(ctx, *dstTokOr), - emitc::OpaqueAttr::get(ctx, atomicTypeTok(atomicType)), - }); - } else { - targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, *srcTokOr), - emitc::OpaqueAttr::get(ctx, *dstTokOr), - emitc::OpaqueAttr::get(ctx, atomicTypeTok(atomicType)), - }); - } - } else { - // Relu/preQuant families (#5/#6/#7/#8): keep AtomicType + ReluPreMode. - if (phaseNonDefault) { - targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, stPhaseTok(phase)), - emitc::OpaqueAttr::get(ctx, *srcTokOr), - emitc::OpaqueAttr::get(ctx, *dstTokOr), - emitc::OpaqueAttr::get(ctx, atomicTypeTok(atomicType)), - emitc::OpaqueAttr::get(ctx, reluPreModeTok(reluPreMode)), - }); - } else { - targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, *srcTokOr), - emitc::OpaqueAttr::get(ctx, *dstTokOr), - emitc::OpaqueAttr::get(ctx, atomicTypeTok(atomicType)), - emitc::OpaqueAttr::get(ctx, reluPreModeTok(reluPreMode)), - }); - } - } - } - - SmallVector operands{dstArg, src}; - if (hasPreQuantScalar) - operands.push_back(preQuantScalar); - - rewriter.create( - loc, TypeRange{}, "TSTORE", - /*args=*/ArrayAttr{}, /*templateArgs=*/targs, - /*operands=*/operands); - - if (op->getNumResults() == 1) { - rewriter.replaceOp(op, dst); - } else { - rewriter.eraseOp(op); - } - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.matmul_dps lowering (Simplified: No internal copy/sync) -//===----------------------------------------------------------------------===// -// -// Render `pto.tmatmul` as one of three forms depending on the optional -// `acc_phase` attribute: -// * absent / Unspecified -> `TMATMUL(dst, lhs, rhs)` -// * Partial -> `TMATMUL(dst, lhs, rhs)` -// * Final -> `TMATMUL(dst, lhs, rhs)` -// The Unspecified default keeps backward compatibility with all upstream IR -// that does not yet emit an explicit phase attribute. -static ArrayAttr buildAccPhaseTemplateArgs(ConversionPatternRewriter &rewriter, - pto::AccPhase phase) { - StringRef tmpl; - switch (phase) { - case pto::AccPhase::Unspecified: - return ArrayAttr{}; - case pto::AccPhase::Partial: - tmpl = "AccPhase::Partial"; - break; - case pto::AccPhase::Final: - tmpl = "AccPhase::Final"; - break; - } - if (tmpl.empty()) - return ArrayAttr{}; - return rewriter.getArrayAttr( - {emitc::OpaqueAttr::get(rewriter.getContext(), tmpl)}); -} - -struct PTOTMatmulToTMATMUL : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMatmulOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // 1. 获取操作数 (剥离 Cast) - Value lhs = peelUnrealized(adaptor.getLhs()); // A (Left) - Value rhs = peelUnrealized(adaptor.getRhs()); // B (Right) - Value dst = peelUnrealized(adaptor.getDst()); // C (Acc) - - // 2. 根据 acc_phase 属性决定是否生成 TMATMUL(...) - ArrayAttr templateArgs = - buildAccPhaseTemplateArgs(rewriter, op.getAccPhase()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TMATMUL", - /*args=*/ArrayAttr{}, /*template_args=*/templateArgs, - ValueRange{dst, lhs, rhs}); - - // 3. 处理 Op 替换/删除 - if (op->getNumResults() == 1) { - rewriter.replaceOp(op, dst); - } else { - rewriter.eraseOp(op); - } - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.tgemv lowering -//===----------------------------------------------------------------------===// -struct PTOTGemvToTGEMV : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGemvOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // 1. 获取操作数 (剥离 Cast) - Value lhs = peelUnrealized(adaptor.getLhs()); // A (Matrix) - Value rhs = peelUnrealized(adaptor.getRhs()); // B (Vector) - Value dst = peelUnrealized(adaptor.getDst()); // C (Result) - - // 2. 直接生成函数调用 TGEMV(dst, lhs, rhs) - rewriter.create( - op.getLoc(), TypeRange{}, "TGEMV", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, lhs, rhs}); - - // 3. 处理 Op 替换/删除 - if (op->getNumResults() == 1) { - rewriter.replaceOp(op, dst); - } else { - rewriter.eraseOp(op); - } - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.tgemv.acc lowering -//===----------------------------------------------------------------------===// -struct PTOTGemvAccToTGEMVACC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGemvAccOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!op.getDst()) - return rewriter.notifyMatchFailure(op, "expected outs(dst) for pto.tgemv.acc"); - - // 1. 获取操作数 - Value accIn = peelUnrealized(adaptor.getAccIn()); // AccOld - Value lhs = peelUnrealized(adaptor.getLhs()); // A (Matrix) - Value rhs = peelUnrealized(adaptor.getRhs()); // B (Vector) - Value dst = peelUnrealized(adaptor.getDst()); // AccNew - - // 2. 直接生成函数调用 TGEMV_ACC(dst, accIn, lhs, rhs) - rewriter.create( - op.getLoc(), TypeRange{}, "TGEMV_ACC", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, accIn, lhs, rhs}); - - // 3. 处理 Op 替换/删除 - if (op->getNumResults() == 1) { - rewriter.replaceOp(op, dst); - } else { - rewriter.eraseOp(op); - } - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.matmul_acc_dps lowering (Simplified: No internal copy/sync) -//===----------------------------------------------------------------------===// -struct PTOTMatmulAccToTMATMULACC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMatmulAccOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!op.getDst()) - return rewriter.notifyMatchFailure(op, "expected outs(dst) for pto.tmatmul.acc"); - - // 1. 获取操作数 - Value accIn = peelUnrealized(adaptor.getAccIn()); // AccOld - Value lhs = peelUnrealized(adaptor.getLhs()); // A (Left) - Value rhs = peelUnrealized(adaptor.getRhs()); // B (Right) - Value dst = peelUnrealized(adaptor.getDst()); // AccNew - - // 2. 根据 acc_phase 属性决定是否生成 TMATMUL_ACC(...) - ArrayAttr templateArgs = - buildAccPhaseTemplateArgs(rewriter, op.getAccPhase()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TMATMUL_ACC", - /*args=*/ArrayAttr{}, /*template_args=*/templateArgs, - ValueRange{dst, accIn, lhs, rhs}); - - // 3. 处理 Op 替换/删除 - if (op->getNumResults() == 1) { - rewriter.replaceOp(op, dst); - } else { - rewriter.eraseOp(op); - } - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// Return lowering -//===----------------------------------------------------------------------=== - -static constexpr llvm::StringLiteral kAutoSyncTailPendingModeAttr = - "__pto.auto_sync_tail_mode"; - -struct ReturnToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (auto emitcFunc = op->getParentOfType()) { - if (auto modeAttr = - emitcFunc->getAttrOfType(kAutoSyncTailPendingModeAttr)) { - auto *ctx = rewriter.getContext(); - rewriter.setInsertionPoint(op); - auto args = rewriter.getArrayAttr( - {emitc::OpaqueAttr::get(ctx, modeAttr.getValue())}); - rewriter.create( - op.getLoc(), TypeRange{}, "ptoas_auto_sync_tail", - args, ArrayAttr{}, ValueRange{}); - } - } - - auto vals = adaptor.getOperands(); - if (vals.empty()) { - rewriter.replaceOpWithNewOp(op, Value{}); - return success(); - } - if (vals.size() == 1) { - rewriter.replaceOpWithNewOp(op, vals[0]); - return success(); - } - return rewriter.notifyMatchFailure(op, "EmitC cannot return multiple values"); - } -}; - -struct CallToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(func::CallOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (op.getNumResults() > 1) - return rewriter.notifyMatchFailure( - op, "EmitC cannot lower calls with multiple results"); - - SmallVector resultTypes; - if (failed( - getTypeConverter()->convertTypes(op.getResultTypes(), resultTypes))) - return rewriter.notifyMatchFailure(op, - "failed to convert call result types"); - - rewriter.replaceOpWithNewOp(op, op.getCalleeAttr(), - resultTypes, - adaptor.getOperands()); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// Sync lowering -//===----------------------------------------------------------------------=== - -static constexpr llvm::StringLiteral kAutoSyncTailBarrierAttr = - "pto.auto_sync_tail_barrier"; -static constexpr llvm::StringLiteral kAutoSyncTailHintAttr = - "pto.auto_sync_tail_hint"; -static constexpr llvm::StringLiteral kAutoSyncTailPolicyBarrierAll = - "barrier_all"; -static constexpr llvm::StringLiteral kAutoSyncTailPolicyMte3ToSEvent0 = - "setwait_mte3_to_s_event0"; -static constexpr llvm::StringLiteral kAutoSyncTailModeBarrierAllToken = - "PTOAutoSyncTailMode::kBarrierAll"; -static constexpr llvm::StringLiteral kAutoSyncTailModeMte3ToSEvent0Token = - "PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0"; - -static std::string getAutoSyncTailModeToken(Operation *op) { - if (op) { - if (auto hintAttr = op->getAttrOfType(kAutoSyncTailHintAttr)) { - if (hintAttr.getValue() == kAutoSyncTailPolicyBarrierAll) - return kAutoSyncTailModeBarrierAllToken.str(); - if (hintAttr.getValue() == kAutoSyncTailPolicyMte3ToSEvent0) - return kAutoSyncTailModeMte3ToSEvent0Token.str(); - } - } - - auto func = op ? op->getParentOfType() : func::FuncOp(); - if (!func) - return kAutoSyncTailModeBarrierAllToken.str(); - - auto hintAttr = func->getAttrOfType(kAutoSyncTailHintAttr); - if (!hintAttr) - return kAutoSyncTailModeBarrierAllToken.str(); - - if (hintAttr.getValue() == kAutoSyncTailPolicyBarrierAll) - return kAutoSyncTailModeBarrierAllToken.str(); - if (hintAttr.getValue() == kAutoSyncTailPolicyMte3ToSEvent0) - return kAutoSyncTailModeMte3ToSEvent0Token.str(); - - // Fallback to the conservative behavior when seeing unknown policies. - return kAutoSyncTailModeBarrierAllToken.str(); -} - -[[maybe_unused]] static std::string getPipeName(pto::PIPE pipe) { - switch (pipe) { - case pto::PIPE::PIPE_S: return "PIPE_S"; - case pto::PIPE::PIPE_V: return "PIPE_V"; - case pto::PIPE::PIPE_M: return "PIPE_M"; - case pto::PIPE::PIPE_MTE1: return "PIPE_MTE1"; - case pto::PIPE::PIPE_MTE2: return "PIPE_MTE2"; - case pto::PIPE::PIPE_MTE3: return "PIPE_MTE3"; - case pto::PIPE::PIPE_ALL: return "PIPE_ALL"; - case pto::PIPE::PIPE_MTE4: return "PIPE_MTE4"; - case pto::PIPE::PIPE_MTE5: return "PIPE_MTE5"; - case pto::PIPE::PIPE_V2: return "PIPE_V2"; - case pto::PIPE::PIPE_FIX: return "PIPE_FIX"; - case pto::PIPE::VIRTUAL_PIPE_MTE2_L1A: return "VIRTUAL_PIPE_MTE2_L1A"; - case pto::PIPE::VIRTUAL_PIPE_MTE2_L1B: return "VIRTUAL_PIPE_MTE2_L1B"; - // 默认回退 - default: return "PIPE_ALL"; - } -} - -//===----------------------------------------------------------------------===// -// pto.barrier lowering -> pipe_barrier(...) -//===----------------------------------------------------------------------===// -struct PTOBarrierToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::BarrierOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (op->hasAttr(kAutoSyncTailBarrierAttr)) { - auto modeAttr = rewriter.getStringAttr(getAutoSyncTailModeToken(op)); - if (auto emitcFunc = op->getParentOfType()) { - emitcFunc->setAttr(kAutoSyncTailPendingModeAttr, modeAttr); - } else if (auto funcOp = op->getParentOfType()) { - funcOp->setAttr(kAutoSyncTailPendingModeAttr, modeAttr); - } - rewriter.eraseOp(op); - return success(); - } - - // [FIX] op.getPipe() returns PipeAttr. - // We must call .getPipe() on the attribute to get the actual Enum value. - pto::PIPE pipeEnum = op.getPipe().getPipe(); - - // Convert Enum to String (e.g., PIPE_ALL -> "PIPE_ALL") - std::string pipeStr = pto::stringifyPIPE(pipeEnum).str(); - auto *ctx = rewriter.getContext(); - - auto args = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeStr) - }); - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, // void return - "pipe_barrier", // function name - args, // arguments - ArrayAttr{}, // template args - ValueRange{} // operands - ); - - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// Sync lowering (robust for bracket form pto.set_flag[...] / pto.wait_flag[...]) -// Replace your PTOSyncToRuntimeCall with the code below. -//===----------------------------------------------------------------------===// - -static bool tryConvertPipeAttrToToken(Attribute attr, std::string &token) { - if (!attr) - return false; - if (auto pipe = dyn_cast(attr)) { - token = mlir::pto::stringifyPIPE(pipe.getPipe()).str(); - return true; - } - if (auto stringAttr = dyn_cast(attr)) { - token = stringAttr.getValue().str(); - return true; - } - return false; -} - -static bool tryConvertEventAttrToToken(Attribute attr, std::string &token) { - if (!attr) - return false; - if (auto event = dyn_cast(attr)) { - token = mlir::pto::stringifyEVENT(event.getEvent()).str(); - return true; - } - if (auto stringAttr = dyn_cast(attr)) { - token = stringAttr.getValue().str(); - return true; - } - return false; -} - -static bool tryAssignSyncTokens(Attribute srcAttr, Attribute dstAttr, - Attribute evtAttr, std::string &srcTok, - std::string &dstTok, std::string &evtTok) { - std::string localSrc; - std::string localDst; - std::string localEvt; - if (!tryConvertPipeAttrToToken(srcAttr, localSrc) || - !tryConvertPipeAttrToToken(dstAttr, localDst) || - !tryConvertEventAttrToToken(evtAttr, localEvt)) { - return false; - } - srcTok = std::move(localSrc); - dstTok = std::move(localDst); - evtTok = std::move(localEvt); - return true; -} - -static bool tryExtractSyncTokensFromNamedAttrs(Operation *op, - StringRef srcName, - StringRef dstName, - StringRef evtName, - std::string &srcTok, - std::string &dstTok, - std::string &evtTok) { - return tryAssignSyncTokens(op->getAttr(srcName), op->getAttr(dstName), - op->getAttr(evtName), srcTok, dstTok, evtTok); -} - -static bool tryExtractSyncTokensFromArrayAttr(Operation *op, StringRef attrName, - std::string &srcTok, - std::string &dstTok, - std::string &evtTok) { - auto arrayAttr = op->getAttrOfType(attrName); - if (!arrayAttr || arrayAttr.size() < 3) - return false; - return tryAssignSyncTokens(arrayAttr[0], arrayAttr[1], arrayAttr[2], srcTok, - dstTok, evtTok); -} - -static bool tryExtractFallbackSyncTokens(Operation *op, std::string &srcTok, - std::string &dstTok, - std::string &evtTok) { - SmallVector pipes; - std::string event; - for (NamedAttribute namedAttr : op->getAttrs()) { - std::string token; - if (tryConvertPipeAttrToToken(namedAttr.getValue(), token)) { - pipes.push_back(std::move(token)); - continue; - } - if (event.empty() && - tryConvertEventAttrToToken(namedAttr.getValue(), token)) { - event = std::move(token); - } - } - if (pipes.size() < 2 || event.empty()) - return false; - srcTok = pipes[0]; - dstTok = pipes[1]; - evtTok = event; - return true; -} - -static LogicalResult extractSyncTripletTokens(Operation *op, - std::string &srcTok, - std::string &dstTok, - std::string &evtTok, - ConversionPatternRewriter &rewriter) { - if (tryExtractSyncTokensFromNamedAttrs(op, "src_pipe", "dst_pipe", "event_id", - srcTok, dstTok, evtTok) || - tryExtractSyncTokensFromNamedAttrs(op, "srcPipe", "dstPipe", "eventId", - srcTok, dstTok, evtTok) || - tryExtractSyncTokensFromNamedAttrs(op, "src", "dst", "event", srcTok, - dstTok, evtTok)) { - return success(); - } - - for (StringRef attrName : {"args", "pipes", "sync", "triplet", "attrs"}) { - if (tryExtractSyncTokensFromArrayAttr(op, attrName, srcTok, dstTok, - evtTok)) { - return success(); - } - } - - if (tryExtractFallbackSyncTokens(op, srcTok, dstTok, evtTok)) - return success(); - return rewriter.notifyMatchFailure( - op, "cannot extract PIPE/PIPE/EVENT tokens from pto.{set,wait}_flag"); -} -static inline std::string pipeTokFromPipeEnum(mlir::pto::PIPE p) { - return mlir::pto::stringifyPIPE(p).str(); -} -[[maybe_unused]] static inline std::string evtTokFromEventEnum(mlir::pto::EVENT e) { - return mlir::pto::stringifyEVENT(e).str(); -} -static inline std::string pipeTokFromPipeAttr(mlir::pto::PipeAttr a) { - return mlir::pto::stringifyPIPE(a.getPipe()).str(); -} -static inline std::string evtTokFromEventAttr(mlir::pto::EventAttr a) { - return mlir::pto::stringifyEVENT(a.getEvent()).str(); -} - -template -struct HasGetSrcPipe : std::false_type {}; -template -struct HasGetSrcPipe().getSrcPipe())>> : std::true_type {}; - -template -struct HasGetDstPipe : std::false_type {}; -template -struct HasGetDstPipe().getDstPipe())>> : std::true_type {}; - -template -struct HasGetEventId : std::false_type {}; -template -struct HasGetEventId().getEventId())>> : std::true_type {}; - -template -struct HasGetSrcPipeAttr : std::false_type {}; -template -struct HasGetSrcPipeAttr().getSrcPipeAttr())>> : std::true_type {}; - -template -struct HasGetDstPipeAttr : std::false_type {}; -template -struct HasGetDstPipeAttr().getDstPipeAttr())>> : std::true_type {}; - -template -struct HasGetEventIdAttr : std::false_type {}; -template -struct HasGetEventIdAttr().getEventIdAttr())>> : std::true_type {}; - -template -static LogicalResult extractSyncTokens(SyncOpT op, - std::string &srcTok, - std::string &dstTok, - std::string &evtTok, - ConversionPatternRewriter &rewriter) { - if constexpr (HasGetSrcPipe::value && - HasGetDstPipe::value && - HasGetEventId::value) { - auto s = op.getSrcPipe(); - auto d = op.getDstPipe(); - auto e = op.getEventId(); - - if constexpr (std::is_same::value) srcTok = pipeTokFromPipeEnum(s); - else srcTok = pipeTokFromPipeAttr(s); - - if constexpr (std::is_same::value) dstTok = pipeTokFromPipeEnum(d); - else dstTok = pipeTokFromPipeAttr(d); - - if constexpr (std::is_same::value) evtTok = evtTokFromEventEnum(e); - else evtTok = evtTokFromEventAttr(e); - - return success(); - } - - if constexpr (HasGetSrcPipeAttr::value && - HasGetDstPipeAttr::value && - HasGetEventIdAttr::value) { - auto s = op.getSrcPipeAttr(); - auto d = op.getDstPipeAttr(); - auto e = op.getEventIdAttr(); - srcTok = pipeTokFromPipeAttr(s); - dstTok = pipeTokFromPipeAttr(d); - evtTok = evtTokFromEventAttr(e); - return success(); - } - - return extractSyncTripletTokens(op.getOperation(), srcTok, dstTok, evtTok, rewriter); -} -struct PTOSetFlagToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::SetFlagOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - auto *ctx = rewriter.getContext(); - - std::string srcTok, dstTok, evtTok; - if (failed(extractSyncTokens(op, srcTok, dstTok, evtTok, rewriter))) - return failure(); - - auto argsAttr = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, srcTok), - emitc::OpaqueAttr::get(ctx, dstTok), - emitc::OpaqueAttr::get(ctx, evtTok), - }); - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, "set_flag", - /*args=*/argsAttr, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{}); - return success(); - } -}; - -struct PTOWaitFlagToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::WaitFlagOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - auto *ctx = rewriter.getContext(); - - std::string srcTok, dstTok, evtTok; - if (failed(extractSyncTokens(op, srcTok, dstTok, evtTok, rewriter))) - return failure(); - - auto argsAttr = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, srcTok), - emitc::OpaqueAttr::get(ctx, dstTok), - emitc::OpaqueAttr::get(ctx, evtTok), - }); - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, "wait_flag", - /*args=*/argsAttr, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{}); - return success(); - } -}; - -struct PTOSyncToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::TSyncOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - SmallVector operands; - operands.reserve(adaptor.getEvents().size()); - for (Value event : adaptor.getEvents()) - operands.push_back(peelUnrealized(event)); - - rewriter.create( - op.getLoc(), TypeRange{}, "TSYNC", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange(operands)); - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOSyncAllToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - static StringRef coreTypeTok(pto::SyncCoreType coreType) { - switch (coreType) { - case pto::SyncCoreType::AIVOnly: - return "SyncCoreType::AIVOnly"; - case pto::SyncCoreType::AICOnly: - return "SyncCoreType::AICOnly"; - case pto::SyncCoreType::Mix: - return "SyncCoreType::Mix"; - } - llvm_unreachable("unhandled SyncCoreType"); - } - - LogicalResult matchAndRewrite(mlir::pto::SyncAllOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto mode = op.getMode().getValue(); - auto coreType = op.getCoreType().getValue(); - - auto buildGmWorkspace = [&]() -> FailureOr { - Value gm = peelUnrealized(adaptor.getGmWorkspace()); - if (isEmitCGlobalTensorLikeType(gm.getType())) - return gm; - - auto memTy = dyn_cast(op.getGmWorkspace().getType()); - if (!memTy) - return failure(); - - Value gt = buildGlobalTensorFromMemref(rewriter, op.getLoc(), gm, memTy, - op.getGmWorkspace().getDefiningOp() - ? op.getGmWorkspace().getDefiningOp() - : op.getOperation()); - if (!gt) - return failure(); - return gt; - }; - - if (mode == pto::SyncAllMode::Hard) { - std::string callee = "SYNCALL<" + coreTypeTok(coreType).str() + ">"; - rewriter.create(op.getLoc(), TypeRange{}, callee, - ArrayAttr{}, ArrayAttr{}, - ValueRange{}); - rewriter.eraseOp(op); - return success(); - } - - FailureOr gmWorkspace = buildGmWorkspace(); - if (failed(gmWorkspace)) - return rewriter.notifyMatchFailure(op, - "failed to build gm_workspace GlobalTensor"); - - auto i32Ty = emitc::OpaqueType::get(rewriter.getContext(), "int32_t"); - Value usedCores = adaptor.getUsedCores() - ? peelUnrealized(adaptor.getUsedCores()) - : makeEmitCIntConstant(rewriter, op.getLoc(), i32Ty, 0); - if (usedCores.getType() != i32Ty) - usedCores = rewriter.create(op.getLoc(), i32Ty, usedCores) - .getResult(); - - std::string callee = - "SYNCALL"; - - SmallVector operands{*gmWorkspace}; - switch (coreType) { - case pto::SyncCoreType::AIVOnly: { - FailureOr ubWorkspace = - buildSyncAllWorkspaceTileValue(rewriter, op.getLoc(), - op.getUbWorkspace(), - adaptor.getUbWorkspace()); - if (failed(ubWorkspace)) - return rewriter.notifyMatchFailure( - op, "failed to materialize ub_workspace tile"); - operands.push_back(*ubWorkspace); - break; - } - case pto::SyncCoreType::AICOnly: { - FailureOr l1Workspace = - buildSyncAllWorkspaceTileValue(rewriter, op.getLoc(), - op.getL1Workspace(), - adaptor.getL1Workspace()); - if (failed(l1Workspace)) - return rewriter.notifyMatchFailure( - op, "failed to materialize l1_workspace tile"); - operands.push_back(*l1Workspace); - break; - } - case pto::SyncCoreType::Mix: { - FailureOr ubWorkspace = - buildSyncAllWorkspaceTileValue(rewriter, op.getLoc(), - op.getUbWorkspace(), - adaptor.getUbWorkspace()); - FailureOr l1Workspace = - buildSyncAllWorkspaceTileValue(rewriter, op.getLoc(), - op.getL1Workspace(), - adaptor.getL1Workspace()); - if (failed(ubWorkspace) || failed(l1Workspace)) - return rewriter.notifyMatchFailure( - op, "failed to materialize mixed syncall workspace tiles"); - operands.push_back(*ubWorkspace); - operands.push_back(*l1Workspace); - break; - } - } - - operands.push_back(usedCores); - rewriter.create(op.getLoc(), TypeRange{}, callee, - ArrayAttr{}, ArrayAttr{}, - ValueRange(operands)); - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOSyncFlagDynToEmitC : public ConversionPattern { - PTOSyncFlagDynToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - StringRef opName, StringRef callee) - : ConversionPattern(typeConverter, opName, /*benefit=*/1, ctx), - callee(callee.str()) {} - - LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - if (operands.size() != 1) - return rewriter.notifyMatchFailure(op, "expected exactly one dynamic event-id operand"); - - auto srcAttr = op->getAttrOfType("src_pipe"); - auto dstAttr = op->getAttrOfType("dst_pipe"); - if (!srcAttr || !dstAttr) - return rewriter.notifyMatchFailure(op, "missing PipeAttr src_pipe/dst_pipe attrs"); - - auto *ctx = rewriter.getContext(); - std::string srcTok = pipeTokFromPipeAttr(srcAttr); - std::string dstTok = pipeTokFromPipeAttr(dstAttr); - - Value eventVal = operands.front(); - eventVal = - emitCCast(rewriter, op->getLoc(), emitc::OpaqueType::get(ctx, "event_t"), eventVal); - - auto argsAttr = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, srcTok), - emitc::OpaqueAttr::get(ctx, dstTok), - IntegerAttr::get(IndexType::get(ctx), 0), - }); - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, callee, - /*args=*/argsAttr, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{eventVal}); - return success(); - } - -private: - std::string callee; -}; - -struct PTOGetBufToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::GetBufOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - auto *ctx = rewriter.getContext(); - - auto opTypeOr = parseSyncOpTypeLikeAttr(op.getOpTypeAttr()); - if (failed(opTypeOr)) - return rewriter.notifyMatchFailure(op, "get_buf expects pipe_event_type/sync_op_type attr"); - auto pipe = mapSyncOpTypeToPipe(*opTypeOr); - if (!isConcreteSyncPipe(pipe)) - return rewriter.notifyMatchFailure(op, "get_buf op_type cannot map to a concrete pipe"); - std::string pipeTok = pipeTokFromPipeEnum(pipe); - auto argsAttr = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeTok), - op.getBufIdAttr(), - op.getModeAttr(), - }); - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, "get_buf", - /*args=*/argsAttr, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{}); - return success(); - } -}; - -struct PTORlsBufToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::RlsBufOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - auto *ctx = rewriter.getContext(); - - auto opTypeOr = parseSyncOpTypeLikeAttr(op.getOpTypeAttr()); - if (failed(opTypeOr)) - return rewriter.notifyMatchFailure(op, "rls_buf expects pipe_event_type/sync_op_type attr"); - auto pipe = mapSyncOpTypeToPipe(*opTypeOr); - if (!isConcreteSyncPipe(pipe)) - return rewriter.notifyMatchFailure(op, "rls_buf op_type cannot map to a concrete pipe"); - std::string pipeTok = pipeTokFromPipeEnum(pipe); - auto argsAttr = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeTok), - op.getBufIdAttr(), - op.getModeAttr(), - }); - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, "rls_buf", - /*args=*/argsAttr, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{}); - return success(); - } -}; - -struct PTOSetFFTsToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::SetFFTsOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto *ctx = rewriter.getContext(); - auto loc = op.getLoc(); - - Value fftsAddr = peelUnrealized(adaptor.getFfts()); - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - - if (isSetFFTsPointerLikeType(fftsAddr.getType())) { - auto castTyAttr = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - fftsAddr = - rewriter - .create(loc, u64Ty, "reinterpret_cast", - /*args=*/ArrayAttr{}, - /*templateArgs=*/castTyAttr, - /*operands=*/ValueRange{fftsAddr}) - .getResult(0); - } else if (fftsAddr.getType() != u64Ty) { - fftsAddr = - rewriter.create(loc, u64Ty, fftsAddr).getResult(); - } - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, "set_ffts_base_addr", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{fftsAddr}); - return success(); - } -}; - -struct PTOSyncSetToEmitC : public OpConversionPattern { - PTOSyncSetToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult - matchAndRewrite(mlir::pto::SyncSetOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - auto *ctx = rewriter.getContext(); - IntegerAttr eventIdAttr = op.getEventIdAttr(); - Value eventIdDyn = adaptor.getEventIdDyn(); - int64_t fftsMode = 2; - if (IntegerAttr fftsModeAttr = op.getFftsModeAttr()) - fftsMode = fftsModeAttr.getInt(); - - if ((eventIdAttr != nullptr) == static_cast(eventIdDyn)) - return rewriter.notifyMatchFailure( - op, "expects exactly one of static event_id attr or dynamic event_id operand"); - - // A5 inter-core sync mirrors +16 only for cube-side producer (PIPE_FIX). - // Vec-side producer (PIPE_MTE3) emits a single set; hardware handles the - // subblock mapping in PTO-ISA custom flow. - if (targetArch == PTOArch::A5) { - pto::PIPE pipe = op.getPipe().getPipe(); - bool needsMirrorPlus16 = (pipe == pto::PIPE::PIPE_FIX); - std::string pipeTok = pipeTokFromPipeAttr(op.getPipe()); - auto emitSet = [&](Value eventOperand, IntegerAttr eventLiteral, - bool isDynamic) { - if (isDynamic) { - auto args = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeTok), - IntegerAttr::get(IndexType::get(ctx), 0), - }); - rewriter.create(loc, TypeRange{}, "set_intra_block", - /*args=*/args, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{eventOperand}); - return; - } - auto args = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeTok), - eventLiteral, - }); - rewriter.create(loc, TypeRange{}, "set_intra_block", - /*args=*/args, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{}); - }; - - if (eventIdAttr) { - emitSet(Value{}, eventIdAttr, /*isDynamic=*/false); - if (needsMirrorPlus16) { - auto plus16 = IntegerAttr::get(eventIdAttr.getType(), - eventIdAttr.getInt() + 16); - emitSet(Value{}, plus16, /*isDynamic=*/false); - } - } else { - Value eventI32 = castInterCoreEventIdToI32(rewriter, loc, eventIdDyn); - emitSet(eventI32, IntegerAttr{}, /*isDynamic=*/true); - if (needsMirrorPlus16) { - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value c16 = makeEmitCIntConstant(rewriter, loc, i32Ty, 16); - Value eventI32Plus16 = - rewriter.create(loc, i32Ty, eventI32, c16).getResult(); - emitSet(eventI32Plus16, IntegerAttr{}, /*isDynamic=*/true); - } - } - - rewriter.eraseOp(op); - return success(); - } - - InterCoreSyncCallDesc desc; - if (eventIdAttr) { - desc = buildInterCoreSyncSetCall(rewriter, loc, targetArch, op.getPipe(), - eventIdAttr, fftsMode); - } else { - desc = buildInterCoreSyncSetCallDyn(rewriter, loc, targetArch, op.getPipe(), - eventIdDyn, fftsMode); - } - rewriter.create(loc, TypeRange{}, desc.callee, - /*args=*/desc.args, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/desc.operands); - - rewriter.eraseOp(op); - return success(); - } - - PTOArch targetArch; -}; - -struct PTOSyncWaitToEmitC : public OpConversionPattern { - PTOSyncWaitToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult - matchAndRewrite(mlir::pto::SyncWaitOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - IntegerAttr eventIdAttr = op.getEventIdAttr(); - Value eventIdDyn = adaptor.getEventIdDyn(); - - if ((eventIdAttr != nullptr) == static_cast(eventIdDyn)) - return rewriter.notifyMatchFailure( - op, "expects exactly one of static event_id attr or dynamic event_id operand"); - - InterCoreSyncCallDesc desc; - if (eventIdAttr) { - desc = buildInterCoreSyncWaitCall(rewriter, targetArch, op.getPipe(), - eventIdAttr); - } else { - desc = buildInterCoreSyncWaitCallDyn(rewriter, loc, targetArch, op.getPipe(), - eventIdDyn); - } - rewriter.create(loc, TypeRange{}, desc.callee, - desc.args, ArrayAttr{}, desc.operands); - - rewriter.eraseOp(op); - return success(); - } - - PTOArch targetArch; -}; - -// GetBlockIdxOp Lowering (pto.get_block_idx -> get_block_idx()) -struct PTOGetBlockIdxToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::pto::GetBlockIdxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - rewriter.replaceOpWithNewOp( - op, op.getType(), "get_block_idx", ValueRange{}, ArrayAttr{}, - ArrayAttr{}); - - return success(); - } -}; - -// GetBlockNumOp Lowering (pto.get_block_num -> get_block_num()) -struct PTOGetBlockNumToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::pto::GetBlockNumOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - rewriter.replaceOpWithNewOp( - op, op.getType(), "get_block_num", ValueRange{}, ArrayAttr{}, - ArrayAttr{}); - - return success(); - } -}; - -// GetSubBlockIdxOp Lowering (pto.get_block_idx -> get_subblockid()) -struct PTOGetSubBlockIdxToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::pto::GetSubBlockIdxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - rewriter.replaceOpWithNewOp( - op, op.getType(), "get_subblockid", ValueRange{}, ArrayAttr{}, - ArrayAttr{}); - - return success(); - } -}; - -// GetSubBlockNumOp Lowering. -struct PTOGetSubBlockNumToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::pto::GetSubBlockNumOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - rewriter.replaceOpWithNewOp( - op, op.getType(), "get_subblockdim", ValueRange{}, ArrayAttr{}, - ArrayAttr{}); - - return success(); - } -}; - - -struct PTOMScatterToMSCATTER : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::MScatterOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto *ctx = rewriter.getContext(); - Value src = peelUnrealized(adaptor.getSrc()); - Value idx = peelUnrealized(adaptor.getIdx()); - Value mem = peelUnrealized(adaptor.getMem()); - - Value memArg = maybeWrapGlobalMemrefAsGlobalTensor( - rewriter, op.getLoc(), mem, op.getMem().getType(), op.getOperation()); - - auto scatterAtomicTok = [&](pto::ScatterAtomicOp atomic) -> StringRef { - switch (atomic) { - case pto::ScatterAtomicOp::None: - return "pto::ScatterAtomicOp::None"; - case pto::ScatterAtomicOp::Add: - return "pto::ScatterAtomicOp::Add"; - case pto::ScatterAtomicOp::Max: - return "pto::ScatterAtomicOp::Max"; - case pto::ScatterAtomicOp::Min: - return "pto::ScatterAtomicOp::Min"; - } - llvm_unreachable("unknown ScatterAtomicOp"); - }; - auto scatterOobTok = [&](pto::ScatterOOB mode) -> StringRef { - switch (mode) { - case pto::ScatterOOB::Undefined: - return "pto::ScatterOOB::Undefined"; - case pto::ScatterOOB::Skip: - return "pto::ScatterOOB::Skip"; - case pto::ScatterOOB::Clamp: - return "pto::ScatterOOB::Clamp"; - case pto::ScatterOOB::Wrap: - return "pto::ScatterOOB::Wrap"; - } - llvm_unreachable("unknown ScatterOOB"); - }; - - SmallVector templateArgVec; - const bool rowCoalesce = - isRowCoalescedMGatherIndexType(op.getSrc().getType(), op.getIdx().getType()); - templateArgVec.push_back(emitc::OpaqueAttr::get( - ctx, rowCoalesce ? "pto::Coalesce::Row" : "pto::Coalesce::Elem")); - if (op.getScatterAtomicOp() != pto::ScatterAtomicOp::None || - op.getScatterOob() != pto::ScatterOOB::Undefined) { - templateArgVec.push_back(emitc::OpaqueAttr::get( - ctx, scatterAtomicTok(op.getScatterAtomicOp()))); - if (op.getScatterOob() != pto::ScatterOOB::Undefined) - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, scatterOobTok(op.getScatterOob()))); - } - ArrayAttr templateArgs = rewriter.getArrayAttr(templateArgVec); - - rewriter.create( - op.getLoc(), TypeRange{}, "MSCATTER", - ArrayAttr{}, templateArgs, - ValueRange{memArg, src, idx}); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTOSetValToSETVAL : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSetValOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value dst = peelUnrealized(adaptor.getDst()); - Value val = peelUnrealized(adaptor.getVal()); - - // ---- offset: SSA index operand ---- - Value offset = peelUnrealized(adaptor.getOffset()); - - // Emit a marker call and let the ptoas post-processing step lower it to - // the corresponding tile setter. - rewriter.create( - op.getLoc(), TypeRange{}, "PTOAS__TILE_SET_VALUE", - ArrayAttr{}, ArrayAttr{}, ValueRange{dst, offset, val}); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTOGetValToGETVAL : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGetValOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src = peelUnrealized(adaptor.getSrc()); - - // ---- offset: SSA index operand ---- - Value offset = peelUnrealized(adaptor.getOffset()); - - // Emit a marker call and let the ptoas post-processing step lower it to - // the corresponding tile getter. - Type dstTy = getTypeConverter()->convertType(op.getDst().getType()); - if (!dstTy) - return failure(); - auto call = rewriter.create( - op.getLoc(), - TypeRange{dstTy}, - "PTOAS__TILE_GET_VALUE", - ArrayAttr{}, ArrayAttr{}, - ValueRange{src, offset}); - - rewriter.replaceOp(op, call.getResults()); - return success(); - } -}; - -struct PTOTAxpyToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAxpyOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value scalar = peelUnrealized(adaptor.getScalar()); - - rewriter.create( - loc, TypeRange{}, "TAXPY", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOHistogramToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::THistogramOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value idx = peelUnrealized(adaptor.getIdx()); - Value dst = peelUnrealized(adaptor.getDst()); - - auto templateArgs = rewriter.getArrayAttr({emitc::OpaqueAttr::get( - ctx, op.getIsMSB() ? "HistByte::BYTE_1" : "HistByte::BYTE_0")}); - rewriter.create( - loc, TypeRange{}, "THISTOGRAM", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, - /*operands=*/ValueRange{dst, src, idx}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOGetScaleAddrToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGetScaleAddrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TGET_SCALE_ADDR", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOSetValidShapeToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::SetValidShapeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto peelAllCasts = [](Value v) { - while (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(0); - if (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(); - return v; - }; - auto isTileLike = [](Value v) -> bool { - auto ot = dyn_cast(v.getType()); - if (!ot) - return false; - StringRef s = ot.getValue(); - return s.contains("Tile<") || s.contains("ConvTile<"); - }; - - Value src = peelAllCasts(peelUnrealized(adaptor.getSource())); - Value row = peelUnrealized(adaptor.getValidRow()); - Value col = peelUnrealized(adaptor.getValidCol()); - - if (!isTileLike(src)) - return rewriter.notifyMatchFailure( - op, "set_validshape source must lower to a tile-like value"); - - rewriter.create( - op.getLoc(), TypeRange{}, "PTOAS__TILE_SET_VALIDSHAPE", ArrayAttr{}, - ArrayAttr{}, ValueRange{src, row, col}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOGetValidShapeToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::GetValidShapeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto peelAllCasts = [](Value v) { - while (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(0); - if (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(); - return v; - }; - auto isTileLike = [](Value v) -> bool { - auto ot = dyn_cast(v.getType()); - if (!ot) - return false; - StringRef s = ot.getValue(); - return s.contains("Tile<") || s.contains("ConvTile<"); - }; - - Value src = peelAllCasts(peelUnrealized(adaptor.getSource())); - if (!isTileLike(src)) - return rewriter.notifyMatchFailure( - op, "get_validshape source must lower to a tile-like value"); - - auto resultTy = getTypeConverter()->convertType(rewriter.getIndexType()); - if (!resultTy) - return failure(); - - Value row = rewriter - .create( - op.getLoc(), resultTy, - "PTOAS__TILE_GET_VALID_ROW", ArrayAttr{}, - ArrayAttr{}, ValueRange{src}) - .getResult(0); - Value col = rewriter - .create( - op.getLoc(), resultTy, - "PTOAS__TILE_GET_VALID_COL", ArrayAttr{}, - ArrayAttr{}, ValueRange{src}) - .getResult(0); - rewriter.replaceOp(op, ValueRange{row, col}); - return success(); - } -}; - -struct PTOTAssignToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAssignOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto peelAllCasts = [](Value v) { - while (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(0); - if (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(); - return v; - }; - auto isTileLike = [](Value v) -> bool { - auto ot = dyn_cast(v.getType()); - if (!ot) - return false; - StringRef s = ot.getValue(); - return s.contains("Tile<") || s.contains("ConvTile<"); - }; - - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value tile = peelAllCasts(peelUnrealized(adaptor.getTile())); - if (!isTileLike(tile)) - return rewriter.notifyMatchFailure( - op, "tassign tile must lower to a tile-like value"); - - Value addr = peelUnrealized(adaptor.getAddr()); - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - if (isa(addr.getType()) || - (isa(addr.getType()) && - cast(addr.getType()).getValue().ends_with("*"))) { - auto rcU64 = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - addr = rewriter - .create(loc, u64Ty, "reinterpret_cast", - ArrayAttr{}, rcU64, - ValueRange{addr}) - .getResult(0); - } else if (addr.getType() != u64Ty) { - addr = rewriter.create(loc, u64Ty, addr).getResult(); - } - - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{tile, addr}); - rewriter.replaceOp(op, tile); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.load_scalar / pto.store_scalar lowering -> ptr[offset] -//===----------------------------------------------------------------------===// - -static Type getPointerLikeElementType(Type type) { - if (auto ptrTy = dyn_cast(type)) - return ptrTy.getElementType(); - if (auto memTy = dyn_cast(type)) - return memTy.getElementType(); - return Type(); -} - -struct PTOPtrToIntToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::PtrToIntOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value ptr = peelUnrealized(adaptor.getPtr()); - Type dstTy = getTypeConverter()->convertType(op.getResult().getType()); - if (!dstTy) - return failure(); - - auto dstOpaque = dyn_cast(dstTy); - if (!dstOpaque) - return failure(); - - auto templateArgs = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), - dstOpaque.getValue())}); - auto cast = rewriter.create( - op.getLoc(), dstTy, "reinterpret_cast", ArrayAttr{}, templateArgs, - ValueRange{ptr}); - rewriter.replaceOp(op, cast.getResult(0)); - return success(); - } -}; - -struct PTOIntToPtrToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::IntToPtrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value addr = peelUnrealized(adaptor.getAddr()); - Type dstTy = getTypeConverter()->convertType(op.getResult().getType()); - if (!dstTy) - return failure(); - - Type dstElemTy = getPointerLikeElementType(op.getResult().getType()); - if (!dstElemTy) - return failure(); - - std::string castType = - std::string("__gm__ ") + getEmitCScalarTypeToken(dstElemTy) + "*"; - auto templateArgs = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), - castType)}); - auto cast = rewriter.create( - op.getLoc(), dstTy, "reinterpret_cast", ArrayAttr{}, templateArgs, - ValueRange{addr}); - rewriter.replaceOp(op, cast.getResult(0)); - return success(); - } -}; - -struct PTOLoadScalarToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::LoadScalarOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value ptr = peelUnrealized(adaptor.getPtr()); - Value offset = peelUnrealized(adaptor.getOffset()); - - Type dstTy = getTypeConverter()->convertType(op.getValue().getType()); - if (!dstTy) - return failure(); - - auto call = rewriter.create( - op.getLoc(), TypeRange{dstTy}, "PTOAS__PTR_LOAD", - ArrayAttr{}, ArrayAttr{}, ValueRange{ptr, offset}); - - rewriter.replaceOp(op, call.getResults()); - return success(); - } -}; - -struct PTOStoreScalarToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::StoreScalarOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value ptr = peelUnrealized(adaptor.getPtr()); - Value offset = peelUnrealized(adaptor.getOffset()); - Value val = peelUnrealized(adaptor.getValue()); - - rewriter.create( - op.getLoc(), TypeRange{}, "PTOAS__PTR_STORE", - ArrayAttr{}, ArrayAttr{}, ValueRange{ptr, offset, val}); - rewriter.create( - op.getLoc(), TypeRange{}, "PTOAS__SCALAR_GM_STORE_FLUSH", - ArrayAttr{}, ArrayAttr{}, ValueRange{ptr}); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.tabs lowering -> TABS(dst, src) -//===----------------------------------------------------------------------===// - -struct PTOTAbsToTABS : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAbsOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - // intrinsic: TABS(dst, src) - rewriter.create( - op.getLoc(), TypeRange{}, "TABS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tadd lowering -> TADD(dst, src0, src1) -//===----------------------------------------------------------------------===// - -struct PTOTAddToTADD : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAddOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TADD", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOInitializeL2G2LPipeToEmitC - : public OpConversionPattern { - PTOInitializeL2G2LPipeToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult matchAndRewrite(mlir::pto::InitializeL2G2LPipeOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto tpipeTok = buildTPipeTokenFromInitOp(op.getOperation(), targetArch); - if (failed(tpipeTok)) - return rewriter.notifyMatchFailure(op, "failed to build TPipe token"); - - auto *ctx = rewriter.getContext(); - auto emitPipeTy = - cast(getTypeConverter()->convertType(op.getPipe().getType())); - - Value gmAddr = peelUnrealized(adaptor.getGmAddr()); - gmAddr = materializeTensorViewDataPointer( - rewriter, op.getLoc(), gmAddr, op.getGmAddr().getType()); - Value localAddr = - op.getLocalAddr() ? peelUnrealized(adaptor.getLocalAddr()) : Value(); - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value zero = makeEmitCIntConstant(rewriter, op.getLoc(), i32Ty, 0); - - Value c2vBuf = zero; - Value v2cBuf = zero; - if (op.getDirMask() == 1) - c2vBuf = localAddr ? localAddr : zero; - else if (op.getDirMask() == 2) - v2cBuf = localAddr ? localAddr : zero; - else if (op.getDirMask() == 3) { - if (localAddr) { - if (!op.getPeerLocalAddr()) - return rewriter.notifyMatchFailure( - op, "bidirectional l2g2l pipe requires peer local buffer"); - c2vBuf = localAddr; - v2cBuf = peelUnrealized(adaptor.getPeerLocalAddr()); - } - } else - return rewriter.notifyMatchFailure(op, "unsupported dir_mask"); - - rewriter.replaceOpWithNewOp( - op, TypeRange{emitPipeTy}, *tpipeTok, ArrayAttr{}, ArrayAttr{}, - ValueRange{gmAddr, c2vBuf, v2cBuf}); - return success(); - } - - PTOArch targetArch; -}; - -struct PTOInitializeL2LPipeToEmitC - : public OpConversionPattern { - PTOInitializeL2LPipeToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult matchAndRewrite(mlir::pto::InitializeL2LPipeOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto tpipeTok = buildTPipeTokenFromInitOp(op.getOperation(), targetArch); - if (failed(tpipeTok)) - return rewriter.notifyMatchFailure(op, "failed to build TPipe token"); - - auto *ctx = rewriter.getContext(); - auto emitPipeTy = - cast(getTypeConverter()->convertType(op.getPipe().getType())); - - auto gmPtrTy = - emitc::PointerType::get(emitc::OpaqueType::get(ctx, "__gm__ void")); - Value nullGm = - makeEmitCOpaqueConstant(rewriter, op.getLoc(), gmPtrTy, "nullptr"); - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value zero = makeEmitCIntConstant(rewriter, op.getLoc(), i32Ty, 0); - Value localAddr = peelUnrealized(adaptor.getLocalAddr()); - - Value c2vBuf = zero; - Value v2cBuf = zero; - if (op.getDirMask() == 1) - c2vBuf = localAddr; - else if (op.getDirMask() == 2) - v2cBuf = localAddr; - else if (op.getDirMask() == 3) { - c2vBuf = localAddr; - v2cBuf = peelUnrealized(adaptor.getPeerLocalAddr()); - } else - return rewriter.notifyMatchFailure(op, "unsupported dir_mask"); - - rewriter.replaceOpWithNewOp( - op, TypeRange{emitPipeTy}, *tpipeTok, ArrayAttr{}, ArrayAttr{}, - ValueRange{nullGm, c2vBuf, v2cBuf}); - return success(); - } - - PTOArch targetArch; -}; - -struct PTOBuildAsyncSessionToEmitC - : public OpConversionPattern { - PTOBuildAsyncSessionToEmitC(TypeConverter &typeConverter, MLIRContext *ctx) - : OpConversionPattern(typeConverter, ctx) {} - - LogicalResult matchAndRewrite(mlir::pto::BuildAsyncSessionOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto *ctx = rewriter.getContext(); - Location loc = op.getLoc(); - - auto sessionTy = - dyn_cast(getTypeConverter()->convertType(op.getSession().getType())); - if (!sessionTy) - return rewriter.notifyMatchFailure(op, "failed to convert async session type"); - - FailureOr scratchTile = - buildAsyncScratchTileValue(rewriter, loc, op.getScratch(), - adaptor.getScratch()); - if (failed(scratchTile)) - return rewriter.notifyMatchFailure(op, "failed to materialize async scratch tile"); - - Value workspace = - castToGMBytePointer(rewriter, loc, peelUnrealized(adaptor.getWorkspace())); - - Value session = rewriter - .create( - loc, sessionTy, emitc::OpaqueAttr::get(ctx, "")) - .getResult(); - - auto u32Ty = emitc::OpaqueType::get(ctx, "uint32_t"); - - auto makeU32Const = [&](uint64_t value) -> Value { - return makeEmitCOpaqueConstant(rewriter, loc, u32Ty, - std::to_string(value) + "u"); - }; - uint64_t syncId = op.getSyncIdAttr() ? op.getSyncIdAttr().getInt() : 0; - uint64_t blockBytes = - op.getBlockBytesAttr() ? op.getBlockBytesAttr().getInt() : 32 * 1024; - uint64_t commBlockOffset = - op.getCommBlockOffsetAttr() ? op.getCommBlockOffsetAttr().getInt() : 0; - uint64_t queueNum = op.getQueueNumAttr() ? op.getQueueNumAttr().getInt() : 1; - uint64_t channelGroupIdx = op.getChannelGroupIdxAttr() - ? op.getChannelGroupIdxAttr().getInt() - : UINT32_MAX; - - Value syncIdVal = makeU32Const(syncId); - Value channelGroupIdxVal = - channelGroupIdx == UINT32_MAX - ? makeEmitCOpaqueConstant(rewriter, loc, u32Ty, "UINT32_MAX") - : makeU32Const(channelGroupIdx); - - auto baseConfigTy = - emitc::OpaqueType::get(ctx, "pto::comm::sdma::SdmaBaseConfig"); - Value baseConfig = - rewriter - .create( - loc, baseConfigTy, - emitc::OpaqueAttr::get( - ctx, "{" + std::to_string(blockBytes) + "ULL, " + - std::to_string(commBlockOffset) + "ULL, " + - std::to_string(queueNum) + "u}")) - .getResult(); - - rewriter.create( - loc, TypeRange{}, "pto::comm::BuildAsyncSession", - ArrayAttr{}, ArrayAttr{}, - ValueRange{*scratchTile, workspace, session, syncIdVal, baseConfig, - channelGroupIdxVal}); - - rewriter.replaceOp(op, session); - return success(); - } -}; - -template -struct PTOAsyncTransferToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - explicit PTOAsyncTransferToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - StringRef callee) - : OpConversionPattern(typeConverter, ctx), callee(callee.str()) {} - - LogicalResult matchAndRewrite(AsyncOp op, typename AsyncOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value dstGT = dst; - Value srcGT = src; - if (!isEmitCGlobalTensorLikeType(dstGT.getType())) { - auto dstMrTy = dyn_cast(op.getDst().getType()); - if (!dstMrTy) - return rewriter.notifyMatchFailure(op, "expected dst to lower to GlobalTensor or memref"); - dstGT = buildGlobalTensorFromMemref(rewriter, op.getLoc(), dst, dstMrTy, - op.getDst().getDefiningOp() - ? op.getDst().getDefiningOp() - : op.getOperation()); - } - if (!isEmitCGlobalTensorLikeType(srcGT.getType())) { - auto srcMrTy = dyn_cast(op.getSrc().getType()); - if (!srcMrTy) - return rewriter.notifyMatchFailure(op, "expected src to lower to GlobalTensor or memref"); - srcGT = buildGlobalTensorFromMemref(rewriter, op.getLoc(), src, srcMrTy, - op.getSrc().getDefiningOp() - ? op.getSrc().getDefiningOp() - : op.getOperation()); - } - if (!dstGT || !srcGT) - return rewriter.notifyMatchFailure(op, "failed to build GlobalTensor operands"); - - Type eventTy = this->getTypeConverter()->convertType(op.getEvent().getType()); - if (!eventTy) - return rewriter.notifyMatchFailure(op, "failed to convert async event type"); - - rewriter.replaceOpWithNewOp( - op, TypeRange{eventTy}, callee, ArrayAttr{}, ArrayAttr{}, - ValueRange{dstGT, srcGT, peelUnrealized(adaptor.getSession())}); - return success(); - } - - std::string callee; -}; - -template -struct PTOAsyncEventToEmitC : public OpConversionPattern { - explicit PTOAsyncEventToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - StringRef callee) - : OpConversionPattern(typeConverter, ctx), - callee(callee.str()) {} - - LogicalResult matchAndRewrite(AsyncEventOp op, - typename AsyncEventOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type resultTy = - this->getTypeConverter()->convertType(op.getCompleted().getType()); - if (!resultTy) - return rewriter.notifyMatchFailure(op, "failed to convert async event result type"); - - rewriter.replaceOpWithNewOp( - op, TypeRange{resultTy}, callee, ArrayAttr{}, ArrayAttr{}, - ValueRange{peelUnrealized(adaptor.getEvent()), - peelUnrealized(adaptor.getSession())}); - return success(); - } - - std::string callee; -}; - -static FailureOr buildCommGlobalTensorValue( - ConversionPatternRewriter &rewriter, Location loc, Value originalValue, - Value emittedValue, Operation *anchor) { - Value value = peelUnrealized(emittedValue); - if (isEmitCGlobalTensorLikeType(value.getType())) - return value; - - auto memTy = dyn_cast(originalValue.getType()); - if (!memTy) - return failure(); - - Value gt = buildGlobalTensorFromMemref(rewriter, loc, value, memTy, anchor); - if (!gt) - return failure(); - return gt; -} - -static FailureOr buildCommTileValue(ConversionPatternRewriter &rewriter, - Location loc, Value originalValue, - Value emittedValue) { - Value value = peelUnrealized(emittedValue); - if (auto opaqueTy = dyn_cast(value.getType())) { - StringRef typeStr = opaqueTy.getValue(); - if (typeStr.contains("Tile<") || typeStr.contains("ConvTile<")) - return value; - } - return buildAsyncScratchTileValue(rewriter, loc, originalValue, emittedValue); -} - -static FailureOr buildCollectiveParallelGroup( - ConversionPatternRewriter &rewriter, Location loc, - ArrayRef groupGTs, int64_t root) { - if (groupGTs.empty()) - return failure(); - - auto firstTy = dyn_cast(groupGTs.front().getType()); - if (!firstTy) - return failure(); - - auto *ctx = rewriter.getContext(); - auto arrayTy = emitc::ArrayType::get({static_cast(groupGTs.size())}, - firstTy); - auto groupArray = cast>( - rewriter - .create(loc, arrayTy, - emitc::OpaqueAttr::get(ctx, "{}")) - .getResult()); - - auto indexTy = emitc::OpaqueType::get(ctx, "int"); - for (auto [idx, groupVal] : llvm::enumerate(groupGTs)) { - Value idxVal = - makeEmitCIntConstant(rewriter, loc, indexTy, static_cast(idx)); - Value slot = - rewriter.create(loc, groupArray, ValueRange{idxVal}) - .getResult(); - rewriter.create(loc, slot, groupVal); - } - - std::string pgTypeStr = - (Twine("pto::comm::ParallelGroup<") + firstTy.getValue() + ">").str(); - auto pgTy = emitc::OpaqueType::get(ctx, pgTypeStr); - Value sizeVal = makeEmitCIntConstant(rewriter, loc, indexTy, - static_cast(groupGTs.size())); - Value rootVal = makeEmitCIntConstant(rewriter, loc, indexTy, root); - return rewriter - .create( - loc, TypeRange{pgTy}, (Twine(pgTypeStr) + "::Create").str(), - ArrayAttr{}, ArrayAttr{}, ValueRange{groupArray, sizeVal, rootVal}) - .getResult(0); -} - -static std::string notifyOpTok(pto::NotifyOp op) { - switch (op) { - case pto::NotifyOp::AtomicAdd: - return "pto::comm::NotifyOp::AtomicAdd"; - case pto::NotifyOp::Set: - return "pto::comm::NotifyOp::Set"; - } - return "pto::comm::NotifyOp::Set"; -} - -static std::string waitCmpTok(pto::WaitCmp cmp) { - switch (cmp) { - case pto::WaitCmp::EQ: - return "pto::comm::WaitCmp::EQ"; - case pto::WaitCmp::NE: - return "pto::comm::WaitCmp::NE"; - case pto::WaitCmp::GT: - return "pto::comm::WaitCmp::GT"; - case pto::WaitCmp::GE: - return "pto::comm::WaitCmp::GE"; - case pto::WaitCmp::LT: - return "pto::comm::WaitCmp::LT"; - case pto::WaitCmp::LE: - return "pto::comm::WaitCmp::LE"; - } - return "pto::comm::WaitCmp::EQ"; -} - -static std::string reduceOpTok(pto::ReduceOp op) { - switch (op) { - case pto::ReduceOp::Sum: - return "pto::comm::ReduceOp::Sum"; - case pto::ReduceOp::Max: - return "pto::comm::ReduceOp::Max"; - case pto::ReduceOp::Min: - return "pto::comm::ReduceOp::Min"; - } - return "pto::comm::ReduceOp::Sum"; -} - -template -static FailureOr> buildCommGroupGlobalTensors( - ConversionPatternRewriter &rewriter, Location loc, OpTy op, - ValueRange originalGroup, ValueRange emittedGroup) { - SmallVector groupGTs; - groupGTs.reserve(originalGroup.size()); - for (auto [orig, emitted] : llvm::zip(originalGroup, emittedGroup)) { - FailureOr gt = - buildCommGlobalTensorValue(rewriter, loc, orig, emitted, op.getOperation()); - if (failed(gt)) - return failure(); - groupGTs.push_back(*gt); - } - return groupGTs; -} - -template -struct PTOCommCollectiveToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - explicit PTOCommCollectiveToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - StringRef apiName) - : OpConversionPattern(typeConverter, ctx), - apiName(apiName.str()) {} - - LogicalResult matchAndRewrite(CollectiveOp op, typename CollectiveOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - - auto buildPong = [&](Value original, Value emitted, StringRef name) -> FailureOr { - if (!original) - return failure(); - return buildCommTileValue(rewriter, loc, original, emitted); - }; - - if constexpr (std::is_same_v) { - FailureOr srcGT = - buildCommGlobalTensorValue(rewriter, loc, op.getSrc(), adaptor.getSrc(), - op.getOperation()); - FailureOr pingTile = - buildCommTileValue(rewriter, loc, op.getPing(), adaptor.getPing()); - auto groupGTs = - buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); - if (failed(srcGT) || failed(pingTile) || failed(groupGTs)) - return rewriter.notifyMatchFailure(op, "failed to materialize broadcast operands"); - FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); - if (failed(pg)) - return rewriter.notifyMatchFailure(op, "failed to materialize broadcast group"); - if (op.getPong()) { - FailureOr pongTile = - buildPong(op.getPong(), adaptor.getPong(), "__pong"); - if (failed(pongTile)) - return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); - rewriter.create( - loc, TypeRange{}, "pto::comm::TBROADCAST", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *srcGT, *pingTile, *pongTile}); - } else { - rewriter.create( - loc, TypeRange{}, "pto::comm::TBROADCAST", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *srcGT, *pingTile}); - } - } else if constexpr (std::is_same_v) { - FailureOr dstGT = - buildCommGlobalTensorValue(rewriter, loc, op.getDst(), adaptor.getDst(), - op.getOperation()); - FailureOr pingTile = - buildCommTileValue(rewriter, loc, op.getPing(), adaptor.getPing()); - auto groupGTs = - buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); - if (failed(dstGT) || failed(pingTile) || failed(groupGTs)) - return rewriter.notifyMatchFailure(op, "failed to materialize gather operands"); - FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); - if (failed(pg)) - return rewriter.notifyMatchFailure(op, "failed to materialize gather group"); - if (op.getPong()) { - FailureOr pongTile = - buildPong(op.getPong(), adaptor.getPong(), "__pong"); - if (failed(pongTile)) - return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); - rewriter.create( - loc, TypeRange{}, "pto::comm::TGATHER", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *dstGT, *pingTile, *pongTile}); - } else { - rewriter.create( - loc, TypeRange{}, "pto::comm::TGATHER", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *dstGT, *pingTile}); - } - } else if constexpr (std::is_same_v) { - FailureOr srcGT = - buildCommGlobalTensorValue(rewriter, loc, op.getSrc(), adaptor.getSrc(), - op.getOperation()); - FailureOr pingTile = - buildCommTileValue(rewriter, loc, op.getPing(), adaptor.getPing()); - auto groupGTs = - buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); - if (failed(srcGT) || failed(pingTile) || failed(groupGTs)) - return rewriter.notifyMatchFailure(op, "failed to materialize scatter operands"); - FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); - if (failed(pg)) - return rewriter.notifyMatchFailure(op, "failed to materialize scatter group"); - if (op.getPong()) { - FailureOr pongTile = - buildPong(op.getPong(), adaptor.getPong(), "__pong"); - if (failed(pongTile)) - return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); - rewriter.create( - loc, TypeRange{}, "pto::comm::TSCATTER", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *srcGT, *pingTile, *pongTile}); - } else { - rewriter.create( - loc, TypeRange{}, "pto::comm::TSCATTER", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *srcGT, *pingTile}); - } - } else { - FailureOr dstGT = - buildCommGlobalTensorValue(rewriter, loc, op.getDst(), adaptor.getDst(), - op.getOperation()); - FailureOr accTile = - buildCommTileValue(rewriter, loc, op.getAcc(), adaptor.getAcc()); - FailureOr recvPing = - buildCommTileValue(rewriter, loc, op.getRecvPing(), adaptor.getRecvPing()); - auto groupGTs = - buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); - if (failed(dstGT) || failed(accTile) || failed(recvPing) || failed(groupGTs)) - return rewriter.notifyMatchFailure(op, "failed to materialize reduce operands"); - FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); - if (failed(pg)) - return rewriter.notifyMatchFailure(op, "failed to materialize reduce group"); - if (op.getRecvPong()) { - FailureOr recvPong = - buildPong(op.getRecvPong(), adaptor.getRecvPong(), "__recv_pong"); - if (failed(recvPong)) - return rewriter.notifyMatchFailure(op, "failed to materialize recv_pong"); - auto reduceTy = - emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::ReduceOp"); - Value reduceOp = makeEmitCOpaqueConstant(rewriter, loc, reduceTy, - reduceOpTok(op.getReduceOp())); - rewriter.create( - loc, TypeRange{}, "pto::comm::TREDUCE", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *dstGT, *accTile, *recvPing, *recvPong, reduceOp}); - } else { - auto reduceTy = - emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::ReduceOp"); - Value reduceOp = makeEmitCOpaqueConstant(rewriter, loc, reduceTy, - reduceOpTok(op.getReduceOp())); - rewriter.create( - loc, TypeRange{}, "pto::comm::TREDUCE", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *dstGT, *accTile, *recvPing, reduceOp}); - } - } - rewriter.eraseOp(op); - return success(); - } - - std::string apiName; -}; - -template -struct PTOP2PCommToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - explicit PTOP2PCommToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - StringRef callee) - : OpConversionPattern(typeConverter, ctx), callee(callee.str()) {} - - LogicalResult matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - FailureOr dstGT = - buildCommGlobalTensorValue(rewriter, op.getLoc(), op.getDst(), adaptor.getDst(), - op.getOperation()); - FailureOr srcGT = - buildCommGlobalTensorValue(rewriter, op.getLoc(), op.getSrc(), adaptor.getSrc(), - op.getOperation()); - FailureOr pingTile = - buildCommTileValue(rewriter, op.getLoc(), op.getPing(), adaptor.getPing()); - if (failed(dstGT) || failed(srcGT) || failed(pingTile)) - return rewriter.notifyMatchFailure(op, "failed to materialize p2p operands"); - - SmallVector operands{*dstGT, *srcGT, *pingTile}; - std::string actualCallee = callee; - if constexpr (std::is_same_v) { - if (op.getAtomicType() == pto::AtomicType::AtomicAdd) - actualCallee = "pto::comm::TPUT"; - } - if (op.getPong()) { - FailureOr pongTile = - buildCommTileValue(rewriter, op.getLoc(), op.getPong(), adaptor.getPong()); - if (failed(pongTile)) - return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); - operands.push_back(*pongTile); - } - - rewriter.create(op.getLoc(), TypeRange{}, actualCallee, - ArrayAttr{}, ArrayAttr{}, operands); - rewriter.eraseOp(op); - return success(); - } - - std::string callee; -}; - -template -struct PTOSignalCommToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - explicit PTOSignalCommToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - StringRef callee) - : OpConversionPattern(typeConverter, ctx), - callee(callee.str()) {} - - LogicalResult matchAndRewrite(SignalOp op, typename SignalOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - FailureOr signalGT = buildCommGlobalTensorValue( - rewriter, op.getLoc(), op.getSignal(), adaptor.getSignal(), op.getOperation()); - if (failed(signalGT)) - return rewriter.notifyMatchFailure(op, "failed to materialize signal operand"); - - if constexpr (std::is_same_v) { - auto notifyTy = - emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::NotifyOp"); - Value notifyOp = makeEmitCOpaqueConstant( - rewriter, op.getLoc(), notifyTy, notifyOpTok(op.getNotifyOp())); - SmallVector operands{*signalGT, peelUnrealized(adaptor.getValue()), - notifyOp}; - rewriter.create(op.getLoc(), TypeRange{}, callee, - ArrayAttr{}, ArrayAttr{}, operands); - rewriter.eraseOp(op); - } else { - auto waitCmpTy = - emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::WaitCmp"); - Value waitCmp = makeEmitCOpaqueConstant( - rewriter, op.getLoc(), waitCmpTy, waitCmpTok(op.getCmp())); - SmallVector operands{*signalGT, peelUnrealized(adaptor.getCmpValue()), - waitCmp}; - if constexpr (std::is_same_v) { - Type resultTy = this->getTypeConverter()->convertType(op.getResult().getType()); - if (!resultTy) - return rewriter.notifyMatchFailure(op, "failed to convert ttest result type"); - rewriter.replaceOpWithNewOp( - op, TypeRange{resultTy}, callee, ArrayAttr{}, ArrayAttr{}, operands); - } else { - rewriter.create(op.getLoc(), TypeRange{}, callee, - ArrayAttr{}, ArrayAttr{}, operands); - rewriter.eraseOp(op); - } - } - return success(); - } - - std::string callee; -}; - -struct PTODeclareTileMemRefToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::DeclareTileMemRefOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::DeclareTileMemRefOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - Type convertedType = getTypeConverter()->convertType(op.getResult().getType()); - if (!convertedType) - return rewriter.notifyMatchFailure( - op, "failed to convert declare_tile_memref result type"); - rewriter.replaceOp(op, makeEmitCOpaqueConstant(rewriter, op.getLoc(), - convertedType, "nullptr")); - return success(); - } -}; - -struct PTODeclareGlobalToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::DeclareGlobalOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::DeclareGlobalOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - Type convertedType = getTypeConverter()->convertType(op.getEntry().getType()); - if (!convertedType) - return rewriter.notifyMatchFailure( - op, "failed to convert declare_global result type"); - if (auto tvTy = dyn_cast(op.getEntry().getType())) { - if (auto stridesAttr = - op->getAttrOfType(kGlobalTensorStridesAttrName)) { - auto strides = stridesAttr.asArrayRef(); - if (strides.size() == static_cast(tvTy.getRank())) { - convertedType = emitc::OpaqueType::get( - rewriter.getContext(), - getGlobalTensorTypeStringFromShapeAndStrides( - tvTy.getElementType(), tvTy.getShape(), strides)); - } - } - } - auto var = rewriter.create( - op.getLoc(), convertedType, - emitc::OpaqueAttr::get(rewriter.getContext(), "")); - rewriter.replaceOp(op, var.getResult()); - return success(); - } -}; - -struct PTODeclareEventIdArrayToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::DeclareEventIdArrayOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::DeclareEventIdArrayOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - Type arrayTy = getTypeConverter()->convertType(op.getArray().getType()); - if (!arrayTy) - return rewriter.notifyMatchFailure(op, - "failed to map declared eventid_array type"); - - auto array = rewriter - .create( - op.getLoc(), arrayTy, - emitc::OpaqueAttr::get(rewriter.getContext(), "")) - .getResult(); - rewriter.replaceOp(op, array); - return success(); - } -}; - -struct PTOEventIdArrayGetToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::EventIdArrayGetOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::EventIdArrayGetOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value array = peelUnrealized(adaptor.getArray()); - Value index = peelUnrealized(adaptor.getIndex()); - - Type resultTy = getTypeConverter()->convertType(op.getResult().getType()); - if (!resultTy) - return rewriter.notifyMatchFailure(op, - "failed to map eventid_array get result type"); - - auto load = - rewriter.create(op.getLoc(), resultTy, array, index); - rewriter.replaceOp(op, load.getResult()); - return success(); - } -}; - -struct PTOEventIdArraySetToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::EventIdArraySetOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::EventIdArraySetOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value array = peelUnrealized(adaptor.getArray()); - Value index = peelUnrealized(adaptor.getIndex()); - Value value = peelUnrealized(adaptor.getValue()); - - rewriter.create( - op.getLoc(), TypeRange{}, "PTOAS__EVENTID_ARRAY_STORE", - ArrayAttr{}, ArrayAttr{}, ValueRange{array, index, value}); - rewriter.eraseOp(op); - return success(); - } -}; - -// pto.declare_local_array -> emitc.variable of !emitc.array<...>. -// Renders as `T a[D1][D2]...;` in the emitted C++. -struct PTODeclareLocalArrayToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::DeclareLocalArrayOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::DeclareLocalArrayOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - Type arrayTy = getTypeConverter()->convertType(op.getArray().getType()); - if (!arrayTy) - return rewriter.notifyMatchFailure(op, - "failed to map !pto.local_array type"); - - auto var = rewriter - .create( - op.getLoc(), arrayTy, - emitc::OpaqueAttr::get(rewriter.getContext(), "")) - .getResult(); - rewriter.replaceOp(op, var); - return success(); - } -}; - -// pto.local_array_get %a[%i0, %i1, ...] -> rvalue. -// Lowers to a single emitc.subscript with the full index pack; the C++ emitter -// prints it as `a[i0][i1]...`. The adaptor already exposes target-typed values -// (the type converter has remapped !pto.local_array -> !emitc.array and -// index/integer indices), so they're forwarded directly to the builder. -struct PTOLocalArrayGetToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::LocalArrayGetOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::LocalArrayGetOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type resultTy = - getTypeConverter()->convertType(op.getResult().getType()); - if (!resultTy) - return rewriter.notifyMatchFailure( - op, "failed to map local_array element type"); - - auto sub = rewriter.create( - op.getLoc(), resultTy, adaptor.getArray(), adaptor.getIndices()); - rewriter.replaceOp(op, sub.getResult()); - return success(); - } -}; - -// pto.local_array_set %a[%i0, %i1, ...], %v -> emitc.assign to subscript slot. -// The C++ emitter prints this as `a[i0][i1]... = v;`. As above, adaptor values -// are already target-typed; pass them through directly. -struct PTOLocalArraySetToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::LocalArraySetOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::LocalArraySetOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value value = adaptor.getValue(); - Type elemTy = value.getType(); - - Value slot = rewriter - .create(op.getLoc(), elemTy, - adaptor.getArray(), - adaptor.getIndices()) - .getResult(); - rewriter.create(op.getLoc(), slot, value); - rewriter.eraseOp(op); - return success(); - } -}; - -static std::optional getStaticIndexLikeValue(Value value) { - if (!value) - return std::nullopt; - if (auto cst = value.getDefiningOp()) - return cst.value(); - if (auto cst = value.getDefiningOp()) - return cst.value(); - if (auto cst = value.getDefiningOp()) { - if (auto intAttr = dyn_cast(cst.getValue())) - return intAttr.getInt(); - } - return std::nullopt; -} - -static FailureOr buildGlobalTensorViewFromPointer( - ConversionPatternRewriter &rewriter, Location loc, Value ptr, Type elemTy, - ArrayRef shape, ArrayRef strides = {}, - StringRef layoutEnum = "pto::Layout::ND") { - if (llvm::any_of(shape, [](int64_t dim) { - return dim == ShapedType::kDynamic; - })) - return failure(); - - auto *ctx = rewriter.getContext(); - SmallVector rowMajorStrides; - ArrayRef effectiveStrides = strides; - if (effectiveStrides.empty()) { - rowMajorStrides = buildRowMajorStrides(shape); - effectiveStrides = rowMajorStrides; - } - SmallVector shape5D; - SmallVector stride5D; - buildGlobalTensorShapeAndStride(shape, effectiveStrides, shape5D, stride5D); - - std::string shapeType = "pto::Shape<" + joinIntTemplateParams(shape5D) + ">"; - std::string strideType = - "pto::Stride<" + joinIntTemplateParams(stride5D) + ">"; - auto shapeVal = rewriter - .create( - loc, emitc::OpaqueType::get(ctx, shapeType), - shapeType, ArrayAttr{}, ArrayAttr{}, ValueRange{}) - .getResult(0); - auto strideVal = rewriter - .create( - loc, emitc::OpaqueType::get(ctx, strideType), - strideType, ArrayAttr{}, ArrayAttr{}, ValueRange{}) - .getResult(0); - - std::string gtTypeStr = - getGlobalTensorTypeStringFromShapeAndStrides(elemTy, shape, - effectiveStrides, - layoutEnum); - auto gtType = emitc::OpaqueType::get(ctx, gtTypeStr); - auto gt = rewriter.create( - loc, gtType, gtTypeStr, ArrayAttr{}, ArrayAttr{}, - ValueRange{ptr, shapeVal, strideVal}); - return gt.getResult(0); -} - -static bool parseIntegerTemplateList(StringRef token, StringRef marker, - SmallVectorImpl &values) { - size_t pos = token.find(marker); - if (pos == StringRef::npos) - return false; - pos += marker.size(); - size_t end = token.find('>', pos); - if (end == StringRef::npos) - return false; - - SmallVector parts; - token.slice(pos, end).split(parts, ','); - values.clear(); - for (StringRef part : parts) { - int64_t value = 0; - if (part.trim().getAsInteger(10, value)) - return false; - values.push_back(value); - } - return true; -} - -static LogicalResult getStaticTensorViewStrides( - Value source, Value convertedSource, pto::TensorViewType sourceType, - SmallVectorImpl &strides) { - int64_t rank = sourceType.getRank(); - strides.clear(); - - if (auto makeView = source.getDefiningOp()) { - if ((int64_t)makeView.getStrides().size() != rank) - return failure(); - for (Value strideValue : makeView.getStrides()) { - auto cst = getStaticIndexLikeValue(strideValue); - if (!cst) - return failure(); - strides.push_back(*cst); - } - return success(); - } - - Value src = peelUnrealized(convertedSource); - if (auto opaqueTy = dyn_cast(src.getType())) { - SmallVector stride5D; - StringRef token = opaqueTy.getValue(); - if ((parseIntegerTemplateList(token, "pto::Stride<", stride5D) || - parseIntegerTemplateList(token, "Stride<", stride5D)) && - (int64_t)stride5D.size() >= rank) { - strides.append(stride5D.end() - rank, stride5D.end()); - return success(); - } - } - - auto fallback = buildRowMajorStrides(sourceType.getShape()); - strides.append(fallback.begin(), fallback.end()); - return success(); -} - -struct PTOPartitionViewToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::PartitionViewOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::PartitionViewOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto srcTy = dyn_cast(op.getSource().getType()); - auto resTy = dyn_cast(op.getResult().getType()); - if (!srcTy || !resTy) - return rewriter.notifyMatchFailure( - op, "expected tensor_view source and partition_tensor_view result"); - - if (op.getOffsets().size() != static_cast(srcTy.getRank()) || - op.getSizes().size() != static_cast(srcTy.getRank())) - return rewriter.notifyMatchFailure(op, "rank mismatch"); - - for (auto [idx, value] : llvm::enumerate(op.getSizes())) { - auto cst = getStaticIndexLikeValue(value); - if (!cst) - return rewriter.notifyMatchFailure( - op, "globaltensor partition_view requires static sizes"); - int64_t resultDim = resTy.getShape()[idx]; - if (resultDim != ShapedType::kDynamic && resultDim != *cst) - return rewriter.notifyMatchFailure( - op, "partition_view static size does not match result type"); - } - - SmallVector srcStrides; - if (failed(getStaticTensorViewStrides(op.getSource(), adaptor.getSource(), - srcTy, srcStrides))) - return rewriter.notifyMatchFailure( - op, "partition_view requires static source strides"); - int64_t staticLinearOffset = 0; - SmallVector> dynamicOffsetTerms; - for (auto [idx, values] : - llvm::enumerate(llvm::zip(op.getOffsets(), adaptor.getOffsets()))) { - Value originalOffset = std::get<0>(values); - Value convertedOffset = std::get<1>(values); - int64_t stride = srcStrides[idx]; - if (stride == ShapedType::kDynamic) - return rewriter.notifyMatchFailure( - op, "dynamic source stride is not supported"); - - if (auto cst = getStaticIndexLikeValue(originalOffset)) { - if (*cst != 0) - staticLinearOffset += (*cst) * stride; - continue; - } - dynamicOffsetTerms.push_back({convertedOffset, stride}); - } - - auto *ctx = rewriter.getContext(); - std::string elemTypeStr = getElemTypeStringForGT(srcTy.getElementType()); - auto ptrTy = emitc::PointerType::get( - emitc::OpaqueType::get(ctx, "__gm__ " + elemTypeStr)); - Value src = peelUnrealized(adaptor.getSource()); - auto data = rewriter - .create( - op.getLoc(), ptrTy, "PTOAS__GLOBAL_TENSOR_DATA", - ArrayAttr{}, ArrayAttr{}, ValueRange{src}) - .getResult(0); - Value ptr = data; - if (!dynamicOffsetTerms.empty()) { - Type u32Ty = emitc::OpaqueType::get(ctx, "unsigned"); - auto makeU32 = [&](int64_t value) { - return makeEmitCIntConstant(rewriter, op.getLoc(), u32Ty, value); - }; - auto asU32 = [&](Value value) -> Value { - if (value.getType() == u32Ty) - return value; - return rewriter.create(op.getLoc(), u32Ty, value) - .getResult(); - }; - - Value totalOffset = makeU32(staticLinearOffset); - for (auto [offsetValue, stride] : dynamicOffsetTerms) { - Value term = asU32(offsetValue); - if (stride != 1) { - Value strideValue = makeU32(stride); - term = rewriter - .create(op.getLoc(), u32Ty, term, - strideValue) - .getResult(); - } - totalOffset = rewriter - .create(op.getLoc(), u32Ty, - totalOffset, term) - .getResult(); - } - ptr = rewriter - .create(op.getLoc(), data.getType(), data, - totalOffset) - .getResult(); - } else { - ptr = applyStaticMemrefOffset(rewriter, op.getLoc(), data, - staticLinearOffset); - } - - auto resultOr = buildGlobalTensorViewFromPointer( - rewriter, op.getLoc(), ptr, resTy.getElementType(), resTy.getShape(), - srcStrides); - if (failed(resultOr)) - return rewriter.notifyMatchFailure( - op, "failed to materialize partition GlobalTensor"); - - rewriter.replaceOp(op, *resultOr); - return success(); - } -}; - -static FailureOr getPipeDataTypeToken(Value value) { - auto opaqueTy = dyn_cast(value.getType()); - if (!opaqueTy) - return failure(); - StringRef token = opaqueTy.getValue(); - if (!token.contains("Tile<") && !token.contains("GlobalTensor<")) - return failure(); - return token.str(); -} - -struct PTOTAllocToEmitC : public OpConversionPattern { - PTOTAllocToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult matchAndRewrite(mlir::pto::TAllocOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); - if (failed(pipeTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); - Value entry = peelUnrealized(adaptor.getEntry()); - auto entryTok = getPipeDataTypeToken(entry); - if (failed(entryTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve entry token"); - auto splitTok = getTileSplitToken(op.getSplit()); - if (failed(splitTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve split token"); - - std::string callee = - "TALLOC<" + *pipeTok + ", " + *entryTok + ", " + *splitTok + ">"; - rewriter.replaceOpWithNewOp( - op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, - ValueRange{peelUnrealized(adaptor.getPipeHandle()), entry}); - return success(); - } - - PTOArch targetArch; -}; - -struct PTOTPushToEmitC : public OpConversionPattern { - PTOTPushToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult matchAndRewrite(mlir::pto::TPushOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); - if (failed(pipeTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); - // Read the tile type token from the already-converted OpaqueType, which - // preserves the exact layout produced by BindTileOp / PointerCastOp EmitC. - Value convertedTile = peelUnrealized(adaptor.getTile()); - auto tileTok = getPipeDataTypeToken(convertedTile); - if (failed(tileTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve tile token"); - auto splitTok = getTileSplitToken(op.getSplit()); - if (failed(splitTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve split token"); - - std::string callee = - "TPUSH<" + *pipeTok + ", " + *tileTok + ", " + *splitTok + ">"; - rewriter.replaceOpWithNewOp( - op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, - ValueRange{peelUnrealized(adaptor.getPipeHandle()), convertedTile}); - return success(); - } - - PTOArch targetArch; -}; - -struct PTOTPopToEmitC : public OpConversionPattern { - PTOTPopToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult matchAndRewrite(mlir::pto::TPopOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); - if (failed(pipeTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); - Value convertedTile = peelUnrealized(adaptor.getTile()); - auto tileTok = getPipeDataTypeToken(convertedTile); - if (failed(tileTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve tile token"); - auto splitTok = getTileSplitToken(op.getSplit()); - if (failed(splitTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve split token"); - - std::string callee = - "TPOP<" + *pipeTok + ", " + *tileTok + ", " + *splitTok + ">"; - rewriter.replaceOpWithNewOp( - op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, - ValueRange{peelUnrealized(adaptor.getPipeHandle()), convertedTile}); - return success(); - } - - PTOArch targetArch; -}; - -struct PTOTFreeToEmitC : public OpConversionPattern { - PTOTFreeToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult matchAndRewrite(mlir::pto::TFreeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); - if (failed(pipeTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); - auto splitTok = getTileSplitToken(op.getSplit()); - if (failed(splitTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve split token"); - - SmallVector operands{peelUnrealized(adaptor.getPipeHandle())}; - std::string callee; - if (op.getEntry()) { - Value entry = peelUnrealized(adaptor.getEntry()); - auto entryTok = getPipeDataTypeToken(entry); - if (failed(entryTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve entry token"); - callee = "TFREE<" + *pipeTok + ", " + *entryTok + ", " + *splitTok + ">"; - operands.push_back(entry); - } else { - callee = "TFREE<" + *pipeTok + ", " + *splitTok + ">"; - } - rewriter.replaceOpWithNewOp( - op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, operands); - return success(); - } - - PTOArch targetArch; -}; - -//===----------------------------------------------------------------------===// -// populate patterns -//===----------------------------------------------------------------------=== -struct ReinterpretCastToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - auto resMrTy = dyn_cast(op.getType()); - if (!resMrTy) - return failure(); - - auto asAttr = dyn_cast_or_null(resMrTy.getMemorySpace()); - const bool isGm = (!asAttr || asAttr.getAddressSpace() == pto::AddressSpace::GM); - - bool emitAddPtrTrace = op->hasAttr("pto.addptr_trace"); - Value source = peelUnrealized(adaptor.getSource()); - auto offsets = adaptor.getOffsets(); - Value offsetVal = offsets.empty() ? Value() : offsets[0]; - - // GM: keep pointer arithmetic. - if (isGm) { - if (!offsetVal) { - rewriter.replaceOp(op, source); - return success(); - } - - Type resultType = getTypeConverter()->convertType(op.getType()); - if (!resultType) - return failure(); - - auto addOp = rewriter.create(loc, resultType, source, offsetVal); - if (emitAddPtrTrace) { - rewriter.setInsertionPointAfter(addOp); - rewriter.create( - loc, TypeRange{}, "PTOAS__ADDPTR_TRACE", - ArrayAttr{}, ArrayAttr{}, - ValueRange{addOp.getResult(), source, offsetVal}); - } - rewriter.replaceOp(op, addOp.getResult()); - return success(); - } - - // UB/L1/L0 tiles: materialize a new Tile view by assigning an adjusted - // underlying pointer (in elements). - pto::AddressSpace as = asAttr.getAddressSpace(); - - // Element type token. - Type elemTy = resMrTy.getElementType(); - std::string elemTok = getEmitCScalarTypeToken(elemTy); - int64_t elemBytes = getEmitCScalarByteWidth(elemTy); - - // Tile role. - const char *roleTok = "TileType::Vec"; - switch (as) { - case pto::AddressSpace::VEC: - roleTok = "TileType::Vec"; - break; - case pto::AddressSpace::MAT: - roleTok = "TileType::Mat"; - break; - case pto::AddressSpace::LEFT: - roleTok = "TileType::Left"; - break; - case pto::AddressSpace::RIGHT: - roleTok = "TileType::Right"; - break; - case pto::AddressSpace::ACC: - roleTok = "TileType::Acc"; - break; - case pto::AddressSpace::BIAS: - roleTok = "TileType::Bias"; - break; - case pto::AddressSpace::GM: - roleTok = "TileType::Vec"; - break; - case pto::AddressSpace::Zero: - roleTok = "TileType::Vec"; - break; - case pto::AddressSpace::SCALING: - roleTok = "TileType::Scaling"; - break; - } - - // Shape (fallback to 32x32). - int64_t rows = 32, cols = 32; - if (resMrTy.getRank() >= 2 && resMrTy.hasStaticShape()) { - rows = resMrTy.getDimSize(0); - cols = resMrTy.getDimSize(1); - } - int64_t templateRows = - renderTileTemplateDim(rows, elemTy, pto::BLayout::RowMajor, 0); - int64_t templateCols = - renderTileTemplateDim(cols, elemTy, pto::BLayout::RowMajor, 1); - - // Keep a conservative default config for now. - std::string tileTypeStr = - std::string("Tile<") + roleTok + ", " + elemTok + ", " + - std::to_string(templateRows) + ", " + std::to_string(templateCols) + - ", BLayout::RowMajor, " + std::to_string(templateRows) + ", " + - std::to_string(templateCols) + - ", SLayout::NoneBox, 512, PadValue::Null, CompactMode::Null>"; - - auto tileType = emitc::OpaqueType::get(ctx, tileTypeStr); - Value tile = rewriter - .create(loc, tileType, - emitc::OpaqueAttr::get(ctx, "")) - .getResult(); - - // Compute an integer address and assign it to the new tile. - // NOTE: pto-isa TASSIGN requires an integral address (not a pointer). - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - auto rcU64 = rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - - // Non-GM reinterpret_cast operands come from UB/L1/L0 tiles. - // We need the underlying address, but `__cce_get_tile_ptr()` is only valid - // inside `__tf__` functions. Use `tile.data()` (via a post-processed marker) - // and compute the adjusted address in bytes. - Value rawPtr = source; - if (auto ot = dyn_cast(source.getType())) { - // Only Tiles have a `.data()` member. For plain address-space pointers - // (e.g. `__ubuf__ float*`), use the pointer value directly. - if (ot.getValue().starts_with("Tile<")) { - rawPtr = materializeTileDataValue(rewriter, loc, source, as, elemTok); - } - } - - Value baseAddr = rawPtr; - if (isSetFFTsPointerLikeType(rawPtr.getType())) { - baseAddr = rewriter - .create(loc, u64Ty, "reinterpret_cast", - /*args=*/ArrayAttr{}, - /*templateArgs=*/rcU64, - /*operands=*/ValueRange{rawPtr}) - .getResult(0); - } else if (rawPtr.getType() != u64Ty) { - baseAddr = rewriter.create(loc, u64Ty, rawPtr).getResult(); - } - - Value addr = baseAddr; - if (offsetVal) { - Value offU64 = offsetVal; - if (offU64.getType() != u64Ty) - offU64 = rewriter.create(loc, u64Ty, offU64).getResult(); - - auto bytesAttr = emitc::OpaqueAttr::get(ctx, std::to_string(elemBytes)); - Value bytesVal = rewriter.create(loc, u64Ty, bytesAttr); - Value byteOff = rewriter.create(loc, u64Ty, offU64, bytesVal); - addr = rewriter.create(loc, u64Ty, baseAddr, byteOff); - } - - rewriter.create(loc, TypeRange{}, "TASSIGN", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{tile, addr}); - - rewriter.replaceOp(op, tile); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.taddc lowering -> TADDC(dst, src0, src1, src2) -//===----------------------------------------------------------------------===// - -struct PTOTAddCToTADDC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAddCOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value src2 = peelUnrealized(adaptor.getSrc2()); - Value dst = peelUnrealized(adaptor.getDst()); - - // pto-isa does not provide NPU implementation for TADDC yet. - // Decompose: dst = src0 + src1 + src2 - rewriter.create( - loc, TypeRange{}, "TADD", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TADD", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, dst, src2}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tadds lowering -> TADDS(dst, src, scalar) -//===----------------------------------------------------------------------===// - -struct PTOAddSToTADDS : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAddSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value scalar = peelUnrealized(adaptor.getScalar()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TADDS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src, scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.taddsc lowering -> TADDSC(dst, src0, scalar, src1) -//===----------------------------------------------------------------------===// - -struct PTOAddSCToTADDSC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAddSCOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - // pto-isa does not provide NPU implementation for TADDSC yet. - // Decompose: dst = src0 + scalar + src1 - rewriter.create( - loc, TypeRange{}, "TADDS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src0, scalar}); - rewriter.create( - loc, TypeRange{}, "TADD", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, dst, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTOTAndToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAndOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getSrc0()); - Value b = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TAND", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, a, b}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOConcatToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TConcatOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TCONCAT", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTOConcatidxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TConcatidxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value src0Idx = peelUnrealized(adaptor.getSrc0Idx()); - Value src1Idx = peelUnrealized(adaptor.getSrc1Idx()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TCONCAT", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src0, src1, src0Idx, src1Idx}); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTOAndSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAndSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TANDS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src, scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; - - -struct PTOTCIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TCIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value S = peelUnrealized(adaptor.getOperands()[0]); - - // The TCI scalar template parameter should follow the original PTO IR - // scalar type, not the converted EmitC value type. - std::string scalarTok = "int32_t"; - if (auto it = dyn_cast(op->getOperand(0).getType())) { - bool isUnsigned = it.isUnsigned(); - if (it.getWidth() == 16) - scalarTok = isUnsigned ? "uint16_t" : "int16_t"; - else - scalarTok = isUnsigned ? "uint32_t" : "int32_t"; - } - - // descending -> "0"/"1" - std::string descTok = op.getDescending() ? "1" : "0"; - - ArrayAttr targs; - if (auto ot = mlir::dyn_cast(dst.getType())) { - std::string tileTok = ot.getValue().str(); // "Tile<...>" - targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, tileTok), - emitc::OpaqueAttr::get(ctx, scalarTok), - emitc::OpaqueAttr::get(ctx, descTok), - }); - } else { - targs = rewriter.getArrayAttr({}); - } - - rewriter.create( - loc, TypeRange{}, "TCI", - /*args=*/ArrayAttr{}, - /*templateArgs=*/targs, - /*operands=*/ValueRange{dst, S}); - - rewriter.eraseOp(op); - return success(); - } -}; -static std::string cmpModeTok(pto::CmpModeAttr a) { - // 生成 "CmpMode::GT" 这种 token - auto m = a.getValue(); // 取 enum - switch (m) { - case pto::CmpMode::EQ: return "CmpMode::EQ"; - case pto::CmpMode::NE: return "CmpMode::NE"; - case pto::CmpMode::LT: return "CmpMode::LT"; - case pto::CmpMode::LE: return "CmpMode::LE"; - case pto::CmpMode::GT: return "CmpMode::GT"; - case pto::CmpMode::GE: return "CmpMode::GE"; - } - return "CmpMode::EQ"; -} -struct PTOColExpandToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPAND", - /*args=*/ArrayAttr(), - /*templateArgs=*/ArrayAttr(), - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandMulToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandMulOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDMUL", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandAddToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandAddOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDADD", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandDivToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandDivOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDDIV", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandExpdifToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandExpdifOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDEXPDIF", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandSubToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandSubOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDSUB", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDMAX", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDMIN", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOTTriToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TTriOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value diagonal = peelUnrealized(adaptor.getDiagonal()); - - ArrayAttr templateArgs; - if (auto dstOT = mlir::dyn_cast(dst.getType())) { - templateArgs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, std::to_string(op.getUpperOrLower())), - }); - } else { - templateArgs = ArrayAttr{}; - } - - SmallVector operands{dst, diagonal}; - rewriter.create( - loc, TypeRange{}, "TTRI", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOCmpToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TCmpOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - - std::string tok = "CmpMode::EQ"; - if (auto a = op.getCmpModeAttr()) - tok = cmpModeTok(a); - - auto modeTy = emitc::OpaqueType::get(ctx, "CmpMode"); - Value modeVal = rewriter.create( - loc, modeTy, emitc::OpaqueAttr::get(ctx, tok)); - - rewriter.create( - loc, - TypeRange{}, - "TCMP", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1, modeVal}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOCmpSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TCmpSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - - // cmpMode -> token - auto cmpAttr = op.getCmpModeAttr(); // PTO_CmpModeAttr - std::string tok = cmpModeTok(cmpAttr); - - auto modeTy = emitc::OpaqueType::get(ctx, "CmpMode"); - Value modeVal = rewriter.create( - loc, modeTy, emitc::OpaqueAttr::get(ctx, tok)); - - rewriter.create( - loc, - TypeRange{}, - "TCMPS", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, scalar, modeVal}); - - rewriter.eraseOp(op); - return success(); - } -}; - - -struct PTOColMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - // intrinsic: TCOLMAX(dst, src) - rewriter.create( - loc, TypeRange{}, "TCOLMAX", - /*args=*/ArrayAttr{}, // default: print all operands - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColArgMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColArgMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLARGMAX", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, tmp}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - // intrinsic: TCOLMIN(dst, src) - rewriter.create( - loc, TypeRange{}, "TCOLMIN", - /*args=*/ArrayAttr{}, // default: print all operands - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColArgMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColArgMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLARGMIN", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, tmp}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColSumToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColSumOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - // Check if tmp exists before accessing it - if (op.getTmp()) { - // Format 2: with tmp and isBinary - Value tmp = peelUnrealized(adaptor.getTmp()); - bool isBinary = false; - if (auto a = op.getIsBinaryAttr()) - isBinary = a.getValue(); - - auto boolTy = emitc::OpaqueType::get(ctx, "bool"); - auto tok = isBinary ? "true" : "false"; - Value isBinaryVal = rewriter.create( - loc, boolTy, emitc::OpaqueAttr::get(ctx, tok)); - - rewriter.create( - loc, TypeRange{}, "TCOLSUM", - /*args=*/ArrayAttr(), - /*templateArgs=*/ArrayAttr(), - /*operands=*/ValueRange{dst, src, tmp, isBinaryVal}); - } else { - // Format 1: without tmp and isBinary - rewriter.create( - loc, TypeRange{}, "TCOLSUM", - /*args=*/ArrayAttr(), - /*templateArgs=*/ArrayAttr(), - /*operands=*/ValueRange{dst, src}); - } - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColProdToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColProdOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLPROD", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; -static std::string roundModeTok(mlir::pto::RoundModeAttr attr) { - using RM = mlir::pto::RoundMode; - switch (attr.getValue()) { - case RM::NONE: return "RoundMode::CAST_NONE"; - case RM::RINT: return "RoundMode::CAST_RINT"; - case RM::ROUND: return "RoundMode::CAST_ROUND"; - case RM::FLOOR: return "RoundMode::CAST_FLOOR"; - case RM::CEIL: return "RoundMode::CAST_CEIL"; - case RM::TRUNC: return "RoundMode::CAST_TRUNC"; - case RM::ODD: return "RoundMode::CAST_ODD"; - case RM::CAST_RINT: return "RoundMode::CAST_RINT"; - } - return "RoundMode::CAST_RINT"; -} -static std::string saturationModeTok(mlir::pto::SaturationModeAttr attr) { - using SM = mlir::pto::SaturationMode; - switch (attr.getValue()) { - case SM::ON: return "SaturationMode::ON"; - case SM::OFF: return "SaturationMode::OFF"; - } - return "SaturationMode::OFF"; -} -struct PTOCvtToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TCvtOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - pto::RoundModeAttr rmAttr = op.getRmodeAttr(); - std::string rmTok = rmAttr ? roundModeTok(rmAttr) - : std::string("RoundMode::CAST_RINT"); - auto rmodeTy = emitc::OpaqueType::get(ctx, "RoundMode"); - Value rmodeVal = rewriter.create( - loc, rmodeTy, emitc::OpaqueAttr::get(ctx, rmTok)); - - auto satModeTy = emitc::OpaqueType::get(ctx, "SaturationMode"); - auto satAttr = op.getSatModeAttr(); - std::string satTok = satAttr ? saturationModeTok(satAttr) - : std::string("SaturationMode::OFF"); - Value satModeVal = rewriter.create( - loc, satModeTy, emitc::OpaqueAttr::get(ctx, satTok)); - - SmallVector operands{dst, src, rmodeVal, satModeVal}; - - rewriter.create( - loc, TypeRange{}, "TCVT", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTORandomToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRandomOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - SmallVector operands{ - dst, - peelUnrealized(adaptor.getKey0()), - peelUnrealized(adaptor.getKey1()), - peelUnrealized(adaptor.getCounter0()), - peelUnrealized(adaptor.getCounter1()), - peelUnrealized(adaptor.getCounter2()), - peelUnrealized(adaptor.getCounter3()), - }; - ArrayAttr templateArgs = rewriter.getArrayAttr( - {emitc::OpaqueAttr::get(ctx, std::to_string(op.getRounds()))}); - - rewriter.create( - loc, TypeRange{}, "PTOAS__TRANDOM", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, operands); - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tdiv lowering -> TDIV(dst, src0, src1) -//===----------------------------------------------------------------------===// - -struct PTODivToTDIV : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TDivOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TDIV", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tdivs lowering -> TDIVS(dst, src, scalar) or TDIVS(dst, scalar, src) -// Order is determined by operand types: if src is tile_buf, order is (tile, scalar) -// Otherwise, order is (scalar, tile) -//===----------------------------------------------------------------------===// - -struct PTODivSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TDivSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value dst = peelUnrealized(adaptor.getDst()); - // Preserve source order from textual parse: - // ins(tile, scalar) -> TDIVS(dst, tile, scalar) - // ins(scalar, tile) -> TDIVS(dst, scalar, tile) - rewriter.create( - loc, TypeRange{}, "TDIVS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src, scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.tdivs (TDivSOp) lowering -> TDIVS(dst, src, scalar) or TDIVS(dst, scalar, src) -// Order is determined by operand types: if src is tile_buf, order is (tile, scalar) -// Otherwise, order is (scalar, tile) -//===----------------------------------------------------------------------===// - -struct PTOTDivSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TDivSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value dst = peelUnrealized(adaptor.getDst()); - rewriter.create( - loc, TypeRange{}, "TDIVS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src, scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.texp lowering -> TEXP(dst, src) -//===----------------------------------------------------------------------===// - -struct PTOExpToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TExpOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TEXP", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.texpands lowering -> TEXPANDS(dst, scalar) -//===----------------------------------------------------------------------===// - -struct PTOExpandsToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TExpandsOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value scalar = peelUnrealized(adaptor.getScalar()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TEXPANDS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.textract lowering -> TEXTRACT(dst, src, indexRow, indexCol) -//===----------------------------------------------------------------------===// - -struct PTOExtractToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TExtractOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value r0 = peelUnrealized(adaptor.getIndexRow()); - Value c0 = peelUnrealized(adaptor.getIndexCol()); - - rewriter.create( - loc, TypeRange{}, "TEXTRACT", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, r0, c0}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.textract_fp lowering -> TEXTRACT_FP(dst, src, fp, indexRow, indexCol) -//===----------------------------------------------------------------------===// - -struct PTOExtractFPToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TExtractFPOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value fp = peelUnrealized(adaptor.getFp()); - Value dst = peelUnrealized(adaptor.getDst()); - Value r0 = peelUnrealized(adaptor.getIndexRow()); - Value c0 = peelUnrealized(adaptor.getIndexCol()); - - rewriter.create( - loc, TypeRange{}, "TEXTRACT_FP", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, fp, r0, c0}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tinsert lowering -> TINSERT(dst, src, indexRow, indexCol) -// Keep lowering arch-agnostic and let PTO-ISA infer proper A5 path. -//===----------------------------------------------------------------------===// - -struct PTOInsertToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TInsertOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value r0 = peelUnrealized(adaptor.getIndexRow()); - Value c0 = peelUnrealized(adaptor.getIndexCol()); - - rewriter.create( - loc, TypeRange{}, "TINSERT", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, r0, c0}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tinsert_fp lowering -> TINSERT_FP(dst, src, fp, indexRow, indexCol) -//===----------------------------------------------------------------------===// - -struct PTOInsertFPToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TInsertFPOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value fp = peelUnrealized(adaptor.getFp()); - Value dst = peelUnrealized(adaptor.getDst()); - Value r0 = peelUnrealized(adaptor.getIndexRow()); - Value c0 = peelUnrealized(adaptor.getIndexCol()); - - rewriter.create( - loc, TypeRange{}, "TINSERT_FP", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, fp, r0, c0}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tfillpad lowering -> TFILLPAD(dst, src) -//===----------------------------------------------------------------------===// - -struct PTOFillPadToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TFillPadOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TFILLPAD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tfillpad_inplace lowering -> TFILLPAD_INPLACE(dst, src) -//===----------------------------------------------------------------------===// - -struct PTOFillPadInplaceToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TFillPadInplaceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TFILLPAD_INPLACE", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tfillpad_expand lowering -> TFILLPAD_EXPAND(dst, src) -//===----------------------------------------------------------------------===// - -struct PTOFillPadExpandToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TFillPadExpandOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TFILLPAD_EXPAND", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tgather lowering -// - Index form : TGATHER(dst, src0, indices, tmp) -// - Compare form: TGATHER(dst, src0, kValue, cdst, tmp) -// - Mask form : TGATHER(dst, src0) -//===----------------------------------------------------------------------===// - -[[maybe_unused]] static std::string maskPatternTok(mlir::pto::MaskPatternAttr a) { - - auto v = a.getValue(); // enum - return (std::string("pto::MaskPattern::") + mlir::pto::stringifyMaskPattern(v).str()); -} - -struct PTOGatherToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGatherOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src0 = peelUnrealized(adaptor.getSrc()); - - auto getOpaqueTok = [&](Value v, StringRef name) -> FailureOr { - if (auto ot = mlir::dyn_cast(v.getType())) - return ot.getValue().str(); - return rewriter.notifyMatchFailure(op, (name + " must be emitc::OpaqueType (tile)").str()); - }; - - // Case 1: index-based TGATHER(dst, src0, indices, tmp) - if (Value idx = adaptor.getIndices()) { - idx = peelUnrealized(idx); - Value tmp = peelUnrealized(adaptor.getTmp()); - - rewriter.create( - loc, TypeRange{}, "TGATHER", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, idx, tmp}); - - rewriter.eraseOp(op); - return success(); - } - - // Case 2: compare-based TGATHER( - // dst, src0, kValue, tmp, cdst, offset) - if (Value cdst = adaptor.getCdst()) { - cdst = peelUnrealized(cdst); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value kValue = peelUnrealized(adaptor.getKValue()); - - auto dstTokOr = getOpaqueTok(dst, "dst"); - auto srcTokOr = getOpaqueTok(src0, "src0"); - auto cdstTokOr = getOpaqueTok(cdst, "cdst"); - auto tmpTokOr = getOpaqueTok(tmp, "tmp"); - if (failed(dstTokOr) || failed(srcTokOr) || failed(cdstTokOr) || failed(tmpTokOr)) - return failure(); - - auto cmpAttr = op.getCmpModeAttr(); - std::string cmpTok = cmpAttr ? cmpModeTok(cmpAttr) : "CmpMode::EQ"; - int64_t offset = 0; - if (auto offsetAttr = op.getOffsetAttr()) - offset = offsetAttr.getInt(); - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value offsetVal = makeEmitCIntConstant(rewriter, loc, i32Ty, offset); - - auto targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, *dstTokOr), - emitc::OpaqueAttr::get(ctx, *srcTokOr), - emitc::OpaqueAttr::get(ctx, *tmpTokOr), - emitc::OpaqueAttr::get(ctx, *cdstTokOr), - emitc::OpaqueAttr::get(ctx, cmpTok), - }); - - rewriter.create( - loc, TypeRange{}, "TGATHER", - /*args=*/ArrayAttr{}, /*templateArgs=*/targs, - /*operands=*/ValueRange{dst, src0, kValue, tmp, cdst, offsetVal}); - - rewriter.eraseOp(op); - return success(); - } - - // Case 3: mask-pattern TGATHER(dst, src0) - auto mp = op.getMaskPatternAttr(); - if (!mp) - return rewriter.notifyMatchFailure(op, "expected maskPattern, indices, or cdst on tgather"); - - auto dstTokOr = getOpaqueTok(dst, "dst"); - auto srcTokOr = getOpaqueTok(src0, "src0"); - if (failed(dstTokOr) || failed(srcTokOr)) - return failure(); - - // mp is an EnumAttr; stringify name is "P0101" etc. - // We emit MaskPattern::P0101 (because generated C++ has `using namespace pto;`) - std::string mpTok = std::string("MaskPattern::") + - mlir::pto::stringifyMaskPattern(mp.getValue()).str(); - - auto targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, *dstTokOr), - emitc::OpaqueAttr::get(ctx, *srcTokOr), - emitc::OpaqueAttr::get(ctx, mpTok), - }); - - rewriter.create( - loc, TypeRange{}, "TGATHER", - /*args=*/ArrayAttr{}, - /*templateArgs=*/targs, - /*operands=*/ValueRange{dst, src0}); - - rewriter.eraseOp(op); - return success(); - } -}; - - -struct PTOGatherbToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGatherBOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value offsets = peelUnrealized(adaptor.getOffsets()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TGATHERB", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, offsets}); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// TLOG lowering to EmitC (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - -struct PTOLogToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TLogOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TLOG", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - - - -//===----------------------------------------------------------------------===// -// TLRELU lowering to EmitC (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - - struct PTOLReluToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TLReluOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value slope = peelUnrealized(adaptor.getSlope()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, slope}; - - rewriter.create( - loc, TypeRange{}, "TLRELU", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// TMAX lowering to EmitC (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - -struct PTOMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TMAX", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// TMAXS lowering to EmitC (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - - struct PTOMaxSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMaxSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, scalar}; - rewriter.create( - loc, TypeRange{}, "TMAXS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - - -//===----------------------------------------------------------------------===// -// TMIN lowering to EmitC (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - -struct PTOMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TMIN", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// TMINS lowering to EmitC (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - -//===----------------------------------------------------------------------===// -// TMINS lowering to EmitC (fix APFloat -> FloatAttr) (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - -struct PTOMinsToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMinSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value scalar = peelUnrealized(adaptor.getScalar()); - - SmallVector operands{dst, src, scalar}; - rewriter.create( - loc, TypeRange{}, "TMINS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering for TMOV op -> EmitC) -//===----------------------------------------------------------------------===// - -struct PTOMovToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMovOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value fp; - if (op.getFp()) - fp = peelUnrealized(adaptor.getFp()); - Value preQuantScalar; - if (op.getPreQuantScalar()) - preQuantScalar = peelUnrealized(adaptor.getPreQuantScalar()); - - auto dstOT = mlir::dyn_cast(dst.getType()); - auto srcOT = mlir::dyn_cast(src.getType()); - if (!dstOT || !srcOT) - return rewriter.notifyMatchFailure( - op, "tmov lowering expects opaque dst/src types"); - - auto modeTok = [&](pto::AccToVecMode mode) -> StringRef { - switch (mode) { - case pto::AccToVecMode::SingleModeVec0: - return "pto::AccToVecMode::SingleModeVec0"; - case pto::AccToVecMode::SingleModeVec1: - return "pto::AccToVecMode::SingleModeVec1"; - case pto::AccToVecMode::DualModeSplitM: - return "pto::AccToVecMode::DualModeSplitM"; - case pto::AccToVecMode::DualModeSplitN: - return "pto::AccToVecMode::DualModeSplitN"; - } - llvm_unreachable("unknown AccToVecMode"); - }; - - auto modeAttr = op.getAccToVecModeAttr(); - auto reluTok = [&](pto::ReluPreMode mode) -> StringRef { - switch (mode) { - case pto::ReluPreMode::NoRelu: - return "ReluPreMode::NoRelu"; - case pto::ReluPreMode::NormalRelu: - return "ReluPreMode::NormalRelu"; - } - llvm_unreachable("unknown ReluPreMode"); - }; - - const bool hasFp = static_cast(fp); - const bool hasPreQuantScalar = static_cast(preQuantScalar); - const bool hasMode = static_cast(modeAttr); - const bool reluNonDefault = op.getReluPreMode() != pto::ReluPreMode::NoRelu; - - SmallVector operands{dst, src}; - SmallVector templateArgVec{ - emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), - }; - StringRef callee = "TMOV"; - - if (hasFp) { - auto fpOT = mlir::dyn_cast(fp.getType()); - if (!fpOT) - return rewriter.notifyMatchFailure( - op, "tmov fp lowering expects opaque fp type"); - operands.push_back(fp); - templateArgVec.push_back(emitc::OpaqueAttr::get(ctx, fpOT.getValue().str())); - if (hasMode) - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, modeTok(modeAttr.getValue()))); - if (hasMode || reluNonDefault) - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); - callee = hasMode ? "TMOV" : "TMOV_FP"; - } else if (hasPreQuantScalar) { - operands.push_back(preQuantScalar); - if (hasMode) - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, modeTok(modeAttr.getValue()))); - if (hasMode || reluNonDefault) - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); - } else if (hasMode) { - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, modeTok(modeAttr.getValue()))); - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); - } else if (reluNonDefault) { - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); - } - - ArrayAttr templateArgs = - templateArgVec.size() == 2 && !hasFp && !hasPreQuantScalar && - !hasMode && !reluNonDefault - ? ArrayAttr{} - : rewriter.getArrayAttr(templateArgVec); - - rewriter.create( - loc, TypeRange{}, callee, - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TMOV_FP DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOMovFPToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMovFPOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value fp = peelUnrealized(adaptor.getFp()); - - // TMOV_FP(dstTileData, cTile, fbTile) - ArrayAttr templateArgs; - auto dstOT = mlir::dyn_cast(dst.getType()); - auto srcOT = mlir::dyn_cast(src.getType()); - auto fpOT = mlir::dyn_cast(fp.getType()); - if (dstOT && srcOT && fpOT) { - templateArgs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, fpOT.getValue().str()), - }); - } else { - templateArgs = ArrayAttr{}; - } - - SmallVector operands{dst, src, fp}; - rewriter.create( - loc, TypeRange{}, "TMOV_FP", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOQuantToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TQuantOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value fp = peelUnrealized(adaptor.getFp()); - - // Optional offset (INT8_ASYM only): passed as pointer (&offset) - Value offsetPtr; - if (op.getOffset()) { - Value offset = peelUnrealized(adaptor.getOffset()); - auto offsetOT = mlir::dyn_cast(offset.getType()); - if (offsetOT) { - offsetPtr = rewriter - .create( - loc, emitc::PointerType::get(offsetOT), "&", offset) - .getResult(); - } - } - - // TQUANT(dst, src, fp[, &offset]) - std::string quantTypeStr = - op.getQuantType() == pto::QuantType::INT8_SYM - ? "pto::QuantType::INT8_SYM" - : "pto::QuantType::INT8_ASYM"; - ArrayAttr templateArgs; - auto dstOT = mlir::dyn_cast(dst.getType()); - auto srcOT = mlir::dyn_cast(src.getType()); - auto fpOT = mlir::dyn_cast(fp.getType()); - if (dstOT && srcOT && fpOT) { - templateArgs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, quantTypeStr), - emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, fpOT.getValue().str()), - }); - } else { - templateArgs = ArrayAttr{}; - } - - SmallVector operands{dst, src, fp}; - if (offsetPtr) - operands.push_back(offsetPtr); - - rewriter.create( - loc, TypeRange{}, "TQUANT", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTODequantToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TDequantOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value scale = peelUnrealized(adaptor.getScale()); - Value offset = peelUnrealized(adaptor.getOffset()); - - // TDEQUANT(dst, src, scale, offset) - ArrayAttr templateArgs; - auto dstOT = mlir::dyn_cast(dst.getType()); - auto srcOT = mlir::dyn_cast(src.getType()); - auto scaleOT = mlir::dyn_cast(scale.getType()); - if (dstOT && srcOT && scaleOT) { - templateArgs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, scaleOT.getValue().str()), - }); - } else { - templateArgs = ArrayAttr{}; - } - - rewriter.create( - loc, TypeRange{}, "TDEQUANT", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, - /*operands=*/SmallVector{dst, src, scale, offset}); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TMRGSORT DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOMrgSortToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMrgSortOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - if (op.isFormat1()) { - Value src = peelUnrealized(adaptor.getSrcs().front()); - Value dst = peelUnrealized(adaptor.getDsts().front()); - Value blockLen = peelUnrealized(adaptor.getBlockLen()); - - SmallVector operands{dst, src, blockLen}; - rewriter.create( - loc, TypeRange{}, "TMRGSORT", - ArrayAttr{}, ArrayAttr{}, operands); - } else if (op.isFormat2()) { - // pto-isa API: - // TMRGSORT( - // dst, executedNumList, tmp, src0, src1[, src2[, src3]]); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDsts()[0]); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value excuted = peelUnrealized(adaptor.getExcuted()); - - SmallVector srcs; - srcs.reserve(adaptor.getSrcs().size()); - for (Value v : adaptor.getSrcs()) - srcs.push_back(peelUnrealized(v)); - - auto dstOT = mlir::dyn_cast(dst.getType()); - auto tmpOT = mlir::dyn_cast(tmp.getType()); - if (!dstOT || !tmpOT || srcs.size() < 2 || srcs.size() > 4) - return op.emitOpError("format2 expects dst/tmp tilebufs and 2 to 4 srcs"); - - SmallVector targs; - targs.reserve(2 + srcs.size() + 1); - targs.push_back(emitc::OpaqueAttr::get(ctx, dstOT.getValue().str())); - targs.push_back(emitc::OpaqueAttr::get(ctx, tmpOT.getValue().str())); - for (Value v : srcs) { - auto ot = mlir::dyn_cast(v.getType()); - if (!ot) - return op.emitOpError("format2 expects tilebuf srcs"); - targs.push_back(emitc::OpaqueAttr::get(ctx, ot.getValue().str())); - } - targs.push_back(emitc::OpaqueAttr::get(ctx, op.getExhausted() ? "true" : "false")); - ArrayAttr templateArgs = rewriter.getArrayAttr(targs); - - SmallVector operands{dst, excuted, tmp}; - operands.append(srcs.begin(), srcs.end()); - - rewriter.create( - loc, TypeRange{}, "TMRGSORT", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, operands); - } else { - return op.emitOpError("unsupported mrgsort_dps format"); - } - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TMUL DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOMulToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMulOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TMUL", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TMULS DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOMulsToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMulSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc0()); - Value dst = peelUnrealized(adaptor.getDst()); - Value scalar = peelUnrealized(adaptor.getScalar()); - - SmallVector operands{dst, src, scalar}; - rewriter.create( - loc, TypeRange{}, "TMULS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TNEG DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTONegToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TNegOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TNEG", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TNOT DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTONotToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TNotOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TNOT", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TOR DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOOrToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TOrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TOR", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TORS DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOOrsToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TOrSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - // NOTE: The conversion type system may materialize integers as emitc.opaque - // (e.g. "int32_t"). For EmitC call emission we can pass the scalar through - // directly without arith casts here. - Value s = adaptor.getScalar(); - - SmallVector operands{dst, src0, s}; - rewriter.create( - loc, TypeRange{}, "TORS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TPARTADD DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOPartAddToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPartAddOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TPARTADD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TPARTMAX DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOPartMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPartMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TPARTMAX", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TPARTMIN DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOPartMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPartMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TPARTMIN", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOPartArgMaxToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPartArgMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value src0Idx = peelUnrealized(adaptor.getSrc0Idx()); - Value src1Idx = peelUnrealized(adaptor.getSrc1Idx()); - Value dst = peelUnrealized(adaptor.getDst()); - Value dstIdx = peelUnrealized(adaptor.getDstIdx()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TPARTARGMAX", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1, dstIdx, src0Idx, src1Idx}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOPartArgMinToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPartArgMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value src0Idx = peelUnrealized(adaptor.getSrc0Idx()); - Value src1Idx = peelUnrealized(adaptor.getSrc1Idx()); - Value dst = peelUnrealized(adaptor.getDst()); - Value dstIdx = peelUnrealized(adaptor.getDstIdx()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TPARTARGMIN", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1, dstIdx, src0Idx, src1Idx}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TPARTMUL DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOPartMulToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPartMulOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TPARTMUL", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TPRELU DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOPreluToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPReluOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - // C++ interface: TPRELU(dst, src0, src1, tmp) — last parameter is tmp. - SmallVector operands{dst, src0, src1, tmp}; - rewriter.create( - loc, TypeRange{}, "TPRELU", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TRECIP DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORecipToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRecipOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TRECIP", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TRELU DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOReluToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TReluOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TRELU", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TREM DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORemToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRemOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - SmallVector operands{dst, src0, src1, tmp}; - rewriter.create( - loc, TypeRange{}, "TREM", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOFModToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TFModOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TFMOD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TREMS DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORemSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRemSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - Value scalar = peelUnrealized(adaptor.getScalar()); - SmallVector operands{dst, src, scalar, tmp}; - rewriter.create( - loc, TypeRange{}, "TREMS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOFModSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TFModSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value scalar = peelUnrealized(adaptor.getScalar()); - - SmallVector operands{dst, src, scalar}; - rewriter.create( - loc, TypeRange{}, "TFMODS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWEXPAND DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORowExpandToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TROWEXPAND", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowExpandAddToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandAddOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TROWEXPANDADD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowExpandExpdifToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandExpdifOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src0, src1, tmp}); - else - operands.assign({dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TROWEXPANDEXPDIF", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWEXPANDDIV DPS/memref op) -//===----------------------------------------------------------------------===// -// Helper: replace or erase based on whether op has results. -static void replaceOrEraseWithOpaqueCall(Operation *op, - StringRef callee, - ArrayRef args, - ConversionPatternRewriter &rewriter) { - TypeRange resultTypes = op->getResultTypes(); - auto call = rewriter.create( - op->getLoc(), resultTypes, callee, ArrayAttr{}, ArrayAttr{}, ValueRange(args)); - if (resultTypes.empty()) - rewriter.eraseOp(op); - else - rewriter.replaceOp(op, call.getResults()); -} - -static void replaceOrEraseWithOpaqueCallAndReturnDst(Operation *op, Value dst, - StringRef callee, - ArrayRef args, - ConversionPatternRewriter &rewriter) { - rewriter.create( - op->getLoc(), TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, ValueRange(args)); - if (op->getNumResults() == 1) - rewriter.replaceOp(op, dst); - else - rewriter.eraseOp(op); -} - -// ---------- TOp ---------- -struct PTOTGemvBiasToTGEMV_BIAS - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGemvBiasOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getA()); - Value b = peelUnrealized(adaptor.getB()); - Value bias = peelUnrealized(adaptor.getBias()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCall(op.getOperation(), "TGEMV_BIAS", - {dst, a, b, bias}, rewriter); - return success(); - } -}; - -struct PTOTGemvMXToTGEMV_MX - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGemvMxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getA()); - Value aScale = peelUnrealized(adaptor.getAScale()); - Value b = peelUnrealized(adaptor.getB()); - Value bScale = peelUnrealized(adaptor.getBScale()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", - {dst, a, aScale, b, bScale}, rewriter); - return success(); - } -}; - -struct PTOTGemvMXAccToTGEMV_MX - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGemvMxAccOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value cIn = peelUnrealized(adaptor.getCIn()); - Value a = peelUnrealized(adaptor.getA()); - Value aScale = peelUnrealized(adaptor.getAScale()); - Value b = peelUnrealized(adaptor.getB()); - Value bScale = peelUnrealized(adaptor.getBScale()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", - {dst, cIn, a, aScale, b, bScale}, rewriter); - return success(); - } -}; - -struct PTOTGemvMXBiasToTGEMV_MX - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGemvMxBiasOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getA()); - Value aScale = peelUnrealized(adaptor.getAScale()); - Value b = peelUnrealized(adaptor.getB()); - Value bScale = peelUnrealized(adaptor.getBScale()); - Value bias = peelUnrealized(adaptor.getBias()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", - {dst, a, aScale, b, bScale, bias}, rewriter); - return success(); - } -}; - -struct PTOTMatmulBiasToTMATMUL_BIAS - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMatmulBiasOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getA()); - Value b = peelUnrealized(adaptor.getB()); - Value bias = peelUnrealized(adaptor.getBias()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_BIAS", - {dst, a, b, bias}, rewriter); - return success(); - } -}; - -struct PTOTMatmulMXToTMATMUL_MX - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMatmulMxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getA()); - Value aScale = peelUnrealized(adaptor.getAScale()); - Value b = peelUnrealized(adaptor.getB()); - Value bScale = peelUnrealized(adaptor.getBScale()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_MX", - {dst, a, aScale, b, bScale}, rewriter); - return success(); - } -}; - -struct PTOTMatmulMXAccToTMATMUL_MX_ACC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMatmulMxAccOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value cIn = peelUnrealized(adaptor.getCIn()); - Value a = peelUnrealized(adaptor.getA()); - Value aScale = peelUnrealized(adaptor.getAScale()); - Value b = peelUnrealized(adaptor.getB()); - Value bScale = peelUnrealized(adaptor.getBScale()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_MX", - {dst, cIn, a, aScale, b, bScale}, rewriter); - return success(); - } -}; - -struct PTOTMatmulMXBiasToTMATMUL_MX_BIAS - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMatmulMxBiasOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getA()); - Value aScale = peelUnrealized(adaptor.getAScale()); - Value b = peelUnrealized(adaptor.getB()); - Value bScale = peelUnrealized(adaptor.getBScale()); - Value bias = peelUnrealized(adaptor.getBias()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_MX", - {dst, a, aScale, b, bScale, bias}, rewriter); - return success(); - } -}; - -struct PTORowExpandDivToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandDivOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src0, src1, tmp}); - else - operands.assign({dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TROWEXPANDDIV", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWEXPANDMUL DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORowExpandMulToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandMulOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src0, src1, tmp}); - else - operands.assign({dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TROWEXPANDMUL", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWEXPANDSUB DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORowExpandSubToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandSubOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src0, src1, tmp}); - else - operands.assign({dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TROWEXPANDSUB", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowExpandMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src0, src1, tmp}); - else - operands.assign({dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TROWEXPANDMAX", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowExpandMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src0, src1, tmp}); - else - operands.assign({dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TROWEXPANDMIN", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWMAX DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORowMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, tmp}; - rewriter.create( - loc, TypeRange{}, "TROWMAX", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowArgMaxToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowArgMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TROWARGMAX", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, tmp}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWMIN DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORowMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, tmp}; - rewriter.create( - loc, TypeRange{}, "TROWMIN", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowArgMinToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowArgMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TROWARGMIN", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, tmp}); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWSUM DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORowSumToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowSumOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, tmp}; - rewriter.create( - loc, TypeRange{}, "TROWSUM", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowProdToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowProdOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, tmp}; - rewriter.create( - loc, TypeRange{}, "TROWPROD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TRSQRT DPS/memref op) -// - no-tmp form : TRSQRT(dst, src) -// - tmp form : TRSQRT(dst, src, tmp) -//===----------------------------------------------------------------------===// - -struct PTORsqrtToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRsqrtOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - SmallVector operands{dst, src}; - if (Value tmp = adaptor.getTmp()) - operands.push_back(peelUnrealized(tmp)); - rewriter.create( - loc, TypeRange{}, "TRSQRT", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSCATTER DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOScatterToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TScatterOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - const bool hasMaskPattern = static_cast(op.getMaskPatternAttr()); - const bool hasIndexes = static_cast(op.getIndexes()); - if (hasMaskPattern == hasIndexes) { - return rewriter.notifyMatchFailure( - op, "expected exactly one of indexes operand or maskPattern attribute"); - } - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - if (auto mp = op.getMaskPatternAttr()) { - auto *ctx = rewriter.getContext(); - auto targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, maskPatternTok(mp)), - }); - rewriter.create( - loc, TypeRange{}, "TSCATTER", - /*args=*/ArrayAttr{}, /*templateArgs=*/targs, - /*operands=*/ValueRange{dst, src}); - } else { - Value idx = peelUnrealized(adaptor.getIndexes()); - rewriter.create( - loc, TypeRange{}, "TSCATTER", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, idx}); - } - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSEL DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSelToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSelOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value mask = peelUnrealized(adaptor.getMask()); - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, mask, src0, src1, tmp}; - rewriter.create( - loc, TypeRange{}, "TSEL", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSELS DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSelSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSelSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value mask = peelUnrealized(adaptor.getMask()); - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, mask, src, tmp, scalar}; - rewriter.create( - loc, TypeRange{}, "TSELS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSHL DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOShlSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TShlOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TSHL", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSHR DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOShrSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TShrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TSHR", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering for TSHLS/TSHRS DPS: shift by scalar) -//===----------------------------------------------------------------------===// - -struct PTOShlSConstToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TShlSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - SmallVector operands{dst, src, scalar}; - rewriter.create( - loc, TypeRange{}, "TSHLS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOShrSConstToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TShrSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - SmallVector operands{dst, src, scalar}; - rewriter.create( - loc, TypeRange{}, "TSHRS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (TSORT32 DPS/memref op: ins(src, idx[, tmp]) outs(dst)) -//===----------------------------------------------------------------------===// - -struct PTOSORT32SToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSort32Op op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value idx = peelUnrealized(adaptor.getIdx()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src, idx, tmp}); - else - operands.assign({dst, src, idx}); - rewriter.create( - loc, TypeRange{}, "TSORT32", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSQRT DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSqrtSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSqrtOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TSQRT", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSTORE_FP DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOStoreFPSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TStoreFPOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value fp = peelUnrealized(adaptor.getFp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, fp}; - rewriter.create( - loc, TypeRange{}, "TSTORE_FP", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSUB DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSubSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSubOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TSUB", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSUBC DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSubCSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSubCOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value src2 = peelUnrealized(adaptor.getSrc2()); - Value dst = peelUnrealized(adaptor.getDst()); - - // pto-isa does not provide NPU implementation for TSUBC yet. - // Decompose: dst = src0 - src1 + src2 - rewriter.create( - loc, TypeRange{}, "TSUB", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TADD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, dst, src2}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSUBS DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSubSSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSubSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, scalar}; - rewriter.create( - loc, TypeRange{}, "TSUBS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSUBSC DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSubSCToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSubSCOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - // pto-isa does not provide NPU implementation for TSUBSC yet. - // Decompose: dst = src0 - scalar + src1 - rewriter.create( - loc, TypeRange{}, "TSUBS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, scalar}); - rewriter.create( - loc, TypeRange{}, "TADD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, dst, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TXOR DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOXORToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TXorOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = peelUnrealized(adaptor.getTmp()); - SmallVector operands{dst, src0, src1, tmp}; - rewriter.create( - loc, TypeRange{}, "TXOR", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTOTTransToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TTransOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, tmp}; - rewriter.create( - loc, TypeRange{}, "TTRANS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TXORS DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOXORSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TXorSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, scalar, tmp}; - rewriter.create( - loc, TypeRange{}, "TXORS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - struct PTOPrintToTPRINT : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPrintOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - - SmallVector operands{src}; - rewriter.create( - loc, TypeRange{}, "TPRINT", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -// pto.print "format", %scalar -> PRINTF("format", scalar) -struct PTOPrintOpToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::PrintOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - std::string fmt = op.getFormat().str(); - if (fmt.empty()) - fmt = "%f"; - std::string quoted = "\""; - for (char c : fmt) { - if (c == '"' || c == '\\') - quoted += '\\'; - else if (c == '\n') - quoted += "\\n"; - else if (c == '\t') - quoted += "\\t"; - else - quoted += c; - } - quoted += "\""; - - Value scalar = peelUnrealized(adaptor.getScalar()); - auto argsAttr = rewriter.getArrayAttr( - {emitc::OpaqueAttr::get(ctx, quoted), - IntegerAttr::get(IndexType::get(ctx), 0)}); - rewriter.create( - loc, TypeRange{}, "cce::printf", - /*args=*/argsAttr, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; - -// pto.trap -> TRAP() -struct PTOTrapOpToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TrapOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - rewriter.create( - loc, TypeRange{}, "trap", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{}); - - rewriter.eraseOp(op); - return success(); - } -}; - -// ============================================================================= -// 2. BindTileOp Lowering (FIX: Trace back to physical address) -// ============================================================================= -struct PTOBindTileToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - struct TileBuildSpec { - std::string tileTypeStr; - bool useConstructor = false; - SmallVector constructorArgs; - }; - - static bool getIndexConst(Value v, int64_t &out) { - if (!v) - return false; - if (auto cst = v.getDefiningOp()) { - if (auto ia = dyn_cast(cst.getValue())) { - out = ia.getValue().getSExtValue(); - return true; - } - } - return false; - } - - static bool getTilePointerStrides(pto::TileBufConfigAttr configAttr, - Type elemTy, int64_t rows, int64_t cols, - int64_t &rowStride, - int64_t &colStride) { - if (rows == ShapedType::kDynamic || cols == ShapedType::kDynamic) - return false; - - int32_t blVal = 0; - if (auto blAttr = dyn_cast(configAttr.getBLayout())) - blVal = static_cast(blAttr.getValue()); - else if (auto intAttr = dyn_cast(configAttr.getBLayout())) - blVal = static_cast(intAttr.getInt()); - - int32_t slVal = 0; - if (auto slAttr = dyn_cast(configAttr.getSLayout())) - slVal = static_cast(slAttr.getValue()); - else if (auto intAttr = dyn_cast(configAttr.getSLayout())) - slVal = static_cast(intAttr.getInt()); - - bool boxed = slVal != 0; - int64_t innerRows = 1; - int64_t innerCols = 1; - if (boxed) { - int32_t fractal = 512; - if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) - fractal = static_cast(frAttr.getInt()); - - unsigned elemBytes = pto::getPTOStorageElemByteSize(elemTy); - if (elemBytes == 0) - return false; - - switch (fractal) { - case 1024: - innerRows = 16; - innerCols = 16; - break; - case 32: - innerRows = 16; - innerCols = 2; - break; - case 512: - if (slVal == 1) { - innerRows = 16; - innerCols = 32 / elemBytes; - } else if (slVal == 2) { - innerRows = 32 / elemBytes; - innerCols = 16; - } else { - return false; - } - break; - default: - return false; - } - if (innerRows <= 0 || innerCols <= 0) - return false; - } - - if (!boxed) { - if (blVal == 1) { - rowStride = 1; - colStride = rows; - } else { - rowStride = cols; - colStride = 1; - } - return true; - } - - if (blVal == 1) { - if (slVal != 1) - return false; - rowStride = innerCols; - colStride = rows; - return true; - } - - rowStride = cols; - colStride = innerRows; - return true; - } - - LogicalResult matchAndRewrite(pto::BindTileOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - auto configAttr = op.getConfigAttr(); - auto viewSemantics = op->getAttrOfType("pto.view_semantics"); - bool isSubView = viewSemantics && viewSemantics.getValue() == "subview"; - - auto peelAllCasts = [](Value v) { - while (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(0); - if (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(); - return v; - }; - auto isTileLike = [](Value v) -> bool { - auto ot = dyn_cast(v.getType()); - if (!ot) - return false; - StringRef s = ot.getValue(); - return s.contains("Tile<") || s.contains("ConvTile<"); - }; - auto buildTileSpec = [&]() -> FailureOr { - auto resMrTy = dyn_cast(op.getType()); - if (!resMrTy) - return failure(); - - const char *roleTok = "TileType::Vec"; - if (auto asAttr = - dyn_cast_or_null(resMrTy.getMemorySpace())) { - switch (asAttr.getAddressSpace()) { - case pto::AddressSpace::VEC: - roleTok = "TileType::Vec"; - break; - case pto::AddressSpace::MAT: - roleTok = "TileType::Mat"; - break; - case pto::AddressSpace::LEFT: - roleTok = "TileType::Left"; - break; - case pto::AddressSpace::RIGHT: - roleTok = "TileType::Right"; - break; - case pto::AddressSpace::ACC: - roleTok = "TileType::Acc"; - break; - case pto::AddressSpace::BIAS: - roleTok = "TileType::Bias"; - break; - case pto::AddressSpace::SCALING: - roleTok = "TileType::Scaling"; - break; - case pto::AddressSpace::GM: - case pto::AddressSpace::Zero: - roleTok = "TileType::Vec"; - break; - } - } - - Type elemTy = resMrTy.getElementType(); - Type emitElemTy = getTypeConverter()->convertType(elemTy); - if (!emitElemTy) - return failure(); - auto emitElemOpaque = dyn_cast(emitElemTy); - if (!emitElemOpaque) - return failure(); - std::string elemTypeStr = emitElemOpaque.getValue().str(); - - if (resMrTy.getRank() < 2) - return failure(); - int64_t rows = resMrTy.getDimSize(0); - int64_t cols = resMrTy.getDimSize(1); - if (rows == ShapedType::kDynamic || cols == ShapedType::kDynamic) - return failure(); - - std::string blTok = "BLayout::RowMajor"; - if (auto blAttr = dyn_cast(configAttr.getBLayout())) { - if (static_cast(blAttr.getValue()) == 1) - blTok = "BLayout::ColMajor"; - } - pto::BLayout blayout = getTileBufBLayoutValue(configAttr); - - if (isSubView) { - auto subMrTy = dyn_cast(op.getSource().getType()); - auto subViewOp = op.getSource().getDefiningOp(); - if (subMrTy && subMrTy.getRank() >= 2 && subViewOp) { - int64_t subRows = subMrTy.getDimSize(0); - int64_t subCols = subMrTy.getDimSize(1); - SmallVector inheritedStrides; - int64_t inheritedOffset = ShapedType::kDynamic; - - if (!pto::isPTOFloat4PackedType(elemTy) && - subRows != ShapedType::kDynamic && - subCols != ShapedType::kDynamic && - succeeded(getStridesAndOffset(subMrTy, inheritedStrides, - inheritedOffset)) && - inheritedStrides.size() >= 2) { - int64_t childRowStride = 0; - int64_t childColStride = 0; - bool sameStrides = getTilePointerStrides( - configAttr, elemTy, subRows, subCols, childRowStride, - childColStride); - sameStrides = sameStrides && - inheritedStrides[0] == childRowStride && - inheritedStrides[1] == childColStride; - if (sameStrides) { - rows = subRows; - cols = subCols; - } - } - } - } - - std::string slTok = "SLayout::NoneBox"; - if (auto slAttr = dyn_cast(configAttr.getSLayout())) { - int32_t slVal = static_cast(slAttr.getValue()); - slTok = (slVal == 1) ? "SLayout::RowMajor" - : (slVal == 2) ? "SLayout::ColMajor" - : "SLayout::NoneBox"; - } - - int32_t fractal = 512; - if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) - fractal = frAttr.getInt(); - - std::string padTok = "PadValue::Null"; - if (auto padAttr = dyn_cast(configAttr.getPad())) { - switch (static_cast(padAttr.getValue())) { - case 1: - padTok = "PadValue::Zero"; - break; - case 2: - padTok = "PadValue::Max"; - break; - case 3: - padTok = "PadValue::Min"; - break; - default: - padTok = "PadValue::Null"; - break; - } - } - - std::string compactTok = "CompactMode::Null"; - if (auto compactAttr = dyn_cast(configAttr.getCompactMode())) { - switch (static_cast(compactAttr.getValue())) { - case 1: - compactTok = "CompactMode::Normal"; - break; - case 2: - compactTok = "CompactMode::RowPlusOne"; - break; - default: - compactTok = "CompactMode::Null"; - break; - } - } - - std::string vrowTok, vcolTok; - bool useConstructor = false; - bool rowIsDynamic = false; - bool colIsDynamic = false; - SmallVector constructorArgs; - - Value vRow = op.getValidRow(); - Value vCol = op.getValidCol(); - Value vRowEmitC = adaptor.getValidRow(); - Value vColEmitC = adaptor.getValidCol(); - bool forceDynamicValid = op->hasAttr(kForceDynamicValidShapeAttrName); - int64_t cRow = 0, cCol = 0; - bool rowIsConst = vRow && getIndexConst(vRow, cRow); - bool colIsConst = vCol && getIndexConst(vCol, cCol); - - auto makeCtorDimValue = [&](Value emitted, int64_t fallback) -> Value { - if (emitted) - return emitted; - return makeEmitCIntConstant( - rewriter, loc, emitc::OpaqueType::get(ctx, "int32_t"), fallback); - }; - auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { - if (!emitted || !pto::isPTOFloat4PackedType(elemTy)) - return emitted; - int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; - if (dimIdx != packedDim) - return emitted; - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); - return rewriter.create(loc, i32Ty, emitted, two).getResult(); - }; - - if (forceDynamicValid) { - vrowTok = "-1"; - vcolTok = "-1"; - useConstructor = true; - constructorArgs.push_back( - makeCtorDimValue(maybeScaleDynamicValid(vRowEmitC, 0), - renderTileTemplateDim(rowIsConst ? cRow : rows, - elemTy, blayout, 0))); - constructorArgs.push_back( - makeCtorDimValue(maybeScaleDynamicValid(vColEmitC, 1), - renderTileTemplateDim(colIsConst ? cCol : cols, - elemTy, blayout, 1))); - } else { - if (rowIsConst) { - vrowTok = std::to_string( - renderTileTemplateDim(cRow, elemTy, blayout, 0)); - } else if (vRow) { - vrowTok = "-1"; - rowIsDynamic = true; - useConstructor = true; - } else { - vrowTok = std::to_string( - renderTileTemplateDim(rows, elemTy, blayout, 0)); - } - - if (colIsConst) { - vcolTok = std::to_string( - renderTileTemplateDim(cCol, elemTy, blayout, 1)); - } else if (vCol) { - vcolTok = "-1"; - colIsDynamic = true; - useConstructor = true; - } else { - vcolTok = std::to_string( - renderTileTemplateDim(cols, elemTy, blayout, 1)); - } - - if (useConstructor) { - if (rowIsDynamic && vRowEmitC) - constructorArgs.push_back(maybeScaleDynamicValid(vRowEmitC, 0)); - if (colIsDynamic && vColEmitC) - constructorArgs.push_back(maybeScaleDynamicValid(vColEmitC, 1)); - } - } - - std::string tileTypeStr = std::string("Tile<") + roleTok + ", " + - elemTypeStr + ", " + - std::to_string(renderTileTemplateDim( - rows, elemTy, blayout, 0)) + - ", " + - std::to_string(renderTileTemplateDim( - cols, elemTy, blayout, 1)) + - ", " + blTok + - ", " + vrowTok + ", " + vcolTok + ", " + slTok + - ", " + std::to_string(fractal) + ", " + padTok + - ", " + compactTok + - ">"; - return TileBuildSpec{tileTypeStr, useConstructor, constructorArgs}; - }; - - auto buildTileValue = [&](const TileBuildSpec &spec, - bool forceDeclaration = false) -> Value { - auto tileType = emitc::OpaqueType::get(ctx, spec.tileTypeStr); - if (spec.useConstructor && !forceDeclaration) { - return rewriter - .create(loc, tileType, spec.tileTypeStr, - ArrayAttr{}, ArrayAttr{}, - ValueRange(spec.constructorArgs)) - .getResult(0); - } - - return rewriter - .create(loc, tileType, emitc::OpaqueAttr::get(ctx, "")) - .getResult(); - }; - - auto emitElemTypeToString = [&](Type elemTy) -> std::string { - return getEmitCScalarTypeToken(elemTy); - }; - - auto buildIntegralAddress = [&](Value sourceValue) -> FailureOr { - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - auto rcU64 = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - - Value rawPtr = sourceValue; - if (auto ot = dyn_cast(sourceValue.getType())) { - StringRef tyStr = ot.getValue(); - if (tyStr.contains("Tile<") || tyStr.contains("ConvTile<")) { - auto srcMrTy = dyn_cast(op.getSource().getType()); - if (!srcMrTy) - return failure(); - std::string elemTok = emitElemTypeToString(srcMrTy.getElementType()); - pto::AddressSpace as = pto::AddressSpace::GM; - if (auto asAttr = - dyn_cast_or_null(srcMrTy.getMemorySpace())) - as = asAttr.getAddressSpace(); - rawPtr = materializeTileDataValue(rewriter, loc, sourceValue, as, - elemTok); - } - } - - if (isSetFFTsPointerLikeType(rawPtr.getType())) { - return rewriter - .create(loc, u64Ty, "reinterpret_cast", - ArrayAttr{}, rcU64, ValueRange{rawPtr}) - .getResult(0); - } - - if (rawPtr.getType() == u64Ty) - return rawPtr; - return rewriter.create(loc, u64Ty, rawPtr).getResult(); - }; - - if (op.getSource().getDefiningOp()) { - FailureOr tileSpec = buildTileSpec(); - if (failed(tileSpec)) - return failure(); - rewriter.replaceOp(op, buildTileValue(*tileSpec)); - return success(); - } - - Value tileCandidate = peelAllCasts(adaptor.getSource()); - if (viewSemantics && viewSemantics.getValue() == "bitcast" && - isTileLike(tileCandidate)) { - FailureOr tileSpec = buildTileSpec(); - if (failed(tileSpec)) - return failure(); - Value dstTile = buildTileValue(*tileSpec); - FailureOr addr = buildIntegralAddress(tileCandidate); - if (failed(addr)) - return failure(); - - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dstTile, *addr}); - rewriter.replaceOp(op, dstTile); - return success(); - } - - if (viewSemantics && viewSemantics.getValue() == "treshape" && - isTileLike(tileCandidate)) { - FailureOr tileSpec = buildTileSpec(); - if (failed(tileSpec)) - return failure(); - Value dstTile = buildTileValue(*tileSpec, /*forceDeclaration=*/true); - - rewriter.create(loc, TypeRange{}, "TRESHAPE", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dstTile, tileCandidate}); - rewriter.replaceOp(op, dstTile); - return success(); - } - - // Subview origins are kept distinct from generic tile rebinding: - // even when source/destination C++ tile types match, subview may carry - // shifted base address semantics and should materialize a fresh handle. - if (isSubView) { - FailureOr tileSpec = buildTileSpec(); - if (failed(tileSpec)) - return failure(); - Value dstTile = buildTileValue(*tileSpec); - FailureOr addr = buildIntegralAddress(tileCandidate); - if (failed(addr)) - return failure(); - - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dstTile, *addr}); - rewriter.replaceOp(op, dstTile); - return success(); - } - - // Generic tile-to-tile rebind path: preserve the same backing storage and - // rebuild a sibling tile with updated metadata/valid dims. - if (isTileLike(tileCandidate)) { - FailureOr tileSpec = buildTileSpec(); - if (failed(tileSpec)) - return failure(); - - if (!tileSpec->useConstructor) { - if (auto srcTy = dyn_cast(tileCandidate.getType())) { - if (srcTy.getValue() == tileSpec->tileTypeStr) { - rewriter.replaceOp(op, tileCandidate); - return success(); - } - } - } - - Value dstTile = buildTileValue(*tileSpec); - FailureOr addr = buildIntegralAddress(tileCandidate); - if (failed(addr)) - return failure(); - - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dstTile, *addr}); - rewriter.replaceOp(op, dstTile); - return success(); - } - - SmallVector physAddrs; - Value source = op.getSource(); - - while (auto castOp = source.getDefiningOp()) - source = castOp.getOperand(0); - - if (auto upstreamCast = source.getDefiningOp()) { - auto upstreamOperands = upstreamCast.getAddrs(); - physAddrs.append(upstreamOperands.begin(), upstreamOperands.end()); - } else { - physAddrs.push_back(adaptor.getSource()); - } - - Value vRow = op.getValidRow(); - Value vCol = op.getValidCol(); - - auto newCast = rewriter.create( - loc, op.getType(), physAddrs, vRow ? vRow : Value(), - vCol ? vCol : Value(), configAttr); - if (viewSemantics) - newCast->setAttr("pto.view_semantics", viewSemantics); - if (op->hasAttr(kForceDynamicValidShapeAttrName)) - newCast->setAttr(kForceDynamicValidShapeAttrName, - op->getAttr(kForceDynamicValidShapeAttrName)); - rewriter.replaceOp(op, newCast.getResult()); - - return success(); - } -}; - -struct PTOAllocTileToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::AllocTileOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - MLIRContext *ctx = rewriter.getContext(); - auto tileTy = cast(op.getResult().getType()); - auto tileTypeString = getEmitCTileTypeString(tileTy); - if (!tileTypeString) - return rewriter.notifyMatchFailure( - op, "only rank-2 alloc_tile handles can be converted to EmitC"); - - Type convertedTy = getTypeConverter()->convertType(tileTy); - if (!convertedTy) - convertedTy = emitc::OpaqueType::get(ctx, *tileTypeString); - - auto validShape = tileTy.getValidShape(); - bool hasDynamicValidDim = - llvm::any_of(validShape, [](int64_t dim) { return dim < 0; }); - bool useConstructor = hasDynamicValidDim; - - SmallVector constructorArgs; - if (useConstructor) { - Type elemTy = tileTy.getElementType(); - pto::BLayout blayout = getTileBufBLayoutValue(tileTy.getConfigAttr()); - auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { - if (!emitted || !pto::isPTOFloat4PackedType(elemTy)) - return emitted; - int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; - if (dimIdx != packedDim) - return emitted; - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); - return rewriter.create(loc, i32Ty, emitted, two) - .getResult(); - }; - - if (validShape.size() > 0 && validShape[0] < 0) { - Value validRow = adaptor.getValidRow(); - if (!validRow) - return rewriter.notifyMatchFailure( - op, "dynamic alloc_tile valid row must have an operand"); - if (validRow) - validRow = peelUnrealized(validRow); - constructorArgs.push_back(maybeScaleDynamicValid(validRow, 0)); - } - if (validShape.size() > 1 && validShape[1] < 0) { - Value validCol = adaptor.getValidCol(); - if (!validCol) - return rewriter.notifyMatchFailure( - op, "dynamic alloc_tile valid col must have an operand"); - if (validCol) - validCol = peelUnrealized(validCol); - constructorArgs.push_back(maybeScaleDynamicValid(validCol, 1)); - } - } - - Value tile; - if (useConstructor) { - tile = rewriter - .create( - loc, convertedTy, *tileTypeString, ArrayAttr{}, - ArrayAttr{}, ValueRange(constructorArgs)) - .getResult(0); - } else { - tile = - rewriter - .create( - loc, convertedTy, emitc::OpaqueAttr::get(ctx, "")) - .getResult(); - } - - Value addr = adaptor.getAddr(); - if (addr) { - addr = peelUnrealized(addr); - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - if (isa(addr.getType()) || - (isa(addr.getType()) && - cast(addr.getType()).getValue().ends_with("*"))) { - auto rcU64 = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - addr = rewriter - .create(loc, u64Ty, "reinterpret_cast", - ArrayAttr{}, rcU64, - ValueRange{addr}) - .getResult(0); - } else if (addr.getType() != u64Ty) { - addr = rewriter.create(loc, u64Ty, addr).getResult(); - } - - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{tile, addr}); - } - - rewriter.replaceOp(op, tile); - return success(); - } -}; - -static FailureOr -createEmitCTileVariable(ConversionPatternRewriter &rewriter, Location loc, - const TypeConverter *typeConverter, - pto::TileBufType tileTy) { - auto tileTypeString = getEmitCTileTypeString(tileTy); - if (!tileTypeString) - return failure(); - - Type convertedTy = typeConverter->convertType(tileTy); - if (!convertedTy) - convertedTy = emitc::OpaqueType::get(rewriter.getContext(), *tileTypeString); - - return rewriter - .create( - loc, convertedTy, emitc::OpaqueAttr::get(rewriter.getContext(), "")) - .getResult(); -} - -struct PTOTReshapeToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TReshapeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto tileTy = dyn_cast(op.getResult().getType()); - if (!tileTy) - return failure(); - - FailureOr dst = - createEmitCTileVariable(rewriter, op.getLoc(), getTypeConverter(), tileTy); - if (failed(dst)) - return failure(); - - Value src = peelUnrealized(adaptor.getSrc()); - if (auto castOp = src.getDefiningOp()) - src = castOp.getOperand(); - - rewriter.create(op.getLoc(), TypeRange{}, "TRESHAPE", - ArrayAttr{}, ArrayAttr{}, - ValueRange{*dst, src}); - rewriter.replaceOp(op, *dst); - return success(); - } -}; - -struct PTOBitcastToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::BitcastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto dstTy = dyn_cast(op.getResult().getType()); - auto srcTy = dyn_cast(op.getSrc().getType()); - if (!dstTy || !srcTy) - return failure(); - - FailureOr dst = - createEmitCTileVariable(rewriter, op.getLoc(), getTypeConverter(), dstTy); - if (failed(dst)) - return failure(); - - Value src = peelUnrealized(adaptor.getSrc()); - if (auto castOp = src.getDefiningOp()) - src = castOp.getOperand(); - - pto::AddressSpace as = pto::AddressSpace::GM; - if (auto asAttr = - dyn_cast_or_null(srcTy.getMemorySpace())) - as = asAttr.getAddressSpace(); - std::string elemTok = getEmitCScalarTypeToken(srcTy.getElementType()); - - Value rawPtr = materializeTileDataValue(rewriter, op.getLoc(), src, as, elemTok); - auto u64Ty = emitc::OpaqueType::get(rewriter.getContext(), "uint64_t"); - Value addr = rawPtr; - if (isSetFFTsPointerLikeType(rawPtr.getType())) { - auto rcU64 = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), - "uint64_t")}); - addr = rewriter - .create(op.getLoc(), u64Ty, - "reinterpret_cast", ArrayAttr{}, - rcU64, ValueRange{rawPtr}) - .getResult(0); - } else if (addr.getType() != u64Ty) { - addr = rewriter.create(op.getLoc(), u64Ty, addr).getResult(); - } - - rewriter.create(op.getLoc(), TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{*dst, addr}); - rewriter.replaceOp(op, *dst); - return success(); - } -}; - -struct PTOMaterializeTileToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - static bool isTileLike(Value v) { - auto ot = dyn_cast(v.getType()); - if (!ot) - return false; - StringRef s = ot.getValue(); - return s.contains("Tile<") || s.contains("ConvTile<"); - } - - LogicalResult matchAndRewrite(pto::MaterializeTileOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - MLIRContext *ctx = rewriter.getContext(); - auto tileTy = cast(op.getResult().getType()); - auto tileTypeString = getEmitCTileTypeString(tileTy); - if (!tileTypeString) - return rewriter.notifyMatchFailure( - op, "only rank-2 tile_buf handles can be materialized to EmitC"); - - Type convertedTy = getTypeConverter()->convertType(tileTy); - if (!convertedTy) - convertedTy = emitc::OpaqueType::get(ctx, *tileTypeString); - - Value source = peelUnrealized(adaptor.getSource()); - if (auto castOp = source.getDefiningOp()) - source = castOp.getOperand(); - - auto viewSemantics = op->getAttrOfType("pto.view_semantics"); - bool forceDynamicValid = op->hasAttr(kForceDynamicValidShapeAttrName); - bool isReshape = viewSemantics && viewSemantics.getValue() == "treshape"; - bool isSubview = viewSemantics && viewSemantics.getValue() == "subview"; - bool sourceIsDeclaredTile = - op.getSource().getDefiningOp(); - - auto createTileValue = [&]() -> Value { - SmallVector constructorArgs; - bool useConstructor = false; - pto::BLayout blayout = getTileBufBLayoutValue(tileTy.getConfigAttr()); - Type elemTy = tileTy.getElementType(); - auto shape = tileTy.getShape(); - auto validShape = tileTy.getValidShape(); - - auto makeCtorDimValue = [&](Value emitted, int64_t fallback) -> Value { - if (emitted) - return emitted; - return makeEmitCIntConstant( - rewriter, loc, emitc::OpaqueType::get(ctx, "int32_t"), fallback); - }; - auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { - if (!emitted || !pto::isPTOFloat4PackedType(elemTy)) - return emitted; - int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; - if (dimIdx != packedDim) - return emitted; - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); - return rewriter.create(loc, i32Ty, emitted, two).getResult(); - }; - auto fallbackDim = [&](int dimIdx) { - return renderTileTemplateDim(shape[dimIdx], elemTy, blayout, dimIdx); - }; - - if (forceDynamicValid) { - useConstructor = true; - constructorArgs.push_back(makeCtorDimValue( - maybeScaleDynamicValid(adaptor.getValidRow(), 0), fallbackDim(0))); - constructorArgs.push_back(makeCtorDimValue( - maybeScaleDynamicValid(adaptor.getValidCol(), 1), fallbackDim(1))); - } else { - if (validShape[0] == ShapedType::kDynamic) { - useConstructor = true; - constructorArgs.push_back(makeCtorDimValue( - maybeScaleDynamicValid(adaptor.getValidRow(), 0), fallbackDim(0))); - } - if (validShape[1] == ShapedType::kDynamic) { - useConstructor = true; - constructorArgs.push_back(makeCtorDimValue( - maybeScaleDynamicValid(adaptor.getValidCol(), 1), fallbackDim(1))); - } - } - - if (useConstructor) { - return rewriter - .create(loc, convertedTy, *tileTypeString, - ArrayAttr{}, ArrayAttr{}, - ValueRange(constructorArgs)) - .getResult(0); - } - - return rewriter - .create(loc, convertedTy, - emitc::OpaqueAttr::get(ctx, "")) - .getResult(); - }; - - if (!isSubview && !forceDynamicValid && isTileLike(source)) { - if (auto srcTy = dyn_cast(source.getType())) { - if (srcTy.getValue() == *tileTypeString) { - rewriter.replaceOp(op, source); - return success(); - } - } - } - - Value tile = createTileValue(); - if (sourceIsDeclaredTile) { - rewriter.replaceOp(op, tile); - return success(); - } - - if (isReshape && isTileLike(source)) { - rewriter.create(loc, TypeRange{}, "TRESHAPE", - ArrayAttr{}, ArrayAttr{}, - ValueRange{tile, source}); - rewriter.replaceOp(op, tile); - return success(); - } - - pto::AddressSpace as = pto::AddressSpace::GM; - if (auto asAttr = - dyn_cast_or_null(tileTy.getMemorySpace())) - as = asAttr.getAddressSpace(); - std::string elemTok = getEmitCScalarTypeToken(tileTy.getElementType()); - - Value rawPtr = source; - if (isTileLike(rawPtr)) - rawPtr = materializeTileDataValue(rewriter, loc, rawPtr, as, elemTok); - - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - Value addr = rawPtr; - if (isSetFFTsPointerLikeType(rawPtr.getType())) { - auto rcU64 = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - addr = rewriter - .create(loc, u64Ty, "reinterpret_cast", - ArrayAttr{}, rcU64, - ValueRange{rawPtr}) - .getResult(0); - } else if (rawPtr.getType() != u64Ty) { - addr = rewriter.create(loc, u64Ty, rawPtr).getResult(); - } - - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{tile, addr}); - rewriter.replaceOp(op, tile); - return success(); - } -}; - -// ============================================================================= -// Arith CmpI -> EmitC Cmp -// ============================================================================= -class ArithCmpIToEmitC : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - // 将 arith.cmpi 转换为 emitc.cmp - // 映射 Predicate: eq -> equal, slt -> less, etc. - emitc::CmpPredicate emitcPred = emitc::CmpPredicate::eq; - const bool isUnsignedPred = - op.getPredicate() == arith::CmpIPredicate::ult || - op.getPredicate() == arith::CmpIPredicate::ule || - op.getPredicate() == arith::CmpIPredicate::ugt || - op.getPredicate() == arith::CmpIPredicate::uge; - switch (op.getPredicate()) { - case arith::CmpIPredicate::eq: emitcPred = emitc::CmpPredicate::eq; break; - case arith::CmpIPredicate::ne: emitcPred = emitc::CmpPredicate::ne; break; - case arith::CmpIPredicate::slt: emitcPred = emitc::CmpPredicate::lt; break; - case arith::CmpIPredicate::sle: emitcPred = emitc::CmpPredicate::le; break; - case arith::CmpIPredicate::sgt: emitcPred = emitc::CmpPredicate::gt; break; - case arith::CmpIPredicate::sge: emitcPred = emitc::CmpPredicate::ge; break; - // ... 处理无符号比较 (ult, ule 等) ... - case arith::CmpIPredicate::ult: emitcPred = emitc::CmpPredicate::lt; break; - case arith::CmpIPredicate::ule: emitcPred = emitc::CmpPredicate::le; break; - case arith::CmpIPredicate::ugt: emitcPred = emitc::CmpPredicate::gt; break; - case arith::CmpIPredicate::uge: emitcPred = emitc::CmpPredicate::ge; break; - } - - Type resTy = getTypeConverter()->convertType(op.getType()); - if (!resTy) - return failure(); - - Value lhs = adaptor.getLhs(); - Value rhs = adaptor.getRhs(); - if (isUnsignedPred) { - Type opTy = op.getLhs().getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure( - op, "expected scalar integer or index operands"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - if (bitWidth != 1) { - lhs = castSignlessIntToUnsignedSameWidth(rewriter, loc, lhs, bitWidth); - rhs = castSignlessIntToUnsignedSameWidth(rewriter, loc, rhs, bitWidth); - } - } - - rewriter.replaceOpWithNewOp( - op, - /*resultType=*/resTy, // i1 -> bool/i1 - emitcPred, - lhs, - rhs - ); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// Section Op Lowering -//===----------------------------------------------------------------------===// -static bool isA5NoSplitPipeOp(Operation *op) { - if (auto talloc = dyn_cast(op)) - return talloc.getSplit() == 0; - if (auto tpush = dyn_cast(op)) - return tpush.getSplit() == 0; - if (auto tpop = dyn_cast(op)) - return tpop.getSplit() == 0; - if (auto tfree = dyn_cast(op)) - return tfree.getSplit() == 0; - if (auto tpush = dyn_cast(op)) - return tpush.getSplit() == 0; - if (auto tpush = dyn_cast(op)) - return tpush.getSplit() == 0; - if (auto talloc = dyn_cast(op)) - return talloc.getSplit() == 0; - if (auto talloc = dyn_cast(op)) - return talloc.getSplit() == 0; - if (auto tpop = dyn_cast(op)) - return tpop.getSplit() == 0; - if (auto tpop = dyn_cast(op)) - return tpop.getSplit() == 0; - if (auto tfree = dyn_cast(op)) - return tfree.getSplit() == 0; - if (auto tfree = dyn_cast(op)) - return tfree.getSplit() == 0; - return false; -} - -static bool hasExplicitSubblockControl(Operation *op) { - bool hasControl = false; - op->walk([&](Operation *nested) { - if (isa(nested)) { - hasControl = true; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - return hasControl; -} - -static bool needsA5NoSplitVectorGuard(Operation *op) { - auto arch = getTargetArch(op); - if (arch != PTOArch::A5) - return false; - bool isVectorScope = isa(op); - if (auto func = dyn_cast(op)) { - if (auto kernelKindAttr = - func->getAttrOfType( - FunctionKernelKindAttr::name)) { - isVectorScope = - kernelKindAttr.getKernelKind() == FunctionKernelKind::Vector; - } - } - if (!isVectorScope) - return false; - if (hasExplicitSubblockControl(op)) - return false; - - bool hasNoSplitPipe = false; - op->walk([&](Operation *nested) { - if (!isA5NoSplitPipeOp(nested)) - return WalkResult::advance(); - hasNoSplitPipe = true; - return WalkResult::interrupt(); - }); - return hasNoSplitPipe; -} - -template -struct SectionToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - std::string getMacroName() const { - if (std::is_same::value) - return "__DAV_CUBE__"; - if (std::is_same::value) - return "__DAV_VEC__"; - return "UNKNOWN_MACRO"; - } - - LogicalResult - matchAndRewrite(SectionOpTy op, typename SectionOpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - bool needsNoSplitGuard = needsA5NoSplitVectorGuard(op.getOperation()); - - std::string startMacro = "\n#if defined(" + getMacroName() + ")"; - rewriter.create(loc, startMacro); - - if constexpr (std::is_same_v) { - // Vector mask is a global HW state and may be modified by previous kernels - // (or earlier sections). Reset it to a well-defined state for deterministic - // execution of VEC ops. - rewriter.create(loc, "set_mask_norm();"); - rewriter.create(loc, "set_vector_mask(-1, -1);"); - } - - if (needsNoSplitGuard) { - rewriter.create( - loc, "if (get_subblockid() == 0) {"); - } - - Block &innerBlock = op.getBody().front(); - if (!innerBlock.empty()) { - rewriter.inlineBlockBefore(&innerBlock, op.getOperation(), ValueRange{}); - } - - if (needsNoSplitGuard) - rewriter.create(loc, "}"); - - std::string endMacro = "#endif // " + getMacroName() + "\n"; - rewriter.create(loc, endMacro); - - rewriter.eraseOp(op); - - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// SCF Control-Flow Pre-Lowering -// -// EmitC translation supports `emitc.for`/`emitc.if` plus CFG-style -// `cf.br`/`cf.cond_br`. Upstream SCFToEmitC patterns only cover `scf.for` and -// `scf.if`, so we pre-lower some SCF ops into those supported forms. -//===----------------------------------------------------------------------===// - -namespace { - -static bool isTriviallyInlineableExecuteRegion(scf::ExecuteRegionOp op) { - Region &r = op.getRegion(); - if (!r.hasOneBlock()) - return false; - Block &b = r.front(); - return isa_and_nonnull(b.getTerminator()); -} - -static bool needsWholeFunctionSCFToCF(func::FuncOp func) { - bool needs = false; - func.walk([&](Operation *op) { - if (!isa(op)) - return WalkResult::advance(); - Operation *parentOp = op->getParentOp(); - - // `scf.execute_region` can legally appear in single-block parents. Only - // require whole-function SCFToCF if we need to lower it into CFG blocks - // (multi-block region / non-trivial terminators). - if (auto exec = dyn_cast(op)) { - if (parentOp && parentOp->hasTrait() && - !isTriviallyInlineableExecuteRegion(exec)) { - needs = true; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - } - - if (parentOp && parentOp->hasTrait()) { - needs = true; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - return needs; -} - -// scf.execute_region is semantically just an inlined region producing results -// via scf.yield. Inline it to the parent block to avoid extra lowering needs. -struct SCFExecuteRegionInline - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(scf::ExecuteRegionOp op, - PatternRewriter &rewriter) const override { - if (op.getRegion().empty()) - return rewriter.notifyMatchFailure(op, "expected non-empty region"); - - Block &innerBlock = op.getRegion().front(); - auto yield = dyn_cast(innerBlock.getTerminator()); - if (!yield) - return rewriter.notifyMatchFailure(op, "expected scf.yield terminator"); - - // Move the body operations before the execute_region op. - rewriter.inlineBlockBefore(&innerBlock, op.getOperation(), ValueRange{}); - - // Replace execute_region results with yielded values, then erase the yield. - rewriter.replaceOp(op, yield.getOperands()); - rewriter.eraseOp(yield); - return success(); - } -}; - -// Lower scf.execute_region into CFG blocks with cf.br/cf.cond_br by inlining the -// region blocks into the parent region and rewriting scf.yield to branch into a -// continuation block carrying results. -// -// Note: This requires the parent region to allow multiple blocks (e.g. the -// function body CFG region). For execute_region nested in single-block regions -// (scf.for/scf.if), run SCFToCF first to eliminate the single-block constraint. -struct SCFExecuteRegionToCF : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(scf::ExecuteRegionOp op, - PatternRewriter &rewriter) const override { - if (isTriviallyInlineableExecuteRegion(op)) - return rewriter.notifyMatchFailure(op, "trivially inlineable"); - - Operation *parentOp = op->getParentOp(); - if (parentOp && parentOp->hasTrait()) { - return rewriter.notifyMatchFailure( - op, "cannot lower scf.execute_region inside a single-block parent region"); - } - - if (op.getRegion().empty()) - return rewriter.notifyMatchFailure(op, "expected non-empty region"); - - Location loc = op.getLoc(); - Block *curBlock = op->getBlock(); - Region *parentRegion = curBlock->getParent(); - - // Split the parent block so we can branch to a continuation block with phi - // arguments for the execute_region results. - auto execIt = Block::iterator(op.getOperation()); - Block *continueBlock = rewriter.splitBlock(curBlock, std::next(execIt)); - - SmallVector contArgs; - contArgs.reserve(op.getNumResults()); - for (Type t : op.getResultTypes()) - contArgs.push_back(continueBlock->addArgument(t, loc)); - - for (auto it : llvm::enumerate(op.getResults())) - it.value().replaceAllUsesWith(contArgs[it.index()]); - - // Capture blocks before moving the region. - SmallVector movedBlocks; - movedBlocks.reserve(op.getRegion().getBlocks().size()); - for (Block &b : op.getRegion()) - movedBlocks.push_back(&b); - Block *entryBlock = &op.getRegion().front(); - - // Inline the execute_region blocks into the parent region right before the - // continuation block. - rewriter.inlineRegionBefore(op.getRegion(), *parentRegion, - continueBlock->getIterator()); - - // Replace all scf.yield terminators with a branch to the continuation. - for (Block *b : movedBlocks) { - auto yield = dyn_cast(b->getTerminator()); - if (!yield) - continue; - rewriter.setInsertionPoint(yield); - rewriter.create(loc, continueBlock, yield.getOperands()); - rewriter.eraseOp(yield); - } - - // Replace execute_region itself with a branch to the inlined entry block. - rewriter.setInsertionPoint(op); - rewriter.create(loc, entryBlock, ValueRange{}); - rewriter.eraseOp(op); - return success(); - } -}; - -// Lower scf.index_switch into CFG blocks with cf.cond_br/cf.br so that we can -// avoid `scf.if` result materialization quirks (and avoid relying on cf.switch, -// which is not supported by EmitC C++ translation). -struct SCFIndexSwitchToCF : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - static LogicalResult cloneYieldingBlockAndBranchTo( - PatternRewriter &rewriter, Location loc, Block &srcBlock, Block *destBlock, - Block *continueBlock) { - rewriter.setInsertionPointToEnd(destBlock); - - IRMapping mapping; - for (Operation &inner : srcBlock.without_terminator()) - rewriter.clone(inner, mapping); - - auto yield = dyn_cast(srcBlock.getTerminator()); - if (!yield) - return failure(); - - SmallVector yieldOperands; - yieldOperands.reserve(yield.getNumOperands()); - for (Value v : yield.getOperands()) - yieldOperands.push_back(mapping.lookupOrDefault(v)); - - rewriter.create(loc, continueBlock, yieldOperands); - return success(); - } - - static Block *splitBlockForContinuation(PatternRewriter &rewriter, - scf::IndexSwitchOp op) { - auto switchIt = Block::iterator(op.getOperation()); - return rewriter.splitBlock(op->getBlock(), std::next(switchIt)); - } - - static void addContinuationArguments(PatternRewriter &rewriter, - scf::IndexSwitchOp op, Location loc, - Block *continueBlock) { - SmallVector contArgs; - contArgs.reserve(op.getNumResults()); - for (Type type : op.getResultTypes()) - contArgs.push_back(continueBlock->addArgument(type, loc)); - for (auto result : llvm::enumerate(op.getResults())) - result.value().replaceAllUsesWith(contArgs[result.index()]); - } - - static void createIndexSwitchBlocks(PatternRewriter &rewriter, - Region *parentRegion, - Region::iterator insertPt, - unsigned numCases, - SmallVectorImpl &checkBlocks, - Block *&defaultBlock, - SmallVectorImpl &caseBlocks) { - checkBlocks.reserve(numCases); - caseBlocks.reserve(numCases); - for (unsigned i = 0; i < numCases; ++i) - checkBlocks.push_back(rewriter.createBlock(parentRegion, insertPt)); - defaultBlock = rewriter.createBlock(parentRegion, insertPt); - for (unsigned i = 0; i < numCases; ++i) - caseBlocks.push_back(rewriter.createBlock(parentRegion, insertPt)); - } - - static void populateIndexSwitchCheckBlocks( - PatternRewriter &rewriter, Location loc, Value selector, - ArrayRef cases, ArrayRef checkBlocks, - ArrayRef caseBlocks, Block *defaultBlock) { - for (unsigned i = 0; i < checkBlocks.size(); ++i) { - rewriter.setInsertionPointToEnd(checkBlocks[i]); - Value caseVal = rewriter.create(loc, cases[i]); - Value cond = rewriter.create( - loc, arith::CmpIPredicate::eq, selector, caseVal); - Block *falseDest = - (i + 1 < checkBlocks.size()) ? checkBlocks[i + 1] : defaultBlock; - rewriter.create(loc, cond, caseBlocks[i], ValueRange{}, - falseDest, ValueRange{}); - } - } - - LogicalResult matchAndRewrite(scf::IndexSwitchOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Operation *parentOp = op->getParentOp(); - if (parentOp && parentOp->hasTrait()) { - return rewriter.notifyMatchFailure( - op, "cannot lower scf.index_switch inside a single-block parent region"); - } - - Block *curBlock = op->getBlock(); - Region *parentRegion = curBlock->getParent(); - Block *continueBlock = splitBlockForContinuation(rewriter, op); - addContinuationArguments(rewriter, op, loc, continueBlock); - - unsigned numCases = op.getCases().size(); - auto insertPt = continueBlock->getIterator(); - - SmallVector checkBlocks; - SmallVector caseBlocks; - Block *defaultBlock = nullptr; - createIndexSwitchBlocks(rewriter, parentRegion, insertPt, numCases, - checkBlocks, defaultBlock, caseBlocks); - - Value selector = op.getArg(); - auto cases = op.getCases(); - populateIndexSwitchCheckBlocks(rewriter, loc, selector, cases, checkBlocks, - caseBlocks, defaultBlock); - - // Fill case blocks and default block with cloned bodies + branch to cont. - for (unsigned i = 0; i < numCases; ++i) { - if (failed(cloneYieldingBlockAndBranchTo( - rewriter, loc, op.getCaseBlock(i), caseBlocks[i], continueBlock))) - return rewriter.notifyMatchFailure(op, "expected scf.yield terminator"); - } - if (failed(cloneYieldingBlockAndBranchTo(rewriter, loc, op.getDefaultBlock(), - defaultBlock, continueBlock))) - return rewriter.notifyMatchFailure(op, "expected scf.yield terminator"); - - // Replace the original switch op with a branch into the check chain. - Block *entryDest = numCases ? checkBlocks[0] : defaultBlock; - rewriter.setInsertionPointAfter(op); - rewriter.create(loc, entryDest, ValueRange{}); - rewriter.eraseOp(op); - return success(); - } -}; - -// Lower scf.while into CFG blocks with cf.br/cf.cond_br. -// -// Note: This requires the parent region to allow multiple blocks. In -// particular, scf.if/scf.for regions are single-block and cannot contain this -// lowering. -struct SCFWhileToCF : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - static LogicalResult validateWhileResultUses(scf::WhileOp op) { - Block *parentBlock = op->getBlock(); - for (Value result : op.getResults()) { - for (OpOperand &use : result.getUses()) { - if (use.getOwner()->getBlock() != parentBlock) - return failure(); - } - } - return success(); - } - - static Block *splitAfterWhileBlock(PatternRewriter &rewriter, - scf::WhileOp op) { - auto whileIt = Block::iterator(op.getOperation()); - return rewriter.splitBlock(op->getBlock(), std::next(whileIt)); - } - - static void addWhileExitArguments(PatternRewriter &rewriter, scf::WhileOp op, - Location loc, Block *afterWhileBlock) { - SmallVector exitArgs; - exitArgs.reserve(op.getNumResults()); - for (Type type : op.getResultTypes()) - exitArgs.push_back(afterWhileBlock->addArgument(type, loc)); - for (auto result : llvm::enumerate(op.getResults())) - result.value().replaceAllUsesWith(exitArgs[result.index()]); - } - - static Block *createWhileHeaderBlock(PatternRewriter &rewriter, - scf::WhileOp op, Location loc, - Block *afterWhileBlock) { - SmallVector headerArgTypes; - for (Value init : op.getInits()) - headerArgTypes.push_back(init.getType()); - SmallVector headerArgLocs(headerArgTypes.size(), loc); - return rewriter.createBlock(afterWhileBlock->getParent(), - afterWhileBlock->getIterator(), headerArgTypes, - headerArgLocs); - } - - static Block *createWhileBodyBlock(PatternRewriter &rewriter, scf::WhileOp op, - Location loc, Block *afterWhileBlock) { - Block &afterRegionBlock = op.getAfter().front(); - SmallVector bodyArgTypes(afterRegionBlock.getArgumentTypes().begin(), - afterRegionBlock.getArgumentTypes().end()); - SmallVector bodyArgLocs(bodyArgTypes.size(), loc); - return rewriter.createBlock(afterWhileBlock->getParent(), - afterWhileBlock->getIterator(), bodyArgTypes, - bodyArgLocs); - } - - static void rewriteWhileTerminators(PatternRewriter &rewriter, Location loc, - Block *headerBlock, Block *bodyBlock, - Block *afterWhileBlock) { - auto condOp = cast(headerBlock->getTerminator()); - rewriter.setInsertionPoint(condOp); - rewriter.create(loc, condOp.getCondition(), - /*trueDest=*/bodyBlock, - /*trueOperands=*/condOp.getArgs(), - /*falseDest=*/afterWhileBlock, - /*falseOperands=*/condOp.getArgs()); - rewriter.eraseOp(condOp); - - auto yieldOp = cast(bodyBlock->getTerminator()); - rewriter.setInsertionPoint(yieldOp); - rewriter.create(loc, headerBlock, yieldOp.getOperands()); - rewriter.eraseOp(yieldOp); - } - - LogicalResult matchAndRewrite(scf::WhileOp op, - PatternRewriter &rewriter) const override { - Operation *parentOp = op->getParentOp(); - if (parentOp && parentOp->hasTrait()) { - return rewriter.notifyMatchFailure( - op, "cannot lower scf.while inside a single-block parent region"); - } - - if (failed(validateWhileResultUses(op))) - return rewriter.notifyMatchFailure( - op, "unsupported: while results used outside the parent block"); - - auto loc = op.getLoc(); - Block *afterWhileBlock = splitAfterWhileBlock(rewriter, op); - addWhileExitArguments(rewriter, op, loc, afterWhileBlock); - Block *headerBlock = createWhileHeaderBlock(rewriter, op, loc, - afterWhileBlock); - Block *bodyBlock = createWhileBodyBlock(rewriter, op, loc, afterWhileBlock); - - // Move the before/after region bodies into the new CFG blocks. - Block &afterRegionBlock = op.getAfter().front(); - rewriter.mergeBlocks(&op.getBefore().front(), headerBlock, - headerBlock->getArguments()); - rewriter.mergeBlocks(&afterRegionBlock, bodyBlock, bodyBlock->getArguments()); - rewriteWhileTerminators(rewriter, loc, headerBlock, bodyBlock, - afterWhileBlock); - - // Replace scf.while itself with a branch to the header. - rewriter.setInsertionPoint(op); - rewriter.create(loc, headerBlock, op.getInits()); - rewriter.eraseOp(op); - return success(); - } -}; - -// Lower cf.switch into chained comparisons and cf.cond_br/cf.br. -// -// EmitC C++ translation currently supports cf.br/cf.cond_br, but not cf.switch. -struct CFSwitchToCondBr : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - static SmallVector> - collectSwitchCaseOperands(cf::SwitchOp op) { - SmallVector> caseOperands; - caseOperands.reserve(op.getCaseDestinations().size()); - for (auto range : op.getCaseOperands()) - caseOperands.emplace_back(range.begin(), range.end()); - return caseOperands; - } - - static SmallVector getSwitchCaseValues(cf::SwitchOp op) { - SmallVector caseValues; - if (auto caseValuesAttr = op.getCaseValues()) { - for (APInt value : caseValuesAttr->getValues()) - caseValues.push_back(value); - } - return caseValues; - } - - static SmallVector createSwitchCheckBlocks(PatternRewriter &rewriter, - Region *parentRegion, - Block *curBlock, - size_t numCases) { - auto insertPt = std::next(curBlock->getIterator()); - SmallVector checkBlocks; - checkBlocks.reserve(numCases); - for (size_t i = 0; i < numCases; ++i) - checkBlocks.push_back(rewriter.createBlock(parentRegion, insertPt)); - return checkBlocks; - } - - static LogicalResult populateSwitchCheckBlocks( - PatternRewriter &rewriter, Location loc, Value flag, IntegerType flagTy, - ArrayRef caseValues, ArrayRef caseDests, - ArrayRef> caseOperands, Block *defaultDest, - ValueRange defaultOperands, ArrayRef checkBlocks, - cf::SwitchOp op) { - for (size_t i = 0; i < caseDests.size(); ++i) { - rewriter.setInsertionPointToEnd(checkBlocks[i]); - APInt caseVal = caseValues[i]; - if (caseVal.getBitWidth() != flagTy.getWidth()) { - return rewriter.notifyMatchFailure( - op, "case value bitwidth doesn't match flag type"); - } - - Value caseConst = rewriter.create( - loc, flagTy, rewriter.getIntegerAttr(flagTy, caseVal)); - Value cond = rewriter.create( - loc, arith::CmpIPredicate::eq, flag, caseConst); - Block *falseDest = - (i + 1 < checkBlocks.size()) ? checkBlocks[i + 1] : defaultDest; - ValueRange falseOperands = - (i + 1 < checkBlocks.size()) ? ValueRange{} : defaultOperands; - rewriter.create(loc, cond, caseDests[i], - caseOperands[i], falseDest, - falseOperands); - } - return success(); - } - - LogicalResult matchAndRewrite(cf::SwitchOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Operation *parentOp = op->getParentOp(); - if (parentOp && parentOp->hasTrait()) { - return rewriter.notifyMatchFailure( - op, "cannot lower cf.switch inside a single-block parent region"); - } - - Block *curBlock = op->getBlock(); - Region *parentRegion = curBlock->getParent(); - - Value flag = op.getFlag(); - auto flagTy = dyn_cast(flag.getType()); - if (!flagTy) - return rewriter.notifyMatchFailure(op, "expected integer switch flag"); - - SmallVector defaultOperands(op.getDefaultOperands().begin(), - op.getDefaultOperands().end()); - Block *defaultDest = op.getDefaultDestination(); - - SmallVector caseDests(op.getCaseDestinations().begin(), - op.getCaseDestinations().end()); - SmallVector> caseOperands = collectSwitchCaseOperands(op); - - if (caseDests.empty()) { - rewriter.replaceOpWithNewOp(op, defaultDest, defaultOperands); - return success(); - } - - if (!op.getCaseValues()) - return rewriter.notifyMatchFailure(op, "missing case_values"); - SmallVector caseValues = getSwitchCaseValues(op); - - if (caseValues.size() != caseDests.size()) - return rewriter.notifyMatchFailure(op, "case_values/destinations mismatch"); - if (caseOperands.size() != caseDests.size()) - return rewriter.notifyMatchFailure(op, "case_operands/destinations mismatch"); - - SmallVector checkBlocks = - createSwitchCheckBlocks(rewriter, parentRegion, curBlock, - caseDests.size()); - if (failed(populateSwitchCheckBlocks(rewriter, loc, flag, flagTy, - caseValues, caseDests, caseOperands, - defaultDest, defaultOperands, - checkBlocks, op))) { - return failure(); - } - - // Replace the switch terminator with a branch into the first check block. - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp(op, checkBlocks.front(), - ValueRange{}); - return success(); - } -}; - -} // namespace - -static void populatePTOToEmitCPatterns(RewritePatternSet &patterns, - TypeConverter &typeConverter, - MLIRContext *ctx, - DataFlowSolver &solver, - PTOArch targetArch) { - (void)solver; - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx, "pto.set_flag_dyn", - "set_flag"); - patterns.add(typeConverter, ctx, "pto.wait_flag_dyn", - "wait_flag"); - // Backward-compatible aliases used in some downstream branches. - patterns.add(typeConverter, ctx, "pto.set_flag_d", - "set_flag"); - patterns.add(typeConverter, ctx, "pto.wait_flag_d", - "wait_flag"); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add( - typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add( - typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add>( - typeConverter, ctx); - patterns.add>( - typeConverter, ctx); - patterns.add>( - typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add>(typeConverter, - ctx); - patterns.add>(typeConverter, - ctx); - patterns.add>(typeConverter, - ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add>(typeConverter, ctx); - patterns.add>(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add>(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add>(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add>( - typeConverter, ctx, - "pto::comm::TPUT_ASYNC"); - patterns.add>( - typeConverter, ctx, - "pto::comm::TGET_ASYNC"); - patterns.add>(typeConverter, ctx, - "pto::comm::TPUT"); - patterns.add>(typeConverter, ctx, - "pto::comm::TGET"); - patterns.add>(typeConverter, ctx, - "pto::comm::TNOTIFY"); - patterns.add>(typeConverter, ctx, - "pto::comm::TWAIT"); - patterns.add>(typeConverter, ctx, - "pto::comm::TTEST"); - patterns.add>(typeConverter, ctx, - "TBROADCAST"); - patterns.add>(typeConverter, ctx, - "TGATHER"); - patterns.add>(typeConverter, ctx, - "TSCATTER"); - patterns.add>(typeConverter, ctx, - "TREDUCE"); - patterns.add>( - typeConverter, ctx, "PTOAS__ASYNC_EVENT_WAIT"); - patterns.add>( - typeConverter, ctx, "PTOAS__ASYNC_EVENT_TEST"); - patterns.add(typeConverter, ctx, targetArch); - patterns.add(typeConverter, ctx, targetArch); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx, targetArch); - patterns.add(typeConverter, ctx, targetArch); - patterns.add(typeConverter, ctx, targetArch); - patterns.add(typeConverter, ctx, targetArch); - patterns.add(typeConverter, ctx, targetArch); - patterns.add(typeConverter, ctx, targetArch); - patterns.add>(typeConverter, ctx); - patterns.add>(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add< - PTOTMatmulBiasToTMATMUL_BIAS, - PTOTMatmulMXToTMATMUL_MX, - PTOTMatmulMXAccToTMATMUL_MX_ACC, - PTOTMatmulMXBiasToTMATMUL_MX_BIAS, - PTOTMatmulBiasToTMATMUL_BIAS, - PTOTMatmulMXToTMATMUL_MX, - PTOTMatmulMXAccToTMATMUL_MX_ACC, - PTOTMatmulMXBiasToTMATMUL_MX_BIAS, - PTOTGemvBiasToTGEMV_BIAS, - PTOTGemvMXToTGEMV_MX, - PTOTGemvMXAccToTGEMV_MX, - PTOTGemvMXBiasToTGEMV_MX, - PTOBarrierToEmitC - >(typeConverter, ctx); - - patterns.add(typeConverter, ctx); - - populateSCFToEmitCConversionPatterns(patterns); - // Keep CFG-style branches type-consistent when block argument types are - // converted (e.g. after lowering scf.while to cf.br/cf.cond_br). - populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); -} - -//===----------------------------------------------------------------------===// -// Pass -//===----------------------------------------------------------------------===// - -namespace { -struct EmitPTOManualPass - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(EmitPTOManualPass) - - PTOArch targetArch; - - EmitPTOManualPass() : targetArch(PTOArch::A3) {} - - explicit EmitPTOManualPass(PTOArch arch) : targetArch(arch) {} - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - LLVM_DEBUG(llvm::dbgs() << "DEBUG: Start PTOToEmitC Pass\n"); - MLIRContext *ctx = &getContext(); - ModuleOp mop = getOperation(); - - if (failed(pto::validatePTOEntryFunctions(mop))) - return signalPassFailure(); - pto::annotatePTOEntryFunctions(mop); - - // A3 requires explicit FFTS base setup for inter-core sync ops. - if (targetArch == PTOArch::A3) { - bool hasMissingSetFFTs = false; - for (auto func : mop.getOps()) { - if (!hasInterCoreSyncOp(func)) - continue; - if (hasSetFFTsOp(func)) - continue; - hasMissingSetFFTs = true; - func.emitError() - << "A3 inter-core sync requires explicit `pto.set_ffts` in the " - "same function when using `pto.sync.set`/`pto.sync.wait`"; - } - if (hasMissingSetFFTs) - return signalPassFailure(); - } - - bool needsEventIdArrayHelper = false; - bool needsTRandomHelper = false; - bool needsGlobalTensorDataHelper = false; - bool needsCommInclude = false; - mop.walk([&](Operation *op) { - if (isa(op)) - needsEventIdArrayHelper = true; - if (isa(op)) - needsTRandomHelper = true; - if (isa(op)) - needsGlobalTensorDataHelper = true; - if (isa(op)) - needsCommInclude = true; - }); - - // 1. 插入头文件 - auto loc = mop->getLoc(); - OpBuilder builder(ctx); - builder.setInsertionPointToStart(mop.getBody()); - builder.create( - loc, "pto/pto-inst.hpp", /*is_standard_include=*/false); - if (needsCommInclude) { - builder.create( - loc, builder.getStringAttr(R"cpp( -#ifndef PIPE_FIX -#define PIPE_FIX PIPE_M -#endif -)cpp")); - builder.create( - loc, "pto/comm/pto_comm_inst.hpp", /*is_standard_include=*/false); - } - builder.create( - loc, builder.getStringAttr("using namespace pto;")); - if (needsGlobalTensorDataHelper) { - builder.create( - loc, builder.getStringAttr(R"cpp( -template -static AICORE inline auto PTOAS__GLOBAL_TENSOR_DATA(Tensor &tensor) - -> decltype(tensor.data()) { - return tensor.data(); -} -)cpp")); - } - if (needsEventIdArrayHelper) { - builder.create( - loc, builder.getStringAttr(R"cpp( -template -struct PTOAS_EventIdArray { - static_assert(N > 0, "PTOAS_EventIdArray requires a positive static size"); - int32_t data[N] = {}; - - AICORE inline int32_t &operator[](int32_t idx) { return data[idx]; } - AICORE inline const int32_t &operator[](int32_t idx) const { return data[idx]; } -}; -)cpp")); - } - if (needsTRandomHelper) { - builder.create( - loc, builder.getStringAttr(R"cpp( -template -static AICORE inline void PTOAS__TRANDOM( - DstTile &dst, uint32_t key0, uint32_t key1, uint32_t counter0, - uint32_t counter1, uint32_t counter2, uint32_t counter3) { - TRandomKey key = {key0, key1}; - TRandomCounter counter = {counter0, counter1, counter2, counter3}; - TRANDOM(dst, key, counter); -} -)cpp")); - } - builder.create( - loc, builder.getStringAttr(R"cpp( -enum class PTOAutoSyncTailMode : int { - kBarrierAll = 0, - kSetWaitMte3ToSEvent0 = 1, -}; - -static AICORE inline void ptoas_auto_sync_tail( - PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { - switch (mode) { - case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: - set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); - wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); - break; - case PTOAutoSyncTailMode::kBarrierAll: - default: - pipe_barrier(PIPE_ALL); - break; - } -} -)cpp")); - // Only inject the bitcast helper when we actually lower ops that need it - // (e.g. arith.bitcast or arith.maximumf/minimumf tie-breaking on zeros). - bool needsBitcastHelper = false; - mop.walk([&](Operation *op) { - if (isa(op)) { - needsBitcastHelper = true; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - if (needsBitcastHelper) { - builder.create( - loc, builder.getStringAttr(R"cpp( - template - static inline To ptoas_bitcast(From from) { - static_assert(sizeof(To) == sizeof(From), "ptoas_bitcast: size mismatch"); - To to; - __builtin_memcpy(&to, &from, sizeof(To)); - return to; - } - )cpp")); - } - - // 1.5 Pre-lower SCF constructs not handled by SCFToEmitC. - { - // scf.while / scf.index_switch are lowered via CFG blocks. This is not - // possible inside ops that require single-block regions (e.g. scf.for / - // scf.if). If we see such nesting, lower the entire function to the - // ControlFlow dialect first. - bool needsAnySCFToCF = false; - for (auto func : mop.getOps()) { - if (needsWholeFunctionSCFToCF(func)) { - needsAnySCFToCF = true; - break; - } - } - if (needsAnySCFToCF) { - RewritePatternSet scfToCfPatterns(ctx); - populateSCFToControlFlowConversionPatterns(scfToCfPatterns); - FrozenRewritePatternSet frozenSCFToCF(std::move(scfToCfPatterns)); - - ConversionTarget scfToCfTarget(*ctx); - // Only eliminate the single-block SCF constructs; we'll pre-lower - // scf.while/index_switch/execute_region ourselves afterwards. - scfToCfTarget.addIllegalOp(); - scfToCfTarget.markUnknownOpDynamicallyLegal( - [](Operation *) { return true; }); - - for (auto func : mop.getOps()) { - if (!needsWholeFunctionSCFToCF(func)) - continue; - if (failed(applyPartialConversion(func, scfToCfTarget, - frozenSCFToCF))) { - func.emitError() - << "failed to lower nested SCF to ControlFlow (SCFToCF)"; - return signalPassFailure(); - } - } - } - - RewritePatternSet scfLoweringPatterns(ctx); - scfLoweringPatterns.add(ctx); - (void)applyPatternsAndFoldGreedily(mop, std::move(scfLoweringPatterns)); - - bool hasUnsupportedSCF = false; - mop.walk([&](Operation *op) { - if (isa(op)) { - hasUnsupportedSCF = true; - op->emitError() << "Unsupported SCF op remained after pre-lowering"; - return WalkResult::interrupt(); - } - if (isa(op)) { - hasUnsupportedSCF = true; - op->emitError() - << "Unsupported CF op remained after pre-lowering: cf.switch"; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - if (hasUnsupportedSCF) - return signalPassFailure(); - } - - PTOToEmitCTypeConverter typeConverter(ctx, targetArch); - - // 2. Pre-convert SCF structural op types (e.g. scf.if/scf.for results) - // using the same type converter. This avoids creating emitc.variable with - // unsupported types such as memref. - { - RewritePatternSet scfTypePatterns(ctx); - ConversionTarget scfTypeTarget(*ctx); - scf::populateSCFStructuralTypeConversionsAndLegality( - typeConverter, scfTypePatterns, scfTypeTarget); - scfTypeTarget.markUnknownOpDynamicallyLegal( - [](Operation *) { return true; }); - - if (failed(applyPartialConversion(mop, scfTypeTarget, - std::move(scfTypePatterns)))) { - mop.emitError("failed to reconcile SCF structural types"); - return signalPassFailure(); - } - } - - // 3. 配置转换目标 - ConversionTarget target(*ctx); - - target.addIllegalDialect(); - target.addIllegalDialect(); - target.addIllegalDialect(); - target.addIllegalDialect(); - - // If we introduced CFG branches (e.g. from scf.while), make sure they are - // updated to use legalized operand types. - target.addDynamicallyLegalOp( - [&](Operation *op) { - return isLegalForBranchOpInterfaceTypeConversionPattern(op, - typeConverter); - }); - - // [关键] 允许 Cast 存在,最后统一清理 - target.addLegalOp(); - - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - - target.addLegalDialect(); - target.addLegalOp(); - - auto solver = std::make_unique(); - solver->load(); - solver->load(); - if (failed(solver->initializeAndRun(getOperation()))) - return signalPassFailure(); - - RewritePatternSet patterns(ctx); - populatePTOToEmitCPatterns(patterns, typeConverter, ctx, *solver, targetArch); - - // 4. 执行转换 - if (failed(applyPartialConversion(mop, target, std::move(patterns)))) { - llvm::errs() << "Conversion FAILED! Rolling back executed.\n"; - return signalPassFailure(); - } - - // ========================================================================= - // 5. [终极清理] - // 顺序至关重要: - // Step A: 先移除所有 Cast,让 Loop 的 Operand 类型变成底层类型 (如 int32) - // Step B: 再根据新的 Operand 类型,修复 Loop IV 的类型 - // ========================================================================= - - // --- Step A: 清理 UnrealizedConversionCastOp --- - // Prefer dropping redundant/unused casts; otherwise lower to emitc.cast - // so the C++ emitter can print it. - auto isEmitCTileLikeType = [](Type ty) { - auto opaqueTy = dyn_cast(ty); - if (!opaqueTy) - return false; - StringRef value = opaqueTy.getValue(); - return value.contains("Tile<") || value.contains("ConvTile<"); - }; - - llvm::SmallVector castsToErase; - bool castCleanupFailed = false; - mop.walk([&](UnrealizedConversionCastOp cast) { - if (castCleanupFailed) - return; - - if (cast->getNumOperands() != 1 || cast->getNumResults() != 1) { - cast.emitError() << "unsupported unrealized_conversion_cast shape"; - castCleanupFailed = true; - return; - } - - Value input = cast.getOperand(0); - Value output = cast.getResult(0); - Type inTy = input.getType(); - Type outTy = output.getType(); - - if (output.use_empty()) { - castsToErase.push_back(cast); - return; - } - - if (inTy == outTy) { - output.replaceAllUsesWith(input); - castsToErase.push_back(cast); - return; - } - - // SCF/CFG type conversion can transiently materialize pointer->memref - // bridge casts. At this stage, the producing value is already in the - // lowered EmitC pointer form; keep it and drop the bridge cast. - if (isEmitCPointerLikeType(inTy) && isa(outTy)) { - output.replaceAllUsesWith(input); - castsToErase.push_back(cast); - return; - } - - // SCF structural type conversion may leave a bridge from the converted - // EmitC tile value back to the original pto.tile_buf type for PTO op - // users. After PTO ops are lowered, the EmitC tile value is the value we - // want to keep. - if (isEmitCTileLikeType(inTy) && isa(outTy)) { - output.replaceAllUsesWith(input); - castsToErase.push_back(cast); - return; - } - - if (emitc::isSupportedEmitCType(inTy) && emitc::isSupportedEmitCType(outTy)) { - OpBuilder builder(cast); - auto c = builder.create(cast.getLoc(), outTy, input); - output.replaceAllUsesWith(c.getResult()); - castsToErase.push_back(cast); - return; - } - - cast.emitError() << "cannot lower unrealized_conversion_cast(" << inTy - << " -> " << outTy << ") to emitc.cast"; - castCleanupFailed = true; - }); - - for (auto cast : castsToErase) - cast.erase(); - - if (castCleanupFailed) - return signalPassFailure(); - - // --- Step A2: Sink casts of emitc.variable "reads" to their use sites --- - // - // SCFToEmitC lowers scf.if/scf.for results via mutable `emitc.variable` and - // `emitc.assign`. During type conversion, casts from the variable handle to - // the converted type may be materialized right after the variable - // declaration, effectively snapshotting the value *before* assignments. That - // produces wrong C++ (use-before-init / stale reads). - // - // Fix by re-materializing the cast at each use site so it reads the variable - // at the point of use. - { - SmallVector castOpsToSink; - mop.walk([&](emitc::CastOp castOp) { - if (castOp.getSource().getDefiningOp()) - castOpsToSink.push_back(castOp); - }); - - for (emitc::CastOp castOp : castOpsToSink) { - Value src = castOp.getSource(); - Type dstTy = castOp.getResult().getType(); - Value oldRes = castOp.getResult(); - - // Replace each use with a freshly inserted cast right before the user. - for (OpOperand &use : llvm::make_early_inc_range(oldRes.getUses())) { - Operation *user = use.getOwner(); - OpBuilder b(user); - b.setInsertionPoint(user); - auto newCast = b.create(castOp.getLoc(), dstTy, src); - use.set(newCast.getResult()); - } - - castOp.erase(); - } - } - - // --- Step B: 修复 Loop 归纳变量 (IV) --- - // 此时 emitc.for 的 operand 已经是 int32 了,我们检查 IV 是否匹配,不匹配则修正 - mop.walk([&](emitc::ForOp forOp) { - Type boundTy = forOp.getLowerBound().getType(); - BlockArgument iv = forOp.getBody()->getArgument(0); - - if (iv.getType() != boundTy) { - iv.setType(boundTy); // 强制将 IV 类型 (index) 修改为与边界一致 (int32) - } - }); - - // --- Step C: 消除冗余 Tile 变量 (Dead Code Elimination) [新增] --- - // 逻辑:如果一个 emitc.variable 没有被读取(use_empty), - // 那么它自己,以及给它赋值的 TASSIGN 都可以删除。 - // 注意:TASSIGN(v15, v9) 会把 v15 作为 Operand 0 使用,所以 v15 不是严格的 use_empty。 - // 我们需要检查:v15 是否除了 TASSIGN 之外没有其他 User。 - - llvm::SmallVector deadVars; - mop.walk([&](emitc::VariableOp varOp) { - // 检查该变量的所有 User - bool isRead = false; - for (Operation* user : varOp.getResult().getUsers()) { - // 如果 User 是 TASSIGN 且变量是第0个参数(dst),不算"读取" - if (auto call = dyn_cast(user)) { - if (call.getCallee() == "TASSIGN" && call.getOperand(0) == varOp.getResult()) { - continue; // 这是一个赋值操作,不算有效使用 - } - } - // 如果还有其他用途(如 TLOAD, TMOV, TMATMUL),则该变量有用 - isRead = true; - break; - } - - if (!isRead) { - deadVars.push_back(varOp); - } - }); - - for (auto varOp : deadVars) { - // 1. 先删除所有使用该变量的 TASSIGN - llvm::SmallVector usersToErase; - for (Operation* user : varOp.getResult().getUsers()) { - // 我们上面已经确认过,剩下的 user 只能是 TASSIGN - usersToErase.push_back(user); - } - for (auto u : usersToErase) u->erase(); - - // 2. 删除变量定义本身 - varOp.erase(); - } - - llvm::SmallVector deadConsts; - mop.walk([&](emitc::ConstantOp constOp) { - if (constOp.getResult().use_empty()) - deadConsts.push_back(constOp); - }); - for (auto constOp : deadConsts) - constOp.erase(); - - // ========================================================================= - } - }; -} // namespace - -std::unique_ptr mlir::pto::createEmitPTOManualPass() { - return std::make_unique(); -} - -std::unique_ptr mlir::pto::createEmitPTOManualPass(PTOArch arch) { - return std::make_unique(arch); -} diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index 521032476..c21669b81 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -6,5 +6,3610 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. +//===- PTOViewToMemref.cpp ------------------------------------------------===// +//===----------------------------------------------------------------------===// +// +// Lower PTO tile/view operations to memref-based IR while preserving tile +// metadata through binding ops and SSA backtracking. -#include "PTOViewToMemref.def" +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" + +#include "mlir/IR/AsmState.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_PTOVIEWTOMEMREF +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" +#include "Utils.h" // 假设包含一些通用的工具函数 + +#include +#include +#include + +#define DEBUG_TYPE "pto-view-to-memref" + +using namespace mlir; + +namespace mlir { +namespace pto { + +static constexpr llvm::StringLiteral kLoweredSetValidShapeAttrName = + "__pto.lowered_set_validshape"; +static constexpr llvm::StringLiteral kForceDynamicValidShapeAttrName = + "__pto.force_dynamic_valid_shape"; + +namespace { + +static void markForceDynamicValidShape(Operation *op, bool force, + MLIRContext *ctx); + +static Type convertPTOTypeToMemRef(Type t); + +constexpr size_t kTileRank2D = 2; +constexpr size_t kRowDimensionIndex = 0; +constexpr size_t kColumnDimensionIndex = 1; +constexpr unsigned kShapeVectorInlineCapacity = 4; +constexpr unsigned kOperationVectorInlineCapacity = 8; + +constexpr int64_t kElementBytes1 = 1; +constexpr int64_t kElementBytes2 = 2; +constexpr int64_t kElementBytes4 = 4; +constexpr int64_t kElementBytes8 = 8; +constexpr int64_t kElementBytes16 = 16; +constexpr int64_t kElementBytes32 = 32; + +constexpr int64_t kInnerExtent1 = 1; +constexpr int64_t kInnerExtent2 = 2; +constexpr int64_t kInnerExtent4 = 4; +constexpr int64_t kInnerExtent8 = 8; +constexpr int64_t kInnerExtent16 = 16; +constexpr int64_t kInnerExtent32 = 32; + +constexpr int32_t kFractalSize32 = 32; +constexpr int32_t kFractalSize512 = 512; +constexpr int32_t kFractalSize1024 = 1024; + +constexpr int32_t kBLayoutColMajor = + static_cast(BLayout::ColMajor); +constexpr int32_t kSLayoutNoneBox = + static_cast(SLayout::NoneBox); +constexpr int32_t kSLayoutRowMajor = + static_cast(SLayout::RowMajor); +constexpr int32_t kSLayoutColMajor = + static_cast(SLayout::ColMajor); +constexpr int32_t kCompactModeRowPlusOne = + static_cast(CompactMode::RowPlusOne); + +constexpr unsigned kThirdOperandIndex = 2; +constexpr unsigned kFourthOperandIndex = 3; +constexpr unsigned kFifthOperandIndex = 4; +constexpr unsigned kSixthOperandIndex = 5; + +template +using SmallInlineVector = SmallVector; + +template +using DefaultInlineVector = SmallVector; + +// ============================================================================= +// Helper: Metadata Backtracking (核心机制) +// ============================================================================= +// 从一个 MemRef Value 向上回溯,找到它绑定的 TileBufConfig。 +// 这解决了 "Type Erasure" 问题:memref 类型本身不包含 config,但 SSA 定义链包含。 +static mlir::pto::TileBufConfigAttr lookupConfig(Value v) { + // 1. 最直接的情况:它就是 bind_tile 的结果 + if (auto bind = v.getDefiningOp()) { + return bind.getConfig(); + } + // PointerCastOp can also carry tile metadata (used when alloc_tile specifies + // an explicit address). + if (auto pc = v.getDefiningOp()) { + if (auto cfg = pc.getConfig()) + return *cfg; + return {}; + } + + // 2. 穿透 View 操作 (SubView, Cast 等) 向上查找 + if (auto subview = v.getDefiningOp()) { + return lookupConfig(subview.getSource()); + } + if (auto cast = v.getDefiningOp()) { + return lookupConfig(cast.getSource()); + } + if (auto cast = v.getDefiningOp()) { + return lookupConfig(cast.getSource()); + } + + // 如果追溯到 BlockArgument (函数参数) 或其他无法穿透的 Op,则返回空 + return {}; +} + +// ============================================================================= +// Helper: Valid dims backtracking (v_row / v_col) +// ============================================================================= +static void lookupValidDims(Value v, Value &vRow, Value &vCol) { + if (auto bind = v.getDefiningOp()) { + vRow = bind.getValidRow(); + vCol = bind.getValidCol(); + return; + } + if (auto pc = v.getDefiningOp()) { + vRow = pc.getValidRow(); + vCol = pc.getValidCol(); + return; + } + if (auto subview = v.getDefiningOp()) { + lookupValidDims(subview.getSource(), vRow, vCol); + return; + } + if (auto cast = v.getDefiningOp()) { + lookupValidDims(cast.getSource(), vRow, vCol); + return; + } + if (auto cast = v.getDefiningOp()) { + lookupValidDims(cast.getSource(), vRow, vCol); + return; + } + vRow = Value(); + vCol = Value(); +} + +// ============================================================================= +// Helper Functions for Layout Normalization +// ============================================================================= + +struct TileLayoutInfo { + int64_t rowStride = 1; + int64_t colStride = 1; + int64_t innerRows = 1; + int64_t innerCols = 1; + bool boxed = false; // slayout != NoneBox +}; + +struct TileLayoutConfig { + int32_t bLayout = 0; + int32_t sLayout = 0; + int32_t fractalSize = kFractalSize512; + int32_t compactMode = 0; +}; + +static int64_t getElemBytes(Type elemTy) { + unsigned bytes = getPTOStorageElemByteSize(elemTy); + return bytes == 0 ? -1 : static_cast(bytes); +} + +template +static bool readEnumAttrOrIntegerI32(Attribute attr, int32_t &out) { + if (auto enumAttr = dyn_cast(attr)) { + out = static_cast(enumAttr.getValue()); + return true; + } + if (auto intAttr = dyn_cast(attr)) { + out = static_cast(intAttr.getInt()); + return true; + } + return false; +} + +static bool readBLayoutI32(Attribute attr, int32_t &out) { + return readEnumAttrOrIntegerI32(attr, out); +} + +static bool readSLayoutI32(Attribute attr, int32_t &out) { + return readEnumAttrOrIntegerI32(attr, out); +} + +static bool readCompactModeI32(Attribute attr, int32_t &out) { + return readEnumAttrOrIntegerI32(attr, out); +} + +static Value peelIndexLikeCast(Value value) { + while (true) { + if (auto castOp = value.getDefiningOp()) { + value = castOp.getIn(); + continue; + } + if (auto extOp = value.getDefiningOp()) { + value = extOp.getIn(); + continue; + } + if (auto extOp = value.getDefiningOp()) { + value = extOp.getIn(); + continue; + } + if (auto truncOp = value.getDefiningOp()) { + value = truncOp.getIn(); + continue; + } + return value; + } +} + +static bool getConstIndexValue(Value value, int64_t &out) { + value = peelIndexLikeCast(value); + if (auto constIndex = value.getDefiningOp()) { + out = constIndex.value(); + return true; + } + if (auto constInt = value.getDefiningOp()) { + out = constInt.value(); + return true; + } + auto constOp = value.getDefiningOp(); + auto intAttr = + constOp ? dyn_cast(constOp.getValue()) : IntegerAttr(); + if (!intAttr) + return false; + out = intAttr.getInt(); + return true; +} + +static TileLayoutConfig getTileLayoutConfig(mlir::pto::TileBufConfigAttr cfg) { + TileLayoutConfig config; + (void)readBLayoutI32(cfg.getBLayout(), config.bLayout); + (void)readSLayoutI32(cfg.getSLayout(), config.sLayout); + if (auto attr = dyn_cast(cfg.getSFractalSize())) + config.fractalSize = static_cast(attr.getInt()); + (void)readCompactModeI32(cfg.getCompactMode(), config.compactMode); + return config; +} + +static bool getFractal512InnerExtent(int64_t elemBytes, int64_t &extent) { + switch (elemBytes) { + case kElementBytes1: + extent = kInnerExtent32; + return true; + case kElementBytes2: + extent = kInnerExtent16; + return true; + case kElementBytes4: + extent = kInnerExtent8; + return true; + case kElementBytes8: + extent = kInnerExtent4; + return true; + case kElementBytes16: + extent = kInnerExtent2; + return true; + case kElementBytes32: + extent = kInnerExtent1; + return true; + default: + return false; + } +} + +static bool computeBoxInnerShape(const TileLayoutConfig &config, Type elemTy, + TileLayoutInfo &info) { + info.boxed = config.sLayout != kSLayoutNoneBox; + if (!info.boxed) { + info.innerRows = kInnerExtent1; + info.innerCols = kInnerExtent1; + return true; + } + + int64_t elemBytes = getElemBytes(elemTy); + if (elemBytes <= 0) + return false; + + switch (config.fractalSize) { + case kFractalSize1024: + info.innerRows = kInnerExtent16; + info.innerCols = kInnerExtent16; + return true; + case kFractalSize32: + info.innerRows = kInnerExtent16; + info.innerCols = kInnerExtent2; + return true; + case kFractalSize512: + if (config.sLayout == kSLayoutRowMajor) { + info.innerRows = kInnerExtent16; + return getFractal512InnerExtent(elemBytes, info.innerCols); + } + if (config.sLayout == kSLayoutColMajor) { + if (!getFractal512InnerExtent(elemBytes, info.innerRows)) + return false; + info.innerCols = kInnerExtent16; + return true; + } + return false; + default: + return false; + } +} + +static bool computeTilePointerStrides(const TileLayoutConfig &config, + ArrayRef shape, + TileLayoutInfo &info) { + int64_t rows = shape[0]; + int64_t cols = shape[1]; + auto applyCompactToMajorStride = [&](int64_t majorStride) -> int64_t { + if (config.compactMode == kCompactModeRowPlusOne) + return majorStride + kInnerExtent1; + return majorStride; + }; + if (!info.boxed) { + if (config.bLayout == kBLayoutColMajor) { + info.rowStride = kInnerExtent1; + info.colStride = applyCompactToMajorStride(rows); + return true; + } + info.rowStride = applyCompactToMajorStride(cols); + info.colStride = kInnerExtent1; + return true; + } + + if (config.bLayout == kBLayoutColMajor) { + if (config.sLayout != kSLayoutRowMajor) + return false; + info.rowStride = info.innerCols; + info.colStride = applyCompactToMajorStride(rows); + return true; + } + + info.rowStride = applyCompactToMajorStride(cols); + info.colStride = info.innerRows; + return true; +} + +static bool computeTileLayoutInfo(mlir::pto::TileBufConfigAttr cfg, Type elemTy, + ArrayRef shape, + TileLayoutInfo &info) { + if (shape.size() != kTileRank2D || + llvm::is_contained(shape, ShapedType::kDynamic)) + return false; + + TileLayoutConfig config = getTileLayoutConfig(cfg); + return computeBoxInnerShape(config, elemTy, info) && + computeTilePointerStrides(config, shape, info); +} + +static void collectAffineAddTerms(AffineExpr root, + SmallVectorImpl &terms) { + SmallInlineVector pending{root}; + while (!pending.empty()) { + AffineExpr current = pending.pop_back_val(); + auto addExpr = llvm::dyn_cast(current); + if (!addExpr || addExpr.getKind() != AffineExprKind::Add) { + terms.push_back(current); + continue; + } + pending.push_back(addExpr.getRHS()); + pending.push_back(addExpr.getLHS()); + } +} + +static bool tryAssignAffineStride(AffineExpr expr, + MutableArrayRef strides) { + if (auto dim = llvm::dyn_cast(expr)) { + strides[dim.getPosition()] = 1; + return true; + } + + auto mulExpr = llvm::dyn_cast(expr); + if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul) + return false; + + auto assignStride = [&](AffineExpr dimExpr, + AffineExpr constantExpr) -> bool { + auto dim = llvm::dyn_cast(dimExpr); + auto constant = llvm::dyn_cast(constantExpr); + if (!dim || !constant) + return false; + strides[dim.getPosition()] = constant.getValue(); + return true; + }; + return assignStride(mulExpr.getLHS(), mulExpr.getRHS()) || + assignStride(mulExpr.getRHS(), mulExpr.getLHS()); +} + +[[maybe_unused]] static void decomposeStridedLayout(AffineMap map, + SmallVectorImpl &strides) { + strides.assign(map.getNumDims(), 0); + if (map.getNumResults() != 1) + return; + + SmallInlineVector terms; + collectAffineAddTerms(map.getResult(0), terms); + for (AffineExpr term : terms) + (void)tryAssignAffineStride(term, strides); +} + +static Value makeIndexConstant(IRRewriter &rewriter, Location loc, + int64_t value) { + return rewriter.create(loc, rewriter.getIndexType(), + rewriter.getIndexAttr(value)); +} + +static SmallVector computeCompactStrides(ArrayRef shape) { + SmallVector strides(shape.size(), 1); + int64_t stride = 1; + for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { + strides[i] = stride; + if (shape[i] != ShapedType::kDynamic) + stride *= shape[i]; + } + return strides; +} + +static void materializeStaticValidDims(IRRewriter &rewriter, Location loc, + mlir::pto::TileBufType tbTy, Value &vRow, + Value &vCol) { + ArrayRef validShape = tbTy.getValidShape(); + if (tbTy.hasDynamicValid()) + return; + if (!validShape.empty() && validShape[kRowDimensionIndex] >= 0) + vRow = makeIndexConstant(rewriter, loc, validShape[kRowDimensionIndex]); + if (validShape.size() >= kTileRank2D && + validShape[kColumnDimensionIndex] >= 0) + vCol = makeIndexConstant(rewriter, loc, validShape[kColumnDimensionIndex]); +} + +static bool checkMultipleOf(Operation *op, int64_t value, int64_t divisor, + StringRef label) { + if (divisor <= 0) { + op->emitError("boxed layout requires positive divisor for ") << label; + return false; + } + if (value % divisor == 0) + return true; + op->emitError("boxed layout requires ") + << label << " multiple of " << divisor << ", got " << value; + return false; +} + +// 确保 Value 是 Index 类型 +static Value ensureIndex(IRRewriter &rewriter, Location loc, Value v, + Operation *anchorOp) { + if (v.getType().isIndex()) + return v; + if (isa(v.getType())) + return rewriter.create(loc, rewriter.getIndexType(), v); + if (anchorOp) + anchorOp->emitError() << "expected index or integer, but got " << v.getType(); + return Value(); +} + +static bool tryGetIndexAttrFromValue(IRRewriter &rewriter, Value v, + IntegerAttr &constAttr) { + if (auto cOp = v.getDefiningOp()) { + constAttr = rewriter.getIndexAttr(cOp.value()); + return true; + } + if (auto cInt = v.getDefiningOp()) { + constAttr = rewriter.getIndexAttr(cInt.value()); + return true; + } + return false; +} + +static void appendMixedIndex(IRRewriter &rewriter, Location loc, Value v, + Operation *anchorOp, + SmallVectorImpl &mixedVals) { + IntegerAttr constAttr; + if (tryGetIndexAttrFromValue(rewriter, v, constAttr)) { + mixedVals.push_back(constAttr); + return; + } + mixedVals.push_back(ensureIndex(rewriter, loc, v, anchorOp)); +} + +static bool foldAddPtrChainIntoOffset(IRRewriter &rewriter, Location loc, + Value &base, Value &totalOffset) { + bool folded = false; + while (auto add = base.getDefiningOp()) { + folded = true; + Value off = ensureIndex(rewriter, loc, add.getOperand(1), add); + totalOffset = + totalOffset ? rewriter.create(loc, totalOffset, off) : off; + base = add.getOperand(0); + } + return folded; +} + +static Value clampSubViewValidDim(IRRewriter &rewriter, Location loc, + Value explicitValid, int64_t size, + Operation *anchorOp) { + Value sizeVal = rewriter.create(loc, size); + if (!explicitValid) + return sizeVal; + + int64_t cst = 0; + if (getConstIndexValue(explicitValid, cst)) + return rewriter.create(loc, std::min(cst, size)); + + Value v = ensureIndex(rewriter, loc, explicitValid, anchorOp); + Value lt = rewriter.create(loc, arith::CmpIPredicate::slt, v, + sizeVal); + return rewriter.create(loc, lt, v, sizeVal); +} + +[[maybe_unused]] static void dumpPretty(Operation *op, llvm::raw_ostream &os) { + OpPrintingFlags flags; + flags.useLocalScope(); + AsmState state(op, flags); + op->print(os, state); + os << "\n"; + os.flush(); +} + +// ============================================================================= +// Type Converter Logic +// ============================================================================= + +static SmallVector buildTileMemRefStrides(mlir::pto::TileBufType tbTy) { + TileLayoutInfo info; + if (computeTileLayoutInfo(tbTy.getConfigAttr(), tbTy.getElementType(), + tbTy.getShape(), info)) { + return {info.rowStride, info.colStride}; + } + return computeCompactStrides(tbTy.getShape()); +} + +static Type convertTileBufTypeToMemRef(mlir::pto::TileBufType tbTy) { + auto layoutAttr = StridedLayoutAttr::get(tbTy.getContext(), + ShapedType::kDynamic, + buildTileMemRefStrides(tbTy)); + return MemRefType::get(tbTy.getShape(), tbTy.getElementType(), layoutAttr, + tbTy.getMemorySpace()); +} + +static Type convertPTOTypeToMemRef(Type t) { + // 1. 处理 !pto.ptr + if (auto pty = dyn_cast(t)) { + return MemRefType::get({ShapedType::kDynamic}, pty.getElementType(), + MemRefLayoutAttrInterface(), Attribute()); + } + + // 2. 处理 !pto.tile_buf<...> + if (auto tbTy = dyn_cast(t)) + return convertTileBufTypeToMemRef(tbTy); + if (auto tvTy = dyn_cast(t)) + return MemRefType::get(tvTy.getShape(), tvTy.getElementType(), + MemRefLayoutAttrInterface(), Attribute()); + if (auto partTy = dyn_cast(t)) + return MemRefType::get(partTy.getShape(), partTy.getElementType(), + MemRefLayoutAttrInterface(), Attribute()); + // 其他类型透传 + return t; +} + +// Ensure scf.if result types follow the rewritten yield operand types. +// PTOViewToMemref rewrites tile values to memref in branch bodies, but scf.if +// result types are not auto-updated by those op-local rewrites. +static LogicalResult reconcileSCFIfResultTypes(func::FuncOp func) { + DefaultInlineVector ifOps; + func.walk([&](scf::IfOp ifOp) { ifOps.push_back(ifOp); }); + + for (scf::IfOp ifOp : ifOps) { + if (ifOp.getNumResults() == 0) + continue; + + auto thenYield = dyn_cast(ifOp.thenBlock()->getTerminator()); + auto elseYield = dyn_cast(ifOp.elseBlock()->getTerminator()); + if (!thenYield || !elseYield) { + ifOp.emitError("result-bearing scf.if must end with scf.yield in both " + "then/else regions"); + return failure(); + } + + if (thenYield.getNumOperands() != ifOp.getNumResults() || + elseYield.getNumOperands() != ifOp.getNumResults()) { + ifOp.emitError("scf.if result count does not match yielded values"); + return failure(); + } + + for (unsigned i = 0; i < ifOp.getNumResults(); ++i) { + Type thenTy = thenYield.getOperand(i).getType(); + Type elseTy = elseYield.getOperand(i).getType(); + if (thenTy != elseTy) { + ifOp.emitError() << "scf.if branch yield type mismatch at result #" << i + << ": then=" << thenTy << ", else=" << elseTy; + return failure(); + } + + if (ifOp.getResult(i).getType() != thenTy) + ifOp.getResult(i).setType(thenTy); + } + } + + return success(); +} + +static LogicalResult reconcileSCFForResultTypes(func::FuncOp func) { + DefaultInlineVector forOps; + func.walk([&](scf::ForOp forOp) { forOps.push_back(forOp); }); + + for (scf::ForOp forOp : forOps) { + if (forOp.getNumResults() == 0) + continue; + + auto yield = dyn_cast(forOp.getBody()->getTerminator()); + if (!yield) { + forOp.emitError("result-bearing scf.for must end with scf.yield"); + return failure(); + } + + if (yield.getNumOperands() != forOp.getNumResults() || + forOp.getInitArgs().size() != forOp.getNumResults()) { + forOp.emitError("scf.for result count does not match iter/yield values"); + return failure(); + } + + for (unsigned i = 0; i < forOp.getNumResults(); ++i) { + Type initTy = forOp.getInitArgs()[i].getType(); + Type yieldTy = yield.getOperand(i).getType(); + if (initTy != yieldTy) { + forOp.emitError() << "scf.for init/yield type mismatch at result #" << i + << ": init=" << initTy << ", yield=" << yieldTy; + return failure(); + } + + BlockArgument iterArg = forOp.getRegionIterArg(i); + if (iterArg.getType() != initTy) + iterArg.setType(initTy); + if (forOp.getResult(i).getType() != initTy) + forOp.getResult(i).setType(initTy); + } + } + + return success(); +} + +static LogicalResult markLoweredSetValidShapeOps(func::FuncOp func, + MLIRContext *ctx) { + WalkResult result = func.walk([&](mlir::pto::SetValidShapeOp op) { + if (isa(op.getSource().getType())) { + if (!lookupConfig(op.getSource())) { + op.emitError( + "set_validshape requires a locally bound tile source; function " + "arguments/results are unsupported"); + return WalkResult::interrupt(); + } + op->setAttr(kLoweredSetValidShapeAttrName, UnitAttr::get(ctx)); + return WalkResult::advance(); + } + op->removeAttr(kLoweredSetValidShapeAttrName); + return WalkResult::advance(); + }); + return result.wasInterrupted() ? failure() : success(); +} + +static void markForceDynamicValidShape(Operation *op, bool force, + MLIRContext *ctx) { + if (force) { + op->setAttr(kForceDynamicValidShapeAttrName, UnitAttr::get(ctx)); + return; + } + op->removeAttr(kForceDynamicValidShapeAttrName); +} + +[[maybe_unused]] static void rewriteFunctionSignature(func::FuncOp func, MLIRContext *ctx) { + Block &entry = func.front(); + auto fnTy = func.getFunctionType(); + + SmallVector newInputs; + for (Type type : fnTy.getInputs()) + newInputs.push_back(convertPTOTypeToMemRef(type)); + + SmallVector newResults; + for (Type type : fnTy.getResults()) + newResults.push_back(convertPTOTypeToMemRef(type)); + + for (unsigned i = 0; i < entry.getNumArguments(); ++i) { + if (entry.getArgument(i).getType() != newInputs[i]) + entry.getArgument(i).setType(newInputs[i]); + } + func.setFunctionType(FunctionType::get(ctx, newInputs, newResults)); +} + +[[maybe_unused]] static LogicalResult lowerAllocTileOps(func::FuncOp func, MLIRContext *ctx) { + DefaultInlineVector allocTiles; + func.walk([&](mlir::pto::AllocTileOp op) { allocTiles.push_back(op); }); + + for (auto op : allocTiles) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + auto tbTy = dyn_cast(op.getResult().getType()); + if (!tbTy) + continue; + + SmallInlineVector shape(tbTy.getShape().begin(), + tbTy.getShape().end()); + Type elemTy = tbTy.getElementType(); + SmallVector strides = buildTileMemRefStrides(tbTy); + + auto targetLayout = + StridedLayoutAttr::get(ctx, ShapedType::kDynamic, strides); + auto targetType = + MemRefType::get(shape, elemTy, targetLayout, tbTy.getMemorySpace()); + + Value vRow = op.getValidRow(); + Value vCol = op.getValidCol(); + materializeStaticValidDims(rewriter, loc, tbTy, vRow, vCol); + + auto configAttr = tbTy.getConfigAttr(); + if (!configAttr) + configAttr = pto::TileBufConfigAttr::getDefault(ctx); + + if (Value addr = op.getAddr()) { + auto pc = rewriter.create( + loc, targetType, ValueRange{addr}, vRow ? vRow : Value(), + vCol ? vCol : Value(), configAttr); + markForceDynamicValidShape(pc, tbTy.hasDynamicValid(), ctx); + auto bindOp = rewriter.create( + loc, targetType, pc.getResult(), vRow ? vRow : Value(), + vCol ? vCol : Value(), configAttr); + markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); + rewriter.replaceOp(op, bindOp.getResult()); + continue; + } + + auto allocLayout = StridedLayoutAttr::get(ctx, 0, strides); + auto allocType = + MemRefType::get(shape, elemTy, allocLayout, tbTy.getMemorySpace()); + Value alloc = rewriter.create(loc, allocType); + auto bindOp = rewriter.create( + loc, targetType, alloc, vRow ? vRow : Value(), vCol ? vCol : Value(), + configAttr); + markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); + rewriter.replaceOp(op, bindOp.getResult()); + } + return success(); +} + +[[maybe_unused]] static LogicalResult lowerDeclareTileOps(func::FuncOp func, MLIRContext *ctx) { + DefaultInlineVector declaredTiles; + func.walk([&](mlir::pto::DeclareTileOp op) { declaredTiles.push_back(op); }); + + for (auto op : declaredTiles) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + auto tbTy = dyn_cast(op.getTile().getType()); + if (!tbTy) { + op.emitError("declare_tile result must be tile_buf type"); + return failure(); + } + + auto targetType = dyn_cast(convertPTOTypeToMemRef(tbTy)); + if (!targetType) { + op.emitError("failed to convert declare_tile result to memref type"); + return failure(); + } + + auto configAttr = tbTy.getConfigAttr(); + if (!configAttr) + configAttr = pto::TileBufConfigAttr::getDefault(ctx); + + Value vRow; + Value vCol; + materializeStaticValidDims(rewriter, loc, tbTy, vRow, vCol); + + auto declaredMemRef = + rewriter.create(loc, targetType); + auto bindOp = rewriter.create( + loc, targetType, declaredMemRef.getResult(), vRow ? vRow : Value(), + vCol ? vCol : Value(), configAttr); + markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); + rewriter.replaceOp(op, bindOp.getResult()); + } + return success(); +} + +static Value castIndexToI64(IRRewriter &rewriter, Location loc, Value value) { + Type i64Ty = rewriter.getI64Type(); + if (value.getType() == i64Ty) + return value; + return rewriter.create(loc, i64Ty, value).getResult(); +} + +static FailureOr +materializePtrToIntAddPtrAddress(IRRewriter &rewriter, Location loc, + mlir::pto::PtrToIntOp anchor, Value source) { + SmallVector addPtrChain; + Value base = source; + while (auto add = base.getDefiningOp()) { + addPtrChain.push_back(add); + base = add.getOperand(0); + } + + if (addPtrChain.empty()) + return failure(); + + auto baseMemTy = dyn_cast(base.getType()); + if (!baseMemTy) { + anchor.emitOpError( + "pto.addptr source base could not be lowered to a GM memref"); + return failure(); + } + + Value byteAddress = rewriter.create( + loc, rewriter.getI64Type(), base); + for (auto add : addPtrChain) { + auto addPtrTy = dyn_cast(add.getResult().getType()); + if (!addPtrTy) { + anchor.emitOpError("requires pto.addptr source to have !pto.ptr result " + "type before byte-address lowering"); + return failure(); + } + + unsigned elemBytes = + mlir::pto::getPTOStorageElemByteSize(addPtrTy.getElementType()); + if (elemBytes == 0) { + anchor.emitOpError("cannot lower pto.addptr source with unknown element " + "byte size to a byte address"); + return failure(); + } + + Value byteOffset = castIndexToI64(rewriter, loc, add.getOffset()); + if (elemBytes != 1) { + Value elemBytesValue = + rewriter.create(loc, elemBytes, 64); + byteOffset = + rewriter.create(loc, byteOffset, elemBytesValue) + .getResult(); + } + byteAddress = + rewriter.create(loc, byteAddress, byteOffset).getResult(); + } + + return byteAddress; +} + +static LogicalResult lowerIntToPtrOps(func::FuncOp func, MLIRContext *ctx) { + DefaultInlineVector intToPtrs; + func.walk([&](mlir::pto::IntToPtrOp op) { intToPtrs.push_back(op); }); + + for (auto op : intToPtrs) { + if (!isa(op.getResult().getType())) + continue; + + auto targetTy = + dyn_cast(convertPTOTypeToMemRef(op.getResult().getType())); + if (!targetTy) { + op.emitError("failed to convert inttoptr result to memref type"); + return failure(); + } + + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + auto lowered = + rewriter.create(op.getLoc(), targetTy, + op.getAddr()); + lowered->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, lowered.getResult()); + } + + return success(); +} + +static LogicalResult lowerPtrToIntOps(func::FuncOp func, MLIRContext *ctx) { + DefaultInlineVector ptrToInts; + func.walk([&](mlir::pto::PtrToIntOp op) { ptrToInts.push_back(op); }); + + for (auto op : ptrToInts) { + Value source = op.getPtr(); + if (source.getDefiningOp()) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + FailureOr byteAddress = + materializePtrToIntAddPtrAddress(rewriter, op.getLoc(), op, source); + if (failed(byteAddress)) + return failure(); + rewriter.replaceOp(op, *byteAddress); + continue; + } + + if (isa(source.getType())) + continue; + } + + DefaultInlineVector remaining; + func.walk([&](mlir::pto::PtrToIntOp op) { + if (isa(op.getPtr().getType())) + remaining.push_back(op); + }); + for (auto op : remaining) { + op.emitError("ptrtoint source could not be lowered to a GM memref"); + return failure(); + } + + return success(); +} + +[[maybe_unused]] static LogicalResult lowerMakeTensorViewOps(func::FuncOp func, MLIRContext *ctx) { + DefaultInlineVector makeViews; + func.walk([&](mlir::pto::MakeTensorViewOp op) { makeViews.push_back(op); }); + + for (auto op : makeViews) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + Value baseBuf = op.getOperand(0); + OpFoldResult off0 = rewriter.getIndexAttr(0); + bool foldedAddPtr = false; + { + Value cur = baseBuf; + Value totalOffset; + while (auto add = cur.getDefiningOp()) { + foldedAddPtr = true; + Value off = ensureIndex(rewriter, loc, add.getOperand(1), add); + totalOffset = totalOffset ? rewriter.create(loc, totalOffset, off) + : off; + cur = add.getOperand(0); + } + if (cur != baseBuf) { + baseBuf = cur; + off0 = totalOffset ? OpFoldResult(totalOffset) : off0; + } + } + + auto baseMr = dyn_cast(baseBuf.getType()); + if (!baseMr) { + op.emitError("make_tensor_view base must be memref"); + return failure(); + } + + size_t rank = op.getShape().size(); + int64_t dyn = ShapedType::kDynamic; + SmallVector dynStrides(rank, dyn); + auto layout = + StridedLayoutAttr::get(ctx, /*offset=*/dyn, /*strides=*/dynStrides); + SmallVector dynShape(rank, dyn); + auto mrTy = MemRefType::get(dynShape, baseMr.getElementType(), layout, + baseMr.getMemorySpace()); + + SmallInlineVector sizes; + for (Value value : op.getShape()) + sizes.push_back(ensureIndex(rewriter, loc, value, op)); + SmallInlineVector strides; + for (Value value : op.getStrides()) + strides.push_back(ensureIndex(rewriter, loc, value, op)); + + auto rc = rewriter.create(loc, mrTy, baseBuf, off0, + sizes, strides); + if (foldedAddPtr) + rc->setAttr("pto.addptr_trace", rewriter.getUnitAttr()); + if (auto layoutAttr = op.getLayoutAttr()) + rc->setAttr("layout", layoutAttr); + rewriter.replaceOp(op, rc.getResult()); + } + return success(); +} + +[[maybe_unused]] static LogicalResult lowerTensorViewDimOps(func::FuncOp func, MLIRContext *ctx) { + DefaultInlineVector tvDims; + func.walk([&](mlir::pto::GetTensorViewDimOp op) { tvDims.push_back(op); }); + + for (auto op : tvDims) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Value view = op.getTensorView(); + auto mrTy = dyn_cast(view.getType()); + if (!mrTy) + continue; + Value dim = rewriter.create(op.getLoc(), view, op.getDimIndex()); + rewriter.replaceOp(op, dim); + } + return success(); +} + +[[maybe_unused]] static LogicalResult foldAddPtrIntoScalarOps(func::FuncOp func, MLIRContext *ctx) { + DefaultInlineVector loadScalars; + func.walk([&](mlir::pto::LoadScalarOp op) { loadScalars.push_back(op); }); + for (auto op : loadScalars) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + Value base = op.getPtr(); + Value totalOffset = ensureIndex(rewriter, loc, op.getOffset(), op); + bool foldedAddPtr = foldAddPtrChainIntoOffset(rewriter, loc, base, totalOffset); + if (foldedAddPtr) { + auto newOp = + rewriter.create(loc, op.getValue().getType(), base, + totalOffset); + rewriter.replaceOp(op, newOp.getValue()); + } + } + + DefaultInlineVector storeScalars; + func.walk([&](mlir::pto::StoreScalarOp op) { storeScalars.push_back(op); }); + for (auto op : storeScalars) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + Value base = op.getPtr(); + Value totalOffset = ensureIndex(rewriter, loc, op.getOffset(), op); + bool foldedAddPtr = foldAddPtrChainIntoOffset(rewriter, loc, base, totalOffset); + if (foldedAddPtr) { + rewriter.create(loc, base, totalOffset, op.getValue()); + rewriter.eraseOp(op); + } + } + + DefaultInlineVector addPtrs; + func.walk([&](mlir::pto::AddPtrOp op) { addPtrs.push_back(op.getOperation()); }); + bool changed = true; + while (changed) { + changed = false; + for (auto &op : addPtrs) { + if (!op) + continue; + if (op->use_empty()) { + op->erase(); + op = nullptr; + changed = true; + } + } + } + for (Operation *op : addPtrs) { + if (!op) + continue; + op->emitError( + "addptr must feed make_tensor_view or load/store_scalar for lowering"); + return failure(); + } + return success(); +} + +static LogicalResult lowerPartitionViewOps(func::FuncOp func, MLIRContext *ctx) { + DefaultInlineVector partitionViews; + func.walk([&](mlir::pto::PartitionViewOp op) { partitionViews.push_back(op); }); + + for (auto op : partitionViews) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + Value src = op.getOperand(0); + auto srcMrTy = dyn_cast(src.getType()); + if (!srcMrTy) + continue; + int64_t rank = srcMrTy.getRank(); + + SmallVector staticSizes; + SmallVector mixedSizes; + for (Value size : op.getSizes()) { + IntegerAttr constAttr; + bool isStatic = false; + if (auto cOp = size.getDefiningOp()) { + constAttr = rewriter.getIndexAttr(cOp.value()); + isStatic = true; + } else if (auto cInt = size.getDefiningOp()) { + constAttr = rewriter.getIndexAttr(cInt.value()); + isStatic = true; + } + + if (isStatic) { + mixedSizes.push_back(constAttr); + staticSizes.push_back(constAttr.getInt()); + } else { + mixedSizes.push_back(ensureIndex(rewriter, loc, size, op)); + staticSizes.push_back(ShapedType::kDynamic); + } + } + + SmallVector mixedOffsets; + for (Value offset : op.getOffsets()) { + appendMixedIndex(rewriter, loc, offset, op, mixedOffsets); + } + + int64_t dyn = ShapedType::kDynamic; + SmallVector dynStrides(rank, dyn); + auto layout = StridedLayoutAttr::get(ctx, dyn, dynStrides); + auto resTy = MemRefType::get(staticSizes, srcMrTy.getElementType(), layout, + srcMrTy.getMemorySpace()); + + SmallVector mixedStrides(rank, rewriter.getIndexAttr(1)); + auto sv = rewriter.create(loc, resTy, src, mixedOffsets, + mixedSizes, mixedStrides); + if (Operation *srcDef = src.getDefiningOp()) { + if (auto layoutAttr = srcDef->getAttrOfType("layout")) + sv->setAttr("layout", layoutAttr); + } + rewriter.replaceOp(op, sv.getResult()); + } + return success(); +} + +static LogicalResult lowerSubViewOps(func::FuncOp func, MLIRContext *ctx) { + DefaultInlineVector subViews; + func.walk([&](mlir::pto::SubViewOp op) { subViews.push_back(op); }); + + for (auto op : subViews) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + auto resultTileTy = + dyn_cast(op.getResult().getType()); + Value src = op->getOperand(0); + auto srcMrTy = dyn_cast(src.getType()); + if (!srcMrTy) { + op.emitError("pto.subview source must be lowered to memref first"); + return failure(); + } + + ArrayAttr sizeAttr = op.getSizes(); + SmallVector staticSizes; + SmallVector mixedSizes; + for (Attribute attr : sizeAttr) { + int64_t size = cast(attr).getInt(); + staticSizes.push_back(size); + mixedSizes.push_back(rewriter.getIndexAttr(size)); + } + + SmallVector mixedOffsets; + for (Value offset : op.getOffsets()) { + appendMixedIndex(rewriter, loc, offset, op, mixedOffsets); + } + + auto configAttr = lookupConfig(src); + if (!configAttr) + configAttr = pto::TileBufConfigAttr::getDefault(ctx); + + TileLayoutInfo layoutInfo; + if (!computeTileLayoutInfo(configAttr, srcMrTy.getElementType(), + srcMrTy.getShape(), layoutInfo)) { + op.emitError("unsupported tile layout for pto.subview"); + return failure(); + } + + if (layoutInfo.boxed) { + if (staticSizes.size() != kTileRank2D || + op.getOffsets().size() != kTileRank2D) { + op.emitError("boxed layout subview expects 2D sizes/offsets"); + return failure(); + } + if (!checkMultipleOf(op, staticSizes[0], layoutInfo.innerRows, "row size") || + !checkMultipleOf(op, staticSizes[1], layoutInfo.innerCols, "col size")) { + return failure(); + } + + int64_t off0 = 0; + int64_t off1 = 0; + bool off0Const = getConstIndexValue(op.getOffsets()[0], off0); + bool off1Const = getConstIndexValue(op.getOffsets()[1], off1); + if (off0Const && + !checkMultipleOf(op, off0, layoutInfo.innerRows, "row offset")) { + return failure(); + } + if (off1Const && + !checkMultipleOf(op, off1, layoutInfo.innerCols, "col offset")) { + return failure(); + } + + } + + SmallVector srcStrides; + int64_t srcOffset = ShapedType::kDynamic; + if (failed(getStridesAndOffset(srcMrTy, srcStrides, srcOffset))) + srcStrides = computeCompactStrides(srcMrTy.getShape()); + + // Keep parent physical shape + strides for bound tile semantics. + auto resultLayout = + StridedLayoutAttr::get(ctx, ShapedType::kDynamic, srcStrides); + auto parentShape = srcMrTy.getShape(); + auto resultMemRefType = + MemRefType::get(parentShape, srcMrTy.getElementType(), resultLayout, + srcMrTy.getMemorySpace()); + + // Intermediate memref.subview keeps logical subview size. + auto subViewMemRefType = + MemRefType::get(staticSizes, srcMrTy.getElementType(), resultLayout, + srcMrTy.getMemorySpace()); + + SmallVector mixedStrides(staticSizes.size(), + rewriter.getIndexAttr(1)); + auto sv = rewriter.create(loc, subViewMemRefType, src, + mixedOffsets, mixedSizes, + mixedStrides); + + Value vRow; + Value vCol; + if (!staticSizes.empty()) + vRow = clampSubViewValidDim(rewriter, loc, op.getValidRow(), + staticSizes[0], op); + if (staticSizes.size() > 1) + vCol = clampSubViewValidDim(rewriter, loc, op.getValidCol(), + staticSizes[1], op); + + auto bindOp = rewriter.create( + loc, resultMemRefType, sv.getResult(), vRow ? vRow : Value(), + vCol ? vCol : Value(), configAttr); + markForceDynamicValidShape(bindOp, + resultTileTy && resultTileTy.hasDynamicValid(), + ctx); + bindOp->setAttr("pto.view_semantics", rewriter.getStringAttr("subview")); + rewriter.replaceOp(op, bindOp.getResult()); + } + return success(); +} + +static Value buildTileBufViewLikeValue(Operation *anchorOp, Value src, + mlir::pto::TileBufType tbTy, + StringRef viewSemantics, + MLIRContext *ctx) { + Location loc = anchorOp->getLoc(); + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(anchorOp); + + auto srcMrTy = dyn_cast(src.getType()); + if (!srcMrTy) { + anchorOp->emitError("tile_buf view op src must be lowered to memref first"); + return Value(); + } + + auto targetType = dyn_cast(convertPTOTypeToMemRef(tbTy)); + if (!targetType) { + anchorOp->emitError("failed to convert tile_buf type to memref type"); + return Value(); + } + for (int64_t dim : targetType.getShape()) { + if (dim == ShapedType::kDynamic) { + anchorOp->emitError("dynamic shapes are not supported for tile_buf view ops"); + return Value(); + } + } + + Value parentVRow; + Value parentVCol; + lookupValidDims(src, parentVRow, parentVCol); + Value vRow = parentVRow; + Value vCol = parentVCol; + materializeStaticValidDims(rewriter, loc, tbTy, vRow, vCol); + + auto configAttr = tbTy.getConfigAttr(); + if (!configAttr) + configAttr = pto::TileBufConfigAttr::getDefault(ctx); + + auto bindOp = rewriter.create( + loc, targetType, src, vRow ? vRow : Value(), vCol ? vCol : Value(), + configAttr); + markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); + if (!viewSemantics.empty()) + bindOp->setAttr("pto.view_semantics", rewriter.getStringAttr(viewSemantics)); + return bindOp.getResult(); +} + +static LogicalResult lowerTileBufViewLikeOps(func::FuncOp func, MLIRContext *ctx) { + DefaultInlineVector reshapes; + func.walk([&](mlir::pto::TReshapeOp op) { reshapes.push_back(op); }); + for (auto op : reshapes) { + auto tbTy = dyn_cast(op.getResult().getType()); + if (!tbTy) { + op.emitError("treshape result must be tile_buf type"); + return failure(); + } + Value lowered = buildTileBufViewLikeValue(op, op->getOperand(0), tbTy, + "treshape", ctx); + if (!lowered) + return failure(); + IRRewriter rewriter(ctx); + rewriter.replaceOp(op, lowered); + } + + DefaultInlineVector bitcasts; + func.walk([&](mlir::pto::BitcastOp op) { bitcasts.push_back(op); }); + for (auto op : bitcasts) { + auto tbTy = dyn_cast(op.getResult().getType()); + if (!tbTy) { + op.emitError("bitcast result must be tile_buf type"); + return failure(); + } + Value lowered = buildTileBufViewLikeValue(op, op->getOperand(0), tbTy, + "bitcast", ctx); + if (!lowered) + return failure(); + IRRewriter rewriter(ctx); + rewriter.replaceOp(op, lowered); + } + return success(); +} + +// ============================================================================= +// The Pass Implementation +// ============================================================================= + +struct PTOViewToMemrefPass + : public mlir::pto::impl::PTOViewToMemrefBase { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PTOViewToMemrefPass) + + void runOnOperation() override { + ModuleOp mod = getOperation(); + MLIRContext *ctx = &getContext(); + + for (auto func : mod.getOps()) { + if (func.isExternal()) continue; + + // ------------------------------------------------------------------ + // Stage 0: ensure inttoptr values remain scalar-load/store only. + // ------------------------------------------------------------------ + if (failed(validateIntToPtrUses(func))) { + signalPassFailure(); + return; + } + + Block &entry = func.front(); + auto fnTy = func.getFunctionType(); + + // ------------------------------------------------------------------ + // Stage 0.10: Rewrite Function Signature + // ------------------------------------------------------------------ + SmallVector newInputs; + for (Type t : fnTy.getInputs()) newInputs.push_back(convertPTOTypeToMemRef(t)); + + SmallVector newResults; + for (Type t : fnTy.getResults()) newResults.push_back(convertPTOTypeToMemRef(t)); + + // Update entry block arguments + for (unsigned i = 0; i < entry.getNumArguments(); ++i) { + if (entry.getArgument(i).getType() != newInputs[i]) { + entry.getArgument(i).setType(newInputs[i]); + } + } + + // Update function type + func.setFunctionType(FunctionType::get(ctx, newInputs, newResults)); + + // ------------------------------------------------------------------ + // Stage 0.20: lower pto.inttoptr result types to GM memrefs. + // ------------------------------------------------------------------ + if (failed(lowerIntToPtrOps(func, ctx))) { + signalPassFailure(); + return; + } + + // ------------------------------------------------------------------ + // Stage 0.30: materialize pto.ptrtoint(addptr ...) byte offsets. + // ------------------------------------------------------------------ + if (failed(lowerPtrToIntOps(func, ctx))) { + signalPassFailure(); + return; + } + + // ------------------------------------------------------------------ + // Stage 0.5: lower pto.alloc_tile -> memref.alloc + pto.bind_tile + // ------------------------------------------------------------------ + DefaultInlineVector allocTiles; + func.walk([&](mlir::pto::AllocTileOp op) { allocTiles.push_back(op); }); + + for (auto op : allocTiles) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + auto tbTy = dyn_cast(op.getResult().getType()); + if (!tbTy) continue; + + // 1. 获取 Shape 和 ElementType + SmallInlineVector shape(tbTy.getShape().begin(), tbTy.getShape().end()); + Type elemTy = tbTy.getElementType(); + + // 2. 计算 Strides (layout-aware when possible) + SmallVector strides; + TileLayoutInfo info; + if (computeTileLayoutInfo(tbTy.getConfigAttr(), elemTy, shape, info)) { + strides = {info.rowStride, info.colStride}; + } else { + strides.resize(shape.size()); + int64_t s = 1; + for (int i = (int)shape.size() - 1; i >= 0; --i) { + strides[i] = s; + if (shape[i] != ShapedType::kDynamic) s *= shape[i]; + } + } + + // 3. 构造 [BindTile 输出] 的动态类型 (Offset: ?) + // 这必须与 convertPTOTypeToMemRef 返回的类型一致,以便与 Subview 兼容 + auto targetLayout = + StridedLayoutAttr::get(ctx, ShapedType::kDynamic, strides); // offset = ? + auto targetType = + MemRefType::get(shape, elemTy, targetLayout, tbTy.getMemorySpace()); + + // 4. Preserve tile valid dims (v_row / v_col). + // + // `pto.alloc_tile` encodes the valid shape in the result TileBufType + // (e.g. acc tile may be rows=16 but v_row=1). The alloc op itself does + // not necessarily carry explicit operands for static valid dims, so we + // must materialize them from the type to keep them through + // tile_buf -> memref lowering. + // + // For dynamically valid tiles (validShape == [-1, -1]), preserve the + // runtime operands if present. + Value vRow = op.getValidRow(); + Value vCol = op.getValidCol(); + // TileBuf valid dims use a negative sentinel (e.g. '?' / -1), which is + // distinct from MLIR's ShapedType::kDynamic (INT64_MIN). Treat any + // negative value as dynamic here. + materializeStaticValidDims(rewriter, loc, tbTy, vRow, vCol); + + // 5. 获取 Config (保持不变) + auto configAttr = tbTy.getConfigAttr(); + if (!configAttr) configAttr = pto::TileBufConfigAttr::getDefault(ctx); + + // 6. If alloc_tile provides an explicit address, keep the original + // pointer_cast lowering intact and additionally rebind through + // pto.bind_tile. PointerCastOp continues to carry the tile metadata + // used by existing lowering paths, while BindTileOp provides the + // unified anchor EmitC uses to recover tile_buf information. + if (Value addr = op.getAddr()) { + auto pc = rewriter.create( + loc, targetType, ValueRange{addr}, vRow ? vRow : Value(), + vCol ? vCol : Value(), configAttr); + markForceDynamicValidShape(pc, tbTy.hasDynamicValid(), ctx); + auto bindOp = rewriter.create( + loc, targetType, pc.getResult(), vRow ? vRow : Value(), + vCol ? vCol : Value(), configAttr); + markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); + rewriter.replaceOp(op, bindOp.getResult()); + continue; + } + + // 7. Otherwise allocate a concrete memref buffer and bind tile. + // memref.alloc 要求明确的 layout,不能是动态 offset。 + auto allocLayout = StridedLayoutAttr::get(ctx, 0, strides); // offset = 0 + auto allocType = MemRefType::get(shape, elemTy, allocLayout, tbTy.getMemorySpace()); + Value alloc = rewriter.create(loc, allocType); + + // BindTileOp 的 Builder 会自动处理空的 Value,将其视为静态维度 + auto bindOp = rewriter.create( + loc, targetType, alloc, vRow ? vRow : Value(), vCol ? vCol : Value(), + configAttr); + markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); + + rewriter.replaceOp(op, bindOp.getResult()); + } + + // ------------------------------------------------------------------ + // Stage 0.75: lower pto.declare_tile -> pto.declare_tile_memref + + // pto.bind_tile + // ------------------------------------------------------------------ + DefaultInlineVector declaredTiles; + func.walk([&](mlir::pto::DeclareTileOp op) { declaredTiles.push_back(op); }); + + for (auto op : declaredTiles) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + auto tbTy = dyn_cast(op.getTile().getType()); + if (!tbTy) { + op.emitError("declare_tile result must be tile_buf type"); + signalPassFailure(); + return; + } + + auto targetType = dyn_cast(convertPTOTypeToMemRef(tbTy)); + if (!targetType) { + op.emitError("failed to convert declare_tile result to memref type"); + signalPassFailure(); + return; + } + + auto configAttr = tbTy.getConfigAttr(); + if (!configAttr) + configAttr = pto::TileBufConfigAttr::getDefault(ctx); + + Value vRow; + Value vCol; + materializeStaticValidDims(rewriter, loc, tbTy, vRow, vCol); + + auto declaredMemRef = + rewriter.create(loc, targetType); + auto bindOp = rewriter.create( + loc, targetType, declaredMemRef.getResult(), + vRow ? vRow : Value(), vCol ? vCol : Value(), configAttr); + markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); + + rewriter.replaceOp(op, bindOp.getResult()); + } + + // ------------------------------------------------------------------ + // Stage 0.8: normalize pto.tassign result type to match tile operand + // after tile_buf -> memref lowering (required for verifier consistency). + // ------------------------------------------------------------------ + DefaultInlineVector tassignOps; + func.walk([&](mlir::pto::TAssignOp op) { tassignOps.push_back(op); }); + for (auto op : tassignOps) { + Type targetTy = op.getTile().getType(); + if (op.getResult().getType() == targetTy) + continue; + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + auto normalized = + rewriter.create(op.getLoc(), targetTy, op.getTile(), + op.getAddr()); + rewriter.replaceOp(op, normalized.getResult()); + } + + // ------------------------------------------------------------------ + // Stage 1: Lower pto.make_tensor_view -> memref.reinterpret_cast + // ------------------------------------------------------------------ + DefaultInlineVector makeViews; + func.walk([&](mlir::pto::MakeTensorViewOp op) { makeViews.push_back(op); }); + + for (auto op : makeViews) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + Value baseBuf = op.getOperand(0); + OpFoldResult off0 = rewriter.getIndexAttr(0); + + // Fold pto.addptr chains into the view base to avoid nested reinterpret_cast. + bool foldedAddPtr = false; + { + Value cur = baseBuf; + Value totalOffset; + while (auto add = cur.getDefiningOp()) { + foldedAddPtr = true; + Value off = ensureIndex(rewriter, loc, add.getOperand(1), add); + if (totalOffset) + totalOffset = rewriter.create(loc, totalOffset, off); + else + totalOffset = off; + cur = add.getOperand(0); + } + if (cur != baseBuf) { + baseBuf = cur; + off0 = totalOffset ? OpFoldResult(totalOffset) : off0; + } + } + + auto baseMr = dyn_cast(baseBuf.getType()); + if (!baseMr) { + op.emitError("make_tensor_view base must be memref"); signalPassFailure(); return; + } + + // [修复] 获取动态 Rank (根据 shape 输入的数量) + size_t rank = op.getShape().size(); + + // Construct target type with dynamic offset/strides + Type elemTy = baseMr.getElementType(); + int64_t dyn = ShapedType::kDynamic; + + // [修复] 构建 N 维 Strided Layout + // strides 数组长度必须等于 rank + SmallVector dynStrides(rank, dyn); + auto layout = StridedLayoutAttr::get(ctx, /*offset=*/dyn, /*strides=*/dynStrides); + + // [修复] 构建 N 维 Shape + SmallVector dynShape(rank, dyn); + auto mrTy = MemRefType::get(dynShape, elemTy, layout, baseMr.getMemorySpace()); + + SmallInlineVector sizes; + for (Value v : op.getShape()) sizes.push_back(ensureIndex(rewriter, loc, v, op)); + + SmallInlineVector strides; + for (Value v : op.getStrides()) strides.push_back(ensureIndex(rewriter, loc, v, op)); + + auto rc = rewriter.create( + loc, mrTy, baseBuf, off0, sizes, strides); + if (foldedAddPtr) { + rc->setAttr("pto.addptr_trace", rewriter.getUnitAttr()); + } + if (auto layoutAttr = op.getLayoutAttr()) { + rc->setAttr("layout", layoutAttr); + } + + rewriter.replaceOp(op, rc.getResult()); + } + + // ------------------------------------------------------------------ + // Stage 1.25: Lower pto.get_tensor_view_dim -> memref.dim + // ------------------------------------------------------------------ + DefaultInlineVector tvDims; + func.walk([&](mlir::pto::GetTensorViewDimOp op) { tvDims.push_back(op); }); + + for (auto op : tvDims) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + Value view = op.getTensorView(); + auto mrTy = dyn_cast(view.getType()); + if (!mrTy) + continue; // leave it to later passes if it hasn't been lowered yet + + Value dimIdx = op.getDimIndex(); + Value dim = rewriter.create(loc, view, dimIdx); + rewriter.replaceOp(op, dim); + } + + // ------------------------------------------------------------------ + // Stage 1.3: Lower pto.partition_view -> memref.subview + // ------------------------------------------------------------------ + if (failed(lowerPartitionViewOps(func, ctx))) { + signalPassFailure(); + return; + } + + // ------------------------------------------------------------------ + // Stage 1.35: Lower pto.subview -> memref.subview + pto.bind_tile + // ------------------------------------------------------------------ + if (failed(lowerSubViewOps(func, ctx))) { + signalPassFailure(); + return; + } + + // ------------------------------------------------------------------ + // Stage 1.4: Lower tile_buf view-like ops (treshape/bitcast) + // ------------------------------------------------------------------ + if (failed(lowerTileBufViewLikeOps(func, ctx))) { + signalPassFailure(); + return; + } + + // ------------------------------------------------------------------ + // Stage 1.5: Fold pto.addptr chains into load/store_scalar. + // ------------------------------------------------------------------ + DefaultInlineVector loadScalars; + func.walk([&](mlir::pto::LoadScalarOp op) { loadScalars.push_back(op); }); + + for (auto op : loadScalars) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + Value base = op.getPtr(); + Value totalOffset = ensureIndex(rewriter, loc, op.getOffset(), op); + + bool foldedAddPtr = false; + while (auto add = base.getDefiningOp()) { + foldedAddPtr = true; + Value off = ensureIndex(rewriter, loc, add.getOperand(1), add); + if (totalOffset) + totalOffset = rewriter.create(loc, totalOffset, off); + else + totalOffset = off; + base = add.getOperand(0); + } + + if (foldedAddPtr) { + auto newOp = rewriter.create( + loc, op.getValue().getType(), base, totalOffset); + rewriter.replaceOp(op, newOp.getValue()); + } + } + + DefaultInlineVector storeScalars; + func.walk([&](mlir::pto::StoreScalarOp op) { storeScalars.push_back(op); }); + + for (auto op : storeScalars) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + Value base = op.getPtr(); + Value totalOffset = ensureIndex(rewriter, loc, op.getOffset(), op); + + bool foldedAddPtr = false; + while (auto add = base.getDefiningOp()) { + foldedAddPtr = true; + Value off = ensureIndex(rewriter, loc, add.getOperand(1), add); + if (totalOffset) + totalOffset = rewriter.create(loc, totalOffset, off); + else + totalOffset = off; + base = add.getOperand(0); + } + + if (foldedAddPtr) { + rewriter.create( + loc, base, totalOffset, op.getValue()); + rewriter.eraseOp(op); + } + } + + // ------------------------------------------------------------------ + // Stage 1.75: Fold addptr used by initialize_l2g2l_pipe(gm_addr). + // This keeps IR well-typed after function arguments are rewritten from + // !pto.ptr to memref. + // ------------------------------------------------------------------ + bool foldedPipeInitAddPtr = true; + while (foldedPipeInitAddPtr) { + foldedPipeInitAddPtr = false; + DefaultInlineVector addPtrsForPipeInit; + func.walk([&](mlir::pto::AddPtrOp op) { + bool eligible = !op->use_empty(); + for (Operation *user : op->getUsers()) { + auto init = dyn_cast(user); + if (!init || init.getGmAddr() != op->getResult(0)) { + eligible = false; + break; + } + } + if (eligible) + addPtrsForPipeInit.push_back(op); + }); + + for (auto op : addPtrsForPipeInit) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + Value base = op->getOperand(0); + Value totalOffset = ensureIndex(rewriter, loc, op->getOperand(1), op); + while (auto add = base.getDefiningOp()) { + Value off = ensureIndex(rewriter, loc, add->getOperand(1), add); + totalOffset = rewriter.create(loc, totalOffset, off); + base = add->getOperand(0); + } + + auto baseMrTy = dyn_cast(base.getType()); + if (!baseMrTy || baseMrTy.getRank() != 1) + continue; + + int64_t dyn = ShapedType::kDynamic; + auto layout = StridedLayoutAttr::get(ctx, dyn, {dyn}); + auto targetTy = MemRefType::get({dyn}, baseMrTy.getElementType(), layout, + baseMrTy.getMemorySpace()); + SmallVector sizes{rewriter.getIndexAttr(1)}; + SmallVector strides{rewriter.getIndexAttr(1)}; + auto rc = rewriter.create( + loc, targetTy, base, OpFoldResult(totalOffset), sizes, strides); + rc->setAttr("pto.addptr_trace", rewriter.getUnitAttr()); + rewriter.replaceOp(op, rc.getResult()); + foldedPipeInitAddPtr = true; + } + } + + // Clean up: addptr should be folded into make_tensor_view. + DefaultInlineVector addPtrs; + func.walk([&](mlir::pto::AddPtrOp op) { addPtrs.push_back(op.getOperation()); }); + bool changed = true; + while (changed) { + changed = false; + for (auto &op : addPtrs) { + if (!op) + continue; + if (op->use_empty()) { + op->erase(); + op = nullptr; + changed = true; + } + } + } + for (auto *op : addPtrs) { + if (!op) + continue; + op->emitError("addptr must feed make_tensor_view, initialize_l2g2l_pipe(gm_addr) or load/store_scalar for lowering"); + signalPassFailure(); + return; + } + + // ------------------------------------------------------------------ + // Stage 3: Rewrite Compute Ops + // [关键] 全面使用 op->getOperand(i) 避免 Typed Accessor Crash + // ------------------------------------------------------------------ + + // --- TLoadOp [Src, Dst] --- + DefaultInlineVector loads; + func.walk([&](mlir::pto::TLoadOp op) { loads.push_back(op); }); + for (auto op : loads) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op->getOperand(0); + Value dst = op->getOperand(1); + + auto newOp = + rewriter.create(op.getLoc(), TypeRange{}, src, dst); + newOp->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newOp->getResults()); + } + + // --- TStoreOp [Src, Dst] --- + DefaultInlineVector storeops; + func.walk([&](mlir::pto::TStoreOp op) { storeops.push_back(op); }); + for (auto op : storeops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op->getOperand(0); + Value dst = op->getOperand(1); + Value preQuant = op.getPreQuantScalar(); + + pto::TStoreOp newOp; + if (preQuant) { + newOp = rewriter.create(op.getLoc(), TypeRange{}, + src, dst, preQuant); + } else { + newOp = rewriter.create(op.getLoc(), TypeRange{}, + src, dst, Value{}); + } + newOp->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newOp->getResults()); + } + + // --- TTransOp [Src, Tmp, Dst] --- + DefaultInlineVector trans; + func.walk([&](mlir::pto::TTransOp op) { trans.push_back(op); }); + for (auto op : trans) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex)); + } + + // --- TExpOp [Src, Dst] --- + DefaultInlineVector exp; + func.walk([&](mlir::pto::TExpOp op) { exp.push_back(op); }); + for (auto op : exp) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, op->getOperand(0), op->getOperand(1)); + } + + // --- TMulOp [Src, Scalar, Dst] --- + DefaultInlineVector mul; + func.walk([&](mlir::pto::TMulOp op) { mul.push_back(op); }); + for (auto op : mul) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, op->getOperand(0), op.getOperand(1), + op->getOperand(kThirdOperandIndex)); + } + + // --- TMulSOp [Src, Scalar, Dst] --- + DefaultInlineVector muls; + func.walk([&](mlir::pto::TMulSOp op) { muls.push_back(op); }); + for (auto op : muls) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, op->getOperand(0), op.getScalar(), + op->getOperand(kThirdOperandIndex)); + } + + // --- TAddOp [Src0, Src1, Dst] --- + DefaultInlineVector addops; + func.walk([&](mlir::pto::TAddOp op) { addops.push_back(op); }); + for (auto op : addops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex)); + } + + // --- TMatmulOp [Lhs, Rhs, Dst] (no optional bias in ODS) --- + DefaultInlineVector matmuls; + func.walk([&](mlir::pto::TMatmulOp op) { matmuls.push_back(op); }); + for (auto op : matmuls) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Value lhs = op->getOperand(0); + Value rhs = op->getOperand(1); + Value dst = op->getOperand(kThirdOperandIndex); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, lhs, rhs, dst, op.getAccPhaseAttr()); + } + + // --- TMatmulAccOp [Acc, Lhs, Rhs, Dst] --- + DefaultInlineVector matmulAccs; + func.walk([&](mlir::pto::TMatmulAccOp op) { matmulAccs.push_back(op); }); + for (auto op : matmulAccs) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex), op.getAccPhaseAttr()); + } + + // --- TMatmulBiasOp [Acc, Lhs, Rhs, Bias, Dst] --- + DefaultInlineVector matmulBiass; + func.walk([&](mlir::pto::TMatmulBiasOp op) { matmulBiass.push_back(op); }); + for (auto op : matmulBiass) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex)); + } + + // --- TMatmulMxOp--- + DefaultInlineVector matmulMxs; + func.walk([&](mlir::pto::TMatmulMxOp op) { matmulMxs.push_back(op); }); + for (auto op : matmulMxs) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex), + op->getOperand(kFifthOperandIndex)); + } + + // --- TMatmulMxAccOp --- + DefaultInlineVector matmulMxAccs; + func.walk([&](mlir::pto::TMatmulMxAccOp op) { matmulMxAccs.push_back(op); }); + for (auto op : matmulMxAccs) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex), + op->getOperand(kFifthOperandIndex), + op->getOperand(kSixthOperandIndex)); + } + + // --- TMatmulMxBiasOp --- + DefaultInlineVector matmulMxBiass; + func.walk([&](mlir::pto::TMatmulMxBiasOp op) { matmulMxBiass.push_back(op); }); + for (auto op : matmulMxBiass) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex), + op->getOperand(kFifthOperandIndex), + op->getOperand(kSixthOperandIndex)); + } + + // --- TGemvOp [Lhs, Rhs, Dst] --- + DefaultInlineVector gemvs; + func.walk([&](mlir::pto::TGemvOp op) { gemvs.push_back(op); }); + for (auto op : gemvs) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value lhs = op->getOperand(0); + Value rhs = op->getOperand(1); + Value dst = op->getOperand(kThirdOperandIndex); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, lhs, rhs, dst); + } + + // --- TGemvAccOp [Acc, Lhs, Rhs, Dst] --- + DefaultInlineVector gemvAccs; + func.walk([&](mlir::pto::TGemvAccOp op) { gemvAccs.push_back(op); }); + for (auto op : gemvAccs) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex)); + } + + // --- TGemvBiasOp [Acc, Lhs, Rhs, Bias, Dst] --- + DefaultInlineVector gemvBiass; + func.walk([&](mlir::pto::TGemvBiasOp op) { gemvBiass.push_back(op); }); + for (auto op : gemvBiass) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex)); + } + + // --- TGemvMxOp [A, AScale, B, BScale, Dst] --- + DefaultInlineVector gemvMxs; + func.walk([&](mlir::pto::TGemvMxOp op) { gemvMxs.push_back(op); }); + for (auto op : gemvMxs) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex), + op->getOperand(kFifthOperandIndex)); + } + + // --- TGemvMxAccOp [CIn, A, AScale, B, BScale, Dst] --- + DefaultInlineVector gemvMxAccs; + func.walk([&](mlir::pto::TGemvMxAccOp op) { gemvMxAccs.push_back(op); }); + for (auto op : gemvMxAccs) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex), + op->getOperand(kFifthOperandIndex), + op->getOperand(kSixthOperandIndex)); + } + + // --- TGemvMxBiasOp [A, AScale, B, BScale, Bias, Dst] --- + DefaultInlineVector gemvMxBiass; + func.walk([&](mlir::pto::TGemvMxBiasOp op) { gemvMxBiass.push_back(op); }); + for (auto op : gemvMxBiass) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex), + op->getOperand(kFifthOperandIndex), + op->getOperand(kSixthOperandIndex)); + } + + // --- TMovOp [Src, Dst] --- + DefaultInlineVector movs; + func.walk([&](mlir::pto::TMovOp op) { movs.push_back(op); }); + for (auto op : movs) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, op.getSrc(), op.getDst(), op.getFp(), + op.getPreQuantScalar(), op.getAccToVecModeAttr(), + op.getReluPreModeAttr()); + } + + DefaultInlineVector abseops; + func.walk([&](mlir::pto::TAbsOp op) { abseops.push_back(op); }); + + for (auto op : abseops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); + } + + DefaultInlineVector addcops; + func.walk([&](mlir::pto::TAddCOp op) { addcops.push_back(op); }); + + for (auto op : addcops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value src2 = op.getSrc2(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto src2Ty = dyn_cast(src2.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !src2Ty ||!dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + src2, + dst); + } + + DefaultInlineVector addsops; + func.walk([&](mlir::pto::TAddSOp op) { addsops.push_back(op); }); + + for (auto op : addsops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value scalar = op.getScalar(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + scalar, + dst); + } + + DefaultInlineVector addscops; + func.walk([&](mlir::pto::TAddSCOp op) { addscops.push_back(op); }); + + for (auto op : addscops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value scalar = op.getScalar(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + scalar, + src1, + dst); + } + + DefaultInlineVector andops; + func.walk([&](mlir::pto::TAndOp op) { andops.push_back(op); }); + + for (auto op : andops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + dst); + } + + DefaultInlineVector concats; + func.walk([&](mlir::pto::TConcatOp op) { concats.push_back(op); }); + + for (auto op : concats) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + dst); + } + + DefaultInlineVector concatIdxs; + func.walk([&](mlir::pto::TConcatidxOp op) { concatIdxs.push_back(op); }); + + IRRewriter rewriter(ctx); + for (auto op : concatIdxs) { + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value src0Idx = op.getSrc0Idx(); + Value src1Idx = op.getSrc1Idx(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto src0IdxTy = dyn_cast(src0Idx.getType()); + auto src1IdxTy = dyn_cast(src1Idx.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !src0IdxTy || !src1IdxTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + src0Idx, + src1Idx, + dst); + } + + DefaultInlineVector andsops; + func.walk([&](mlir::pto::TAndSOp op) { andsops.push_back(op); }); + + for (auto op : andsops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value scalar = op.getScalar(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + scalar, + dst); + } + + DefaultInlineVector ciops; + func.walk([&](mlir::pto::TCIOp op) { ciops.push_back(op); }); + + for (auto op : ciops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value s = op->getOperand(0); + Value dst = op.getDst(); + bool descending = op.getDescending(); + + auto sTy = dyn_cast(s.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!sTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + s, + dst, + descending); + } + + DefaultInlineVector cmpops; + func.walk([&](mlir::pto::TCmpOp op) { cmpops.push_back(op); }); + + for (auto op : cmpops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + auto newOp = rewriter.create( + op.getLoc(), + TypeRange{}, + src0, + src1, + dst); + + if (auto a = op.getCmpModeAttr()) + newOp->setAttr("cmpMode", a); + + rewriter.replaceOp(op, newOp->getResults()); // 0 results -> OK + } + + DefaultInlineVector cmpsops; + func.walk([&](mlir::pto::TCmpSOp op) { cmpsops.push_back(op); }); + + for (auto op : cmpsops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value scalar = op.getScalar(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + auto scalarTy = scalar.getType(); + bool scalarOk = + isa(scalarTy); // ScalarType in ODS: int/float + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + if (!scalarOk) { + op.emitError("expects scalar to be an integer or float type"); + signalPassFailure(); + return; + } + + auto cmpMode = op.getCmpModeAttr(); + auto newOp = rewriter.create( + op.getLoc(), + TypeRange{}, + src, + scalar, + cmpMode, + dst); + + rewriter.replaceOp(op, newOp->getResults()); // 0 results -> OK + } + + DefaultInlineVector colexpand; + func.walk([&](mlir::pto::TColExpandOp op) { colexpand.push_back(op); }); + + for (auto op : colexpand) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if ( !srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); + } + + DefaultInlineVector colmaxops; + func.walk([&](mlir::pto::TColMaxOp op) { colmaxops.push_back(op); }); + + for (auto op : colmaxops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if ( !srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); + } + + DefaultInlineVector colminops; + func.walk([&](mlir::pto::TColMinOp op) { colminops.push_back(op); }); + + for (auto op : colminops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if ( !srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); + } + + DefaultInlineVector colexpandmulops; + func.walk([&](mlir::pto::TColExpandMulOp op) { + colexpandmulops.push_back(op); + }); + + for (auto op : colexpandmulops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + dst); + } + + DefaultInlineVector colexpandmaxops; + func.walk([&](mlir::pto::TColExpandMaxOp op) { + colexpandmaxops.push_back(op); + }); + + for (auto op : colexpandmaxops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + dst); + } + + DefaultInlineVector colexpandminops; + func.walk([&](mlir::pto::TColExpandMinOp op) { + colexpandminops.push_back(op); + }); + + for (auto op : colexpandminops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + dst); + } + + DefaultInlineVector colsumops; + func.walk([&](mlir::pto::TColSumOp op) { colsumops.push_back(op); }); + + for (auto op : colsumops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + Value tmp = op.getTmp(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("src/dst are not memref yet"); + signalPassFailure(); + return; + } + + // If tmp exists, it must have isBinary attribute + if (tmp) { + auto tmpTy = dyn_cast(tmp.getType()); + if (!tmpTy) { + op.emitError("tmp is not memref yet"); + signalPassFailure(); + return; + } + + // Get isBinary attribute (should exist if tmp exists) + BoolAttr isBinaryAttr = op.getIsBinaryAttr(); + if (!isBinaryAttr) { + isBinaryAttr = BoolAttr::get(ctx, false); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + tmp, + dst, + isBinaryAttr); + } else { + // Format 1: no tmp, no isBinary + // Use generic builder to avoid adding default isBinary attribute + SmallVector operands = {src, dst}; + SmallVector attrs; + // Copy all attributes except isBinary + for (auto attr : op->getAttrs()) { + if (attr.getName() != "isBinary") { + attrs.push_back(attr); + } + } + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + operands, + attrs); + } + } + + DefaultInlineVector cvtops; + func.walk([&](mlir::pto::TCvtOp op) { cvtops.push_back(op); }); + + for (auto op : cvtops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + auto rmodeAttr = op.getRmodeAttr(); // PTO_RoundModeAttr + auto satModeAttr = op.getSatModeAttr(); + + auto newOp = rewriter.create( + op.getLoc(), + TypeRange{}, + src, + dst, + rmodeAttr, + satModeAttr); + + rewriter.replaceOp(op, newOp->getResults()); + } + + DefaultInlineVector divops; + func.walk([&](mlir::pto::TDivOp op) { divops.push_back(op); }); + + for (auto op : divops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + dst); + } + + DefaultInlineVector divsops; + func.walk([&](mlir::pto::TDivSOp op) { divsops.push_back(op); }); + + for (auto op : divsops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value scale = op.getScalar(); + Value dst = op.getDst(); + + // Check types - they might still be TileBufType or already converted to MemRefType + auto srcTy = dyn_cast(src.getType()); + auto srcTileTy = dyn_cast(src.getType()); + auto scaleTileTy = dyn_cast(scale.getType()); + auto dstTy = dyn_cast(dst.getType()); + auto dstTileTy = dyn_cast(dst.getType()); + + // Determine which operand is tile-like and which is scalar-like. + // Keep the original operand order (set by parser textual form). + // Check if src is memref/tensor/tile (not scalar) + bool srcIsMemref = (srcTy != nullptr || srcTileTy != nullptr || + isa(src.getType()) || + isa(src.getType())); + // Check if scale is memref/tensor/tile (not scalar) + bool scaleIsMemref = (isa(scale.getType()) || + scaleTileTy != nullptr || + isa(scale.getType()) || + isa(scale.getType())); + + // Type validation - ensure we have the right types + if (!srcIsMemref && !scaleIsMemref) { + op.emitError("at least one operand (src or scale) must be tile_buf or memref"); + signalPassFailure(); + return; + } + if (srcIsMemref && scaleIsMemref) { + op.emitError("exactly one operand (src or scale) must be tile_buf or memref, the other must be scalar"); + signalPassFailure(); + return; + } + + if (!dstTy && !dstTileTy) { + op.emitError("dst operand must be tile_buf or memref"); + signalPassFailure(); + return; + } + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + scale, + dst); + } + + DefaultInlineVector expandsops; + func.walk([&](mlir::pto::TExpandsOp op) { expandsops.push_back(op); }); + + for (auto op : expandsops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value scalar = op.getScalar(); + Value dst = op.getDst(); + + auto dstTy = dyn_cast(dst.getType()); + if (!dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + scalar, + dst); + } + + DefaultInlineVector extractops; + func.walk([&](mlir::pto::TExtractOp op) { extractops.push_back(op); }); + + for (auto op : extractops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value indexRow = op.getIndexRow(); + Value indexCol = op.getIndexCol(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto indexRowTy = dyn_cast(indexRow.getType()); + auto indexColTy = dyn_cast(indexCol.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !indexRowTy || !indexColTy || !dstTy) { + op.emitError("ins/outs are not correct yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + indexRow, + indexCol, + dst); + } + + DefaultInlineVector fillpadops; + func.walk([&](mlir::pto::TFillPadOp op) { fillpadops.push_back(op); }); + + for (auto op : fillpadops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); + } + + DefaultInlineVector fillpadInplaceOps; + func.walk( + [&](mlir::pto::TFillPadInplaceOp op) { fillpadInplaceOps.push_back(op); }); + + for (auto op : fillpadInplaceOps) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); + } + + // --- TSetValOp [Dst, Offset, Val] --- + // Lower tile-world scalar write to memref-world SETVAL DPS op. + DefaultInlineVector tsetvalops; + func.walk([&](mlir::pto::TSetValOp op) { tsetvalops.push_back(op); }); + + for (auto op : tsetvalops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value dst = op.getDst(); + Value offset = op.getOffset(); + Value val = op.getVal(); + + auto dstTy = dyn_cast(dst.getType()); + if (!dstTy) { + op.emitError("dst is not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + dst, + offset, + val); + } + + // --- TGetValOp [Src, Offset] -> Scalar --- + // Lower tile-world scalar read to memref-world GETVAL DPS op. + DefaultInlineVector tgetvalops; + func.walk([&](mlir::pto::TGetValOp op) { tgetvalops.push_back(op); }); + + for (auto op : tgetvalops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value offset = op.getOffset(); + Type dstType = op.getDst().getType(); + + auto srcTy = dyn_cast(src.getType()); + if (!srcTy) { + op.emitError("src is not memref yet"); + signalPassFailure(); + return; + } + + auto newOp = rewriter.create( + op.getLoc(), + dstType, + src, + offset); + rewriter.replaceOp(op, newOp.getDst()); + } + + DefaultInlineVector gatherops; + func.walk([&](mlir::pto::TGatherOp op) { gatherops.push_back(op); }); + + for (auto op : gatherops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + Value cdst = op.getCdst(); + Value indices = op.getIndices(); + Value tmp = op.getTmp(); + Value kValue = op.getKValue(); + auto maskPattern = op.getMaskPatternAttr(); + auto cmpMode = op.getCmpModeAttr(); + auto offset = op.getOffsetAttr(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + if (maskPattern) { + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst, + /*cdst=*/Value(), + /*indices=*/Value(), + /*tmp=*/Value(), + /*kValue=*/Value(), + /*maskPattern=*/maskPattern, + /*cmpMode=*/pto::CmpModeAttr(), + /*offset=*/IntegerAttr()); + continue; + } + + if (cdst || kValue) { + auto cdstTy = dyn_cast(cdst.getType()); + auto tmpTy = dyn_cast(tmp.getType()); + if (!cdstTy || !tmpTy) { + op.emitError("compare-form tgather expects cdst/tmp to be memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst, + cdst, + /*indices=*/Value(), + tmp, + kValue, + /*maskPattern=*/pto::MaskPatternAttr(), + cmpMode, + offset); + continue; + } + + if (indices || tmp) { + auto indicesTy = dyn_cast(indices.getType()); + auto tmpTy = dyn_cast(tmp.getType()); + if (!indicesTy || !tmpTy) { + op.emitError("index-form tgather expects indices/tmp to be memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst, + /*cdst=*/Value(), + indices, + tmp, + /*kValue=*/Value(), + /*maskPattern=*/pto::MaskPatternAttr(), + /*cmpMode=*/pto::CmpModeAttr(), + /*offset=*/IntegerAttr()); + continue; + } + + op.emitError("expects tgather to be in mask, index+tmp, or compare+tmp form"); + signalPassFailure(); + return; + } + + DefaultInlineVector gatherbops; + func.walk([&](mlir::pto::TGatherBOp op) { gatherbops.push_back(op); }); + + for (auto op : gatherbops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value offsets = op.getOffsets(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto offsetsTy = dyn_cast(offsets.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !offsetsTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + offsets, + dst); + } + + DefaultInlineVector logops; + func.walk([&](mlir::pto::TLogOp op) { logops.push_back(op); }); + + for (auto op : logops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); + } + + DefaultInlineVector lreluops; + func.walk([&](mlir::pto::TLReluOp op) { lreluops.push_back(op); }); + + for (auto op : lreluops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value slope = op.getSlope(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto slopeTy = dyn_cast(slope.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !slopeTy || !dstTy) { + op.emitError("ins/outs are not correct type yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + slope, + dst); + } + + DefaultInlineVector maxops; + func.walk([&](mlir::pto::TMaxOp op) { maxops.push_back(op); }); + + for (auto op : maxops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + dst); + } + + DefaultInlineVector maxsops; + func.walk([&](mlir::pto::TMaxSOp op) { maxsops.push_back(op); }); + + for (auto op : maxsops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value scalar = op.getScalar(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + bool scalarIsScalar = isa(scalar.getType()); + if (!srcTy || !scalarIsScalar || !dstTy) { + op.emitError("expects src/dst to be memref and scalar to be integer/float"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + scalar, + dst); + } + + DefaultInlineVector minops; + func.walk([&](mlir::pto::TMinOp op) { minops.push_back(op); }); + + for (auto op : minops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + dst); + } + + DefaultInlineVector minsops; + func.walk([&](mlir::pto::TMinSOp op) { minsops.push_back(op); }); + + for (auto op : minsops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value scalar = op.getScalar(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + bool scalarIsScalar = isa(scalar.getType()); + if (!srcTy || !scalarIsScalar || !dstTy) { + op.emitError("expects src/dst to be memref and scalar to be integer/float"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + scalar, + dst); + } + + DefaultInlineVector movfpops; + func.walk([&](mlir::pto::TMovFPOp op) { movfpops.push_back(op); }); + + for (auto op : movfpops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value fp = op.getFp(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto fpTy = dyn_cast(fp.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !fpTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + fp, + dst); + } + + DefaultInlineVector quantops; + func.walk([&](mlir::pto::TQuantOp op) { quantops.push_back(op); }); + + for (auto op : quantops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value fp = op.getFp(); + Value offset = op.getOffset(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto fpTy = dyn_cast(fp.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !fpTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + if (offset && !dyn_cast(offset.getType())) { + op.emitError("offset is not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + fp, + offset, + dst, + op.getQuantTypeAttr()); + } + + DefaultInlineVector mrgsortops; + func.walk([&](mlir::pto::TMrgSortOp op) { mrgsortops.push_back(op); }); + + for (auto op : mrgsortops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + if (op.isFormat1()) { + Value src = op.getSrc(); + Value dst = op.getDst(); + Value blockLenVal = op.getBlockLen(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + ValueRange{src}, + blockLenVal, + ValueRange{dst}, + Value() /*tmp*/, + Value() /*excuted*/, + op.getExhaustedAttr()); + } else if (op.isFormat2()) { + bool allMemRef = true; + for (Value v : op.getSrcs()) + if (!dyn_cast(v.getType())) { allMemRef = false; break; } + if (!allMemRef) { + op.emitError("format2 ins/outs are not memref yet"); + signalPassFailure(); + return; + } + if (op.getDsts().size() != 1u || !op.getTmp()) { + op.emitError("format2 expects outs(dst) and ins(tmp)"); + signalPassFailure(); + return; + } + + Value dst = op.getDst(); + Value tmp = op.getTmp(); + Value excuted = op.getExcuted(); + if (!dyn_cast(dst.getType()) || !dyn_cast(tmp.getType())) { + op.emitError("format2 dst/tmp must be memref"); + signalPassFailure(); + return; + } + if (!dyn_cast(excuted.getType())) { + op.emitError("format2 outs(excuted) must be vector"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + op.getSrcs(), + Value() /*blockLen*/, + ValueRange{dst}, + tmp, + excuted, + op.getExhaustedAttr()); + } else { + op.emitError("tmrgsort must be format1 or format2"); + signalPassFailure(); + return; + } + } + + DefaultInlineVector negops; + func.walk([&](mlir::pto::TNegOp op) { negops.push_back(op); }); + + for (auto op : negops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); + } + + DefaultInlineVector notops; + func.walk([&](mlir::pto::TNotOp op) { notops.push_back(op); }); + + for (auto op : notops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); + } + + DefaultInlineVector orops; + func.walk([&](mlir::pto::TOrOp op) { orops.push_back(op); }); + + for (auto op : orops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + src0, + src1, + dst); + } + + DefaultInlineVector orsops; + func.walk([&](mlir::pto::TOrSOp op) { orsops.push_back(op); }); + + for (auto op : orsops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value scalar = op.getScalar(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto scalarTy = dyn_cast(scalar.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !scalarTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + scalar, + dst); + } + + DefaultInlineVector partaddops; + func.walk([&](mlir::pto::TPartAddOp op) { partaddops.push_back(op); }); + + for (auto op : partaddops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + src0, + src1, + dst); + } + + DefaultInlineVector partmulops; + func.walk([&](mlir::pto::TPartMulOp op) { partmulops.push_back(op); }); + + for (auto op : partmulops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + src0, + src1, + dst); + } + + DefaultInlineVector mgatherops; + func.walk([&](mlir::pto::MGatherOp op) { mgatherops.push_back(op); }); + + for (auto op : mgatherops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value dst = op.getDst(); + Value idx = op.getIdx(); + Value mem = op.getMem(); + + auto dstTy = dyn_cast(dst.getType()); + auto idxTy = dyn_cast(idx.getType()); + auto memTy = dyn_cast(mem.getType()); + if (!dstTy || !idxTy || !memTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + mem, + idx, + dst, + op.getGatherOobAttr()); + } + + DefaultInlineVector mascatterops; + func.walk([&](mlir::pto::MScatterOp op) { mascatterops.push_back(op); }); + + for (auto op : mascatterops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value idx = op.getIdx(); + Value mem = op.getMem(); + + auto srcTy = dyn_cast(src.getType()); + auto idxTy = dyn_cast(idx.getType()); + auto memTy = dyn_cast(mem.getType()); + if (!srcTy || !idxTy || !memTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + idx, + mem, + op.getScatterAtomicOpAttr(), + op.getScatterOobAttr()); + } + DefaultInlineVector printops; + func.walk([&](mlir::pto::TPrintOp op) { printops.push_back(op); }); + + for (auto op : printops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + + auto srcTy = dyn_cast(src.getType()); + if (!srcTy) { + op.emitError("ins/outs are not memref yet"); + signalPassFailure(); + return; + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src); + } + + // ------------------------------------------------------------------ + // Stage 4: Reconcile control-flow result types + // ------------------------------------------------------------------ + if (failed(reconcileSCFIfResultTypes(func))) { + signalPassFailure(); + return; + } + if (failed(reconcileSCFForResultTypes(func))) { + signalPassFailure(); + return; + } + + // Mark memref-form set_validshape only after control-flow result-type + // reconciliation. Values such as scf.if results can stay tile_buf until + // this late stage. + if (failed(markLoweredSetValidShapeOps(func, ctx))) { + signalPassFailure(); + return; + } + } + + // Debug Output + LLVM_DEBUG(llvm::dbgs() << mod.getOperation()); + } +}; + +} // namespace + +std::unique_ptr createPTOViewToMemrefPass() { + return std::make_unique(); +} + +} // namespace pto +} // namespace mlir diff --git a/lib/PTO/Transforms/PTOViewToMemref.def b/lib/PTO/Transforms/PTOViewToMemref.def deleted file mode 100644 index c21669b81..000000000 --- a/lib/PTO/Transforms/PTOViewToMemref.def +++ /dev/null @@ -1,3615 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -//===- PTOViewToMemref.cpp ------------------------------------------------===// -//===----------------------------------------------------------------------===// -// -// Lower PTO tile/view operations to memref-based IR while preserving tile -// metadata through binding ops and SSA backtracking. - -#include "PTO/IR/PTO.h" -#include "PTO/IR/PTOTypeUtils.h" -#include "PTO/Transforms/Passes.h" - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" - -#include "mlir/IR/AsmState.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/SymbolTable.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace pto { -#define GEN_PASS_DEF_PTOVIEWTOMEMREF -#include "PTO/Transforms/Passes.h.inc" -} // namespace pto -} // namespace mlir - -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/raw_ostream.h" -#include "Utils.h" // 假设包含一些通用的工具函数 - -#include -#include -#include - -#define DEBUG_TYPE "pto-view-to-memref" - -using namespace mlir; - -namespace mlir { -namespace pto { - -static constexpr llvm::StringLiteral kLoweredSetValidShapeAttrName = - "__pto.lowered_set_validshape"; -static constexpr llvm::StringLiteral kForceDynamicValidShapeAttrName = - "__pto.force_dynamic_valid_shape"; - -namespace { - -static void markForceDynamicValidShape(Operation *op, bool force, - MLIRContext *ctx); - -static Type convertPTOTypeToMemRef(Type t); - -constexpr size_t kTileRank2D = 2; -constexpr size_t kRowDimensionIndex = 0; -constexpr size_t kColumnDimensionIndex = 1; -constexpr unsigned kShapeVectorInlineCapacity = 4; -constexpr unsigned kOperationVectorInlineCapacity = 8; - -constexpr int64_t kElementBytes1 = 1; -constexpr int64_t kElementBytes2 = 2; -constexpr int64_t kElementBytes4 = 4; -constexpr int64_t kElementBytes8 = 8; -constexpr int64_t kElementBytes16 = 16; -constexpr int64_t kElementBytes32 = 32; - -constexpr int64_t kInnerExtent1 = 1; -constexpr int64_t kInnerExtent2 = 2; -constexpr int64_t kInnerExtent4 = 4; -constexpr int64_t kInnerExtent8 = 8; -constexpr int64_t kInnerExtent16 = 16; -constexpr int64_t kInnerExtent32 = 32; - -constexpr int32_t kFractalSize32 = 32; -constexpr int32_t kFractalSize512 = 512; -constexpr int32_t kFractalSize1024 = 1024; - -constexpr int32_t kBLayoutColMajor = - static_cast(BLayout::ColMajor); -constexpr int32_t kSLayoutNoneBox = - static_cast(SLayout::NoneBox); -constexpr int32_t kSLayoutRowMajor = - static_cast(SLayout::RowMajor); -constexpr int32_t kSLayoutColMajor = - static_cast(SLayout::ColMajor); -constexpr int32_t kCompactModeRowPlusOne = - static_cast(CompactMode::RowPlusOne); - -constexpr unsigned kThirdOperandIndex = 2; -constexpr unsigned kFourthOperandIndex = 3; -constexpr unsigned kFifthOperandIndex = 4; -constexpr unsigned kSixthOperandIndex = 5; - -template -using SmallInlineVector = SmallVector; - -template -using DefaultInlineVector = SmallVector; - -// ============================================================================= -// Helper: Metadata Backtracking (核心机制) -// ============================================================================= -// 从一个 MemRef Value 向上回溯,找到它绑定的 TileBufConfig。 -// 这解决了 "Type Erasure" 问题:memref 类型本身不包含 config,但 SSA 定义链包含。 -static mlir::pto::TileBufConfigAttr lookupConfig(Value v) { - // 1. 最直接的情况:它就是 bind_tile 的结果 - if (auto bind = v.getDefiningOp()) { - return bind.getConfig(); - } - // PointerCastOp can also carry tile metadata (used when alloc_tile specifies - // an explicit address). - if (auto pc = v.getDefiningOp()) { - if (auto cfg = pc.getConfig()) - return *cfg; - return {}; - } - - // 2. 穿透 View 操作 (SubView, Cast 等) 向上查找 - if (auto subview = v.getDefiningOp()) { - return lookupConfig(subview.getSource()); - } - if (auto cast = v.getDefiningOp()) { - return lookupConfig(cast.getSource()); - } - if (auto cast = v.getDefiningOp()) { - return lookupConfig(cast.getSource()); - } - - // 如果追溯到 BlockArgument (函数参数) 或其他无法穿透的 Op,则返回空 - return {}; -} - -// ============================================================================= -// Helper: Valid dims backtracking (v_row / v_col) -// ============================================================================= -static void lookupValidDims(Value v, Value &vRow, Value &vCol) { - if (auto bind = v.getDefiningOp()) { - vRow = bind.getValidRow(); - vCol = bind.getValidCol(); - return; - } - if (auto pc = v.getDefiningOp()) { - vRow = pc.getValidRow(); - vCol = pc.getValidCol(); - return; - } - if (auto subview = v.getDefiningOp()) { - lookupValidDims(subview.getSource(), vRow, vCol); - return; - } - if (auto cast = v.getDefiningOp()) { - lookupValidDims(cast.getSource(), vRow, vCol); - return; - } - if (auto cast = v.getDefiningOp()) { - lookupValidDims(cast.getSource(), vRow, vCol); - return; - } - vRow = Value(); - vCol = Value(); -} - -// ============================================================================= -// Helper Functions for Layout Normalization -// ============================================================================= - -struct TileLayoutInfo { - int64_t rowStride = 1; - int64_t colStride = 1; - int64_t innerRows = 1; - int64_t innerCols = 1; - bool boxed = false; // slayout != NoneBox -}; - -struct TileLayoutConfig { - int32_t bLayout = 0; - int32_t sLayout = 0; - int32_t fractalSize = kFractalSize512; - int32_t compactMode = 0; -}; - -static int64_t getElemBytes(Type elemTy) { - unsigned bytes = getPTOStorageElemByteSize(elemTy); - return bytes == 0 ? -1 : static_cast(bytes); -} - -template -static bool readEnumAttrOrIntegerI32(Attribute attr, int32_t &out) { - if (auto enumAttr = dyn_cast(attr)) { - out = static_cast(enumAttr.getValue()); - return true; - } - if (auto intAttr = dyn_cast(attr)) { - out = static_cast(intAttr.getInt()); - return true; - } - return false; -} - -static bool readBLayoutI32(Attribute attr, int32_t &out) { - return readEnumAttrOrIntegerI32(attr, out); -} - -static bool readSLayoutI32(Attribute attr, int32_t &out) { - return readEnumAttrOrIntegerI32(attr, out); -} - -static bool readCompactModeI32(Attribute attr, int32_t &out) { - return readEnumAttrOrIntegerI32(attr, out); -} - -static Value peelIndexLikeCast(Value value) { - while (true) { - if (auto castOp = value.getDefiningOp()) { - value = castOp.getIn(); - continue; - } - if (auto extOp = value.getDefiningOp()) { - value = extOp.getIn(); - continue; - } - if (auto extOp = value.getDefiningOp()) { - value = extOp.getIn(); - continue; - } - if (auto truncOp = value.getDefiningOp()) { - value = truncOp.getIn(); - continue; - } - return value; - } -} - -static bool getConstIndexValue(Value value, int64_t &out) { - value = peelIndexLikeCast(value); - if (auto constIndex = value.getDefiningOp()) { - out = constIndex.value(); - return true; - } - if (auto constInt = value.getDefiningOp()) { - out = constInt.value(); - return true; - } - auto constOp = value.getDefiningOp(); - auto intAttr = - constOp ? dyn_cast(constOp.getValue()) : IntegerAttr(); - if (!intAttr) - return false; - out = intAttr.getInt(); - return true; -} - -static TileLayoutConfig getTileLayoutConfig(mlir::pto::TileBufConfigAttr cfg) { - TileLayoutConfig config; - (void)readBLayoutI32(cfg.getBLayout(), config.bLayout); - (void)readSLayoutI32(cfg.getSLayout(), config.sLayout); - if (auto attr = dyn_cast(cfg.getSFractalSize())) - config.fractalSize = static_cast(attr.getInt()); - (void)readCompactModeI32(cfg.getCompactMode(), config.compactMode); - return config; -} - -static bool getFractal512InnerExtent(int64_t elemBytes, int64_t &extent) { - switch (elemBytes) { - case kElementBytes1: - extent = kInnerExtent32; - return true; - case kElementBytes2: - extent = kInnerExtent16; - return true; - case kElementBytes4: - extent = kInnerExtent8; - return true; - case kElementBytes8: - extent = kInnerExtent4; - return true; - case kElementBytes16: - extent = kInnerExtent2; - return true; - case kElementBytes32: - extent = kInnerExtent1; - return true; - default: - return false; - } -} - -static bool computeBoxInnerShape(const TileLayoutConfig &config, Type elemTy, - TileLayoutInfo &info) { - info.boxed = config.sLayout != kSLayoutNoneBox; - if (!info.boxed) { - info.innerRows = kInnerExtent1; - info.innerCols = kInnerExtent1; - return true; - } - - int64_t elemBytes = getElemBytes(elemTy); - if (elemBytes <= 0) - return false; - - switch (config.fractalSize) { - case kFractalSize1024: - info.innerRows = kInnerExtent16; - info.innerCols = kInnerExtent16; - return true; - case kFractalSize32: - info.innerRows = kInnerExtent16; - info.innerCols = kInnerExtent2; - return true; - case kFractalSize512: - if (config.sLayout == kSLayoutRowMajor) { - info.innerRows = kInnerExtent16; - return getFractal512InnerExtent(elemBytes, info.innerCols); - } - if (config.sLayout == kSLayoutColMajor) { - if (!getFractal512InnerExtent(elemBytes, info.innerRows)) - return false; - info.innerCols = kInnerExtent16; - return true; - } - return false; - default: - return false; - } -} - -static bool computeTilePointerStrides(const TileLayoutConfig &config, - ArrayRef shape, - TileLayoutInfo &info) { - int64_t rows = shape[0]; - int64_t cols = shape[1]; - auto applyCompactToMajorStride = [&](int64_t majorStride) -> int64_t { - if (config.compactMode == kCompactModeRowPlusOne) - return majorStride + kInnerExtent1; - return majorStride; - }; - if (!info.boxed) { - if (config.bLayout == kBLayoutColMajor) { - info.rowStride = kInnerExtent1; - info.colStride = applyCompactToMajorStride(rows); - return true; - } - info.rowStride = applyCompactToMajorStride(cols); - info.colStride = kInnerExtent1; - return true; - } - - if (config.bLayout == kBLayoutColMajor) { - if (config.sLayout != kSLayoutRowMajor) - return false; - info.rowStride = info.innerCols; - info.colStride = applyCompactToMajorStride(rows); - return true; - } - - info.rowStride = applyCompactToMajorStride(cols); - info.colStride = info.innerRows; - return true; -} - -static bool computeTileLayoutInfo(mlir::pto::TileBufConfigAttr cfg, Type elemTy, - ArrayRef shape, - TileLayoutInfo &info) { - if (shape.size() != kTileRank2D || - llvm::is_contained(shape, ShapedType::kDynamic)) - return false; - - TileLayoutConfig config = getTileLayoutConfig(cfg); - return computeBoxInnerShape(config, elemTy, info) && - computeTilePointerStrides(config, shape, info); -} - -static void collectAffineAddTerms(AffineExpr root, - SmallVectorImpl &terms) { - SmallInlineVector pending{root}; - while (!pending.empty()) { - AffineExpr current = pending.pop_back_val(); - auto addExpr = llvm::dyn_cast(current); - if (!addExpr || addExpr.getKind() != AffineExprKind::Add) { - terms.push_back(current); - continue; - } - pending.push_back(addExpr.getRHS()); - pending.push_back(addExpr.getLHS()); - } -} - -static bool tryAssignAffineStride(AffineExpr expr, - MutableArrayRef strides) { - if (auto dim = llvm::dyn_cast(expr)) { - strides[dim.getPosition()] = 1; - return true; - } - - auto mulExpr = llvm::dyn_cast(expr); - if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul) - return false; - - auto assignStride = [&](AffineExpr dimExpr, - AffineExpr constantExpr) -> bool { - auto dim = llvm::dyn_cast(dimExpr); - auto constant = llvm::dyn_cast(constantExpr); - if (!dim || !constant) - return false; - strides[dim.getPosition()] = constant.getValue(); - return true; - }; - return assignStride(mulExpr.getLHS(), mulExpr.getRHS()) || - assignStride(mulExpr.getRHS(), mulExpr.getLHS()); -} - -[[maybe_unused]] static void decomposeStridedLayout(AffineMap map, - SmallVectorImpl &strides) { - strides.assign(map.getNumDims(), 0); - if (map.getNumResults() != 1) - return; - - SmallInlineVector terms; - collectAffineAddTerms(map.getResult(0), terms); - for (AffineExpr term : terms) - (void)tryAssignAffineStride(term, strides); -} - -static Value makeIndexConstant(IRRewriter &rewriter, Location loc, - int64_t value) { - return rewriter.create(loc, rewriter.getIndexType(), - rewriter.getIndexAttr(value)); -} - -static SmallVector computeCompactStrides(ArrayRef shape) { - SmallVector strides(shape.size(), 1); - int64_t stride = 1; - for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { - strides[i] = stride; - if (shape[i] != ShapedType::kDynamic) - stride *= shape[i]; - } - return strides; -} - -static void materializeStaticValidDims(IRRewriter &rewriter, Location loc, - mlir::pto::TileBufType tbTy, Value &vRow, - Value &vCol) { - ArrayRef validShape = tbTy.getValidShape(); - if (tbTy.hasDynamicValid()) - return; - if (!validShape.empty() && validShape[kRowDimensionIndex] >= 0) - vRow = makeIndexConstant(rewriter, loc, validShape[kRowDimensionIndex]); - if (validShape.size() >= kTileRank2D && - validShape[kColumnDimensionIndex] >= 0) - vCol = makeIndexConstant(rewriter, loc, validShape[kColumnDimensionIndex]); -} - -static bool checkMultipleOf(Operation *op, int64_t value, int64_t divisor, - StringRef label) { - if (divisor <= 0) { - op->emitError("boxed layout requires positive divisor for ") << label; - return false; - } - if (value % divisor == 0) - return true; - op->emitError("boxed layout requires ") - << label << " multiple of " << divisor << ", got " << value; - return false; -} - -// 确保 Value 是 Index 类型 -static Value ensureIndex(IRRewriter &rewriter, Location loc, Value v, - Operation *anchorOp) { - if (v.getType().isIndex()) - return v; - if (isa(v.getType())) - return rewriter.create(loc, rewriter.getIndexType(), v); - if (anchorOp) - anchorOp->emitError() << "expected index or integer, but got " << v.getType(); - return Value(); -} - -static bool tryGetIndexAttrFromValue(IRRewriter &rewriter, Value v, - IntegerAttr &constAttr) { - if (auto cOp = v.getDefiningOp()) { - constAttr = rewriter.getIndexAttr(cOp.value()); - return true; - } - if (auto cInt = v.getDefiningOp()) { - constAttr = rewriter.getIndexAttr(cInt.value()); - return true; - } - return false; -} - -static void appendMixedIndex(IRRewriter &rewriter, Location loc, Value v, - Operation *anchorOp, - SmallVectorImpl &mixedVals) { - IntegerAttr constAttr; - if (tryGetIndexAttrFromValue(rewriter, v, constAttr)) { - mixedVals.push_back(constAttr); - return; - } - mixedVals.push_back(ensureIndex(rewriter, loc, v, anchorOp)); -} - -static bool foldAddPtrChainIntoOffset(IRRewriter &rewriter, Location loc, - Value &base, Value &totalOffset) { - bool folded = false; - while (auto add = base.getDefiningOp()) { - folded = true; - Value off = ensureIndex(rewriter, loc, add.getOperand(1), add); - totalOffset = - totalOffset ? rewriter.create(loc, totalOffset, off) : off; - base = add.getOperand(0); - } - return folded; -} - -static Value clampSubViewValidDim(IRRewriter &rewriter, Location loc, - Value explicitValid, int64_t size, - Operation *anchorOp) { - Value sizeVal = rewriter.create(loc, size); - if (!explicitValid) - return sizeVal; - - int64_t cst = 0; - if (getConstIndexValue(explicitValid, cst)) - return rewriter.create(loc, std::min(cst, size)); - - Value v = ensureIndex(rewriter, loc, explicitValid, anchorOp); - Value lt = rewriter.create(loc, arith::CmpIPredicate::slt, v, - sizeVal); - return rewriter.create(loc, lt, v, sizeVal); -} - -[[maybe_unused]] static void dumpPretty(Operation *op, llvm::raw_ostream &os) { - OpPrintingFlags flags; - flags.useLocalScope(); - AsmState state(op, flags); - op->print(os, state); - os << "\n"; - os.flush(); -} - -// ============================================================================= -// Type Converter Logic -// ============================================================================= - -static SmallVector buildTileMemRefStrides(mlir::pto::TileBufType tbTy) { - TileLayoutInfo info; - if (computeTileLayoutInfo(tbTy.getConfigAttr(), tbTy.getElementType(), - tbTy.getShape(), info)) { - return {info.rowStride, info.colStride}; - } - return computeCompactStrides(tbTy.getShape()); -} - -static Type convertTileBufTypeToMemRef(mlir::pto::TileBufType tbTy) { - auto layoutAttr = StridedLayoutAttr::get(tbTy.getContext(), - ShapedType::kDynamic, - buildTileMemRefStrides(tbTy)); - return MemRefType::get(tbTy.getShape(), tbTy.getElementType(), layoutAttr, - tbTy.getMemorySpace()); -} - -static Type convertPTOTypeToMemRef(Type t) { - // 1. 处理 !pto.ptr - if (auto pty = dyn_cast(t)) { - return MemRefType::get({ShapedType::kDynamic}, pty.getElementType(), - MemRefLayoutAttrInterface(), Attribute()); - } - - // 2. 处理 !pto.tile_buf<...> - if (auto tbTy = dyn_cast(t)) - return convertTileBufTypeToMemRef(tbTy); - if (auto tvTy = dyn_cast(t)) - return MemRefType::get(tvTy.getShape(), tvTy.getElementType(), - MemRefLayoutAttrInterface(), Attribute()); - if (auto partTy = dyn_cast(t)) - return MemRefType::get(partTy.getShape(), partTy.getElementType(), - MemRefLayoutAttrInterface(), Attribute()); - // 其他类型透传 - return t; -} - -// Ensure scf.if result types follow the rewritten yield operand types. -// PTOViewToMemref rewrites tile values to memref in branch bodies, but scf.if -// result types are not auto-updated by those op-local rewrites. -static LogicalResult reconcileSCFIfResultTypes(func::FuncOp func) { - DefaultInlineVector ifOps; - func.walk([&](scf::IfOp ifOp) { ifOps.push_back(ifOp); }); - - for (scf::IfOp ifOp : ifOps) { - if (ifOp.getNumResults() == 0) - continue; - - auto thenYield = dyn_cast(ifOp.thenBlock()->getTerminator()); - auto elseYield = dyn_cast(ifOp.elseBlock()->getTerminator()); - if (!thenYield || !elseYield) { - ifOp.emitError("result-bearing scf.if must end with scf.yield in both " - "then/else regions"); - return failure(); - } - - if (thenYield.getNumOperands() != ifOp.getNumResults() || - elseYield.getNumOperands() != ifOp.getNumResults()) { - ifOp.emitError("scf.if result count does not match yielded values"); - return failure(); - } - - for (unsigned i = 0; i < ifOp.getNumResults(); ++i) { - Type thenTy = thenYield.getOperand(i).getType(); - Type elseTy = elseYield.getOperand(i).getType(); - if (thenTy != elseTy) { - ifOp.emitError() << "scf.if branch yield type mismatch at result #" << i - << ": then=" << thenTy << ", else=" << elseTy; - return failure(); - } - - if (ifOp.getResult(i).getType() != thenTy) - ifOp.getResult(i).setType(thenTy); - } - } - - return success(); -} - -static LogicalResult reconcileSCFForResultTypes(func::FuncOp func) { - DefaultInlineVector forOps; - func.walk([&](scf::ForOp forOp) { forOps.push_back(forOp); }); - - for (scf::ForOp forOp : forOps) { - if (forOp.getNumResults() == 0) - continue; - - auto yield = dyn_cast(forOp.getBody()->getTerminator()); - if (!yield) { - forOp.emitError("result-bearing scf.for must end with scf.yield"); - return failure(); - } - - if (yield.getNumOperands() != forOp.getNumResults() || - forOp.getInitArgs().size() != forOp.getNumResults()) { - forOp.emitError("scf.for result count does not match iter/yield values"); - return failure(); - } - - for (unsigned i = 0; i < forOp.getNumResults(); ++i) { - Type initTy = forOp.getInitArgs()[i].getType(); - Type yieldTy = yield.getOperand(i).getType(); - if (initTy != yieldTy) { - forOp.emitError() << "scf.for init/yield type mismatch at result #" << i - << ": init=" << initTy << ", yield=" << yieldTy; - return failure(); - } - - BlockArgument iterArg = forOp.getRegionIterArg(i); - if (iterArg.getType() != initTy) - iterArg.setType(initTy); - if (forOp.getResult(i).getType() != initTy) - forOp.getResult(i).setType(initTy); - } - } - - return success(); -} - -static LogicalResult markLoweredSetValidShapeOps(func::FuncOp func, - MLIRContext *ctx) { - WalkResult result = func.walk([&](mlir::pto::SetValidShapeOp op) { - if (isa(op.getSource().getType())) { - if (!lookupConfig(op.getSource())) { - op.emitError( - "set_validshape requires a locally bound tile source; function " - "arguments/results are unsupported"); - return WalkResult::interrupt(); - } - op->setAttr(kLoweredSetValidShapeAttrName, UnitAttr::get(ctx)); - return WalkResult::advance(); - } - op->removeAttr(kLoweredSetValidShapeAttrName); - return WalkResult::advance(); - }); - return result.wasInterrupted() ? failure() : success(); -} - -static void markForceDynamicValidShape(Operation *op, bool force, - MLIRContext *ctx) { - if (force) { - op->setAttr(kForceDynamicValidShapeAttrName, UnitAttr::get(ctx)); - return; - } - op->removeAttr(kForceDynamicValidShapeAttrName); -} - -[[maybe_unused]] static void rewriteFunctionSignature(func::FuncOp func, MLIRContext *ctx) { - Block &entry = func.front(); - auto fnTy = func.getFunctionType(); - - SmallVector newInputs; - for (Type type : fnTy.getInputs()) - newInputs.push_back(convertPTOTypeToMemRef(type)); - - SmallVector newResults; - for (Type type : fnTy.getResults()) - newResults.push_back(convertPTOTypeToMemRef(type)); - - for (unsigned i = 0; i < entry.getNumArguments(); ++i) { - if (entry.getArgument(i).getType() != newInputs[i]) - entry.getArgument(i).setType(newInputs[i]); - } - func.setFunctionType(FunctionType::get(ctx, newInputs, newResults)); -} - -[[maybe_unused]] static LogicalResult lowerAllocTileOps(func::FuncOp func, MLIRContext *ctx) { - DefaultInlineVector allocTiles; - func.walk([&](mlir::pto::AllocTileOp op) { allocTiles.push_back(op); }); - - for (auto op : allocTiles) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); - - auto tbTy = dyn_cast(op.getResult().getType()); - if (!tbTy) - continue; - - SmallInlineVector shape(tbTy.getShape().begin(), - tbTy.getShape().end()); - Type elemTy = tbTy.getElementType(); - SmallVector strides = buildTileMemRefStrides(tbTy); - - auto targetLayout = - StridedLayoutAttr::get(ctx, ShapedType::kDynamic, strides); - auto targetType = - MemRefType::get(shape, elemTy, targetLayout, tbTy.getMemorySpace()); - - Value vRow = op.getValidRow(); - Value vCol = op.getValidCol(); - materializeStaticValidDims(rewriter, loc, tbTy, vRow, vCol); - - auto configAttr = tbTy.getConfigAttr(); - if (!configAttr) - configAttr = pto::TileBufConfigAttr::getDefault(ctx); - - if (Value addr = op.getAddr()) { - auto pc = rewriter.create( - loc, targetType, ValueRange{addr}, vRow ? vRow : Value(), - vCol ? vCol : Value(), configAttr); - markForceDynamicValidShape(pc, tbTy.hasDynamicValid(), ctx); - auto bindOp = rewriter.create( - loc, targetType, pc.getResult(), vRow ? vRow : Value(), - vCol ? vCol : Value(), configAttr); - markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); - rewriter.replaceOp(op, bindOp.getResult()); - continue; - } - - auto allocLayout = StridedLayoutAttr::get(ctx, 0, strides); - auto allocType = - MemRefType::get(shape, elemTy, allocLayout, tbTy.getMemorySpace()); - Value alloc = rewriter.create(loc, allocType); - auto bindOp = rewriter.create( - loc, targetType, alloc, vRow ? vRow : Value(), vCol ? vCol : Value(), - configAttr); - markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); - rewriter.replaceOp(op, bindOp.getResult()); - } - return success(); -} - -[[maybe_unused]] static LogicalResult lowerDeclareTileOps(func::FuncOp func, MLIRContext *ctx) { - DefaultInlineVector declaredTiles; - func.walk([&](mlir::pto::DeclareTileOp op) { declaredTiles.push_back(op); }); - - for (auto op : declaredTiles) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); - - auto tbTy = dyn_cast(op.getTile().getType()); - if (!tbTy) { - op.emitError("declare_tile result must be tile_buf type"); - return failure(); - } - - auto targetType = dyn_cast(convertPTOTypeToMemRef(tbTy)); - if (!targetType) { - op.emitError("failed to convert declare_tile result to memref type"); - return failure(); - } - - auto configAttr = tbTy.getConfigAttr(); - if (!configAttr) - configAttr = pto::TileBufConfigAttr::getDefault(ctx); - - Value vRow; - Value vCol; - materializeStaticValidDims(rewriter, loc, tbTy, vRow, vCol); - - auto declaredMemRef = - rewriter.create(loc, targetType); - auto bindOp = rewriter.create( - loc, targetType, declaredMemRef.getResult(), vRow ? vRow : Value(), - vCol ? vCol : Value(), configAttr); - markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); - rewriter.replaceOp(op, bindOp.getResult()); - } - return success(); -} - -static Value castIndexToI64(IRRewriter &rewriter, Location loc, Value value) { - Type i64Ty = rewriter.getI64Type(); - if (value.getType() == i64Ty) - return value; - return rewriter.create(loc, i64Ty, value).getResult(); -} - -static FailureOr -materializePtrToIntAddPtrAddress(IRRewriter &rewriter, Location loc, - mlir::pto::PtrToIntOp anchor, Value source) { - SmallVector addPtrChain; - Value base = source; - while (auto add = base.getDefiningOp()) { - addPtrChain.push_back(add); - base = add.getOperand(0); - } - - if (addPtrChain.empty()) - return failure(); - - auto baseMemTy = dyn_cast(base.getType()); - if (!baseMemTy) { - anchor.emitOpError( - "pto.addptr source base could not be lowered to a GM memref"); - return failure(); - } - - Value byteAddress = rewriter.create( - loc, rewriter.getI64Type(), base); - for (auto add : addPtrChain) { - auto addPtrTy = dyn_cast(add.getResult().getType()); - if (!addPtrTy) { - anchor.emitOpError("requires pto.addptr source to have !pto.ptr result " - "type before byte-address lowering"); - return failure(); - } - - unsigned elemBytes = - mlir::pto::getPTOStorageElemByteSize(addPtrTy.getElementType()); - if (elemBytes == 0) { - anchor.emitOpError("cannot lower pto.addptr source with unknown element " - "byte size to a byte address"); - return failure(); - } - - Value byteOffset = castIndexToI64(rewriter, loc, add.getOffset()); - if (elemBytes != 1) { - Value elemBytesValue = - rewriter.create(loc, elemBytes, 64); - byteOffset = - rewriter.create(loc, byteOffset, elemBytesValue) - .getResult(); - } - byteAddress = - rewriter.create(loc, byteAddress, byteOffset).getResult(); - } - - return byteAddress; -} - -static LogicalResult lowerIntToPtrOps(func::FuncOp func, MLIRContext *ctx) { - DefaultInlineVector intToPtrs; - func.walk([&](mlir::pto::IntToPtrOp op) { intToPtrs.push_back(op); }); - - for (auto op : intToPtrs) { - if (!isa(op.getResult().getType())) - continue; - - auto targetTy = - dyn_cast(convertPTOTypeToMemRef(op.getResult().getType())); - if (!targetTy) { - op.emitError("failed to convert inttoptr result to memref type"); - return failure(); - } - - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - auto lowered = - rewriter.create(op.getLoc(), targetTy, - op.getAddr()); - lowered->setAttrs(op->getAttrs()); - rewriter.replaceOp(op, lowered.getResult()); - } - - return success(); -} - -static LogicalResult lowerPtrToIntOps(func::FuncOp func, MLIRContext *ctx) { - DefaultInlineVector ptrToInts; - func.walk([&](mlir::pto::PtrToIntOp op) { ptrToInts.push_back(op); }); - - for (auto op : ptrToInts) { - Value source = op.getPtr(); - if (source.getDefiningOp()) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - FailureOr byteAddress = - materializePtrToIntAddPtrAddress(rewriter, op.getLoc(), op, source); - if (failed(byteAddress)) - return failure(); - rewriter.replaceOp(op, *byteAddress); - continue; - } - - if (isa(source.getType())) - continue; - } - - DefaultInlineVector remaining; - func.walk([&](mlir::pto::PtrToIntOp op) { - if (isa(op.getPtr().getType())) - remaining.push_back(op); - }); - for (auto op : remaining) { - op.emitError("ptrtoint source could not be lowered to a GM memref"); - return failure(); - } - - return success(); -} - -[[maybe_unused]] static LogicalResult lowerMakeTensorViewOps(func::FuncOp func, MLIRContext *ctx) { - DefaultInlineVector makeViews; - func.walk([&](mlir::pto::MakeTensorViewOp op) { makeViews.push_back(op); }); - - for (auto op : makeViews) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); - - Value baseBuf = op.getOperand(0); - OpFoldResult off0 = rewriter.getIndexAttr(0); - bool foldedAddPtr = false; - { - Value cur = baseBuf; - Value totalOffset; - while (auto add = cur.getDefiningOp()) { - foldedAddPtr = true; - Value off = ensureIndex(rewriter, loc, add.getOperand(1), add); - totalOffset = totalOffset ? rewriter.create(loc, totalOffset, off) - : off; - cur = add.getOperand(0); - } - if (cur != baseBuf) { - baseBuf = cur; - off0 = totalOffset ? OpFoldResult(totalOffset) : off0; - } - } - - auto baseMr = dyn_cast(baseBuf.getType()); - if (!baseMr) { - op.emitError("make_tensor_view base must be memref"); - return failure(); - } - - size_t rank = op.getShape().size(); - int64_t dyn = ShapedType::kDynamic; - SmallVector dynStrides(rank, dyn); - auto layout = - StridedLayoutAttr::get(ctx, /*offset=*/dyn, /*strides=*/dynStrides); - SmallVector dynShape(rank, dyn); - auto mrTy = MemRefType::get(dynShape, baseMr.getElementType(), layout, - baseMr.getMemorySpace()); - - SmallInlineVector sizes; - for (Value value : op.getShape()) - sizes.push_back(ensureIndex(rewriter, loc, value, op)); - SmallInlineVector strides; - for (Value value : op.getStrides()) - strides.push_back(ensureIndex(rewriter, loc, value, op)); - - auto rc = rewriter.create(loc, mrTy, baseBuf, off0, - sizes, strides); - if (foldedAddPtr) - rc->setAttr("pto.addptr_trace", rewriter.getUnitAttr()); - if (auto layoutAttr = op.getLayoutAttr()) - rc->setAttr("layout", layoutAttr); - rewriter.replaceOp(op, rc.getResult()); - } - return success(); -} - -[[maybe_unused]] static LogicalResult lowerTensorViewDimOps(func::FuncOp func, MLIRContext *ctx) { - DefaultInlineVector tvDims; - func.walk([&](mlir::pto::GetTensorViewDimOp op) { tvDims.push_back(op); }); - - for (auto op : tvDims) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Value view = op.getTensorView(); - auto mrTy = dyn_cast(view.getType()); - if (!mrTy) - continue; - Value dim = rewriter.create(op.getLoc(), view, op.getDimIndex()); - rewriter.replaceOp(op, dim); - } - return success(); -} - -[[maybe_unused]] static LogicalResult foldAddPtrIntoScalarOps(func::FuncOp func, MLIRContext *ctx) { - DefaultInlineVector loadScalars; - func.walk([&](mlir::pto::LoadScalarOp op) { loadScalars.push_back(op); }); - for (auto op : loadScalars) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); - - Value base = op.getPtr(); - Value totalOffset = ensureIndex(rewriter, loc, op.getOffset(), op); - bool foldedAddPtr = foldAddPtrChainIntoOffset(rewriter, loc, base, totalOffset); - if (foldedAddPtr) { - auto newOp = - rewriter.create(loc, op.getValue().getType(), base, - totalOffset); - rewriter.replaceOp(op, newOp.getValue()); - } - } - - DefaultInlineVector storeScalars; - func.walk([&](mlir::pto::StoreScalarOp op) { storeScalars.push_back(op); }); - for (auto op : storeScalars) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); - - Value base = op.getPtr(); - Value totalOffset = ensureIndex(rewriter, loc, op.getOffset(), op); - bool foldedAddPtr = foldAddPtrChainIntoOffset(rewriter, loc, base, totalOffset); - if (foldedAddPtr) { - rewriter.create(loc, base, totalOffset, op.getValue()); - rewriter.eraseOp(op); - } - } - - DefaultInlineVector addPtrs; - func.walk([&](mlir::pto::AddPtrOp op) { addPtrs.push_back(op.getOperation()); }); - bool changed = true; - while (changed) { - changed = false; - for (auto &op : addPtrs) { - if (!op) - continue; - if (op->use_empty()) { - op->erase(); - op = nullptr; - changed = true; - } - } - } - for (Operation *op : addPtrs) { - if (!op) - continue; - op->emitError( - "addptr must feed make_tensor_view or load/store_scalar for lowering"); - return failure(); - } - return success(); -} - -static LogicalResult lowerPartitionViewOps(func::FuncOp func, MLIRContext *ctx) { - DefaultInlineVector partitionViews; - func.walk([&](mlir::pto::PartitionViewOp op) { partitionViews.push_back(op); }); - - for (auto op : partitionViews) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); - Value src = op.getOperand(0); - auto srcMrTy = dyn_cast(src.getType()); - if (!srcMrTy) - continue; - int64_t rank = srcMrTy.getRank(); - - SmallVector staticSizes; - SmallVector mixedSizes; - for (Value size : op.getSizes()) { - IntegerAttr constAttr; - bool isStatic = false; - if (auto cOp = size.getDefiningOp()) { - constAttr = rewriter.getIndexAttr(cOp.value()); - isStatic = true; - } else if (auto cInt = size.getDefiningOp()) { - constAttr = rewriter.getIndexAttr(cInt.value()); - isStatic = true; - } - - if (isStatic) { - mixedSizes.push_back(constAttr); - staticSizes.push_back(constAttr.getInt()); - } else { - mixedSizes.push_back(ensureIndex(rewriter, loc, size, op)); - staticSizes.push_back(ShapedType::kDynamic); - } - } - - SmallVector mixedOffsets; - for (Value offset : op.getOffsets()) { - appendMixedIndex(rewriter, loc, offset, op, mixedOffsets); - } - - int64_t dyn = ShapedType::kDynamic; - SmallVector dynStrides(rank, dyn); - auto layout = StridedLayoutAttr::get(ctx, dyn, dynStrides); - auto resTy = MemRefType::get(staticSizes, srcMrTy.getElementType(), layout, - srcMrTy.getMemorySpace()); - - SmallVector mixedStrides(rank, rewriter.getIndexAttr(1)); - auto sv = rewriter.create(loc, resTy, src, mixedOffsets, - mixedSizes, mixedStrides); - if (Operation *srcDef = src.getDefiningOp()) { - if (auto layoutAttr = srcDef->getAttrOfType("layout")) - sv->setAttr("layout", layoutAttr); - } - rewriter.replaceOp(op, sv.getResult()); - } - return success(); -} - -static LogicalResult lowerSubViewOps(func::FuncOp func, MLIRContext *ctx) { - DefaultInlineVector subViews; - func.walk([&](mlir::pto::SubViewOp op) { subViews.push_back(op); }); - - for (auto op : subViews) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); - auto resultTileTy = - dyn_cast(op.getResult().getType()); - Value src = op->getOperand(0); - auto srcMrTy = dyn_cast(src.getType()); - if (!srcMrTy) { - op.emitError("pto.subview source must be lowered to memref first"); - return failure(); - } - - ArrayAttr sizeAttr = op.getSizes(); - SmallVector staticSizes; - SmallVector mixedSizes; - for (Attribute attr : sizeAttr) { - int64_t size = cast(attr).getInt(); - staticSizes.push_back(size); - mixedSizes.push_back(rewriter.getIndexAttr(size)); - } - - SmallVector mixedOffsets; - for (Value offset : op.getOffsets()) { - appendMixedIndex(rewriter, loc, offset, op, mixedOffsets); - } - - auto configAttr = lookupConfig(src); - if (!configAttr) - configAttr = pto::TileBufConfigAttr::getDefault(ctx); - - TileLayoutInfo layoutInfo; - if (!computeTileLayoutInfo(configAttr, srcMrTy.getElementType(), - srcMrTy.getShape(), layoutInfo)) { - op.emitError("unsupported tile layout for pto.subview"); - return failure(); - } - - if (layoutInfo.boxed) { - if (staticSizes.size() != kTileRank2D || - op.getOffsets().size() != kTileRank2D) { - op.emitError("boxed layout subview expects 2D sizes/offsets"); - return failure(); - } - if (!checkMultipleOf(op, staticSizes[0], layoutInfo.innerRows, "row size") || - !checkMultipleOf(op, staticSizes[1], layoutInfo.innerCols, "col size")) { - return failure(); - } - - int64_t off0 = 0; - int64_t off1 = 0; - bool off0Const = getConstIndexValue(op.getOffsets()[0], off0); - bool off1Const = getConstIndexValue(op.getOffsets()[1], off1); - if (off0Const && - !checkMultipleOf(op, off0, layoutInfo.innerRows, "row offset")) { - return failure(); - } - if (off1Const && - !checkMultipleOf(op, off1, layoutInfo.innerCols, "col offset")) { - return failure(); - } - - } - - SmallVector srcStrides; - int64_t srcOffset = ShapedType::kDynamic; - if (failed(getStridesAndOffset(srcMrTy, srcStrides, srcOffset))) - srcStrides = computeCompactStrides(srcMrTy.getShape()); - - // Keep parent physical shape + strides for bound tile semantics. - auto resultLayout = - StridedLayoutAttr::get(ctx, ShapedType::kDynamic, srcStrides); - auto parentShape = srcMrTy.getShape(); - auto resultMemRefType = - MemRefType::get(parentShape, srcMrTy.getElementType(), resultLayout, - srcMrTy.getMemorySpace()); - - // Intermediate memref.subview keeps logical subview size. - auto subViewMemRefType = - MemRefType::get(staticSizes, srcMrTy.getElementType(), resultLayout, - srcMrTy.getMemorySpace()); - - SmallVector mixedStrides(staticSizes.size(), - rewriter.getIndexAttr(1)); - auto sv = rewriter.create(loc, subViewMemRefType, src, - mixedOffsets, mixedSizes, - mixedStrides); - - Value vRow; - Value vCol; - if (!staticSizes.empty()) - vRow = clampSubViewValidDim(rewriter, loc, op.getValidRow(), - staticSizes[0], op); - if (staticSizes.size() > 1) - vCol = clampSubViewValidDim(rewriter, loc, op.getValidCol(), - staticSizes[1], op); - - auto bindOp = rewriter.create( - loc, resultMemRefType, sv.getResult(), vRow ? vRow : Value(), - vCol ? vCol : Value(), configAttr); - markForceDynamicValidShape(bindOp, - resultTileTy && resultTileTy.hasDynamicValid(), - ctx); - bindOp->setAttr("pto.view_semantics", rewriter.getStringAttr("subview")); - rewriter.replaceOp(op, bindOp.getResult()); - } - return success(); -} - -static Value buildTileBufViewLikeValue(Operation *anchorOp, Value src, - mlir::pto::TileBufType tbTy, - StringRef viewSemantics, - MLIRContext *ctx) { - Location loc = anchorOp->getLoc(); - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(anchorOp); - - auto srcMrTy = dyn_cast(src.getType()); - if (!srcMrTy) { - anchorOp->emitError("tile_buf view op src must be lowered to memref first"); - return Value(); - } - - auto targetType = dyn_cast(convertPTOTypeToMemRef(tbTy)); - if (!targetType) { - anchorOp->emitError("failed to convert tile_buf type to memref type"); - return Value(); - } - for (int64_t dim : targetType.getShape()) { - if (dim == ShapedType::kDynamic) { - anchorOp->emitError("dynamic shapes are not supported for tile_buf view ops"); - return Value(); - } - } - - Value parentVRow; - Value parentVCol; - lookupValidDims(src, parentVRow, parentVCol); - Value vRow = parentVRow; - Value vCol = parentVCol; - materializeStaticValidDims(rewriter, loc, tbTy, vRow, vCol); - - auto configAttr = tbTy.getConfigAttr(); - if (!configAttr) - configAttr = pto::TileBufConfigAttr::getDefault(ctx); - - auto bindOp = rewriter.create( - loc, targetType, src, vRow ? vRow : Value(), vCol ? vCol : Value(), - configAttr); - markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); - if (!viewSemantics.empty()) - bindOp->setAttr("pto.view_semantics", rewriter.getStringAttr(viewSemantics)); - return bindOp.getResult(); -} - -static LogicalResult lowerTileBufViewLikeOps(func::FuncOp func, MLIRContext *ctx) { - DefaultInlineVector reshapes; - func.walk([&](mlir::pto::TReshapeOp op) { reshapes.push_back(op); }); - for (auto op : reshapes) { - auto tbTy = dyn_cast(op.getResult().getType()); - if (!tbTy) { - op.emitError("treshape result must be tile_buf type"); - return failure(); - } - Value lowered = buildTileBufViewLikeValue(op, op->getOperand(0), tbTy, - "treshape", ctx); - if (!lowered) - return failure(); - IRRewriter rewriter(ctx); - rewriter.replaceOp(op, lowered); - } - - DefaultInlineVector bitcasts; - func.walk([&](mlir::pto::BitcastOp op) { bitcasts.push_back(op); }); - for (auto op : bitcasts) { - auto tbTy = dyn_cast(op.getResult().getType()); - if (!tbTy) { - op.emitError("bitcast result must be tile_buf type"); - return failure(); - } - Value lowered = buildTileBufViewLikeValue(op, op->getOperand(0), tbTy, - "bitcast", ctx); - if (!lowered) - return failure(); - IRRewriter rewriter(ctx); - rewriter.replaceOp(op, lowered); - } - return success(); -} - -// ============================================================================= -// The Pass Implementation -// ============================================================================= - -struct PTOViewToMemrefPass - : public mlir::pto::impl::PTOViewToMemrefBase { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PTOViewToMemrefPass) - - void runOnOperation() override { - ModuleOp mod = getOperation(); - MLIRContext *ctx = &getContext(); - - for (auto func : mod.getOps()) { - if (func.isExternal()) continue; - - // ------------------------------------------------------------------ - // Stage 0: ensure inttoptr values remain scalar-load/store only. - // ------------------------------------------------------------------ - if (failed(validateIntToPtrUses(func))) { - signalPassFailure(); - return; - } - - Block &entry = func.front(); - auto fnTy = func.getFunctionType(); - - // ------------------------------------------------------------------ - // Stage 0.10: Rewrite Function Signature - // ------------------------------------------------------------------ - SmallVector newInputs; - for (Type t : fnTy.getInputs()) newInputs.push_back(convertPTOTypeToMemRef(t)); - - SmallVector newResults; - for (Type t : fnTy.getResults()) newResults.push_back(convertPTOTypeToMemRef(t)); - - // Update entry block arguments - for (unsigned i = 0; i < entry.getNumArguments(); ++i) { - if (entry.getArgument(i).getType() != newInputs[i]) { - entry.getArgument(i).setType(newInputs[i]); - } - } - - // Update function type - func.setFunctionType(FunctionType::get(ctx, newInputs, newResults)); - - // ------------------------------------------------------------------ - // Stage 0.20: lower pto.inttoptr result types to GM memrefs. - // ------------------------------------------------------------------ - if (failed(lowerIntToPtrOps(func, ctx))) { - signalPassFailure(); - return; - } - - // ------------------------------------------------------------------ - // Stage 0.30: materialize pto.ptrtoint(addptr ...) byte offsets. - // ------------------------------------------------------------------ - if (failed(lowerPtrToIntOps(func, ctx))) { - signalPassFailure(); - return; - } - - // ------------------------------------------------------------------ - // Stage 0.5: lower pto.alloc_tile -> memref.alloc + pto.bind_tile - // ------------------------------------------------------------------ - DefaultInlineVector allocTiles; - func.walk([&](mlir::pto::AllocTileOp op) { allocTiles.push_back(op); }); - - for (auto op : allocTiles) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); - - auto tbTy = dyn_cast(op.getResult().getType()); - if (!tbTy) continue; - - // 1. 获取 Shape 和 ElementType - SmallInlineVector shape(tbTy.getShape().begin(), tbTy.getShape().end()); - Type elemTy = tbTy.getElementType(); - - // 2. 计算 Strides (layout-aware when possible) - SmallVector strides; - TileLayoutInfo info; - if (computeTileLayoutInfo(tbTy.getConfigAttr(), elemTy, shape, info)) { - strides = {info.rowStride, info.colStride}; - } else { - strides.resize(shape.size()); - int64_t s = 1; - for (int i = (int)shape.size() - 1; i >= 0; --i) { - strides[i] = s; - if (shape[i] != ShapedType::kDynamic) s *= shape[i]; - } - } - - // 3. 构造 [BindTile 输出] 的动态类型 (Offset: ?) - // 这必须与 convertPTOTypeToMemRef 返回的类型一致,以便与 Subview 兼容 - auto targetLayout = - StridedLayoutAttr::get(ctx, ShapedType::kDynamic, strides); // offset = ? - auto targetType = - MemRefType::get(shape, elemTy, targetLayout, tbTy.getMemorySpace()); - - // 4. Preserve tile valid dims (v_row / v_col). - // - // `pto.alloc_tile` encodes the valid shape in the result TileBufType - // (e.g. acc tile may be rows=16 but v_row=1). The alloc op itself does - // not necessarily carry explicit operands for static valid dims, so we - // must materialize them from the type to keep them through - // tile_buf -> memref lowering. - // - // For dynamically valid tiles (validShape == [-1, -1]), preserve the - // runtime operands if present. - Value vRow = op.getValidRow(); - Value vCol = op.getValidCol(); - // TileBuf valid dims use a negative sentinel (e.g. '?' / -1), which is - // distinct from MLIR's ShapedType::kDynamic (INT64_MIN). Treat any - // negative value as dynamic here. - materializeStaticValidDims(rewriter, loc, tbTy, vRow, vCol); - - // 5. 获取 Config (保持不变) - auto configAttr = tbTy.getConfigAttr(); - if (!configAttr) configAttr = pto::TileBufConfigAttr::getDefault(ctx); - - // 6. If alloc_tile provides an explicit address, keep the original - // pointer_cast lowering intact and additionally rebind through - // pto.bind_tile. PointerCastOp continues to carry the tile metadata - // used by existing lowering paths, while BindTileOp provides the - // unified anchor EmitC uses to recover tile_buf information. - if (Value addr = op.getAddr()) { - auto pc = rewriter.create( - loc, targetType, ValueRange{addr}, vRow ? vRow : Value(), - vCol ? vCol : Value(), configAttr); - markForceDynamicValidShape(pc, tbTy.hasDynamicValid(), ctx); - auto bindOp = rewriter.create( - loc, targetType, pc.getResult(), vRow ? vRow : Value(), - vCol ? vCol : Value(), configAttr); - markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); - rewriter.replaceOp(op, bindOp.getResult()); - continue; - } - - // 7. Otherwise allocate a concrete memref buffer and bind tile. - // memref.alloc 要求明确的 layout,不能是动态 offset。 - auto allocLayout = StridedLayoutAttr::get(ctx, 0, strides); // offset = 0 - auto allocType = MemRefType::get(shape, elemTy, allocLayout, tbTy.getMemorySpace()); - Value alloc = rewriter.create(loc, allocType); - - // BindTileOp 的 Builder 会自动处理空的 Value,将其视为静态维度 - auto bindOp = rewriter.create( - loc, targetType, alloc, vRow ? vRow : Value(), vCol ? vCol : Value(), - configAttr); - markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); - - rewriter.replaceOp(op, bindOp.getResult()); - } - - // ------------------------------------------------------------------ - // Stage 0.75: lower pto.declare_tile -> pto.declare_tile_memref + - // pto.bind_tile - // ------------------------------------------------------------------ - DefaultInlineVector declaredTiles; - func.walk([&](mlir::pto::DeclareTileOp op) { declaredTiles.push_back(op); }); - - for (auto op : declaredTiles) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); - - auto tbTy = dyn_cast(op.getTile().getType()); - if (!tbTy) { - op.emitError("declare_tile result must be tile_buf type"); - signalPassFailure(); - return; - } - - auto targetType = dyn_cast(convertPTOTypeToMemRef(tbTy)); - if (!targetType) { - op.emitError("failed to convert declare_tile result to memref type"); - signalPassFailure(); - return; - } - - auto configAttr = tbTy.getConfigAttr(); - if (!configAttr) - configAttr = pto::TileBufConfigAttr::getDefault(ctx); - - Value vRow; - Value vCol; - materializeStaticValidDims(rewriter, loc, tbTy, vRow, vCol); - - auto declaredMemRef = - rewriter.create(loc, targetType); - auto bindOp = rewriter.create( - loc, targetType, declaredMemRef.getResult(), - vRow ? vRow : Value(), vCol ? vCol : Value(), configAttr); - markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); - - rewriter.replaceOp(op, bindOp.getResult()); - } - - // ------------------------------------------------------------------ - // Stage 0.8: normalize pto.tassign result type to match tile operand - // after tile_buf -> memref lowering (required for verifier consistency). - // ------------------------------------------------------------------ - DefaultInlineVector tassignOps; - func.walk([&](mlir::pto::TAssignOp op) { tassignOps.push_back(op); }); - for (auto op : tassignOps) { - Type targetTy = op.getTile().getType(); - if (op.getResult().getType() == targetTy) - continue; - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - auto normalized = - rewriter.create(op.getLoc(), targetTy, op.getTile(), - op.getAddr()); - rewriter.replaceOp(op, normalized.getResult()); - } - - // ------------------------------------------------------------------ - // Stage 1: Lower pto.make_tensor_view -> memref.reinterpret_cast - // ------------------------------------------------------------------ - DefaultInlineVector makeViews; - func.walk([&](mlir::pto::MakeTensorViewOp op) { makeViews.push_back(op); }); - - for (auto op : makeViews) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); - - Value baseBuf = op.getOperand(0); - OpFoldResult off0 = rewriter.getIndexAttr(0); - - // Fold pto.addptr chains into the view base to avoid nested reinterpret_cast. - bool foldedAddPtr = false; - { - Value cur = baseBuf; - Value totalOffset; - while (auto add = cur.getDefiningOp()) { - foldedAddPtr = true; - Value off = ensureIndex(rewriter, loc, add.getOperand(1), add); - if (totalOffset) - totalOffset = rewriter.create(loc, totalOffset, off); - else - totalOffset = off; - cur = add.getOperand(0); - } - if (cur != baseBuf) { - baseBuf = cur; - off0 = totalOffset ? OpFoldResult(totalOffset) : off0; - } - } - - auto baseMr = dyn_cast(baseBuf.getType()); - if (!baseMr) { - op.emitError("make_tensor_view base must be memref"); signalPassFailure(); return; - } - - // [修复] 获取动态 Rank (根据 shape 输入的数量) - size_t rank = op.getShape().size(); - - // Construct target type with dynamic offset/strides - Type elemTy = baseMr.getElementType(); - int64_t dyn = ShapedType::kDynamic; - - // [修复] 构建 N 维 Strided Layout - // strides 数组长度必须等于 rank - SmallVector dynStrides(rank, dyn); - auto layout = StridedLayoutAttr::get(ctx, /*offset=*/dyn, /*strides=*/dynStrides); - - // [修复] 构建 N 维 Shape - SmallVector dynShape(rank, dyn); - auto mrTy = MemRefType::get(dynShape, elemTy, layout, baseMr.getMemorySpace()); - - SmallInlineVector sizes; - for (Value v : op.getShape()) sizes.push_back(ensureIndex(rewriter, loc, v, op)); - - SmallInlineVector strides; - for (Value v : op.getStrides()) strides.push_back(ensureIndex(rewriter, loc, v, op)); - - auto rc = rewriter.create( - loc, mrTy, baseBuf, off0, sizes, strides); - if (foldedAddPtr) { - rc->setAttr("pto.addptr_trace", rewriter.getUnitAttr()); - } - if (auto layoutAttr = op.getLayoutAttr()) { - rc->setAttr("layout", layoutAttr); - } - - rewriter.replaceOp(op, rc.getResult()); - } - - // ------------------------------------------------------------------ - // Stage 1.25: Lower pto.get_tensor_view_dim -> memref.dim - // ------------------------------------------------------------------ - DefaultInlineVector tvDims; - func.walk([&](mlir::pto::GetTensorViewDimOp op) { tvDims.push_back(op); }); - - for (auto op : tvDims) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); - - Value view = op.getTensorView(); - auto mrTy = dyn_cast(view.getType()); - if (!mrTy) - continue; // leave it to later passes if it hasn't been lowered yet - - Value dimIdx = op.getDimIndex(); - Value dim = rewriter.create(loc, view, dimIdx); - rewriter.replaceOp(op, dim); - } - - // ------------------------------------------------------------------ - // Stage 1.3: Lower pto.partition_view -> memref.subview - // ------------------------------------------------------------------ - if (failed(lowerPartitionViewOps(func, ctx))) { - signalPassFailure(); - return; - } - - // ------------------------------------------------------------------ - // Stage 1.35: Lower pto.subview -> memref.subview + pto.bind_tile - // ------------------------------------------------------------------ - if (failed(lowerSubViewOps(func, ctx))) { - signalPassFailure(); - return; - } - - // ------------------------------------------------------------------ - // Stage 1.4: Lower tile_buf view-like ops (treshape/bitcast) - // ------------------------------------------------------------------ - if (failed(lowerTileBufViewLikeOps(func, ctx))) { - signalPassFailure(); - return; - } - - // ------------------------------------------------------------------ - // Stage 1.5: Fold pto.addptr chains into load/store_scalar. - // ------------------------------------------------------------------ - DefaultInlineVector loadScalars; - func.walk([&](mlir::pto::LoadScalarOp op) { loadScalars.push_back(op); }); - - for (auto op : loadScalars) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); - - Value base = op.getPtr(); - Value totalOffset = ensureIndex(rewriter, loc, op.getOffset(), op); - - bool foldedAddPtr = false; - while (auto add = base.getDefiningOp()) { - foldedAddPtr = true; - Value off = ensureIndex(rewriter, loc, add.getOperand(1), add); - if (totalOffset) - totalOffset = rewriter.create(loc, totalOffset, off); - else - totalOffset = off; - base = add.getOperand(0); - } - - if (foldedAddPtr) { - auto newOp = rewriter.create( - loc, op.getValue().getType(), base, totalOffset); - rewriter.replaceOp(op, newOp.getValue()); - } - } - - DefaultInlineVector storeScalars; - func.walk([&](mlir::pto::StoreScalarOp op) { storeScalars.push_back(op); }); - - for (auto op : storeScalars) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); - - Value base = op.getPtr(); - Value totalOffset = ensureIndex(rewriter, loc, op.getOffset(), op); - - bool foldedAddPtr = false; - while (auto add = base.getDefiningOp()) { - foldedAddPtr = true; - Value off = ensureIndex(rewriter, loc, add.getOperand(1), add); - if (totalOffset) - totalOffset = rewriter.create(loc, totalOffset, off); - else - totalOffset = off; - base = add.getOperand(0); - } - - if (foldedAddPtr) { - rewriter.create( - loc, base, totalOffset, op.getValue()); - rewriter.eraseOp(op); - } - } - - // ------------------------------------------------------------------ - // Stage 1.75: Fold addptr used by initialize_l2g2l_pipe(gm_addr). - // This keeps IR well-typed after function arguments are rewritten from - // !pto.ptr to memref. - // ------------------------------------------------------------------ - bool foldedPipeInitAddPtr = true; - while (foldedPipeInitAddPtr) { - foldedPipeInitAddPtr = false; - DefaultInlineVector addPtrsForPipeInit; - func.walk([&](mlir::pto::AddPtrOp op) { - bool eligible = !op->use_empty(); - for (Operation *user : op->getUsers()) { - auto init = dyn_cast(user); - if (!init || init.getGmAddr() != op->getResult(0)) { - eligible = false; - break; - } - } - if (eligible) - addPtrsForPipeInit.push_back(op); - }); - - for (auto op : addPtrsForPipeInit) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Location loc = op.getLoc(); - - Value base = op->getOperand(0); - Value totalOffset = ensureIndex(rewriter, loc, op->getOperand(1), op); - while (auto add = base.getDefiningOp()) { - Value off = ensureIndex(rewriter, loc, add->getOperand(1), add); - totalOffset = rewriter.create(loc, totalOffset, off); - base = add->getOperand(0); - } - - auto baseMrTy = dyn_cast(base.getType()); - if (!baseMrTy || baseMrTy.getRank() != 1) - continue; - - int64_t dyn = ShapedType::kDynamic; - auto layout = StridedLayoutAttr::get(ctx, dyn, {dyn}); - auto targetTy = MemRefType::get({dyn}, baseMrTy.getElementType(), layout, - baseMrTy.getMemorySpace()); - SmallVector sizes{rewriter.getIndexAttr(1)}; - SmallVector strides{rewriter.getIndexAttr(1)}; - auto rc = rewriter.create( - loc, targetTy, base, OpFoldResult(totalOffset), sizes, strides); - rc->setAttr("pto.addptr_trace", rewriter.getUnitAttr()); - rewriter.replaceOp(op, rc.getResult()); - foldedPipeInitAddPtr = true; - } - } - - // Clean up: addptr should be folded into make_tensor_view. - DefaultInlineVector addPtrs; - func.walk([&](mlir::pto::AddPtrOp op) { addPtrs.push_back(op.getOperation()); }); - bool changed = true; - while (changed) { - changed = false; - for (auto &op : addPtrs) { - if (!op) - continue; - if (op->use_empty()) { - op->erase(); - op = nullptr; - changed = true; - } - } - } - for (auto *op : addPtrs) { - if (!op) - continue; - op->emitError("addptr must feed make_tensor_view, initialize_l2g2l_pipe(gm_addr) or load/store_scalar for lowering"); - signalPassFailure(); - return; - } - - // ------------------------------------------------------------------ - // Stage 3: Rewrite Compute Ops - // [关键] 全面使用 op->getOperand(i) 避免 Typed Accessor Crash - // ------------------------------------------------------------------ - - // --- TLoadOp [Src, Dst] --- - DefaultInlineVector loads; - func.walk([&](mlir::pto::TLoadOp op) { loads.push_back(op); }); - for (auto op : loads) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op->getOperand(0); - Value dst = op->getOperand(1); - - auto newOp = - rewriter.create(op.getLoc(), TypeRange{}, src, dst); - newOp->setAttrs(op->getAttrs()); - rewriter.replaceOp(op, newOp->getResults()); - } - - // --- TStoreOp [Src, Dst] --- - DefaultInlineVector storeops; - func.walk([&](mlir::pto::TStoreOp op) { storeops.push_back(op); }); - for (auto op : storeops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op->getOperand(0); - Value dst = op->getOperand(1); - Value preQuant = op.getPreQuantScalar(); - - pto::TStoreOp newOp; - if (preQuant) { - newOp = rewriter.create(op.getLoc(), TypeRange{}, - src, dst, preQuant); - } else { - newOp = rewriter.create(op.getLoc(), TypeRange{}, - src, dst, Value{}); - } - newOp->setAttrs(op->getAttrs()); - rewriter.replaceOp(op, newOp->getResults()); - } - - // --- TTransOp [Src, Tmp, Dst] --- - DefaultInlineVector trans; - func.walk([&](mlir::pto::TTransOp op) { trans.push_back(op); }); - for (auto op : trans) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex)); - } - - // --- TExpOp [Src, Dst] --- - DefaultInlineVector exp; - func.walk([&](mlir::pto::TExpOp op) { exp.push_back(op); }); - for (auto op : exp) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, op->getOperand(0), op->getOperand(1)); - } - - // --- TMulOp [Src, Scalar, Dst] --- - DefaultInlineVector mul; - func.walk([&](mlir::pto::TMulOp op) { mul.push_back(op); }); - for (auto op : mul) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, op->getOperand(0), op.getOperand(1), - op->getOperand(kThirdOperandIndex)); - } - - // --- TMulSOp [Src, Scalar, Dst] --- - DefaultInlineVector muls; - func.walk([&](mlir::pto::TMulSOp op) { muls.push_back(op); }); - for (auto op : muls) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, op->getOperand(0), op.getScalar(), - op->getOperand(kThirdOperandIndex)); - } - - // --- TAddOp [Src0, Src1, Dst] --- - DefaultInlineVector addops; - func.walk([&](mlir::pto::TAddOp op) { addops.push_back(op); }); - for (auto op : addops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex)); - } - - // --- TMatmulOp [Lhs, Rhs, Dst] (no optional bias in ODS) --- - DefaultInlineVector matmuls; - func.walk([&](mlir::pto::TMatmulOp op) { matmuls.push_back(op); }); - for (auto op : matmuls) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Value lhs = op->getOperand(0); - Value rhs = op->getOperand(1); - Value dst = op->getOperand(kThirdOperandIndex); - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, lhs, rhs, dst, op.getAccPhaseAttr()); - } - - // --- TMatmulAccOp [Acc, Lhs, Rhs, Dst] --- - DefaultInlineVector matmulAccs; - func.walk([&](mlir::pto::TMatmulAccOp op) { matmulAccs.push_back(op); }); - for (auto op : matmulAccs) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex), op.getAccPhaseAttr()); - } - - // --- TMatmulBiasOp [Acc, Lhs, Rhs, Bias, Dst] --- - DefaultInlineVector matmulBiass; - func.walk([&](mlir::pto::TMatmulBiasOp op) { matmulBiass.push_back(op); }); - for (auto op : matmulBiass) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex)); - } - - // --- TMatmulMxOp--- - DefaultInlineVector matmulMxs; - func.walk([&](mlir::pto::TMatmulMxOp op) { matmulMxs.push_back(op); }); - for (auto op : matmulMxs) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex), - op->getOperand(kFifthOperandIndex)); - } - - // --- TMatmulMxAccOp --- - DefaultInlineVector matmulMxAccs; - func.walk([&](mlir::pto::TMatmulMxAccOp op) { matmulMxAccs.push_back(op); }); - for (auto op : matmulMxAccs) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex), - op->getOperand(kFifthOperandIndex), - op->getOperand(kSixthOperandIndex)); - } - - // --- TMatmulMxBiasOp --- - DefaultInlineVector matmulMxBiass; - func.walk([&](mlir::pto::TMatmulMxBiasOp op) { matmulMxBiass.push_back(op); }); - for (auto op : matmulMxBiass) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex), - op->getOperand(kFifthOperandIndex), - op->getOperand(kSixthOperandIndex)); - } - - // --- TGemvOp [Lhs, Rhs, Dst] --- - DefaultInlineVector gemvs; - func.walk([&](mlir::pto::TGemvOp op) { gemvs.push_back(op); }); - for (auto op : gemvs) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value lhs = op->getOperand(0); - Value rhs = op->getOperand(1); - Value dst = op->getOperand(kThirdOperandIndex); - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, lhs, rhs, dst); - } - - // --- TGemvAccOp [Acc, Lhs, Rhs, Dst] --- - DefaultInlineVector gemvAccs; - func.walk([&](mlir::pto::TGemvAccOp op) { gemvAccs.push_back(op); }); - for (auto op : gemvAccs) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex)); - } - - // --- TGemvBiasOp [Acc, Lhs, Rhs, Bias, Dst] --- - DefaultInlineVector gemvBiass; - func.walk([&](mlir::pto::TGemvBiasOp op) { gemvBiass.push_back(op); }); - for (auto op : gemvBiass) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex)); - } - - // --- TGemvMxOp [A, AScale, B, BScale, Dst] --- - DefaultInlineVector gemvMxs; - func.walk([&](mlir::pto::TGemvMxOp op) { gemvMxs.push_back(op); }); - for (auto op : gemvMxs) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex), - op->getOperand(kFifthOperandIndex)); - } - - // --- TGemvMxAccOp [CIn, A, AScale, B, BScale, Dst] --- - DefaultInlineVector gemvMxAccs; - func.walk([&](mlir::pto::TGemvMxAccOp op) { gemvMxAccs.push_back(op); }); - for (auto op : gemvMxAccs) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex), - op->getOperand(kFifthOperandIndex), - op->getOperand(kSixthOperandIndex)); - } - - // --- TGemvMxBiasOp [A, AScale, B, BScale, Bias, Dst] --- - DefaultInlineVector gemvMxBiass; - func.walk([&](mlir::pto::TGemvMxBiasOp op) { gemvMxBiass.push_back(op); }); - for (auto op : gemvMxBiass) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex), - op->getOperand(kFifthOperandIndex), - op->getOperand(kSixthOperandIndex)); - } - - // --- TMovOp [Src, Dst] --- - DefaultInlineVector movs; - func.walk([&](mlir::pto::TMovOp op) { movs.push_back(op); }); - for (auto op : movs) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, op.getSrc(), op.getDst(), op.getFp(), - op.getPreQuantScalar(), op.getAccToVecModeAttr(), - op.getReluPreModeAttr()); - } - - DefaultInlineVector abseops; - func.walk([&](mlir::pto::TAbsOp op) { abseops.push_back(op); }); - - for (auto op : abseops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - DefaultInlineVector addcops; - func.walk([&](mlir::pto::TAddCOp op) { addcops.push_back(op); }); - - for (auto op : addcops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value src2 = op.getSrc2(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto src2Ty = dyn_cast(src2.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !src2Ty ||!dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - src2, - dst); - } - - DefaultInlineVector addsops; - func.walk([&](mlir::pto::TAddSOp op) { addsops.push_back(op); }); - - for (auto op : addsops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value scalar = op.getScalar(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - scalar, - dst); - } - - DefaultInlineVector addscops; - func.walk([&](mlir::pto::TAddSCOp op) { addscops.push_back(op); }); - - for (auto op : addscops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value scalar = op.getScalar(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - scalar, - src1, - dst); - } - - DefaultInlineVector andops; - func.walk([&](mlir::pto::TAndOp op) { andops.push_back(op); }); - - for (auto op : andops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - dst); - } - - DefaultInlineVector concats; - func.walk([&](mlir::pto::TConcatOp op) { concats.push_back(op); }); - - for (auto op : concats) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - dst); - } - - DefaultInlineVector concatIdxs; - func.walk([&](mlir::pto::TConcatidxOp op) { concatIdxs.push_back(op); }); - - IRRewriter rewriter(ctx); - for (auto op : concatIdxs) { - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value src0Idx = op.getSrc0Idx(); - Value src1Idx = op.getSrc1Idx(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto src0IdxTy = dyn_cast(src0Idx.getType()); - auto src1IdxTy = dyn_cast(src1Idx.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !src0IdxTy || !src1IdxTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - src0Idx, - src1Idx, - dst); - } - - DefaultInlineVector andsops; - func.walk([&](mlir::pto::TAndSOp op) { andsops.push_back(op); }); - - for (auto op : andsops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value scalar = op.getScalar(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - scalar, - dst); - } - - DefaultInlineVector ciops; - func.walk([&](mlir::pto::TCIOp op) { ciops.push_back(op); }); - - for (auto op : ciops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value s = op->getOperand(0); - Value dst = op.getDst(); - bool descending = op.getDescending(); - - auto sTy = dyn_cast(s.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!sTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - s, - dst, - descending); - } - - DefaultInlineVector cmpops; - func.walk([&](mlir::pto::TCmpOp op) { cmpops.push_back(op); }); - - for (auto op : cmpops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - auto newOp = rewriter.create( - op.getLoc(), - TypeRange{}, - src0, - src1, - dst); - - if (auto a = op.getCmpModeAttr()) - newOp->setAttr("cmpMode", a); - - rewriter.replaceOp(op, newOp->getResults()); // 0 results -> OK - } - - DefaultInlineVector cmpsops; - func.walk([&](mlir::pto::TCmpSOp op) { cmpsops.push_back(op); }); - - for (auto op : cmpsops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value scalar = op.getScalar(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - auto scalarTy = scalar.getType(); - bool scalarOk = - isa(scalarTy); // ScalarType in ODS: int/float - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - if (!scalarOk) { - op.emitError("expects scalar to be an integer or float type"); - signalPassFailure(); - return; - } - - auto cmpMode = op.getCmpModeAttr(); - auto newOp = rewriter.create( - op.getLoc(), - TypeRange{}, - src, - scalar, - cmpMode, - dst); - - rewriter.replaceOp(op, newOp->getResults()); // 0 results -> OK - } - - DefaultInlineVector colexpand; - func.walk([&](mlir::pto::TColExpandOp op) { colexpand.push_back(op); }); - - for (auto op : colexpand) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if ( !srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - DefaultInlineVector colmaxops; - func.walk([&](mlir::pto::TColMaxOp op) { colmaxops.push_back(op); }); - - for (auto op : colmaxops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if ( !srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - DefaultInlineVector colminops; - func.walk([&](mlir::pto::TColMinOp op) { colminops.push_back(op); }); - - for (auto op : colminops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if ( !srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - DefaultInlineVector colexpandmulops; - func.walk([&](mlir::pto::TColExpandMulOp op) { - colexpandmulops.push_back(op); - }); - - for (auto op : colexpandmulops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - dst); - } - - DefaultInlineVector colexpandmaxops; - func.walk([&](mlir::pto::TColExpandMaxOp op) { - colexpandmaxops.push_back(op); - }); - - for (auto op : colexpandmaxops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - dst); - } - - DefaultInlineVector colexpandminops; - func.walk([&](mlir::pto::TColExpandMinOp op) { - colexpandminops.push_back(op); - }); - - for (auto op : colexpandminops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - dst); - } - - DefaultInlineVector colsumops; - func.walk([&](mlir::pto::TColSumOp op) { colsumops.push_back(op); }); - - for (auto op : colsumops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - Value tmp = op.getTmp(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("src/dst are not memref yet"); - signalPassFailure(); - return; - } - - // If tmp exists, it must have isBinary attribute - if (tmp) { - auto tmpTy = dyn_cast(tmp.getType()); - if (!tmpTy) { - op.emitError("tmp is not memref yet"); - signalPassFailure(); - return; - } - - // Get isBinary attribute (should exist if tmp exists) - BoolAttr isBinaryAttr = op.getIsBinaryAttr(); - if (!isBinaryAttr) { - isBinaryAttr = BoolAttr::get(ctx, false); - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - tmp, - dst, - isBinaryAttr); - } else { - // Format 1: no tmp, no isBinary - // Use generic builder to avoid adding default isBinary attribute - SmallVector operands = {src, dst}; - SmallVector attrs; - // Copy all attributes except isBinary - for (auto attr : op->getAttrs()) { - if (attr.getName() != "isBinary") { - attrs.push_back(attr); - } - } - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - operands, - attrs); - } - } - - DefaultInlineVector cvtops; - func.walk([&](mlir::pto::TCvtOp op) { cvtops.push_back(op); }); - - for (auto op : cvtops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - auto rmodeAttr = op.getRmodeAttr(); // PTO_RoundModeAttr - auto satModeAttr = op.getSatModeAttr(); - - auto newOp = rewriter.create( - op.getLoc(), - TypeRange{}, - src, - dst, - rmodeAttr, - satModeAttr); - - rewriter.replaceOp(op, newOp->getResults()); - } - - DefaultInlineVector divops; - func.walk([&](mlir::pto::TDivOp op) { divops.push_back(op); }); - - for (auto op : divops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - dst); - } - - DefaultInlineVector divsops; - func.walk([&](mlir::pto::TDivSOp op) { divsops.push_back(op); }); - - for (auto op : divsops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value scale = op.getScalar(); - Value dst = op.getDst(); - - // Check types - they might still be TileBufType or already converted to MemRefType - auto srcTy = dyn_cast(src.getType()); - auto srcTileTy = dyn_cast(src.getType()); - auto scaleTileTy = dyn_cast(scale.getType()); - auto dstTy = dyn_cast(dst.getType()); - auto dstTileTy = dyn_cast(dst.getType()); - - // Determine which operand is tile-like and which is scalar-like. - // Keep the original operand order (set by parser textual form). - // Check if src is memref/tensor/tile (not scalar) - bool srcIsMemref = (srcTy != nullptr || srcTileTy != nullptr || - isa(src.getType()) || - isa(src.getType())); - // Check if scale is memref/tensor/tile (not scalar) - bool scaleIsMemref = (isa(scale.getType()) || - scaleTileTy != nullptr || - isa(scale.getType()) || - isa(scale.getType())); - - // Type validation - ensure we have the right types - if (!srcIsMemref && !scaleIsMemref) { - op.emitError("at least one operand (src or scale) must be tile_buf or memref"); - signalPassFailure(); - return; - } - if (srcIsMemref && scaleIsMemref) { - op.emitError("exactly one operand (src or scale) must be tile_buf or memref, the other must be scalar"); - signalPassFailure(); - return; - } - - if (!dstTy && !dstTileTy) { - op.emitError("dst operand must be tile_buf or memref"); - signalPassFailure(); - return; - } - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - scale, - dst); - } - - DefaultInlineVector expandsops; - func.walk([&](mlir::pto::TExpandsOp op) { expandsops.push_back(op); }); - - for (auto op : expandsops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value scalar = op.getScalar(); - Value dst = op.getDst(); - - auto dstTy = dyn_cast(dst.getType()); - if (!dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - scalar, - dst); - } - - DefaultInlineVector extractops; - func.walk([&](mlir::pto::TExtractOp op) { extractops.push_back(op); }); - - for (auto op : extractops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value indexRow = op.getIndexRow(); - Value indexCol = op.getIndexCol(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto indexRowTy = dyn_cast(indexRow.getType()); - auto indexColTy = dyn_cast(indexCol.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !indexRowTy || !indexColTy || !dstTy) { - op.emitError("ins/outs are not correct yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - indexRow, - indexCol, - dst); - } - - DefaultInlineVector fillpadops; - func.walk([&](mlir::pto::TFillPadOp op) { fillpadops.push_back(op); }); - - for (auto op : fillpadops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - DefaultInlineVector fillpadInplaceOps; - func.walk( - [&](mlir::pto::TFillPadInplaceOp op) { fillpadInplaceOps.push_back(op); }); - - for (auto op : fillpadInplaceOps) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - // --- TSetValOp [Dst, Offset, Val] --- - // Lower tile-world scalar write to memref-world SETVAL DPS op. - DefaultInlineVector tsetvalops; - func.walk([&](mlir::pto::TSetValOp op) { tsetvalops.push_back(op); }); - - for (auto op : tsetvalops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value dst = op.getDst(); - Value offset = op.getOffset(); - Value val = op.getVal(); - - auto dstTy = dyn_cast(dst.getType()); - if (!dstTy) { - op.emitError("dst is not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - dst, - offset, - val); - } - - // --- TGetValOp [Src, Offset] -> Scalar --- - // Lower tile-world scalar read to memref-world GETVAL DPS op. - DefaultInlineVector tgetvalops; - func.walk([&](mlir::pto::TGetValOp op) { tgetvalops.push_back(op); }); - - for (auto op : tgetvalops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value offset = op.getOffset(); - Type dstType = op.getDst().getType(); - - auto srcTy = dyn_cast(src.getType()); - if (!srcTy) { - op.emitError("src is not memref yet"); - signalPassFailure(); - return; - } - - auto newOp = rewriter.create( - op.getLoc(), - dstType, - src, - offset); - rewriter.replaceOp(op, newOp.getDst()); - } - - DefaultInlineVector gatherops; - func.walk([&](mlir::pto::TGatherOp op) { gatherops.push_back(op); }); - - for (auto op : gatherops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - Value cdst = op.getCdst(); - Value indices = op.getIndices(); - Value tmp = op.getTmp(); - Value kValue = op.getKValue(); - auto maskPattern = op.getMaskPatternAttr(); - auto cmpMode = op.getCmpModeAttr(); - auto offset = op.getOffsetAttr(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - if (maskPattern) { - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst, - /*cdst=*/Value(), - /*indices=*/Value(), - /*tmp=*/Value(), - /*kValue=*/Value(), - /*maskPattern=*/maskPattern, - /*cmpMode=*/pto::CmpModeAttr(), - /*offset=*/IntegerAttr()); - continue; - } - - if (cdst || kValue) { - auto cdstTy = dyn_cast(cdst.getType()); - auto tmpTy = dyn_cast(tmp.getType()); - if (!cdstTy || !tmpTy) { - op.emitError("compare-form tgather expects cdst/tmp to be memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst, - cdst, - /*indices=*/Value(), - tmp, - kValue, - /*maskPattern=*/pto::MaskPatternAttr(), - cmpMode, - offset); - continue; - } - - if (indices || tmp) { - auto indicesTy = dyn_cast(indices.getType()); - auto tmpTy = dyn_cast(tmp.getType()); - if (!indicesTy || !tmpTy) { - op.emitError("index-form tgather expects indices/tmp to be memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst, - /*cdst=*/Value(), - indices, - tmp, - /*kValue=*/Value(), - /*maskPattern=*/pto::MaskPatternAttr(), - /*cmpMode=*/pto::CmpModeAttr(), - /*offset=*/IntegerAttr()); - continue; - } - - op.emitError("expects tgather to be in mask, index+tmp, or compare+tmp form"); - signalPassFailure(); - return; - } - - DefaultInlineVector gatherbops; - func.walk([&](mlir::pto::TGatherBOp op) { gatherbops.push_back(op); }); - - for (auto op : gatherbops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value offsets = op.getOffsets(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto offsetsTy = dyn_cast(offsets.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !offsetsTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - offsets, - dst); - } - - DefaultInlineVector logops; - func.walk([&](mlir::pto::TLogOp op) { logops.push_back(op); }); - - for (auto op : logops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - DefaultInlineVector lreluops; - func.walk([&](mlir::pto::TLReluOp op) { lreluops.push_back(op); }); - - for (auto op : lreluops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value slope = op.getSlope(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto slopeTy = dyn_cast(slope.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !slopeTy || !dstTy) { - op.emitError("ins/outs are not correct type yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - slope, - dst); - } - - DefaultInlineVector maxops; - func.walk([&](mlir::pto::TMaxOp op) { maxops.push_back(op); }); - - for (auto op : maxops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - dst); - } - - DefaultInlineVector maxsops; - func.walk([&](mlir::pto::TMaxSOp op) { maxsops.push_back(op); }); - - for (auto op : maxsops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value scalar = op.getScalar(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - bool scalarIsScalar = isa(scalar.getType()); - if (!srcTy || !scalarIsScalar || !dstTy) { - op.emitError("expects src/dst to be memref and scalar to be integer/float"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - scalar, - dst); - } - - DefaultInlineVector minops; - func.walk([&](mlir::pto::TMinOp op) { minops.push_back(op); }); - - for (auto op : minops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - dst); - } - - DefaultInlineVector minsops; - func.walk([&](mlir::pto::TMinSOp op) { minsops.push_back(op); }); - - for (auto op : minsops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value scalar = op.getScalar(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - bool scalarIsScalar = isa(scalar.getType()); - if (!srcTy || !scalarIsScalar || !dstTy) { - op.emitError("expects src/dst to be memref and scalar to be integer/float"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - scalar, - dst); - } - - DefaultInlineVector movfpops; - func.walk([&](mlir::pto::TMovFPOp op) { movfpops.push_back(op); }); - - for (auto op : movfpops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value fp = op.getFp(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto fpTy = dyn_cast(fp.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !fpTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - fp, - dst); - } - - DefaultInlineVector quantops; - func.walk([&](mlir::pto::TQuantOp op) { quantops.push_back(op); }); - - for (auto op : quantops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value fp = op.getFp(); - Value offset = op.getOffset(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto fpTy = dyn_cast(fp.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !fpTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - if (offset && !dyn_cast(offset.getType())) { - op.emitError("offset is not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - fp, - offset, - dst, - op.getQuantTypeAttr()); - } - - DefaultInlineVector mrgsortops; - func.walk([&](mlir::pto::TMrgSortOp op) { mrgsortops.push_back(op); }); - - for (auto op : mrgsortops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - if (op.isFormat1()) { - Value src = op.getSrc(); - Value dst = op.getDst(); - Value blockLenVal = op.getBlockLen(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - ValueRange{src}, - blockLenVal, - ValueRange{dst}, - Value() /*tmp*/, - Value() /*excuted*/, - op.getExhaustedAttr()); - } else if (op.isFormat2()) { - bool allMemRef = true; - for (Value v : op.getSrcs()) - if (!dyn_cast(v.getType())) { allMemRef = false; break; } - if (!allMemRef) { - op.emitError("format2 ins/outs are not memref yet"); - signalPassFailure(); - return; - } - if (op.getDsts().size() != 1u || !op.getTmp()) { - op.emitError("format2 expects outs(dst) and ins(tmp)"); - signalPassFailure(); - return; - } - - Value dst = op.getDst(); - Value tmp = op.getTmp(); - Value excuted = op.getExcuted(); - if (!dyn_cast(dst.getType()) || !dyn_cast(tmp.getType())) { - op.emitError("format2 dst/tmp must be memref"); - signalPassFailure(); - return; - } - if (!dyn_cast(excuted.getType())) { - op.emitError("format2 outs(excuted) must be vector"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - op.getSrcs(), - Value() /*blockLen*/, - ValueRange{dst}, - tmp, - excuted, - op.getExhaustedAttr()); - } else { - op.emitError("tmrgsort must be format1 or format2"); - signalPassFailure(); - return; - } - } - - DefaultInlineVector negops; - func.walk([&](mlir::pto::TNegOp op) { negops.push_back(op); }); - - for (auto op : negops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - DefaultInlineVector notops; - func.walk([&](mlir::pto::TNotOp op) { notops.push_back(op); }); - - for (auto op : notops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - DefaultInlineVector orops; - func.walk([&](mlir::pto::TOrOp op) { orops.push_back(op); }); - - for (auto op : orops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - src0, - src1, - dst); - } - - DefaultInlineVector orsops; - func.walk([&](mlir::pto::TOrSOp op) { orsops.push_back(op); }); - - for (auto op : orsops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value scalar = op.getScalar(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto scalarTy = dyn_cast(scalar.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !scalarTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - scalar, - dst); - } - - DefaultInlineVector partaddops; - func.walk([&](mlir::pto::TPartAddOp op) { partaddops.push_back(op); }); - - for (auto op : partaddops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - src0, - src1, - dst); - } - - DefaultInlineVector partmulops; - func.walk([&](mlir::pto::TPartMulOp op) { partmulops.push_back(op); }); - - for (auto op : partmulops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - src0, - src1, - dst); - } - - DefaultInlineVector mgatherops; - func.walk([&](mlir::pto::MGatherOp op) { mgatherops.push_back(op); }); - - for (auto op : mgatherops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value dst = op.getDst(); - Value idx = op.getIdx(); - Value mem = op.getMem(); - - auto dstTy = dyn_cast(dst.getType()); - auto idxTy = dyn_cast(idx.getType()); - auto memTy = dyn_cast(mem.getType()); - if (!dstTy || !idxTy || !memTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - mem, - idx, - dst, - op.getGatherOobAttr()); - } - - DefaultInlineVector mascatterops; - func.walk([&](mlir::pto::MScatterOp op) { mascatterops.push_back(op); }); - - for (auto op : mascatterops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value idx = op.getIdx(); - Value mem = op.getMem(); - - auto srcTy = dyn_cast(src.getType()); - auto idxTy = dyn_cast(idx.getType()); - auto memTy = dyn_cast(mem.getType()); - if (!srcTy || !idxTy || !memTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - idx, - mem, - op.getScatterAtomicOpAttr(), - op.getScatterOobAttr()); - } - DefaultInlineVector printops; - func.walk([&](mlir::pto::TPrintOp op) { printops.push_back(op); }); - - for (auto op : printops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - - auto srcTy = dyn_cast(src.getType()); - if (!srcTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src); - } - - // ------------------------------------------------------------------ - // Stage 4: Reconcile control-flow result types - // ------------------------------------------------------------------ - if (failed(reconcileSCFIfResultTypes(func))) { - signalPassFailure(); - return; - } - if (failed(reconcileSCFForResultTypes(func))) { - signalPassFailure(); - return; - } - - // Mark memref-form set_validshape only after control-flow result-type - // reconciliation. Values such as scf.if results can stay tile_buf until - // this late stage. - if (failed(markLoweredSetValidShapeOps(func, ctx))) { - signalPassFailure(); - return; - } - } - - // Debug Output - LLVM_DEBUG(llvm::dbgs() << mod.getOperation()); - } -}; - -} // namespace - -std::unique_ptr createPTOViewToMemrefPass() { - return std::make_unique(); -} - -} // namespace pto -} // namespace mlir diff --git a/tools/ptobc/generated/ptobc_opcodes_v0.def b/tools/ptobc/generated/ptobc_opcodes_v0.def deleted file mode 100644 index 8303e1261..000000000 --- a/tools/ptobc/generated/ptobc_opcodes_v0.def +++ /dev/null @@ -1,722 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -// Generated by docs/bytecode/tools/gen_v0_tables.py -#pragma once - -#include -#include - -#include -#include - -namespace ptobc::v0 { - -inline constexpr uint8_t kVariantDefault = 0; -inline constexpr uint8_t kVariantAcc = 1; -inline constexpr uint8_t kVariantBias = 2; -inline constexpr uint8_t kVariantMx = 3; -inline constexpr uint8_t kVariantMxAcc = 4; -inline constexpr uint8_t kVariantMxBias = 5; -inline constexpr uint8_t kSectionCubeVariant = 0; -inline constexpr uint8_t kSectionVectorVariant = 1; -inline constexpr uint8_t kHasVariant = 1; -inline constexpr uint16_t kTscatterMaskOpcode = 0x109C; - -inline constexpr int kTgemvOperandCount = 3; -inline constexpr int kTgemvAccOperandCount = 4; -inline constexpr int kTgemvBiasOperandCount = 4; -inline constexpr int kTgemvMxOperandCount = 5; -inline constexpr int kTgemvMxAccOperandCount = 6; -inline constexpr int kTgemvMxBiasOperandCount = 6; -inline constexpr int kTmatmulOperandCount = 3; -inline constexpr int kTmatmulAccOperandCount = 4; -inline constexpr int kTmatmulBiasOperandCount = 4; -inline constexpr int kTmatmulMxOperandCount = 5; -inline constexpr int kTmatmulMxAccOperandCount = 6; -inline constexpr int kTmatmulMxBiasOperandCount = 6; - -struct OpInfo { - uint16_t opcode; - const char *name; - uint8_t has_variant_u8; - uint8_t result_type_mode; - uint8_t operand_mode; - uint16_t num_operands; - uint16_t num_results; - uint16_t num_regions; - uint8_t imm_kind; -}; - -inline constexpr OpInfo kOpTable[] = { - {0x0000, "pto.get_block_idx", 0, 0x01, 0x00, 0, 1, 0, 0x00}, - {0x0001, "pto.get_block_num", 0, 0x01, 0x00, 0, 1, 0, 0x00}, - {0x0002, "pto.get_subblock_idx", 0, 0x01, 0x00, 0, 1, 0, 0x00}, - {0x0003, "pto.get_subblock_num", 0, 0x01, 0x00, 0, 1, 0, 0x00}, - {0x0004, "pto.make_tensor_view", 0, 0x01, 0x03, 1, 1, 0, 0x06}, - {0x0005, "pto.partition_view", 0, 0x01, 0x03, 1, 1, 0, 0x07}, - {0x0006, "pto.section", 1, 0x00, 0x00, 0, 0, 1, 0x00}, - {0x1000, "pto.addptr", 0, 0x01, 0x00, 2, 1, 0, 0x00}, - {0x1001, "pto.alloc_tile", 0, 0x01, 0x04, 0, 1, 0, 0x08}, - {0x1002, "pto.barrier", 0, 0x00, 0x00, 0, 0, 0, 0x00}, - {0x1003, "pto.mgather", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1004, "pto.mscatter", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1005, "pto.record_event", 0, 0x00, 0x00, 0, 0, 0, 0x02}, - {0x1006, "pto.tabs", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1007, "pto.tadd", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1008, "pto.taddc", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x1009, "pto.tadds", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x100A, "pto.taddsc", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x100B, "pto.tand", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x100C, "pto.tands", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x100D, "pto.tci", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x100E, "pto.tcmp", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x100F, "pto.tcmps", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1010, "pto.tcolexpand", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1011, "pto.tcolexpandadd", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1012, "pto.tcolexpanddiv", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1013, "pto.tcolexpandexpdif", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1014, "pto.tcolexpandmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1015, "pto.tcolexpandmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1016, "pto.tcolexpandmul", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1017, "pto.tcolexpandsub", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1018, "pto.tcolmax", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1019, "pto.tcolmin", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x101A, "pto.tcolprod", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x101B, "pto.tcolsum", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x101C, "pto.tcvt", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x101D, "pto.tdiv", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x101E, "pto.tdivs", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x101F, "pto.texp", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1020, "pto.texpands", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1021, "pto.textract", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x1022, "pto.textract_fp", 0, 0x00, 0x00, 5, 0, 0, 0x00}, - {0x1023, "pto.tfillpad", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1024, "pto.tfillpad_expand", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1025, "pto.tfillpad_inplace", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1026, "pto.tfmod", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1027, "pto.tfmods", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1028, "pto.tgather", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1029, "pto.tgatherb", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x102A, "pto.tgemv", 1, 0x00, 0x01, 0, 0, 0, 0x00}, - {0x102B, "pto.tgetval", 0, 0x01, 0x00, 2, 1, 0, 0x00}, - {0x102C, "pto.timg2col", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x102D, "pto.tinsert", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x102E, "pto.tinsert_fp", 0, 0x00, 0x00, 5, 0, 0, 0x00}, - {0x102F, "pto.tload", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1030, "pto.tlog", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1031, "pto.tlrelu", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1032, "pto.tmatmul", 1, 0x00, 0x01, 0, 0, 0, 0x00}, - {0x1033, "pto.tmatmul.mx", 1, 0x00, 0x01, 0, 0, 0, 0x00}, - {0x1034, "pto.tmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1035, "pto.tmaxs", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1036, "pto.tmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1037, "pto.tmins", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1038, "pto.tmov", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1039, "pto.tmov.fp", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x103A, "pto.tmrgsort", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x103B, "pto.tmul", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x103C, "pto.tmuls", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x103D, "pto.tneg", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x103E, "pto.tnot", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x103F, "pto.tor", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1040, "pto.tors", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1041, "pto.tpartadd", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1042, "pto.tpartmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1043, "pto.tpartmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1044, "pto.tpartmul", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1045, "pto.tprefetch", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1046, "pto.tprelu", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x1047, "pto.tquant", 0, 0x00, 0x02, 3, 0, 0, 0x00}, - {0x1048, "pto.trecip", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1049, "pto.trelu", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x104A, "pto.trem", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x104B, "pto.trems", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x104C, "pto.treshape", 0, 0x01, 0x00, 1, 1, 0, 0x00}, - {0x104D, "pto.trowexpand", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x104E, "pto.trowexpandadd", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x104F, "pto.trowexpandexpdif", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1050, "pto.trowexpandmax", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1051, "pto.trowexpandmin", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1052, "pto.trowmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1053, "pto.trowmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1054, "pto.trowsum", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1055, "pto.trsqrt", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1056, "pto.tscatter", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1057, "pto.tsel", 0, 0x00, 0x00, 5, 0, 0, 0x00}, - {0x1058, "pto.tsels", 0, 0x00, 0x00, 5, 0, 0, 0x00}, - {0x1059, "pto.tset_img2col_padding", 0, 0x00, 0x00, 1, 0, 0, 0x00}, - {0x105A, "pto.tset_img2col_rpt", 0, 0x00, 0x00, 1, 0, 0, 0x00}, - {0x105B, "pto.tsetfmatrix", 0, 0x00, 0x00, 1, 0, 0, 0x00}, - {0x105C, "pto.tsethf32mode", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x105D, "pto.tsettf32mode", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x105E, "pto.tsetval", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x105F, "pto.tshl", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1060, "pto.tshls", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1061, "pto.tshr", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1062, "pto.tshrs", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1063, "pto.tsort32", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1064, "pto.tsqrt", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1065, "pto.tstore", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1066, "pto.tstore_fp", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1067, "pto.tsub", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1068, "pto.tsubc", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x1069, "pto.tsubs", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x106A, "pto.tsubsc", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x106B, "pto.trowexpandsub", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x106C, "pto.ttrans", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x106D, "pto.ttri", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x106E, "pto.txor", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x106F, "pto.txors", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x1070, "pto.wait_event", 0, 0x00, 0x00, 0, 0, 0, 0x02}, - {0x1071, "pto.tprint", 0, 0x00, 0x00, 1, 0, 0, 0x00}, - {0x1072, "pto.subview", 0, 0x01, 0x02, 0, 1, 0, 0x00}, - {0x1073, "pto.trowexpanddiv", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1074, "pto.trowexpandmul", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1075, "pto.tdequant", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x1076, "pto.taxpy", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1077, "pto.thistogram", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1078, "pto.tget_scale_addr", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1079, "pto.trowargmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x107A, "pto.trowargmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x107B, "pto.tcolargmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x107C, "pto.tcolargmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x107D, "pto.tsync", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x107E, "pto.reserve_buffer", 0, 0x01, 0x00, 0, 1, 0, 0x00}, - {0x107F, "pto.import_reserved_buffer", 0, 0x01, 0x00, 0, 1, 0, 0x00}, - {0x1080, "pto.aic_initialize_pipe", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1081, "pto.aiv_initialize_pipe", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1082, "pto.tpush_to_aiv", 0, 0x00, 0x00, 1, 0, 0, 0x00}, - {0x1083, "pto.tpush_to_aic", 0, 0x00, 0x00, 1, 0, 0, 0x00}, - {0x1084, "pto.tpop_from_aic", 0, 0x01, 0x02, 0, 1, 0, 0x00}, - {0x1085, "pto.tpop_from_aiv", 0, 0x01, 0x02, 0, 1, 0, 0x00}, - {0x1086, "pto.tfree_from_aic", 0, 0x00, 0x00, 0, 0, 0, 0x00}, - {0x1087, "pto.tfree_from_aiv", 0, 0x00, 0x00, 0, 0, 0, 0x00}, - {0x1088, "pto.set_validshape", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1089, "pto.tconcat", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x108A, "pto.trowprod", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x108B, "pto.initialize_l2g2l_pipe", 0, 0x01, 0x02, 0, 1, 0, 0x00}, - {0x108C, "pto.initialize_l2l_pipe", 0, 0x01, 0x02, 0, 1, 0, 0x00}, - {0x108D, "pto.tpush", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x108E, "pto.declare_tile", 0, 0x01, 0x00, 0, 1, 0, 0x00}, - {0x108F, "pto.tpop", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1090, "pto.tfree", 0, 0x00, 0x00, 1, 0, 0, 0x00}, - {0x1091, "pto.comm.tput", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1092, "pto.comm.tget", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1093, "pto.comm.tnotify", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1094, "pto.comm.twait", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1095, "pto.comm.ttest", 0, 0x01, 0x02, 0, 1, 0, 0x00}, - {0x1096, "pto.comm.tbroadcast", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1097, "pto.comm.tgather", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1098, "pto.comm.tscatter", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1099, "pto.comm.treduce", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x109A, "pto.tpartargmax", 0, 0x00, 0x00, 6, 0, 0, 0x00}, - {0x109B, "pto.tpartargmin", 0, 0x00, 0x00, 6, 0, 0, 0x00}, - {0x109C, "pto.tscatter.maskpattern", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x2000, "arith.addi", 0, 0x01, 0x00, 2, 1, 0, 0x00}, - {0x2001, "arith.ceildivsi", 0, 0x01, 0x00, 2, 1, 0, 0x00}, - {0x2002, "arith.cmpi", 0, 0x01, 0x00, 2, 1, 0, 0x01}, - {0x2003, "arith.constant", 0, 0x01, 0x00, 0, 1, 0, 0x05}, - {0x2004, "arith.index_cast", 0, 0x01, 0x00, 1, 1, 0, 0x00}, - {0x2005, "arith.minui", 0, 0x01, 0x00, 2, 1, 0, 0x00}, - {0x2006, "arith.muli", 0, 0x01, 0x00, 2, 1, 0, 0x00}, - {0x2007, "arith.select", 0, 0x01, 0x00, 3, 1, 0, 0x00}, - {0x2008, "arith.subi", 0, 0x01, 0x00, 2, 1, 0, 0x00}, - {0x4000, "scf.for", 0, 0x00, 0x00, 3, 0, 1, 0x00}, - {0x4001, "scf.if", 0, 0x00, 0x00, 1, 0, 2, 0x00}, - {0x4002, "scf.yield", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x6000, "func.func", 0, 0x00, 0x00, 0, 0, 0, 0x00}, - {0x6001, "func.return", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x6002, "func.call", 0, 0x02, 0x02, 0, 0, 0, 0x00}, -}; - -inline const OpInfo *lookupByOpcode(uint16_t opcode) { - // Binary search on kOpTable (sorted by opcode). - size_t lo = 0, hi = sizeof(kOpTable) / sizeof(kOpTable[0]); - while (lo < hi) { - size_t mid = lo + (hi - lo) / 2; - uint16_t v = kOpTable[mid].opcode; - if (v == opcode) return &kOpTable[mid]; - if (v < opcode) lo = mid + 1; else hi = mid; - } - return nullptr; -} - -inline std::optional lookupOpcodeByName(llvm::StringRef name) { - uint16_t v = llvm::StringSwitch(name) - .Case("arith.addi", 0x2000) - .Case("arith.ceildivsi", 0x2001) - .Case("arith.cmpi", 0x2002) - .Case("arith.constant", 0x2003) - .Case("arith.index_cast", 0x2004) - .Case("arith.minui", 0x2005) - .Case("arith.muli", 0x2006) - .Case("arith.select", 0x2007) - .Case("arith.subi", 0x2008) - .Case("func.func", 0x6000) - .Case("func.return", 0x6001) - .Case("func.call", 0x6002) - .Case("pto.addptr", 0x1000) - .Case("pto.alloc_tile", 0x1001) - .Case("pto.barrier", 0x1002) - .Case("pto.get_block_idx", 0x0000) - .Case("pto.get_block_num", 0x0001) - .Case("pto.get_subblock_idx", 0x0002) - .Case("pto.get_subblock_num", 0x0003) - .Case("pto.make_tensor_view", 0x0004) - .Case("pto.mgather", 0x1003) - .Case("pto.mscatter", 0x1004) - .Case("pto.partition_view", 0x0005) - .Case("pto.record_event", 0x1005) - .Case("pto.section", 0x0006) - .Case("pto.tabs", 0x1006) - .Case("pto.tadd", 0x1007) - .Case("pto.taddc", 0x1008) - .Case("pto.tadds", 0x1009) - .Case("pto.taddsc", 0x100A) - .Case("pto.tand", 0x100B) - .Case("pto.tands", 0x100C) - .Case("pto.tci", 0x100D) - .Case("pto.tcmp", 0x100E) - .Case("pto.tcmps", 0x100F) - .Case("pto.tcolexpand", 0x1010) - .Case("pto.tcolexpandadd", 0x1011) - .Case("pto.tcolexpanddiv", 0x1012) - .Case("pto.tcolexpandexpdif", 0x1013) - .Case("pto.tcolexpandmax", 0x1014) - .Case("pto.tcolexpandmin", 0x1015) - .Case("pto.tcolexpandmul", 0x1016) - .Case("pto.tcolexpandsub", 0x1017) - .Case("pto.tcolmax", 0x1018) - .Case("pto.tcolmin", 0x1019) - .Case("pto.tcolprod", 0x101A) - .Case("pto.tcolsum", 0x101B) - .Case("pto.tcvt", 0x101C) - .Case("pto.tdiv", 0x101D) - .Case("pto.tdivs", 0x101E) - .Case("pto.texp", 0x101F) - .Case("pto.texpands", 0x1020) - .Case("pto.textract", 0x1021) - .Case("pto.textract_fp", 0x1022) - .Case("pto.tfillpad", 0x1023) - .Case("pto.tfillpad_expand", 0x1024) - .Case("pto.tfillpad_inplace", 0x1025) - .Case("pto.tfmod", 0x1026) - .Case("pto.tfmods", 0x1027) - .Case("pto.tgather", 0x1028) - .Case("pto.tgatherb", 0x1029) - .Case("pto.tgemv", 0x102A) - .Case("pto.tgetval", 0x102B) - .Case("pto.timg2col", 0x102C) - .Case("pto.tinsert", 0x102D) - .Case("pto.tinsert_fp", 0x102E) - .Case("pto.tload", 0x102F) - .Case("pto.tlog", 0x1030) - .Case("pto.tlrelu", 0x1031) - .Case("pto.tmatmul", 0x1032) - .Case("pto.tmatmul.mx", 0x1033) - .Case("pto.tmax", 0x1034) - .Case("pto.tmaxs", 0x1035) - .Case("pto.tmin", 0x1036) - .Case("pto.tmins", 0x1037) - .Case("pto.tmov", 0x1038) - .Case("pto.tmov.fp", 0x1039) - .Case("pto.tmrgsort", 0x103A) - .Case("pto.tmul", 0x103B) - .Case("pto.tmuls", 0x103C) - .Case("pto.tneg", 0x103D) - .Case("pto.tnot", 0x103E) - .Case("pto.tor", 0x103F) - .Case("pto.tors", 0x1040) - .Case("pto.tpartadd", 0x1041) - .Case("pto.tpartmax", 0x1042) - .Case("pto.tpartmin", 0x1043) - .Case("pto.tpartmul", 0x1044) - .Case("pto.tprefetch", 0x1045) - .Case("pto.tprelu", 0x1046) - .Case("pto.tquant", 0x1047) - .Case("pto.trecip", 0x1048) - .Case("pto.trelu", 0x1049) - .Case("pto.trem", 0x104A) - .Case("pto.trems", 0x104B) - .Case("pto.treshape", 0x104C) - .Case("pto.trowexpand", 0x104D) - .Case("pto.trowexpandadd", 0x104E) - .Case("pto.trowexpandexpdif", 0x104F) - .Case("pto.trowexpandmax", 0x1050) - .Case("pto.trowexpandmin", 0x1051) - .Case("pto.trowmax", 0x1052) - .Case("pto.trowmin", 0x1053) - .Case("pto.trowsum", 0x1054) - .Case("pto.trsqrt", 0x1055) - .Case("pto.tscatter", 0x1056) - .Case("pto.tsel", 0x1057) - .Case("pto.tsels", 0x1058) - .Case("pto.tset_img2col_padding", 0x1059) - .Case("pto.tset_img2col_rpt", 0x105A) - .Case("pto.tsetfmatrix", 0x105B) - .Case("pto.tsethf32mode", 0x105C) - .Case("pto.tsettf32mode", 0x105D) - .Case("pto.tsetval", 0x105E) - .Case("pto.tshl", 0x105F) - .Case("pto.tshls", 0x1060) - .Case("pto.tshr", 0x1061) - .Case("pto.tshrs", 0x1062) - .Case("pto.tsort32", 0x1063) - .Case("pto.tsqrt", 0x1064) - .Case("pto.tstore", 0x1065) - .Case("pto.tstore_fp", 0x1066) - .Case("pto.tsub", 0x1067) - .Case("pto.tsubc", 0x1068) - .Case("pto.tsubs", 0x1069) - .Case("pto.tsubsc", 0x106A) - .Case("pto.trowexpandsub", 0x106B) - .Case("pto.ttrans", 0x106C) - .Case("pto.ttri", 0x106D) - .Case("pto.txor", 0x106E) - .Case("pto.txors", 0x106F) - .Case("pto.wait_event", 0x1070) - .Case("pto.tprint", 0x1071) - .Case("pto.subview", 0x1072) - .Case("pto.trowexpanddiv", 0x1073) - .Case("pto.trowexpandmul", 0x1074) - .Case("pto.tdequant", 0x1075) - .Case("pto.taxpy", 0x1076) - .Case("pto.thistogram", 0x1077) - .Case("pto.tget_scale_addr", 0x1078) - .Case("pto.trowargmax", 0x1079) - .Case("pto.trowargmin", 0x107A) - .Case("pto.tcolargmax", 0x107B) - .Case("pto.tcolargmin", 0x107C) - .Case("pto.tsync", 0x107D) - .Case("pto.reserve_buffer", 0x107E) - .Case("pto.import_reserved_buffer", 0x107F) - .Case("pto.aic_initialize_pipe", 0x1080) - .Case("pto.aiv_initialize_pipe", 0x1081) - .Case("pto.tpush_to_aiv", 0x1082) - .Case("pto.tpush_to_aic", 0x1083) - .Case("pto.tpop_from_aic", 0x1084) - .Case("pto.tpop_from_aiv", 0x1085) - .Case("pto.tfree_from_aic", 0x1086) - .Case("pto.tfree_from_aiv", 0x1087) - .Case("pto.set_validshape", 0x1088) - .Case("pto.tconcat", 0x1089) - .Case("pto.trowprod", 0x108A) - .Case("pto.initialize_l2g2l_pipe", 0x108B) - .Case("pto.initialize_l2l_pipe", 0x108C) - .Case("pto.tpush", 0x108D) - .Case("pto.declare_tile", 0x108E) - .Case("pto.tpop", 0x108F) - .Case("pto.tfree", 0x1090) - .Case("pto.comm.tput", 0x1091) - .Case("pto.comm.tget", 0x1092) - .Case("pto.comm.tnotify", 0x1093) - .Case("pto.comm.twait", 0x1094) - .Case("pto.comm.ttest", 0x1095) - .Case("pto.comm.tbroadcast", 0x1096) - .Case("pto.comm.tgather", 0x1097) - .Case("pto.comm.tscatter", 0x1098) - .Case("pto.comm.treduce", 0x1099) - .Case("pto.tpartargmax", 0x109A) - .Case("pto.tpartargmin", 0x109B) - .Case("scf.for", 0x4000) - .Case("scf.if", 0x4001) - .Case("scf.yield", 0x4002) - .Default(0xFFFF); - if (v == 0xFFFF) return std::nullopt; - return v; -} - -inline const OpInfo *lookupByName(llvm::StringRef name) { - auto o = lookupOpcodeByName(name); - if (!o) return nullptr; - return lookupByOpcode(*o); -} - -struct OpcodeAndVariant { uint16_t opcode; uint8_t hasVariant; uint8_t variant; }; - -inline std::optional lookupOpcodeAndVariantByFullName(llvm::StringRef fullName) { - // For non-family ops, variant is 0. For family ops, variant is the assigned u8. - // NOTE: `pto.section` is not a real op name; use `pto.section.cube`/`pto.section.vector`. - return llvm::StringSwitch>(fullName) - .Case("arith.addi", OpcodeAndVariant{0x2000, 0, 0}) - .Case("arith.ceildivsi", OpcodeAndVariant{0x2001, 0, 0}) - .Case("arith.cmpi", OpcodeAndVariant{0x2002, 0, 0}) - .Case("arith.constant", OpcodeAndVariant{0x2003, 0, 0}) - .Case("arith.index_cast", OpcodeAndVariant{0x2004, 0, 0}) - .Case("arith.minui", OpcodeAndVariant{0x2005, 0, 0}) - .Case("arith.muli", OpcodeAndVariant{0x2006, 0, 0}) - .Case("arith.select", OpcodeAndVariant{0x2007, 0, 0}) - .Case("arith.subi", OpcodeAndVariant{0x2008, 0, 0}) - .Case("func.func", OpcodeAndVariant{0x6000, 0, 0}) - .Case("func.return", OpcodeAndVariant{0x6001, 0, 0}) - .Case("func.call", OpcodeAndVariant{0x6002, 0, 0}) - .Case("pto.addptr", OpcodeAndVariant{0x1000, 0, 0}) - .Case("pto.alloc_tile", OpcodeAndVariant{0x1001, 0, 0}) - .Case("pto.barrier", OpcodeAndVariant{0x1002, 0, 0}) - .Case("pto.get_block_idx", OpcodeAndVariant{0x0000, 0, 0}) - .Case("pto.get_block_num", OpcodeAndVariant{0x0001, 0, 0}) - .Case("pto.get_subblock_idx", OpcodeAndVariant{0x0002, 0, 0}) - .Case("pto.get_subblock_num", OpcodeAndVariant{0x0003, 0, 0}) - .Case("pto.make_tensor_view", OpcodeAndVariant{0x0004, 0, 0}) - .Case("pto.mgather", OpcodeAndVariant{0x1003, 0, 0}) - .Case("pto.mscatter", OpcodeAndVariant{0x1004, 0, 0}) - .Case("pto.partition_view", OpcodeAndVariant{0x0005, 0, 0}) - .Case("pto.record_event", OpcodeAndVariant{0x1005, 0, 0}) - .Case("pto.tabs", OpcodeAndVariant{0x1006, 0, 0}) - .Case("pto.tadd", OpcodeAndVariant{0x1007, 0, 0}) - .Case("pto.taddc", OpcodeAndVariant{0x1008, 0, 0}) - .Case("pto.tadds", OpcodeAndVariant{0x1009, 0, 0}) - .Case("pto.taddsc", OpcodeAndVariant{0x100A, 0, 0}) - .Case("pto.tand", OpcodeAndVariant{0x100B, 0, 0}) - .Case("pto.tands", OpcodeAndVariant{0x100C, 0, 0}) - .Case("pto.tci", OpcodeAndVariant{0x100D, 0, 0}) - .Case("pto.tcmp", OpcodeAndVariant{0x100E, 0, 0}) - .Case("pto.tcmps", OpcodeAndVariant{0x100F, 0, 0}) - .Case("pto.tcolexpand", OpcodeAndVariant{0x1010, 0, 0}) - .Case("pto.tcolexpandadd", OpcodeAndVariant{0x1011, 0, 0}) - .Case("pto.tcolexpanddiv", OpcodeAndVariant{0x1012, 0, 0}) - .Case("pto.tcolexpandexpdif", OpcodeAndVariant{0x1013, 0, 0}) - .Case("pto.tcolexpandmax", OpcodeAndVariant{0x1014, 0, 0}) - .Case("pto.tcolexpandmin", OpcodeAndVariant{0x1015, 0, 0}) - .Case("pto.tcolexpandmul", OpcodeAndVariant{0x1016, 0, 0}) - .Case("pto.tcolexpandsub", OpcodeAndVariant{0x1017, 0, 0}) - .Case("pto.tcolmax", OpcodeAndVariant{0x1018, 0, 0}) - .Case("pto.tcolmin", OpcodeAndVariant{0x1019, 0, 0}) - .Case("pto.tcolprod", OpcodeAndVariant{0x101A, 0, 0}) - .Case("pto.tcolsum", OpcodeAndVariant{0x101B, 0, 0}) - .Case("pto.tcvt", OpcodeAndVariant{0x101C, 0, 0}) - .Case("pto.tdiv", OpcodeAndVariant{0x101D, 0, 0}) - .Case("pto.tdivs", OpcodeAndVariant{0x101E, 0, 0}) - .Case("pto.texp", OpcodeAndVariant{0x101F, 0, 0}) - .Case("pto.texpands", OpcodeAndVariant{0x1020, 0, 0}) - .Case("pto.textract", OpcodeAndVariant{0x1021, 0, 0}) - .Case("pto.textract_fp", OpcodeAndVariant{0x1022, 0, 0}) - .Case("pto.tfillpad", OpcodeAndVariant{0x1023, 0, 0}) - .Case("pto.tfillpad_expand", OpcodeAndVariant{0x1024, 0, 0}) - .Case("pto.tfillpad_inplace", OpcodeAndVariant{0x1025, 0, 0}) - .Case("pto.tfmod", OpcodeAndVariant{0x1026, 0, 0}) - .Case("pto.tfmods", OpcodeAndVariant{0x1027, 0, 0}) - .Case("pto.tgather", OpcodeAndVariant{0x1028, 0, 0}) - .Case("pto.tgatherb", OpcodeAndVariant{0x1029, 0, 0}) - .Case("pto.tgetval", OpcodeAndVariant{0x102B, 0, 0}) - .Case("pto.timg2col", OpcodeAndVariant{0x102C, 0, 0}) - .Case("pto.tinsert", OpcodeAndVariant{0x102D, 0, 0}) - .Case("pto.tinsert_fp", OpcodeAndVariant{0x102E, 0, 0}) - .Case("pto.tload", OpcodeAndVariant{0x102F, 0, 0}) - .Case("pto.tlog", OpcodeAndVariant{0x1030, 0, 0}) - .Case("pto.tlrelu", OpcodeAndVariant{0x1031, 0, 0}) - .Case("pto.tmax", OpcodeAndVariant{0x1034, 0, 0}) - .Case("pto.tmaxs", OpcodeAndVariant{0x1035, 0, 0}) - .Case("pto.tmin", OpcodeAndVariant{0x1036, 0, 0}) - .Case("pto.tmins", OpcodeAndVariant{0x1037, 0, 0}) - .Case("pto.tmov", OpcodeAndVariant{0x1038, 0, 0}) - .Case("pto.tmov.fp", OpcodeAndVariant{0x1039, 0, 0}) - .Case("pto.tmrgsort", OpcodeAndVariant{0x103A, 0, 0}) - .Case("pto.tmul", OpcodeAndVariant{0x103B, 0, 0}) - .Case("pto.tmuls", OpcodeAndVariant{0x103C, 0, 0}) - .Case("pto.tneg", OpcodeAndVariant{0x103D, 0, 0}) - .Case("pto.tnot", OpcodeAndVariant{0x103E, 0, 0}) - .Case("pto.tor", OpcodeAndVariant{0x103F, 0, 0}) - .Case("pto.tors", OpcodeAndVariant{0x1040, 0, 0}) - .Case("pto.tpartadd", OpcodeAndVariant{0x1041, 0, 0}) - .Case("pto.tpartmax", OpcodeAndVariant{0x1042, 0, 0}) - .Case("pto.tpartmin", OpcodeAndVariant{0x1043, 0, 0}) - .Case("pto.tpartmul", OpcodeAndVariant{0x1044, 0, 0}) - .Case("pto.tprefetch", OpcodeAndVariant{0x1045, 0, 0}) - .Case("pto.tprelu", OpcodeAndVariant{0x1046, 0, 0}) - .Case("pto.tquant", OpcodeAndVariant{0x1047, 0, 0}) - .Case("pto.trecip", OpcodeAndVariant{0x1048, 0, 0}) - .Case("pto.trelu", OpcodeAndVariant{0x1049, 0, 0}) - .Case("pto.trem", OpcodeAndVariant{0x104A, 0, 0}) - .Case("pto.trems", OpcodeAndVariant{0x104B, 0, 0}) - .Case("pto.treshape", OpcodeAndVariant{0x104C, 0, 0}) - .Case("pto.trowexpand", OpcodeAndVariant{0x104D, 0, 0}) - .Case("pto.trowexpandadd", OpcodeAndVariant{0x104E, 0, 0}) - .Case("pto.trowexpandexpdif", OpcodeAndVariant{0x104F, 0, 0}) - .Case("pto.trowexpandmax", OpcodeAndVariant{0x1050, 0, 0}) - .Case("pto.trowexpandmin", OpcodeAndVariant{0x1051, 0, 0}) - .Case("pto.trowmax", OpcodeAndVariant{0x1052, 0, 0}) - .Case("pto.trowmin", OpcodeAndVariant{0x1053, 0, 0}) - .Case("pto.trowsum", OpcodeAndVariant{0x1054, 0, 0}) - .Case("pto.trsqrt", OpcodeAndVariant{0x1055, 0, 0}) - .Case("pto.tscatter", OpcodeAndVariant{0x1056, 0, 0}) - .Case("pto.tsel", OpcodeAndVariant{0x1057, 0, 0}) - .Case("pto.tsels", OpcodeAndVariant{0x1058, 0, 0}) - .Case("pto.tset_img2col_padding", OpcodeAndVariant{0x1059, 0, 0}) - .Case("pto.tset_img2col_rpt", OpcodeAndVariant{0x105A, 0, 0}) - .Case("pto.tsetfmatrix", OpcodeAndVariant{0x105B, 0, 0}) - .Case("pto.tsethf32mode", OpcodeAndVariant{0x105C, 0, 0}) - .Case("pto.tsettf32mode", OpcodeAndVariant{0x105D, 0, 0}) - .Case("pto.tsetval", OpcodeAndVariant{0x105E, 0, 0}) - .Case("pto.tshl", OpcodeAndVariant{0x105F, 0, 0}) - .Case("pto.tshls", OpcodeAndVariant{0x1060, 0, 0}) - .Case("pto.tshr", OpcodeAndVariant{0x1061, 0, 0}) - .Case("pto.tshrs", OpcodeAndVariant{0x1062, 0, 0}) - .Case("pto.tsort32", OpcodeAndVariant{0x1063, 0, 0}) - .Case("pto.tsqrt", OpcodeAndVariant{0x1064, 0, 0}) - .Case("pto.tstore", OpcodeAndVariant{0x1065, 0, 0}) - .Case("pto.tstore_fp", OpcodeAndVariant{0x1066, 0, 0}) - .Case("pto.tsub", OpcodeAndVariant{0x1067, 0, 0}) - .Case("pto.tsubc", OpcodeAndVariant{0x1068, 0, 0}) - .Case("pto.tsubs", OpcodeAndVariant{0x1069, 0, 0}) - .Case("pto.tsubsc", OpcodeAndVariant{0x106A, 0, 0}) - .Case("pto.trowexpandsub", OpcodeAndVariant{0x106B, 0, 0}) - .Case("pto.ttrans", OpcodeAndVariant{0x106C, 0, 0}) - .Case("pto.ttri", OpcodeAndVariant{0x106D, 0, 0}) - .Case("pto.txor", OpcodeAndVariant{0x106E, 0, 0}) - .Case("pto.txors", OpcodeAndVariant{0x106F, 0, 0}) - .Case("pto.wait_event", OpcodeAndVariant{0x1070, 0, 0}) - .Case("pto.tprint", OpcodeAndVariant{0x1071, 0, 0}) - .Case("pto.subview", OpcodeAndVariant{0x1072, 0, 0}) - .Case("pto.trowexpanddiv", OpcodeAndVariant{0x1073, 0, 0}) - .Case("pto.trowexpandmul", OpcodeAndVariant{0x1074, 0, 0}) - .Case("pto.tdequant", OpcodeAndVariant{0x1075, 0, 0}) - .Case("pto.taxpy", OpcodeAndVariant{0x1076, 0, 0}) - .Case("pto.thistogram", OpcodeAndVariant{0x1077, 0, 0}) - .Case("pto.tget_scale_addr", OpcodeAndVariant{0x1078, 0, 0}) - .Case("pto.trowargmax", OpcodeAndVariant{0x1079, 0, 0}) - .Case("pto.trowargmin", OpcodeAndVariant{0x107A, 0, 0}) - .Case("pto.tcolargmax", OpcodeAndVariant{0x107B, 0, 0}) - .Case("pto.tcolargmin", OpcodeAndVariant{0x107C, 0, 0}) - .Case("pto.tsync", OpcodeAndVariant{0x107D, 0, 0}) - .Case("pto.reserve_buffer", OpcodeAndVariant{0x107E, 0, 0}) - .Case("pto.import_reserved_buffer", OpcodeAndVariant{0x107F, 0, 0}) - .Case("pto.aic_initialize_pipe", OpcodeAndVariant{0x1080, 0, 0}) - .Case("pto.aiv_initialize_pipe", OpcodeAndVariant{0x1081, 0, 0}) - .Case("pto.tpush_to_aiv", OpcodeAndVariant{0x1082, 0, 0}) - .Case("pto.tpush_to_aic", OpcodeAndVariant{0x1083, 0, 0}) - .Case("pto.tpop_from_aic", OpcodeAndVariant{0x1084, 0, 0}) - .Case("pto.tpop_from_aiv", OpcodeAndVariant{0x1085, 0, 0}) - .Case("pto.tfree_from_aic", OpcodeAndVariant{0x1086, 0, 0}) - .Case("pto.tfree_from_aiv", OpcodeAndVariant{0x1087, 0, 0}) - .Case("pto.set_validshape", OpcodeAndVariant{0x1088, 0, 0}) - .Case("pto.tconcat", OpcodeAndVariant{0x1089, 0, 0}) - .Case("pto.trowprod", OpcodeAndVariant{0x108A, 0, 0}) - .Case("pto.initialize_l2g2l_pipe", OpcodeAndVariant{0x108B, 0, 0}) - .Case("pto.initialize_l2l_pipe", OpcodeAndVariant{0x108C, 0, 0}) - .Case("pto.tpush", OpcodeAndVariant{0x108D, 0, 0}) - .Case("pto.declare_tile", OpcodeAndVariant{0x108E, 0, 0}) - .Case("pto.tpop", OpcodeAndVariant{0x108F, 0, 0}) - .Case("pto.tfree", OpcodeAndVariant{0x1090, 0, 0}) - .Case("pto.comm.tput", OpcodeAndVariant{0x1091, 0, 0}) - .Case("pto.comm.tget", OpcodeAndVariant{0x1092, 0, 0}) - .Case("pto.comm.tnotify", OpcodeAndVariant{0x1093, 0, 0}) - .Case("pto.comm.twait", OpcodeAndVariant{0x1094, 0, 0}) - .Case("pto.comm.ttest", OpcodeAndVariant{0x1095, 0, 0}) - .Case("pto.comm.tbroadcast", OpcodeAndVariant{0x1096, 0, 0}) - .Case("pto.comm.tgather", OpcodeAndVariant{0x1097, 0, 0}) - .Case("pto.comm.tscatter", OpcodeAndVariant{0x1098, 0, 0}) - .Case("pto.comm.treduce", OpcodeAndVariant{0x1099, 0, 0}) - .Case("pto.tpartargmax", OpcodeAndVariant{0x109A, 0, 0}) - .Case("pto.tpartargmin", OpcodeAndVariant{0x109B, 0, 0}) - .Case("scf.for", OpcodeAndVariant{0x4000, 0, 0}) - .Case("scf.if", OpcodeAndVariant{0x4001, 0, 0}) - .Case("scf.yield", OpcodeAndVariant{0x4002, 0, 0}) - .Case("pto.section.cube", - OpcodeAndVariant{0x0006, kHasVariant, kSectionCubeVariant}) - .Case("pto.section.vector", - OpcodeAndVariant{0x0006, kHasVariant, kSectionVectorVariant}) - .Case("pto.tgemv", - OpcodeAndVariant{0x102A, kHasVariant, kVariantDefault}) - .Case("pto.tgemv.acc", - OpcodeAndVariant{0x102A, kHasVariant, kVariantAcc}) - .Case("pto.tgemv.bias", - OpcodeAndVariant{0x102A, kHasVariant, kVariantBias}) - .Case("pto.tgemv.mx", - OpcodeAndVariant{0x102A, kHasVariant, kVariantMx}) - .Case("pto.tgemv.mx.acc", - OpcodeAndVariant{0x102A, kHasVariant, kVariantMxAcc}) - .Case("pto.tgemv.mx.bias", - OpcodeAndVariant{0x102A, kHasVariant, kVariantMxBias}) - .Case("pto.tmatmul", - OpcodeAndVariant{0x1032, kHasVariant, kVariantDefault}) - .Case("pto.tmatmul.acc", - OpcodeAndVariant{0x1032, kHasVariant, kVariantAcc}) - .Case("pto.tmatmul.bias", - OpcodeAndVariant{0x1032, kHasVariant, kVariantBias}) - .Case("pto.tmatmul.mx", - OpcodeAndVariant{0x1033, kHasVariant, kVariantDefault}) - .Case("pto.tmatmul.mx.acc", - OpcodeAndVariant{0x1033, kHasVariant, kVariantAcc}) - .Case("pto.tmatmul.mx.bias", - OpcodeAndVariant{0x1033, kHasVariant, kVariantBias}) - .Default(std::nullopt); -} - -inline const char *fullNameFromOpcodeVariant(uint16_t opcode, uint8_t variant) { - const OpInfo *info = lookupByOpcode(opcode); - if (!info) return nullptr; - if (opcode == kTscatterMaskOpcode) return "pto.tscatter"; - if (!info->has_variant_u8) return info->name; - switch (opcode) { - case 0x0006: - switch (variant) { - case kSectionCubeVariant: return "pto.section.cube"; - case kSectionVectorVariant: return "pto.section.vector"; - default: return info->name; - } - case 0x102A: - switch (variant) { - case kVariantDefault: return "pto.tgemv"; - case kVariantAcc: return "pto.tgemv.acc"; - case kVariantBias: return "pto.tgemv.bias"; - case kVariantMx: return "pto.tgemv.mx"; - case kVariantMxAcc: return "pto.tgemv.mx.acc"; - case kVariantMxBias: return "pto.tgemv.mx.bias"; - default: return info->name; - } - case 0x1032: - switch (variant) { - case kVariantDefault: return "pto.tmatmul"; - case kVariantAcc: return "pto.tmatmul.acc"; - case kVariantBias: return "pto.tmatmul.bias"; - default: return info->name; - } - case 0x1033: - switch (variant) { - case kVariantDefault: return "pto.tmatmul.mx"; - case kVariantAcc: return "pto.tmatmul.mx.acc"; - case kVariantBias: return "pto.tmatmul.mx.bias"; - default: return info->name; - } - default: return info->name; - } -} - -inline std::optional lookupOperandsByVariant(uint16_t opcode, uint8_t variant) { - switch (opcode) { - case 0x102A: - switch (variant) { - case kVariantDefault: return kTgemvOperandCount; - case kVariantAcc: return kTgemvAccOperandCount; - case kVariantBias: return kTgemvBiasOperandCount; - case kVariantMx: return kTgemvMxOperandCount; - case kVariantMxAcc: return kTgemvMxAccOperandCount; - case kVariantMxBias: return kTgemvMxBiasOperandCount; - default: return std::nullopt; - } - case 0x1032: - switch (variant) { - case kVariantDefault: return kTmatmulOperandCount; - case kVariantAcc: return kTmatmulAccOperandCount; - case kVariantBias: return kTmatmulBiasOperandCount; - default: return std::nullopt; - } - case 0x1033: - switch (variant) { - case kVariantDefault: return kTmatmulMxOperandCount; - case kVariantAcc: return kTmatmulMxAccOperandCount; - case kVariantBias: return kTmatmulMxBiasOperandCount; - default: return std::nullopt; - } - default: return std::nullopt; - } -} - -} // namespace ptobc::v0 diff --git a/tools/ptobc/generated/ptobc_opcodes_v0.h b/tools/ptobc/generated/ptobc_opcodes_v0.h index 3f0faf5f1..8303e1261 100644 --- a/tools/ptobc/generated/ptobc_opcodes_v0.h +++ b/tools/ptobc/generated/ptobc_opcodes_v0.h @@ -6,5 +6,717 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. +// Generated by docs/bytecode/tools/gen_v0_tables.py +#pragma once -#include "ptobc_opcodes_v0.def" +#include +#include + +#include +#include + +namespace ptobc::v0 { + +inline constexpr uint8_t kVariantDefault = 0; +inline constexpr uint8_t kVariantAcc = 1; +inline constexpr uint8_t kVariantBias = 2; +inline constexpr uint8_t kVariantMx = 3; +inline constexpr uint8_t kVariantMxAcc = 4; +inline constexpr uint8_t kVariantMxBias = 5; +inline constexpr uint8_t kSectionCubeVariant = 0; +inline constexpr uint8_t kSectionVectorVariant = 1; +inline constexpr uint8_t kHasVariant = 1; +inline constexpr uint16_t kTscatterMaskOpcode = 0x109C; + +inline constexpr int kTgemvOperandCount = 3; +inline constexpr int kTgemvAccOperandCount = 4; +inline constexpr int kTgemvBiasOperandCount = 4; +inline constexpr int kTgemvMxOperandCount = 5; +inline constexpr int kTgemvMxAccOperandCount = 6; +inline constexpr int kTgemvMxBiasOperandCount = 6; +inline constexpr int kTmatmulOperandCount = 3; +inline constexpr int kTmatmulAccOperandCount = 4; +inline constexpr int kTmatmulBiasOperandCount = 4; +inline constexpr int kTmatmulMxOperandCount = 5; +inline constexpr int kTmatmulMxAccOperandCount = 6; +inline constexpr int kTmatmulMxBiasOperandCount = 6; + +struct OpInfo { + uint16_t opcode; + const char *name; + uint8_t has_variant_u8; + uint8_t result_type_mode; + uint8_t operand_mode; + uint16_t num_operands; + uint16_t num_results; + uint16_t num_regions; + uint8_t imm_kind; +}; + +inline constexpr OpInfo kOpTable[] = { + {0x0000, "pto.get_block_idx", 0, 0x01, 0x00, 0, 1, 0, 0x00}, + {0x0001, "pto.get_block_num", 0, 0x01, 0x00, 0, 1, 0, 0x00}, + {0x0002, "pto.get_subblock_idx", 0, 0x01, 0x00, 0, 1, 0, 0x00}, + {0x0003, "pto.get_subblock_num", 0, 0x01, 0x00, 0, 1, 0, 0x00}, + {0x0004, "pto.make_tensor_view", 0, 0x01, 0x03, 1, 1, 0, 0x06}, + {0x0005, "pto.partition_view", 0, 0x01, 0x03, 1, 1, 0, 0x07}, + {0x0006, "pto.section", 1, 0x00, 0x00, 0, 0, 1, 0x00}, + {0x1000, "pto.addptr", 0, 0x01, 0x00, 2, 1, 0, 0x00}, + {0x1001, "pto.alloc_tile", 0, 0x01, 0x04, 0, 1, 0, 0x08}, + {0x1002, "pto.barrier", 0, 0x00, 0x00, 0, 0, 0, 0x00}, + {0x1003, "pto.mgather", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1004, "pto.mscatter", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1005, "pto.record_event", 0, 0x00, 0x00, 0, 0, 0, 0x02}, + {0x1006, "pto.tabs", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1007, "pto.tadd", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1008, "pto.taddc", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x1009, "pto.tadds", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x100A, "pto.taddsc", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x100B, "pto.tand", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x100C, "pto.tands", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x100D, "pto.tci", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x100E, "pto.tcmp", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x100F, "pto.tcmps", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1010, "pto.tcolexpand", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1011, "pto.tcolexpandadd", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1012, "pto.tcolexpanddiv", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1013, "pto.tcolexpandexpdif", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1014, "pto.tcolexpandmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1015, "pto.tcolexpandmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1016, "pto.tcolexpandmul", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1017, "pto.tcolexpandsub", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1018, "pto.tcolmax", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1019, "pto.tcolmin", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x101A, "pto.tcolprod", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x101B, "pto.tcolsum", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x101C, "pto.tcvt", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x101D, "pto.tdiv", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x101E, "pto.tdivs", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x101F, "pto.texp", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1020, "pto.texpands", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1021, "pto.textract", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x1022, "pto.textract_fp", 0, 0x00, 0x00, 5, 0, 0, 0x00}, + {0x1023, "pto.tfillpad", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1024, "pto.tfillpad_expand", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1025, "pto.tfillpad_inplace", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1026, "pto.tfmod", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1027, "pto.tfmods", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1028, "pto.tgather", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1029, "pto.tgatherb", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x102A, "pto.tgemv", 1, 0x00, 0x01, 0, 0, 0, 0x00}, + {0x102B, "pto.tgetval", 0, 0x01, 0x00, 2, 1, 0, 0x00}, + {0x102C, "pto.timg2col", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x102D, "pto.tinsert", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x102E, "pto.tinsert_fp", 0, 0x00, 0x00, 5, 0, 0, 0x00}, + {0x102F, "pto.tload", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1030, "pto.tlog", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1031, "pto.tlrelu", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1032, "pto.tmatmul", 1, 0x00, 0x01, 0, 0, 0, 0x00}, + {0x1033, "pto.tmatmul.mx", 1, 0x00, 0x01, 0, 0, 0, 0x00}, + {0x1034, "pto.tmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1035, "pto.tmaxs", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1036, "pto.tmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1037, "pto.tmins", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1038, "pto.tmov", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1039, "pto.tmov.fp", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x103A, "pto.tmrgsort", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x103B, "pto.tmul", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x103C, "pto.tmuls", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x103D, "pto.tneg", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x103E, "pto.tnot", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x103F, "pto.tor", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1040, "pto.tors", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1041, "pto.tpartadd", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1042, "pto.tpartmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1043, "pto.tpartmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1044, "pto.tpartmul", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1045, "pto.tprefetch", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1046, "pto.tprelu", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x1047, "pto.tquant", 0, 0x00, 0x02, 3, 0, 0, 0x00}, + {0x1048, "pto.trecip", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1049, "pto.trelu", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x104A, "pto.trem", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x104B, "pto.trems", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x104C, "pto.treshape", 0, 0x01, 0x00, 1, 1, 0, 0x00}, + {0x104D, "pto.trowexpand", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x104E, "pto.trowexpandadd", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x104F, "pto.trowexpandexpdif", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1050, "pto.trowexpandmax", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1051, "pto.trowexpandmin", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1052, "pto.trowmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1053, "pto.trowmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1054, "pto.trowsum", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1055, "pto.trsqrt", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1056, "pto.tscatter", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1057, "pto.tsel", 0, 0x00, 0x00, 5, 0, 0, 0x00}, + {0x1058, "pto.tsels", 0, 0x00, 0x00, 5, 0, 0, 0x00}, + {0x1059, "pto.tset_img2col_padding", 0, 0x00, 0x00, 1, 0, 0, 0x00}, + {0x105A, "pto.tset_img2col_rpt", 0, 0x00, 0x00, 1, 0, 0, 0x00}, + {0x105B, "pto.tsetfmatrix", 0, 0x00, 0x00, 1, 0, 0, 0x00}, + {0x105C, "pto.tsethf32mode", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x105D, "pto.tsettf32mode", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x105E, "pto.tsetval", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x105F, "pto.tshl", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1060, "pto.tshls", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1061, "pto.tshr", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1062, "pto.tshrs", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1063, "pto.tsort32", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1064, "pto.tsqrt", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1065, "pto.tstore", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1066, "pto.tstore_fp", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1067, "pto.tsub", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1068, "pto.tsubc", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x1069, "pto.tsubs", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x106A, "pto.tsubsc", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x106B, "pto.trowexpandsub", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x106C, "pto.ttrans", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x106D, "pto.ttri", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x106E, "pto.txor", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x106F, "pto.txors", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x1070, "pto.wait_event", 0, 0x00, 0x00, 0, 0, 0, 0x02}, + {0x1071, "pto.tprint", 0, 0x00, 0x00, 1, 0, 0, 0x00}, + {0x1072, "pto.subview", 0, 0x01, 0x02, 0, 1, 0, 0x00}, + {0x1073, "pto.trowexpanddiv", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1074, "pto.trowexpandmul", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1075, "pto.tdequant", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x1076, "pto.taxpy", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1077, "pto.thistogram", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1078, "pto.tget_scale_addr", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1079, "pto.trowargmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x107A, "pto.trowargmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x107B, "pto.tcolargmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x107C, "pto.tcolargmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x107D, "pto.tsync", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x107E, "pto.reserve_buffer", 0, 0x01, 0x00, 0, 1, 0, 0x00}, + {0x107F, "pto.import_reserved_buffer", 0, 0x01, 0x00, 0, 1, 0, 0x00}, + {0x1080, "pto.aic_initialize_pipe", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1081, "pto.aiv_initialize_pipe", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1082, "pto.tpush_to_aiv", 0, 0x00, 0x00, 1, 0, 0, 0x00}, + {0x1083, "pto.tpush_to_aic", 0, 0x00, 0x00, 1, 0, 0, 0x00}, + {0x1084, "pto.tpop_from_aic", 0, 0x01, 0x02, 0, 1, 0, 0x00}, + {0x1085, "pto.tpop_from_aiv", 0, 0x01, 0x02, 0, 1, 0, 0x00}, + {0x1086, "pto.tfree_from_aic", 0, 0x00, 0x00, 0, 0, 0, 0x00}, + {0x1087, "pto.tfree_from_aiv", 0, 0x00, 0x00, 0, 0, 0, 0x00}, + {0x1088, "pto.set_validshape", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1089, "pto.tconcat", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x108A, "pto.trowprod", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x108B, "pto.initialize_l2g2l_pipe", 0, 0x01, 0x02, 0, 1, 0, 0x00}, + {0x108C, "pto.initialize_l2l_pipe", 0, 0x01, 0x02, 0, 1, 0, 0x00}, + {0x108D, "pto.tpush", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x108E, "pto.declare_tile", 0, 0x01, 0x00, 0, 1, 0, 0x00}, + {0x108F, "pto.tpop", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1090, "pto.tfree", 0, 0x00, 0x00, 1, 0, 0, 0x00}, + {0x1091, "pto.comm.tput", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1092, "pto.comm.tget", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1093, "pto.comm.tnotify", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1094, "pto.comm.twait", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1095, "pto.comm.ttest", 0, 0x01, 0x02, 0, 1, 0, 0x00}, + {0x1096, "pto.comm.tbroadcast", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1097, "pto.comm.tgather", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1098, "pto.comm.tscatter", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1099, "pto.comm.treduce", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x109A, "pto.tpartargmax", 0, 0x00, 0x00, 6, 0, 0, 0x00}, + {0x109B, "pto.tpartargmin", 0, 0x00, 0x00, 6, 0, 0, 0x00}, + {0x109C, "pto.tscatter.maskpattern", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x2000, "arith.addi", 0, 0x01, 0x00, 2, 1, 0, 0x00}, + {0x2001, "arith.ceildivsi", 0, 0x01, 0x00, 2, 1, 0, 0x00}, + {0x2002, "arith.cmpi", 0, 0x01, 0x00, 2, 1, 0, 0x01}, + {0x2003, "arith.constant", 0, 0x01, 0x00, 0, 1, 0, 0x05}, + {0x2004, "arith.index_cast", 0, 0x01, 0x00, 1, 1, 0, 0x00}, + {0x2005, "arith.minui", 0, 0x01, 0x00, 2, 1, 0, 0x00}, + {0x2006, "arith.muli", 0, 0x01, 0x00, 2, 1, 0, 0x00}, + {0x2007, "arith.select", 0, 0x01, 0x00, 3, 1, 0, 0x00}, + {0x2008, "arith.subi", 0, 0x01, 0x00, 2, 1, 0, 0x00}, + {0x4000, "scf.for", 0, 0x00, 0x00, 3, 0, 1, 0x00}, + {0x4001, "scf.if", 0, 0x00, 0x00, 1, 0, 2, 0x00}, + {0x4002, "scf.yield", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x6000, "func.func", 0, 0x00, 0x00, 0, 0, 0, 0x00}, + {0x6001, "func.return", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x6002, "func.call", 0, 0x02, 0x02, 0, 0, 0, 0x00}, +}; + +inline const OpInfo *lookupByOpcode(uint16_t opcode) { + // Binary search on kOpTable (sorted by opcode). + size_t lo = 0, hi = sizeof(kOpTable) / sizeof(kOpTable[0]); + while (lo < hi) { + size_t mid = lo + (hi - lo) / 2; + uint16_t v = kOpTable[mid].opcode; + if (v == opcode) return &kOpTable[mid]; + if (v < opcode) lo = mid + 1; else hi = mid; + } + return nullptr; +} + +inline std::optional lookupOpcodeByName(llvm::StringRef name) { + uint16_t v = llvm::StringSwitch(name) + .Case("arith.addi", 0x2000) + .Case("arith.ceildivsi", 0x2001) + .Case("arith.cmpi", 0x2002) + .Case("arith.constant", 0x2003) + .Case("arith.index_cast", 0x2004) + .Case("arith.minui", 0x2005) + .Case("arith.muli", 0x2006) + .Case("arith.select", 0x2007) + .Case("arith.subi", 0x2008) + .Case("func.func", 0x6000) + .Case("func.return", 0x6001) + .Case("func.call", 0x6002) + .Case("pto.addptr", 0x1000) + .Case("pto.alloc_tile", 0x1001) + .Case("pto.barrier", 0x1002) + .Case("pto.get_block_idx", 0x0000) + .Case("pto.get_block_num", 0x0001) + .Case("pto.get_subblock_idx", 0x0002) + .Case("pto.get_subblock_num", 0x0003) + .Case("pto.make_tensor_view", 0x0004) + .Case("pto.mgather", 0x1003) + .Case("pto.mscatter", 0x1004) + .Case("pto.partition_view", 0x0005) + .Case("pto.record_event", 0x1005) + .Case("pto.section", 0x0006) + .Case("pto.tabs", 0x1006) + .Case("pto.tadd", 0x1007) + .Case("pto.taddc", 0x1008) + .Case("pto.tadds", 0x1009) + .Case("pto.taddsc", 0x100A) + .Case("pto.tand", 0x100B) + .Case("pto.tands", 0x100C) + .Case("pto.tci", 0x100D) + .Case("pto.tcmp", 0x100E) + .Case("pto.tcmps", 0x100F) + .Case("pto.tcolexpand", 0x1010) + .Case("pto.tcolexpandadd", 0x1011) + .Case("pto.tcolexpanddiv", 0x1012) + .Case("pto.tcolexpandexpdif", 0x1013) + .Case("pto.tcolexpandmax", 0x1014) + .Case("pto.tcolexpandmin", 0x1015) + .Case("pto.tcolexpandmul", 0x1016) + .Case("pto.tcolexpandsub", 0x1017) + .Case("pto.tcolmax", 0x1018) + .Case("pto.tcolmin", 0x1019) + .Case("pto.tcolprod", 0x101A) + .Case("pto.tcolsum", 0x101B) + .Case("pto.tcvt", 0x101C) + .Case("pto.tdiv", 0x101D) + .Case("pto.tdivs", 0x101E) + .Case("pto.texp", 0x101F) + .Case("pto.texpands", 0x1020) + .Case("pto.textract", 0x1021) + .Case("pto.textract_fp", 0x1022) + .Case("pto.tfillpad", 0x1023) + .Case("pto.tfillpad_expand", 0x1024) + .Case("pto.tfillpad_inplace", 0x1025) + .Case("pto.tfmod", 0x1026) + .Case("pto.tfmods", 0x1027) + .Case("pto.tgather", 0x1028) + .Case("pto.tgatherb", 0x1029) + .Case("pto.tgemv", 0x102A) + .Case("pto.tgetval", 0x102B) + .Case("pto.timg2col", 0x102C) + .Case("pto.tinsert", 0x102D) + .Case("pto.tinsert_fp", 0x102E) + .Case("pto.tload", 0x102F) + .Case("pto.tlog", 0x1030) + .Case("pto.tlrelu", 0x1031) + .Case("pto.tmatmul", 0x1032) + .Case("pto.tmatmul.mx", 0x1033) + .Case("pto.tmax", 0x1034) + .Case("pto.tmaxs", 0x1035) + .Case("pto.tmin", 0x1036) + .Case("pto.tmins", 0x1037) + .Case("pto.tmov", 0x1038) + .Case("pto.tmov.fp", 0x1039) + .Case("pto.tmrgsort", 0x103A) + .Case("pto.tmul", 0x103B) + .Case("pto.tmuls", 0x103C) + .Case("pto.tneg", 0x103D) + .Case("pto.tnot", 0x103E) + .Case("pto.tor", 0x103F) + .Case("pto.tors", 0x1040) + .Case("pto.tpartadd", 0x1041) + .Case("pto.tpartmax", 0x1042) + .Case("pto.tpartmin", 0x1043) + .Case("pto.tpartmul", 0x1044) + .Case("pto.tprefetch", 0x1045) + .Case("pto.tprelu", 0x1046) + .Case("pto.tquant", 0x1047) + .Case("pto.trecip", 0x1048) + .Case("pto.trelu", 0x1049) + .Case("pto.trem", 0x104A) + .Case("pto.trems", 0x104B) + .Case("pto.treshape", 0x104C) + .Case("pto.trowexpand", 0x104D) + .Case("pto.trowexpandadd", 0x104E) + .Case("pto.trowexpandexpdif", 0x104F) + .Case("pto.trowexpandmax", 0x1050) + .Case("pto.trowexpandmin", 0x1051) + .Case("pto.trowmax", 0x1052) + .Case("pto.trowmin", 0x1053) + .Case("pto.trowsum", 0x1054) + .Case("pto.trsqrt", 0x1055) + .Case("pto.tscatter", 0x1056) + .Case("pto.tsel", 0x1057) + .Case("pto.tsels", 0x1058) + .Case("pto.tset_img2col_padding", 0x1059) + .Case("pto.tset_img2col_rpt", 0x105A) + .Case("pto.tsetfmatrix", 0x105B) + .Case("pto.tsethf32mode", 0x105C) + .Case("pto.tsettf32mode", 0x105D) + .Case("pto.tsetval", 0x105E) + .Case("pto.tshl", 0x105F) + .Case("pto.tshls", 0x1060) + .Case("pto.tshr", 0x1061) + .Case("pto.tshrs", 0x1062) + .Case("pto.tsort32", 0x1063) + .Case("pto.tsqrt", 0x1064) + .Case("pto.tstore", 0x1065) + .Case("pto.tstore_fp", 0x1066) + .Case("pto.tsub", 0x1067) + .Case("pto.tsubc", 0x1068) + .Case("pto.tsubs", 0x1069) + .Case("pto.tsubsc", 0x106A) + .Case("pto.trowexpandsub", 0x106B) + .Case("pto.ttrans", 0x106C) + .Case("pto.ttri", 0x106D) + .Case("pto.txor", 0x106E) + .Case("pto.txors", 0x106F) + .Case("pto.wait_event", 0x1070) + .Case("pto.tprint", 0x1071) + .Case("pto.subview", 0x1072) + .Case("pto.trowexpanddiv", 0x1073) + .Case("pto.trowexpandmul", 0x1074) + .Case("pto.tdequant", 0x1075) + .Case("pto.taxpy", 0x1076) + .Case("pto.thistogram", 0x1077) + .Case("pto.tget_scale_addr", 0x1078) + .Case("pto.trowargmax", 0x1079) + .Case("pto.trowargmin", 0x107A) + .Case("pto.tcolargmax", 0x107B) + .Case("pto.tcolargmin", 0x107C) + .Case("pto.tsync", 0x107D) + .Case("pto.reserve_buffer", 0x107E) + .Case("pto.import_reserved_buffer", 0x107F) + .Case("pto.aic_initialize_pipe", 0x1080) + .Case("pto.aiv_initialize_pipe", 0x1081) + .Case("pto.tpush_to_aiv", 0x1082) + .Case("pto.tpush_to_aic", 0x1083) + .Case("pto.tpop_from_aic", 0x1084) + .Case("pto.tpop_from_aiv", 0x1085) + .Case("pto.tfree_from_aic", 0x1086) + .Case("pto.tfree_from_aiv", 0x1087) + .Case("pto.set_validshape", 0x1088) + .Case("pto.tconcat", 0x1089) + .Case("pto.trowprod", 0x108A) + .Case("pto.initialize_l2g2l_pipe", 0x108B) + .Case("pto.initialize_l2l_pipe", 0x108C) + .Case("pto.tpush", 0x108D) + .Case("pto.declare_tile", 0x108E) + .Case("pto.tpop", 0x108F) + .Case("pto.tfree", 0x1090) + .Case("pto.comm.tput", 0x1091) + .Case("pto.comm.tget", 0x1092) + .Case("pto.comm.tnotify", 0x1093) + .Case("pto.comm.twait", 0x1094) + .Case("pto.comm.ttest", 0x1095) + .Case("pto.comm.tbroadcast", 0x1096) + .Case("pto.comm.tgather", 0x1097) + .Case("pto.comm.tscatter", 0x1098) + .Case("pto.comm.treduce", 0x1099) + .Case("pto.tpartargmax", 0x109A) + .Case("pto.tpartargmin", 0x109B) + .Case("scf.for", 0x4000) + .Case("scf.if", 0x4001) + .Case("scf.yield", 0x4002) + .Default(0xFFFF); + if (v == 0xFFFF) return std::nullopt; + return v; +} + +inline const OpInfo *lookupByName(llvm::StringRef name) { + auto o = lookupOpcodeByName(name); + if (!o) return nullptr; + return lookupByOpcode(*o); +} + +struct OpcodeAndVariant { uint16_t opcode; uint8_t hasVariant; uint8_t variant; }; + +inline std::optional lookupOpcodeAndVariantByFullName(llvm::StringRef fullName) { + // For non-family ops, variant is 0. For family ops, variant is the assigned u8. + // NOTE: `pto.section` is not a real op name; use `pto.section.cube`/`pto.section.vector`. + return llvm::StringSwitch>(fullName) + .Case("arith.addi", OpcodeAndVariant{0x2000, 0, 0}) + .Case("arith.ceildivsi", OpcodeAndVariant{0x2001, 0, 0}) + .Case("arith.cmpi", OpcodeAndVariant{0x2002, 0, 0}) + .Case("arith.constant", OpcodeAndVariant{0x2003, 0, 0}) + .Case("arith.index_cast", OpcodeAndVariant{0x2004, 0, 0}) + .Case("arith.minui", OpcodeAndVariant{0x2005, 0, 0}) + .Case("arith.muli", OpcodeAndVariant{0x2006, 0, 0}) + .Case("arith.select", OpcodeAndVariant{0x2007, 0, 0}) + .Case("arith.subi", OpcodeAndVariant{0x2008, 0, 0}) + .Case("func.func", OpcodeAndVariant{0x6000, 0, 0}) + .Case("func.return", OpcodeAndVariant{0x6001, 0, 0}) + .Case("func.call", OpcodeAndVariant{0x6002, 0, 0}) + .Case("pto.addptr", OpcodeAndVariant{0x1000, 0, 0}) + .Case("pto.alloc_tile", OpcodeAndVariant{0x1001, 0, 0}) + .Case("pto.barrier", OpcodeAndVariant{0x1002, 0, 0}) + .Case("pto.get_block_idx", OpcodeAndVariant{0x0000, 0, 0}) + .Case("pto.get_block_num", OpcodeAndVariant{0x0001, 0, 0}) + .Case("pto.get_subblock_idx", OpcodeAndVariant{0x0002, 0, 0}) + .Case("pto.get_subblock_num", OpcodeAndVariant{0x0003, 0, 0}) + .Case("pto.make_tensor_view", OpcodeAndVariant{0x0004, 0, 0}) + .Case("pto.mgather", OpcodeAndVariant{0x1003, 0, 0}) + .Case("pto.mscatter", OpcodeAndVariant{0x1004, 0, 0}) + .Case("pto.partition_view", OpcodeAndVariant{0x0005, 0, 0}) + .Case("pto.record_event", OpcodeAndVariant{0x1005, 0, 0}) + .Case("pto.tabs", OpcodeAndVariant{0x1006, 0, 0}) + .Case("pto.tadd", OpcodeAndVariant{0x1007, 0, 0}) + .Case("pto.taddc", OpcodeAndVariant{0x1008, 0, 0}) + .Case("pto.tadds", OpcodeAndVariant{0x1009, 0, 0}) + .Case("pto.taddsc", OpcodeAndVariant{0x100A, 0, 0}) + .Case("pto.tand", OpcodeAndVariant{0x100B, 0, 0}) + .Case("pto.tands", OpcodeAndVariant{0x100C, 0, 0}) + .Case("pto.tci", OpcodeAndVariant{0x100D, 0, 0}) + .Case("pto.tcmp", OpcodeAndVariant{0x100E, 0, 0}) + .Case("pto.tcmps", OpcodeAndVariant{0x100F, 0, 0}) + .Case("pto.tcolexpand", OpcodeAndVariant{0x1010, 0, 0}) + .Case("pto.tcolexpandadd", OpcodeAndVariant{0x1011, 0, 0}) + .Case("pto.tcolexpanddiv", OpcodeAndVariant{0x1012, 0, 0}) + .Case("pto.tcolexpandexpdif", OpcodeAndVariant{0x1013, 0, 0}) + .Case("pto.tcolexpandmax", OpcodeAndVariant{0x1014, 0, 0}) + .Case("pto.tcolexpandmin", OpcodeAndVariant{0x1015, 0, 0}) + .Case("pto.tcolexpandmul", OpcodeAndVariant{0x1016, 0, 0}) + .Case("pto.tcolexpandsub", OpcodeAndVariant{0x1017, 0, 0}) + .Case("pto.tcolmax", OpcodeAndVariant{0x1018, 0, 0}) + .Case("pto.tcolmin", OpcodeAndVariant{0x1019, 0, 0}) + .Case("pto.tcolprod", OpcodeAndVariant{0x101A, 0, 0}) + .Case("pto.tcolsum", OpcodeAndVariant{0x101B, 0, 0}) + .Case("pto.tcvt", OpcodeAndVariant{0x101C, 0, 0}) + .Case("pto.tdiv", OpcodeAndVariant{0x101D, 0, 0}) + .Case("pto.tdivs", OpcodeAndVariant{0x101E, 0, 0}) + .Case("pto.texp", OpcodeAndVariant{0x101F, 0, 0}) + .Case("pto.texpands", OpcodeAndVariant{0x1020, 0, 0}) + .Case("pto.textract", OpcodeAndVariant{0x1021, 0, 0}) + .Case("pto.textract_fp", OpcodeAndVariant{0x1022, 0, 0}) + .Case("pto.tfillpad", OpcodeAndVariant{0x1023, 0, 0}) + .Case("pto.tfillpad_expand", OpcodeAndVariant{0x1024, 0, 0}) + .Case("pto.tfillpad_inplace", OpcodeAndVariant{0x1025, 0, 0}) + .Case("pto.tfmod", OpcodeAndVariant{0x1026, 0, 0}) + .Case("pto.tfmods", OpcodeAndVariant{0x1027, 0, 0}) + .Case("pto.tgather", OpcodeAndVariant{0x1028, 0, 0}) + .Case("pto.tgatherb", OpcodeAndVariant{0x1029, 0, 0}) + .Case("pto.tgetval", OpcodeAndVariant{0x102B, 0, 0}) + .Case("pto.timg2col", OpcodeAndVariant{0x102C, 0, 0}) + .Case("pto.tinsert", OpcodeAndVariant{0x102D, 0, 0}) + .Case("pto.tinsert_fp", OpcodeAndVariant{0x102E, 0, 0}) + .Case("pto.tload", OpcodeAndVariant{0x102F, 0, 0}) + .Case("pto.tlog", OpcodeAndVariant{0x1030, 0, 0}) + .Case("pto.tlrelu", OpcodeAndVariant{0x1031, 0, 0}) + .Case("pto.tmax", OpcodeAndVariant{0x1034, 0, 0}) + .Case("pto.tmaxs", OpcodeAndVariant{0x1035, 0, 0}) + .Case("pto.tmin", OpcodeAndVariant{0x1036, 0, 0}) + .Case("pto.tmins", OpcodeAndVariant{0x1037, 0, 0}) + .Case("pto.tmov", OpcodeAndVariant{0x1038, 0, 0}) + .Case("pto.tmov.fp", OpcodeAndVariant{0x1039, 0, 0}) + .Case("pto.tmrgsort", OpcodeAndVariant{0x103A, 0, 0}) + .Case("pto.tmul", OpcodeAndVariant{0x103B, 0, 0}) + .Case("pto.tmuls", OpcodeAndVariant{0x103C, 0, 0}) + .Case("pto.tneg", OpcodeAndVariant{0x103D, 0, 0}) + .Case("pto.tnot", OpcodeAndVariant{0x103E, 0, 0}) + .Case("pto.tor", OpcodeAndVariant{0x103F, 0, 0}) + .Case("pto.tors", OpcodeAndVariant{0x1040, 0, 0}) + .Case("pto.tpartadd", OpcodeAndVariant{0x1041, 0, 0}) + .Case("pto.tpartmax", OpcodeAndVariant{0x1042, 0, 0}) + .Case("pto.tpartmin", OpcodeAndVariant{0x1043, 0, 0}) + .Case("pto.tpartmul", OpcodeAndVariant{0x1044, 0, 0}) + .Case("pto.tprefetch", OpcodeAndVariant{0x1045, 0, 0}) + .Case("pto.tprelu", OpcodeAndVariant{0x1046, 0, 0}) + .Case("pto.tquant", OpcodeAndVariant{0x1047, 0, 0}) + .Case("pto.trecip", OpcodeAndVariant{0x1048, 0, 0}) + .Case("pto.trelu", OpcodeAndVariant{0x1049, 0, 0}) + .Case("pto.trem", OpcodeAndVariant{0x104A, 0, 0}) + .Case("pto.trems", OpcodeAndVariant{0x104B, 0, 0}) + .Case("pto.treshape", OpcodeAndVariant{0x104C, 0, 0}) + .Case("pto.trowexpand", OpcodeAndVariant{0x104D, 0, 0}) + .Case("pto.trowexpandadd", OpcodeAndVariant{0x104E, 0, 0}) + .Case("pto.trowexpandexpdif", OpcodeAndVariant{0x104F, 0, 0}) + .Case("pto.trowexpandmax", OpcodeAndVariant{0x1050, 0, 0}) + .Case("pto.trowexpandmin", OpcodeAndVariant{0x1051, 0, 0}) + .Case("pto.trowmax", OpcodeAndVariant{0x1052, 0, 0}) + .Case("pto.trowmin", OpcodeAndVariant{0x1053, 0, 0}) + .Case("pto.trowsum", OpcodeAndVariant{0x1054, 0, 0}) + .Case("pto.trsqrt", OpcodeAndVariant{0x1055, 0, 0}) + .Case("pto.tscatter", OpcodeAndVariant{0x1056, 0, 0}) + .Case("pto.tsel", OpcodeAndVariant{0x1057, 0, 0}) + .Case("pto.tsels", OpcodeAndVariant{0x1058, 0, 0}) + .Case("pto.tset_img2col_padding", OpcodeAndVariant{0x1059, 0, 0}) + .Case("pto.tset_img2col_rpt", OpcodeAndVariant{0x105A, 0, 0}) + .Case("pto.tsetfmatrix", OpcodeAndVariant{0x105B, 0, 0}) + .Case("pto.tsethf32mode", OpcodeAndVariant{0x105C, 0, 0}) + .Case("pto.tsettf32mode", OpcodeAndVariant{0x105D, 0, 0}) + .Case("pto.tsetval", OpcodeAndVariant{0x105E, 0, 0}) + .Case("pto.tshl", OpcodeAndVariant{0x105F, 0, 0}) + .Case("pto.tshls", OpcodeAndVariant{0x1060, 0, 0}) + .Case("pto.tshr", OpcodeAndVariant{0x1061, 0, 0}) + .Case("pto.tshrs", OpcodeAndVariant{0x1062, 0, 0}) + .Case("pto.tsort32", OpcodeAndVariant{0x1063, 0, 0}) + .Case("pto.tsqrt", OpcodeAndVariant{0x1064, 0, 0}) + .Case("pto.tstore", OpcodeAndVariant{0x1065, 0, 0}) + .Case("pto.tstore_fp", OpcodeAndVariant{0x1066, 0, 0}) + .Case("pto.tsub", OpcodeAndVariant{0x1067, 0, 0}) + .Case("pto.tsubc", OpcodeAndVariant{0x1068, 0, 0}) + .Case("pto.tsubs", OpcodeAndVariant{0x1069, 0, 0}) + .Case("pto.tsubsc", OpcodeAndVariant{0x106A, 0, 0}) + .Case("pto.trowexpandsub", OpcodeAndVariant{0x106B, 0, 0}) + .Case("pto.ttrans", OpcodeAndVariant{0x106C, 0, 0}) + .Case("pto.ttri", OpcodeAndVariant{0x106D, 0, 0}) + .Case("pto.txor", OpcodeAndVariant{0x106E, 0, 0}) + .Case("pto.txors", OpcodeAndVariant{0x106F, 0, 0}) + .Case("pto.wait_event", OpcodeAndVariant{0x1070, 0, 0}) + .Case("pto.tprint", OpcodeAndVariant{0x1071, 0, 0}) + .Case("pto.subview", OpcodeAndVariant{0x1072, 0, 0}) + .Case("pto.trowexpanddiv", OpcodeAndVariant{0x1073, 0, 0}) + .Case("pto.trowexpandmul", OpcodeAndVariant{0x1074, 0, 0}) + .Case("pto.tdequant", OpcodeAndVariant{0x1075, 0, 0}) + .Case("pto.taxpy", OpcodeAndVariant{0x1076, 0, 0}) + .Case("pto.thistogram", OpcodeAndVariant{0x1077, 0, 0}) + .Case("pto.tget_scale_addr", OpcodeAndVariant{0x1078, 0, 0}) + .Case("pto.trowargmax", OpcodeAndVariant{0x1079, 0, 0}) + .Case("pto.trowargmin", OpcodeAndVariant{0x107A, 0, 0}) + .Case("pto.tcolargmax", OpcodeAndVariant{0x107B, 0, 0}) + .Case("pto.tcolargmin", OpcodeAndVariant{0x107C, 0, 0}) + .Case("pto.tsync", OpcodeAndVariant{0x107D, 0, 0}) + .Case("pto.reserve_buffer", OpcodeAndVariant{0x107E, 0, 0}) + .Case("pto.import_reserved_buffer", OpcodeAndVariant{0x107F, 0, 0}) + .Case("pto.aic_initialize_pipe", OpcodeAndVariant{0x1080, 0, 0}) + .Case("pto.aiv_initialize_pipe", OpcodeAndVariant{0x1081, 0, 0}) + .Case("pto.tpush_to_aiv", OpcodeAndVariant{0x1082, 0, 0}) + .Case("pto.tpush_to_aic", OpcodeAndVariant{0x1083, 0, 0}) + .Case("pto.tpop_from_aic", OpcodeAndVariant{0x1084, 0, 0}) + .Case("pto.tpop_from_aiv", OpcodeAndVariant{0x1085, 0, 0}) + .Case("pto.tfree_from_aic", OpcodeAndVariant{0x1086, 0, 0}) + .Case("pto.tfree_from_aiv", OpcodeAndVariant{0x1087, 0, 0}) + .Case("pto.set_validshape", OpcodeAndVariant{0x1088, 0, 0}) + .Case("pto.tconcat", OpcodeAndVariant{0x1089, 0, 0}) + .Case("pto.trowprod", OpcodeAndVariant{0x108A, 0, 0}) + .Case("pto.initialize_l2g2l_pipe", OpcodeAndVariant{0x108B, 0, 0}) + .Case("pto.initialize_l2l_pipe", OpcodeAndVariant{0x108C, 0, 0}) + .Case("pto.tpush", OpcodeAndVariant{0x108D, 0, 0}) + .Case("pto.declare_tile", OpcodeAndVariant{0x108E, 0, 0}) + .Case("pto.tpop", OpcodeAndVariant{0x108F, 0, 0}) + .Case("pto.tfree", OpcodeAndVariant{0x1090, 0, 0}) + .Case("pto.comm.tput", OpcodeAndVariant{0x1091, 0, 0}) + .Case("pto.comm.tget", OpcodeAndVariant{0x1092, 0, 0}) + .Case("pto.comm.tnotify", OpcodeAndVariant{0x1093, 0, 0}) + .Case("pto.comm.twait", OpcodeAndVariant{0x1094, 0, 0}) + .Case("pto.comm.ttest", OpcodeAndVariant{0x1095, 0, 0}) + .Case("pto.comm.tbroadcast", OpcodeAndVariant{0x1096, 0, 0}) + .Case("pto.comm.tgather", OpcodeAndVariant{0x1097, 0, 0}) + .Case("pto.comm.tscatter", OpcodeAndVariant{0x1098, 0, 0}) + .Case("pto.comm.treduce", OpcodeAndVariant{0x1099, 0, 0}) + .Case("pto.tpartargmax", OpcodeAndVariant{0x109A, 0, 0}) + .Case("pto.tpartargmin", OpcodeAndVariant{0x109B, 0, 0}) + .Case("scf.for", OpcodeAndVariant{0x4000, 0, 0}) + .Case("scf.if", OpcodeAndVariant{0x4001, 0, 0}) + .Case("scf.yield", OpcodeAndVariant{0x4002, 0, 0}) + .Case("pto.section.cube", + OpcodeAndVariant{0x0006, kHasVariant, kSectionCubeVariant}) + .Case("pto.section.vector", + OpcodeAndVariant{0x0006, kHasVariant, kSectionVectorVariant}) + .Case("pto.tgemv", + OpcodeAndVariant{0x102A, kHasVariant, kVariantDefault}) + .Case("pto.tgemv.acc", + OpcodeAndVariant{0x102A, kHasVariant, kVariantAcc}) + .Case("pto.tgemv.bias", + OpcodeAndVariant{0x102A, kHasVariant, kVariantBias}) + .Case("pto.tgemv.mx", + OpcodeAndVariant{0x102A, kHasVariant, kVariantMx}) + .Case("pto.tgemv.mx.acc", + OpcodeAndVariant{0x102A, kHasVariant, kVariantMxAcc}) + .Case("pto.tgemv.mx.bias", + OpcodeAndVariant{0x102A, kHasVariant, kVariantMxBias}) + .Case("pto.tmatmul", + OpcodeAndVariant{0x1032, kHasVariant, kVariantDefault}) + .Case("pto.tmatmul.acc", + OpcodeAndVariant{0x1032, kHasVariant, kVariantAcc}) + .Case("pto.tmatmul.bias", + OpcodeAndVariant{0x1032, kHasVariant, kVariantBias}) + .Case("pto.tmatmul.mx", + OpcodeAndVariant{0x1033, kHasVariant, kVariantDefault}) + .Case("pto.tmatmul.mx.acc", + OpcodeAndVariant{0x1033, kHasVariant, kVariantAcc}) + .Case("pto.tmatmul.mx.bias", + OpcodeAndVariant{0x1033, kHasVariant, kVariantBias}) + .Default(std::nullopt); +} + +inline const char *fullNameFromOpcodeVariant(uint16_t opcode, uint8_t variant) { + const OpInfo *info = lookupByOpcode(opcode); + if (!info) return nullptr; + if (opcode == kTscatterMaskOpcode) return "pto.tscatter"; + if (!info->has_variant_u8) return info->name; + switch (opcode) { + case 0x0006: + switch (variant) { + case kSectionCubeVariant: return "pto.section.cube"; + case kSectionVectorVariant: return "pto.section.vector"; + default: return info->name; + } + case 0x102A: + switch (variant) { + case kVariantDefault: return "pto.tgemv"; + case kVariantAcc: return "pto.tgemv.acc"; + case kVariantBias: return "pto.tgemv.bias"; + case kVariantMx: return "pto.tgemv.mx"; + case kVariantMxAcc: return "pto.tgemv.mx.acc"; + case kVariantMxBias: return "pto.tgemv.mx.bias"; + default: return info->name; + } + case 0x1032: + switch (variant) { + case kVariantDefault: return "pto.tmatmul"; + case kVariantAcc: return "pto.tmatmul.acc"; + case kVariantBias: return "pto.tmatmul.bias"; + default: return info->name; + } + case 0x1033: + switch (variant) { + case kVariantDefault: return "pto.tmatmul.mx"; + case kVariantAcc: return "pto.tmatmul.mx.acc"; + case kVariantBias: return "pto.tmatmul.mx.bias"; + default: return info->name; + } + default: return info->name; + } +} + +inline std::optional lookupOperandsByVariant(uint16_t opcode, uint8_t variant) { + switch (opcode) { + case 0x102A: + switch (variant) { + case kVariantDefault: return kTgemvOperandCount; + case kVariantAcc: return kTgemvAccOperandCount; + case kVariantBias: return kTgemvBiasOperandCount; + case kVariantMx: return kTgemvMxOperandCount; + case kVariantMxAcc: return kTgemvMxAccOperandCount; + case kVariantMxBias: return kTgemvMxBiasOperandCount; + default: return std::nullopt; + } + case 0x1032: + switch (variant) { + case kVariantDefault: return kTmatmulOperandCount; + case kVariantAcc: return kTmatmulAccOperandCount; + case kVariantBias: return kTmatmulBiasOperandCount; + default: return std::nullopt; + } + case 0x1033: + switch (variant) { + case kVariantDefault: return kTmatmulMxOperandCount; + case kVariantAcc: return kTmatmulMxAccOperandCount; + case kVariantBias: return kTmatmulMxBiasOperandCount; + default: return std::nullopt; + } + default: return std::nullopt; + } +} + +} // namespace ptobc::v0 From b49334128c5678061a719e253e7db011a9040ae0 Mon Sep 17 00:00:00 2001 From: HecreReed <821896444@qq.com> Date: Tue, 26 May 2026 12:25:13 +0800 Subject: [PATCH 3/8] refactor: split small nbnc hotspots without def includes --- lib/PTO/Transforms/CMakeLists.txt | 1 + .../Transforms/GraphSyncSolver/SyncSolver.cpp | 663 ---------------- .../GraphSyncSolver/SyncSolverMerge.cpp | 705 ++++++++++++++++++ tools/ptobc/CMakeLists.txt | 1 + tools/ptobc/generated/ptobc_opcodes_v0.cpp | 679 +++++++++++++++++ tools/ptobc/generated/ptobc_opcodes_v0.h | 668 +---------------- 6 files changed, 1393 insertions(+), 1324 deletions(-) create mode 100644 lib/PTO/Transforms/GraphSyncSolver/SyncSolverMerge.cpp create mode 100644 tools/ptobc/generated/ptobc_opcodes_v0.cpp diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index b8194d0a4..7f0b1b850 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -50,6 +50,7 @@ add_mlir_dialect_library(PTOTransforms GraphSyncSolver/GraphSolver.cpp GraphSyncSolver/EventIdSolver.cpp GraphSyncSolver/SyncSolver.cpp + GraphSyncSolver/SyncSolverMerge.cpp GraphSyncSolver/SyncSolverCodeGen.cpp LoweringSyncToPipe.cpp PTOVerifyTFreePass.cpp diff --git a/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp b/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp index 23a4032a6..c920fd69e 100644 --- a/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp +++ b/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp @@ -1911,666 +1911,3 @@ void Solver::handleConflict(Occurrence *occ1, Occurrence *occ2, } } -void Solver::calcAllEventIds() { - for (auto &[pipes, eventIdSolver] : eventIdSolver) { - assert(eventIdSolver != nullptr); - - [[maybe_unused]] auto result = - eventIdSolver->shrinkEventIdMaxToEventIdNum(); - assert(llvm::succeeded(result)); - assert(eventIdSolver->isColorable()); - } -} - -void Solver::collectBackwardSyncEventIds() { - LLVM_DEBUG(llvm::dbgs() << "collectBackwardSyncEventIds\n";); - for (auto &conflictPair : chosenConflictedPairs) { - if (!conflictPair->isUseless && conflictPair->isInnerBackward && - conflictPair->eventIdNode != nullptr) { - LLVM_DEBUG(llvm::dbgs() << " " << conflictPair->str() << "\n";); - for (auto eventId : conflictPair->eventIdNode->getEventIds()) { - auto &e = backwardSyncEvents[conflictPair->backwardSyncLoopOp] - [{conflictPair->setCorePipeInfo, - conflictPair->waitCorePipeInfo}][eventId]; - e = std::max(e, conflictPair->eventIdInfo.eventIdRepeatNum); - } - } - } -} - -void Solver::resetAndBuildSetWaitOpIndex(const SyncMap &syncMapBefore, - const SyncMap &syncMapAfter) { - globalSetWaitIndex = 0; - setWaitStartIndex.clear(); - setWaitEndIndex.clear(); - setWaitStartIndexInclusive.clear(); - setWaitEndIndexInclusive.clear(); - setWaitFlagOpsIndex.clear(); - collectSetWaitOpsIndexes(funcIr.get(), syncMapBefore, syncMapAfter); -} - -std::set> & -Solver::getSetWaitOpsIndexRef(pto::PIPE pipeSrc, pto::PIPE pipeDst, - int64_t eventId) { - auto key = std::make_tuple(pipeSrc, pipeDst, eventId); - return setWaitFlagOpsIndex[key]; -} - -// Collect indices for all Set/Wait ops to facilitate merging decisions. -void Solver::collectSetWaitOpsIndexes(OperationBase *op, - const SyncMap &syncMapBefore, - const SyncMap &syncMapAfter) { - assert(op != nullptr); - setWaitStartIndexInclusive[op] = globalSetWaitIndex++; - if (syncMapBefore.count(op)) { - auto *it = syncMapBefore.find(op); - assert(it != syncMapBefore.end()); - for (auto &syncOp : it->second) { - if (auto *setWaitOp = llvm::dyn_cast(syncOp.get())) { - for (auto eventId : setWaitOp->eventIds) { - auto &index = getSetWaitOpsIndexRef(setWaitOp->pipeSrc, - setWaitOp->pipeDst, eventId); - index.insert({globalSetWaitIndex++, setWaitOp}); - } - } - } - } - setWaitStartIndex[op] = globalSetWaitIndex++; - if (auto *scopeOp = llvm::dyn_cast(op)) { - for (auto &childOp : scopeOp->body) { - collectSetWaitOpsIndexes(childOp.get(), syncMapBefore, syncMapAfter); - } - } - setWaitEndIndex[op] = globalSetWaitIndex++; - if (syncMapAfter.count(op)) { - auto *it = syncMapAfter.find(op); - assert(it != syncMapAfter.end()); - for (auto &syncOp : it->second) { - if (auto *setWaitOp = llvm::dyn_cast(syncOp.get())) { - for (auto eventId : setWaitOp->eventIds) { - auto &index = getSetWaitOpsIndexRef(setWaitOp->pipeSrc, - setWaitOp->pipeDst, eventId); - index.insert({globalSetWaitIndex++, setWaitOp}); - } - } - } - } - setWaitEndIndexInclusive[op] = globalSetWaitIndex++; -} - -bool Solver::checkBackwardSyncEventsContains(OperationBase *op, - CorePipeInfo corePipeSrc, - CorePipeInfo corePipeDst, - int64_t eventId) { - auto *it1 = backwardSyncEvents.find(op); - if (it1 == backwardSyncEvents.end()) { - return false; - } - auto it2 = it1->second.find({corePipeSrc, corePipeDst}); - if (it2 == it1->second.end()) { - return false; - } - return it2->second.contains(eventId); -} - -bool Solver::checkBackwardSyncEventsContainsAfterMerge( - OperationBase *op, CorePipeInfo corePipeSrc, CorePipeInfo corePipeDst) { - auto *it1 = backwardSyncEventsAfterMerge.find(op); - if (it1 == backwardSyncEventsAfterMerge.end()) { - return false; - } - return it1->second.contains({corePipeSrc, corePipeDst}); -} - -// Check whether a backward-sync event id can be merged at scope level. -bool Solver::checkMergeable(Scope *scopeOp, CorePipeInfo corePipeSrc, - CorePipeInfo corePipeDst, int64_t eventId, - bool shouldBeUsedAtleastOnce) { - auto &index = - getSetWaitOpsIndexRef(corePipeSrc.pipe, corePipeDst.pipe, eventId); - if (shouldBeUsedAtleastOnce) { - auto it = index.lower_bound({setWaitStartIndexInclusive[scopeOp], nullptr}); - bool usedAtleastOnce = - it != index.end() && it->first < setWaitEndIndexInclusive[scopeOp]; - if (!usedAtleastOnce) { - return false; - } - } - { - auto it1 = - index.lower_bound({setWaitStartIndexInclusive[scopeOp], nullptr}); - auto it2 = index.lower_bound({setWaitEndIndex[scopeOp], nullptr}); - bool usedBefore = - it1 != index.end() && it1->first < setWaitStartIndex[scopeOp]; - bool usedAfter = - it2 != index.end() && it2->first < setWaitEndIndexInclusive[scopeOp]; - if (usedBefore || usedAfter) { - return false; - } - } - if (auto *conditionOp = llvm::dyn_cast(scopeOp)) { - if (!conditionOp->hasFalseScope()) { - return false; - } - return checkMergeable(conditionOp->getTrueScope(), corePipeSrc, corePipeDst, - eventId, true) && - checkMergeable(conditionOp->getFalseScope(), corePipeSrc, - corePipeDst, eventId, true); - } - if (auto *loopOp = llvm::dyn_cast(scopeOp)) { - for (auto &childOp : loopOp->body) { - if (auto *childScopeOp = llvm::dyn_cast(childOp.get())) { - if (!checkMergeable(childScopeOp, corePipeSrc, corePipeDst, eventId, - false)) { - return false; - } - } - } - for (auto &childOp : loopOp->body) { - if (auto *childScopeOp = llvm::dyn_cast(childOp.get())) { - if (checkMergeable(childScopeOp, corePipeSrc, corePipeDst, eventId, - true)) { - return true; - } - } - } - return false; - } - for (auto &childOp : scopeOp->body) { - auto it1 = - index.lower_bound({setWaitStartIndexInclusive[childOp.get()], nullptr}); - auto it2 = index.lower_bound({setWaitEndIndex[childOp.get()], nullptr}); - bool usedAtleastOnce = it1 != index.end() && - it1->first < setWaitEndIndexInclusive[childOp.get()]; - if (!usedAtleastOnce) { - continue; - } - bool before = - it1 != index.end() && it1->first < setWaitStartIndex[childOp.get()]; - bool after = it2 != index.end() && - it2->first < setWaitEndIndexInclusive[childOp.get()]; - if (before || after) { - return false; - } - if (!checkBackwardSyncEventsContains(childOp.get(), corePipeSrc, - corePipeDst, eventId)) { - return false; - } - if (checkBackwardSyncEventsContainsAfterMerge(childOp.get(), corePipeSrc, - corePipeDst)) { - return false; - } - } - return true; -} - -// Attempt to merge backward sync events across children and prune duplicates. -void Solver::mergeBackwardSyncEventIds(OperationBase *op) { - auto *scopeOp = llvm::dyn_cast_if_present(op); - if (scopeOp == nullptr) { - return; - } - for (auto &op : scopeOp->body) { - mergeBackwardSyncEventIds(op.get()); - } - - if (llvm::isa_and_present(op)) { - return; - } - if (llvm::isa_and_present(op->parentOp)) { - return; - } - - auto *conditionOp = llvm::dyn_cast(op); - if (conditionOp != nullptr) { - if (!conditionOp->hasFalseScope()) { - return; - } - } - - llvm::DenseSet> toBeErased; - - llvm::SmallVector coreTypes; - if (options.isCrossCoreMode()) { - coreTypes = {pto::TCoreType::VECTOR, pto::TCoreType::CUBE}; - } else { - coreTypes = {pto::TCoreType::CUBE_OR_VECTOR}; - } - size_t pipeNumMax = static_cast(pto::PIPE::PIPE_NUM); - const int64_t eventIdMax = getHWAvailableEventIdNum(options.syncMode); - - for (int64_t eventId = 0; eventId < eventIdMax; ++eventId) { - for (auto coreSrc : coreTypes) { - for (auto coreDst : coreTypes) { - for (size_t pipeSrcInt = 0; pipeSrcInt < pipeNumMax; pipeSrcInt++) { - for (size_t pipeDstInt = 0; pipeDstInt < pipeNumMax; pipeDstInt++) { - auto pipeSrc = static_cast(pipeSrcInt); - auto pipeDst = static_cast(pipeDstInt); - auto corePipeSrc = CorePipeInfo(coreSrc, pipeSrc); - auto corePipeDst = CorePipeInfo(coreDst, pipeDst); - if (checkBackwardSyncEventsContains(scopeOp, corePipeSrc, - corePipeDst, eventId)) { - continue; - } - if (checkMergeable(scopeOp, corePipeSrc, corePipeDst, eventId)) { - toBeErased.insert({corePipeSrc, corePipeDst, eventId}); - backwardSyncEvents[scopeOp][{corePipeSrc, corePipeDst}].insert( - {eventId, 1}); - } - } - } - } - } - } - - if (isa(scopeOp)) { - for (auto &op : scopeOp->body) { - if (auto *block = llvm::dyn_cast(op.get())) { - for (auto &childOp : block->body) { - if (auto *childScopeOp = llvm::dyn_cast(childOp.get())) { - for (auto [corePipeSrc, corePipeDst, eventId] : toBeErased) { - if (checkBackwardSyncEventsContains(childScopeOp, corePipeSrc, - corePipeDst, eventId)) { - auto key = std::make_tuple(corePipeSrc, corePipeDst); - backwardSyncEvents[childScopeOp][key].erase(eventId); - if (backwardSyncEvents[childScopeOp][key].empty()) { - backwardSyncEvents[childScopeOp].erase(key); - } - } - } - } - } - } - } - } else { - for (auto &childOp : scopeOp->body) { - if (auto *childScopeOp = llvm::dyn_cast(childOp.get())) { - for (auto [corePipeSrc, corePipeDst, eventId] : toBeErased) { - if (checkBackwardSyncEventsContains(childScopeOp, corePipeSrc, - corePipeDst, eventId)) { - auto key = std::make_tuple(corePipeSrc, corePipeDst); - backwardSyncEvents[childScopeOp][key].erase(eventId); - if (backwardSyncEvents[childScopeOp][key].empty()) { - backwardSyncEvents[childScopeOp].erase(key); - } - } - } - } - } - } -} - -void Solver::mergeBackwardSyncPairs(SyncMap &syncMapBefore, - SyncMap &syncMapAfter) { - if (!options.moveOutAndMergeBackwardSyncPairs) { - return; - } - if (options.isIntraCoreMode()) { - resetAndBuildSetWaitOpIndex(syncMapBefore, syncMapAfter); - auto *scopeOp = llvm::dyn_cast(funcIr.get()); - assert(scopeOp != nullptr && scopeOp->body.front() != nullptr); - mergeBackwardSyncEventIds(scopeOp->body.front().get()); - } -} - -SyncBeforeAfterMap Solver::getBeforeAfterSyncMaps() { - calcAllEventIds(); - SyncMap syncMapBefore, syncMapAfter; - std::vector conflictPairs; - for (auto &conflictPair : chosenConflictedPairs) { - conflictPairs.push_back(conflictPair.get()); - } - for (auto &conflictPair : persistentChosenConflictedPairs) { - conflictPairs.push_back(conflictPair.get()); - } - - for (auto *conflictPair : conflictPairs) { - if (conflictPair->isUseless) { - continue; - } - if (conflictPair->replacedWithUnitFlag) { - continue; - } - assert(conflictPair->setOp != nullptr && conflictPair->waitOp != nullptr); - if (conflictPair->isBarrier()) { - auto barrierOp = std::make_unique( - conflictPair->waitOp->op, conflictPair->waitOp->parentOp, - conflictPair->waitCorePipeInfo.pipe); - LLVM_DEBUG(barrierOp->debugId = conflictPair->id); - syncMapBefore[conflictPair->waitOp].push_back(std::move(barrierOp)); - } else { - assert(conflictPair->eventIdNode != nullptr); - auto setOp = std::make_unique( - conflictPair->setOp->op, conflictPair->setOp->parentOp, - conflictPair->eventIdNode->getEventIds(), - conflictPair->setCorePipeInfo.pipe, - conflictPair->waitCorePipeInfo.pipe); - auto waitOp = std::make_unique( - conflictPair->waitOp->op, conflictPair->waitOp->parentOp, - conflictPair->eventIdNode->getEventIds(), - conflictPair->setCorePipeInfo.pipe, - conflictPair->waitCorePipeInfo.pipe); - if (options.isCrossCoreMode()) { - setOp->coreType = conflictPair->setCorePipeInfo.coreType; - waitOp->coreType = conflictPair->waitCorePipeInfo.coreType; - } - setOp->eventIdInfo = conflictPair->eventIdInfo; - waitOp->eventIdInfo = conflictPair->eventIdInfo; - setOp->checkLastIter = conflictPair->setOnLastIterOnly; - waitOp->checkFirstIter = conflictPair->waitOnFirstIterOnly; - LLVM_DEBUG({ - setOp->debugId = conflictPair->id; - waitOp->debugId = conflictPair->id; - }); - assert(setOp != nullptr && waitOp != nullptr); - syncMapAfter[conflictPair->setOp].push_back(std::move(setOp)); - syncMapBefore[conflictPair->waitOp].push_front(std::move(waitOp)); - } - } - - collectBackwardSyncEventIds(); - mergeBackwardSyncPairs(syncMapBefore, syncMapAfter); - - for (auto &[op, mp] : backwardSyncEvents) { - if (mp.empty()) { - continue; - } - auto *scopeOp = llvm::dyn_cast(op); - assert(scopeOp != nullptr); - for (auto [setWaitCorePipes, eventIdsMp] : mp) { - if (eventIdsMp.empty()) { - continue; - } - llvm::SmallVector eventIds; - for (auto [eventId, repeatNum] : eventIdsMp) { - llvm::SmallVector curEventIds(repeatNum, eventId); - llvm::append_range(eventIds, curEventIds); - } - llvm::sort(eventIds); - auto [corePipeSrc, corePipeDst] = setWaitCorePipes; - auto setOp = - std::make_unique(scopeOp->op, scopeOp->parentOp, eventIds, - corePipeSrc.pipe, corePipeDst.pipe); - auto waitOp = - std::make_unique(scopeOp->op, scopeOp->parentOp, eventIds, - corePipeSrc.pipe, corePipeDst.pipe); - setOp->allAtOnce = true; - waitOp->allAtOnce = true; - if (options.isCrossCoreMode()) { - setOp->coreType = corePipeSrc.coreType; - waitOp->coreType = corePipeDst.coreType; - } - assert(setOp != nullptr && waitOp != nullptr); - syncMapBefore[scopeOp].push_back(std::move(setOp)); - syncMapAfter[scopeOp].push_front(std::move(waitOp)); - } - } - return std::make_pair(std::move(syncMapBefore), std::move(syncMapAfter)); -} - -void Solver::processConflict(Occurrence *occ1, Occurrence *occ2, - RWOperation *rwOp1, RWOperation *rwOp2, - bool isUseless) { - for (auto [corePipeSrc, corePipeDst] : checkMemoryConflicts(rwOp1, rwOp2)) { - if (options.alwaysUsePipeSAsWaitingPipe) { - corePipeDst.pipe = pto::PIPE::PIPE_S; - } - auto eventIdInfo = - getEventIdInfo(occ1, occ2, rwOp1, rwOp2, corePipeSrc, corePipeDst); - handleConflict(occ1, occ2, rwOp1, rwOp2, corePipeSrc, corePipeDst, - eventIdInfo, isUseless); - } -} - -// Main processing loop that iterates processingOrders and attempts to -// discover and record conflicts. -void Solver::processOrders() { - for (auto &[occ1, occ2, rwOp1, rwOp2, isUseless] : processingOrders) { - assert(occ1 != occ2); - assert(occ1->syncIrIndex < occ2->syncIrIndex); - if (checkVisited(occ1, occ2)) { - assert(false && "expected to not check a pair more than once."); - continue; - } - if (checkImpossibleOccPair(occ1, occ2) || checkAlreadySynced(occ1, occ2) || - skipMMad1DecomposedLoopOpt(occ1, occ2) || - checkSkipParallelLoop(occ1, occ2) || - checkSkipCrossCorePair(occ1, occ2)) { - continue; - } - DEBUG_WITH_TYPE("gss-sync-solver-checking", { - llvm::dbgs() << "checking: " << (isUseless ? "is-useless\n" : "\n"); - llvm::dbgs() << occ1->syncIrIndex << ' ' << occ1->startIndex << ' ' - << occ1->endIndex << ' ' << occ1->op->str(0, false) << '\n'; - llvm::dbgs() << occ2->syncIrIndex << ' ' << occ2->startIndex << ' ' - << occ2->endIndex << ' ' << occ2->op->str(0, false) << '\n'; - }); - if (checkAlreadySyncedWithUnitFlag(occ1, occ2)) { - continue; - } - processConflict(occ1, occ2, rwOp1, rwOp2, isUseless); - } -} - -void Solver::insertMergedBackwardSyncPairs() { - for (auto &[scopeOp, st] : backwardSyncEventsAfterMerge) { - for (auto &corePipeInfoPair : st) { - auto [corePipeSrc, corePipeDst] = corePipeInfoPair; - for (auto *scopeOcc : opAllOccurrences[scopeOp]) { - auto *parentScopeOcc = scopeOcc->parentOcc; - assert(parentScopeOcc != nullptr); - Occurrence *setOcc = nullptr; - Occurrence *waitOcc = nullptr; - auto startIndex = scopeOcc->startIndex; - auto endIndex = scopeOcc->endIndex; - if (isa(scopeOp)) { - setOcc = getBeforePlaceHolderOcc(scopeOcc); - waitOcc = getAfterPlaceHolderOcc(scopeOcc); - startIndex = setOcc->endIndex; - endIndex = waitOcc->startIndex; - } - auto conflictPair = std::make_unique( - nullptr, nullptr, nullptr, nullptr, setOcc, waitOcc, corePipeSrc, - corePipeDst, startIndex, endIndex); - assert(conflictPair->startIndex <= conflictPair->endIndex); - conflictPair->isUseless = true; - conflictPair->dontReuse = true; - conflictPair->dontCheckForConflict = true; - conflictPair->couldNotRun = false; // notice this - LLVM_DEBUG({ - llvm::dbgs() << "consider-merged-backward-pair: " - << scopeOp->str(0, false) << ' ' << conflictPair->str() - << "\n"; - }); - scopeOccChosenConflicts[parentScopeOcc].insert(conflictPair.get()); - chosenConflictedPairs.push_back(std::move(conflictPair)); - } - } - } -} - -llvm::LogicalResult Solver::considerOuterBackwardSyncPairs() { - if (!options.considerOuterBackwardSyncPairs) { - return llvm::failure(); - } - bool backwardPairsPositionChanged = false; - for (auto &[scopeOp, st] : backwardSyncEventsAfterMerge) { - SmallVector> toBeErased; - for (auto &corePipeInfoPair : st) { - if (!backwardSyncEvents.contains(scopeOp) || - !backwardSyncEvents[scopeOp].contains(corePipeInfoPair)) { - toBeErased.push_back(corePipeInfoPair); - } - } - if (!toBeErased.empty()) { - backwardPairsPositionChanged = true; - for (auto &corePipeInfoPair : toBeErased) { - st.erase(corePipeInfoPair); - } - } - } - int chosenOpsDepth = -1; - SmallVector chosenOps; - for (auto &[scopeOp, mp] : backwardSyncEvents) { - if (backwardSyncEventsAfterMerge.contains(scopeOp)) { - continue; - } - int scopeOpDepth = scopeOp->getDepth(); - if (chosenOpsDepth == scopeOpDepth) { - chosenOps.push_back(scopeOp); - } else if (chosenOpsDepth == -1 || chosenOpsDepth < scopeOpDepth) { - chosenOps.clear(); - chosenOps.push_back(scopeOp); - chosenOpsDepth = scopeOpDepth; - } - } - if (chosenOps.empty()) { - return llvm::failure(); - } - bool newPairIsInserted = false; - for (auto *chosenOp : chosenOps) { - for (auto &[corePipeInfoPair, eventIdsMp] : backwardSyncEvents[chosenOp]) { - assert(!eventIdsMp.empty()); - if (!eventIdsMp.empty()) { - auto [it, isInserted] = - backwardSyncEventsAfterMerge[chosenOp].insert(corePipeInfoPair); - newPairIsInserted |= isInserted; - } - } - } - return llvm::success(backwardPairsPositionChanged || newPairIsInserted); -} - -llvm::LogicalResult Solver::reuseSyncPairToSaveEventIds() { - if (!options.reuseSyncPairToSaveEventIds || barrierAllPairs.empty()) { - return llvm::failure(); - } - bool limitReached = true; - for (auto [corePipeSrc, corePipeDst] : barrierAllPairs) { - if (reusePairs[{corePipeSrc, corePipeDst}] < maxReuseNum) { - if (reusePairs[{corePipeSrc, corePipeDst}] <= - reusedPairs[{corePipeSrc, corePipeDst}]) { - reusePairs[{corePipeSrc, corePipeDst}] += 1; - limitReached = false; - } - } - } - DEBUG_WITH_TYPE("gss-sync-solver-reuse", { - llvm::dbgs() << "reusePairs: \n"; - for (auto [pipeCorePairs, cnt] : reusePairs) { - llvm::dbgs() << get<0>(pipeCorePairs).pipe << ' ' - << get<1>(pipeCorePairs).pipe << ' ' << cnt << '\n'; - } - }); - return llvm::success(!limitReached); -} - -llvm::LogicalResult Solver::disableMultiEventIdForBarrierAllPairs() { - if (!options.disableMultiEventIdForBarrierAllPairs || - barrierAllPairs.empty()) { - return llvm::failure(); - } - bool newPairIsInserted = false; - for (auto corePipeInfoPair : barrierAllPairs) { - auto [it, isInserted] = disabledMultiEventIdPairs.insert(corePipeInfoPair); - newPairIsInserted |= isInserted; - } - LLVM_DEBUG({ - if (newPairIsInserted) { - llvm::dbgs() << "disabled-multi-event-id-pairs: \n"; - for (auto &[corePipeSrc, corePipeDst] : disabledMultiEventIdPairs) { - llvm::dbgs() << corePipeSrc.coreType << ' ' << corePipeSrc.pipe << ' ' - << corePipeDst.coreType << ' ' << corePipeDst.pipe << '\n'; - } - } - }); - return llvm::success(newPairIsInserted); -} - -llvm::LogicalResult Solver::tryMovingOutBackwardSyncPairsToOuterLoops() { - if (!options.moveOutAndMergeBackwardSyncPairs || !options.isCrossCoreMode() || - dontMoveBackwardSyncPairsToOutmostLoop) { - return llvm::failure(); - } - if (!moveBackwardSyncPairsToOutmostLoop) { - moveBackwardSyncPairsToOutmostLoop = true; - return llvm::success(); - } - if (!barrierAllPairs.empty()) { - moveBackwardSyncPairsToOutmostLoop = false; - dontMoveBackwardSyncPairsToOutmostLoop = true; - return llvm::success(); - } - return llvm::failure(); -} - -// High-level solve orchestration with multiple passes and optional merging -// iterations. -llvm::LogicalResult Solver::runSolver(bool enableOpts1, bool enableOpts2) { - reset(/*resetEventIdRanOutOpts=*/true); - - int64_t runNum = 0; - while (runNum++ < maxRunNum) { - LLVM_DEBUG(llvm::dbgs() << "runNum: " << runNum << '\n'); - - reset(); - insertMergedBackwardSyncPairs(); - processOrders(); - - if (llvm::succeeded(tryMovingOutBackwardSyncPairsToOuterLoops())) { - continue; - } - - if (enableOpts1) { - if (options.considerOuterBackwardSyncPairs) { - getBeforeAfterSyncMaps(); - if (llvm::succeeded(considerOuterBackwardSyncPairs())) { - continue; - } - if (!barrierAllPairs.empty()) { - backwardSyncEventsAfterMerge.clear(); - } - } - } - - if (enableOpts2) { - if (!barrierAllPairs.empty()) { - if (llvm::succeeded(reuseSyncPairToSaveEventIds())) { - continue; - } - if (llvm::succeeded(disableMultiEventIdForBarrierAllPairs())) { - continue; - } - } - } - - if (!barrierAllPairs.empty()) { - pickAndInsertABarrierAll(); - reset(/*resetEventIdRanOutOpts=*/true); - continue; - } - break; - } - - reset(); - insertMergedBackwardSyncPairs(); - processOrders(); - - return llvm::success(runNum < maxRunNum); -} - -void Solver::solve() { - if (llvm::succeeded(runSolver())) { - return; - } - if (!options.isTestMode()) { - if (llvm::succeeded(runSolver(/*enableOpts1=*/false))) { - return; - } - if (llvm::succeeded( - runSolver(/*enableOpts1=*/false, /*enableOpts2=*/false))) { - return; - } - } - llvm_unreachable("GSS: runSolver() failed."); -} diff --git a/lib/PTO/Transforms/GraphSyncSolver/SyncSolverMerge.cpp b/lib/PTO/Transforms/GraphSyncSolver/SyncSolverMerge.cpp new file mode 100644 index 000000000..b35a37b79 --- /dev/null +++ b/lib/PTO/Transforms/GraphSyncSolver/SyncSolverMerge.cpp @@ -0,0 +1,705 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===--------- SyncSolver.cpp ------- Graph Sync Solver -------------------===// +//===----------------------------------------------------------------------===// + +#include "PTO/Transforms/GraphSyncSolver/SyncSolver.h" +#include "PTO/Transforms/GraphSyncSolver/GraphSolver.h" +#include "PTO/Transforms/GraphSyncSolver/MemInfo.h" +#include "PTO/Transforms/GraphSyncSolver/SyncSolverIR.h" +#include "PTO/Transforms/GraphSyncSolver/Utility.h" + +#include "PTO/IR/PTO.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/LogicalResult.h" +#include +#include +#include +#include +#include +#include +#include + +#define DEBUG_TYPE "PTO-gss-solver" + +using namespace mlir; +using namespace pto::syncsolver; + +void Solver::calcAllEventIds() { + for (auto &[pipes, eventIdSolver] : eventIdSolver) { + assert(eventIdSolver != nullptr); + + [[maybe_unused]] auto result = + eventIdSolver->shrinkEventIdMaxToEventIdNum(); + assert(llvm::succeeded(result)); + assert(eventIdSolver->isColorable()); + } +} + +void Solver::collectBackwardSyncEventIds() { + LLVM_DEBUG(llvm::dbgs() << "collectBackwardSyncEventIds\n";); + for (auto &conflictPair : chosenConflictedPairs) { + if (!conflictPair->isUseless && conflictPair->isInnerBackward && + conflictPair->eventIdNode != nullptr) { + LLVM_DEBUG(llvm::dbgs() << " " << conflictPair->str() << "\n";); + for (auto eventId : conflictPair->eventIdNode->getEventIds()) { + auto &e = backwardSyncEvents[conflictPair->backwardSyncLoopOp] + [{conflictPair->setCorePipeInfo, + conflictPair->waitCorePipeInfo}][eventId]; + e = std::max(e, conflictPair->eventIdInfo.eventIdRepeatNum); + } + } + } +} + +void Solver::resetAndBuildSetWaitOpIndex(const SyncMap &syncMapBefore, + const SyncMap &syncMapAfter) { + globalSetWaitIndex = 0; + setWaitStartIndex.clear(); + setWaitEndIndex.clear(); + setWaitStartIndexInclusive.clear(); + setWaitEndIndexInclusive.clear(); + setWaitFlagOpsIndex.clear(); + collectSetWaitOpsIndexes(funcIr.get(), syncMapBefore, syncMapAfter); +} + +std::set> & +Solver::getSetWaitOpsIndexRef(pto::PIPE pipeSrc, pto::PIPE pipeDst, + int64_t eventId) { + auto key = std::make_tuple(pipeSrc, pipeDst, eventId); + return setWaitFlagOpsIndex[key]; +} + +// Collect indices for all Set/Wait ops to facilitate merging decisions. +void Solver::collectSetWaitOpsIndexes(OperationBase *op, + const SyncMap &syncMapBefore, + const SyncMap &syncMapAfter) { + assert(op != nullptr); + setWaitStartIndexInclusive[op] = globalSetWaitIndex++; + if (syncMapBefore.count(op)) { + auto *it = syncMapBefore.find(op); + assert(it != syncMapBefore.end()); + for (auto &syncOp : it->second) { + if (auto *setWaitOp = llvm::dyn_cast(syncOp.get())) { + for (auto eventId : setWaitOp->eventIds) { + auto &index = getSetWaitOpsIndexRef(setWaitOp->pipeSrc, + setWaitOp->pipeDst, eventId); + index.insert({globalSetWaitIndex++, setWaitOp}); + } + } + } + } + setWaitStartIndex[op] = globalSetWaitIndex++; + if (auto *scopeOp = llvm::dyn_cast(op)) { + for (auto &childOp : scopeOp->body) { + collectSetWaitOpsIndexes(childOp.get(), syncMapBefore, syncMapAfter); + } + } + setWaitEndIndex[op] = globalSetWaitIndex++; + if (syncMapAfter.count(op)) { + auto *it = syncMapAfter.find(op); + assert(it != syncMapAfter.end()); + for (auto &syncOp : it->second) { + if (auto *setWaitOp = llvm::dyn_cast(syncOp.get())) { + for (auto eventId : setWaitOp->eventIds) { + auto &index = getSetWaitOpsIndexRef(setWaitOp->pipeSrc, + setWaitOp->pipeDst, eventId); + index.insert({globalSetWaitIndex++, setWaitOp}); + } + } + } + } + setWaitEndIndexInclusive[op] = globalSetWaitIndex++; +} + +bool Solver::checkBackwardSyncEventsContains(OperationBase *op, + CorePipeInfo corePipeSrc, + CorePipeInfo corePipeDst, + int64_t eventId) { + auto *it1 = backwardSyncEvents.find(op); + if (it1 == backwardSyncEvents.end()) { + return false; + } + auto it2 = it1->second.find({corePipeSrc, corePipeDst}); + if (it2 == it1->second.end()) { + return false; + } + return it2->second.contains(eventId); +} + +bool Solver::checkBackwardSyncEventsContainsAfterMerge( + OperationBase *op, CorePipeInfo corePipeSrc, CorePipeInfo corePipeDst) { + auto *it1 = backwardSyncEventsAfterMerge.find(op); + if (it1 == backwardSyncEventsAfterMerge.end()) { + return false; + } + return it1->second.contains({corePipeSrc, corePipeDst}); +} + +// Check whether a backward-sync event id can be merged at scope level. +bool Solver::checkMergeable(Scope *scopeOp, CorePipeInfo corePipeSrc, + CorePipeInfo corePipeDst, int64_t eventId, + bool shouldBeUsedAtleastOnce) { + auto &index = + getSetWaitOpsIndexRef(corePipeSrc.pipe, corePipeDst.pipe, eventId); + if (shouldBeUsedAtleastOnce) { + auto it = index.lower_bound({setWaitStartIndexInclusive[scopeOp], nullptr}); + bool usedAtleastOnce = + it != index.end() && it->first < setWaitEndIndexInclusive[scopeOp]; + if (!usedAtleastOnce) { + return false; + } + } + { + auto it1 = + index.lower_bound({setWaitStartIndexInclusive[scopeOp], nullptr}); + auto it2 = index.lower_bound({setWaitEndIndex[scopeOp], nullptr}); + bool usedBefore = + it1 != index.end() && it1->first < setWaitStartIndex[scopeOp]; + bool usedAfter = + it2 != index.end() && it2->first < setWaitEndIndexInclusive[scopeOp]; + if (usedBefore || usedAfter) { + return false; + } + } + if (auto *conditionOp = llvm::dyn_cast(scopeOp)) { + if (!conditionOp->hasFalseScope()) { + return false; + } + return checkMergeable(conditionOp->getTrueScope(), corePipeSrc, corePipeDst, + eventId, true) && + checkMergeable(conditionOp->getFalseScope(), corePipeSrc, + corePipeDst, eventId, true); + } + if (auto *loopOp = llvm::dyn_cast(scopeOp)) { + for (auto &childOp : loopOp->body) { + if (auto *childScopeOp = llvm::dyn_cast(childOp.get())) { + if (!checkMergeable(childScopeOp, corePipeSrc, corePipeDst, eventId, + false)) { + return false; + } + } + } + for (auto &childOp : loopOp->body) { + if (auto *childScopeOp = llvm::dyn_cast(childOp.get())) { + if (checkMergeable(childScopeOp, corePipeSrc, corePipeDst, eventId, + true)) { + return true; + } + } + } + return false; + } + for (auto &childOp : scopeOp->body) { + auto it1 = + index.lower_bound({setWaitStartIndexInclusive[childOp.get()], nullptr}); + auto it2 = index.lower_bound({setWaitEndIndex[childOp.get()], nullptr}); + bool usedAtleastOnce = it1 != index.end() && + it1->first < setWaitEndIndexInclusive[childOp.get()]; + if (!usedAtleastOnce) { + continue; + } + bool before = + it1 != index.end() && it1->first < setWaitStartIndex[childOp.get()]; + bool after = it2 != index.end() && + it2->first < setWaitEndIndexInclusive[childOp.get()]; + if (before || after) { + return false; + } + if (!checkBackwardSyncEventsContains(childOp.get(), corePipeSrc, + corePipeDst, eventId)) { + return false; + } + if (checkBackwardSyncEventsContainsAfterMerge(childOp.get(), corePipeSrc, + corePipeDst)) { + return false; + } + } + return true; +} + +// Attempt to merge backward sync events across children and prune duplicates. +void Solver::mergeBackwardSyncEventIds(OperationBase *op) { + auto *scopeOp = llvm::dyn_cast_if_present(op); + if (scopeOp == nullptr) { + return; + } + for (auto &op : scopeOp->body) { + mergeBackwardSyncEventIds(op.get()); + } + + if (llvm::isa_and_present(op)) { + return; + } + if (llvm::isa_and_present(op->parentOp)) { + return; + } + + auto *conditionOp = llvm::dyn_cast(op); + if (conditionOp != nullptr) { + if (!conditionOp->hasFalseScope()) { + return; + } + } + + llvm::DenseSet> toBeErased; + + llvm::SmallVector coreTypes; + if (options.isCrossCoreMode()) { + coreTypes = {pto::TCoreType::VECTOR, pto::TCoreType::CUBE}; + } else { + coreTypes = {pto::TCoreType::CUBE_OR_VECTOR}; + } + size_t pipeNumMax = static_cast(pto::PIPE::PIPE_NUM); + const int64_t eventIdMax = getHWAvailableEventIdNum(options.syncMode); + + for (int64_t eventId = 0; eventId < eventIdMax; ++eventId) { + for (auto coreSrc : coreTypes) { + for (auto coreDst : coreTypes) { + for (size_t pipeSrcInt = 0; pipeSrcInt < pipeNumMax; pipeSrcInt++) { + for (size_t pipeDstInt = 0; pipeDstInt < pipeNumMax; pipeDstInt++) { + auto pipeSrc = static_cast(pipeSrcInt); + auto pipeDst = static_cast(pipeDstInt); + auto corePipeSrc = CorePipeInfo(coreSrc, pipeSrc); + auto corePipeDst = CorePipeInfo(coreDst, pipeDst); + if (checkBackwardSyncEventsContains(scopeOp, corePipeSrc, + corePipeDst, eventId)) { + continue; + } + if (checkMergeable(scopeOp, corePipeSrc, corePipeDst, eventId)) { + toBeErased.insert({corePipeSrc, corePipeDst, eventId}); + backwardSyncEvents[scopeOp][{corePipeSrc, corePipeDst}].insert( + {eventId, 1}); + } + } + } + } + } + } + + if (isa(scopeOp)) { + for (auto &op : scopeOp->body) { + if (auto *block = llvm::dyn_cast(op.get())) { + for (auto &childOp : block->body) { + if (auto *childScopeOp = llvm::dyn_cast(childOp.get())) { + for (auto [corePipeSrc, corePipeDst, eventId] : toBeErased) { + if (checkBackwardSyncEventsContains(childScopeOp, corePipeSrc, + corePipeDst, eventId)) { + auto key = std::make_tuple(corePipeSrc, corePipeDst); + backwardSyncEvents[childScopeOp][key].erase(eventId); + if (backwardSyncEvents[childScopeOp][key].empty()) { + backwardSyncEvents[childScopeOp].erase(key); + } + } + } + } + } + } + } + } else { + for (auto &childOp : scopeOp->body) { + if (auto *childScopeOp = llvm::dyn_cast(childOp.get())) { + for (auto [corePipeSrc, corePipeDst, eventId] : toBeErased) { + if (checkBackwardSyncEventsContains(childScopeOp, corePipeSrc, + corePipeDst, eventId)) { + auto key = std::make_tuple(corePipeSrc, corePipeDst); + backwardSyncEvents[childScopeOp][key].erase(eventId); + if (backwardSyncEvents[childScopeOp][key].empty()) { + backwardSyncEvents[childScopeOp].erase(key); + } + } + } + } + } + } +} + +void Solver::mergeBackwardSyncPairs(SyncMap &syncMapBefore, + SyncMap &syncMapAfter) { + if (!options.moveOutAndMergeBackwardSyncPairs) { + return; + } + if (options.isIntraCoreMode()) { + resetAndBuildSetWaitOpIndex(syncMapBefore, syncMapAfter); + auto *scopeOp = llvm::dyn_cast(funcIr.get()); + assert(scopeOp != nullptr && scopeOp->body.front() != nullptr); + mergeBackwardSyncEventIds(scopeOp->body.front().get()); + } +} + +SyncBeforeAfterMap Solver::getBeforeAfterSyncMaps() { + calcAllEventIds(); + SyncMap syncMapBefore, syncMapAfter; + std::vector conflictPairs; + for (auto &conflictPair : chosenConflictedPairs) { + conflictPairs.push_back(conflictPair.get()); + } + for (auto &conflictPair : persistentChosenConflictedPairs) { + conflictPairs.push_back(conflictPair.get()); + } + + for (auto *conflictPair : conflictPairs) { + if (conflictPair->isUseless) { + continue; + } + if (conflictPair->replacedWithUnitFlag) { + continue; + } + assert(conflictPair->setOp != nullptr && conflictPair->waitOp != nullptr); + if (conflictPair->isBarrier()) { + auto barrierOp = std::make_unique( + conflictPair->waitOp->op, conflictPair->waitOp->parentOp, + conflictPair->waitCorePipeInfo.pipe); + LLVM_DEBUG(barrierOp->debugId = conflictPair->id); + syncMapBefore[conflictPair->waitOp].push_back(std::move(barrierOp)); + } else { + assert(conflictPair->eventIdNode != nullptr); + auto setOp = std::make_unique( + conflictPair->setOp->op, conflictPair->setOp->parentOp, + conflictPair->eventIdNode->getEventIds(), + conflictPair->setCorePipeInfo.pipe, + conflictPair->waitCorePipeInfo.pipe); + auto waitOp = std::make_unique( + conflictPair->waitOp->op, conflictPair->waitOp->parentOp, + conflictPair->eventIdNode->getEventIds(), + conflictPair->setCorePipeInfo.pipe, + conflictPair->waitCorePipeInfo.pipe); + if (options.isCrossCoreMode()) { + setOp->coreType = conflictPair->setCorePipeInfo.coreType; + waitOp->coreType = conflictPair->waitCorePipeInfo.coreType; + } + setOp->eventIdInfo = conflictPair->eventIdInfo; + waitOp->eventIdInfo = conflictPair->eventIdInfo; + setOp->checkLastIter = conflictPair->setOnLastIterOnly; + waitOp->checkFirstIter = conflictPair->waitOnFirstIterOnly; + LLVM_DEBUG({ + setOp->debugId = conflictPair->id; + waitOp->debugId = conflictPair->id; + }); + assert(setOp != nullptr && waitOp != nullptr); + syncMapAfter[conflictPair->setOp].push_back(std::move(setOp)); + syncMapBefore[conflictPair->waitOp].push_front(std::move(waitOp)); + } + } + + collectBackwardSyncEventIds(); + mergeBackwardSyncPairs(syncMapBefore, syncMapAfter); + + for (auto &[op, mp] : backwardSyncEvents) { + if (mp.empty()) { + continue; + } + auto *scopeOp = llvm::dyn_cast(op); + assert(scopeOp != nullptr); + for (auto [setWaitCorePipes, eventIdsMp] : mp) { + if (eventIdsMp.empty()) { + continue; + } + llvm::SmallVector eventIds; + for (auto [eventId, repeatNum] : eventIdsMp) { + llvm::SmallVector curEventIds(repeatNum, eventId); + llvm::append_range(eventIds, curEventIds); + } + llvm::sort(eventIds); + auto [corePipeSrc, corePipeDst] = setWaitCorePipes; + auto setOp = + std::make_unique(scopeOp->op, scopeOp->parentOp, eventIds, + corePipeSrc.pipe, corePipeDst.pipe); + auto waitOp = + std::make_unique(scopeOp->op, scopeOp->parentOp, eventIds, + corePipeSrc.pipe, corePipeDst.pipe); + setOp->allAtOnce = true; + waitOp->allAtOnce = true; + if (options.isCrossCoreMode()) { + setOp->coreType = corePipeSrc.coreType; + waitOp->coreType = corePipeDst.coreType; + } + assert(setOp != nullptr && waitOp != nullptr); + syncMapBefore[scopeOp].push_back(std::move(setOp)); + syncMapAfter[scopeOp].push_front(std::move(waitOp)); + } + } + return std::make_pair(std::move(syncMapBefore), std::move(syncMapAfter)); +} + +void Solver::processConflict(Occurrence *occ1, Occurrence *occ2, + RWOperation *rwOp1, RWOperation *rwOp2, + bool isUseless) { + for (auto [corePipeSrc, corePipeDst] : checkMemoryConflicts(rwOp1, rwOp2)) { + if (options.alwaysUsePipeSAsWaitingPipe) { + corePipeDst.pipe = pto::PIPE::PIPE_S; + } + auto eventIdInfo = + getEventIdInfo(occ1, occ2, rwOp1, rwOp2, corePipeSrc, corePipeDst); + handleConflict(occ1, occ2, rwOp1, rwOp2, corePipeSrc, corePipeDst, + eventIdInfo, isUseless); + } +} + +// Main processing loop that iterates processingOrders and attempts to +// discover and record conflicts. +void Solver::processOrders() { + for (auto &[occ1, occ2, rwOp1, rwOp2, isUseless] : processingOrders) { + assert(occ1 != occ2); + assert(occ1->syncIrIndex < occ2->syncIrIndex); + if (checkVisited(occ1, occ2)) { + assert(false && "expected to not check a pair more than once."); + continue; + } + if (checkImpossibleOccPair(occ1, occ2) || checkAlreadySynced(occ1, occ2) || + skipMMad1DecomposedLoopOpt(occ1, occ2) || + checkSkipParallelLoop(occ1, occ2) || + checkSkipCrossCorePair(occ1, occ2)) { + continue; + } + DEBUG_WITH_TYPE("gss-sync-solver-checking", { + llvm::dbgs() << "checking: " << (isUseless ? "is-useless\n" : "\n"); + llvm::dbgs() << occ1->syncIrIndex << ' ' << occ1->startIndex << ' ' + << occ1->endIndex << ' ' << occ1->op->str(0, false) << '\n'; + llvm::dbgs() << occ2->syncIrIndex << ' ' << occ2->startIndex << ' ' + << occ2->endIndex << ' ' << occ2->op->str(0, false) << '\n'; + }); + if (checkAlreadySyncedWithUnitFlag(occ1, occ2)) { + continue; + } + processConflict(occ1, occ2, rwOp1, rwOp2, isUseless); + } +} + +void Solver::insertMergedBackwardSyncPairs() { + for (auto &[scopeOp, st] : backwardSyncEventsAfterMerge) { + for (auto &corePipeInfoPair : st) { + auto [corePipeSrc, corePipeDst] = corePipeInfoPair; + for (auto *scopeOcc : opAllOccurrences[scopeOp]) { + auto *parentScopeOcc = scopeOcc->parentOcc; + assert(parentScopeOcc != nullptr); + Occurrence *setOcc = nullptr; + Occurrence *waitOcc = nullptr; + auto startIndex = scopeOcc->startIndex; + auto endIndex = scopeOcc->endIndex; + if (isa(scopeOp)) { + setOcc = getBeforePlaceHolderOcc(scopeOcc); + waitOcc = getAfterPlaceHolderOcc(scopeOcc); + startIndex = setOcc->endIndex; + endIndex = waitOcc->startIndex; + } + auto conflictPair = std::make_unique( + nullptr, nullptr, nullptr, nullptr, setOcc, waitOcc, corePipeSrc, + corePipeDst, startIndex, endIndex); + assert(conflictPair->startIndex <= conflictPair->endIndex); + conflictPair->isUseless = true; + conflictPair->dontReuse = true; + conflictPair->dontCheckForConflict = true; + conflictPair->couldNotRun = false; // notice this + LLVM_DEBUG({ + llvm::dbgs() << "consider-merged-backward-pair: " + << scopeOp->str(0, false) << ' ' << conflictPair->str() + << "\n"; + }); + scopeOccChosenConflicts[parentScopeOcc].insert(conflictPair.get()); + chosenConflictedPairs.push_back(std::move(conflictPair)); + } + } + } +} + +llvm::LogicalResult Solver::considerOuterBackwardSyncPairs() { + if (!options.considerOuterBackwardSyncPairs) { + return llvm::failure(); + } + bool backwardPairsPositionChanged = false; + for (auto &[scopeOp, st] : backwardSyncEventsAfterMerge) { + SmallVector> toBeErased; + for (auto &corePipeInfoPair : st) { + if (!backwardSyncEvents.contains(scopeOp) || + !backwardSyncEvents[scopeOp].contains(corePipeInfoPair)) { + toBeErased.push_back(corePipeInfoPair); + } + } + if (!toBeErased.empty()) { + backwardPairsPositionChanged = true; + for (auto &corePipeInfoPair : toBeErased) { + st.erase(corePipeInfoPair); + } + } + } + int chosenOpsDepth = -1; + SmallVector chosenOps; + for (auto &[scopeOp, mp] : backwardSyncEvents) { + if (backwardSyncEventsAfterMerge.contains(scopeOp)) { + continue; + } + int scopeOpDepth = scopeOp->getDepth(); + if (chosenOpsDepth == scopeOpDepth) { + chosenOps.push_back(scopeOp); + } else if (chosenOpsDepth == -1 || chosenOpsDepth < scopeOpDepth) { + chosenOps.clear(); + chosenOps.push_back(scopeOp); + chosenOpsDepth = scopeOpDepth; + } + } + if (chosenOps.empty()) { + return llvm::failure(); + } + bool newPairIsInserted = false; + for (auto *chosenOp : chosenOps) { + for (auto &[corePipeInfoPair, eventIdsMp] : backwardSyncEvents[chosenOp]) { + assert(!eventIdsMp.empty()); + if (!eventIdsMp.empty()) { + auto [it, isInserted] = + backwardSyncEventsAfterMerge[chosenOp].insert(corePipeInfoPair); + newPairIsInserted |= isInserted; + } + } + } + return llvm::success(backwardPairsPositionChanged || newPairIsInserted); +} + +llvm::LogicalResult Solver::reuseSyncPairToSaveEventIds() { + if (!options.reuseSyncPairToSaveEventIds || barrierAllPairs.empty()) { + return llvm::failure(); + } + bool limitReached = true; + for (auto [corePipeSrc, corePipeDst] : barrierAllPairs) { + if (reusePairs[{corePipeSrc, corePipeDst}] < maxReuseNum) { + if (reusePairs[{corePipeSrc, corePipeDst}] <= + reusedPairs[{corePipeSrc, corePipeDst}]) { + reusePairs[{corePipeSrc, corePipeDst}] += 1; + limitReached = false; + } + } + } + DEBUG_WITH_TYPE("gss-sync-solver-reuse", { + llvm::dbgs() << "reusePairs: \n"; + for (auto [pipeCorePairs, cnt] : reusePairs) { + llvm::dbgs() << get<0>(pipeCorePairs).pipe << ' ' + << get<1>(pipeCorePairs).pipe << ' ' << cnt << '\n'; + } + }); + return llvm::success(!limitReached); +} + +llvm::LogicalResult Solver::disableMultiEventIdForBarrierAllPairs() { + if (!options.disableMultiEventIdForBarrierAllPairs || + barrierAllPairs.empty()) { + return llvm::failure(); + } + bool newPairIsInserted = false; + for (auto corePipeInfoPair : barrierAllPairs) { + auto [it, isInserted] = disabledMultiEventIdPairs.insert(corePipeInfoPair); + newPairIsInserted |= isInserted; + } + LLVM_DEBUG({ + if (newPairIsInserted) { + llvm::dbgs() << "disabled-multi-event-id-pairs: \n"; + for (auto &[corePipeSrc, corePipeDst] : disabledMultiEventIdPairs) { + llvm::dbgs() << corePipeSrc.coreType << ' ' << corePipeSrc.pipe << ' ' + << corePipeDst.coreType << ' ' << corePipeDst.pipe << '\n'; + } + } + }); + return llvm::success(newPairIsInserted); +} + +llvm::LogicalResult Solver::tryMovingOutBackwardSyncPairsToOuterLoops() { + if (!options.moveOutAndMergeBackwardSyncPairs || !options.isCrossCoreMode() || + dontMoveBackwardSyncPairsToOutmostLoop) { + return llvm::failure(); + } + if (!moveBackwardSyncPairsToOutmostLoop) { + moveBackwardSyncPairsToOutmostLoop = true; + return llvm::success(); + } + if (!barrierAllPairs.empty()) { + moveBackwardSyncPairsToOutmostLoop = false; + dontMoveBackwardSyncPairsToOutmostLoop = true; + return llvm::success(); + } + return llvm::failure(); +} + +// High-level solve orchestration with multiple passes and optional merging +// iterations. +llvm::LogicalResult Solver::runSolver(bool enableOpts1, bool enableOpts2) { + reset(/*resetEventIdRanOutOpts=*/true); + + int64_t runNum = 0; + while (runNum++ < maxRunNum) { + LLVM_DEBUG(llvm::dbgs() << "runNum: " << runNum << '\n'); + + reset(); + insertMergedBackwardSyncPairs(); + processOrders(); + + if (llvm::succeeded(tryMovingOutBackwardSyncPairsToOuterLoops())) { + continue; + } + + if (enableOpts1) { + if (options.considerOuterBackwardSyncPairs) { + getBeforeAfterSyncMaps(); + if (llvm::succeeded(considerOuterBackwardSyncPairs())) { + continue; + } + if (!barrierAllPairs.empty()) { + backwardSyncEventsAfterMerge.clear(); + } + } + } + + if (enableOpts2) { + if (!barrierAllPairs.empty()) { + if (llvm::succeeded(reuseSyncPairToSaveEventIds())) { + continue; + } + if (llvm::succeeded(disableMultiEventIdForBarrierAllPairs())) { + continue; + } + } + } + + if (!barrierAllPairs.empty()) { + pickAndInsertABarrierAll(); + reset(/*resetEventIdRanOutOpts=*/true); + continue; + } + break; + } + + reset(); + insertMergedBackwardSyncPairs(); + processOrders(); + + return llvm::success(runNum < maxRunNum); +} + +void Solver::solve() { + if (llvm::succeeded(runSolver())) { + return; + } + if (!options.isTestMode()) { + if (llvm::succeeded(runSolver(/*enableOpts1=*/false))) { + return; + } + if (llvm::succeeded( + runSolver(/*enableOpts1=*/false, /*enableOpts2=*/false))) { + return; + } + } + llvm_unreachable("GSS: runSolver() failed."); +} diff --git a/tools/ptobc/CMakeLists.txt b/tools/ptobc/CMakeLists.txt index 8224cd637..e21e7bc3d 100644 --- a/tools/ptobc/CMakeLists.txt +++ b/tools/ptobc/CMakeLists.txt @@ -23,6 +23,7 @@ add_library(ptobc_lib STATIC src/mlir_encode.cpp src/canonical_printer.cpp src/ptobc_decode_print.cpp + generated/ptobc_opcodes_v0.cpp ) target_include_directories(ptobc_lib PUBLIC diff --git a/tools/ptobc/generated/ptobc_opcodes_v0.cpp b/tools/ptobc/generated/ptobc_opcodes_v0.cpp new file mode 100644 index 000000000..233854ead --- /dev/null +++ b/tools/ptobc/generated/ptobc_opcodes_v0.cpp @@ -0,0 +1,679 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Generated by docs/bytecode/tools/gen_v0_tables.py + +#include "ptobc_opcodes_v0.h" + +namespace ptobc::v0 { + +const OpInfo kOpTable[] = { + {0x0000, "pto.get_block_idx", 0, 0x01, 0x00, 0, 1, 0, 0x00}, + {0x0001, "pto.get_block_num", 0, 0x01, 0x00, 0, 1, 0, 0x00}, + {0x0002, "pto.get_subblock_idx", 0, 0x01, 0x00, 0, 1, 0, 0x00}, + {0x0003, "pto.get_subblock_num", 0, 0x01, 0x00, 0, 1, 0, 0x00}, + {0x0004, "pto.make_tensor_view", 0, 0x01, 0x03, 1, 1, 0, 0x06}, + {0x0005, "pto.partition_view", 0, 0x01, 0x03, 1, 1, 0, 0x07}, + {0x0006, "pto.section", 1, 0x00, 0x00, 0, 0, 1, 0x00}, + {0x1000, "pto.addptr", 0, 0x01, 0x00, 2, 1, 0, 0x00}, + {0x1001, "pto.alloc_tile", 0, 0x01, 0x04, 0, 1, 0, 0x08}, + {0x1002, "pto.barrier", 0, 0x00, 0x00, 0, 0, 0, 0x00}, + {0x1003, "pto.mgather", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1004, "pto.mscatter", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1005, "pto.record_event", 0, 0x00, 0x00, 0, 0, 0, 0x02}, + {0x1006, "pto.tabs", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1007, "pto.tadd", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1008, "pto.taddc", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x1009, "pto.tadds", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x100A, "pto.taddsc", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x100B, "pto.tand", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x100C, "pto.tands", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x100D, "pto.tci", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x100E, "pto.tcmp", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x100F, "pto.tcmps", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1010, "pto.tcolexpand", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1011, "pto.tcolexpandadd", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1012, "pto.tcolexpanddiv", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1013, "pto.tcolexpandexpdif", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1014, "pto.tcolexpandmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1015, "pto.tcolexpandmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1016, "pto.tcolexpandmul", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1017, "pto.tcolexpandsub", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1018, "pto.tcolmax", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1019, "pto.tcolmin", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x101A, "pto.tcolprod", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x101B, "pto.tcolsum", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x101C, "pto.tcvt", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x101D, "pto.tdiv", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x101E, "pto.tdivs", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x101F, "pto.texp", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1020, "pto.texpands", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1021, "pto.textract", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x1022, "pto.textract_fp", 0, 0x00, 0x00, 5, 0, 0, 0x00}, + {0x1023, "pto.tfillpad", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1024, "pto.tfillpad_expand", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1025, "pto.tfillpad_inplace", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1026, "pto.tfmod", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1027, "pto.tfmods", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1028, "pto.tgather", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1029, "pto.tgatherb", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x102A, "pto.tgemv", 1, 0x00, 0x01, 0, 0, 0, 0x00}, + {0x102B, "pto.tgetval", 0, 0x01, 0x00, 2, 1, 0, 0x00}, + {0x102C, "pto.timg2col", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x102D, "pto.tinsert", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x102E, "pto.tinsert_fp", 0, 0x00, 0x00, 5, 0, 0, 0x00}, + {0x102F, "pto.tload", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1030, "pto.tlog", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1031, "pto.tlrelu", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1032, "pto.tmatmul", 1, 0x00, 0x01, 0, 0, 0, 0x00}, + {0x1033, "pto.tmatmul.mx", 1, 0x00, 0x01, 0, 0, 0, 0x00}, + {0x1034, "pto.tmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1035, "pto.tmaxs", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1036, "pto.tmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1037, "pto.tmins", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1038, "pto.tmov", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1039, "pto.tmov.fp", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x103A, "pto.tmrgsort", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x103B, "pto.tmul", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x103C, "pto.tmuls", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x103D, "pto.tneg", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x103E, "pto.tnot", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x103F, "pto.tor", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1040, "pto.tors", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1041, "pto.tpartadd", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1042, "pto.tpartmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1043, "pto.tpartmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1044, "pto.tpartmul", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1045, "pto.tprefetch", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1046, "pto.tprelu", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x1047, "pto.tquant", 0, 0x00, 0x02, 3, 0, 0, 0x00}, + {0x1048, "pto.trecip", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1049, "pto.trelu", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x104A, "pto.trem", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x104B, "pto.trems", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x104C, "pto.treshape", 0, 0x01, 0x00, 1, 1, 0, 0x00}, + {0x104D, "pto.trowexpand", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x104E, "pto.trowexpandadd", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x104F, "pto.trowexpandexpdif", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1050, "pto.trowexpandmax", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1051, "pto.trowexpandmin", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1052, "pto.trowmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1053, "pto.trowmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1054, "pto.trowsum", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1055, "pto.trsqrt", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1056, "pto.tscatter", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1057, "pto.tsel", 0, 0x00, 0x00, 5, 0, 0, 0x00}, + {0x1058, "pto.tsels", 0, 0x00, 0x00, 5, 0, 0, 0x00}, + {0x1059, "pto.tset_img2col_padding", 0, 0x00, 0x00, 1, 0, 0, 0x00}, + {0x105A, "pto.tset_img2col_rpt", 0, 0x00, 0x00, 1, 0, 0, 0x00}, + {0x105B, "pto.tsetfmatrix", 0, 0x00, 0x00, 1, 0, 0, 0x00}, + {0x105C, "pto.tsethf32mode", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x105D, "pto.tsettf32mode", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x105E, "pto.tsetval", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x105F, "pto.tshl", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1060, "pto.tshls", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1061, "pto.tshr", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1062, "pto.tshrs", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1063, "pto.tsort32", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1064, "pto.tsqrt", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1065, "pto.tstore", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1066, "pto.tstore_fp", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1067, "pto.tsub", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1068, "pto.tsubc", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x1069, "pto.tsubs", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x106A, "pto.tsubsc", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x106B, "pto.trowexpandsub", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x106C, "pto.ttrans", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x106D, "pto.ttri", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x106E, "pto.txor", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x106F, "pto.txors", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x1070, "pto.wait_event", 0, 0x00, 0x00, 0, 0, 0, 0x02}, + {0x1071, "pto.tprint", 0, 0x00, 0x00, 1, 0, 0, 0x00}, + {0x1072, "pto.subview", 0, 0x01, 0x02, 0, 1, 0, 0x00}, + {0x1073, "pto.trowexpanddiv", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1074, "pto.trowexpandmul", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1075, "pto.tdequant", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x1076, "pto.taxpy", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1077, "pto.thistogram", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1078, "pto.tget_scale_addr", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1079, "pto.trowargmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x107A, "pto.trowargmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x107B, "pto.tcolargmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x107C, "pto.tcolargmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x107D, "pto.tsync", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x107E, "pto.reserve_buffer", 0, 0x01, 0x00, 0, 1, 0, 0x00}, + {0x107F, "pto.import_reserved_buffer", 0, 0x01, 0x00, 0, 1, 0, 0x00}, + {0x1080, "pto.aic_initialize_pipe", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1081, "pto.aiv_initialize_pipe", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1082, "pto.tpush_to_aiv", 0, 0x00, 0x00, 1, 0, 0, 0x00}, + {0x1083, "pto.tpush_to_aic", 0, 0x00, 0x00, 1, 0, 0, 0x00}, + {0x1084, "pto.tpop_from_aic", 0, 0x01, 0x02, 0, 1, 0, 0x00}, + {0x1085, "pto.tpop_from_aiv", 0, 0x01, 0x02, 0, 1, 0, 0x00}, + {0x1086, "pto.tfree_from_aic", 0, 0x00, 0x00, 0, 0, 0, 0x00}, + {0x1087, "pto.tfree_from_aiv", 0, 0x00, 0x00, 0, 0, 0, 0x00}, + {0x1088, "pto.set_validshape", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1089, "pto.tconcat", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x108A, "pto.trowprod", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x108B, "pto.initialize_l2g2l_pipe", 0, 0x01, 0x02, 0, 1, 0, 0x00}, + {0x108C, "pto.initialize_l2l_pipe", 0, 0x01, 0x02, 0, 1, 0, 0x00}, + {0x108D, "pto.tpush", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x108E, "pto.declare_tile", 0, 0x01, 0x00, 0, 1, 0, 0x00}, + {0x108F, "pto.tpop", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1090, "pto.tfree", 0, 0x00, 0x00, 1, 0, 0, 0x00}, + {0x1091, "pto.comm.tput", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1092, "pto.comm.tget", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1093, "pto.comm.tnotify", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1094, "pto.comm.twait", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1095, "pto.comm.ttest", 0, 0x01, 0x02, 0, 1, 0, 0x00}, + {0x1096, "pto.comm.tbroadcast", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1097, "pto.comm.tgather", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1098, "pto.comm.tscatter", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1099, "pto.comm.treduce", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x109A, "pto.tpartargmax", 0, 0x00, 0x00, 6, 0, 0, 0x00}, + {0x109B, "pto.tpartargmin", 0, 0x00, 0x00, 6, 0, 0, 0x00}, + {0x109C, "pto.tscatter.maskpattern", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x2000, "arith.addi", 0, 0x01, 0x00, 2, 1, 0, 0x00}, + {0x2001, "arith.ceildivsi", 0, 0x01, 0x00, 2, 1, 0, 0x00}, + {0x2002, "arith.cmpi", 0, 0x01, 0x00, 2, 1, 0, 0x01}, + {0x2003, "arith.constant", 0, 0x01, 0x00, 0, 1, 0, 0x05}, + {0x2004, "arith.index_cast", 0, 0x01, 0x00, 1, 1, 0, 0x00}, + {0x2005, "arith.minui", 0, 0x01, 0x00, 2, 1, 0, 0x00}, + {0x2006, "arith.muli", 0, 0x01, 0x00, 2, 1, 0, 0x00}, + {0x2007, "arith.select", 0, 0x01, 0x00, 3, 1, 0, 0x00}, + {0x2008, "arith.subi", 0, 0x01, 0x00, 2, 1, 0, 0x00}, + {0x4000, "scf.for", 0, 0x00, 0x00, 3, 0, 1, 0x00}, + {0x4001, "scf.if", 0, 0x00, 0x00, 1, 0, 2, 0x00}, + {0x4002, "scf.yield", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x6000, "func.func", 0, 0x00, 0x00, 0, 0, 0, 0x00}, + {0x6001, "func.return", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x6002, "func.call", 0, 0x02, 0x02, 0, 0, 0, 0x00}, +}; + +const OpInfo *lookupByOpcode(uint16_t opcode) { + // Binary search on kOpTable (sorted by opcode). + size_t lo = 0, hi = sizeof(kOpTable) / sizeof(kOpTable[0]); + while (lo < hi) { + size_t mid = lo + (hi - lo) / 2; + uint16_t v = kOpTable[mid].opcode; + if (v == opcode) return &kOpTable[mid]; + if (v < opcode) lo = mid + 1; else hi = mid; + } + return nullptr; +} + +std::optional lookupOpcodeByName(llvm::StringRef name) { + uint16_t v = llvm::StringSwitch(name) + .Case("arith.addi", 0x2000) + .Case("arith.ceildivsi", 0x2001) + .Case("arith.cmpi", 0x2002) + .Case("arith.constant", 0x2003) + .Case("arith.index_cast", 0x2004) + .Case("arith.minui", 0x2005) + .Case("arith.muli", 0x2006) + .Case("arith.select", 0x2007) + .Case("arith.subi", 0x2008) + .Case("func.func", 0x6000) + .Case("func.return", 0x6001) + .Case("func.call", 0x6002) + .Case("pto.addptr", 0x1000) + .Case("pto.alloc_tile", 0x1001) + .Case("pto.barrier", 0x1002) + .Case("pto.get_block_idx", 0x0000) + .Case("pto.get_block_num", 0x0001) + .Case("pto.get_subblock_idx", 0x0002) + .Case("pto.get_subblock_num", 0x0003) + .Case("pto.make_tensor_view", 0x0004) + .Case("pto.mgather", 0x1003) + .Case("pto.mscatter", 0x1004) + .Case("pto.partition_view", 0x0005) + .Case("pto.record_event", 0x1005) + .Case("pto.section", 0x0006) + .Case("pto.tabs", 0x1006) + .Case("pto.tadd", 0x1007) + .Case("pto.taddc", 0x1008) + .Case("pto.tadds", 0x1009) + .Case("pto.taddsc", 0x100A) + .Case("pto.tand", 0x100B) + .Case("pto.tands", 0x100C) + .Case("pto.tci", 0x100D) + .Case("pto.tcmp", 0x100E) + .Case("pto.tcmps", 0x100F) + .Case("pto.tcolexpand", 0x1010) + .Case("pto.tcolexpandadd", 0x1011) + .Case("pto.tcolexpanddiv", 0x1012) + .Case("pto.tcolexpandexpdif", 0x1013) + .Case("pto.tcolexpandmax", 0x1014) + .Case("pto.tcolexpandmin", 0x1015) + .Case("pto.tcolexpandmul", 0x1016) + .Case("pto.tcolexpandsub", 0x1017) + .Case("pto.tcolmax", 0x1018) + .Case("pto.tcolmin", 0x1019) + .Case("pto.tcolprod", 0x101A) + .Case("pto.tcolsum", 0x101B) + .Case("pto.tcvt", 0x101C) + .Case("pto.tdiv", 0x101D) + .Case("pto.tdivs", 0x101E) + .Case("pto.texp", 0x101F) + .Case("pto.texpands", 0x1020) + .Case("pto.textract", 0x1021) + .Case("pto.textract_fp", 0x1022) + .Case("pto.tfillpad", 0x1023) + .Case("pto.tfillpad_expand", 0x1024) + .Case("pto.tfillpad_inplace", 0x1025) + .Case("pto.tfmod", 0x1026) + .Case("pto.tfmods", 0x1027) + .Case("pto.tgather", 0x1028) + .Case("pto.tgatherb", 0x1029) + .Case("pto.tgemv", 0x102A) + .Case("pto.tgetval", 0x102B) + .Case("pto.timg2col", 0x102C) + .Case("pto.tinsert", 0x102D) + .Case("pto.tinsert_fp", 0x102E) + .Case("pto.tload", 0x102F) + .Case("pto.tlog", 0x1030) + .Case("pto.tlrelu", 0x1031) + .Case("pto.tmatmul", 0x1032) + .Case("pto.tmatmul.mx", 0x1033) + .Case("pto.tmax", 0x1034) + .Case("pto.tmaxs", 0x1035) + .Case("pto.tmin", 0x1036) + .Case("pto.tmins", 0x1037) + .Case("pto.tmov", 0x1038) + .Case("pto.tmov.fp", 0x1039) + .Case("pto.tmrgsort", 0x103A) + .Case("pto.tmul", 0x103B) + .Case("pto.tmuls", 0x103C) + .Case("pto.tneg", 0x103D) + .Case("pto.tnot", 0x103E) + .Case("pto.tor", 0x103F) + .Case("pto.tors", 0x1040) + .Case("pto.tpartadd", 0x1041) + .Case("pto.tpartmax", 0x1042) + .Case("pto.tpartmin", 0x1043) + .Case("pto.tpartmul", 0x1044) + .Case("pto.tprefetch", 0x1045) + .Case("pto.tprelu", 0x1046) + .Case("pto.tquant", 0x1047) + .Case("pto.trecip", 0x1048) + .Case("pto.trelu", 0x1049) + .Case("pto.trem", 0x104A) + .Case("pto.trems", 0x104B) + .Case("pto.treshape", 0x104C) + .Case("pto.trowexpand", 0x104D) + .Case("pto.trowexpandadd", 0x104E) + .Case("pto.trowexpandexpdif", 0x104F) + .Case("pto.trowexpandmax", 0x1050) + .Case("pto.trowexpandmin", 0x1051) + .Case("pto.trowmax", 0x1052) + .Case("pto.trowmin", 0x1053) + .Case("pto.trowsum", 0x1054) + .Case("pto.trsqrt", 0x1055) + .Case("pto.tscatter", 0x1056) + .Case("pto.tsel", 0x1057) + .Case("pto.tsels", 0x1058) + .Case("pto.tset_img2col_padding", 0x1059) + .Case("pto.tset_img2col_rpt", 0x105A) + .Case("pto.tsetfmatrix", 0x105B) + .Case("pto.tsethf32mode", 0x105C) + .Case("pto.tsettf32mode", 0x105D) + .Case("pto.tsetval", 0x105E) + .Case("pto.tshl", 0x105F) + .Case("pto.tshls", 0x1060) + .Case("pto.tshr", 0x1061) + .Case("pto.tshrs", 0x1062) + .Case("pto.tsort32", 0x1063) + .Case("pto.tsqrt", 0x1064) + .Case("pto.tstore", 0x1065) + .Case("pto.tstore_fp", 0x1066) + .Case("pto.tsub", 0x1067) + .Case("pto.tsubc", 0x1068) + .Case("pto.tsubs", 0x1069) + .Case("pto.tsubsc", 0x106A) + .Case("pto.trowexpandsub", 0x106B) + .Case("pto.ttrans", 0x106C) + .Case("pto.ttri", 0x106D) + .Case("pto.txor", 0x106E) + .Case("pto.txors", 0x106F) + .Case("pto.wait_event", 0x1070) + .Case("pto.tprint", 0x1071) + .Case("pto.subview", 0x1072) + .Case("pto.trowexpanddiv", 0x1073) + .Case("pto.trowexpandmul", 0x1074) + .Case("pto.tdequant", 0x1075) + .Case("pto.taxpy", 0x1076) + .Case("pto.thistogram", 0x1077) + .Case("pto.tget_scale_addr", 0x1078) + .Case("pto.trowargmax", 0x1079) + .Case("pto.trowargmin", 0x107A) + .Case("pto.tcolargmax", 0x107B) + .Case("pto.tcolargmin", 0x107C) + .Case("pto.tsync", 0x107D) + .Case("pto.reserve_buffer", 0x107E) + .Case("pto.import_reserved_buffer", 0x107F) + .Case("pto.aic_initialize_pipe", 0x1080) + .Case("pto.aiv_initialize_pipe", 0x1081) + .Case("pto.tpush_to_aiv", 0x1082) + .Case("pto.tpush_to_aic", 0x1083) + .Case("pto.tpop_from_aic", 0x1084) + .Case("pto.tpop_from_aiv", 0x1085) + .Case("pto.tfree_from_aic", 0x1086) + .Case("pto.tfree_from_aiv", 0x1087) + .Case("pto.set_validshape", 0x1088) + .Case("pto.tconcat", 0x1089) + .Case("pto.trowprod", 0x108A) + .Case("pto.initialize_l2g2l_pipe", 0x108B) + .Case("pto.initialize_l2l_pipe", 0x108C) + .Case("pto.tpush", 0x108D) + .Case("pto.declare_tile", 0x108E) + .Case("pto.tpop", 0x108F) + .Case("pto.tfree", 0x1090) + .Case("pto.comm.tput", 0x1091) + .Case("pto.comm.tget", 0x1092) + .Case("pto.comm.tnotify", 0x1093) + .Case("pto.comm.twait", 0x1094) + .Case("pto.comm.ttest", 0x1095) + .Case("pto.comm.tbroadcast", 0x1096) + .Case("pto.comm.tgather", 0x1097) + .Case("pto.comm.tscatter", 0x1098) + .Case("pto.comm.treduce", 0x1099) + .Case("pto.tpartargmax", 0x109A) + .Case("pto.tpartargmin", 0x109B) + .Case("scf.for", 0x4000) + .Case("scf.if", 0x4001) + .Case("scf.yield", 0x4002) + .Default(0xFFFF); + if (v == 0xFFFF) return std::nullopt; + return v; +} + +const OpInfo *lookupByName(llvm::StringRef name) { + auto o = lookupOpcodeByName(name); + if (!o) return nullptr; + return lookupByOpcode(*o); +} + +std::optional lookupOpcodeAndVariantByFullName(llvm::StringRef fullName) { + // For non-family ops, variant is 0. For family ops, variant is the assigned u8. + // NOTE: `pto.section` is not a real op name; use `pto.section.cube`/`pto.section.vector`. + return llvm::StringSwitch>(fullName) + .Case("arith.addi", OpcodeAndVariant{0x2000, 0, 0}) + .Case("arith.ceildivsi", OpcodeAndVariant{0x2001, 0, 0}) + .Case("arith.cmpi", OpcodeAndVariant{0x2002, 0, 0}) + .Case("arith.constant", OpcodeAndVariant{0x2003, 0, 0}) + .Case("arith.index_cast", OpcodeAndVariant{0x2004, 0, 0}) + .Case("arith.minui", OpcodeAndVariant{0x2005, 0, 0}) + .Case("arith.muli", OpcodeAndVariant{0x2006, 0, 0}) + .Case("arith.select", OpcodeAndVariant{0x2007, 0, 0}) + .Case("arith.subi", OpcodeAndVariant{0x2008, 0, 0}) + .Case("func.func", OpcodeAndVariant{0x6000, 0, 0}) + .Case("func.return", OpcodeAndVariant{0x6001, 0, 0}) + .Case("func.call", OpcodeAndVariant{0x6002, 0, 0}) + .Case("pto.addptr", OpcodeAndVariant{0x1000, 0, 0}) + .Case("pto.alloc_tile", OpcodeAndVariant{0x1001, 0, 0}) + .Case("pto.barrier", OpcodeAndVariant{0x1002, 0, 0}) + .Case("pto.get_block_idx", OpcodeAndVariant{0x0000, 0, 0}) + .Case("pto.get_block_num", OpcodeAndVariant{0x0001, 0, 0}) + .Case("pto.get_subblock_idx", OpcodeAndVariant{0x0002, 0, 0}) + .Case("pto.get_subblock_num", OpcodeAndVariant{0x0003, 0, 0}) + .Case("pto.make_tensor_view", OpcodeAndVariant{0x0004, 0, 0}) + .Case("pto.mgather", OpcodeAndVariant{0x1003, 0, 0}) + .Case("pto.mscatter", OpcodeAndVariant{0x1004, 0, 0}) + .Case("pto.partition_view", OpcodeAndVariant{0x0005, 0, 0}) + .Case("pto.record_event", OpcodeAndVariant{0x1005, 0, 0}) + .Case("pto.tabs", OpcodeAndVariant{0x1006, 0, 0}) + .Case("pto.tadd", OpcodeAndVariant{0x1007, 0, 0}) + .Case("pto.taddc", OpcodeAndVariant{0x1008, 0, 0}) + .Case("pto.tadds", OpcodeAndVariant{0x1009, 0, 0}) + .Case("pto.taddsc", OpcodeAndVariant{0x100A, 0, 0}) + .Case("pto.tand", OpcodeAndVariant{0x100B, 0, 0}) + .Case("pto.tands", OpcodeAndVariant{0x100C, 0, 0}) + .Case("pto.tci", OpcodeAndVariant{0x100D, 0, 0}) + .Case("pto.tcmp", OpcodeAndVariant{0x100E, 0, 0}) + .Case("pto.tcmps", OpcodeAndVariant{0x100F, 0, 0}) + .Case("pto.tcolexpand", OpcodeAndVariant{0x1010, 0, 0}) + .Case("pto.tcolexpandadd", OpcodeAndVariant{0x1011, 0, 0}) + .Case("pto.tcolexpanddiv", OpcodeAndVariant{0x1012, 0, 0}) + .Case("pto.tcolexpandexpdif", OpcodeAndVariant{0x1013, 0, 0}) + .Case("pto.tcolexpandmax", OpcodeAndVariant{0x1014, 0, 0}) + .Case("pto.tcolexpandmin", OpcodeAndVariant{0x1015, 0, 0}) + .Case("pto.tcolexpandmul", OpcodeAndVariant{0x1016, 0, 0}) + .Case("pto.tcolexpandsub", OpcodeAndVariant{0x1017, 0, 0}) + .Case("pto.tcolmax", OpcodeAndVariant{0x1018, 0, 0}) + .Case("pto.tcolmin", OpcodeAndVariant{0x1019, 0, 0}) + .Case("pto.tcolprod", OpcodeAndVariant{0x101A, 0, 0}) + .Case("pto.tcolsum", OpcodeAndVariant{0x101B, 0, 0}) + .Case("pto.tcvt", OpcodeAndVariant{0x101C, 0, 0}) + .Case("pto.tdiv", OpcodeAndVariant{0x101D, 0, 0}) + .Case("pto.tdivs", OpcodeAndVariant{0x101E, 0, 0}) + .Case("pto.texp", OpcodeAndVariant{0x101F, 0, 0}) + .Case("pto.texpands", OpcodeAndVariant{0x1020, 0, 0}) + .Case("pto.textract", OpcodeAndVariant{0x1021, 0, 0}) + .Case("pto.textract_fp", OpcodeAndVariant{0x1022, 0, 0}) + .Case("pto.tfillpad", OpcodeAndVariant{0x1023, 0, 0}) + .Case("pto.tfillpad_expand", OpcodeAndVariant{0x1024, 0, 0}) + .Case("pto.tfillpad_inplace", OpcodeAndVariant{0x1025, 0, 0}) + .Case("pto.tfmod", OpcodeAndVariant{0x1026, 0, 0}) + .Case("pto.tfmods", OpcodeAndVariant{0x1027, 0, 0}) + .Case("pto.tgather", OpcodeAndVariant{0x1028, 0, 0}) + .Case("pto.tgatherb", OpcodeAndVariant{0x1029, 0, 0}) + .Case("pto.tgetval", OpcodeAndVariant{0x102B, 0, 0}) + .Case("pto.timg2col", OpcodeAndVariant{0x102C, 0, 0}) + .Case("pto.tinsert", OpcodeAndVariant{0x102D, 0, 0}) + .Case("pto.tinsert_fp", OpcodeAndVariant{0x102E, 0, 0}) + .Case("pto.tload", OpcodeAndVariant{0x102F, 0, 0}) + .Case("pto.tlog", OpcodeAndVariant{0x1030, 0, 0}) + .Case("pto.tlrelu", OpcodeAndVariant{0x1031, 0, 0}) + .Case("pto.tmax", OpcodeAndVariant{0x1034, 0, 0}) + .Case("pto.tmaxs", OpcodeAndVariant{0x1035, 0, 0}) + .Case("pto.tmin", OpcodeAndVariant{0x1036, 0, 0}) + .Case("pto.tmins", OpcodeAndVariant{0x1037, 0, 0}) + .Case("pto.tmov", OpcodeAndVariant{0x1038, 0, 0}) + .Case("pto.tmov.fp", OpcodeAndVariant{0x1039, 0, 0}) + .Case("pto.tmrgsort", OpcodeAndVariant{0x103A, 0, 0}) + .Case("pto.tmul", OpcodeAndVariant{0x103B, 0, 0}) + .Case("pto.tmuls", OpcodeAndVariant{0x103C, 0, 0}) + .Case("pto.tneg", OpcodeAndVariant{0x103D, 0, 0}) + .Case("pto.tnot", OpcodeAndVariant{0x103E, 0, 0}) + .Case("pto.tor", OpcodeAndVariant{0x103F, 0, 0}) + .Case("pto.tors", OpcodeAndVariant{0x1040, 0, 0}) + .Case("pto.tpartadd", OpcodeAndVariant{0x1041, 0, 0}) + .Case("pto.tpartmax", OpcodeAndVariant{0x1042, 0, 0}) + .Case("pto.tpartmin", OpcodeAndVariant{0x1043, 0, 0}) + .Case("pto.tpartmul", OpcodeAndVariant{0x1044, 0, 0}) + .Case("pto.tprefetch", OpcodeAndVariant{0x1045, 0, 0}) + .Case("pto.tprelu", OpcodeAndVariant{0x1046, 0, 0}) + .Case("pto.tquant", OpcodeAndVariant{0x1047, 0, 0}) + .Case("pto.trecip", OpcodeAndVariant{0x1048, 0, 0}) + .Case("pto.trelu", OpcodeAndVariant{0x1049, 0, 0}) + .Case("pto.trem", OpcodeAndVariant{0x104A, 0, 0}) + .Case("pto.trems", OpcodeAndVariant{0x104B, 0, 0}) + .Case("pto.treshape", OpcodeAndVariant{0x104C, 0, 0}) + .Case("pto.trowexpand", OpcodeAndVariant{0x104D, 0, 0}) + .Case("pto.trowexpandadd", OpcodeAndVariant{0x104E, 0, 0}) + .Case("pto.trowexpandexpdif", OpcodeAndVariant{0x104F, 0, 0}) + .Case("pto.trowexpandmax", OpcodeAndVariant{0x1050, 0, 0}) + .Case("pto.trowexpandmin", OpcodeAndVariant{0x1051, 0, 0}) + .Case("pto.trowmax", OpcodeAndVariant{0x1052, 0, 0}) + .Case("pto.trowmin", OpcodeAndVariant{0x1053, 0, 0}) + .Case("pto.trowsum", OpcodeAndVariant{0x1054, 0, 0}) + .Case("pto.trsqrt", OpcodeAndVariant{0x1055, 0, 0}) + .Case("pto.tscatter", OpcodeAndVariant{0x1056, 0, 0}) + .Case("pto.tsel", OpcodeAndVariant{0x1057, 0, 0}) + .Case("pto.tsels", OpcodeAndVariant{0x1058, 0, 0}) + .Case("pto.tset_img2col_padding", OpcodeAndVariant{0x1059, 0, 0}) + .Case("pto.tset_img2col_rpt", OpcodeAndVariant{0x105A, 0, 0}) + .Case("pto.tsetfmatrix", OpcodeAndVariant{0x105B, 0, 0}) + .Case("pto.tsethf32mode", OpcodeAndVariant{0x105C, 0, 0}) + .Case("pto.tsettf32mode", OpcodeAndVariant{0x105D, 0, 0}) + .Case("pto.tsetval", OpcodeAndVariant{0x105E, 0, 0}) + .Case("pto.tshl", OpcodeAndVariant{0x105F, 0, 0}) + .Case("pto.tshls", OpcodeAndVariant{0x1060, 0, 0}) + .Case("pto.tshr", OpcodeAndVariant{0x1061, 0, 0}) + .Case("pto.tshrs", OpcodeAndVariant{0x1062, 0, 0}) + .Case("pto.tsort32", OpcodeAndVariant{0x1063, 0, 0}) + .Case("pto.tsqrt", OpcodeAndVariant{0x1064, 0, 0}) + .Case("pto.tstore", OpcodeAndVariant{0x1065, 0, 0}) + .Case("pto.tstore_fp", OpcodeAndVariant{0x1066, 0, 0}) + .Case("pto.tsub", OpcodeAndVariant{0x1067, 0, 0}) + .Case("pto.tsubc", OpcodeAndVariant{0x1068, 0, 0}) + .Case("pto.tsubs", OpcodeAndVariant{0x1069, 0, 0}) + .Case("pto.tsubsc", OpcodeAndVariant{0x106A, 0, 0}) + .Case("pto.trowexpandsub", OpcodeAndVariant{0x106B, 0, 0}) + .Case("pto.ttrans", OpcodeAndVariant{0x106C, 0, 0}) + .Case("pto.ttri", OpcodeAndVariant{0x106D, 0, 0}) + .Case("pto.txor", OpcodeAndVariant{0x106E, 0, 0}) + .Case("pto.txors", OpcodeAndVariant{0x106F, 0, 0}) + .Case("pto.wait_event", OpcodeAndVariant{0x1070, 0, 0}) + .Case("pto.tprint", OpcodeAndVariant{0x1071, 0, 0}) + .Case("pto.subview", OpcodeAndVariant{0x1072, 0, 0}) + .Case("pto.trowexpanddiv", OpcodeAndVariant{0x1073, 0, 0}) + .Case("pto.trowexpandmul", OpcodeAndVariant{0x1074, 0, 0}) + .Case("pto.tdequant", OpcodeAndVariant{0x1075, 0, 0}) + .Case("pto.taxpy", OpcodeAndVariant{0x1076, 0, 0}) + .Case("pto.thistogram", OpcodeAndVariant{0x1077, 0, 0}) + .Case("pto.tget_scale_addr", OpcodeAndVariant{0x1078, 0, 0}) + .Case("pto.trowargmax", OpcodeAndVariant{0x1079, 0, 0}) + .Case("pto.trowargmin", OpcodeAndVariant{0x107A, 0, 0}) + .Case("pto.tcolargmax", OpcodeAndVariant{0x107B, 0, 0}) + .Case("pto.tcolargmin", OpcodeAndVariant{0x107C, 0, 0}) + .Case("pto.tsync", OpcodeAndVariant{0x107D, 0, 0}) + .Case("pto.reserve_buffer", OpcodeAndVariant{0x107E, 0, 0}) + .Case("pto.import_reserved_buffer", OpcodeAndVariant{0x107F, 0, 0}) + .Case("pto.aic_initialize_pipe", OpcodeAndVariant{0x1080, 0, 0}) + .Case("pto.aiv_initialize_pipe", OpcodeAndVariant{0x1081, 0, 0}) + .Case("pto.tpush_to_aiv", OpcodeAndVariant{0x1082, 0, 0}) + .Case("pto.tpush_to_aic", OpcodeAndVariant{0x1083, 0, 0}) + .Case("pto.tpop_from_aic", OpcodeAndVariant{0x1084, 0, 0}) + .Case("pto.tpop_from_aiv", OpcodeAndVariant{0x1085, 0, 0}) + .Case("pto.tfree_from_aic", OpcodeAndVariant{0x1086, 0, 0}) + .Case("pto.tfree_from_aiv", OpcodeAndVariant{0x1087, 0, 0}) + .Case("pto.set_validshape", OpcodeAndVariant{0x1088, 0, 0}) + .Case("pto.tconcat", OpcodeAndVariant{0x1089, 0, 0}) + .Case("pto.trowprod", OpcodeAndVariant{0x108A, 0, 0}) + .Case("pto.initialize_l2g2l_pipe", OpcodeAndVariant{0x108B, 0, 0}) + .Case("pto.initialize_l2l_pipe", OpcodeAndVariant{0x108C, 0, 0}) + .Case("pto.tpush", OpcodeAndVariant{0x108D, 0, 0}) + .Case("pto.declare_tile", OpcodeAndVariant{0x108E, 0, 0}) + .Case("pto.tpop", OpcodeAndVariant{0x108F, 0, 0}) + .Case("pto.tfree", OpcodeAndVariant{0x1090, 0, 0}) + .Case("pto.comm.tput", OpcodeAndVariant{0x1091, 0, 0}) + .Case("pto.comm.tget", OpcodeAndVariant{0x1092, 0, 0}) + .Case("pto.comm.tnotify", OpcodeAndVariant{0x1093, 0, 0}) + .Case("pto.comm.twait", OpcodeAndVariant{0x1094, 0, 0}) + .Case("pto.comm.ttest", OpcodeAndVariant{0x1095, 0, 0}) + .Case("pto.comm.tbroadcast", OpcodeAndVariant{0x1096, 0, 0}) + .Case("pto.comm.tgather", OpcodeAndVariant{0x1097, 0, 0}) + .Case("pto.comm.tscatter", OpcodeAndVariant{0x1098, 0, 0}) + .Case("pto.comm.treduce", OpcodeAndVariant{0x1099, 0, 0}) + .Case("pto.tpartargmax", OpcodeAndVariant{0x109A, 0, 0}) + .Case("pto.tpartargmin", OpcodeAndVariant{0x109B, 0, 0}) + .Case("scf.for", OpcodeAndVariant{0x4000, 0, 0}) + .Case("scf.if", OpcodeAndVariant{0x4001, 0, 0}) + .Case("scf.yield", OpcodeAndVariant{0x4002, 0, 0}) + .Case("pto.section.cube", + OpcodeAndVariant{0x0006, kHasVariant, kSectionCubeVariant}) + .Case("pto.section.vector", + OpcodeAndVariant{0x0006, kHasVariant, kSectionVectorVariant}) + .Case("pto.tgemv", + OpcodeAndVariant{0x102A, kHasVariant, kVariantDefault}) + .Case("pto.tgemv.acc", + OpcodeAndVariant{0x102A, kHasVariant, kVariantAcc}) + .Case("pto.tgemv.bias", + OpcodeAndVariant{0x102A, kHasVariant, kVariantBias}) + .Case("pto.tgemv.mx", + OpcodeAndVariant{0x102A, kHasVariant, kVariantMx}) + .Case("pto.tgemv.mx.acc", + OpcodeAndVariant{0x102A, kHasVariant, kVariantMxAcc}) + .Case("pto.tgemv.mx.bias", + OpcodeAndVariant{0x102A, kHasVariant, kVariantMxBias}) + .Case("pto.tmatmul", + OpcodeAndVariant{0x1032, kHasVariant, kVariantDefault}) + .Case("pto.tmatmul.acc", + OpcodeAndVariant{0x1032, kHasVariant, kVariantAcc}) + .Case("pto.tmatmul.bias", + OpcodeAndVariant{0x1032, kHasVariant, kVariantBias}) + .Case("pto.tmatmul.mx", + OpcodeAndVariant{0x1033, kHasVariant, kVariantDefault}) + .Case("pto.tmatmul.mx.acc", + OpcodeAndVariant{0x1033, kHasVariant, kVariantAcc}) + .Case("pto.tmatmul.mx.bias", + OpcodeAndVariant{0x1033, kHasVariant, kVariantBias}) + .Default(std::nullopt); +} + +const char *fullNameFromOpcodeVariant(uint16_t opcode, uint8_t variant) { + const OpInfo *info = lookupByOpcode(opcode); + if (!info) return nullptr; + if (opcode == kTscatterMaskOpcode) return "pto.tscatter"; + if (!info->has_variant_u8) return info->name; + switch (opcode) { + case 0x0006: + switch (variant) { + case kSectionCubeVariant: return "pto.section.cube"; + case kSectionVectorVariant: return "pto.section.vector"; + default: return info->name; + } + case 0x102A: + switch (variant) { + case kVariantDefault: return "pto.tgemv"; + case kVariantAcc: return "pto.tgemv.acc"; + case kVariantBias: return "pto.tgemv.bias"; + case kVariantMx: return "pto.tgemv.mx"; + case kVariantMxAcc: return "pto.tgemv.mx.acc"; + case kVariantMxBias: return "pto.tgemv.mx.bias"; + default: return info->name; + } + case 0x1032: + switch (variant) { + case kVariantDefault: return "pto.tmatmul"; + case kVariantAcc: return "pto.tmatmul.acc"; + case kVariantBias: return "pto.tmatmul.bias"; + default: return info->name; + } + case 0x1033: + switch (variant) { + case kVariantDefault: return "pto.tmatmul.mx"; + case kVariantAcc: return "pto.tmatmul.mx.acc"; + case kVariantBias: return "pto.tmatmul.mx.bias"; + default: return info->name; + } + default: return info->name; + } +} + +std::optional lookupOperandsByVariant(uint16_t opcode, uint8_t variant) { + switch (opcode) { + case 0x102A: + switch (variant) { + case kVariantDefault: return kTgemvOperandCount; + case kVariantAcc: return kTgemvAccOperandCount; + case kVariantBias: return kTgemvBiasOperandCount; + case kVariantMx: return kTgemvMxOperandCount; + case kVariantMxAcc: return kTgemvMxAccOperandCount; + case kVariantMxBias: return kTgemvMxBiasOperandCount; + default: return std::nullopt; + } + case 0x1032: + switch (variant) { + case kVariantDefault: return kTmatmulOperandCount; + case kVariantAcc: return kTmatmulAccOperandCount; + case kVariantBias: return kTmatmulBiasOperandCount; + default: return std::nullopt; + } + case 0x1033: + switch (variant) { + case kVariantDefault: return kTmatmulMxOperandCount; + case kVariantAcc: return kTmatmulMxAccOperandCount; + case kVariantBias: return kTmatmulMxBiasOperandCount; + default: return std::nullopt; + } + default: return std::nullopt; + } +} + +} // namespace ptobc::v0 diff --git a/tools/ptobc/generated/ptobc_opcodes_v0.h b/tools/ptobc/generated/ptobc_opcodes_v0.h index 8303e1261..9c7a9a1d0 100644 --- a/tools/ptobc/generated/ptobc_opcodes_v0.h +++ b/tools/ptobc/generated/ptobc_opcodes_v0.h @@ -53,670 +53,16 @@ struct OpInfo { uint8_t imm_kind; }; -inline constexpr OpInfo kOpTable[] = { - {0x0000, "pto.get_block_idx", 0, 0x01, 0x00, 0, 1, 0, 0x00}, - {0x0001, "pto.get_block_num", 0, 0x01, 0x00, 0, 1, 0, 0x00}, - {0x0002, "pto.get_subblock_idx", 0, 0x01, 0x00, 0, 1, 0, 0x00}, - {0x0003, "pto.get_subblock_num", 0, 0x01, 0x00, 0, 1, 0, 0x00}, - {0x0004, "pto.make_tensor_view", 0, 0x01, 0x03, 1, 1, 0, 0x06}, - {0x0005, "pto.partition_view", 0, 0x01, 0x03, 1, 1, 0, 0x07}, - {0x0006, "pto.section", 1, 0x00, 0x00, 0, 0, 1, 0x00}, - {0x1000, "pto.addptr", 0, 0x01, 0x00, 2, 1, 0, 0x00}, - {0x1001, "pto.alloc_tile", 0, 0x01, 0x04, 0, 1, 0, 0x08}, - {0x1002, "pto.barrier", 0, 0x00, 0x00, 0, 0, 0, 0x00}, - {0x1003, "pto.mgather", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1004, "pto.mscatter", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1005, "pto.record_event", 0, 0x00, 0x00, 0, 0, 0, 0x02}, - {0x1006, "pto.tabs", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1007, "pto.tadd", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1008, "pto.taddc", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x1009, "pto.tadds", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x100A, "pto.taddsc", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x100B, "pto.tand", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x100C, "pto.tands", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x100D, "pto.tci", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x100E, "pto.tcmp", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x100F, "pto.tcmps", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1010, "pto.tcolexpand", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1011, "pto.tcolexpandadd", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1012, "pto.tcolexpanddiv", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1013, "pto.tcolexpandexpdif", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1014, "pto.tcolexpandmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1015, "pto.tcolexpandmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1016, "pto.tcolexpandmul", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1017, "pto.tcolexpandsub", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1018, "pto.tcolmax", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1019, "pto.tcolmin", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x101A, "pto.tcolprod", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x101B, "pto.tcolsum", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x101C, "pto.tcvt", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x101D, "pto.tdiv", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x101E, "pto.tdivs", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x101F, "pto.texp", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1020, "pto.texpands", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1021, "pto.textract", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x1022, "pto.textract_fp", 0, 0x00, 0x00, 5, 0, 0, 0x00}, - {0x1023, "pto.tfillpad", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1024, "pto.tfillpad_expand", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1025, "pto.tfillpad_inplace", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1026, "pto.tfmod", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1027, "pto.tfmods", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1028, "pto.tgather", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1029, "pto.tgatherb", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x102A, "pto.tgemv", 1, 0x00, 0x01, 0, 0, 0, 0x00}, - {0x102B, "pto.tgetval", 0, 0x01, 0x00, 2, 1, 0, 0x00}, - {0x102C, "pto.timg2col", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x102D, "pto.tinsert", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x102E, "pto.tinsert_fp", 0, 0x00, 0x00, 5, 0, 0, 0x00}, - {0x102F, "pto.tload", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1030, "pto.tlog", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1031, "pto.tlrelu", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1032, "pto.tmatmul", 1, 0x00, 0x01, 0, 0, 0, 0x00}, - {0x1033, "pto.tmatmul.mx", 1, 0x00, 0x01, 0, 0, 0, 0x00}, - {0x1034, "pto.tmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1035, "pto.tmaxs", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1036, "pto.tmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1037, "pto.tmins", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1038, "pto.tmov", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1039, "pto.tmov.fp", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x103A, "pto.tmrgsort", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x103B, "pto.tmul", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x103C, "pto.tmuls", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x103D, "pto.tneg", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x103E, "pto.tnot", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x103F, "pto.tor", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1040, "pto.tors", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1041, "pto.tpartadd", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1042, "pto.tpartmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1043, "pto.tpartmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1044, "pto.tpartmul", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1045, "pto.tprefetch", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1046, "pto.tprelu", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x1047, "pto.tquant", 0, 0x00, 0x02, 3, 0, 0, 0x00}, - {0x1048, "pto.trecip", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1049, "pto.trelu", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x104A, "pto.trem", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x104B, "pto.trems", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x104C, "pto.treshape", 0, 0x01, 0x00, 1, 1, 0, 0x00}, - {0x104D, "pto.trowexpand", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x104E, "pto.trowexpandadd", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x104F, "pto.trowexpandexpdif", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1050, "pto.trowexpandmax", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1051, "pto.trowexpandmin", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1052, "pto.trowmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1053, "pto.trowmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1054, "pto.trowsum", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1055, "pto.trsqrt", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1056, "pto.tscatter", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1057, "pto.tsel", 0, 0x00, 0x00, 5, 0, 0, 0x00}, - {0x1058, "pto.tsels", 0, 0x00, 0x00, 5, 0, 0, 0x00}, - {0x1059, "pto.tset_img2col_padding", 0, 0x00, 0x00, 1, 0, 0, 0x00}, - {0x105A, "pto.tset_img2col_rpt", 0, 0x00, 0x00, 1, 0, 0, 0x00}, - {0x105B, "pto.tsetfmatrix", 0, 0x00, 0x00, 1, 0, 0, 0x00}, - {0x105C, "pto.tsethf32mode", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x105D, "pto.tsettf32mode", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x105E, "pto.tsetval", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x105F, "pto.tshl", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1060, "pto.tshls", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1061, "pto.tshr", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1062, "pto.tshrs", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1063, "pto.tsort32", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1064, "pto.tsqrt", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1065, "pto.tstore", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1066, "pto.tstore_fp", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1067, "pto.tsub", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1068, "pto.tsubc", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x1069, "pto.tsubs", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x106A, "pto.tsubsc", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x106B, "pto.trowexpandsub", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x106C, "pto.ttrans", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x106D, "pto.ttri", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x106E, "pto.txor", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x106F, "pto.txors", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x1070, "pto.wait_event", 0, 0x00, 0x00, 0, 0, 0, 0x02}, - {0x1071, "pto.tprint", 0, 0x00, 0x00, 1, 0, 0, 0x00}, - {0x1072, "pto.subview", 0, 0x01, 0x02, 0, 1, 0, 0x00}, - {0x1073, "pto.trowexpanddiv", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1074, "pto.trowexpandmul", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1075, "pto.tdequant", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x1076, "pto.taxpy", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1077, "pto.thistogram", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1078, "pto.tget_scale_addr", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1079, "pto.trowargmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x107A, "pto.trowargmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x107B, "pto.tcolargmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x107C, "pto.tcolargmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x107D, "pto.tsync", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x107E, "pto.reserve_buffer", 0, 0x01, 0x00, 0, 1, 0, 0x00}, - {0x107F, "pto.import_reserved_buffer", 0, 0x01, 0x00, 0, 1, 0, 0x00}, - {0x1080, "pto.aic_initialize_pipe", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1081, "pto.aiv_initialize_pipe", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1082, "pto.tpush_to_aiv", 0, 0x00, 0x00, 1, 0, 0, 0x00}, - {0x1083, "pto.tpush_to_aic", 0, 0x00, 0x00, 1, 0, 0, 0x00}, - {0x1084, "pto.tpop_from_aic", 0, 0x01, 0x02, 0, 1, 0, 0x00}, - {0x1085, "pto.tpop_from_aiv", 0, 0x01, 0x02, 0, 1, 0, 0x00}, - {0x1086, "pto.tfree_from_aic", 0, 0x00, 0x00, 0, 0, 0, 0x00}, - {0x1087, "pto.tfree_from_aiv", 0, 0x00, 0x00, 0, 0, 0, 0x00}, - {0x1088, "pto.set_validshape", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1089, "pto.tconcat", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x108A, "pto.trowprod", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x108B, "pto.initialize_l2g2l_pipe", 0, 0x01, 0x02, 0, 1, 0, 0x00}, - {0x108C, "pto.initialize_l2l_pipe", 0, 0x01, 0x02, 0, 1, 0, 0x00}, - {0x108D, "pto.tpush", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x108E, "pto.declare_tile", 0, 0x01, 0x00, 0, 1, 0, 0x00}, - {0x108F, "pto.tpop", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1090, "pto.tfree", 0, 0x00, 0x00, 1, 0, 0, 0x00}, - {0x1091, "pto.comm.tput", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1092, "pto.comm.tget", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1093, "pto.comm.tnotify", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1094, "pto.comm.twait", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1095, "pto.comm.ttest", 0, 0x01, 0x02, 0, 1, 0, 0x00}, - {0x1096, "pto.comm.tbroadcast", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1097, "pto.comm.tgather", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1098, "pto.comm.tscatter", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1099, "pto.comm.treduce", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x109A, "pto.tpartargmax", 0, 0x00, 0x00, 6, 0, 0, 0x00}, - {0x109B, "pto.tpartargmin", 0, 0x00, 0x00, 6, 0, 0, 0x00}, - {0x109C, "pto.tscatter.maskpattern", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x2000, "arith.addi", 0, 0x01, 0x00, 2, 1, 0, 0x00}, - {0x2001, "arith.ceildivsi", 0, 0x01, 0x00, 2, 1, 0, 0x00}, - {0x2002, "arith.cmpi", 0, 0x01, 0x00, 2, 1, 0, 0x01}, - {0x2003, "arith.constant", 0, 0x01, 0x00, 0, 1, 0, 0x05}, - {0x2004, "arith.index_cast", 0, 0x01, 0x00, 1, 1, 0, 0x00}, - {0x2005, "arith.minui", 0, 0x01, 0x00, 2, 1, 0, 0x00}, - {0x2006, "arith.muli", 0, 0x01, 0x00, 2, 1, 0, 0x00}, - {0x2007, "arith.select", 0, 0x01, 0x00, 3, 1, 0, 0x00}, - {0x2008, "arith.subi", 0, 0x01, 0x00, 2, 1, 0, 0x00}, - {0x4000, "scf.for", 0, 0x00, 0x00, 3, 0, 1, 0x00}, - {0x4001, "scf.if", 0, 0x00, 0x00, 1, 0, 2, 0x00}, - {0x4002, "scf.yield", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x6000, "func.func", 0, 0x00, 0x00, 0, 0, 0, 0x00}, - {0x6001, "func.return", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x6002, "func.call", 0, 0x02, 0x02, 0, 0, 0, 0x00}, -}; - -inline const OpInfo *lookupByOpcode(uint16_t opcode) { - // Binary search on kOpTable (sorted by opcode). - size_t lo = 0, hi = sizeof(kOpTable) / sizeof(kOpTable[0]); - while (lo < hi) { - size_t mid = lo + (hi - lo) / 2; - uint16_t v = kOpTable[mid].opcode; - if (v == opcode) return &kOpTable[mid]; - if (v < opcode) lo = mid + 1; else hi = mid; - } - return nullptr; -} - -inline std::optional lookupOpcodeByName(llvm::StringRef name) { - uint16_t v = llvm::StringSwitch(name) - .Case("arith.addi", 0x2000) - .Case("arith.ceildivsi", 0x2001) - .Case("arith.cmpi", 0x2002) - .Case("arith.constant", 0x2003) - .Case("arith.index_cast", 0x2004) - .Case("arith.minui", 0x2005) - .Case("arith.muli", 0x2006) - .Case("arith.select", 0x2007) - .Case("arith.subi", 0x2008) - .Case("func.func", 0x6000) - .Case("func.return", 0x6001) - .Case("func.call", 0x6002) - .Case("pto.addptr", 0x1000) - .Case("pto.alloc_tile", 0x1001) - .Case("pto.barrier", 0x1002) - .Case("pto.get_block_idx", 0x0000) - .Case("pto.get_block_num", 0x0001) - .Case("pto.get_subblock_idx", 0x0002) - .Case("pto.get_subblock_num", 0x0003) - .Case("pto.make_tensor_view", 0x0004) - .Case("pto.mgather", 0x1003) - .Case("pto.mscatter", 0x1004) - .Case("pto.partition_view", 0x0005) - .Case("pto.record_event", 0x1005) - .Case("pto.section", 0x0006) - .Case("pto.tabs", 0x1006) - .Case("pto.tadd", 0x1007) - .Case("pto.taddc", 0x1008) - .Case("pto.tadds", 0x1009) - .Case("pto.taddsc", 0x100A) - .Case("pto.tand", 0x100B) - .Case("pto.tands", 0x100C) - .Case("pto.tci", 0x100D) - .Case("pto.tcmp", 0x100E) - .Case("pto.tcmps", 0x100F) - .Case("pto.tcolexpand", 0x1010) - .Case("pto.tcolexpandadd", 0x1011) - .Case("pto.tcolexpanddiv", 0x1012) - .Case("pto.tcolexpandexpdif", 0x1013) - .Case("pto.tcolexpandmax", 0x1014) - .Case("pto.tcolexpandmin", 0x1015) - .Case("pto.tcolexpandmul", 0x1016) - .Case("pto.tcolexpandsub", 0x1017) - .Case("pto.tcolmax", 0x1018) - .Case("pto.tcolmin", 0x1019) - .Case("pto.tcolprod", 0x101A) - .Case("pto.tcolsum", 0x101B) - .Case("pto.tcvt", 0x101C) - .Case("pto.tdiv", 0x101D) - .Case("pto.tdivs", 0x101E) - .Case("pto.texp", 0x101F) - .Case("pto.texpands", 0x1020) - .Case("pto.textract", 0x1021) - .Case("pto.textract_fp", 0x1022) - .Case("pto.tfillpad", 0x1023) - .Case("pto.tfillpad_expand", 0x1024) - .Case("pto.tfillpad_inplace", 0x1025) - .Case("pto.tfmod", 0x1026) - .Case("pto.tfmods", 0x1027) - .Case("pto.tgather", 0x1028) - .Case("pto.tgatherb", 0x1029) - .Case("pto.tgemv", 0x102A) - .Case("pto.tgetval", 0x102B) - .Case("pto.timg2col", 0x102C) - .Case("pto.tinsert", 0x102D) - .Case("pto.tinsert_fp", 0x102E) - .Case("pto.tload", 0x102F) - .Case("pto.tlog", 0x1030) - .Case("pto.tlrelu", 0x1031) - .Case("pto.tmatmul", 0x1032) - .Case("pto.tmatmul.mx", 0x1033) - .Case("pto.tmax", 0x1034) - .Case("pto.tmaxs", 0x1035) - .Case("pto.tmin", 0x1036) - .Case("pto.tmins", 0x1037) - .Case("pto.tmov", 0x1038) - .Case("pto.tmov.fp", 0x1039) - .Case("pto.tmrgsort", 0x103A) - .Case("pto.tmul", 0x103B) - .Case("pto.tmuls", 0x103C) - .Case("pto.tneg", 0x103D) - .Case("pto.tnot", 0x103E) - .Case("pto.tor", 0x103F) - .Case("pto.tors", 0x1040) - .Case("pto.tpartadd", 0x1041) - .Case("pto.tpartmax", 0x1042) - .Case("pto.tpartmin", 0x1043) - .Case("pto.tpartmul", 0x1044) - .Case("pto.tprefetch", 0x1045) - .Case("pto.tprelu", 0x1046) - .Case("pto.tquant", 0x1047) - .Case("pto.trecip", 0x1048) - .Case("pto.trelu", 0x1049) - .Case("pto.trem", 0x104A) - .Case("pto.trems", 0x104B) - .Case("pto.treshape", 0x104C) - .Case("pto.trowexpand", 0x104D) - .Case("pto.trowexpandadd", 0x104E) - .Case("pto.trowexpandexpdif", 0x104F) - .Case("pto.trowexpandmax", 0x1050) - .Case("pto.trowexpandmin", 0x1051) - .Case("pto.trowmax", 0x1052) - .Case("pto.trowmin", 0x1053) - .Case("pto.trowsum", 0x1054) - .Case("pto.trsqrt", 0x1055) - .Case("pto.tscatter", 0x1056) - .Case("pto.tsel", 0x1057) - .Case("pto.tsels", 0x1058) - .Case("pto.tset_img2col_padding", 0x1059) - .Case("pto.tset_img2col_rpt", 0x105A) - .Case("pto.tsetfmatrix", 0x105B) - .Case("pto.tsethf32mode", 0x105C) - .Case("pto.tsettf32mode", 0x105D) - .Case("pto.tsetval", 0x105E) - .Case("pto.tshl", 0x105F) - .Case("pto.tshls", 0x1060) - .Case("pto.tshr", 0x1061) - .Case("pto.tshrs", 0x1062) - .Case("pto.tsort32", 0x1063) - .Case("pto.tsqrt", 0x1064) - .Case("pto.tstore", 0x1065) - .Case("pto.tstore_fp", 0x1066) - .Case("pto.tsub", 0x1067) - .Case("pto.tsubc", 0x1068) - .Case("pto.tsubs", 0x1069) - .Case("pto.tsubsc", 0x106A) - .Case("pto.trowexpandsub", 0x106B) - .Case("pto.ttrans", 0x106C) - .Case("pto.ttri", 0x106D) - .Case("pto.txor", 0x106E) - .Case("pto.txors", 0x106F) - .Case("pto.wait_event", 0x1070) - .Case("pto.tprint", 0x1071) - .Case("pto.subview", 0x1072) - .Case("pto.trowexpanddiv", 0x1073) - .Case("pto.trowexpandmul", 0x1074) - .Case("pto.tdequant", 0x1075) - .Case("pto.taxpy", 0x1076) - .Case("pto.thistogram", 0x1077) - .Case("pto.tget_scale_addr", 0x1078) - .Case("pto.trowargmax", 0x1079) - .Case("pto.trowargmin", 0x107A) - .Case("pto.tcolargmax", 0x107B) - .Case("pto.tcolargmin", 0x107C) - .Case("pto.tsync", 0x107D) - .Case("pto.reserve_buffer", 0x107E) - .Case("pto.import_reserved_buffer", 0x107F) - .Case("pto.aic_initialize_pipe", 0x1080) - .Case("pto.aiv_initialize_pipe", 0x1081) - .Case("pto.tpush_to_aiv", 0x1082) - .Case("pto.tpush_to_aic", 0x1083) - .Case("pto.tpop_from_aic", 0x1084) - .Case("pto.tpop_from_aiv", 0x1085) - .Case("pto.tfree_from_aic", 0x1086) - .Case("pto.tfree_from_aiv", 0x1087) - .Case("pto.set_validshape", 0x1088) - .Case("pto.tconcat", 0x1089) - .Case("pto.trowprod", 0x108A) - .Case("pto.initialize_l2g2l_pipe", 0x108B) - .Case("pto.initialize_l2l_pipe", 0x108C) - .Case("pto.tpush", 0x108D) - .Case("pto.declare_tile", 0x108E) - .Case("pto.tpop", 0x108F) - .Case("pto.tfree", 0x1090) - .Case("pto.comm.tput", 0x1091) - .Case("pto.comm.tget", 0x1092) - .Case("pto.comm.tnotify", 0x1093) - .Case("pto.comm.twait", 0x1094) - .Case("pto.comm.ttest", 0x1095) - .Case("pto.comm.tbroadcast", 0x1096) - .Case("pto.comm.tgather", 0x1097) - .Case("pto.comm.tscatter", 0x1098) - .Case("pto.comm.treduce", 0x1099) - .Case("pto.tpartargmax", 0x109A) - .Case("pto.tpartargmin", 0x109B) - .Case("scf.for", 0x4000) - .Case("scf.if", 0x4001) - .Case("scf.yield", 0x4002) - .Default(0xFFFF); - if (v == 0xFFFF) return std::nullopt; - return v; -} +extern const OpInfo kOpTable[]; -inline const OpInfo *lookupByName(llvm::StringRef name) { - auto o = lookupOpcodeByName(name); - if (!o) return nullptr; - return lookupByOpcode(*o); -} +const OpInfo *lookupByOpcode(uint16_t opcode); +std::optional lookupOpcodeByName(llvm::StringRef name); +const OpInfo *lookupByName(llvm::StringRef name); struct OpcodeAndVariant { uint16_t opcode; uint8_t hasVariant; uint8_t variant; }; -inline std::optional lookupOpcodeAndVariantByFullName(llvm::StringRef fullName) { - // For non-family ops, variant is 0. For family ops, variant is the assigned u8. - // NOTE: `pto.section` is not a real op name; use `pto.section.cube`/`pto.section.vector`. - return llvm::StringSwitch>(fullName) - .Case("arith.addi", OpcodeAndVariant{0x2000, 0, 0}) - .Case("arith.ceildivsi", OpcodeAndVariant{0x2001, 0, 0}) - .Case("arith.cmpi", OpcodeAndVariant{0x2002, 0, 0}) - .Case("arith.constant", OpcodeAndVariant{0x2003, 0, 0}) - .Case("arith.index_cast", OpcodeAndVariant{0x2004, 0, 0}) - .Case("arith.minui", OpcodeAndVariant{0x2005, 0, 0}) - .Case("arith.muli", OpcodeAndVariant{0x2006, 0, 0}) - .Case("arith.select", OpcodeAndVariant{0x2007, 0, 0}) - .Case("arith.subi", OpcodeAndVariant{0x2008, 0, 0}) - .Case("func.func", OpcodeAndVariant{0x6000, 0, 0}) - .Case("func.return", OpcodeAndVariant{0x6001, 0, 0}) - .Case("func.call", OpcodeAndVariant{0x6002, 0, 0}) - .Case("pto.addptr", OpcodeAndVariant{0x1000, 0, 0}) - .Case("pto.alloc_tile", OpcodeAndVariant{0x1001, 0, 0}) - .Case("pto.barrier", OpcodeAndVariant{0x1002, 0, 0}) - .Case("pto.get_block_idx", OpcodeAndVariant{0x0000, 0, 0}) - .Case("pto.get_block_num", OpcodeAndVariant{0x0001, 0, 0}) - .Case("pto.get_subblock_idx", OpcodeAndVariant{0x0002, 0, 0}) - .Case("pto.get_subblock_num", OpcodeAndVariant{0x0003, 0, 0}) - .Case("pto.make_tensor_view", OpcodeAndVariant{0x0004, 0, 0}) - .Case("pto.mgather", OpcodeAndVariant{0x1003, 0, 0}) - .Case("pto.mscatter", OpcodeAndVariant{0x1004, 0, 0}) - .Case("pto.partition_view", OpcodeAndVariant{0x0005, 0, 0}) - .Case("pto.record_event", OpcodeAndVariant{0x1005, 0, 0}) - .Case("pto.tabs", OpcodeAndVariant{0x1006, 0, 0}) - .Case("pto.tadd", OpcodeAndVariant{0x1007, 0, 0}) - .Case("pto.taddc", OpcodeAndVariant{0x1008, 0, 0}) - .Case("pto.tadds", OpcodeAndVariant{0x1009, 0, 0}) - .Case("pto.taddsc", OpcodeAndVariant{0x100A, 0, 0}) - .Case("pto.tand", OpcodeAndVariant{0x100B, 0, 0}) - .Case("pto.tands", OpcodeAndVariant{0x100C, 0, 0}) - .Case("pto.tci", OpcodeAndVariant{0x100D, 0, 0}) - .Case("pto.tcmp", OpcodeAndVariant{0x100E, 0, 0}) - .Case("pto.tcmps", OpcodeAndVariant{0x100F, 0, 0}) - .Case("pto.tcolexpand", OpcodeAndVariant{0x1010, 0, 0}) - .Case("pto.tcolexpandadd", OpcodeAndVariant{0x1011, 0, 0}) - .Case("pto.tcolexpanddiv", OpcodeAndVariant{0x1012, 0, 0}) - .Case("pto.tcolexpandexpdif", OpcodeAndVariant{0x1013, 0, 0}) - .Case("pto.tcolexpandmax", OpcodeAndVariant{0x1014, 0, 0}) - .Case("pto.tcolexpandmin", OpcodeAndVariant{0x1015, 0, 0}) - .Case("pto.tcolexpandmul", OpcodeAndVariant{0x1016, 0, 0}) - .Case("pto.tcolexpandsub", OpcodeAndVariant{0x1017, 0, 0}) - .Case("pto.tcolmax", OpcodeAndVariant{0x1018, 0, 0}) - .Case("pto.tcolmin", OpcodeAndVariant{0x1019, 0, 0}) - .Case("pto.tcolprod", OpcodeAndVariant{0x101A, 0, 0}) - .Case("pto.tcolsum", OpcodeAndVariant{0x101B, 0, 0}) - .Case("pto.tcvt", OpcodeAndVariant{0x101C, 0, 0}) - .Case("pto.tdiv", OpcodeAndVariant{0x101D, 0, 0}) - .Case("pto.tdivs", OpcodeAndVariant{0x101E, 0, 0}) - .Case("pto.texp", OpcodeAndVariant{0x101F, 0, 0}) - .Case("pto.texpands", OpcodeAndVariant{0x1020, 0, 0}) - .Case("pto.textract", OpcodeAndVariant{0x1021, 0, 0}) - .Case("pto.textract_fp", OpcodeAndVariant{0x1022, 0, 0}) - .Case("pto.tfillpad", OpcodeAndVariant{0x1023, 0, 0}) - .Case("pto.tfillpad_expand", OpcodeAndVariant{0x1024, 0, 0}) - .Case("pto.tfillpad_inplace", OpcodeAndVariant{0x1025, 0, 0}) - .Case("pto.tfmod", OpcodeAndVariant{0x1026, 0, 0}) - .Case("pto.tfmods", OpcodeAndVariant{0x1027, 0, 0}) - .Case("pto.tgather", OpcodeAndVariant{0x1028, 0, 0}) - .Case("pto.tgatherb", OpcodeAndVariant{0x1029, 0, 0}) - .Case("pto.tgetval", OpcodeAndVariant{0x102B, 0, 0}) - .Case("pto.timg2col", OpcodeAndVariant{0x102C, 0, 0}) - .Case("pto.tinsert", OpcodeAndVariant{0x102D, 0, 0}) - .Case("pto.tinsert_fp", OpcodeAndVariant{0x102E, 0, 0}) - .Case("pto.tload", OpcodeAndVariant{0x102F, 0, 0}) - .Case("pto.tlog", OpcodeAndVariant{0x1030, 0, 0}) - .Case("pto.tlrelu", OpcodeAndVariant{0x1031, 0, 0}) - .Case("pto.tmax", OpcodeAndVariant{0x1034, 0, 0}) - .Case("pto.tmaxs", OpcodeAndVariant{0x1035, 0, 0}) - .Case("pto.tmin", OpcodeAndVariant{0x1036, 0, 0}) - .Case("pto.tmins", OpcodeAndVariant{0x1037, 0, 0}) - .Case("pto.tmov", OpcodeAndVariant{0x1038, 0, 0}) - .Case("pto.tmov.fp", OpcodeAndVariant{0x1039, 0, 0}) - .Case("pto.tmrgsort", OpcodeAndVariant{0x103A, 0, 0}) - .Case("pto.tmul", OpcodeAndVariant{0x103B, 0, 0}) - .Case("pto.tmuls", OpcodeAndVariant{0x103C, 0, 0}) - .Case("pto.tneg", OpcodeAndVariant{0x103D, 0, 0}) - .Case("pto.tnot", OpcodeAndVariant{0x103E, 0, 0}) - .Case("pto.tor", OpcodeAndVariant{0x103F, 0, 0}) - .Case("pto.tors", OpcodeAndVariant{0x1040, 0, 0}) - .Case("pto.tpartadd", OpcodeAndVariant{0x1041, 0, 0}) - .Case("pto.tpartmax", OpcodeAndVariant{0x1042, 0, 0}) - .Case("pto.tpartmin", OpcodeAndVariant{0x1043, 0, 0}) - .Case("pto.tpartmul", OpcodeAndVariant{0x1044, 0, 0}) - .Case("pto.tprefetch", OpcodeAndVariant{0x1045, 0, 0}) - .Case("pto.tprelu", OpcodeAndVariant{0x1046, 0, 0}) - .Case("pto.tquant", OpcodeAndVariant{0x1047, 0, 0}) - .Case("pto.trecip", OpcodeAndVariant{0x1048, 0, 0}) - .Case("pto.trelu", OpcodeAndVariant{0x1049, 0, 0}) - .Case("pto.trem", OpcodeAndVariant{0x104A, 0, 0}) - .Case("pto.trems", OpcodeAndVariant{0x104B, 0, 0}) - .Case("pto.treshape", OpcodeAndVariant{0x104C, 0, 0}) - .Case("pto.trowexpand", OpcodeAndVariant{0x104D, 0, 0}) - .Case("pto.trowexpandadd", OpcodeAndVariant{0x104E, 0, 0}) - .Case("pto.trowexpandexpdif", OpcodeAndVariant{0x104F, 0, 0}) - .Case("pto.trowexpandmax", OpcodeAndVariant{0x1050, 0, 0}) - .Case("pto.trowexpandmin", OpcodeAndVariant{0x1051, 0, 0}) - .Case("pto.trowmax", OpcodeAndVariant{0x1052, 0, 0}) - .Case("pto.trowmin", OpcodeAndVariant{0x1053, 0, 0}) - .Case("pto.trowsum", OpcodeAndVariant{0x1054, 0, 0}) - .Case("pto.trsqrt", OpcodeAndVariant{0x1055, 0, 0}) - .Case("pto.tscatter", OpcodeAndVariant{0x1056, 0, 0}) - .Case("pto.tsel", OpcodeAndVariant{0x1057, 0, 0}) - .Case("pto.tsels", OpcodeAndVariant{0x1058, 0, 0}) - .Case("pto.tset_img2col_padding", OpcodeAndVariant{0x1059, 0, 0}) - .Case("pto.tset_img2col_rpt", OpcodeAndVariant{0x105A, 0, 0}) - .Case("pto.tsetfmatrix", OpcodeAndVariant{0x105B, 0, 0}) - .Case("pto.tsethf32mode", OpcodeAndVariant{0x105C, 0, 0}) - .Case("pto.tsettf32mode", OpcodeAndVariant{0x105D, 0, 0}) - .Case("pto.tsetval", OpcodeAndVariant{0x105E, 0, 0}) - .Case("pto.tshl", OpcodeAndVariant{0x105F, 0, 0}) - .Case("pto.tshls", OpcodeAndVariant{0x1060, 0, 0}) - .Case("pto.tshr", OpcodeAndVariant{0x1061, 0, 0}) - .Case("pto.tshrs", OpcodeAndVariant{0x1062, 0, 0}) - .Case("pto.tsort32", OpcodeAndVariant{0x1063, 0, 0}) - .Case("pto.tsqrt", OpcodeAndVariant{0x1064, 0, 0}) - .Case("pto.tstore", OpcodeAndVariant{0x1065, 0, 0}) - .Case("pto.tstore_fp", OpcodeAndVariant{0x1066, 0, 0}) - .Case("pto.tsub", OpcodeAndVariant{0x1067, 0, 0}) - .Case("pto.tsubc", OpcodeAndVariant{0x1068, 0, 0}) - .Case("pto.tsubs", OpcodeAndVariant{0x1069, 0, 0}) - .Case("pto.tsubsc", OpcodeAndVariant{0x106A, 0, 0}) - .Case("pto.trowexpandsub", OpcodeAndVariant{0x106B, 0, 0}) - .Case("pto.ttrans", OpcodeAndVariant{0x106C, 0, 0}) - .Case("pto.ttri", OpcodeAndVariant{0x106D, 0, 0}) - .Case("pto.txor", OpcodeAndVariant{0x106E, 0, 0}) - .Case("pto.txors", OpcodeAndVariant{0x106F, 0, 0}) - .Case("pto.wait_event", OpcodeAndVariant{0x1070, 0, 0}) - .Case("pto.tprint", OpcodeAndVariant{0x1071, 0, 0}) - .Case("pto.subview", OpcodeAndVariant{0x1072, 0, 0}) - .Case("pto.trowexpanddiv", OpcodeAndVariant{0x1073, 0, 0}) - .Case("pto.trowexpandmul", OpcodeAndVariant{0x1074, 0, 0}) - .Case("pto.tdequant", OpcodeAndVariant{0x1075, 0, 0}) - .Case("pto.taxpy", OpcodeAndVariant{0x1076, 0, 0}) - .Case("pto.thistogram", OpcodeAndVariant{0x1077, 0, 0}) - .Case("pto.tget_scale_addr", OpcodeAndVariant{0x1078, 0, 0}) - .Case("pto.trowargmax", OpcodeAndVariant{0x1079, 0, 0}) - .Case("pto.trowargmin", OpcodeAndVariant{0x107A, 0, 0}) - .Case("pto.tcolargmax", OpcodeAndVariant{0x107B, 0, 0}) - .Case("pto.tcolargmin", OpcodeAndVariant{0x107C, 0, 0}) - .Case("pto.tsync", OpcodeAndVariant{0x107D, 0, 0}) - .Case("pto.reserve_buffer", OpcodeAndVariant{0x107E, 0, 0}) - .Case("pto.import_reserved_buffer", OpcodeAndVariant{0x107F, 0, 0}) - .Case("pto.aic_initialize_pipe", OpcodeAndVariant{0x1080, 0, 0}) - .Case("pto.aiv_initialize_pipe", OpcodeAndVariant{0x1081, 0, 0}) - .Case("pto.tpush_to_aiv", OpcodeAndVariant{0x1082, 0, 0}) - .Case("pto.tpush_to_aic", OpcodeAndVariant{0x1083, 0, 0}) - .Case("pto.tpop_from_aic", OpcodeAndVariant{0x1084, 0, 0}) - .Case("pto.tpop_from_aiv", OpcodeAndVariant{0x1085, 0, 0}) - .Case("pto.tfree_from_aic", OpcodeAndVariant{0x1086, 0, 0}) - .Case("pto.tfree_from_aiv", OpcodeAndVariant{0x1087, 0, 0}) - .Case("pto.set_validshape", OpcodeAndVariant{0x1088, 0, 0}) - .Case("pto.tconcat", OpcodeAndVariant{0x1089, 0, 0}) - .Case("pto.trowprod", OpcodeAndVariant{0x108A, 0, 0}) - .Case("pto.initialize_l2g2l_pipe", OpcodeAndVariant{0x108B, 0, 0}) - .Case("pto.initialize_l2l_pipe", OpcodeAndVariant{0x108C, 0, 0}) - .Case("pto.tpush", OpcodeAndVariant{0x108D, 0, 0}) - .Case("pto.declare_tile", OpcodeAndVariant{0x108E, 0, 0}) - .Case("pto.tpop", OpcodeAndVariant{0x108F, 0, 0}) - .Case("pto.tfree", OpcodeAndVariant{0x1090, 0, 0}) - .Case("pto.comm.tput", OpcodeAndVariant{0x1091, 0, 0}) - .Case("pto.comm.tget", OpcodeAndVariant{0x1092, 0, 0}) - .Case("pto.comm.tnotify", OpcodeAndVariant{0x1093, 0, 0}) - .Case("pto.comm.twait", OpcodeAndVariant{0x1094, 0, 0}) - .Case("pto.comm.ttest", OpcodeAndVariant{0x1095, 0, 0}) - .Case("pto.comm.tbroadcast", OpcodeAndVariant{0x1096, 0, 0}) - .Case("pto.comm.tgather", OpcodeAndVariant{0x1097, 0, 0}) - .Case("pto.comm.tscatter", OpcodeAndVariant{0x1098, 0, 0}) - .Case("pto.comm.treduce", OpcodeAndVariant{0x1099, 0, 0}) - .Case("pto.tpartargmax", OpcodeAndVariant{0x109A, 0, 0}) - .Case("pto.tpartargmin", OpcodeAndVariant{0x109B, 0, 0}) - .Case("scf.for", OpcodeAndVariant{0x4000, 0, 0}) - .Case("scf.if", OpcodeAndVariant{0x4001, 0, 0}) - .Case("scf.yield", OpcodeAndVariant{0x4002, 0, 0}) - .Case("pto.section.cube", - OpcodeAndVariant{0x0006, kHasVariant, kSectionCubeVariant}) - .Case("pto.section.vector", - OpcodeAndVariant{0x0006, kHasVariant, kSectionVectorVariant}) - .Case("pto.tgemv", - OpcodeAndVariant{0x102A, kHasVariant, kVariantDefault}) - .Case("pto.tgemv.acc", - OpcodeAndVariant{0x102A, kHasVariant, kVariantAcc}) - .Case("pto.tgemv.bias", - OpcodeAndVariant{0x102A, kHasVariant, kVariantBias}) - .Case("pto.tgemv.mx", - OpcodeAndVariant{0x102A, kHasVariant, kVariantMx}) - .Case("pto.tgemv.mx.acc", - OpcodeAndVariant{0x102A, kHasVariant, kVariantMxAcc}) - .Case("pto.tgemv.mx.bias", - OpcodeAndVariant{0x102A, kHasVariant, kVariantMxBias}) - .Case("pto.tmatmul", - OpcodeAndVariant{0x1032, kHasVariant, kVariantDefault}) - .Case("pto.tmatmul.acc", - OpcodeAndVariant{0x1032, kHasVariant, kVariantAcc}) - .Case("pto.tmatmul.bias", - OpcodeAndVariant{0x1032, kHasVariant, kVariantBias}) - .Case("pto.tmatmul.mx", - OpcodeAndVariant{0x1033, kHasVariant, kVariantDefault}) - .Case("pto.tmatmul.mx.acc", - OpcodeAndVariant{0x1033, kHasVariant, kVariantAcc}) - .Case("pto.tmatmul.mx.bias", - OpcodeAndVariant{0x1033, kHasVariant, kVariantBias}) - .Default(std::nullopt); -} - -inline const char *fullNameFromOpcodeVariant(uint16_t opcode, uint8_t variant) { - const OpInfo *info = lookupByOpcode(opcode); - if (!info) return nullptr; - if (opcode == kTscatterMaskOpcode) return "pto.tscatter"; - if (!info->has_variant_u8) return info->name; - switch (opcode) { - case 0x0006: - switch (variant) { - case kSectionCubeVariant: return "pto.section.cube"; - case kSectionVectorVariant: return "pto.section.vector"; - default: return info->name; - } - case 0x102A: - switch (variant) { - case kVariantDefault: return "pto.tgemv"; - case kVariantAcc: return "pto.tgemv.acc"; - case kVariantBias: return "pto.tgemv.bias"; - case kVariantMx: return "pto.tgemv.mx"; - case kVariantMxAcc: return "pto.tgemv.mx.acc"; - case kVariantMxBias: return "pto.tgemv.mx.bias"; - default: return info->name; - } - case 0x1032: - switch (variant) { - case kVariantDefault: return "pto.tmatmul"; - case kVariantAcc: return "pto.tmatmul.acc"; - case kVariantBias: return "pto.tmatmul.bias"; - default: return info->name; - } - case 0x1033: - switch (variant) { - case kVariantDefault: return "pto.tmatmul.mx"; - case kVariantAcc: return "pto.tmatmul.mx.acc"; - case kVariantBias: return "pto.tmatmul.mx.bias"; - default: return info->name; - } - default: return info->name; - } -} - -inline std::optional lookupOperandsByVariant(uint16_t opcode, uint8_t variant) { - switch (opcode) { - case 0x102A: - switch (variant) { - case kVariantDefault: return kTgemvOperandCount; - case kVariantAcc: return kTgemvAccOperandCount; - case kVariantBias: return kTgemvBiasOperandCount; - case kVariantMx: return kTgemvMxOperandCount; - case kVariantMxAcc: return kTgemvMxAccOperandCount; - case kVariantMxBias: return kTgemvMxBiasOperandCount; - default: return std::nullopt; - } - case 0x1032: - switch (variant) { - case kVariantDefault: return kTmatmulOperandCount; - case kVariantAcc: return kTmatmulAccOperandCount; - case kVariantBias: return kTmatmulBiasOperandCount; - default: return std::nullopt; - } - case 0x1033: - switch (variant) { - case kVariantDefault: return kTmatmulMxOperandCount; - case kVariantAcc: return kTmatmulMxAccOperandCount; - case kVariantBias: return kTmatmulMxBiasOperandCount; - default: return std::nullopt; - } - default: return std::nullopt; - } -} +std::optional lookupOpcodeAndVariantByFullName(llvm::StringRef fullName); +const char *fullNameFromOpcodeVariant(uint16_t opcode, uint8_t variant); +std::optional lookupOperandsByVariant(uint16_t opcode, uint8_t variant); } // namespace ptobc::v0 From e9cc2060f6b72225713f8ccb3fd7104bea65773e Mon Sep 17 00:00:00 2001 From: HecreReed <821896444@qq.com> Date: Tue, 26 May 2026 12:33:28 +0800 Subject: [PATCH 4/8] refactor: split view lowering and emitc arith patterns --- lib/PTO/Transforms/CMakeLists.txt | 2 + lib/PTO/Transforms/PTOToEmitC.cpp | 1589 +-------------- lib/PTO/Transforms/PTOToEmitCArith.cpp | 1782 +++++++++++++++++ lib/PTO/Transforms/PTOToEmitCInternal.h | 24 + lib/PTO/Transforms/PTOViewToMemref.cpp | 1777 +--------------- lib/PTO/Transforms/PTOViewToMemrefCompute.cpp | 1760 ++++++++++++++++ lib/PTO/Transforms/PTOViewToMemrefInternal.h | 25 + 7 files changed, 3602 insertions(+), 3357 deletions(-) create mode 100644 lib/PTO/Transforms/PTOToEmitCArith.cpp create mode 100644 lib/PTO/Transforms/PTOToEmitCInternal.h create mode 100644 lib/PTO/Transforms/PTOViewToMemrefCompute.cpp create mode 100644 lib/PTO/Transforms/PTOViewToMemrefInternal.h diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index 7f0b1b850..efe728827 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -16,8 +16,10 @@ add_mlir_dialect_library(PTOTransforms PTOInjectBarrierAllSync.cpp InsertSync/InsertSyncDebug.cpp PTOViewToMemref.cpp + PTOViewToMemrefCompute.cpp PTOValidateIntToPtrUses.cpp PTOToEmitC.cpp + PTOToEmitCArith.cpp Utils.cpp OptMemPlanForPipeline.cpp AllocToPointerCast.cpp diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index ea9466da1..a841c889e 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -19,6 +19,7 @@ #include "PTO/IR/PTOTypeUtils.h" #include "PTO/IR/PTOSyncUtils.h" #include "PTO/Transforms/Passes.h" +#include "PTOToEmitCInternal.h" #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" @@ -968,1273 +969,7 @@ static bool hasSetFFTsOp(func::FuncOp func) { return found; } -//===----------------------------------------------------------------------===// -// Arith -> EmitC (full dialect coverage for scalar ops) -//===----------------------------------------------------------------------===// - -template -struct ArithSimpleBinaryToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type dstTy = this->getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getOperands()); - return success(); - } -}; - -// Integer bitwise ops (andi/ori/xori) on signless integers: perform in unsigned -// to avoid signedness pitfalls, then cast back. -template -struct ArithUnsignedBitwiseBinaryToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = this->getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - if (bitWidth == 1) { - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getLhs(), - adaptor.getRhs()); - return success(); - } - - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value resU = rewriter.create(loc, uTy, lhsU, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, resU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithDivUIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::DivUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value divU = rewriter.create(loc, uTy, lhsU, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, divU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithRemUIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::RemUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value remU = rewriter.create(loc, uTy, lhsU, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, remU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithCeilDivUIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::CeilDivUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value one = makeEmitCIntConstant(rewriter, loc, uTy, 1); - Value rhsMinusOne = rewriter.create(loc, uTy, rhsU, one); - Value num = rewriter.create(loc, uTy, lhsU, rhsMinusOne); - Value divU = rewriter.create(loc, uTy, num, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, divU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithCeilDivSIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::CeilDivSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - Value zero = makeEmitCIntConstant(rewriter, loc, dstTy, 0); - Value one = makeEmitCIntConstant(rewriter, loc, dstTy, 1); - - Value q0 = rewriter.create(loc, dstTy, adaptor.getLhs(), - adaptor.getRhs()); - Value r = rewriter.create(loc, dstTy, adaptor.getLhs(), - adaptor.getRhs()); - - Value rNeZero = rewriter.create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::ne, r, - zero); - Value lhsLt0 = - rewriter.create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, adaptor.getLhs(), - zero); - Value rhsLt0 = - rewriter.create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, adaptor.getRhs(), - zero); - Value signsSame = - rewriter.create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::eq, lhsLt0, rhsLt0); - Value adjust = - rewriter.create(loc, rewriter.getI1Type(), - rNeZero, signsSame); - - Value qPlusOne = rewriter.create(loc, dstTy, q0, one); - Value result = rewriter.create(loc, dstTy, adjust, - qPlusOne, q0); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithFloorDivSIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::FloorDivSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - Value zero = makeEmitCIntConstant(rewriter, loc, dstTy, 0); - Value one = makeEmitCIntConstant(rewriter, loc, dstTy, 1); - - Value q0 = rewriter.create(loc, dstTy, adaptor.getLhs(), - adaptor.getRhs()); - Value r = rewriter.create(loc, dstTy, adaptor.getLhs(), - adaptor.getRhs()); - - Value rNeZero = rewriter.create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::ne, r, - zero); - Value lhsLt0 = - rewriter.create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, adaptor.getLhs(), - zero); - Value rhsLt0 = - rewriter.create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, adaptor.getRhs(), - zero); - Value signsDifferent = - rewriter.create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::ne, lhsLt0, rhsLt0); - Value adjust = - rewriter.create(loc, rewriter.getI1Type(), - rNeZero, signsDifferent); - - Value qMinusOne = rewriter.create(loc, dstTy, q0, one); - Value result = rewriter.create(loc, dstTy, adjust, - qMinusOne, q0); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithShiftLeftToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::ShLIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - if (bitWidth == 1) { - // Compute on u8 and truncate to i1. - auto u8Ty = getUnsignedIntOpaqueType(rewriter.getContext(), 8); - Value lhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getLhs()); - Value rhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getRhs()); - Value sh = rewriter.create(loc, u8Ty, lhsU8, - rhsU8); - Value masked = - rewriter.create(loc, u8Ty, sh, - makeEmitCIntConstant(rewriter, loc, - u8Ty, 1)); - rewriter.replaceOp(op, emitCCast(rewriter, loc, dstTy, masked)); - return success(); - } - - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value shU = - rewriter.create(loc, uTy, lhsU, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, shU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithShiftRightUIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::ShRUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - if (bitWidth == 1) { - // (x >> y) on i1 is either x (y==0) or 0 (y!=0); approximate in u8. - auto u8Ty = getUnsignedIntOpaqueType(rewriter.getContext(), 8); - Value lhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getLhs()); - Value rhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getRhs()); - Value sh = rewriter.create(loc, u8Ty, lhsU8, - rhsU8); - Value masked = - rewriter.create(loc, u8Ty, sh, - makeEmitCIntConstant(rewriter, loc, - u8Ty, 1)); - rewriter.replaceOp(op, emitCCast(rewriter, loc, dstTy, masked)); - return success(); - } - - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value shU = - rewriter.create(loc, uTy, lhsU, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, shU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithShiftRightSIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::ShRSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - if (bitWidth == 1) { - // (x >> y) on i1 is either x (y==0) or 0 (y!=0); approximate in u8. - auto u8Ty = getUnsignedIntOpaqueType(rewriter.getContext(), 8); - Value lhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getLhs()); - Value rhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getRhs()); - Value sh = rewriter.create(loc, u8Ty, lhsU8, - rhsU8); - Value masked = - rewriter.create(loc, u8Ty, sh, - makeEmitCIntConstant(rewriter, loc, - u8Ty, 1)); - rewriter.replaceOp(op, emitCCast(rewriter, loc, dstTy, masked)); - return success(); - } - - // Signed arithmetic shift; cast RHS to unsigned to interpret shift amount. - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value sh = - rewriter.create(loc, dstTy, adaptor.getLhs(), - rhsU); - rewriter.replaceOp(op, sh); - return success(); - } -}; - -struct ArithNegFToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::NegFOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getOperand()); - return success(); - } -}; - -struct ArithRemFToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::RemFOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - - // Use builtin `fmod` when possible. For f16, compute in float and cast back. - Type callTy = dstTy; - Value lhs = adaptor.getLhs(); - Value rhs = adaptor.getRhs(); - - if (auto opFloatTy = dyn_cast(op.getType())) { - if (opFloatTy.isF16()) { - auto f32Ty = emitc::OpaqueType::get(rewriter.getContext(), "float"); - lhs = emitCCast(rewriter, loc, f32Ty, lhs); - rhs = emitCCast(rewriter, loc, f32Ty, rhs); - callTy = f32Ty; - } - } - - // Prefer `__builtin_fmod*` to avoid relying on extra headers. - llvm::StringRef callee = "__builtin_fmod"; - if (auto opFloatTy = dyn_cast(op.getType())) { - if (opFloatTy.isF32() || opFloatTy.isF16()) - callee = "__builtin_fmodf"; - else if (opFloatTy.isF64()) - callee = "__builtin_fmod"; - } - - auto call = rewriter.create( - loc, TypeRange{callTy}, callee, ValueRange{lhs, rhs}, - /*args=*/ArrayAttr{}, /*template_args=*/ArrayAttr{}); - Value result = call.getResult(0); - if (callTy != dstTy) - result = emitCCast(rewriter, loc, dstTy, result); - - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithSelectToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!op.getCondition().getType().isInteger(1)) - return rewriter.notifyMatchFailure( - op, "only scalar i1 conditions supported for arith.select"); - - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - - auto cond = - rewriter.create(op.getLoc(), dstTy, - adaptor.getCondition(), - adaptor.getTrueValue(), - adaptor.getFalseValue()); - rewriter.replaceOp(op, cond.getResult()); - return success(); - } -}; - -struct ArithExtUIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto dstIntTy = dyn_cast(op.getType()); - auto srcIntTy = dyn_cast(op.getIn().getType()); - if (!dstIntTy || !srcIntTy) - return rewriter.notifyMatchFailure(op, "expected scalar integer types"); - - Type dstTy = getTypeConverter()->convertType(dstIntTy); - if (!dstTy) - return failure(); - - // i1 -> iN: bool to integer already behaves as 0/1. - if (srcIntTy.getWidth() == 1) { - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); - return success(); - } - - auto uDstTy = - getUnsignedIntOpaqueType(rewriter.getContext(), dstIntTy.getWidth()); - Value srcU = - castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getIn(), - srcIntTy.getWidth()); - Value extU = emitCCast(rewriter, loc, uDstTy, srcU); - Value result = emitCCast(rewriter, loc, dstTy, extU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithExtSIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto dstIntTy = dyn_cast(op.getType()); - auto srcIntTy = dyn_cast(op.getIn().getType()); - if (!dstIntTy || !srcIntTy) - return rewriter.notifyMatchFailure(op, "expected scalar integer types"); - - Type dstTy = getTypeConverter()->convertType(dstIntTy); - if (!dstTy) - return failure(); - - // i1 sign-extension: 0 -> 0, 1 -> -1. - if (srcIntTy.getWidth() == 1) { - Value zero = makeEmitCIntConstant(rewriter, loc, dstTy, 0); - Value asInt = emitCCast(rewriter, loc, dstTy, adaptor.getIn()); - Value neg = rewriter.create(loc, dstTy, zero, asInt).getResult(); - rewriter.replaceOp(op, neg); - return success(); - } - - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); - return success(); - } -}; - -template -struct ArithCastToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type dstTy = this->getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); - return success(); - } -}; - -struct ArithIndexCastUIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::IndexCastUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - - // MemRef casts are handled elsewhere; for safety, fall back to emitc.cast. - if (isa(op.getIn().getType()) || isa(op.getType())) { - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); - return success(); - } - - auto getBW = [](Type t) -> std::optional { - if (auto i = dyn_cast(t)) - return i.getWidth(); - if (isa(t)) - return kPTOIndexBitWidth; - return std::nullopt; - }; - - auto srcBW = getBW(op.getIn().getType()); - auto dstBW = getBW(op.getType()); - if (!srcBW || !dstBW) - return rewriter.notifyMatchFailure(op, "unsupported index_castui types"); - - if (*dstBW <= *srcBW) { - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); - return success(); - } - - auto uSrcTy = getUnsignedIntOpaqueType(rewriter.getContext(), *srcBW); - auto uDstTy = getUnsignedIntOpaqueType(rewriter.getContext(), *dstBW); - Value srcU = emitCCast(rewriter, loc, uSrcTy, adaptor.getIn()); - Value extU = emitCCast(rewriter, loc, uDstTy, srcU); - Value result = emitCCast(rewriter, loc, dstTy, extU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithUIToFPToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto srcIntTy = dyn_cast(op.getIn().getType()); - if (!srcIntTy) - return rewriter.notifyMatchFailure(op, "expected scalar integer input"); - - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - - // Convert via an unsigned integer type of the same width. - if (srcIntTy.getWidth() == 1) { - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); - return success(); - } - Value srcU = - castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getIn(), - srcIntTy.getWidth()); - Value fp = rewriter.create(loc, dstTy, srcU).getResult(); - rewriter.replaceOp(op, fp); - return success(); - } -}; - -struct ArithFPToUIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::FPToUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto dstIntTy = dyn_cast(op.getType()); - if (!dstIntTy) - return rewriter.notifyMatchFailure(op, "expected scalar integer result"); - - Type dstTy = getTypeConverter()->convertType(dstIntTy); - if (!dstTy) - return failure(); - - auto uDstTy = - getUnsignedIntOpaqueType(rewriter.getContext(), dstIntTy.getWidth()); - Value asU = rewriter.create(loc, uDstTy, adaptor.getIn()).getResult(); - Value result = emitCCast(rewriter, loc, dstTy, asU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithBitcastToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - - // For pointer-like types, a regular cast is fine. - if (isa(dstTy)) { - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); - return success(); - } - - // Only support scalar int/float/index bitcasts here. - auto srcTy = op.getIn().getType(); - auto dstOrigTy = op.getType(); - - auto getBitWidth = [](Type t) -> std::optional { - if (auto it = dyn_cast(t)) - return it.getWidth(); - if (auto ft = dyn_cast(t)) - return ft.getWidth(); - if (isa(t)) - return kPTOIndexBitWidth; - return std::nullopt; - }; - auto srcBW = getBitWidth(srcTy); - auto dstBW = getBitWidth(dstOrigTy); - if (!srcBW || !dstBW || *srcBW != *dstBW) - return rewriter.notifyMatchFailure(op, "bitcast requires equal bitwidth"); - - // Determine the template argument from the destination type string. - auto dstOpaque = dyn_cast(dstTy); - if (!dstOpaque) - return rewriter.notifyMatchFailure(op, "expected emitc opaque dest type"); - - auto templateArgs = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), - dstOpaque.getValue())}); - auto call = rewriter.create( - loc, TypeRange{dstTy}, "ptoas_bitcast", /*operands=*/ValueRange{adaptor.getIn()}, - /*args=*/ArrayAttr{}, /*template_args=*/templateArgs); - rewriter.replaceOp(op, call.getResult(0)); - return success(); - } -}; - -// arith.cmpf lowering with ordered/unordered semantics. -struct ArithCmpFToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - struct CmpFConfig { - bool unordered = false; - emitc::CmpPredicate predicate = emitc::CmpPredicate::eq; - }; - - static Value isNaN(ConversionPatternRewriter &rewriter, Location loc, - Value v) { - return rewriter - .create(loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, - v, v) - .getResult(); - } - - static Value isNotNaN(ConversionPatternRewriter &rewriter, Location loc, - Value v) { - return rewriter - .create(loc, rewriter.getI1Type(), emitc::CmpPredicate::eq, - v, v) - .getResult(); - } - - static std::optional buildSpecialCmpFResult( - arith::CmpFPredicate predicate, ConversionPatternRewriter &rewriter, - Location loc, Type i1Ty, Value lhs, Value rhs) { - switch (predicate) { - case arith::CmpFPredicate::AlwaysFalse: - return makeEmitCOpaqueConstant(rewriter, loc, i1Ty, "false"); - case arith::CmpFPredicate::AlwaysTrue: - return makeEmitCOpaqueConstant(rewriter, loc, i1Ty, "true"); - case arith::CmpFPredicate::ORD: - return rewriter.create( - loc, i1Ty, isNotNaN(rewriter, loc, lhs), - isNotNaN(rewriter, loc, rhs)) - .getResult(); - case arith::CmpFPredicate::UNO: - return rewriter.create( - loc, i1Ty, isNaN(rewriter, loc, lhs), - isNaN(rewriter, loc, rhs)) - .getResult(); - default: - return std::nullopt; - } - } - - static std::optional - getCmpFConfig(arith::CmpFPredicate predicate) { - switch (predicate) { - case arith::CmpFPredicate::OEQ: - return CmpFConfig{false, emitc::CmpPredicate::eq}; - case arith::CmpFPredicate::OGT: - return CmpFConfig{false, emitc::CmpPredicate::gt}; - case arith::CmpFPredicate::OGE: - return CmpFConfig{false, emitc::CmpPredicate::ge}; - case arith::CmpFPredicate::OLT: - return CmpFConfig{false, emitc::CmpPredicate::lt}; - case arith::CmpFPredicate::OLE: - return CmpFConfig{false, emitc::CmpPredicate::le}; - case arith::CmpFPredicate::ONE: - return CmpFConfig{false, emitc::CmpPredicate::ne}; - case arith::CmpFPredicate::UEQ: - return CmpFConfig{true, emitc::CmpPredicate::eq}; - case arith::CmpFPredicate::UGT: - return CmpFConfig{true, emitc::CmpPredicate::gt}; - case arith::CmpFPredicate::UGE: - return CmpFConfig{true, emitc::CmpPredicate::ge}; - case arith::CmpFPredicate::ULT: - return CmpFConfig{true, emitc::CmpPredicate::lt}; - case arith::CmpFPredicate::ULE: - return CmpFConfig{true, emitc::CmpPredicate::le}; - case arith::CmpFPredicate::UNE: - return CmpFConfig{true, emitc::CmpPredicate::ne}; - default: - return std::nullopt; - } - } - - static Value buildCmpFResult(const CmpFConfig &config, - ConversionPatternRewriter &rewriter, - Location loc, Type i1Ty, Value lhs, Value rhs) { - Value cmp = rewriter - .create(loc, i1Ty, config.predicate, lhs, rhs) - .getResult(); - Value unord = rewriter.create( - loc, i1Ty, isNaN(rewriter, loc, lhs), isNaN(rewriter, loc, rhs)); - if (config.unordered) - return rewriter - .create(loc, i1Ty, unord, cmp) - .getResult(); - Value ord = rewriter.create( - loc, i1Ty, isNotNaN(rewriter, loc, lhs), isNotNaN(rewriter, loc, rhs)); - return rewriter - .create(loc, i1Ty, ord, cmp) - .getResult(); - } - - LogicalResult matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!isa(op.getLhs().getType())) - return rewriter.notifyMatchFailure(op, "cmpf only supported on scalar floats"); - - auto loc = op.getLoc(); - auto i1Ty = rewriter.getI1Type(); - if (auto special = buildSpecialCmpFResult(op.getPredicate(), rewriter, loc, - i1Ty, adaptor.getLhs(), - adaptor.getRhs())) { - rewriter.replaceOp(op, *special); - return success(); - } - - auto config = getCmpFConfig(op.getPredicate()); - if (!config) - return rewriter.notifyMatchFailure(op, "unsupported cmpf predicate"); - rewriter.replaceOp(op, buildCmpFResult(*config, rewriter, loc, i1Ty, - adaptor.getLhs(), adaptor.getRhs())); - return success(); - } -}; - -struct ArithAddUIExtendedToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getSum().getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, - "expected scalar integer or index operands"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - SmallVector newResultTypes; - if (failed(getTypeConverter()->convertTypes(op->getResultTypes(), - newResultTypes))) - return failure(); - if (newResultTypes.size() != 2) - return failure(); - - Type sumDstTy = newResultTypes[0]; - Type overflowDstTy = newResultTypes[1]; - - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - auto wideTy = getWiderUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value lhsWide = emitCCast(rewriter, loc, wideTy, lhsU); - Value rhsWide = emitCCast(rewriter, loc, wideTy, rhsU); - Value sumWide = - rewriter.create(loc, wideTy, lhsWide, rhsWide).getResult(); - - Value sumN = emitCCast(rewriter, loc, uTy, sumWide); - Value sum = emitCCast(rewriter, loc, sumDstTy, sumN); - - Value shiftAmt = makeEmitCIntConstant(rewriter, loc, wideTy, bitWidth); - Value high = rewriter - .create(loc, wideTy, sumWide, - shiftAmt) - .getResult(); - Value zeroWide = makeEmitCIntConstant(rewriter, loc, wideTy, 0); - Value overflow = - rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::ne, high, zeroWide) - .getResult(); - overflow = emitCCast(rewriter, loc, overflowDstTy, overflow); - - rewriter.replaceOp(op, {sum, overflow}); - return success(); - } -}; - -template -struct ArithMulExtendedToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getResult(0).getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, - "expected scalar integer or index operands"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - SmallVector newResultTypes; - if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), - newResultTypes))) - return failure(); - if (newResultTypes.size() != 2) - return failure(); - - Type lowDstTy = newResultTypes[0]; - Type highDstTy = newResultTypes[1]; - - Type wideTy = isUnsigned ? (Type)getWiderUnsignedIntOpaqueType(rewriter.getContext(), - bitWidth) - : (Type)getWiderSignedIntOpaqueType(rewriter.getContext(), - bitWidth); - - Value lhsWide; - Value rhsWide; - if constexpr (isUnsigned) { - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - lhsWide = emitCCast(rewriter, loc, wideTy, lhsU); - rhsWide = emitCCast(rewriter, loc, wideTy, rhsU); - } else { - lhsWide = emitCCast(rewriter, loc, wideTy, adaptor.getLhs()); - rhsWide = emitCCast(rewriter, loc, wideTy, adaptor.getRhs()); - } - - Value prodWide = - rewriter.create(loc, wideTy, lhsWide, rhsWide).getResult(); - Value low = emitCCast(rewriter, loc, lowDstTy, prodWide); - - Value shiftAmt = makeEmitCIntConstant(rewriter, loc, wideTy, bitWidth); - Value highWide = rewriter - .create(loc, wideTy, prodWide, - shiftAmt) - .getResult(); - Value high = emitCCast(rewriter, loc, highDstTy, highWide); - - rewriter.replaceOp(op, {low, high}); - return success(); - } -}; - -using ArithMulSIExtendedToEmitC = - ArithMulExtendedToEmitC; -using ArithMulUIExtendedToEmitC = - ArithMulExtendedToEmitC; - -struct ArithMinMaxIToEmitCBase { - static Value makeSelect(ConversionPatternRewriter &rewriter, Location loc, - Type dstTy, Value cond, Value trueV, Value falseV) { - return rewriter - .create(loc, dstTy, cond, trueV, falseV) - .getResult(); - } -}; - -struct ArithMaxSIToEmitC : public OpConversionPattern, - ArithMinMaxIToEmitCBase { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::MaxSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - Value cond = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, - adaptor.getLhs(), adaptor.getRhs()) - .getResult(); - Value res = makeSelect(rewriter, loc, dstTy, cond, adaptor.getRhs(), - adaptor.getLhs()); - rewriter.replaceOp(op, res); - return success(); - } -}; - -struct ArithMinSIToEmitC : public OpConversionPattern, - ArithMinMaxIToEmitCBase { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::MinSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - Value cond = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, - adaptor.getLhs(), adaptor.getRhs()) - .getResult(); - Value res = makeSelect(rewriter, loc, dstTy, cond, adaptor.getLhs(), - adaptor.getRhs()); - rewriter.replaceOp(op, res); - return success(); - } -}; - -struct ArithMaxUIToEmitC : public OpConversionPattern, - ArithMinMaxIToEmitCBase { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::MaxUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - Value lhsU = - castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = - castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value cond = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, lhsU, rhsU) - .getResult(); - Value res = makeSelect(rewriter, loc, dstTy, cond, adaptor.getRhs(), - adaptor.getLhs()); - rewriter.replaceOp(op, res); - return success(); - } -}; - -struct ArithMinUIToEmitC : public OpConversionPattern, - ArithMinMaxIToEmitCBase { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::MinUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - Value lhsU = - castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = - castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value cond = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, lhsU, rhsU) - .getResult(); - Value res = makeSelect(rewriter, loc, dstTy, cond, adaptor.getLhs(), - adaptor.getRhs()); - rewriter.replaceOp(op, res); - return success(); - } -}; - -// Floating-point max/min variants. -struct ArithFloatMinMaxToEmitCBase { - static Value isNaN(ConversionPatternRewriter &rewriter, Location loc, - Value v) { - return rewriter - .create(loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, - v, v) - .getResult(); - } - - static Value makeFZero(ConversionPatternRewriter &rewriter, Location loc, - Type ty) { - return makeEmitCOpaqueConstant(rewriter, loc, ty, "0.0f"); - } -}; - -struct ArithMaxNumFToEmitC : public OpConversionPattern, - ArithFloatMinMaxToEmitCBase { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::MaxNumFOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - - Value lhsNaN = isNaN(rewriter, loc, adaptor.getLhs()); - Value rhsNaN = isNaN(rewriter, loc, adaptor.getRhs()); - - Value cmpLt = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, - adaptor.getLhs(), adaptor.getRhs()) - .getResult(); - Value maxNoNaN = - rewriter - .create(loc, dstTy, cmpLt, adaptor.getRhs(), - adaptor.getLhs()) - .getResult(); - - Value rhsOrMax = - rewriter - .create(loc, dstTy, rhsNaN, adaptor.getLhs(), - maxNoNaN) - .getResult(); - Value res = - rewriter - .create(loc, dstTy, lhsNaN, adaptor.getRhs(), - rhsOrMax) - .getResult(); - rewriter.replaceOp(op, res); - return success(); - } -}; - -struct ArithMinNumFToEmitC : public OpConversionPattern, - ArithFloatMinMaxToEmitCBase { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::MinNumFOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - - Value lhsNaN = isNaN(rewriter, loc, adaptor.getLhs()); - Value rhsNaN = isNaN(rewriter, loc, adaptor.getRhs()); - - Value cmpLt = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, - adaptor.getLhs(), adaptor.getRhs()) - .getResult(); - Value minNoNaN = - rewriter - .create(loc, dstTy, cmpLt, adaptor.getLhs(), - adaptor.getRhs()) - .getResult(); - - Value rhsOrMin = - rewriter - .create(loc, dstTy, rhsNaN, adaptor.getLhs(), - minNoNaN) - .getResult(); - Value res = - rewriter - .create(loc, dstTy, lhsNaN, adaptor.getRhs(), - rhsOrMin) - .getResult(); - rewriter.replaceOp(op, res); - return success(); - } -}; - -template -struct ArithMinMaxFPropagateNaNToEmitC : public OpConversionPattern, - ArithFloatMinMaxToEmitCBase { - using OpConversionPattern::OpConversionPattern; - - static Value buildPrimaryCandidate(ConversionPatternRewriter &rewriter, - Location loc, Type dstTy, Value lhs, - Value rhs) { - Value cmpLt = - rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, lhs, rhs) - .getResult(); - return rewriter - .create( - loc, dstTy, cmpLt, isMaximum ? rhs : lhs, isMaximum ? lhs : rhs) - .getResult(); - } - - static Value buildSignBitValue(ConversionPatternRewriter &rewriter, - Location loc, Value lhs, FloatType floatTy) { - auto bitsTy = - getUnsignedIntOpaqueType(rewriter.getContext(), floatTy.getWidth()); - auto templateArgs = rewriter.getArrayAttr({emitc::OpaqueAttr::get( - rewriter.getContext(), cast(bitsTy).getValue())}); - Value lhsBits = - rewriter - .create(loc, TypeRange{bitsTy}, "ptoas_bitcast", - ValueRange{lhs}, ArrayAttr{}, - templateArgs) - .getResult(0); - Value oneBits = makeEmitCIntConstant(rewriter, loc, bitsTy, 1); - Value shiftAmount = - makeEmitCIntConstant(rewriter, loc, bitsTy, floatTy.getWidth() - 1); - Value signMask = rewriter - .create(loc, bitsTy, oneBits, - shiftAmount) - .getResult(); - return rewriter - .create(loc, bitsTy, lhsBits, signMask) - .getResult(); - } - - static Value buildSignedZeroCandidate(ConversionPatternRewriter &rewriter, - Location loc, Type dstTy, Value lhs, - Value rhs, FloatType floatTy) { - Value zero = makeFZero(rewriter, loc, dstTy); - Value equal = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::eq, lhs, rhs) - .getResult(); - Value lhsZero = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::eq, lhs, - zero) - .getResult(); - Value bothZero = rewriter - .create(loc, rewriter.getI1Type(), - equal, lhsZero) - .getResult(); - auto bitsTy = - getUnsignedIntOpaqueType(rewriter.getContext(), floatTy.getWidth()); - Value zeroBits = makeEmitCIntConstant(rewriter, loc, bitsTy, 0); - Value lhsIsNegZero = - rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::ne, - buildSignBitValue(rewriter, loc, lhs, floatTy), - zeroBits) - .getResult(); - Value tie = rewriter - .create( - loc, dstTy, lhsIsNegZero, isMaximum ? rhs : lhs, - isMaximum ? lhs : rhs) - .getResult(); - return rewriter - .create(loc, dstTy, bothZero, tie, - buildPrimaryCandidate(rewriter, loc, dstTy, - lhs, rhs)) - .getResult(); - } - - static Value buildNaNPropagatingResult(ConversionPatternRewriter &rewriter, - Location loc, Type dstTy, Value lhs, - Value rhs, FloatType floatTy) { - Value lhsNaN = isNaN(rewriter, loc, lhs); - Value rhsNaN = isNaN(rewriter, loc, rhs); - Value noNaN = - buildSignedZeroCandidate(rewriter, loc, dstTy, lhs, rhs, floatTy); - Value rhsOrNoNaN = rewriter - .create(loc, dstTy, rhsNaN, rhs, - noNaN) - .getResult(); - return rewriter - .create(loc, dstTy, lhsNaN, lhs, rhsOrNoNaN) - .getResult(); - } - - LogicalResult - matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!isa(op.getType())) - return rewriter.notifyMatchFailure(op, "expected scalar float type"); - - auto loc = op.getLoc(); - Type dstTy = this->getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - - auto floatTy = cast(op.getType()); - rewriter.replaceOp(op, buildNaNPropagatingResult( - rewriter, loc, dstTy, adaptor.getLhs(), - adaptor.getRhs(), floatTy)); - return success(); - } -}; - -using ArithMaximumFToEmitC = - ArithMinMaxFPropagateNaNToEmitC; -using ArithMinimumFToEmitC = - ArithMinMaxFPropagateNaNToEmitC; +// Arith/Affine conversion patterns live in PTOToEmitCArith.cpp. //===----------------------------------------------------------------------===// // Arith -> EmitC helpers @@ -2284,7 +1019,7 @@ static emitc::OpaqueType getUnsignedIntOpaqueType(MLIRContext *ctx, } } -static emitc::OpaqueType getWiderSignedIntOpaqueType(MLIRContext *ctx, +[[maybe_unused]] static emitc::OpaqueType getWiderSignedIntOpaqueType(MLIRContext *ctx, unsigned bitWidth) { switch (bitWidth) { case 1: @@ -2301,7 +1036,7 @@ static emitc::OpaqueType getWiderSignedIntOpaqueType(MLIRContext *ctx, } } -static emitc::OpaqueType getWiderUnsignedIntOpaqueType(MLIRContext *ctx, +[[maybe_unused]] static emitc::OpaqueType getWiderUnsignedIntOpaqueType(MLIRContext *ctx, unsigned bitWidth) { switch (bitWidth) { case 1: @@ -2330,7 +1065,7 @@ static Value makeEmitCIntConstant(ConversionPatternRewriter &rewriter, return makeEmitCOpaqueConstant(rewriter, loc, type, std::to_string(value)); } -static FailureOr buildEmitCOpaqueConstantLiteral(Type targetType, +[[maybe_unused]] static FailureOr buildEmitCOpaqueConstantLiteral(Type targetType, Attribute valueAttr) { auto opaqueTy = dyn_cast(targetType); if (!opaqueTy) @@ -2382,264 +1117,6 @@ static Value castSignlessIntToUnsignedSameWidth(ConversionPatternRewriter &rewri return emitCCast(rewriter, loc, uTy, v); } -struct ArithMulIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(arith::MulIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - // i1 mul is equivalent to bitwise AND (mod 2 arithmetic). - if (bitWidth == 1) { - rewriter.replaceOpWithNewOp(op, opTy, adaptor.getLhs(), - adaptor.getRhs()); - return success(); - } - - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value mulU = rewriter.create(loc, uTy, lhsU, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, mulU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithAddIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - // i1 add is equivalent to XOR (mod 2 arithmetic). - if (bitWidth == 1) { - rewriter.replaceOpWithNewOp(op, opTy, adaptor.getLhs(), - adaptor.getRhs()); - return success(); - } - - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value addU = rewriter.create(loc, uTy, lhsU, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, addU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithCastOPToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type newTy = getTypeConverter()->convertType(op.getType()); - if (!newTy) - return failure(); - if (adaptor.getIn().getType() == newTy) { - rewriter.replaceOp(op, adaptor.getIn()); - return success(); - } - rewriter.replaceOpWithNewOp(op, newTy, adaptor.getIn()); - return success(); - } -}; - -struct ArithSubIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(arith::SubIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - // i1 sub is equivalent to XOR (mod 2 arithmetic). - if (bitWidth == 1) { - rewriter.replaceOpWithNewOp(op, opTy, adaptor.getLhs(), - adaptor.getRhs()); - return success(); - } - - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value subU = rewriter.create(loc, uTy, lhsU, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, subU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithDivSIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(arith::DivSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type newTy = getTypeConverter()->convertType(op.getType()); - if (!newTy) - return failure(); - rewriter.replaceOpWithNewOp(op, newTy, adaptor.getLhs(), - adaptor.getRhs()); - return success(); - } -}; - -struct ArithRemSIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type newTy = getTypeConverter()->convertType(op.getType()); - if (!newTy) - return failure(); - rewriter.replaceOpWithNewOp(op, newTy, adaptor.getLhs(), - adaptor.getRhs()); - return success(); - } -}; - -struct ArithTruncIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - auto dstIntTy = dyn_cast(op.getType()); - auto srcIntTy = dyn_cast(op.getIn().getType()); - if (!dstIntTy || !srcIntTy) - return rewriter.notifyMatchFailure(op, "expected scalar integer types"); - - Type dstTy = getTypeConverter()->convertType(dstIntTy); - if (!dstTy) - return failure(); - - // to-i1 conversions: Arith wants truncation to the low bit, while C/C++ - // casts to bool are equivalent to `v != 0`. Implement as `(bool)(v & 1)`. - if (dstIntTy.getWidth() == 1) { - if (srcIntTy.getWidth() == 1) { - rewriter.replaceOp(op, adaptor.getIn()); - return success(); - } - - auto uSrcTy = - getUnsignedIntOpaqueType(rewriter.getContext(), srcIntTy.getWidth()); - Value inU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getIn(), - srcIntTy.getWidth()); - Value one = makeEmitCIntConstant(rewriter, loc, uSrcTy, 1); - Value masked = - rewriter.create(loc, uSrcTy, inU, one); - Value asBool = emitCCast(rewriter, loc, dstTy, masked); - rewriter.replaceOp(op, asBool); - return success(); - } - - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); - return success(); - } -}; - -struct ArithConstantToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type newType = getTypeConverter()->convertType(op.getType()); - if (!newType) - return failure(); - - // `adaptor.getValue()` may be null if attribute conversion isn't defined. - // Use the original attribute as fallback and always cast null-safely. - Attribute valueAttr = adaptor.getValue(); - if (!valueAttr) - valueAttr = op.getValue(); - - if (auto opaqueLiteral = buildEmitCOpaqueConstantLiteral(newType, valueAttr); - succeeded(opaqueLiteral)) { - auto constAttr = emitc::OpaqueAttr::get(rewriter.getContext(), *opaqueLiteral); - rewriter.replaceOpWithNewOp(op, newType, constAttr); - return success(); - } - - if (auto floatAttr = dyn_cast_or_null(valueAttr)) { - SmallString<32> valStr; - floatAttr.getValue().toString(valStr); - llvm::StringRef s(valStr); - // Ensure the literal parses as a floating-point constant in C/C++. - // `APFloat::toString` may emit "1" for integral values; make it "1.0". - const bool hasFloatMarker = - s.contains('.') || s.contains('e') || s.contains('E') || - s.contains('p') || s.contains('P') || s.starts_with("0x") || - s.starts_with("0X") || s.starts_with("nan") || - s.starts_with("-nan") || s.starts_with("inf") || - s.starts_with("-inf"); - if (!hasFloatMarker) - valStr.append(".0"); - // Suffix: keep `f` for f16/f32; omit for f64. - if (!floatAttr.getType().isF64()) - valStr.append("f"); - auto constAttr = emitc::OpaqueAttr::get(rewriter.getContext(), valStr); - rewriter.replaceOpWithNewOp(op, newType, constAttr); - return success(); - } - - if (auto intAttr = dyn_cast_or_null(valueAttr)) { - std::string valStr = std::to_string(intAttr.getValue().getSExtValue()); - auto constAttr = emitc::OpaqueAttr::get(rewriter.getContext(), valStr); - rewriter.replaceOpWithNewOp(op, newType, constAttr); - return success(); - } - - return failure(); - } -}; //===----------------------------------------------------------------------===// // pto.mgather lowering -> MGATHER(dst, src, indexes) (pto-isa) //===----------------------------------------------------------------------===// @@ -12144,6 +10621,8 @@ static void populatePTOToEmitCPatterns(RewritePatternSet &patterns, PTOArch targetArch) { (void)solver; patterns.add(typeConverter, ctx); + populatePTOToEmitCArithPatterns(patterns, typeConverter, ctx); + patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); @@ -12237,11 +10716,6 @@ static void populatePTOToEmitCPatterns(RewritePatternSet &patterns, patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); @@ -12266,53 +10740,6 @@ static void populatePTOToEmitCPatterns(RewritePatternSet &patterns, patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add>( - typeConverter, ctx); - patterns.add>( - typeConverter, ctx); - patterns.add>( - typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add>(typeConverter, - ctx); - patterns.add>(typeConverter, - ctx); - patterns.add>(typeConverter, - ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add>(typeConverter, ctx); - patterns.add>(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add>(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add>(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); @@ -12337,8 +10764,6 @@ static void populatePTOToEmitCPatterns(RewritePatternSet &patterns, patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add>( typeConverter, ctx, diff --git a/lib/PTO/Transforms/PTOToEmitCArith.cpp b/lib/PTO/Transforms/PTOToEmitCArith.cpp new file mode 100644 index 000000000..19a0717a7 --- /dev/null +++ b/lib/PTO/Transforms/PTOToEmitCArith.cpp @@ -0,0 +1,1782 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- PTOToEmitCArith.cpp ------------------------------------------------===// +//===----------------------------------------------------------------------===// + +#include "PTOToEmitCInternal.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/APSInt.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include + +using namespace mlir; + +namespace mlir::pto { +namespace { + +static constexpr unsigned kPTOIndexBitWidth = 32; + +static emitc::OpaqueType getSignedIntOpaqueType(MLIRContext *ctx, unsigned bitWidth); +static emitc::OpaqueType getUnsignedIntOpaqueType(MLIRContext *ctx, unsigned bitWidth); +static emitc::OpaqueType getWiderSignedIntOpaqueType(MLIRContext *ctx, unsigned bitWidth); +static emitc::OpaqueType getWiderUnsignedIntOpaqueType(MLIRContext *ctx, unsigned bitWidth); +static Value makeEmitCOpaqueConstant(ConversionPatternRewriter &rewriter, + Location loc, Type type, + llvm::StringRef literal); +static Value makeEmitCIntConstant(ConversionPatternRewriter &rewriter, + Location loc, Type type, int64_t value); +static FailureOr buildEmitCOpaqueConstantLiteral(Type targetType, + Attribute valueAttr); +static Value emitCCast(ConversionPatternRewriter &rewriter, Location loc, + Type dstType, Value src); +static Value castSignlessIntToUnsignedSameWidth(ConversionPatternRewriter &rewriter, + Location loc, Value v, + unsigned bitWidth); + +//===----------------------------------------------------------------------===// +// Arith -> EmitC (full dialect coverage for scalar ops) +//===----------------------------------------------------------------------===// + +template +struct ArithSimpleBinaryToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type dstTy = this->getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getOperands()); + return success(); + } +}; + +// Integer bitwise ops (andi/ori/xori) on signless integers: perform in unsigned +// to avoid signedness pitfalls, then cast back. +template +struct ArithUnsignedBitwiseBinaryToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = this->getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + if (bitWidth == 1) { + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value resU = rewriter.create(loc, uTy, lhsU, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, resU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithDivUIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::DivUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value divU = rewriter.create(loc, uTy, lhsU, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, divU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithRemUIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::RemUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value remU = rewriter.create(loc, uTy, lhsU, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, remU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithCeilDivUIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::CeilDivUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value one = makeEmitCIntConstant(rewriter, loc, uTy, 1); + Value rhsMinusOne = rewriter.create(loc, uTy, rhsU, one); + Value num = rewriter.create(loc, uTy, lhsU, rhsMinusOne); + Value divU = rewriter.create(loc, uTy, num, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, divU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithCeilDivSIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::CeilDivSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + Value zero = makeEmitCIntConstant(rewriter, loc, dstTy, 0); + Value one = makeEmitCIntConstant(rewriter, loc, dstTy, 1); + + Value q0 = rewriter.create(loc, dstTy, adaptor.getLhs(), + adaptor.getRhs()); + Value r = rewriter.create(loc, dstTy, adaptor.getLhs(), + adaptor.getRhs()); + + Value rNeZero = rewriter.create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::ne, r, + zero); + Value lhsLt0 = + rewriter.create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, adaptor.getLhs(), + zero); + Value rhsLt0 = + rewriter.create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, adaptor.getRhs(), + zero); + Value signsSame = + rewriter.create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::eq, lhsLt0, rhsLt0); + Value adjust = + rewriter.create(loc, rewriter.getI1Type(), + rNeZero, signsSame); + + Value qPlusOne = rewriter.create(loc, dstTy, q0, one); + Value result = rewriter.create(loc, dstTy, adjust, + qPlusOne, q0); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithFloorDivSIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::FloorDivSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + Value zero = makeEmitCIntConstant(rewriter, loc, dstTy, 0); + Value one = makeEmitCIntConstant(rewriter, loc, dstTy, 1); + + Value q0 = rewriter.create(loc, dstTy, adaptor.getLhs(), + adaptor.getRhs()); + Value r = rewriter.create(loc, dstTy, adaptor.getLhs(), + adaptor.getRhs()); + + Value rNeZero = rewriter.create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::ne, r, + zero); + Value lhsLt0 = + rewriter.create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, adaptor.getLhs(), + zero); + Value rhsLt0 = + rewriter.create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, adaptor.getRhs(), + zero); + Value signsDifferent = + rewriter.create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::ne, lhsLt0, rhsLt0); + Value adjust = + rewriter.create(loc, rewriter.getI1Type(), + rNeZero, signsDifferent); + + Value qMinusOne = rewriter.create(loc, dstTy, q0, one); + Value result = rewriter.create(loc, dstTy, adjust, + qMinusOne, q0); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithShiftLeftToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::ShLIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + if (bitWidth == 1) { + // Compute on u8 and truncate to i1. + auto u8Ty = getUnsignedIntOpaqueType(rewriter.getContext(), 8); + Value lhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getLhs()); + Value rhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getRhs()); + Value sh = rewriter.create(loc, u8Ty, lhsU8, + rhsU8); + Value masked = + rewriter.create(loc, u8Ty, sh, + makeEmitCIntConstant(rewriter, loc, + u8Ty, 1)); + rewriter.replaceOp(op, emitCCast(rewriter, loc, dstTy, masked)); + return success(); + } + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value shU = + rewriter.create(loc, uTy, lhsU, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, shU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithShiftRightUIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::ShRUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + if (bitWidth == 1) { + // (x >> y) on i1 is either x (y==0) or 0 (y!=0); approximate in u8. + auto u8Ty = getUnsignedIntOpaqueType(rewriter.getContext(), 8); + Value lhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getLhs()); + Value rhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getRhs()); + Value sh = rewriter.create(loc, u8Ty, lhsU8, + rhsU8); + Value masked = + rewriter.create(loc, u8Ty, sh, + makeEmitCIntConstant(rewriter, loc, + u8Ty, 1)); + rewriter.replaceOp(op, emitCCast(rewriter, loc, dstTy, masked)); + return success(); + } + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value shU = + rewriter.create(loc, uTy, lhsU, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, shU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithShiftRightSIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::ShRSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + if (bitWidth == 1) { + // (x >> y) on i1 is either x (y==0) or 0 (y!=0); approximate in u8. + auto u8Ty = getUnsignedIntOpaqueType(rewriter.getContext(), 8); + Value lhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getLhs()); + Value rhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getRhs()); + Value sh = rewriter.create(loc, u8Ty, lhsU8, + rhsU8); + Value masked = + rewriter.create(loc, u8Ty, sh, + makeEmitCIntConstant(rewriter, loc, + u8Ty, 1)); + rewriter.replaceOp(op, emitCCast(rewriter, loc, dstTy, masked)); + return success(); + } + + // Signed arithmetic shift; cast RHS to unsigned to interpret shift amount. + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value sh = + rewriter.create(loc, dstTy, adaptor.getLhs(), + rhsU); + rewriter.replaceOp(op, sh); + return success(); + } +}; + +struct ArithNegFToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::NegFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getOperand()); + return success(); + } +}; + +struct ArithRemFToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::RemFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + + // Use builtin `fmod` when possible. For f16, compute in float and cast back. + Type callTy = dstTy; + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + + if (auto opFloatTy = dyn_cast(op.getType())) { + if (opFloatTy.isF16()) { + auto f32Ty = emitc::OpaqueType::get(rewriter.getContext(), "float"); + lhs = emitCCast(rewriter, loc, f32Ty, lhs); + rhs = emitCCast(rewriter, loc, f32Ty, rhs); + callTy = f32Ty; + } + } + + // Prefer `__builtin_fmod*` to avoid relying on extra headers. + llvm::StringRef callee = "__builtin_fmod"; + if (auto opFloatTy = dyn_cast(op.getType())) { + if (opFloatTy.isF32() || opFloatTy.isF16()) + callee = "__builtin_fmodf"; + else if (opFloatTy.isF64()) + callee = "__builtin_fmod"; + } + + auto call = rewriter.create( + loc, TypeRange{callTy}, callee, ValueRange{lhs, rhs}, + /*args=*/ArrayAttr{}, /*template_args=*/ArrayAttr{}); + Value result = call.getResult(0); + if (callTy != dstTy) + result = emitCCast(rewriter, loc, dstTy, result); + + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithSelectToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getCondition().getType().isInteger(1)) + return rewriter.notifyMatchFailure( + op, "only scalar i1 conditions supported for arith.select"); + + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + + auto cond = + rewriter.create(op.getLoc(), dstTy, + adaptor.getCondition(), + adaptor.getTrueValue(), + adaptor.getFalseValue()); + rewriter.replaceOp(op, cond.getResult()); + return success(); + } +}; + +struct ArithExtUIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto dstIntTy = dyn_cast(op.getType()); + auto srcIntTy = dyn_cast(op.getIn().getType()); + if (!dstIntTy || !srcIntTy) + return rewriter.notifyMatchFailure(op, "expected scalar integer types"); + + Type dstTy = getTypeConverter()->convertType(dstIntTy); + if (!dstTy) + return failure(); + + // i1 -> iN: bool to integer already behaves as 0/1. + if (srcIntTy.getWidth() == 1) { + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); + return success(); + } + + auto uDstTy = + getUnsignedIntOpaqueType(rewriter.getContext(), dstIntTy.getWidth()); + Value srcU = + castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getIn(), + srcIntTy.getWidth()); + Value extU = emitCCast(rewriter, loc, uDstTy, srcU); + Value result = emitCCast(rewriter, loc, dstTy, extU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithExtSIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto dstIntTy = dyn_cast(op.getType()); + auto srcIntTy = dyn_cast(op.getIn().getType()); + if (!dstIntTy || !srcIntTy) + return rewriter.notifyMatchFailure(op, "expected scalar integer types"); + + Type dstTy = getTypeConverter()->convertType(dstIntTy); + if (!dstTy) + return failure(); + + // i1 sign-extension: 0 -> 0, 1 -> -1. + if (srcIntTy.getWidth() == 1) { + Value zero = makeEmitCIntConstant(rewriter, loc, dstTy, 0); + Value asInt = emitCCast(rewriter, loc, dstTy, adaptor.getIn()); + Value neg = rewriter.create(loc, dstTy, zero, asInt).getResult(); + rewriter.replaceOp(op, neg); + return success(); + } + + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); + return success(); + } +}; + +template +struct ArithCastToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type dstTy = this->getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); + return success(); + } +}; + +struct ArithIndexCastUIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::IndexCastUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + + // MemRef casts are handled elsewhere; for safety, fall back to emitc.cast. + if (isa(op.getIn().getType()) || isa(op.getType())) { + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); + return success(); + } + + auto getBW = [](Type t) -> std::optional { + if (auto i = dyn_cast(t)) + return i.getWidth(); + if (isa(t)) + return kPTOIndexBitWidth; + return std::nullopt; + }; + + auto srcBW = getBW(op.getIn().getType()); + auto dstBW = getBW(op.getType()); + if (!srcBW || !dstBW) + return rewriter.notifyMatchFailure(op, "unsupported index_castui types"); + + if (*dstBW <= *srcBW) { + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); + return success(); + } + + auto uSrcTy = getUnsignedIntOpaqueType(rewriter.getContext(), *srcBW); + auto uDstTy = getUnsignedIntOpaqueType(rewriter.getContext(), *dstBW); + Value srcU = emitCCast(rewriter, loc, uSrcTy, adaptor.getIn()); + Value extU = emitCCast(rewriter, loc, uDstTy, srcU); + Value result = emitCCast(rewriter, loc, dstTy, extU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithUIToFPToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto srcIntTy = dyn_cast(op.getIn().getType()); + if (!srcIntTy) + return rewriter.notifyMatchFailure(op, "expected scalar integer input"); + + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + + // Convert via an unsigned integer type of the same width. + if (srcIntTy.getWidth() == 1) { + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); + return success(); + } + Value srcU = + castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getIn(), + srcIntTy.getWidth()); + Value fp = rewriter.create(loc, dstTy, srcU).getResult(); + rewriter.replaceOp(op, fp); + return success(); + } +}; + +struct ArithFPToUIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::FPToUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto dstIntTy = dyn_cast(op.getType()); + if (!dstIntTy) + return rewriter.notifyMatchFailure(op, "expected scalar integer result"); + + Type dstTy = getTypeConverter()->convertType(dstIntTy); + if (!dstTy) + return failure(); + + auto uDstTy = + getUnsignedIntOpaqueType(rewriter.getContext(), dstIntTy.getWidth()); + Value asU = rewriter.create(loc, uDstTy, adaptor.getIn()).getResult(); + Value result = emitCCast(rewriter, loc, dstTy, asU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithBitcastToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + + // For pointer-like types, a regular cast is fine. + if (isa(dstTy)) { + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); + return success(); + } + + // Only support scalar int/float/index bitcasts here. + auto srcTy = op.getIn().getType(); + auto dstOrigTy = op.getType(); + + auto getBitWidth = [](Type t) -> std::optional { + if (auto it = dyn_cast(t)) + return it.getWidth(); + if (auto ft = dyn_cast(t)) + return ft.getWidth(); + if (isa(t)) + return kPTOIndexBitWidth; + return std::nullopt; + }; + auto srcBW = getBitWidth(srcTy); + auto dstBW = getBitWidth(dstOrigTy); + if (!srcBW || !dstBW || *srcBW != *dstBW) + return rewriter.notifyMatchFailure(op, "bitcast requires equal bitwidth"); + + // Determine the template argument from the destination type string. + auto dstOpaque = dyn_cast(dstTy); + if (!dstOpaque) + return rewriter.notifyMatchFailure(op, "expected emitc opaque dest type"); + + auto templateArgs = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), + dstOpaque.getValue())}); + auto call = rewriter.create( + loc, TypeRange{dstTy}, "ptoas_bitcast", /*operands=*/ValueRange{adaptor.getIn()}, + /*args=*/ArrayAttr{}, /*template_args=*/templateArgs); + rewriter.replaceOp(op, call.getResult(0)); + return success(); + } +}; + +// arith.cmpf lowering with ordered/unordered semantics. +struct ArithCmpFToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + struct CmpFConfig { + bool unordered = false; + emitc::CmpPredicate predicate = emitc::CmpPredicate::eq; + }; + + static Value isNaN(ConversionPatternRewriter &rewriter, Location loc, + Value v) { + return rewriter + .create(loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, + v, v) + .getResult(); + } + + static Value isNotNaN(ConversionPatternRewriter &rewriter, Location loc, + Value v) { + return rewriter + .create(loc, rewriter.getI1Type(), emitc::CmpPredicate::eq, + v, v) + .getResult(); + } + + static std::optional buildSpecialCmpFResult( + arith::CmpFPredicate predicate, ConversionPatternRewriter &rewriter, + Location loc, Type i1Ty, Value lhs, Value rhs) { + switch (predicate) { + case arith::CmpFPredicate::AlwaysFalse: + return makeEmitCOpaqueConstant(rewriter, loc, i1Ty, "false"); + case arith::CmpFPredicate::AlwaysTrue: + return makeEmitCOpaqueConstant(rewriter, loc, i1Ty, "true"); + case arith::CmpFPredicate::ORD: + return rewriter.create( + loc, i1Ty, isNotNaN(rewriter, loc, lhs), + isNotNaN(rewriter, loc, rhs)) + .getResult(); + case arith::CmpFPredicate::UNO: + return rewriter.create( + loc, i1Ty, isNaN(rewriter, loc, lhs), + isNaN(rewriter, loc, rhs)) + .getResult(); + default: + return std::nullopt; + } + } + + static std::optional + getCmpFConfig(arith::CmpFPredicate predicate) { + switch (predicate) { + case arith::CmpFPredicate::OEQ: + return CmpFConfig{false, emitc::CmpPredicate::eq}; + case arith::CmpFPredicate::OGT: + return CmpFConfig{false, emitc::CmpPredicate::gt}; + case arith::CmpFPredicate::OGE: + return CmpFConfig{false, emitc::CmpPredicate::ge}; + case arith::CmpFPredicate::OLT: + return CmpFConfig{false, emitc::CmpPredicate::lt}; + case arith::CmpFPredicate::OLE: + return CmpFConfig{false, emitc::CmpPredicate::le}; + case arith::CmpFPredicate::ONE: + return CmpFConfig{false, emitc::CmpPredicate::ne}; + case arith::CmpFPredicate::UEQ: + return CmpFConfig{true, emitc::CmpPredicate::eq}; + case arith::CmpFPredicate::UGT: + return CmpFConfig{true, emitc::CmpPredicate::gt}; + case arith::CmpFPredicate::UGE: + return CmpFConfig{true, emitc::CmpPredicate::ge}; + case arith::CmpFPredicate::ULT: + return CmpFConfig{true, emitc::CmpPredicate::lt}; + case arith::CmpFPredicate::ULE: + return CmpFConfig{true, emitc::CmpPredicate::le}; + case arith::CmpFPredicate::UNE: + return CmpFConfig{true, emitc::CmpPredicate::ne}; + default: + return std::nullopt; + } + } + + static Value buildCmpFResult(const CmpFConfig &config, + ConversionPatternRewriter &rewriter, + Location loc, Type i1Ty, Value lhs, Value rhs) { + Value cmp = rewriter + .create(loc, i1Ty, config.predicate, lhs, rhs) + .getResult(); + Value unord = rewriter.create( + loc, i1Ty, isNaN(rewriter, loc, lhs), isNaN(rewriter, loc, rhs)); + if (config.unordered) + return rewriter + .create(loc, i1Ty, unord, cmp) + .getResult(); + Value ord = rewriter.create( + loc, i1Ty, isNotNaN(rewriter, loc, lhs), isNotNaN(rewriter, loc, rhs)); + return rewriter + .create(loc, i1Ty, ord, cmp) + .getResult(); + } + + LogicalResult matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isa(op.getLhs().getType())) + return rewriter.notifyMatchFailure(op, "cmpf only supported on scalar floats"); + + auto loc = op.getLoc(); + auto i1Ty = rewriter.getI1Type(); + if (auto special = buildSpecialCmpFResult(op.getPredicate(), rewriter, loc, + i1Ty, adaptor.getLhs(), + adaptor.getRhs())) { + rewriter.replaceOp(op, *special); + return success(); + } + + auto config = getCmpFConfig(op.getPredicate()); + if (!config) + return rewriter.notifyMatchFailure(op, "unsupported cmpf predicate"); + rewriter.replaceOp(op, buildCmpFResult(*config, rewriter, loc, i1Ty, + adaptor.getLhs(), adaptor.getRhs())); + return success(); + } +}; + +struct ArithAddUIExtendedToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getSum().getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, + "expected scalar integer or index operands"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + SmallVector newResultTypes; + if (failed(getTypeConverter()->convertTypes(op->getResultTypes(), + newResultTypes))) + return failure(); + if (newResultTypes.size() != 2) + return failure(); + + Type sumDstTy = newResultTypes[0]; + Type overflowDstTy = newResultTypes[1]; + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + auto wideTy = getWiderUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value lhsWide = emitCCast(rewriter, loc, wideTy, lhsU); + Value rhsWide = emitCCast(rewriter, loc, wideTy, rhsU); + Value sumWide = + rewriter.create(loc, wideTy, lhsWide, rhsWide).getResult(); + + Value sumN = emitCCast(rewriter, loc, uTy, sumWide); + Value sum = emitCCast(rewriter, loc, sumDstTy, sumN); + + Value shiftAmt = makeEmitCIntConstant(rewriter, loc, wideTy, bitWidth); + Value high = rewriter + .create(loc, wideTy, sumWide, + shiftAmt) + .getResult(); + Value zeroWide = makeEmitCIntConstant(rewriter, loc, wideTy, 0); + Value overflow = + rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::ne, high, zeroWide) + .getResult(); + overflow = emitCCast(rewriter, loc, overflowDstTy, overflow); + + rewriter.replaceOp(op, {sum, overflow}); + return success(); + } +}; + +template +struct ArithMulExtendedToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getResult(0).getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, + "expected scalar integer or index operands"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + SmallVector newResultTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), + newResultTypes))) + return failure(); + if (newResultTypes.size() != 2) + return failure(); + + Type lowDstTy = newResultTypes[0]; + Type highDstTy = newResultTypes[1]; + + Type wideTy = isUnsigned ? (Type)getWiderUnsignedIntOpaqueType(rewriter.getContext(), + bitWidth) + : (Type)getWiderSignedIntOpaqueType(rewriter.getContext(), + bitWidth); + + Value lhsWide; + Value rhsWide; + if constexpr (isUnsigned) { + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + lhsWide = emitCCast(rewriter, loc, wideTy, lhsU); + rhsWide = emitCCast(rewriter, loc, wideTy, rhsU); + } else { + lhsWide = emitCCast(rewriter, loc, wideTy, adaptor.getLhs()); + rhsWide = emitCCast(rewriter, loc, wideTy, adaptor.getRhs()); + } + + Value prodWide = + rewriter.create(loc, wideTy, lhsWide, rhsWide).getResult(); + Value low = emitCCast(rewriter, loc, lowDstTy, prodWide); + + Value shiftAmt = makeEmitCIntConstant(rewriter, loc, wideTy, bitWidth); + Value highWide = rewriter + .create(loc, wideTy, prodWide, + shiftAmt) + .getResult(); + Value high = emitCCast(rewriter, loc, highDstTy, highWide); + + rewriter.replaceOp(op, {low, high}); + return success(); + } +}; + +using ArithMulSIExtendedToEmitC = + ArithMulExtendedToEmitC; +using ArithMulUIExtendedToEmitC = + ArithMulExtendedToEmitC; + +struct ArithMinMaxIToEmitCBase { + static Value makeSelect(ConversionPatternRewriter &rewriter, Location loc, + Type dstTy, Value cond, Value trueV, Value falseV) { + return rewriter + .create(loc, dstTy, cond, trueV, falseV) + .getResult(); + } +}; + +struct ArithMaxSIToEmitC : public OpConversionPattern, + ArithMinMaxIToEmitCBase { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::MaxSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + Value cond = rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, + adaptor.getLhs(), adaptor.getRhs()) + .getResult(); + Value res = makeSelect(rewriter, loc, dstTy, cond, adaptor.getRhs(), + adaptor.getLhs()); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ArithMinSIToEmitC : public OpConversionPattern, + ArithMinMaxIToEmitCBase { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::MinSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + Value cond = rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, + adaptor.getLhs(), adaptor.getRhs()) + .getResult(); + Value res = makeSelect(rewriter, loc, dstTy, cond, adaptor.getLhs(), + adaptor.getRhs()); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ArithMaxUIToEmitC : public OpConversionPattern, + ArithMinMaxIToEmitCBase { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::MaxUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + Value lhsU = + castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = + castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value cond = rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, lhsU, rhsU) + .getResult(); + Value res = makeSelect(rewriter, loc, dstTy, cond, adaptor.getRhs(), + adaptor.getLhs()); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ArithMinUIToEmitC : public OpConversionPattern, + ArithMinMaxIToEmitCBase { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::MinUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + Value lhsU = + castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = + castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value cond = rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, lhsU, rhsU) + .getResult(); + Value res = makeSelect(rewriter, loc, dstTy, cond, adaptor.getLhs(), + adaptor.getRhs()); + rewriter.replaceOp(op, res); + return success(); + } +}; + +// Floating-point max/min variants. +struct ArithFloatMinMaxToEmitCBase { + static Value isNaN(ConversionPatternRewriter &rewriter, Location loc, + Value v) { + return rewriter + .create(loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, + v, v) + .getResult(); + } + + static Value makeFZero(ConversionPatternRewriter &rewriter, Location loc, + Type ty) { + return makeEmitCOpaqueConstant(rewriter, loc, ty, "0.0f"); + } +}; + +struct ArithMaxNumFToEmitC : public OpConversionPattern, + ArithFloatMinMaxToEmitCBase { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::MaxNumFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + + Value lhsNaN = isNaN(rewriter, loc, adaptor.getLhs()); + Value rhsNaN = isNaN(rewriter, loc, adaptor.getRhs()); + + Value cmpLt = rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, + adaptor.getLhs(), adaptor.getRhs()) + .getResult(); + Value maxNoNaN = + rewriter + .create(loc, dstTy, cmpLt, adaptor.getRhs(), + adaptor.getLhs()) + .getResult(); + + Value rhsOrMax = + rewriter + .create(loc, dstTy, rhsNaN, adaptor.getLhs(), + maxNoNaN) + .getResult(); + Value res = + rewriter + .create(loc, dstTy, lhsNaN, adaptor.getRhs(), + rhsOrMax) + .getResult(); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ArithMinNumFToEmitC : public OpConversionPattern, + ArithFloatMinMaxToEmitCBase { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::MinNumFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + + Value lhsNaN = isNaN(rewriter, loc, adaptor.getLhs()); + Value rhsNaN = isNaN(rewriter, loc, adaptor.getRhs()); + + Value cmpLt = rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, + adaptor.getLhs(), adaptor.getRhs()) + .getResult(); + Value minNoNaN = + rewriter + .create(loc, dstTy, cmpLt, adaptor.getLhs(), + adaptor.getRhs()) + .getResult(); + + Value rhsOrMin = + rewriter + .create(loc, dstTy, rhsNaN, adaptor.getLhs(), + minNoNaN) + .getResult(); + Value res = + rewriter + .create(loc, dstTy, lhsNaN, adaptor.getRhs(), + rhsOrMin) + .getResult(); + rewriter.replaceOp(op, res); + return success(); + } +}; + +template +struct ArithMinMaxFPropagateNaNToEmitC : public OpConversionPattern, + ArithFloatMinMaxToEmitCBase { + using OpConversionPattern::OpConversionPattern; + + static Value buildPrimaryCandidate(ConversionPatternRewriter &rewriter, + Location loc, Type dstTy, Value lhs, + Value rhs) { + Value cmpLt = + rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, lhs, rhs) + .getResult(); + return rewriter + .create( + loc, dstTy, cmpLt, isMaximum ? rhs : lhs, isMaximum ? lhs : rhs) + .getResult(); + } + + static Value buildSignBitValue(ConversionPatternRewriter &rewriter, + Location loc, Value lhs, FloatType floatTy) { + auto bitsTy = + getUnsignedIntOpaqueType(rewriter.getContext(), floatTy.getWidth()); + auto templateArgs = rewriter.getArrayAttr({emitc::OpaqueAttr::get( + rewriter.getContext(), cast(bitsTy).getValue())}); + Value lhsBits = + rewriter + .create(loc, TypeRange{bitsTy}, "ptoas_bitcast", + ValueRange{lhs}, ArrayAttr{}, + templateArgs) + .getResult(0); + Value oneBits = makeEmitCIntConstant(rewriter, loc, bitsTy, 1); + Value shiftAmount = + makeEmitCIntConstant(rewriter, loc, bitsTy, floatTy.getWidth() - 1); + Value signMask = rewriter + .create(loc, bitsTy, oneBits, + shiftAmount) + .getResult(); + return rewriter + .create(loc, bitsTy, lhsBits, signMask) + .getResult(); + } + + static Value buildSignedZeroCandidate(ConversionPatternRewriter &rewriter, + Location loc, Type dstTy, Value lhs, + Value rhs, FloatType floatTy) { + Value zero = makeFZero(rewriter, loc, dstTy); + Value equal = rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::eq, lhs, rhs) + .getResult(); + Value lhsZero = rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::eq, lhs, + zero) + .getResult(); + Value bothZero = rewriter + .create(loc, rewriter.getI1Type(), + equal, lhsZero) + .getResult(); + auto bitsTy = + getUnsignedIntOpaqueType(rewriter.getContext(), floatTy.getWidth()); + Value zeroBits = makeEmitCIntConstant(rewriter, loc, bitsTy, 0); + Value lhsIsNegZero = + rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::ne, + buildSignBitValue(rewriter, loc, lhs, floatTy), + zeroBits) + .getResult(); + Value tie = rewriter + .create( + loc, dstTy, lhsIsNegZero, isMaximum ? rhs : lhs, + isMaximum ? lhs : rhs) + .getResult(); + return rewriter + .create(loc, dstTy, bothZero, tie, + buildPrimaryCandidate(rewriter, loc, dstTy, + lhs, rhs)) + .getResult(); + } + + static Value buildNaNPropagatingResult(ConversionPatternRewriter &rewriter, + Location loc, Type dstTy, Value lhs, + Value rhs, FloatType floatTy) { + Value lhsNaN = isNaN(rewriter, loc, lhs); + Value rhsNaN = isNaN(rewriter, loc, rhs); + Value noNaN = + buildSignedZeroCandidate(rewriter, loc, dstTy, lhs, rhs, floatTy); + Value rhsOrNoNaN = rewriter + .create(loc, dstTy, rhsNaN, rhs, + noNaN) + .getResult(); + return rewriter + .create(loc, dstTy, lhsNaN, lhs, rhsOrNoNaN) + .getResult(); + } + + LogicalResult + matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isa(op.getType())) + return rewriter.notifyMatchFailure(op, "expected scalar float type"); + + auto loc = op.getLoc(); + Type dstTy = this->getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + + auto floatTy = cast(op.getType()); + rewriter.replaceOp(op, buildNaNPropagatingResult( + rewriter, loc, dstTy, adaptor.getLhs(), + adaptor.getRhs(), floatTy)); + return success(); + } +}; + +using ArithMaximumFToEmitC = + ArithMinMaxFPropagateNaNToEmitC; +using ArithMinimumFToEmitC = + ArithMinMaxFPropagateNaNToEmitC; + +//===----------------------------------------------------------------------===// +// Arith -> EmitC helpers +//===----------------------------------------------------------------------===// + +static emitc::OpaqueType getSignedIntOpaqueType(MLIRContext *ctx, + unsigned bitWidth) { + switch (bitWidth) { + case 1: + return emitc::OpaqueType::get(ctx, "int8_t"); + case 8: + return emitc::OpaqueType::get(ctx, "int8_t"); + case 16: + return emitc::OpaqueType::get(ctx, "int16_t"); + case 32: + return emitc::OpaqueType::get(ctx, "int32_t"); + case 64: + return emitc::OpaqueType::get(ctx, "int64_t"); + case 128: + return emitc::OpaqueType::get(ctx, "__int128"); + default: + llvm::errs() << "[Debug] Unsupported signed integer bitwidth: " << bitWidth + << "\n"; + return emitc::OpaqueType::get(ctx, "int64_t"); + } +} + +static emitc::OpaqueType getUnsignedIntOpaqueType(MLIRContext *ctx, + unsigned bitWidth) { + switch (bitWidth) { + case 1: + return emitc::OpaqueType::get(ctx, "uint8_t"); + case 8: + return emitc::OpaqueType::get(ctx, "uint8_t"); + case 16: + return emitc::OpaqueType::get(ctx, "uint16_t"); + case 32: + return emitc::OpaqueType::get(ctx, "uint32_t"); + case 64: + return emitc::OpaqueType::get(ctx, "uint64_t"); + case 128: + return emitc::OpaqueType::get(ctx, "unsigned __int128"); + default: + llvm::errs() << "[Debug] Unsupported unsigned integer bitwidth: " + << bitWidth << "\n"; + return emitc::OpaqueType::get(ctx, "uint64_t"); + } +} + +static emitc::OpaqueType getWiderSignedIntOpaqueType(MLIRContext *ctx, + unsigned bitWidth) { + switch (bitWidth) { + case 1: + case 8: + return getSignedIntOpaqueType(ctx, 16); + case 16: + return getSignedIntOpaqueType(ctx, 32); + case 32: + return getSignedIntOpaqueType(ctx, 64); + case 64: + return getSignedIntOpaqueType(ctx, 128); + default: + return getSignedIntOpaqueType(ctx, 128); + } +} + +static emitc::OpaqueType getWiderUnsignedIntOpaqueType(MLIRContext *ctx, + unsigned bitWidth) { + switch (bitWidth) { + case 1: + case 8: + return getUnsignedIntOpaqueType(ctx, 16); + case 16: + return getUnsignedIntOpaqueType(ctx, 32); + case 32: + return getUnsignedIntOpaqueType(ctx, 64); + case 64: + return getUnsignedIntOpaqueType(ctx, 128); + default: + return getUnsignedIntOpaqueType(ctx, 128); + } +} + +static Value makeEmitCOpaqueConstant(ConversionPatternRewriter &rewriter, + Location loc, Type type, + llvm::StringRef literal) { + auto attr = emitc::OpaqueAttr::get(rewriter.getContext(), literal); + return rewriter.create(loc, type, attr); +} + +static Value makeEmitCIntConstant(ConversionPatternRewriter &rewriter, + Location loc, Type type, int64_t value) { + return makeEmitCOpaqueConstant(rewriter, loc, type, std::to_string(value)); +} + +static FailureOr buildEmitCOpaqueConstantLiteral(Type targetType, + Attribute valueAttr) { + auto opaqueTy = dyn_cast(targetType); + if (!opaqueTy) + return failure(); + + if (opaqueTy.getValue() == "pto::MrgSortExecutedNumList") { + auto dense = dyn_cast_or_null(valueAttr); + if (!dense) + return failure(); + + auto vecTy = dyn_cast(dense.getType()); + if (!vecTy || vecTy.getRank() != 1 || vecTy.getNumElements() != 4 || + !vecTy.getElementType().isInteger(16)) + return failure(); + + std::string literal; + llvm::raw_string_ostream os(literal); + os << "pto::MrgSortExecutedNumList{"; + bool first = true; + for (APInt elem : dense.getValues()) { + if (!first) + os << ", "; + first = false; + os << elem.getZExtValue(); + } + os << "}"; + os.flush(); + return literal; + } + + return failure(); +} + +static Value emitCCast(ConversionPatternRewriter &rewriter, Location loc, + Type dstType, Value src) { + if (src.getType() == dstType) + return src; + return rewriter.createOrFold(loc, dstType, src); +} + +// For signless iN integers lowered to signed C++ types, this creates a value +// representing the same N-bit pattern in an unsigned C++ type of the same +// width. This avoids incorrect sign-extension when later widening to a larger +// unsigned type. +static Value castSignlessIntToUnsignedSameWidth(ConversionPatternRewriter &rewriter, + Location loc, Value v, + unsigned bitWidth) { + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + return emitCCast(rewriter, loc, uTy, v); +} + +struct ArithMulIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(arith::MulIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + // i1 mul is equivalent to bitwise AND (mod 2 arithmetic). + if (bitWidth == 1) { + rewriter.replaceOpWithNewOp(op, opTy, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value mulU = rewriter.create(loc, uTy, lhsU, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, mulU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithAddIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + // i1 add is equivalent to XOR (mod 2 arithmetic). + if (bitWidth == 1) { + rewriter.replaceOpWithNewOp(op, opTy, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value addU = rewriter.create(loc, uTy, lhsU, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, addU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithCastOPToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type newTy = getTypeConverter()->convertType(op.getType()); + if (!newTy) + return failure(); + if (adaptor.getIn().getType() == newTy) { + rewriter.replaceOp(op, adaptor.getIn()); + return success(); + } + rewriter.replaceOpWithNewOp(op, newTy, adaptor.getIn()); + return success(); + } +}; + +struct ArithSubIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(arith::SubIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + // i1 sub is equivalent to XOR (mod 2 arithmetic). + if (bitWidth == 1) { + rewriter.replaceOpWithNewOp(op, opTy, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value subU = rewriter.create(loc, uTy, lhsU, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, subU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithDivSIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(arith::DivSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type newTy = getTypeConverter()->convertType(op.getType()); + if (!newTy) + return failure(); + rewriter.replaceOpWithNewOp(op, newTy, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } +}; + +struct ArithRemSIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type newTy = getTypeConverter()->convertType(op.getType()); + if (!newTy) + return failure(); + rewriter.replaceOpWithNewOp(op, newTy, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } +}; + +struct ArithTruncIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + auto dstIntTy = dyn_cast(op.getType()); + auto srcIntTy = dyn_cast(op.getIn().getType()); + if (!dstIntTy || !srcIntTy) + return rewriter.notifyMatchFailure(op, "expected scalar integer types"); + + Type dstTy = getTypeConverter()->convertType(dstIntTy); + if (!dstTy) + return failure(); + + // to-i1 conversions: Arith wants truncation to the low bit, while C/C++ + // casts to bool are equivalent to `v != 0`. Implement as `(bool)(v & 1)`. + if (dstIntTy.getWidth() == 1) { + if (srcIntTy.getWidth() == 1) { + rewriter.replaceOp(op, adaptor.getIn()); + return success(); + } + + auto uSrcTy = + getUnsignedIntOpaqueType(rewriter.getContext(), srcIntTy.getWidth()); + Value inU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getIn(), + srcIntTy.getWidth()); + Value one = makeEmitCIntConstant(rewriter, loc, uSrcTy, 1); + Value masked = + rewriter.create(loc, uSrcTy, inU, one); + Value asBool = emitCCast(rewriter, loc, dstTy, masked); + rewriter.replaceOp(op, asBool); + return success(); + } + + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); + return success(); + } +}; + +struct ArithConstantToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type newType = getTypeConverter()->convertType(op.getType()); + if (!newType) + return failure(); + + // `adaptor.getValue()` may be null if attribute conversion isn't defined. + // Use the original attribute as fallback and always cast null-safely. + Attribute valueAttr = adaptor.getValue(); + if (!valueAttr) + valueAttr = op.getValue(); + + if (auto opaqueLiteral = buildEmitCOpaqueConstantLiteral(newType, valueAttr); + succeeded(opaqueLiteral)) { + auto constAttr = emitc::OpaqueAttr::get(rewriter.getContext(), *opaqueLiteral); + rewriter.replaceOpWithNewOp(op, newType, constAttr); + return success(); + } + + if (auto floatAttr = dyn_cast_or_null(valueAttr)) { + SmallString<32> valStr; + floatAttr.getValue().toString(valStr); + llvm::StringRef s(valStr); + // Ensure the literal parses as a floating-point constant in C/C++. + // `APFloat::toString` may emit "1" for integral values; make it "1.0". + const bool hasFloatMarker = + s.contains('.') || s.contains('e') || s.contains('E') || + s.contains('p') || s.contains('P') || s.starts_with("0x") || + s.starts_with("0X") || s.starts_with("nan") || + s.starts_with("-nan") || s.starts_with("inf") || + s.starts_with("-inf"); + if (!hasFloatMarker) + valStr.append(".0"); + // Suffix: keep `f` for f16/f32; omit for f64. + if (!floatAttr.getType().isF64()) + valStr.append("f"); + auto constAttr = emitc::OpaqueAttr::get(rewriter.getContext(), valStr); + rewriter.replaceOpWithNewOp(op, newType, constAttr); + return success(); + } + + if (auto intAttr = dyn_cast_or_null(valueAttr)) { + std::string valStr = std::to_string(intAttr.getValue().getSExtValue()); + auto constAttr = emitc::OpaqueAttr::get(rewriter.getContext(), valStr); + rewriter.replaceOpWithNewOp(op, newType, constAttr); + return success(); + } + + return failure(); + } +}; + +} // namespace + +void populatePTOToEmitCArithPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx) { + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add>( + typeConverter, ctx); + patterns.add>( + typeConverter, ctx); + patterns.add>( + typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); +} + +} // namespace mlir::pto diff --git a/lib/PTO/Transforms/PTOToEmitCInternal.h b/lib/PTO/Transforms/PTOToEmitCInternal.h new file mode 100644 index 000000000..0d43b8a1b --- /dev/null +++ b/lib/PTO/Transforms/PTOToEmitCInternal.h @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef MLIR_DIALECT_PTO_TRANSFORMS_PTOTOEMITCINTERNAL_H +#define MLIR_DIALECT_PTO_TRANSFORMS_PTOTOEMITCINTERNAL_H + +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir::pto { + +void populatePTOToEmitCArithPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx); + +} // namespace mlir::pto + +#endif // MLIR_DIALECT_PTO_TRANSFORMS_PTOTOEMITCINTERNAL_H diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index c21669b81..63f2c8687 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -15,6 +15,7 @@ #include "PTO/IR/PTO.h" #include "PTO/IR/PTOTypeUtils.h" #include "PTO/Transforms/Passes.h" +#include "PTOViewToMemrefInternal.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -98,11 +99,6 @@ constexpr int32_t kSLayoutColMajor = constexpr int32_t kCompactModeRowPlusOne = static_cast(CompactMode::RowPlusOne); -constexpr unsigned kThirdOperandIndex = 2; -constexpr unsigned kFourthOperandIndex = 3; -constexpr unsigned kFifthOperandIndex = 4; -constexpr unsigned kSixthOperandIndex = 5; - template using SmallInlineVector = SmallVector; @@ -1804,1781 +1800,12 @@ struct PTOViewToMemrefPass // ------------------------------------------------------------------ // Stage 3: Rewrite Compute Ops - // [关键] 全面使用 op->getOperand(i) 避免 Typed Accessor Crash // ------------------------------------------------------------------ - - // --- TLoadOp [Src, Dst] --- - DefaultInlineVector loads; - func.walk([&](mlir::pto::TLoadOp op) { loads.push_back(op); }); - for (auto op : loads) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op->getOperand(0); - Value dst = op->getOperand(1); - - auto newOp = - rewriter.create(op.getLoc(), TypeRange{}, src, dst); - newOp->setAttrs(op->getAttrs()); - rewriter.replaceOp(op, newOp->getResults()); - } - - // --- TStoreOp [Src, Dst] --- - DefaultInlineVector storeops; - func.walk([&](mlir::pto::TStoreOp op) { storeops.push_back(op); }); - for (auto op : storeops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op->getOperand(0); - Value dst = op->getOperand(1); - Value preQuant = op.getPreQuantScalar(); - - pto::TStoreOp newOp; - if (preQuant) { - newOp = rewriter.create(op.getLoc(), TypeRange{}, - src, dst, preQuant); - } else { - newOp = rewriter.create(op.getLoc(), TypeRange{}, - src, dst, Value{}); - } - newOp->setAttrs(op->getAttrs()); - rewriter.replaceOp(op, newOp->getResults()); - } - - // --- TTransOp [Src, Tmp, Dst] --- - DefaultInlineVector trans; - func.walk([&](mlir::pto::TTransOp op) { trans.push_back(op); }); - for (auto op : trans) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex)); - } - - // --- TExpOp [Src, Dst] --- - DefaultInlineVector exp; - func.walk([&](mlir::pto::TExpOp op) { exp.push_back(op); }); - for (auto op : exp) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, op->getOperand(0), op->getOperand(1)); - } - - // --- TMulOp [Src, Scalar, Dst] --- - DefaultInlineVector mul; - func.walk([&](mlir::pto::TMulOp op) { mul.push_back(op); }); - for (auto op : mul) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, op->getOperand(0), op.getOperand(1), - op->getOperand(kThirdOperandIndex)); - } - - // --- TMulSOp [Src, Scalar, Dst] --- - DefaultInlineVector muls; - func.walk([&](mlir::pto::TMulSOp op) { muls.push_back(op); }); - for (auto op : muls) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, op->getOperand(0), op.getScalar(), - op->getOperand(kThirdOperandIndex)); - } - - // --- TAddOp [Src0, Src1, Dst] --- - DefaultInlineVector addops; - func.walk([&](mlir::pto::TAddOp op) { addops.push_back(op); }); - for (auto op : addops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex)); - } - - // --- TMatmulOp [Lhs, Rhs, Dst] (no optional bias in ODS) --- - DefaultInlineVector matmuls; - func.walk([&](mlir::pto::TMatmulOp op) { matmuls.push_back(op); }); - for (auto op : matmuls) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Value lhs = op->getOperand(0); - Value rhs = op->getOperand(1); - Value dst = op->getOperand(kThirdOperandIndex); - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, lhs, rhs, dst, op.getAccPhaseAttr()); - } - - // --- TMatmulAccOp [Acc, Lhs, Rhs, Dst] --- - DefaultInlineVector matmulAccs; - func.walk([&](mlir::pto::TMatmulAccOp op) { matmulAccs.push_back(op); }); - for (auto op : matmulAccs) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex), op.getAccPhaseAttr()); - } - - // --- TMatmulBiasOp [Acc, Lhs, Rhs, Bias, Dst] --- - DefaultInlineVector matmulBiass; - func.walk([&](mlir::pto::TMatmulBiasOp op) { matmulBiass.push_back(op); }); - for (auto op : matmulBiass) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex)); - } - - // --- TMatmulMxOp--- - DefaultInlineVector matmulMxs; - func.walk([&](mlir::pto::TMatmulMxOp op) { matmulMxs.push_back(op); }); - for (auto op : matmulMxs) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex), - op->getOperand(kFifthOperandIndex)); - } - - // --- TMatmulMxAccOp --- - DefaultInlineVector matmulMxAccs; - func.walk([&](mlir::pto::TMatmulMxAccOp op) { matmulMxAccs.push_back(op); }); - for (auto op : matmulMxAccs) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex), - op->getOperand(kFifthOperandIndex), - op->getOperand(kSixthOperandIndex)); - } - - // --- TMatmulMxBiasOp --- - DefaultInlineVector matmulMxBiass; - func.walk([&](mlir::pto::TMatmulMxBiasOp op) { matmulMxBiass.push_back(op); }); - for (auto op : matmulMxBiass) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex), - op->getOperand(kFifthOperandIndex), - op->getOperand(kSixthOperandIndex)); - } - - // --- TGemvOp [Lhs, Rhs, Dst] --- - DefaultInlineVector gemvs; - func.walk([&](mlir::pto::TGemvOp op) { gemvs.push_back(op); }); - for (auto op : gemvs) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value lhs = op->getOperand(0); - Value rhs = op->getOperand(1); - Value dst = op->getOperand(kThirdOperandIndex); - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, lhs, rhs, dst); - } - - // --- TGemvAccOp [Acc, Lhs, Rhs, Dst] --- - DefaultInlineVector gemvAccs; - func.walk([&](mlir::pto::TGemvAccOp op) { gemvAccs.push_back(op); }); - for (auto op : gemvAccs) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex)); - } - - // --- TGemvBiasOp [Acc, Lhs, Rhs, Bias, Dst] --- - DefaultInlineVector gemvBiass; - func.walk([&](mlir::pto::TGemvBiasOp op) { gemvBiass.push_back(op); }); - for (auto op : gemvBiass) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex)); - } - - // --- TGemvMxOp [A, AScale, B, BScale, Dst] --- - DefaultInlineVector gemvMxs; - func.walk([&](mlir::pto::TGemvMxOp op) { gemvMxs.push_back(op); }); - for (auto op : gemvMxs) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex), - op->getOperand(kFifthOperandIndex)); - } - - // --- TGemvMxAccOp [CIn, A, AScale, B, BScale, Dst] --- - DefaultInlineVector gemvMxAccs; - func.walk([&](mlir::pto::TGemvMxAccOp op) { gemvMxAccs.push_back(op); }); - for (auto op : gemvMxAccs) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex), - op->getOperand(kFifthOperandIndex), - op->getOperand(kSixthOperandIndex)); - } - - // --- TGemvMxBiasOp [A, AScale, B, BScale, Bias, Dst] --- - DefaultInlineVector gemvMxBiass; - func.walk([&](mlir::pto::TGemvMxBiasOp op) { gemvMxBiass.push_back(op); }); - for (auto op : gemvMxBiass) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex), - op->getOperand(kFifthOperandIndex), - op->getOperand(kSixthOperandIndex)); - } - - // --- TMovOp [Src, Dst] --- - DefaultInlineVector movs; - func.walk([&](mlir::pto::TMovOp op) { movs.push_back(op); }); - for (auto op : movs) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, op.getSrc(), op.getDst(), op.getFp(), - op.getPreQuantScalar(), op.getAccToVecModeAttr(), - op.getReluPreModeAttr()); - } - - DefaultInlineVector abseops; - func.walk([&](mlir::pto::TAbsOp op) { abseops.push_back(op); }); - - for (auto op : abseops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - DefaultInlineVector addcops; - func.walk([&](mlir::pto::TAddCOp op) { addcops.push_back(op); }); - - for (auto op : addcops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value src2 = op.getSrc2(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto src2Ty = dyn_cast(src2.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !src2Ty ||!dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - src2, - dst); - } - - DefaultInlineVector addsops; - func.walk([&](mlir::pto::TAddSOp op) { addsops.push_back(op); }); - - for (auto op : addsops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value scalar = op.getScalar(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - scalar, - dst); - } - - DefaultInlineVector addscops; - func.walk([&](mlir::pto::TAddSCOp op) { addscops.push_back(op); }); - - for (auto op : addscops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value scalar = op.getScalar(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - scalar, - src1, - dst); - } - - DefaultInlineVector andops; - func.walk([&](mlir::pto::TAndOp op) { andops.push_back(op); }); - - for (auto op : andops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - dst); - } - - DefaultInlineVector concats; - func.walk([&](mlir::pto::TConcatOp op) { concats.push_back(op); }); - - for (auto op : concats) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - dst); - } - - DefaultInlineVector concatIdxs; - func.walk([&](mlir::pto::TConcatidxOp op) { concatIdxs.push_back(op); }); - - IRRewriter rewriter(ctx); - for (auto op : concatIdxs) { - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value src0Idx = op.getSrc0Idx(); - Value src1Idx = op.getSrc1Idx(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto src0IdxTy = dyn_cast(src0Idx.getType()); - auto src1IdxTy = dyn_cast(src1Idx.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !src0IdxTy || !src1IdxTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - src0Idx, - src1Idx, - dst); - } - - DefaultInlineVector andsops; - func.walk([&](mlir::pto::TAndSOp op) { andsops.push_back(op); }); - - for (auto op : andsops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value scalar = op.getScalar(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - scalar, - dst); - } - - DefaultInlineVector ciops; - func.walk([&](mlir::pto::TCIOp op) { ciops.push_back(op); }); - - for (auto op : ciops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value s = op->getOperand(0); - Value dst = op.getDst(); - bool descending = op.getDescending(); - - auto sTy = dyn_cast(s.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!sTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - s, - dst, - descending); - } - - DefaultInlineVector cmpops; - func.walk([&](mlir::pto::TCmpOp op) { cmpops.push_back(op); }); - - for (auto op : cmpops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - auto newOp = rewriter.create( - op.getLoc(), - TypeRange{}, - src0, - src1, - dst); - - if (auto a = op.getCmpModeAttr()) - newOp->setAttr("cmpMode", a); - - rewriter.replaceOp(op, newOp->getResults()); // 0 results -> OK - } - - DefaultInlineVector cmpsops; - func.walk([&](mlir::pto::TCmpSOp op) { cmpsops.push_back(op); }); - - for (auto op : cmpsops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value scalar = op.getScalar(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - auto scalarTy = scalar.getType(); - bool scalarOk = - isa(scalarTy); // ScalarType in ODS: int/float - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - if (!scalarOk) { - op.emitError("expects scalar to be an integer or float type"); - signalPassFailure(); - return; - } - - auto cmpMode = op.getCmpModeAttr(); - auto newOp = rewriter.create( - op.getLoc(), - TypeRange{}, - src, - scalar, - cmpMode, - dst); - - rewriter.replaceOp(op, newOp->getResults()); // 0 results -> OK - } - - DefaultInlineVector colexpand; - func.walk([&](mlir::pto::TColExpandOp op) { colexpand.push_back(op); }); - - for (auto op : colexpand) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if ( !srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - DefaultInlineVector colmaxops; - func.walk([&](mlir::pto::TColMaxOp op) { colmaxops.push_back(op); }); - - for (auto op : colmaxops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if ( !srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - DefaultInlineVector colminops; - func.walk([&](mlir::pto::TColMinOp op) { colminops.push_back(op); }); - - for (auto op : colminops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if ( !srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - DefaultInlineVector colexpandmulops; - func.walk([&](mlir::pto::TColExpandMulOp op) { - colexpandmulops.push_back(op); - }); - - for (auto op : colexpandmulops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - dst); - } - - DefaultInlineVector colexpandmaxops; - func.walk([&](mlir::pto::TColExpandMaxOp op) { - colexpandmaxops.push_back(op); - }); - - for (auto op : colexpandmaxops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - dst); - } - - DefaultInlineVector colexpandminops; - func.walk([&](mlir::pto::TColExpandMinOp op) { - colexpandminops.push_back(op); - }); - - for (auto op : colexpandminops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - dst); - } - - DefaultInlineVector colsumops; - func.walk([&](mlir::pto::TColSumOp op) { colsumops.push_back(op); }); - - for (auto op : colsumops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - Value tmp = op.getTmp(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("src/dst are not memref yet"); - signalPassFailure(); - return; - } - - // If tmp exists, it must have isBinary attribute - if (tmp) { - auto tmpTy = dyn_cast(tmp.getType()); - if (!tmpTy) { - op.emitError("tmp is not memref yet"); - signalPassFailure(); - return; - } - - // Get isBinary attribute (should exist if tmp exists) - BoolAttr isBinaryAttr = op.getIsBinaryAttr(); - if (!isBinaryAttr) { - isBinaryAttr = BoolAttr::get(ctx, false); - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - tmp, - dst, - isBinaryAttr); - } else { - // Format 1: no tmp, no isBinary - // Use generic builder to avoid adding default isBinary attribute - SmallVector operands = {src, dst}; - SmallVector attrs; - // Copy all attributes except isBinary - for (auto attr : op->getAttrs()) { - if (attr.getName() != "isBinary") { - attrs.push_back(attr); - } - } - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - operands, - attrs); - } - } - - DefaultInlineVector cvtops; - func.walk([&](mlir::pto::TCvtOp op) { cvtops.push_back(op); }); - - for (auto op : cvtops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - auto rmodeAttr = op.getRmodeAttr(); // PTO_RoundModeAttr - auto satModeAttr = op.getSatModeAttr(); - - auto newOp = rewriter.create( - op.getLoc(), - TypeRange{}, - src, - dst, - rmodeAttr, - satModeAttr); - - rewriter.replaceOp(op, newOp->getResults()); - } - - DefaultInlineVector divops; - func.walk([&](mlir::pto::TDivOp op) { divops.push_back(op); }); - - for (auto op : divops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - dst); - } - - DefaultInlineVector divsops; - func.walk([&](mlir::pto::TDivSOp op) { divsops.push_back(op); }); - - for (auto op : divsops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value scale = op.getScalar(); - Value dst = op.getDst(); - - // Check types - they might still be TileBufType or already converted to MemRefType - auto srcTy = dyn_cast(src.getType()); - auto srcTileTy = dyn_cast(src.getType()); - auto scaleTileTy = dyn_cast(scale.getType()); - auto dstTy = dyn_cast(dst.getType()); - auto dstTileTy = dyn_cast(dst.getType()); - - // Determine which operand is tile-like and which is scalar-like. - // Keep the original operand order (set by parser textual form). - // Check if src is memref/tensor/tile (not scalar) - bool srcIsMemref = (srcTy != nullptr || srcTileTy != nullptr || - isa(src.getType()) || - isa(src.getType())); - // Check if scale is memref/tensor/tile (not scalar) - bool scaleIsMemref = (isa(scale.getType()) || - scaleTileTy != nullptr || - isa(scale.getType()) || - isa(scale.getType())); - - // Type validation - ensure we have the right types - if (!srcIsMemref && !scaleIsMemref) { - op.emitError("at least one operand (src or scale) must be tile_buf or memref"); - signalPassFailure(); - return; - } - if (srcIsMemref && scaleIsMemref) { - op.emitError("exactly one operand (src or scale) must be tile_buf or memref, the other must be scalar"); - signalPassFailure(); - return; - } - - if (!dstTy && !dstTileTy) { - op.emitError("dst operand must be tile_buf or memref"); - signalPassFailure(); - return; - } - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - scale, - dst); - } - - DefaultInlineVector expandsops; - func.walk([&](mlir::pto::TExpandsOp op) { expandsops.push_back(op); }); - - for (auto op : expandsops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value scalar = op.getScalar(); - Value dst = op.getDst(); - - auto dstTy = dyn_cast(dst.getType()); - if (!dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - scalar, - dst); - } - - DefaultInlineVector extractops; - func.walk([&](mlir::pto::TExtractOp op) { extractops.push_back(op); }); - - for (auto op : extractops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value indexRow = op.getIndexRow(); - Value indexCol = op.getIndexCol(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto indexRowTy = dyn_cast(indexRow.getType()); - auto indexColTy = dyn_cast(indexCol.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !indexRowTy || !indexColTy || !dstTy) { - op.emitError("ins/outs are not correct yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - indexRow, - indexCol, - dst); - } - - DefaultInlineVector fillpadops; - func.walk([&](mlir::pto::TFillPadOp op) { fillpadops.push_back(op); }); - - for (auto op : fillpadops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - DefaultInlineVector fillpadInplaceOps; - func.walk( - [&](mlir::pto::TFillPadInplaceOp op) { fillpadInplaceOps.push_back(op); }); - - for (auto op : fillpadInplaceOps) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - // --- TSetValOp [Dst, Offset, Val] --- - // Lower tile-world scalar write to memref-world SETVAL DPS op. - DefaultInlineVector tsetvalops; - func.walk([&](mlir::pto::TSetValOp op) { tsetvalops.push_back(op); }); - - for (auto op : tsetvalops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value dst = op.getDst(); - Value offset = op.getOffset(); - Value val = op.getVal(); - - auto dstTy = dyn_cast(dst.getType()); - if (!dstTy) { - op.emitError("dst is not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - dst, - offset, - val); - } - - // --- TGetValOp [Src, Offset] -> Scalar --- - // Lower tile-world scalar read to memref-world GETVAL DPS op. - DefaultInlineVector tgetvalops; - func.walk([&](mlir::pto::TGetValOp op) { tgetvalops.push_back(op); }); - - for (auto op : tgetvalops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value offset = op.getOffset(); - Type dstType = op.getDst().getType(); - - auto srcTy = dyn_cast(src.getType()); - if (!srcTy) { - op.emitError("src is not memref yet"); - signalPassFailure(); - return; - } - - auto newOp = rewriter.create( - op.getLoc(), - dstType, - src, - offset); - rewriter.replaceOp(op, newOp.getDst()); - } - - DefaultInlineVector gatherops; - func.walk([&](mlir::pto::TGatherOp op) { gatherops.push_back(op); }); - - for (auto op : gatherops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - Value cdst = op.getCdst(); - Value indices = op.getIndices(); - Value tmp = op.getTmp(); - Value kValue = op.getKValue(); - auto maskPattern = op.getMaskPatternAttr(); - auto cmpMode = op.getCmpModeAttr(); - auto offset = op.getOffsetAttr(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - if (maskPattern) { - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst, - /*cdst=*/Value(), - /*indices=*/Value(), - /*tmp=*/Value(), - /*kValue=*/Value(), - /*maskPattern=*/maskPattern, - /*cmpMode=*/pto::CmpModeAttr(), - /*offset=*/IntegerAttr()); - continue; - } - - if (cdst || kValue) { - auto cdstTy = dyn_cast(cdst.getType()); - auto tmpTy = dyn_cast(tmp.getType()); - if (!cdstTy || !tmpTy) { - op.emitError("compare-form tgather expects cdst/tmp to be memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst, - cdst, - /*indices=*/Value(), - tmp, - kValue, - /*maskPattern=*/pto::MaskPatternAttr(), - cmpMode, - offset); - continue; - } - - if (indices || tmp) { - auto indicesTy = dyn_cast(indices.getType()); - auto tmpTy = dyn_cast(tmp.getType()); - if (!indicesTy || !tmpTy) { - op.emitError("index-form tgather expects indices/tmp to be memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst, - /*cdst=*/Value(), - indices, - tmp, - /*kValue=*/Value(), - /*maskPattern=*/pto::MaskPatternAttr(), - /*cmpMode=*/pto::CmpModeAttr(), - /*offset=*/IntegerAttr()); - continue; - } - - op.emitError("expects tgather to be in mask, index+tmp, or compare+tmp form"); + if (failed(lowerViewToMemrefComputeOps(func, ctx))) { signalPassFailure(); return; } - DefaultInlineVector gatherbops; - func.walk([&](mlir::pto::TGatherBOp op) { gatherbops.push_back(op); }); - - for (auto op : gatherbops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value offsets = op.getOffsets(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto offsetsTy = dyn_cast(offsets.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !offsetsTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - offsets, - dst); - } - - DefaultInlineVector logops; - func.walk([&](mlir::pto::TLogOp op) { logops.push_back(op); }); - - for (auto op : logops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - DefaultInlineVector lreluops; - func.walk([&](mlir::pto::TLReluOp op) { lreluops.push_back(op); }); - - for (auto op : lreluops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value slope = op.getSlope(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto slopeTy = dyn_cast(slope.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !slopeTy || !dstTy) { - op.emitError("ins/outs are not correct type yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - slope, - dst); - } - - DefaultInlineVector maxops; - func.walk([&](mlir::pto::TMaxOp op) { maxops.push_back(op); }); - - for (auto op : maxops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - dst); - } - - DefaultInlineVector maxsops; - func.walk([&](mlir::pto::TMaxSOp op) { maxsops.push_back(op); }); - - for (auto op : maxsops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value scalar = op.getScalar(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - bool scalarIsScalar = isa(scalar.getType()); - if (!srcTy || !scalarIsScalar || !dstTy) { - op.emitError("expects src/dst to be memref and scalar to be integer/float"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - scalar, - dst); - } - - DefaultInlineVector minops; - func.walk([&](mlir::pto::TMinOp op) { minops.push_back(op); }); - - for (auto op : minops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - dst); - } - - DefaultInlineVector minsops; - func.walk([&](mlir::pto::TMinSOp op) { minsops.push_back(op); }); - - for (auto op : minsops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value scalar = op.getScalar(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - bool scalarIsScalar = isa(scalar.getType()); - if (!srcTy || !scalarIsScalar || !dstTy) { - op.emitError("expects src/dst to be memref and scalar to be integer/float"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - scalar, - dst); - } - - DefaultInlineVector movfpops; - func.walk([&](mlir::pto::TMovFPOp op) { movfpops.push_back(op); }); - - for (auto op : movfpops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value fp = op.getFp(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto fpTy = dyn_cast(fp.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !fpTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - fp, - dst); - } - - DefaultInlineVector quantops; - func.walk([&](mlir::pto::TQuantOp op) { quantops.push_back(op); }); - - for (auto op : quantops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value fp = op.getFp(); - Value offset = op.getOffset(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto fpTy = dyn_cast(fp.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !fpTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - if (offset && !dyn_cast(offset.getType())) { - op.emitError("offset is not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - fp, - offset, - dst, - op.getQuantTypeAttr()); - } - - DefaultInlineVector mrgsortops; - func.walk([&](mlir::pto::TMrgSortOp op) { mrgsortops.push_back(op); }); - - for (auto op : mrgsortops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - if (op.isFormat1()) { - Value src = op.getSrc(); - Value dst = op.getDst(); - Value blockLenVal = op.getBlockLen(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - ValueRange{src}, - blockLenVal, - ValueRange{dst}, - Value() /*tmp*/, - Value() /*excuted*/, - op.getExhaustedAttr()); - } else if (op.isFormat2()) { - bool allMemRef = true; - for (Value v : op.getSrcs()) - if (!dyn_cast(v.getType())) { allMemRef = false; break; } - if (!allMemRef) { - op.emitError("format2 ins/outs are not memref yet"); - signalPassFailure(); - return; - } - if (op.getDsts().size() != 1u || !op.getTmp()) { - op.emitError("format2 expects outs(dst) and ins(tmp)"); - signalPassFailure(); - return; - } - - Value dst = op.getDst(); - Value tmp = op.getTmp(); - Value excuted = op.getExcuted(); - if (!dyn_cast(dst.getType()) || !dyn_cast(tmp.getType())) { - op.emitError("format2 dst/tmp must be memref"); - signalPassFailure(); - return; - } - if (!dyn_cast(excuted.getType())) { - op.emitError("format2 outs(excuted) must be vector"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - op.getSrcs(), - Value() /*blockLen*/, - ValueRange{dst}, - tmp, - excuted, - op.getExhaustedAttr()); - } else { - op.emitError("tmrgsort must be format1 or format2"); - signalPassFailure(); - return; - } - } - - DefaultInlineVector negops; - func.walk([&](mlir::pto::TNegOp op) { negops.push_back(op); }); - - for (auto op : negops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - DefaultInlineVector notops; - func.walk([&](mlir::pto::TNotOp op) { notops.push_back(op); }); - - for (auto op : notops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - DefaultInlineVector orops; - func.walk([&](mlir::pto::TOrOp op) { orops.push_back(op); }); - - for (auto op : orops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - src0, - src1, - dst); - } - - DefaultInlineVector orsops; - func.walk([&](mlir::pto::TOrSOp op) { orsops.push_back(op); }); - - for (auto op : orsops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value scalar = op.getScalar(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto scalarTy = dyn_cast(scalar.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !scalarTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - scalar, - dst); - } - - DefaultInlineVector partaddops; - func.walk([&](mlir::pto::TPartAddOp op) { partaddops.push_back(op); }); - - for (auto op : partaddops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - src0, - src1, - dst); - } - - DefaultInlineVector partmulops; - func.walk([&](mlir::pto::TPartMulOp op) { partmulops.push_back(op); }); - - for (auto op : partmulops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - src0, - src1, - dst); - } - - DefaultInlineVector mgatherops; - func.walk([&](mlir::pto::MGatherOp op) { mgatherops.push_back(op); }); - - for (auto op : mgatherops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value dst = op.getDst(); - Value idx = op.getIdx(); - Value mem = op.getMem(); - - auto dstTy = dyn_cast(dst.getType()); - auto idxTy = dyn_cast(idx.getType()); - auto memTy = dyn_cast(mem.getType()); - if (!dstTy || !idxTy || !memTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - mem, - idx, - dst, - op.getGatherOobAttr()); - } - - DefaultInlineVector mascatterops; - func.walk([&](mlir::pto::MScatterOp op) { mascatterops.push_back(op); }); - - for (auto op : mascatterops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value idx = op.getIdx(); - Value mem = op.getMem(); - - auto srcTy = dyn_cast(src.getType()); - auto idxTy = dyn_cast(idx.getType()); - auto memTy = dyn_cast(mem.getType()); - if (!srcTy || !idxTy || !memTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - idx, - mem, - op.getScatterAtomicOpAttr(), - op.getScatterOobAttr()); - } - DefaultInlineVector printops; - func.walk([&](mlir::pto::TPrintOp op) { printops.push_back(op); }); - - for (auto op : printops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - - auto srcTy = dyn_cast(src.getType()); - if (!srcTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src); - } - // ------------------------------------------------------------------ // Stage 4: Reconcile control-flow result types // ------------------------------------------------------------------ diff --git a/lib/PTO/Transforms/PTOViewToMemrefCompute.cpp b/lib/PTO/Transforms/PTOViewToMemrefCompute.cpp new file mode 100644 index 000000000..47558fda0 --- /dev/null +++ b/lib/PTO/Transforms/PTOViewToMemrefCompute.cpp @@ -0,0 +1,1760 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- PTOViewToMemrefCompute.cpp ----------------------------------------===// +//===----------------------------------------------------------------------===// + +#include "PTOViewToMemrefInternal.h" + +#include "PTO/IR/PTO.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LLVM.h" + +using namespace mlir; + +namespace mlir::pto { + +namespace { + +template +using DefaultInlineVector = SmallVector; + +constexpr unsigned kThirdOperandIndex = 2; +constexpr unsigned kFourthOperandIndex = 3; +constexpr unsigned kFifthOperandIndex = 4; +constexpr unsigned kSixthOperandIndex = 5; + +} // namespace + +LogicalResult lowerViewToMemrefComputeOps(func::FuncOp func, MLIRContext *ctx) { +// ------------------------------------------------------------------ +// Stage 3: Rewrite Compute Ops +// [关键] 全面使用 op->getOperand(i) 避免 Typed Accessor Crash +// ------------------------------------------------------------------ + +// --- TLoadOp [Src, Dst] --- +DefaultInlineVector loads; +func.walk([&](mlir::pto::TLoadOp op) { loads.push_back(op); }); +for (auto op : loads) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op->getOperand(0); + Value dst = op->getOperand(1); + + auto newOp = + rewriter.create(op.getLoc(), TypeRange{}, src, dst); + newOp->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newOp->getResults()); +} + +// --- TStoreOp [Src, Dst] --- +DefaultInlineVector storeops; +func.walk([&](mlir::pto::TStoreOp op) { storeops.push_back(op); }); +for (auto op : storeops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op->getOperand(0); + Value dst = op->getOperand(1); + Value preQuant = op.getPreQuantScalar(); + + pto::TStoreOp newOp; + if (preQuant) { + newOp = rewriter.create(op.getLoc(), TypeRange{}, + src, dst, preQuant); + } else { + newOp = rewriter.create(op.getLoc(), TypeRange{}, + src, dst, Value{}); + } + newOp->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newOp->getResults()); +} + + // --- TTransOp [Src, Tmp, Dst] --- +DefaultInlineVector trans; +func.walk([&](mlir::pto::TTransOp op) { trans.push_back(op); }); +for (auto op : trans) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex)); +} + +// --- TExpOp [Src, Dst] --- +DefaultInlineVector exp; +func.walk([&](mlir::pto::TExpOp op) { exp.push_back(op); }); +for (auto op : exp) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, op->getOperand(0), op->getOperand(1)); +} + +// --- TMulOp [Src, Scalar, Dst] --- +DefaultInlineVector mul; +func.walk([&](mlir::pto::TMulOp op) { mul.push_back(op); }); +for (auto op : mul) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, op->getOperand(0), op.getOperand(1), + op->getOperand(kThirdOperandIndex)); +} + +// --- TMulSOp [Src, Scalar, Dst] --- +DefaultInlineVector muls; +func.walk([&](mlir::pto::TMulSOp op) { muls.push_back(op); }); +for (auto op : muls) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, op->getOperand(0), op.getScalar(), + op->getOperand(kThirdOperandIndex)); +} + +// --- TAddOp [Src0, Src1, Dst] --- +DefaultInlineVector addops; +func.walk([&](mlir::pto::TAddOp op) { addops.push_back(op); }); +for (auto op : addops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex)); +} + +// --- TMatmulOp [Lhs, Rhs, Dst] (no optional bias in ODS) --- +DefaultInlineVector matmuls; +func.walk([&](mlir::pto::TMatmulOp op) { matmuls.push_back(op); }); +for (auto op : matmuls) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Value lhs = op->getOperand(0); + Value rhs = op->getOperand(1); + Value dst = op->getOperand(kThirdOperandIndex); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, lhs, rhs, dst, op.getAccPhaseAttr()); +} + +// --- TMatmulAccOp [Acc, Lhs, Rhs, Dst] --- +DefaultInlineVector matmulAccs; +func.walk([&](mlir::pto::TMatmulAccOp op) { matmulAccs.push_back(op); }); +for (auto op : matmulAccs) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex), op.getAccPhaseAttr()); +} + +// --- TMatmulBiasOp [Acc, Lhs, Rhs, Bias, Dst] --- +DefaultInlineVector matmulBiass; +func.walk([&](mlir::pto::TMatmulBiasOp op) { matmulBiass.push_back(op); }); +for (auto op : matmulBiass) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex)); +} + +// --- TMatmulMxOp--- +DefaultInlineVector matmulMxs; +func.walk([&](mlir::pto::TMatmulMxOp op) { matmulMxs.push_back(op); }); +for (auto op : matmulMxs) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex), + op->getOperand(kFifthOperandIndex)); +} + +// --- TMatmulMxAccOp --- +DefaultInlineVector matmulMxAccs; +func.walk([&](mlir::pto::TMatmulMxAccOp op) { matmulMxAccs.push_back(op); }); +for (auto op : matmulMxAccs) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex), + op->getOperand(kFifthOperandIndex), + op->getOperand(kSixthOperandIndex)); +} + +// --- TMatmulMxBiasOp --- +DefaultInlineVector matmulMxBiass; +func.walk([&](mlir::pto::TMatmulMxBiasOp op) { matmulMxBiass.push_back(op); }); +for (auto op : matmulMxBiass) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex), + op->getOperand(kFifthOperandIndex), + op->getOperand(kSixthOperandIndex)); +} + +// --- TGemvOp [Lhs, Rhs, Dst] --- +DefaultInlineVector gemvs; +func.walk([&](mlir::pto::TGemvOp op) { gemvs.push_back(op); }); +for (auto op : gemvs) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value lhs = op->getOperand(0); + Value rhs = op->getOperand(1); + Value dst = op->getOperand(kThirdOperandIndex); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, lhs, rhs, dst); +} + +// --- TGemvAccOp [Acc, Lhs, Rhs, Dst] --- +DefaultInlineVector gemvAccs; +func.walk([&](mlir::pto::TGemvAccOp op) { gemvAccs.push_back(op); }); +for (auto op : gemvAccs) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex)); +} + +// --- TGemvBiasOp [Acc, Lhs, Rhs, Bias, Dst] --- +DefaultInlineVector gemvBiass; +func.walk([&](mlir::pto::TGemvBiasOp op) { gemvBiass.push_back(op); }); +for (auto op : gemvBiass) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex)); +} + +// --- TGemvMxOp [A, AScale, B, BScale, Dst] --- +DefaultInlineVector gemvMxs; +func.walk([&](mlir::pto::TGemvMxOp op) { gemvMxs.push_back(op); }); +for (auto op : gemvMxs) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex), + op->getOperand(kFifthOperandIndex)); +} + +// --- TGemvMxAccOp [CIn, A, AScale, B, BScale, Dst] --- +DefaultInlineVector gemvMxAccs; +func.walk([&](mlir::pto::TGemvMxAccOp op) { gemvMxAccs.push_back(op); }); +for (auto op : gemvMxAccs) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex), + op->getOperand(kFifthOperandIndex), + op->getOperand(kSixthOperandIndex)); +} + +// --- TGemvMxBiasOp [A, AScale, B, BScale, Bias, Dst] --- +DefaultInlineVector gemvMxBiass; +func.walk([&](mlir::pto::TGemvMxBiasOp op) { gemvMxBiass.push_back(op); }); +for (auto op : gemvMxBiass) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex), + op->getOperand(kFifthOperandIndex), + op->getOperand(kSixthOperandIndex)); +} + +// --- TMovOp [Src, Dst] --- +DefaultInlineVector movs; +func.walk([&](mlir::pto::TMovOp op) { movs.push_back(op); }); +for (auto op : movs) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, op.getSrc(), op.getDst(), op.getFp(), + op.getPreQuantScalar(), op.getAccToVecModeAttr(), + op.getReluPreModeAttr()); +} + +DefaultInlineVector abseops; +func.walk([&](mlir::pto::TAbsOp op) { abseops.push_back(op); }); + +for (auto op : abseops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); +} + +DefaultInlineVector addcops; +func.walk([&](mlir::pto::TAddCOp op) { addcops.push_back(op); }); + +for (auto op : addcops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value src2 = op.getSrc2(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto src2Ty = dyn_cast(src2.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !src2Ty ||!dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + src2, + dst); +} + +DefaultInlineVector addsops; +func.walk([&](mlir::pto::TAddSOp op) { addsops.push_back(op); }); + +for (auto op : addsops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value scalar = op.getScalar(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + scalar, + dst); +} + +DefaultInlineVector addscops; +func.walk([&](mlir::pto::TAddSCOp op) { addscops.push_back(op); }); + +for (auto op : addscops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value scalar = op.getScalar(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + scalar, + src1, + dst); +} + +DefaultInlineVector andops; +func.walk([&](mlir::pto::TAndOp op) { andops.push_back(op); }); + +for (auto op : andops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + dst); +} + +DefaultInlineVector concats; +func.walk([&](mlir::pto::TConcatOp op) { concats.push_back(op); }); + +for (auto op : concats) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + dst); +} + +DefaultInlineVector concatIdxs; +func.walk([&](mlir::pto::TConcatidxOp op) { concatIdxs.push_back(op); }); + +IRRewriter rewriter(ctx); +for (auto op : concatIdxs) { + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value src0Idx = op.getSrc0Idx(); + Value src1Idx = op.getSrc1Idx(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto src0IdxTy = dyn_cast(src0Idx.getType()); + auto src1IdxTy = dyn_cast(src1Idx.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !src0IdxTy || !src1IdxTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + src0Idx, + src1Idx, + dst); +} + +DefaultInlineVector andsops; +func.walk([&](mlir::pto::TAndSOp op) { andsops.push_back(op); }); + +for (auto op : andsops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value scalar = op.getScalar(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + scalar, + dst); +} + +DefaultInlineVector ciops; +func.walk([&](mlir::pto::TCIOp op) { ciops.push_back(op); }); + +for (auto op : ciops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value s = op->getOperand(0); + Value dst = op.getDst(); + bool descending = op.getDescending(); + + auto sTy = dyn_cast(s.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!sTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + s, + dst, + descending); +} + +DefaultInlineVector cmpops; +func.walk([&](mlir::pto::TCmpOp op) { cmpops.push_back(op); }); + +for (auto op : cmpops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + auto newOp = rewriter.create( + op.getLoc(), + TypeRange{}, + src0, + src1, + dst); + + if (auto a = op.getCmpModeAttr()) + newOp->setAttr("cmpMode", a); + + rewriter.replaceOp(op, newOp->getResults()); // 0 results -> OK +} + +DefaultInlineVector cmpsops; +func.walk([&](mlir::pto::TCmpSOp op) { cmpsops.push_back(op); }); + +for (auto op : cmpsops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value scalar = op.getScalar(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + auto scalarTy = scalar.getType(); + bool scalarOk = + isa(scalarTy); // ScalarType in ODS: int/float + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + if (!scalarOk) { + op.emitError("expects scalar to be an integer or float type"); + return failure(); + } + + auto cmpMode = op.getCmpModeAttr(); + auto newOp = rewriter.create( + op.getLoc(), + TypeRange{}, + src, + scalar, + cmpMode, + dst); + + rewriter.replaceOp(op, newOp->getResults()); // 0 results -> OK +} + +DefaultInlineVector colexpand; +func.walk([&](mlir::pto::TColExpandOp op) { colexpand.push_back(op); }); + +for (auto op : colexpand) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if ( !srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); +} + +DefaultInlineVector colmaxops; +func.walk([&](mlir::pto::TColMaxOp op) { colmaxops.push_back(op); }); + +for (auto op : colmaxops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if ( !srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); +} + +DefaultInlineVector colminops; +func.walk([&](mlir::pto::TColMinOp op) { colminops.push_back(op); }); + +for (auto op : colminops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if ( !srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); +} + +DefaultInlineVector colexpandmulops; +func.walk([&](mlir::pto::TColExpandMulOp op) { + colexpandmulops.push_back(op); +}); + +for (auto op : colexpandmulops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + dst); +} + +DefaultInlineVector colexpandmaxops; +func.walk([&](mlir::pto::TColExpandMaxOp op) { + colexpandmaxops.push_back(op); +}); + +for (auto op : colexpandmaxops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + dst); +} + +DefaultInlineVector colexpandminops; +func.walk([&](mlir::pto::TColExpandMinOp op) { + colexpandminops.push_back(op); +}); + +for (auto op : colexpandminops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + dst); +} + +DefaultInlineVector colsumops; +func.walk([&](mlir::pto::TColSumOp op) { colsumops.push_back(op); }); + +for (auto op : colsumops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + Value tmp = op.getTmp(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("src/dst are not memref yet"); + return failure(); + } + + // If tmp exists, it must have isBinary attribute + if (tmp) { + auto tmpTy = dyn_cast(tmp.getType()); + if (!tmpTy) { + op.emitError("tmp is not memref yet"); + return failure(); + } + + // Get isBinary attribute (should exist if tmp exists) + BoolAttr isBinaryAttr = op.getIsBinaryAttr(); + if (!isBinaryAttr) { + isBinaryAttr = BoolAttr::get(ctx, false); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + tmp, + dst, + isBinaryAttr); + } else { + // Format 1: no tmp, no isBinary + // Use generic builder to avoid adding default isBinary attribute + SmallVector operands = {src, dst}; + SmallVector attrs; + // Copy all attributes except isBinary + for (auto attr : op->getAttrs()) { + if (attr.getName() != "isBinary") { + attrs.push_back(attr); + } + } + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + operands, + attrs); + } +} + +DefaultInlineVector cvtops; +func.walk([&](mlir::pto::TCvtOp op) { cvtops.push_back(op); }); + +for (auto op : cvtops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + auto rmodeAttr = op.getRmodeAttr(); // PTO_RoundModeAttr + auto satModeAttr = op.getSatModeAttr(); + + auto newOp = rewriter.create( + op.getLoc(), + TypeRange{}, + src, + dst, + rmodeAttr, + satModeAttr); + + rewriter.replaceOp(op, newOp->getResults()); +} + +DefaultInlineVector divops; +func.walk([&](mlir::pto::TDivOp op) { divops.push_back(op); }); + +for (auto op : divops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + dst); +} + +DefaultInlineVector divsops; +func.walk([&](mlir::pto::TDivSOp op) { divsops.push_back(op); }); + +for (auto op : divsops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value scale = op.getScalar(); + Value dst = op.getDst(); + + // Check types - they might still be TileBufType or already converted to MemRefType + auto srcTy = dyn_cast(src.getType()); + auto srcTileTy = dyn_cast(src.getType()); + auto scaleTileTy = dyn_cast(scale.getType()); + auto dstTy = dyn_cast(dst.getType()); + auto dstTileTy = dyn_cast(dst.getType()); + + // Determine which operand is tile-like and which is scalar-like. + // Keep the original operand order (set by parser textual form). + // Check if src is memref/tensor/tile (not scalar) + bool srcIsMemref = (srcTy != nullptr || srcTileTy != nullptr || + isa(src.getType()) || + isa(src.getType())); + // Check if scale is memref/tensor/tile (not scalar) + bool scaleIsMemref = (isa(scale.getType()) || + scaleTileTy != nullptr || + isa(scale.getType()) || + isa(scale.getType())); + + // Type validation - ensure we have the right types + if (!srcIsMemref && !scaleIsMemref) { + op.emitError("at least one operand (src or scale) must be tile_buf or memref"); + return failure(); + } + if (srcIsMemref && scaleIsMemref) { + op.emitError("exactly one operand (src or scale) must be tile_buf or memref, the other must be scalar"); + return failure(); + } + + if (!dstTy && !dstTileTy) { + op.emitError("dst operand must be tile_buf or memref"); + return failure(); + } + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + scale, + dst); +} + +DefaultInlineVector expandsops; +func.walk([&](mlir::pto::TExpandsOp op) { expandsops.push_back(op); }); + +for (auto op : expandsops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value scalar = op.getScalar(); + Value dst = op.getDst(); + + auto dstTy = dyn_cast(dst.getType()); + if (!dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + scalar, + dst); +} + +DefaultInlineVector extractops; +func.walk([&](mlir::pto::TExtractOp op) { extractops.push_back(op); }); + +for (auto op : extractops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value indexRow = op.getIndexRow(); + Value indexCol = op.getIndexCol(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto indexRowTy = dyn_cast(indexRow.getType()); + auto indexColTy = dyn_cast(indexCol.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !indexRowTy || !indexColTy || !dstTy) { + op.emitError("ins/outs are not correct yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + indexRow, + indexCol, + dst); +} + +DefaultInlineVector fillpadops; +func.walk([&](mlir::pto::TFillPadOp op) { fillpadops.push_back(op); }); + +for (auto op : fillpadops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); +} + +DefaultInlineVector fillpadInplaceOps; +func.walk( + [&](mlir::pto::TFillPadInplaceOp op) { fillpadInplaceOps.push_back(op); }); + +for (auto op : fillpadInplaceOps) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); +} + +// --- TSetValOp [Dst, Offset, Val] --- +// Lower tile-world scalar write to memref-world SETVAL DPS op. +DefaultInlineVector tsetvalops; +func.walk([&](mlir::pto::TSetValOp op) { tsetvalops.push_back(op); }); + +for (auto op : tsetvalops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value dst = op.getDst(); + Value offset = op.getOffset(); + Value val = op.getVal(); + + auto dstTy = dyn_cast(dst.getType()); + if (!dstTy) { + op.emitError("dst is not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + dst, + offset, + val); +} + +// --- TGetValOp [Src, Offset] -> Scalar --- +// Lower tile-world scalar read to memref-world GETVAL DPS op. +DefaultInlineVector tgetvalops; +func.walk([&](mlir::pto::TGetValOp op) { tgetvalops.push_back(op); }); + +for (auto op : tgetvalops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value offset = op.getOffset(); + Type dstType = op.getDst().getType(); + + auto srcTy = dyn_cast(src.getType()); + if (!srcTy) { + op.emitError("src is not memref yet"); + return failure(); + } + + auto newOp = rewriter.create( + op.getLoc(), + dstType, + src, + offset); + rewriter.replaceOp(op, newOp.getDst()); +} + +DefaultInlineVector gatherops; +func.walk([&](mlir::pto::TGatherOp op) { gatherops.push_back(op); }); + +for (auto op : gatherops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + Value cdst = op.getCdst(); + Value indices = op.getIndices(); + Value tmp = op.getTmp(); + Value kValue = op.getKValue(); + auto maskPattern = op.getMaskPatternAttr(); + auto cmpMode = op.getCmpModeAttr(); + auto offset = op.getOffsetAttr(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + if (maskPattern) { + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst, + /*cdst=*/Value(), + /*indices=*/Value(), + /*tmp=*/Value(), + /*kValue=*/Value(), + /*maskPattern=*/maskPattern, + /*cmpMode=*/pto::CmpModeAttr(), + /*offset=*/IntegerAttr()); + continue; + } + + if (cdst || kValue) { + auto cdstTy = dyn_cast(cdst.getType()); + auto tmpTy = dyn_cast(tmp.getType()); + if (!cdstTy || !tmpTy) { + op.emitError("compare-form tgather expects cdst/tmp to be memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst, + cdst, + /*indices=*/Value(), + tmp, + kValue, + /*maskPattern=*/pto::MaskPatternAttr(), + cmpMode, + offset); + continue; + } + + if (indices || tmp) { + auto indicesTy = dyn_cast(indices.getType()); + auto tmpTy = dyn_cast(tmp.getType()); + if (!indicesTy || !tmpTy) { + op.emitError("index-form tgather expects indices/tmp to be memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst, + /*cdst=*/Value(), + indices, + tmp, + /*kValue=*/Value(), + /*maskPattern=*/pto::MaskPatternAttr(), + /*cmpMode=*/pto::CmpModeAttr(), + /*offset=*/IntegerAttr()); + continue; + } + + op.emitError("expects tgather to be in mask, index+tmp, or compare+tmp form"); + return failure(); +} + +DefaultInlineVector gatherbops; +func.walk([&](mlir::pto::TGatherBOp op) { gatherbops.push_back(op); }); + +for (auto op : gatherbops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value offsets = op.getOffsets(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto offsetsTy = dyn_cast(offsets.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !offsetsTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + offsets, + dst); +} + +DefaultInlineVector logops; +func.walk([&](mlir::pto::TLogOp op) { logops.push_back(op); }); + +for (auto op : logops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); +} + +DefaultInlineVector lreluops; +func.walk([&](mlir::pto::TLReluOp op) { lreluops.push_back(op); }); + +for (auto op : lreluops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value slope = op.getSlope(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto slopeTy = dyn_cast(slope.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !slopeTy || !dstTy) { + op.emitError("ins/outs are not correct type yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + slope, + dst); +} + +DefaultInlineVector maxops; +func.walk([&](mlir::pto::TMaxOp op) { maxops.push_back(op); }); + +for (auto op : maxops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + dst); +} + +DefaultInlineVector maxsops; +func.walk([&](mlir::pto::TMaxSOp op) { maxsops.push_back(op); }); + +for (auto op : maxsops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value scalar = op.getScalar(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + bool scalarIsScalar = isa(scalar.getType()); + if (!srcTy || !scalarIsScalar || !dstTy) { + op.emitError("expects src/dst to be memref and scalar to be integer/float"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + scalar, + dst); +} + +DefaultInlineVector minops; +func.walk([&](mlir::pto::TMinOp op) { minops.push_back(op); }); + +for (auto op : minops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + dst); +} + +DefaultInlineVector minsops; +func.walk([&](mlir::pto::TMinSOp op) { minsops.push_back(op); }); + +for (auto op : minsops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value scalar = op.getScalar(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + bool scalarIsScalar = isa(scalar.getType()); + if (!srcTy || !scalarIsScalar || !dstTy) { + op.emitError("expects src/dst to be memref and scalar to be integer/float"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + scalar, + dst); +} + +DefaultInlineVector movfpops; +func.walk([&](mlir::pto::TMovFPOp op) { movfpops.push_back(op); }); + +for (auto op : movfpops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value fp = op.getFp(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto fpTy = dyn_cast(fp.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !fpTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + fp, + dst); +} + +DefaultInlineVector quantops; +func.walk([&](mlir::pto::TQuantOp op) { quantops.push_back(op); }); + +for (auto op : quantops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value fp = op.getFp(); + Value offset = op.getOffset(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto fpTy = dyn_cast(fp.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !fpTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + if (offset && !dyn_cast(offset.getType())) { + op.emitError("offset is not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + fp, + offset, + dst, + op.getQuantTypeAttr()); +} + +DefaultInlineVector mrgsortops; +func.walk([&](mlir::pto::TMrgSortOp op) { mrgsortops.push_back(op); }); + +for (auto op : mrgsortops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + if (op.isFormat1()) { + Value src = op.getSrc(); + Value dst = op.getDst(); + Value blockLenVal = op.getBlockLen(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + ValueRange{src}, + blockLenVal, + ValueRange{dst}, + Value() /*tmp*/, + Value() /*excuted*/, + op.getExhaustedAttr()); + } else if (op.isFormat2()) { + bool allMemRef = true; + for (Value v : op.getSrcs()) + if (!dyn_cast(v.getType())) { allMemRef = false; break; } + if (!allMemRef) { + op.emitError("format2 ins/outs are not memref yet"); + return failure(); + } + if (op.getDsts().size() != 1u || !op.getTmp()) { + op.emitError("format2 expects outs(dst) and ins(tmp)"); + return failure(); + } + + Value dst = op.getDst(); + Value tmp = op.getTmp(); + Value excuted = op.getExcuted(); + if (!dyn_cast(dst.getType()) || !dyn_cast(tmp.getType())) { + op.emitError("format2 dst/tmp must be memref"); + return failure(); + } + if (!dyn_cast(excuted.getType())) { + op.emitError("format2 outs(excuted) must be vector"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + op.getSrcs(), + Value() /*blockLen*/, + ValueRange{dst}, + tmp, + excuted, + op.getExhaustedAttr()); + } else { + op.emitError("tmrgsort must be format1 or format2"); + return failure(); + } +} + +DefaultInlineVector negops; +func.walk([&](mlir::pto::TNegOp op) { negops.push_back(op); }); + +for (auto op : negops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); +} + +DefaultInlineVector notops; +func.walk([&](mlir::pto::TNotOp op) { notops.push_back(op); }); + +for (auto op : notops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); +} + +DefaultInlineVector orops; +func.walk([&](mlir::pto::TOrOp op) { orops.push_back(op); }); + +for (auto op : orops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + src0, + src1, + dst); +} + +DefaultInlineVector orsops; +func.walk([&](mlir::pto::TOrSOp op) { orsops.push_back(op); }); + +for (auto op : orsops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value scalar = op.getScalar(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto scalarTy = dyn_cast(scalar.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !scalarTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + scalar, + dst); +} + +DefaultInlineVector partaddops; +func.walk([&](mlir::pto::TPartAddOp op) { partaddops.push_back(op); }); + +for (auto op : partaddops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + src0, + src1, + dst); +} + +DefaultInlineVector partmulops; +func.walk([&](mlir::pto::TPartMulOp op) { partmulops.push_back(op); }); + +for (auto op : partmulops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + src0, + src1, + dst); +} + +DefaultInlineVector mgatherops; +func.walk([&](mlir::pto::MGatherOp op) { mgatherops.push_back(op); }); + +for (auto op : mgatherops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value dst = op.getDst(); + Value idx = op.getIdx(); + Value mem = op.getMem(); + + auto dstTy = dyn_cast(dst.getType()); + auto idxTy = dyn_cast(idx.getType()); + auto memTy = dyn_cast(mem.getType()); + if (!dstTy || !idxTy || !memTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + mem, + idx, + dst, + op.getGatherOobAttr()); +} + +DefaultInlineVector mascatterops; +func.walk([&](mlir::pto::MScatterOp op) { mascatterops.push_back(op); }); + +for (auto op : mascatterops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value idx = op.getIdx(); + Value mem = op.getMem(); + + auto srcTy = dyn_cast(src.getType()); + auto idxTy = dyn_cast(idx.getType()); + auto memTy = dyn_cast(mem.getType()); + if (!srcTy || !idxTy || !memTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + idx, + mem, + op.getScatterAtomicOpAttr(), + op.getScatterOobAttr()); +} +DefaultInlineVector printops; +func.walk([&](mlir::pto::TPrintOp op) { printops.push_back(op); }); + +for (auto op : printops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + + auto srcTy = dyn_cast(src.getType()); + if (!srcTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src); +} + + + return success(); +} + +} // namespace mlir::pto diff --git a/lib/PTO/Transforms/PTOViewToMemrefInternal.h b/lib/PTO/Transforms/PTOViewToMemrefInternal.h new file mode 100644 index 000000000..8cfb80d2c --- /dev/null +++ b/lib/PTO/Transforms/PTOViewToMemrefInternal.h @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef MLIR_DIALECT_PTO_TRANSFORMS_PTOVIEWTOMEMREFINTERNAL_H +#define MLIR_DIALECT_PTO_TRANSFORMS_PTOVIEWTOMEMREFINTERNAL_H + +#include "mlir/IR/MLIRContext.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir::func { +class FuncOp; +} // namespace mlir::func + +namespace mlir::pto { + +LogicalResult lowerViewToMemrefComputeOps(func::FuncOp func, MLIRContext *ctx); + +} // namespace mlir::pto + +#endif // MLIR_DIALECT_PTO_TRANSFORMS_PTOVIEWTOMEMREFINTERNAL_H From 9cf2a1cdc6e207159ce6e43e297bed05b20ea835 Mon Sep 17 00:00:00 2001 From: HecreReed <821896444@qq.com> Date: Tue, 26 May 2026 13:09:19 +0800 Subject: [PATCH 5/8] refactor: split emitc lowering hotspots --- lib/PTO/Transforms/CMakeLists.txt | 10 + lib/PTO/Transforms/PTOToEmitC.cpp | 9154 +---------------- lib/PTO/Transforms/PTOToEmitCComm.cpp | 889 ++ lib/PTO/Transforms/PTOToEmitCControlFlow.cpp | 717 ++ lib/PTO/Transforms/PTOToEmitCInternal.h | 124 + lib/PTO/Transforms/PTOToEmitCKernelOps.cpp | 516 + lib/PTO/Transforms/PTOToEmitCMemoryOps.cpp | 597 ++ lib/PTO/Transforms/PTOToEmitCRuntimeOps.cpp | 736 ++ lib/PTO/Transforms/PTOToEmitCSimpleOps.cpp | 563 + lib/PTO/Transforms/PTOToEmitCSync.cpp | 1046 ++ .../PTOToEmitCTileMaterialization.cpp | 923 ++ lib/PTO/Transforms/PTOToEmitCTilePatterns.cpp | 1438 +++ .../PTOToEmitCTilePatternsExtra.cpp | 1819 ++++ 13 files changed, 9524 insertions(+), 9008 deletions(-) create mode 100644 lib/PTO/Transforms/PTOToEmitCComm.cpp create mode 100644 lib/PTO/Transforms/PTOToEmitCControlFlow.cpp create mode 100644 lib/PTO/Transforms/PTOToEmitCKernelOps.cpp create mode 100644 lib/PTO/Transforms/PTOToEmitCMemoryOps.cpp create mode 100644 lib/PTO/Transforms/PTOToEmitCRuntimeOps.cpp create mode 100644 lib/PTO/Transforms/PTOToEmitCSimpleOps.cpp create mode 100644 lib/PTO/Transforms/PTOToEmitCSync.cpp create mode 100644 lib/PTO/Transforms/PTOToEmitCTileMaterialization.cpp create mode 100644 lib/PTO/Transforms/PTOToEmitCTilePatterns.cpp create mode 100644 lib/PTO/Transforms/PTOToEmitCTilePatternsExtra.cpp diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index efe728827..35ea68387 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -20,6 +20,16 @@ add_mlir_dialect_library(PTOTransforms PTOValidateIntToPtrUses.cpp PTOToEmitC.cpp PTOToEmitCArith.cpp + PTOToEmitCTilePatterns.cpp + PTOToEmitCTilePatternsExtra.cpp + PTOToEmitCTileMaterialization.cpp + PTOToEmitCSync.cpp + PTOToEmitCComm.cpp + PTOToEmitCKernelOps.cpp + PTOToEmitCControlFlow.cpp + PTOToEmitCSimpleOps.cpp + PTOToEmitCRuntimeOps.cpp + PTOToEmitCMemoryOps.cpp Utils.cpp OptMemPlanForPipeline.cpp AllocToPointerCast.cpp diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index a841c889e..c8e15b51e 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -70,24 +70,14 @@ namespace mlir { using namespace mlir; using namespace mlir::pto; -static std::string getElemTypeStringForGT(Type elemTy); static bool getStaticMemrefLayout(MemRefType mrTy, SmallVectorImpl &strides, int64_t &offset); static int64_t multiplyOrDynamic(int64_t lhs, int64_t rhs); -static void buildGlobalTensorShapeAndStride(ArrayRef shape, - ArrayRef strides, - SmallVectorImpl &shape5D, - SmallVectorImpl &stride5D); -static std::string joinIntTemplateParams(ArrayRef values); -static SmallVector buildRowMajorStrides(ArrayRef shape); static std::string getGlobalTensorTypeStringFromShape(Type elemTy, ArrayRef shape, StringRef layoutEnum = "pto::Layout::ND"); -static std::string getGlobalTensorTypeStringFromShapeAndStrides( - Type elemTy, ArrayRef shape, ArrayRef strides, - StringRef layoutEnum = "pto::Layout::ND"); static emitc::OpaqueType getGlobalTensorOpaqueTypeFromShape( MLIRContext *ctx, Type elemTy, ArrayRef shape, StringRef layoutEnum = "pto::Layout::ND"); @@ -122,20 +112,17 @@ static const char *addrSpaceQualifier(pto::AddressSpace as) { "__pto.lowered_set_validshape"; [[maybe_unused]] static constexpr llvm::StringLiteral kLoweredSetValidShapeConfigAttrName = "__pto.lowered_set_validshape_config"; -static constexpr llvm::StringLiteral kForceDynamicValidShapeAttrName = +[[maybe_unused]] static constexpr llvm::StringLiteral kForceDynamicValidShapeAttrName = "__pto.force_dynamic_valid_shape"; -static constexpr llvm::StringLiteral kGlobalTensorStridesAttrName = +[[maybe_unused]] static constexpr llvm::StringLiteral kGlobalTensorStridesAttrName = "__pto.globaltensor_strides"; -static Value peelUnrealized(Value v) { +Value mlir::pto::peelUnrealized(Value v) { if (auto castOp = v.getDefiningOp()) return castOp.getOperand(0); return v; } -static Value buildGlobalTensorFromMemref(ConversionPatternRewriter &rewriter, - Location loc, Value basePtr, - MemRefType mrTy, Operation *anchor); static Value maybeWrapGlobalMemrefAsGlobalTensor( ConversionPatternRewriter &rewriter, Location loc, Value loweredValue, @@ -259,12 +246,12 @@ static std::string layoutToEmitCString(mlir::pto::Layout layout) { return "pto::Layout::ND"; } -static bool isEmitCGlobalTensorLikeType(Type ty) { +bool mlir::pto::isEmitCGlobalTensorLikeType(Type ty) { auto opaqueTy = dyn_cast(ty); return opaqueTy && opaqueTy.getValue().contains("GlobalTensor<"); } -static std::string getEmitCScalarTypeToken(Type elemTy) { +std::string mlir::pto::getEmitCScalarTypeToken(Type elemTy) { if (pto::isPTOFloat8Type(elemTy) && (elemTy.isFloat8E4M3() || elemTy.isFloat8E4M3FN() || elemTy.isFloat8E4M3FNUZ() || elemTy.isFloat8E4M3B11FNUZ())) @@ -321,7 +308,7 @@ static bool isEmitCPointerLikeType(Type ty) { return false; } -static int64_t getEmitCScalarByteWidth(Type elemTy) { +[[maybe_unused]] static int64_t getEmitCScalarByteWidth(Type elemTy) { if (pto::getPTOStorageElemByteSize(elemTy) == 1) return 1; if (elemTy.isF16() || elemTy.isBF16() || elemTy.isInteger(16)) @@ -336,8 +323,8 @@ static int64_t getEmitCScalarByteWidth(Type elemTy) { static std::string tileBufBLayoutToken(pto::TileBufConfigAttr configAttr); static std::string tileBufSLayoutToken(pto::TileBufConfigAttr configAttr); static std::string tileBufPadToken(pto::TileBufConfigAttr configAttr); -static pto::BLayout getTileBufBLayoutValue(pto::TileBufConfigAttr configAttr); -static int64_t renderTileTemplateDim(int64_t rawDim, Type elemTy, +pto::BLayout mlir::pto::getTileBufBLayoutValue(pto::TileBufConfigAttr configAttr); +int64_t mlir::pto::renderTileTemplateDim(int64_t rawDim, Type elemTy, pto::BLayout blayout, int dimIdx); static const char *tileRoleToken(Attribute memorySpace) { @@ -383,7 +370,7 @@ static std::string tileBufCompactToken(pto::TileBufConfigAttr configAttr) { return compactTok; } -static std::optional getEmitCTileTypeString(pto::TileBufType type) { +std::optional mlir::pto::getEmitCTileTypeString(pto::TileBufType type) { if (type.getRank() != 2) return std::nullopt; auto validShape = type.getValidShape(); @@ -643,11 +630,10 @@ class PTOToEmitCTypeConverter : public TypeConverter { } }; -static constexpr unsigned kPTOIndexBitWidth = +[[maybe_unused]] static constexpr unsigned kPTOIndexBitWidth = 32; // keep consistent with IndexType conversion // Forward declarations (definitions below). -static inline std::string pipeTokFromPipeAttr(mlir::pto::PipeAttr a); static emitc::OpaqueType getSignedIntOpaqueType(MLIRContext *ctx, unsigned bitWidth); static emitc::OpaqueType getUnsignedIntOpaqueType(MLIRContext *ctx, @@ -656,107 +642,10 @@ static emitc::OpaqueType getWiderSignedIntOpaqueType(MLIRContext *ctx, unsigned bitWidth); static emitc::OpaqueType getWiderUnsignedIntOpaqueType(MLIRContext *ctx, unsigned bitWidth); -static Value makeEmitCOpaqueConstant(ConversionPatternRewriter &rewriter, - Location loc, Type type, - llvm::StringRef literal); -static Value makeEmitCIntConstant(ConversionPatternRewriter &rewriter, - Location loc, Type type, int64_t value); -static Value emitCCast(ConversionPatternRewriter &rewriter, Location loc, - Type dstType, Value src); static FailureOr buildEmitCOpaqueConstantLiteral(Type targetType, Attribute valueAttr); -static Value castSignlessIntToUnsignedSameWidth(ConversionPatternRewriter &rewriter, - Location loc, Value v, - unsigned bitWidth); -static bool needsA5NoSplitVectorGuard(Operation *op); - -static FailureOr getTileSplitToken(int64_t split) { - switch (split) { - case 0: - return std::string("TileSplitAxis::TILE_NO_SPLIT"); - case 1: - return std::string("TileSplitAxis::TILE_UP_DOWN"); - case 2: - return std::string("TileSplitAxis::TILE_LEFT_RIGHT"); - default: - return failure(); - } -} - -static FailureOr -getTPipeDirectionToken(bool isL2G2L, int8_t dirMask, PTOArch targetArch) { - if (dirMask == 1) { - if (isL2G2L && targetArch == PTOArch::A5) - return std::string("Direction::DIR_C2V_GM"); - return std::string("Direction::DIR_C2V"); - } - if (dirMask == 2) { - if (isL2G2L && targetArch == PTOArch::A5) - return std::string("Direction::DIR_V2C_GM"); - return std::string("Direction::DIR_V2C"); - } - if (dirMask == 3) - return std::string("Direction::DIR_BOTH"); - return failure(); -} - -static std::string buildTPipeToken(int32_t flagBase, llvm::StringRef dirTok, - int32_t slotSize, int32_t slotNum, - int32_t localSlotNum, bool nosplit) { - std::string token = "TPipe<" + std::to_string(flagBase) + ", " + dirTok.str() + - ", " + std::to_string(slotSize) + ", " + - std::to_string(slotNum); - token += ", " + std::to_string(localSlotNum); - token += nosplit ? ", true" : ", false"; - token += ">"; - return token; -} - -static FailureOr buildTPipeTokenFromInitOp(Operation *op, - PTOArch targetArch) { - if (auto initOp = dyn_cast(op)) { - if (!initOp.getFlagBaseAttr()) - return failure(); - auto dirTok = - getTPipeDirectionToken(/*isL2G2L=*/true, initOp.getDirMask(), targetArch); - if (failed(dirTok)) - return failure(); - int32_t localSlotNum = initOp.getLocalSlotNumAttr() - ? initOp.getLocalSlotNumAttr().getInt() - : initOp.getSlotNum(); - return buildTPipeToken(initOp.getFlagBaseAttr().getInt(), *dirTok, - initOp.getSlotSize(), initOp.getSlotNum(), - localSlotNum, - initOp.getNosplitAttr() && - initOp.getNosplitAttr().getValue()); - } - - if (auto initOp = dyn_cast(op)) { - if (!initOp.getFlagBaseAttr()) - return failure(); - auto dirTok = - getTPipeDirectionToken(/*isL2G2L=*/false, initOp.getDirMask(), targetArch); - if (failed(dirTok)) - return failure(); - return buildTPipeToken(initOp.getFlagBaseAttr().getInt(), *dirTok, - initOp.getSlotSize(), initOp.getSlotNum(), 2, - initOp.getNosplitAttr() && - initOp.getNosplitAttr().getValue()); - } - - return failure(); -} -static FailureOr getTPipeTokenFromValue(Value pipeHandle, - PTOArch targetArch) { - pipeHandle = peelUnrealized(pipeHandle); - Operation *def = pipeHandle.getDefiningOp(); - if (!def) - return failure(); - return buildTPipeTokenFromInitOp(def, targetArch); -} - -static bool isSetFFTsPointerLikeType(Type ty) { +bool mlir::pto::isSetFFTsPointerLikeType(Type ty) { return isEmitCPointerLikeType(ty); } @@ -771,7 +660,7 @@ static Type getTileDataResultType(MLIRContext *ctx, pto::AddressSpace as, return getEmitCPointerType(ctx, addrSpaceQualifier(as), elemTok); } -static Value materializeTileDataValue(ConversionPatternRewriter &rewriter, +Value mlir::pto::materializeTileDataValue(ConversionPatternRewriter &rewriter, Location loc, Value tile, pto::AddressSpace as, StringRef elemTok) { @@ -783,7 +672,7 @@ static Value materializeTileDataValue(ConversionPatternRewriter &rewriter, .getResult(0); } -static Value materializeAddressAsPointer(ConversionPatternRewriter &rewriter, +Value mlir::pto::materializeAddressAsPointer(ConversionPatternRewriter &rewriter, Location loc, Value addr, pto::AddressSpace as, StringRef elemTok) { @@ -805,146 +694,6 @@ static Value materializeAddressAsPointer(ConversionPatternRewriter &rewriter, .getResult(0); } -struct InterCoreSyncCallDesc { - const char *callee = nullptr; - ArrayAttr args; - SmallVector operands; -}; - -static Value castInterCoreEventIdToI32(ConversionPatternRewriter &rewriter, - Location loc, Value eventId) { - auto i32Ty = emitc::OpaqueType::get(rewriter.getContext(), "int32_t"); - if (eventId.getType() == i32Ty) - return eventId; - return emitCCast(rewriter, loc, i32Ty, eventId); -} - -static Attribute getFFTSModeCodegenArg(ConversionPatternRewriter &rewriter, - int64_t fftsMode) { - auto *ctx = rewriter.getContext(); - if (fftsMode == 2) - return emitc::OpaqueAttr::get(ctx, "FFTS_MODE_VAL"); - return emitc::OpaqueAttr::get(ctx, std::to_string(fftsMode)); -} - -static Value createFFTSMsg(ConversionPatternRewriter &rewriter, Location loc, - Value eventI32, int64_t fftsMode) { - auto *ctx = rewriter.getContext(); - auto msgTy = emitc::OpaqueType::get(ctx, "uint16_t"); - auto msgArgs = rewriter.getArrayAttr({ - getFFTSModeCodegenArg(rewriter, fftsMode), - IntegerAttr::get(IndexType::get(ctx), 0), - }); - return rewriter - .create(loc, msgTy, "getFFTSMsg", - /*args=*/msgArgs, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{eventI32}) - .getResult(0); -} - -static InterCoreSyncCallDesc buildInterCoreSyncSetCall( - ConversionPatternRewriter &rewriter, Location loc, PTOArch targetArch, - pto::PipeAttr pipeAttr, IntegerAttr eventIdAttr, int64_t fftsMode) { - auto *ctx = rewriter.getContext(); - std::string pipeTok = pipeTokFromPipeAttr(pipeAttr); - - if (targetArch == PTOArch::A3) { - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value eventVal = - makeEmitCIntConstant(rewriter, loc, i32Ty, eventIdAttr.getInt()); - Value msgVal = createFFTSMsg(rewriter, loc, eventVal, fftsMode); - - InterCoreSyncCallDesc desc; - desc.callee = "ffts_cross_core_sync"; - desc.args = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeTok), - IntegerAttr::get(IndexType::get(ctx), 0), - }); - desc.operands.push_back(msgVal); - return desc; - } - - InterCoreSyncCallDesc desc; - desc.callee = "set_intra_block"; - desc.args = rewriter.getArrayAttr( - {emitc::OpaqueAttr::get(ctx, pipeTok), eventIdAttr}); - return desc; -} - -static InterCoreSyncCallDesc buildInterCoreSyncSetCallDyn( - ConversionPatternRewriter &rewriter, Location loc, PTOArch targetArch, - pto::PipeAttr pipeAttr, Value eventIdVal, int64_t fftsMode) { - auto *ctx = rewriter.getContext(); - std::string pipeTok = pipeTokFromPipeAttr(pipeAttr); - Value eventI32 = castInterCoreEventIdToI32(rewriter, loc, eventIdVal); - - if (targetArch == PTOArch::A3) { - Value msgVal = createFFTSMsg(rewriter, loc, eventI32, fftsMode); - - InterCoreSyncCallDesc desc; - desc.callee = "ffts_cross_core_sync"; - desc.args = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeTok), - IntegerAttr::get(IndexType::get(ctx), 0), - }); - desc.operands.push_back(msgVal); - return desc; - } - - InterCoreSyncCallDesc desc; - desc.callee = "set_intra_block"; - desc.args = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeTok), - IntegerAttr::get(IndexType::get(ctx), 0), - }); - desc.operands.push_back(eventI32); - return desc; -} - -static InterCoreSyncCallDesc buildInterCoreSyncWaitCall( - ConversionPatternRewriter &rewriter, PTOArch targetArch, - pto::PipeAttr pipeAttr, IntegerAttr eventIdAttr) { - auto *ctx = rewriter.getContext(); - std::string pipeTok = pipeTokFromPipeAttr(pipeAttr); - - InterCoreSyncCallDesc desc; - if (targetArch == PTOArch::A3) { - desc.callee = "wait_flag_dev"; - desc.args = rewriter.getArrayAttr({eventIdAttr}); - return desc; - } - - desc.callee = "wait_intra_block"; - desc.args = rewriter.getArrayAttr( - {emitc::OpaqueAttr::get(ctx, pipeTok), eventIdAttr}); - return desc; -} - -static InterCoreSyncCallDesc buildInterCoreSyncWaitCallDyn( - ConversionPatternRewriter &rewriter, Location loc, PTOArch targetArch, - pto::PipeAttr pipeAttr, Value eventIdVal) { - auto *ctx = rewriter.getContext(); - std::string pipeTok = pipeTokFromPipeAttr(pipeAttr); - Value eventI32 = castInterCoreEventIdToI32(rewriter, loc, eventIdVal); - - InterCoreSyncCallDesc desc; - if (targetArch == PTOArch::A3) { - desc.callee = "wait_flag_dev"; - desc.args = rewriter.getArrayAttr({IntegerAttr::get(IndexType::get(ctx), 0)}); - desc.operands.push_back(eventI32); - return desc; - } - - desc.callee = "wait_intra_block"; - desc.args = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeTok), - IntegerAttr::get(IndexType::get(ctx), 0), - }); - desc.operands.push_back(eventI32); - return desc; -} - static bool hasInterCoreSyncOp(func::FuncOp func) { bool found = false; func.walk([&](Operation *op) { @@ -1053,14 +802,14 @@ static emitc::OpaqueType getUnsignedIntOpaqueType(MLIRContext *ctx, } } -static Value makeEmitCOpaqueConstant(ConversionPatternRewriter &rewriter, +Value mlir::pto::makeEmitCOpaqueConstant(ConversionPatternRewriter &rewriter, Location loc, Type type, llvm::StringRef literal) { auto attr = emitc::OpaqueAttr::get(rewriter.getContext(), literal); return rewriter.create(loc, type, attr); } -static Value makeEmitCIntConstant(ConversionPatternRewriter &rewriter, +Value mlir::pto::makeEmitCIntConstant(ConversionPatternRewriter &rewriter, Location loc, Type type, int64_t value) { return makeEmitCOpaqueConstant(rewriter, loc, type, std::to_string(value)); } @@ -1099,7 +848,7 @@ static Value makeEmitCIntConstant(ConversionPatternRewriter &rewriter, return failure(); } -static Value emitCCast(ConversionPatternRewriter &rewriter, Location loc, +Value mlir::pto::emitCCast(ConversionPatternRewriter &rewriter, Location loc, Type dstType, Value src) { if (src.getType() == dstType) return src; @@ -1110,7 +859,7 @@ static Value emitCCast(ConversionPatternRewriter &rewriter, Location loc, // representing the same N-bit pattern in an unsigned C++ type of the same // width. This avoids incorrect sign-extension when later widening to a larger // unsigned type. -static Value castSignlessIntToUnsignedSameWidth(ConversionPatternRewriter &rewriter, +Value mlir::pto::castSignlessIntToUnsignedSameWidth(ConversionPatternRewriter &rewriter, Location loc, Value v, unsigned bitWidth) { auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); @@ -1173,90 +922,6 @@ struct PTOMGatherToMGATHER : public OpConversionPattern { } }; -struct AffineApplyMulConstToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(affine::AffineApplyOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto map = op.getAffineMap(); - - if (map.getNumDims() != 0 || map.getNumSymbols() != 1) - return failure(); - - auto expr = map.getResult(0); - auto bin = dyn_cast(expr); - if (!bin || bin.getKind() != AffineExprKind::Mul) - return failure(); - - auto lhs = bin.getLHS(); - auto rhs = bin.getRHS(); - - auto symExpr = dyn_cast(lhs); - auto constExpr = dyn_cast(rhs); - if (!symExpr || !constExpr) - return failure(); - - Value inputVal = adaptor.getMapOperands()[0]; - - std::string valStr = std::to_string(constExpr.getValue()); - auto cstAttr = emitc::OpaqueAttr::get(rewriter.getContext(), valStr); - auto cstOp = rewriter.create( - op.getLoc(), inputVal.getType(), cstAttr); - - rewriter.replaceOpWithNewOp( - op, inputVal.getType(), inputVal, cstOp); - - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// Kernel inference helpers -//===----------------------------------------------------------------------===// - -enum class KernelKind { VecAdd, Matmul, Unknown }; - -[[maybe_unused]] static KernelKind inferKernelKind(func::FuncOp f) { - bool hasAdd = false; - bool hasMM = false; - f.walk([&](Operation *op) { - if (isa(op)) hasAdd = true; - if (isa(op)) hasMM = true; - if (isa(op)) hasMM = true; - }); - if (hasMM) return KernelKind::Matmul; - if (hasAdd) return KernelKind::VecAdd; - return KernelKind::Unknown; -} - -[[maybe_unused]] static void inferTileMNK(func::FuncOp f, int &M, int &N, int &K) { - M = 32; N = 32; K = 32; - SmallVector subs; - f.walk([&](memref::SubViewOp sv) { subs.push_back(sv); }); - - auto readShape2D = [&](memref::SubViewOp sv, int &d0, int &d1) { - auto resTy = mlir::cast(sv.getResult().getType()); - if (resTy.getRank() == 2 && resTy.hasStaticShape()) { - d0 = (int)resTy.getDimSize(0); - d1 = (int)resTy.getDimSize(1); - } - }; - - if (subs.empty()) return; - - int a0=32, a1=32; - readShape2D(subs[0], a0, a1); - M = a0; N = a1; - - if (subs.size() >= 2) { - int b0=32, b1=32; - readShape2D(subs[0], a0, a1); - readShape2D(subs[1], b0, b1); - M = a0; K = a1; N = b1; - } -} - static std::optional getKernelKindMacro(func::FuncOp funcOp) { auto kernelKindAttr = funcOp->getAttrOfType(FunctionKernelKindAttr::name); @@ -1874,7 +1539,7 @@ struct SubviewToEmitCPattern : public OpConversionPattern { // Helper: build GlobalTensor from a static MemRef (for TLOAD/TSTORE) //===----------------------------------------------------------------------===// -static std::string getElemTypeStringForGT(Type elemTy) { +std::string mlir::pto::getElemTypeStringForGT(Type elemTy) { return getEmitCScalarTypeToken(elemTy); } @@ -1903,7 +1568,7 @@ static bool getStaticMemrefLayout(MemRefType mrTy, SmallVectorImpl &str }); } -static Value applyStaticMemrefOffset(ConversionPatternRewriter &rewriter, +Value mlir::pto::applyStaticMemrefOffset(ConversionPatternRewriter &rewriter, Location loc, Value basePtr, int64_t offset) { if (offset == 0) @@ -1925,7 +1590,7 @@ static int64_t multiplyOrDynamic(int64_t lhs, int64_t rhs) { return lhs * rhs; } -static void buildGlobalTensorShapeAndStride(ArrayRef shape, +void mlir::pto::buildGlobalTensorShapeAndStride(ArrayRef shape, ArrayRef strides, SmallVectorImpl &shape5D, SmallVectorImpl &stride5D) { @@ -1944,7 +1609,7 @@ static void buildGlobalTensorShapeAndStride(ArrayRef shape, } } -static std::string joinIntTemplateParams(ArrayRef values) { +std::string mlir::pto::joinIntTemplateParams(ArrayRef values) { std::string result; for (size_t i = 0; i < values.size(); ++i) { if (i != 0) @@ -1954,7 +1619,7 @@ static std::string joinIntTemplateParams(ArrayRef values) { return result; } -static SmallVector buildRowMajorStrides(ArrayRef shape) { +SmallVector mlir::pto::buildRowMajorStrides(ArrayRef shape) { SmallVector strides(shape.size(), 1); int64_t running = 1; for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { @@ -1972,7 +1637,7 @@ static std::string getGlobalTensorTypeStringFromShape(Type elemTy, layoutEnum); } -static std::string getGlobalTensorTypeStringFromShapeAndStrides( +std::string mlir::pto::getGlobalTensorTypeStringFromShapeAndStrides( Type elemTy, ArrayRef shape, ArrayRef strides, StringRef layoutEnum) { SmallVector shape5D; @@ -2043,7 +1708,7 @@ static GlobalTensorTypeNames getGlobalTensorTypeNames(Operation *anchor) { "GT" + suffix + "_layout", }; } -static Value buildGlobalTensorFromMemref(ConversionPatternRewriter &rewriter, +Value mlir::pto::buildGlobalTensorFromMemref(ConversionPatternRewriter &rewriter, Location loc, Value basePtr, MemRefType mrTy, Operation *anchor) { @@ -2127,7 +1792,7 @@ static Value maybeWrapGlobalMemrefAsGlobalTensor( return loweredValue; } -static Value castToGMBytePointer(ConversionPatternRewriter &rewriter, +Value mlir::pto::castToGMBytePointer(ConversionPatternRewriter &rewriter, Location loc, Value value) { auto *ctx = rewriter.getContext(); auto targetTy = @@ -2147,7 +1812,7 @@ static Value castToGMBytePointer(ConversionPatternRewriter &rewriter, return rewriter.create(loc, targetTy, value).getResult(); } -static Value materializeTensorViewDataPointer( +Value mlir::pto::materializeTensorViewDataPointer( ConversionPatternRewriter &rewriter, Location loc, Value value, Type sourceType) { auto tvTy = dyn_cast(sourceType); @@ -2205,13 +1870,13 @@ static std::string tileBufPadToken(pto::TileBufConfigAttr configAttr) { return padTok; } -static pto::BLayout getTileBufBLayoutValue(pto::TileBufConfigAttr configAttr) { +pto::BLayout mlir::pto::getTileBufBLayoutValue(pto::TileBufConfigAttr configAttr) { if (auto blAttr = dyn_cast(configAttr.getBLayout())) return blAttr.getValue(); return pto::BLayout::RowMajor; } -static int64_t renderTileTemplateDim(int64_t rawDim, Type elemTy, +int64_t mlir::pto::renderTileTemplateDim(int64_t rawDim, Type elemTy, pto::BLayout blayout, int dimIdx) { assert(dimIdx >= 0 && dimIdx < 2 && "renderTileTemplateDim expects a rank-2 rows/cols dimension index"); @@ -2223,7 +1888,7 @@ static int64_t renderTileTemplateDim(int64_t rawDim, Type elemTy, return dimIdx == packedDim ? rawDim * 2 : rawDim; } -static FailureOr buildAsyncScratchTileValue( +FailureOr mlir::pto::buildAsyncScratchTileValue( ConversionPatternRewriter &rewriter, Location loc, Value originalScratch, Value emittedScratch) { Value scratch = peelUnrealized(emittedScratch); @@ -2288,8608 +1953,140 @@ static FailureOr buildAsyncScratchTileValue( return tile; } -static FailureOr buildSyncAllWorkspaceTileValue( - ConversionPatternRewriter &rewriter, Location loc, Value originalWorkspace, - Value emittedWorkspace) { - Value workspace = peelUnrealized(emittedWorkspace); - if (auto opaqueTy = dyn_cast(workspace.getType())) { - StringRef typeStr = opaqueTy.getValue(); - if (typeStr.contains("Tile<") || typeStr.contains("ConvTile<")) - return workspace; - } - - auto memTy = dyn_cast(originalWorkspace.getType()); - if (!memTy) - return failure(); - if (!memTy.hasStaticShape()) - return failure(); - - ArrayRef rawShape = memTy.getShape(); - if (rawShape.empty() || rawShape.size() > 2) - return failure(); +//===----------------------------------------------------------------------===// +// pto.pointer_cast lowering +//===----------------------------------------------------------------------=== +struct PTOMScatterToMSCATTER : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - int64_t rows = rawShape.size() == 1 ? 1 : rawShape[0]; - int64_t cols = rawShape.size() == 1 ? rawShape[0] : rawShape[1]; - SmallVector shape{rows, cols}; - SmallVector validShape{rows, cols}; + LogicalResult matchAndRewrite(pto::MScatterOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *ctx = rewriter.getContext(); + Value src = peelUnrealized(adaptor.getSrc()); + Value idx = peelUnrealized(adaptor.getIdx()); + Value mem = peelUnrealized(adaptor.getMem()); - auto *ctx = rewriter.getContext(); - pto::TileBufConfigAttr configAttr = pto::TileBufConfigAttr::getDefault(ctx); - if (auto bind = originalWorkspace.getDefiningOp()) { - configAttr = bind.getConfig(); - } else if (auto cast = originalWorkspace.getDefiningOp()) { - if (auto config = cast.getConfig()) - configAttr = *config; - } + Value memArg = maybeWrapGlobalMemrefAsGlobalTensor( + rewriter, op.getLoc(), mem, op.getMem().getType(), op.getOperation()); - Attribute memorySpace = memTy.getMemorySpace(); - if (!memorySpace) - return failure(); + auto scatterAtomicTok = [&](pto::ScatterAtomicOp atomic) -> StringRef { + switch (atomic) { + case pto::ScatterAtomicOp::None: + return "pto::ScatterAtomicOp::None"; + case pto::ScatterAtomicOp::Add: + return "pto::ScatterAtomicOp::Add"; + case pto::ScatterAtomicOp::Max: + return "pto::ScatterAtomicOp::Max"; + case pto::ScatterAtomicOp::Min: + return "pto::ScatterAtomicOp::Min"; + } + llvm_unreachable("unknown ScatterAtomicOp"); + }; + auto scatterOobTok = [&](pto::ScatterOOB mode) -> StringRef { + switch (mode) { + case pto::ScatterOOB::Undefined: + return "pto::ScatterOOB::Undefined"; + case pto::ScatterOOB::Skip: + return "pto::ScatterOOB::Skip"; + case pto::ScatterOOB::Clamp: + return "pto::ScatterOOB::Clamp"; + case pto::ScatterOOB::Wrap: + return "pto::ScatterOOB::Wrap"; + } + llvm_unreachable("unknown ScatterOOB"); + }; - auto tileTy = pto::TileBufType::get(ctx, shape, memTy.getElementType(), - memorySpace, validShape, configAttr); - auto tileTypeString = getEmitCTileTypeString(tileTy); - if (!tileTypeString) - return failure(); + SmallVector templateArgVec; + const bool rowCoalesce = + isRowCoalescedMGatherIndexType(op.getSrc().getType(), op.getIdx().getType()); + templateArgVec.push_back(emitc::OpaqueAttr::get( + ctx, rowCoalesce ? "pto::Coalesce::Row" : "pto::Coalesce::Elem")); + if (op.getScatterAtomicOp() != pto::ScatterAtomicOp::None || + op.getScatterOob() != pto::ScatterOOB::Undefined) { + templateArgVec.push_back(emitc::OpaqueAttr::get( + ctx, scatterAtomicTok(op.getScatterAtomicOp()))); + if (op.getScatterOob() != pto::ScatterOOB::Undefined) + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, scatterOobTok(op.getScatterOob()))); + } + ArrayAttr templateArgs = rewriter.getArrayAttr(templateArgVec); - auto tileEmitTy = emitc::OpaqueType::get(ctx, *tileTypeString); - Value tile = rewriter - .create(loc, tileEmitTy, - emitc::OpaqueAttr::get(ctx, "")) - .getResult(); + rewriter.create( + op.getLoc(), TypeRange{}, "MSCATTER", + ArrayAttr{}, templateArgs, + ValueRange{memArg, src, idx}); - Value rawPtr = workspace; - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - if (isSetFFTsPointerLikeType(rawPtr.getType())) { - auto rcU64 = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - rawPtr = rewriter - .create(loc, u64Ty, "reinterpret_cast", - ArrayAttr{}, rcU64, - ValueRange{rawPtr}) - .getResult(0); - } else if (rawPtr.getType() != u64Ty) { - rawPtr = rewriter.create(loc, u64Ty, rawPtr).getResult(); + rewriter.eraseOp(op); + return success(); } - - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{tile, rawPtr}); - return tile; +}; +static void populatePTOToEmitCPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx, + DataFlowSolver &solver, + PTOArch targetArch) { + (void)solver; + populatePTOToEmitCArithPatterns(patterns, typeConverter, ctx); + populatePTOToEmitCRuntimeOpPatterns(patterns, typeConverter, ctx, targetArch); + populatePTOToEmitCMemoryOpPatterns(patterns, typeConverter, ctx); + populatePTOToEmitCTilePatterns(patterns, typeConverter, ctx); + populatePTOToEmitCSimpleOpPatterns(patterns, typeConverter, ctx); + populatePTOToEmitCTileMaterializationPatterns(patterns, typeConverter, ctx); + populatePTOToEmitCSyncPatterns(patterns, typeConverter, ctx, targetArch); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + populatePTOToEmitCKernelOpPatterns(patterns, typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + populatePTOToEmitCCommPatterns(patterns, typeConverter, ctx, targetArch); + populatePTOToEmitCControlFlowPatterns(patterns, typeConverter, ctx); } //===----------------------------------------------------------------------===// -// pto.pointer_cast lowering -//===----------------------------------------------------------------------=== -struct PointerCastConversion : public OpConversionPattern { - static bool getIndexConst(Value v, int64_t &out) { - if (auto cst = v.getDefiningOp()) { - if (auto ia = dyn_cast(cst.getValue())) { - out = ia.getValue().getSExtValue(); - return true; - } - } - return false; - } +// Pass +//===----------------------------------------------------------------------===// - using OpConversionPattern::OpConversionPattern; +namespace { +struct EmitPTOManualPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(EmitPTOManualPass) - enum class TileRole { Vec, Mat, Left, Right, Acc, Bias, Scaling }; + PTOArch targetArch; - static void collectUserOpsThroughCasts(Value v, SmallVectorImpl &out) { - for (Operation *u : v.getUsers()) { - if (auto castOp = dyn_cast(u)) { - for (Value r : castOp.getResults()) - collectUserOpsThroughCasts(r, out); - continue; - } - out.push_back(u); - } - } + EmitPTOManualPass() : targetArch(PTOArch::A3) {} - static Value peelUnrealized(Value v) { - while (auto castOp = v.getDefiningOp()) { - v = castOp.getOperand(0); - } - return v; + explicit EmitPTOManualPass(PTOArch arch) : targetArch(arch) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); } - static TileRole inferRole(pto::PointerCastOp op) { - // 1. 优先检查 AddressSpace - if (auto memRefTy = dyn_cast(op.getType())) { - Attribute memorySpace = memRefTy.getMemorySpace(); - if (auto ptoAttr = dyn_cast_or_null(memorySpace)) { - switch (ptoAttr.getAddressSpace()) { - case pto::AddressSpace::LEFT: return TileRole::Left; - case pto::AddressSpace::RIGHT: return TileRole::Right; - case pto::AddressSpace::ACC: return TileRole::Acc; - case pto::AddressSpace::BIAS: return TileRole::Bias; - case pto::AddressSpace::MAT: return TileRole::Mat; - case pto::AddressSpace::SCALING: return TileRole::Scaling; - default: break; - } - } - } + void runOnOperation() override { + LLVM_DEBUG(llvm::dbgs() << "DEBUG: Start PTOToEmitC Pass\n"); + MLIRContext *ctx = &getContext(); + ModuleOp mop = getOperation(); - // 2. 通过 Usage 推导 (Fallback) - SmallVector users; - collectUserOpsThroughCasts(op.getResult(), users); + if (failed(pto::validatePTOEntryFunctions(mop))) + return signalPassFailure(); + pto::annotatePTOEntryFunctions(mop); - for (Operation *user : users) { - if (auto mm = dyn_cast(user)) { - if (mm.getDst() && peelUnrealized(mm.getDst()) == op.getResult()) return TileRole::Acc; - if (peelUnrealized(mm.getLhs()) == op.getResult()) return TileRole::Left; - if (peelUnrealized(mm.getRhs()) == op.getResult()) return TileRole::Right; - } - if (auto mmacc = dyn_cast(user)) { - if (mmacc.getDst() && peelUnrealized(mmacc.getDst()) == op.getResult()) return TileRole::Acc; - if (peelUnrealized(mmacc.getAccIn()) == op.getResult()) return TileRole::Acc; - if (peelUnrealized(mmacc.getLhs()) == op.getResult()) return TileRole::Left; - if (peelUnrealized(mmacc.getRhs()) == op.getResult()) return TileRole::Right; + // A3 requires explicit FFTS base setup for inter-core sync ops. + if (targetArch == PTOArch::A3) { + bool hasMissingSetFFTs = false; + for (auto func : mop.getOps()) { + if (!hasInterCoreSyncOp(func)) + continue; + if (hasSetFFTsOp(func)) + continue; + hasMissingSetFFTs = true; + func.emitError() + << "A3 inter-core sync requires explicit `pto.set_ffts` in the " + "same function when using `pto.sync.set`/`pto.sync.wait`"; } - } - - return TileRole::Vec; - } - - // [新增] 辅助函数:判断 Value 是否源自 arith.constant - static bool isConstant(Value v, int64_t &outVal) { - if (!v) return false; - if (auto cst = v.getDefiningOp()) { - if (auto attr = dyn_cast(cst.getValue())) { - outVal = attr.getInt(); - return true; - } - } - return false; - } - - LogicalResult matchAndRewrite(pto::PointerCastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - auto selfType = mlir::cast(op.getType()); - ArrayRef shape = selfType.getShape(); - Type elemType = selfType.getElementType(); - - // 1. 推导 Tile Role - TileRole role = inferRole(op); - - // 2. 类型字符串生成 (elemTypeStr, dimStr) - std::string elemTypeStr = getEmitCScalarTypeToken(elemType); - - std::string dimStr; - pto::BLayout blayout = pto::BLayout::RowMajor; - auto dimToString = [&](int64_t dim, const char *symbol, - int dimIdx) -> std::string { - if (dim == ShapedType::kDynamic) - return std::string(symbol); - return std::to_string(renderTileTemplateDim(dim, elemType, blayout, - dimIdx)); - }; - - // 3. Role Token - const char *roleTok = "TileType::Vec"; - switch (role) { - case TileRole::Left: roleTok = "TileType::Left"; break; - case TileRole::Right: roleTok = "TileType::Right"; break; - case TileRole::Acc: roleTok = "TileType::Acc"; break; - case TileRole::Bias: roleTok = "TileType::Bias"; break; - case TileRole::Mat: roleTok = "TileType::Mat"; break; - case TileRole::Vec: roleTok = "TileType::Vec"; break; - case TileRole::Scaling: roleTok = "TileType::Scaling"; break; - } - - // 4. Config & Layout (support BLayoutAttr/SLayoutAttr/PadValueAttr after namespace change) - std::string layoutParams = "BLayout::RowMajor"; - std::string extraParams = ""; - if (auto configOpt = op.getConfig()) { - auto config = *configOpt; - int32_t blVal = 0; - if (auto attr = dyn_cast(config.getBLayout())) - blVal = static_cast(attr.getValue()); - - if (blVal == 1) layoutParams = "BLayout::ColMajor"; - blayout = blVal == 1 ? pto::BLayout::ColMajor : pto::BLayout::RowMajor; - - int32_t slVal = 0; - if (auto attr = dyn_cast(config.getSLayout())) - slVal = static_cast(attr.getValue()); - - std::string slStr = (slVal == 1) ? "SLayout::RowMajor" : (slVal == 2) ? "SLayout::ColMajor" : "SLayout::NoneBox"; - - int32_t frVal = 0; - if (auto attr = dyn_cast(config.getSFractalSize())) frVal = attr.getInt(); - - int32_t padVal = 0; - if (auto attr = dyn_cast(config.getPad())) - padVal = static_cast(attr.getValue()); - - std::string padStr = "PadValue::Null"; - switch (padVal) { - case 1: padStr = "PadValue::Zero"; break; - case 2: padStr = "PadValue::Max"; break; - case 3: padStr = "PadValue::Min"; break; - } - - int32_t compactVal = 0; - if (auto attr = dyn_cast(config.getCompactMode())) - compactVal = static_cast(attr.getValue()); - - std::string compactStr = "CompactMode::Null"; - switch (compactVal) { - case 1: compactStr = "CompactMode::Normal"; break; - case 2: compactStr = "CompactMode::RowPlusOne"; break; - } - - if (!slStr.empty()) { - extraParams += ", " + slStr + ", " + std::to_string(frVal) + ", " + - padStr + ", " + compactStr; - } - } else { - extraParams = ", SLayout::NoneBox, 512, PadValue::Null, CompactMode::Null"; - } - - if (role == TileRole::Left) - dimStr = dimToString(shape[0], "M", 0) + ", " + - dimToString(shape[1], "K", 1); - else if (role == TileRole::Right) - dimStr = dimToString(shape[0], "K", 0) + ", " + - dimToString(shape[1], "N", 1); - else if (role == TileRole::Bias) - dimStr = "1, " + dimToString(shape[1], "N", 1); - else - dimStr = dimToString(shape[0], "M", 0) + ", " + - dimToString(shape[1], "N", 1); - - // [核心修改] Valid Dims 处理逻辑 (支持混合静态/动态) - std::string vrowTok, vcolTok; - bool useConstructor = false; - - bool rowIsDynamic = false; - bool colIsDynamic = false; - - SmallVector constructorArgs; - - Value vRow = op.getValidRow(); - Value vCol = op.getValidCol(); - Value vRowEmitC = adaptor.getValidRow(); - Value vColEmitC = adaptor.getValidCol(); - bool forceDynamicValid = op->hasAttr(kForceDynamicValidShapeAttrName); - - int64_t cRow = 0, cCol = 0; - bool rowIsConst = vRow && isConstant(vRow, cRow); - bool colIsConst = vCol && isConstant(vCol, cCol); - - auto makeCtorDimValue = [&](Value emitted, int64_t fallback) -> Value { - if (emitted) - return emitted; - return makeEmitCIntConstant( - rewriter, loc, emitc::OpaqueType::get(ctx, "int32_t"), fallback); - }; - auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { - if (!emitted || !pto::isPTOFloat4PackedType(elemType)) - return emitted; - int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; - if (dimIdx != packedDim) - return emitted; - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); - return rewriter.create(loc, i32Ty, emitted, two).getResult(); - }; - - if (forceDynamicValid) { - vrowTok = "-1"; - vcolTok = "-1"; - useConstructor = true; - constructorArgs.push_back( - makeCtorDimValue(maybeScaleDynamicValid(vRowEmitC, 0), - renderTileTemplateDim(rowIsConst ? cRow : shape[0], - elemType, blayout, 0))); - constructorArgs.push_back( - makeCtorDimValue(maybeScaleDynamicValid(vColEmitC, 1), - renderTileTemplateDim(colIsConst ? cCol : shape[1], - elemType, blayout, 1))); - } else { - if (rowIsConst) { - vrowTok = std::to_string( - renderTileTemplateDim(cRow, elemType, blayout, 0)); - } else if (vRow) { - vrowTok = "-1"; - rowIsDynamic = true; - useConstructor = true; - } else { - vrowTok = std::to_string( - renderTileTemplateDim(shape[0], elemType, blayout, 0)); - } - - if (colIsConst) { - vcolTok = std::to_string( - renderTileTemplateDim(cCol, elemType, blayout, 1)); - } else if (vCol) { - vcolTok = "-1"; - colIsDynamic = true; - useConstructor = true; - } else { - vcolTok = std::to_string( - renderTileTemplateDim(shape[1], elemType, blayout, 1)); - } - - if (useConstructor) { - if (rowIsDynamic && vRowEmitC) - constructorArgs.push_back(maybeScaleDynamicValid(vRowEmitC, 0)); - if (colIsDynamic && vColEmitC) - constructorArgs.push_back(maybeScaleDynamicValid(vColEmitC, 1)); - } - } - - // 5. 生成 Tile 类型字符串 - std::string tileTypeStr = - std::string("Tile<") + roleTok + ", " + elemTypeStr + ", " + dimStr + ", " + - layoutParams + ", " + vrowTok + ", " + vcolTok + extraParams + ">"; - - auto tileType = emitc::OpaqueType::get(ctx, tileTypeStr); - Value resultValue; - - if (useConstructor) { - // 使用 CallOpaqueOp 生成构造函数调用 (Tile v = Tile(...)) - auto ctorOp = rewriter.create( - loc, - tileType, // Result Type - tileTypeStr, // Callee Name (类名) - ArrayAttr{}, // args - ArrayAttr{}, // template_args - ValueRange(constructorArgs) // operands - ); - resultValue = ctorOp.getResult(0); - } else { - // 静态情况 (Tile v;) - auto varOp = rewriter.create( - loc, - tileType, - emitc::OpaqueAttr::get(ctx, "") - ); - resultValue = varOp.getResult(); - } - - // TASSIGN: pto-isa expects an integral address. - Value addr = adaptor.getAddrs()[0]; - if (isa(addr.getType()) || - (isa(addr.getType()) && - cast(addr.getType()).getValue().ends_with("*"))) { - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - auto rcU64 = rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - addr = rewriter.create( - loc, u64Ty, "reinterpret_cast", - /*args=*/ArrayAttr{}, /*templateArgs=*/rcU64, - /*operands=*/ValueRange{addr}) - .getResult(0); - } - - rewriter.create( - loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{resultValue, addr}); - - rewriter.replaceOp(op, resultValue); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.load_dps / pto.store_dps lowering (FIX: keep optional result) -//===----------------------------------------------------------------------=== - -struct PTOTLoadToTLOAD : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TLoadOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!op.getDst()) - return rewriter.notifyMatchFailure(op, "expected outs(dst) on pto.tload"); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value srcArg = src; - if (auto srcMrTy = dyn_cast(op.getSrc().getType())) { - bool isGlobal = true; - if (auto asAttr = dyn_cast_or_null(srcMrTy.getMemorySpace())) { - auto as = asAttr.getAddressSpace(); - isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); - } - if (isGlobal) { - if (Value gt = buildGlobalTensorFromMemref(rewriter, op.getLoc(), src, srcMrTy, - op.getOperation())) - srcArg = gt; - } - } - - rewriter.create( - op.getLoc(), TypeRange{}, "TLOAD", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, srcArg}); - - if (op->getNumResults() == 1) { - rewriter.replaceOp(op, dst); - } else { - rewriter.eraseOp(op); - } - return success(); - } -}; - -struct PTOTPrefetchToTPREFETCH : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPrefetchOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!op.getDst()) - return rewriter.notifyMatchFailure(op, "expected outs(dst) on pto.tprefetch"); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value srcArg = src; - if (auto srcMrTy = dyn_cast(op.getSrc().getType())) { - bool isGlobal = true; - if (auto asAttr = dyn_cast_or_null(srcMrTy.getMemorySpace())) { - auto as = asAttr.getAddressSpace(); - isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); - } - if (isGlobal) { - if (Value gt = buildGlobalTensorFromMemref(rewriter, op.getLoc(), src, srcMrTy, - op.getOperation())) - srcArg = gt; - } - } - - rewriter.create( - op.getLoc(), TypeRange{}, "TPREFETCH", - ArrayAttr{}, ArrayAttr{}, ValueRange{dst, srcArg}); - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOTPrefetchAsyncToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPrefetchAsyncOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src = peelUnrealized(adaptor.getSrc()); - Value srcArg = src; - if (!isEmitCGlobalTensorLikeType(srcArg.getType())) { - auto srcMrTy = dyn_cast(op.getSrc().getType()); - if (!srcMrTy) - return rewriter.notifyMatchFailure( - op, "expected src to lower to GlobalTensor or memref"); - srcArg = buildGlobalTensorFromMemref(rewriter, op.getLoc(), src, srcMrTy, - op.getSrc().getDefiningOp() - ? op.getSrc().getDefiningOp() - : op.getOperation()); - } - if (!srcArg) - return rewriter.notifyMatchFailure(op, - "failed to build GlobalTensor src"); - - Value prefetchCtx = peelUnrealized(adaptor.getCtx()); - - Type eventTy = getTypeConverter()->convertType(op.getEvent().getType()); - if (!eventTy) - return rewriter.notifyMatchFailure( - op, "failed to convert tprefetch_async result type"); - - Value event = rewriter - .create( - op.getLoc(), TypeRange{eventTy}, "TPREFETCH_ASYNC", - ArrayAttr{}, ArrayAttr{}, - ValueRange{srcArg, prefetchCtx}) - .getResult(0); - - rewriter.replaceOp(op, ValueRange{event}); - return success(); - } -}; - -struct PTOMakePrefetchAsyncContextToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::MakePrefetchAsyncContextOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type ctxTy = getTypeConverter()->convertType(op.getCtx().getType()); - if (!ctxTy) - return rewriter.notifyMatchFailure( - op, "failed to convert make_prefetch_async_context result type"); - - Value workspace = peelUnrealized(adaptor.getWorkspace()); - workspace = castToGMBytePointer(rewriter, op.getLoc(), workspace); - - Value ctx = rewriter - .create( - op.getLoc(), TypeRange{ctxTy}, "pto::PrefetchAsyncContext", - ArrayAttr{}, ArrayAttr{}, ValueRange{workspace}) - .getResult(0); - - rewriter.replaceOp(op, ValueRange{ctx}); - return success(); - } -}; - -struct PTOGetPrefetchAsyncSessionToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::GetPrefetchAsyncSessionOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type sessionTy = getTypeConverter()->convertType(op.getSession().getType()); - if (!sessionTy) - return rewriter.notifyMatchFailure( - op, "failed to convert get_prefetch_async_session result type"); - - Value ctx = peelUnrealized(adaptor.getCtx()); - Value session = rewriter - .create( - op.getLoc(), TypeRange{sessionTy}, - "PTOAS__PREFETCH_CTX_SESSION", ArrayAttr{}, - ArrayAttr{}, ValueRange{ctx}) - .getResult(0); - - rewriter.replaceOp(op, ValueRange{session}); - return success(); - } -}; - -struct PTOTStoreToTSTORE : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - static std::string stPhaseTok(pto::STPhase phase) { - switch (phase) { - case pto::STPhase::Unspecified: return "STPhase::Unspecified"; - case pto::STPhase::Partial: return "STPhase::Partial"; - case pto::STPhase::Final: return "STPhase::Final"; - } - return "STPhase::Unspecified"; - } - - static std::string atomicTypeTok(pto::AtomicType atomicType) { - switch (atomicType) { - case pto::AtomicType::AtomicNone: return "AtomicType::AtomicNone"; - case pto::AtomicType::AtomicAdd: return "AtomicType::AtomicAdd"; - } - return "AtomicType::AtomicNone"; - } - - static std::string reluPreModeTok(pto::ReluPreMode reluPreMode) { - switch (reluPreMode) { - case pto::ReluPreMode::NoRelu: return "ReluPreMode::NoRelu"; - case pto::ReluPreMode::NormalRelu: return "ReluPreMode::NormalRelu"; - } - return "ReluPreMode::NoRelu"; - } - - LogicalResult matchAndRewrite(pto::TStoreOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!op.getDst()) - return rewriter.notifyMatchFailure(op, "expected outs(dst) on pto.tstore"); - - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value preQuantScalar; - if (op.getPreQuantScalar()) - preQuantScalar = peelUnrealized(adaptor.getPreQuantScalar()); - Value dstArg = dst; - if (auto dstMrTy = dyn_cast(op.getDst().getType())) { - bool isGlobal = true; - if (auto asAttr = dyn_cast_or_null(dstMrTy.getMemorySpace())) { - auto as = asAttr.getAddressSpace(); - isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); - } - if (isGlobal) { - if (Value gt = buildGlobalTensorFromMemref(rewriter, op.getLoc(), dst, dstMrTy, - op.getOperation())) - dstArg = gt; - } - } - - const auto phase = op.getStPhase(); - const auto atomicType = op.getAtomicType(); - const auto reluPreMode = op.getReluPreMode(); - const bool hasPreQuantScalar = static_cast(preQuantScalar); - const bool phaseNonDefault = phase != pto::STPhase::Unspecified; - const bool atomicNonDefault = atomicType != pto::AtomicType::AtomicNone; - const bool reluNonDefault = reluPreMode != pto::ReluPreMode::NoRelu; - - auto getOpaqueTok = [&](Value v, StringRef name) -> FailureOr { - if (auto ot = mlir::dyn_cast(v.getType())) - return ot.getValue().str(); - return rewriter.notifyMatchFailure(op, (name + " must be emitc::OpaqueType").str()); - }; - - ArrayAttr targs; - // Map op attributes/operands to the exact TSTORE overload family: - // 1) TSTORE(dst, src) - // 2) TSTORE(dst, src) - // 3) TSTORE(dst, src) - // 4) TSTORE(dst, src) - // 5) TSTORE(dst, src) - // 6) TSTORE(dst, src) - // 7) TSTORE(dst, src, preQuant) - // 8) TSTORE(dst, src, preQuant) - if (!hasPreQuantScalar && !reluNonDefault && !atomicNonDefault) { - if (phaseNonDefault) { - targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, stPhaseTok(phase)), - }); - } else { - targs = ArrayAttr{}; - } - } else { - auto srcTokOr = getOpaqueTok(src, "src"); - auto dstTokOr = getOpaqueTok(dstArg, "dst"); - if (failed(srcTokOr) || failed(dstTokOr)) - return failure(); - - // If there is no preQuant and relu stays default, emit the atomic-only - // overloads (#3/#4) without ReluPreMode template argument. - if (!hasPreQuantScalar && !reluNonDefault) { - if (phaseNonDefault) { - targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, stPhaseTok(phase)), - emitc::OpaqueAttr::get(ctx, *srcTokOr), - emitc::OpaqueAttr::get(ctx, *dstTokOr), - emitc::OpaqueAttr::get(ctx, atomicTypeTok(atomicType)), - }); - } else { - targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, *srcTokOr), - emitc::OpaqueAttr::get(ctx, *dstTokOr), - emitc::OpaqueAttr::get(ctx, atomicTypeTok(atomicType)), - }); - } - } else { - // Relu/preQuant families (#5/#6/#7/#8): keep AtomicType + ReluPreMode. - if (phaseNonDefault) { - targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, stPhaseTok(phase)), - emitc::OpaqueAttr::get(ctx, *srcTokOr), - emitc::OpaqueAttr::get(ctx, *dstTokOr), - emitc::OpaqueAttr::get(ctx, atomicTypeTok(atomicType)), - emitc::OpaqueAttr::get(ctx, reluPreModeTok(reluPreMode)), - }); - } else { - targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, *srcTokOr), - emitc::OpaqueAttr::get(ctx, *dstTokOr), - emitc::OpaqueAttr::get(ctx, atomicTypeTok(atomicType)), - emitc::OpaqueAttr::get(ctx, reluPreModeTok(reluPreMode)), - }); - } - } - } - - SmallVector operands{dstArg, src}; - if (hasPreQuantScalar) - operands.push_back(preQuantScalar); - - rewriter.create( - loc, TypeRange{}, "TSTORE", - /*args=*/ArrayAttr{}, /*templateArgs=*/targs, - /*operands=*/operands); - - if (op->getNumResults() == 1) { - rewriter.replaceOp(op, dst); - } else { - rewriter.eraseOp(op); - } - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.matmul_dps lowering (Simplified: No internal copy/sync) -//===----------------------------------------------------------------------===// -// -// Render `pto.tmatmul` as one of three forms depending on the optional -// `acc_phase` attribute: -// * absent / Unspecified -> `TMATMUL(dst, lhs, rhs)` -// * Partial -> `TMATMUL(dst, lhs, rhs)` -// * Final -> `TMATMUL(dst, lhs, rhs)` -// The Unspecified default keeps backward compatibility with all upstream IR -// that does not yet emit an explicit phase attribute. -static ArrayAttr buildAccPhaseTemplateArgs(ConversionPatternRewriter &rewriter, - pto::AccPhase phase) { - StringRef tmpl; - switch (phase) { - case pto::AccPhase::Unspecified: - return ArrayAttr{}; - case pto::AccPhase::Partial: - tmpl = "AccPhase::Partial"; - break; - case pto::AccPhase::Final: - tmpl = "AccPhase::Final"; - break; - } - if (tmpl.empty()) - return ArrayAttr{}; - return rewriter.getArrayAttr( - {emitc::OpaqueAttr::get(rewriter.getContext(), tmpl)}); -} - -struct PTOTMatmulToTMATMUL : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMatmulOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // 1. 获取操作数 (剥离 Cast) - Value lhs = peelUnrealized(adaptor.getLhs()); // A (Left) - Value rhs = peelUnrealized(adaptor.getRhs()); // B (Right) - Value dst = peelUnrealized(adaptor.getDst()); // C (Acc) - - // 2. 根据 acc_phase 属性决定是否生成 TMATMUL(...) - ArrayAttr templateArgs = - buildAccPhaseTemplateArgs(rewriter, op.getAccPhase()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TMATMUL", - /*args=*/ArrayAttr{}, /*template_args=*/templateArgs, - ValueRange{dst, lhs, rhs}); - - // 3. 处理 Op 替换/删除 - if (op->getNumResults() == 1) { - rewriter.replaceOp(op, dst); - } else { - rewriter.eraseOp(op); - } - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.tgemv lowering -//===----------------------------------------------------------------------===// -struct PTOTGemvToTGEMV : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGemvOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // 1. 获取操作数 (剥离 Cast) - Value lhs = peelUnrealized(adaptor.getLhs()); // A (Matrix) - Value rhs = peelUnrealized(adaptor.getRhs()); // B (Vector) - Value dst = peelUnrealized(adaptor.getDst()); // C (Result) - - // 2. 直接生成函数调用 TGEMV(dst, lhs, rhs) - rewriter.create( - op.getLoc(), TypeRange{}, "TGEMV", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, lhs, rhs}); - - // 3. 处理 Op 替换/删除 - if (op->getNumResults() == 1) { - rewriter.replaceOp(op, dst); - } else { - rewriter.eraseOp(op); - } - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.tgemv.acc lowering -//===----------------------------------------------------------------------===// -struct PTOTGemvAccToTGEMVACC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGemvAccOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!op.getDst()) - return rewriter.notifyMatchFailure(op, "expected outs(dst) for pto.tgemv.acc"); - - // 1. 获取操作数 - Value accIn = peelUnrealized(adaptor.getAccIn()); // AccOld - Value lhs = peelUnrealized(adaptor.getLhs()); // A (Matrix) - Value rhs = peelUnrealized(adaptor.getRhs()); // B (Vector) - Value dst = peelUnrealized(adaptor.getDst()); // AccNew - - // 2. 直接生成函数调用 TGEMV_ACC(dst, accIn, lhs, rhs) - rewriter.create( - op.getLoc(), TypeRange{}, "TGEMV_ACC", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, accIn, lhs, rhs}); - - // 3. 处理 Op 替换/删除 - if (op->getNumResults() == 1) { - rewriter.replaceOp(op, dst); - } else { - rewriter.eraseOp(op); - } - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.matmul_acc_dps lowering (Simplified: No internal copy/sync) -//===----------------------------------------------------------------------===// -struct PTOTMatmulAccToTMATMULACC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMatmulAccOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!op.getDst()) - return rewriter.notifyMatchFailure(op, "expected outs(dst) for pto.tmatmul.acc"); - - // 1. 获取操作数 - Value accIn = peelUnrealized(adaptor.getAccIn()); // AccOld - Value lhs = peelUnrealized(adaptor.getLhs()); // A (Left) - Value rhs = peelUnrealized(adaptor.getRhs()); // B (Right) - Value dst = peelUnrealized(adaptor.getDst()); // AccNew - - // 2. 根据 acc_phase 属性决定是否生成 TMATMUL_ACC(...) - ArrayAttr templateArgs = - buildAccPhaseTemplateArgs(rewriter, op.getAccPhase()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TMATMUL_ACC", - /*args=*/ArrayAttr{}, /*template_args=*/templateArgs, - ValueRange{dst, accIn, lhs, rhs}); - - // 3. 处理 Op 替换/删除 - if (op->getNumResults() == 1) { - rewriter.replaceOp(op, dst); - } else { - rewriter.eraseOp(op); - } - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// Return lowering -//===----------------------------------------------------------------------=== - -static constexpr llvm::StringLiteral kAutoSyncTailPendingModeAttr = - "__pto.auto_sync_tail_mode"; - -struct ReturnToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (auto emitcFunc = op->getParentOfType()) { - if (auto modeAttr = - emitcFunc->getAttrOfType(kAutoSyncTailPendingModeAttr)) { - auto *ctx = rewriter.getContext(); - rewriter.setInsertionPoint(op); - auto args = rewriter.getArrayAttr( - {emitc::OpaqueAttr::get(ctx, modeAttr.getValue())}); - rewriter.create( - op.getLoc(), TypeRange{}, "ptoas_auto_sync_tail", - args, ArrayAttr{}, ValueRange{}); - } - } - - auto vals = adaptor.getOperands(); - if (vals.empty()) { - rewriter.replaceOpWithNewOp(op, Value{}); - return success(); - } - if (vals.size() == 1) { - rewriter.replaceOpWithNewOp(op, vals[0]); - return success(); - } - return rewriter.notifyMatchFailure(op, "EmitC cannot return multiple values"); - } -}; - -struct CallToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(func::CallOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (op.getNumResults() > 1) - return rewriter.notifyMatchFailure( - op, "EmitC cannot lower calls with multiple results"); - - SmallVector resultTypes; - if (failed( - getTypeConverter()->convertTypes(op.getResultTypes(), resultTypes))) - return rewriter.notifyMatchFailure(op, - "failed to convert call result types"); - - rewriter.replaceOpWithNewOp(op, op.getCalleeAttr(), - resultTypes, - adaptor.getOperands()); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// Sync lowering -//===----------------------------------------------------------------------=== - -static constexpr llvm::StringLiteral kAutoSyncTailBarrierAttr = - "pto.auto_sync_tail_barrier"; -static constexpr llvm::StringLiteral kAutoSyncTailHintAttr = - "pto.auto_sync_tail_hint"; -static constexpr llvm::StringLiteral kAutoSyncTailPolicyBarrierAll = - "barrier_all"; -static constexpr llvm::StringLiteral kAutoSyncTailPolicyMte3ToSEvent0 = - "setwait_mte3_to_s_event0"; -static constexpr llvm::StringLiteral kAutoSyncTailModeBarrierAllToken = - "PTOAutoSyncTailMode::kBarrierAll"; -static constexpr llvm::StringLiteral kAutoSyncTailModeMte3ToSEvent0Token = - "PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0"; - -static std::string getAutoSyncTailModeToken(Operation *op) { - if (op) { - if (auto hintAttr = op->getAttrOfType(kAutoSyncTailHintAttr)) { - if (hintAttr.getValue() == kAutoSyncTailPolicyBarrierAll) - return kAutoSyncTailModeBarrierAllToken.str(); - if (hintAttr.getValue() == kAutoSyncTailPolicyMte3ToSEvent0) - return kAutoSyncTailModeMte3ToSEvent0Token.str(); - } - } - - auto func = op ? op->getParentOfType() : func::FuncOp(); - if (!func) - return kAutoSyncTailModeBarrierAllToken.str(); - - auto hintAttr = func->getAttrOfType(kAutoSyncTailHintAttr); - if (!hintAttr) - return kAutoSyncTailModeBarrierAllToken.str(); - - if (hintAttr.getValue() == kAutoSyncTailPolicyBarrierAll) - return kAutoSyncTailModeBarrierAllToken.str(); - if (hintAttr.getValue() == kAutoSyncTailPolicyMte3ToSEvent0) - return kAutoSyncTailModeMte3ToSEvent0Token.str(); - - // Fallback to the conservative behavior when seeing unknown policies. - return kAutoSyncTailModeBarrierAllToken.str(); -} - -[[maybe_unused]] static std::string getPipeName(pto::PIPE pipe) { - switch (pipe) { - case pto::PIPE::PIPE_S: return "PIPE_S"; - case pto::PIPE::PIPE_V: return "PIPE_V"; - case pto::PIPE::PIPE_M: return "PIPE_M"; - case pto::PIPE::PIPE_MTE1: return "PIPE_MTE1"; - case pto::PIPE::PIPE_MTE2: return "PIPE_MTE2"; - case pto::PIPE::PIPE_MTE3: return "PIPE_MTE3"; - case pto::PIPE::PIPE_ALL: return "PIPE_ALL"; - case pto::PIPE::PIPE_MTE4: return "PIPE_MTE4"; - case pto::PIPE::PIPE_MTE5: return "PIPE_MTE5"; - case pto::PIPE::PIPE_V2: return "PIPE_V2"; - case pto::PIPE::PIPE_FIX: return "PIPE_FIX"; - case pto::PIPE::VIRTUAL_PIPE_MTE2_L1A: return "VIRTUAL_PIPE_MTE2_L1A"; - case pto::PIPE::VIRTUAL_PIPE_MTE2_L1B: return "VIRTUAL_PIPE_MTE2_L1B"; - // 默认回退 - default: return "PIPE_ALL"; - } -} - -//===----------------------------------------------------------------------===// -// pto.barrier lowering -> pipe_barrier(...) -//===----------------------------------------------------------------------===// -struct PTOBarrierToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::BarrierOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (op->hasAttr(kAutoSyncTailBarrierAttr)) { - auto modeAttr = rewriter.getStringAttr(getAutoSyncTailModeToken(op)); - if (auto emitcFunc = op->getParentOfType()) { - emitcFunc->setAttr(kAutoSyncTailPendingModeAttr, modeAttr); - } else if (auto funcOp = op->getParentOfType()) { - funcOp->setAttr(kAutoSyncTailPendingModeAttr, modeAttr); - } - rewriter.eraseOp(op); - return success(); - } - - // [FIX] op.getPipe() returns PipeAttr. - // We must call .getPipe() on the attribute to get the actual Enum value. - pto::PIPE pipeEnum = op.getPipe().getPipe(); - - // Convert Enum to String (e.g., PIPE_ALL -> "PIPE_ALL") - std::string pipeStr = pto::stringifyPIPE(pipeEnum).str(); - auto *ctx = rewriter.getContext(); - - auto args = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeStr) - }); - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, // void return - "pipe_barrier", // function name - args, // arguments - ArrayAttr{}, // template args - ValueRange{} // operands - ); - - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// Sync lowering (robust for bracket form pto.set_flag[...] / pto.wait_flag[...]) -// Replace your PTOSyncToRuntimeCall with the code below. -//===----------------------------------------------------------------------===// - -static bool tryConvertPipeAttrToToken(Attribute attr, std::string &token) { - if (!attr) - return false; - if (auto pipe = dyn_cast(attr)) { - token = mlir::pto::stringifyPIPE(pipe.getPipe()).str(); - return true; - } - if (auto stringAttr = dyn_cast(attr)) { - token = stringAttr.getValue().str(); - return true; - } - return false; -} - -static bool tryConvertEventAttrToToken(Attribute attr, std::string &token) { - if (!attr) - return false; - if (auto event = dyn_cast(attr)) { - token = mlir::pto::stringifyEVENT(event.getEvent()).str(); - return true; - } - if (auto stringAttr = dyn_cast(attr)) { - token = stringAttr.getValue().str(); - return true; - } - return false; -} - -static bool tryAssignSyncTokens(Attribute srcAttr, Attribute dstAttr, - Attribute evtAttr, std::string &srcTok, - std::string &dstTok, std::string &evtTok) { - std::string localSrc; - std::string localDst; - std::string localEvt; - if (!tryConvertPipeAttrToToken(srcAttr, localSrc) || - !tryConvertPipeAttrToToken(dstAttr, localDst) || - !tryConvertEventAttrToToken(evtAttr, localEvt)) { - return false; - } - srcTok = std::move(localSrc); - dstTok = std::move(localDst); - evtTok = std::move(localEvt); - return true; -} - -static bool tryExtractSyncTokensFromNamedAttrs(Operation *op, - StringRef srcName, - StringRef dstName, - StringRef evtName, - std::string &srcTok, - std::string &dstTok, - std::string &evtTok) { - return tryAssignSyncTokens(op->getAttr(srcName), op->getAttr(dstName), - op->getAttr(evtName), srcTok, dstTok, evtTok); -} - -static bool tryExtractSyncTokensFromArrayAttr(Operation *op, StringRef attrName, - std::string &srcTok, - std::string &dstTok, - std::string &evtTok) { - auto arrayAttr = op->getAttrOfType(attrName); - if (!arrayAttr || arrayAttr.size() < 3) - return false; - return tryAssignSyncTokens(arrayAttr[0], arrayAttr[1], arrayAttr[2], srcTok, - dstTok, evtTok); -} - -static bool tryExtractFallbackSyncTokens(Operation *op, std::string &srcTok, - std::string &dstTok, - std::string &evtTok) { - SmallVector pipes; - std::string event; - for (NamedAttribute namedAttr : op->getAttrs()) { - std::string token; - if (tryConvertPipeAttrToToken(namedAttr.getValue(), token)) { - pipes.push_back(std::move(token)); - continue; - } - if (event.empty() && - tryConvertEventAttrToToken(namedAttr.getValue(), token)) { - event = std::move(token); - } - } - if (pipes.size() < 2 || event.empty()) - return false; - srcTok = pipes[0]; - dstTok = pipes[1]; - evtTok = event; - return true; -} - -static LogicalResult extractSyncTripletTokens(Operation *op, - std::string &srcTok, - std::string &dstTok, - std::string &evtTok, - ConversionPatternRewriter &rewriter) { - if (tryExtractSyncTokensFromNamedAttrs(op, "src_pipe", "dst_pipe", "event_id", - srcTok, dstTok, evtTok) || - tryExtractSyncTokensFromNamedAttrs(op, "srcPipe", "dstPipe", "eventId", - srcTok, dstTok, evtTok) || - tryExtractSyncTokensFromNamedAttrs(op, "src", "dst", "event", srcTok, - dstTok, evtTok)) { - return success(); - } - - for (StringRef attrName : {"args", "pipes", "sync", "triplet", "attrs"}) { - if (tryExtractSyncTokensFromArrayAttr(op, attrName, srcTok, dstTok, - evtTok)) { - return success(); - } - } - - if (tryExtractFallbackSyncTokens(op, srcTok, dstTok, evtTok)) - return success(); - return rewriter.notifyMatchFailure( - op, "cannot extract PIPE/PIPE/EVENT tokens from pto.{set,wait}_flag"); -} -static inline std::string pipeTokFromPipeEnum(mlir::pto::PIPE p) { - return mlir::pto::stringifyPIPE(p).str(); -} -[[maybe_unused]] static inline std::string evtTokFromEventEnum(mlir::pto::EVENT e) { - return mlir::pto::stringifyEVENT(e).str(); -} -static inline std::string pipeTokFromPipeAttr(mlir::pto::PipeAttr a) { - return mlir::pto::stringifyPIPE(a.getPipe()).str(); -} -static inline std::string evtTokFromEventAttr(mlir::pto::EventAttr a) { - return mlir::pto::stringifyEVENT(a.getEvent()).str(); -} - -template -struct HasGetSrcPipe : std::false_type {}; -template -struct HasGetSrcPipe().getSrcPipe())>> : std::true_type {}; - -template -struct HasGetDstPipe : std::false_type {}; -template -struct HasGetDstPipe().getDstPipe())>> : std::true_type {}; - -template -struct HasGetEventId : std::false_type {}; -template -struct HasGetEventId().getEventId())>> : std::true_type {}; - -template -struct HasGetSrcPipeAttr : std::false_type {}; -template -struct HasGetSrcPipeAttr().getSrcPipeAttr())>> : std::true_type {}; - -template -struct HasGetDstPipeAttr : std::false_type {}; -template -struct HasGetDstPipeAttr().getDstPipeAttr())>> : std::true_type {}; - -template -struct HasGetEventIdAttr : std::false_type {}; -template -struct HasGetEventIdAttr().getEventIdAttr())>> : std::true_type {}; - -template -static LogicalResult extractSyncTokens(SyncOpT op, - std::string &srcTok, - std::string &dstTok, - std::string &evtTok, - ConversionPatternRewriter &rewriter) { - if constexpr (HasGetSrcPipe::value && - HasGetDstPipe::value && - HasGetEventId::value) { - auto s = op.getSrcPipe(); - auto d = op.getDstPipe(); - auto e = op.getEventId(); - - if constexpr (std::is_same::value) srcTok = pipeTokFromPipeEnum(s); - else srcTok = pipeTokFromPipeAttr(s); - - if constexpr (std::is_same::value) dstTok = pipeTokFromPipeEnum(d); - else dstTok = pipeTokFromPipeAttr(d); - - if constexpr (std::is_same::value) evtTok = evtTokFromEventEnum(e); - else evtTok = evtTokFromEventAttr(e); - - return success(); - } - - if constexpr (HasGetSrcPipeAttr::value && - HasGetDstPipeAttr::value && - HasGetEventIdAttr::value) { - auto s = op.getSrcPipeAttr(); - auto d = op.getDstPipeAttr(); - auto e = op.getEventIdAttr(); - srcTok = pipeTokFromPipeAttr(s); - dstTok = pipeTokFromPipeAttr(d); - evtTok = evtTokFromEventAttr(e); - return success(); - } - - return extractSyncTripletTokens(op.getOperation(), srcTok, dstTok, evtTok, rewriter); -} -struct PTOSetFlagToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::SetFlagOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - auto *ctx = rewriter.getContext(); - - std::string srcTok, dstTok, evtTok; - if (failed(extractSyncTokens(op, srcTok, dstTok, evtTok, rewriter))) - return failure(); - - auto argsAttr = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, srcTok), - emitc::OpaqueAttr::get(ctx, dstTok), - emitc::OpaqueAttr::get(ctx, evtTok), - }); - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, "set_flag", - /*args=*/argsAttr, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{}); - return success(); - } -}; - -struct PTOWaitFlagToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::WaitFlagOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - auto *ctx = rewriter.getContext(); - - std::string srcTok, dstTok, evtTok; - if (failed(extractSyncTokens(op, srcTok, dstTok, evtTok, rewriter))) - return failure(); - - auto argsAttr = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, srcTok), - emitc::OpaqueAttr::get(ctx, dstTok), - emitc::OpaqueAttr::get(ctx, evtTok), - }); - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, "wait_flag", - /*args=*/argsAttr, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{}); - return success(); - } -}; - -struct PTOSyncToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::TSyncOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - SmallVector operands; - operands.reserve(adaptor.getEvents().size()); - for (Value event : adaptor.getEvents()) - operands.push_back(peelUnrealized(event)); - - rewriter.create( - op.getLoc(), TypeRange{}, "TSYNC", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange(operands)); - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOSyncAllToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - static StringRef coreTypeTok(pto::SyncCoreType coreType) { - switch (coreType) { - case pto::SyncCoreType::AIVOnly: - return "SyncCoreType::AIVOnly"; - case pto::SyncCoreType::AICOnly: - return "SyncCoreType::AICOnly"; - case pto::SyncCoreType::Mix: - return "SyncCoreType::Mix"; - } - llvm_unreachable("unhandled SyncCoreType"); - } - - LogicalResult matchAndRewrite(mlir::pto::SyncAllOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto mode = op.getMode().getValue(); - auto coreType = op.getCoreType().getValue(); - - auto buildGmWorkspace = [&]() -> FailureOr { - Value gm = peelUnrealized(adaptor.getGmWorkspace()); - if (isEmitCGlobalTensorLikeType(gm.getType())) - return gm; - - auto memTy = dyn_cast(op.getGmWorkspace().getType()); - if (!memTy) - return failure(); - - Value gt = buildGlobalTensorFromMemref(rewriter, op.getLoc(), gm, memTy, - op.getGmWorkspace().getDefiningOp() - ? op.getGmWorkspace().getDefiningOp() - : op.getOperation()); - if (!gt) - return failure(); - return gt; - }; - - if (mode == pto::SyncAllMode::Hard) { - std::string callee = "SYNCALL<" + coreTypeTok(coreType).str() + ">"; - rewriter.create(op.getLoc(), TypeRange{}, callee, - ArrayAttr{}, ArrayAttr{}, - ValueRange{}); - rewriter.eraseOp(op); - return success(); - } - - FailureOr gmWorkspace = buildGmWorkspace(); - if (failed(gmWorkspace)) - return rewriter.notifyMatchFailure(op, - "failed to build gm_workspace GlobalTensor"); - - auto i32Ty = emitc::OpaqueType::get(rewriter.getContext(), "int32_t"); - Value usedCores = adaptor.getUsedCores() - ? peelUnrealized(adaptor.getUsedCores()) - : makeEmitCIntConstant(rewriter, op.getLoc(), i32Ty, 0); - if (usedCores.getType() != i32Ty) - usedCores = rewriter.create(op.getLoc(), i32Ty, usedCores) - .getResult(); - - std::string callee = - "SYNCALL"; - - SmallVector operands{*gmWorkspace}; - switch (coreType) { - case pto::SyncCoreType::AIVOnly: { - FailureOr ubWorkspace = - buildSyncAllWorkspaceTileValue(rewriter, op.getLoc(), - op.getUbWorkspace(), - adaptor.getUbWorkspace()); - if (failed(ubWorkspace)) - return rewriter.notifyMatchFailure( - op, "failed to materialize ub_workspace tile"); - operands.push_back(*ubWorkspace); - break; - } - case pto::SyncCoreType::AICOnly: { - FailureOr l1Workspace = - buildSyncAllWorkspaceTileValue(rewriter, op.getLoc(), - op.getL1Workspace(), - adaptor.getL1Workspace()); - if (failed(l1Workspace)) - return rewriter.notifyMatchFailure( - op, "failed to materialize l1_workspace tile"); - operands.push_back(*l1Workspace); - break; - } - case pto::SyncCoreType::Mix: { - FailureOr ubWorkspace = - buildSyncAllWorkspaceTileValue(rewriter, op.getLoc(), - op.getUbWorkspace(), - adaptor.getUbWorkspace()); - FailureOr l1Workspace = - buildSyncAllWorkspaceTileValue(rewriter, op.getLoc(), - op.getL1Workspace(), - adaptor.getL1Workspace()); - if (failed(ubWorkspace) || failed(l1Workspace)) - return rewriter.notifyMatchFailure( - op, "failed to materialize mixed syncall workspace tiles"); - operands.push_back(*ubWorkspace); - operands.push_back(*l1Workspace); - break; - } - } - - operands.push_back(usedCores); - rewriter.create(op.getLoc(), TypeRange{}, callee, - ArrayAttr{}, ArrayAttr{}, - ValueRange(operands)); - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOSyncFlagDynToEmitC : public ConversionPattern { - PTOSyncFlagDynToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - StringRef opName, StringRef callee) - : ConversionPattern(typeConverter, opName, /*benefit=*/1, ctx), - callee(callee.str()) {} - - LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - if (operands.size() != 1) - return rewriter.notifyMatchFailure(op, "expected exactly one dynamic event-id operand"); - - auto srcAttr = op->getAttrOfType("src_pipe"); - auto dstAttr = op->getAttrOfType("dst_pipe"); - if (!srcAttr || !dstAttr) - return rewriter.notifyMatchFailure(op, "missing PipeAttr src_pipe/dst_pipe attrs"); - - auto *ctx = rewriter.getContext(); - std::string srcTok = pipeTokFromPipeAttr(srcAttr); - std::string dstTok = pipeTokFromPipeAttr(dstAttr); - - Value eventVal = operands.front(); - eventVal = - emitCCast(rewriter, op->getLoc(), emitc::OpaqueType::get(ctx, "event_t"), eventVal); - - auto argsAttr = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, srcTok), - emitc::OpaqueAttr::get(ctx, dstTok), - IntegerAttr::get(IndexType::get(ctx), 0), - }); - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, callee, - /*args=*/argsAttr, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{eventVal}); - return success(); - } - -private: - std::string callee; -}; - -struct PTOGetBufToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::GetBufOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - auto *ctx = rewriter.getContext(); - - auto opTypeOr = parseSyncOpTypeLikeAttr(op.getOpTypeAttr()); - if (failed(opTypeOr)) - return rewriter.notifyMatchFailure(op, "get_buf expects pipe_event_type/sync_op_type attr"); - auto pipe = mapSyncOpTypeToPipe(*opTypeOr); - if (!isConcreteSyncPipe(pipe)) - return rewriter.notifyMatchFailure(op, "get_buf op_type cannot map to a concrete pipe"); - std::string pipeTok = pipeTokFromPipeEnum(pipe); - auto argsAttr = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeTok), - op.getBufIdAttr(), - op.getModeAttr(), - }); - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, "get_buf", - /*args=*/argsAttr, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{}); - return success(); - } -}; - -struct PTORlsBufToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::RlsBufOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - auto *ctx = rewriter.getContext(); - - auto opTypeOr = parseSyncOpTypeLikeAttr(op.getOpTypeAttr()); - if (failed(opTypeOr)) - return rewriter.notifyMatchFailure(op, "rls_buf expects pipe_event_type/sync_op_type attr"); - auto pipe = mapSyncOpTypeToPipe(*opTypeOr); - if (!isConcreteSyncPipe(pipe)) - return rewriter.notifyMatchFailure(op, "rls_buf op_type cannot map to a concrete pipe"); - std::string pipeTok = pipeTokFromPipeEnum(pipe); - auto argsAttr = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeTok), - op.getBufIdAttr(), - op.getModeAttr(), - }); - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, "rls_buf", - /*args=*/argsAttr, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{}); - return success(); - } -}; - -struct PTOSetFFTsToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::SetFFTsOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto *ctx = rewriter.getContext(); - auto loc = op.getLoc(); - - Value fftsAddr = peelUnrealized(adaptor.getFfts()); - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - - if (isSetFFTsPointerLikeType(fftsAddr.getType())) { - auto castTyAttr = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - fftsAddr = - rewriter - .create(loc, u64Ty, "reinterpret_cast", - /*args=*/ArrayAttr{}, - /*templateArgs=*/castTyAttr, - /*operands=*/ValueRange{fftsAddr}) - .getResult(0); - } else if (fftsAddr.getType() != u64Ty) { - fftsAddr = - rewriter.create(loc, u64Ty, fftsAddr).getResult(); - } - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, "set_ffts_base_addr", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{fftsAddr}); - return success(); - } -}; - -struct PTOSyncSetToEmitC : public OpConversionPattern { - PTOSyncSetToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult - matchAndRewrite(mlir::pto::SyncSetOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - auto *ctx = rewriter.getContext(); - IntegerAttr eventIdAttr = op.getEventIdAttr(); - Value eventIdDyn = adaptor.getEventIdDyn(); - int64_t fftsMode = 2; - if (IntegerAttr fftsModeAttr = op.getFftsModeAttr()) - fftsMode = fftsModeAttr.getInt(); - - if ((eventIdAttr != nullptr) == static_cast(eventIdDyn)) - return rewriter.notifyMatchFailure( - op, "expects exactly one of static event_id attr or dynamic event_id operand"); - - // A5 inter-core sync mirrors +16 only for cube-side producer (PIPE_FIX). - // Vec-side producer (PIPE_MTE3) emits a single set; hardware handles the - // subblock mapping in PTO-ISA custom flow. - if (targetArch == PTOArch::A5) { - pto::PIPE pipe = op.getPipe().getPipe(); - bool needsMirrorPlus16 = (pipe == pto::PIPE::PIPE_FIX); - std::string pipeTok = pipeTokFromPipeAttr(op.getPipe()); - auto emitSet = [&](Value eventOperand, IntegerAttr eventLiteral, - bool isDynamic) { - if (isDynamic) { - auto args = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeTok), - IntegerAttr::get(IndexType::get(ctx), 0), - }); - rewriter.create(loc, TypeRange{}, "set_intra_block", - /*args=*/args, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{eventOperand}); - return; - } - auto args = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeTok), - eventLiteral, - }); - rewriter.create(loc, TypeRange{}, "set_intra_block", - /*args=*/args, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{}); - }; - - if (eventIdAttr) { - emitSet(Value{}, eventIdAttr, /*isDynamic=*/false); - if (needsMirrorPlus16) { - auto plus16 = IntegerAttr::get(eventIdAttr.getType(), - eventIdAttr.getInt() + 16); - emitSet(Value{}, plus16, /*isDynamic=*/false); - } - } else { - Value eventI32 = castInterCoreEventIdToI32(rewriter, loc, eventIdDyn); - emitSet(eventI32, IntegerAttr{}, /*isDynamic=*/true); - if (needsMirrorPlus16) { - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value c16 = makeEmitCIntConstant(rewriter, loc, i32Ty, 16); - Value eventI32Plus16 = - rewriter.create(loc, i32Ty, eventI32, c16).getResult(); - emitSet(eventI32Plus16, IntegerAttr{}, /*isDynamic=*/true); - } - } - - rewriter.eraseOp(op); - return success(); - } - - InterCoreSyncCallDesc desc; - if (eventIdAttr) { - desc = buildInterCoreSyncSetCall(rewriter, loc, targetArch, op.getPipe(), - eventIdAttr, fftsMode); - } else { - desc = buildInterCoreSyncSetCallDyn(rewriter, loc, targetArch, op.getPipe(), - eventIdDyn, fftsMode); - } - rewriter.create(loc, TypeRange{}, desc.callee, - /*args=*/desc.args, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/desc.operands); - - rewriter.eraseOp(op); - return success(); - } - - PTOArch targetArch; -}; - -struct PTOSyncWaitToEmitC : public OpConversionPattern { - PTOSyncWaitToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult - matchAndRewrite(mlir::pto::SyncWaitOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - IntegerAttr eventIdAttr = op.getEventIdAttr(); - Value eventIdDyn = adaptor.getEventIdDyn(); - - if ((eventIdAttr != nullptr) == static_cast(eventIdDyn)) - return rewriter.notifyMatchFailure( - op, "expects exactly one of static event_id attr or dynamic event_id operand"); - - InterCoreSyncCallDesc desc; - if (eventIdAttr) { - desc = buildInterCoreSyncWaitCall(rewriter, targetArch, op.getPipe(), - eventIdAttr); - } else { - desc = buildInterCoreSyncWaitCallDyn(rewriter, loc, targetArch, op.getPipe(), - eventIdDyn); - } - rewriter.create(loc, TypeRange{}, desc.callee, - desc.args, ArrayAttr{}, desc.operands); - - rewriter.eraseOp(op); - return success(); - } - - PTOArch targetArch; -}; - -// GetBlockIdxOp Lowering (pto.get_block_idx -> get_block_idx()) -struct PTOGetBlockIdxToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::pto::GetBlockIdxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - rewriter.replaceOpWithNewOp( - op, op.getType(), "get_block_idx", ValueRange{}, ArrayAttr{}, - ArrayAttr{}); - - return success(); - } -}; - -// GetBlockNumOp Lowering (pto.get_block_num -> get_block_num()) -struct PTOGetBlockNumToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::pto::GetBlockNumOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - rewriter.replaceOpWithNewOp( - op, op.getType(), "get_block_num", ValueRange{}, ArrayAttr{}, - ArrayAttr{}); - - return success(); - } -}; - -// GetSubBlockIdxOp Lowering (pto.get_block_idx -> get_subblockid()) -struct PTOGetSubBlockIdxToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::pto::GetSubBlockIdxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - rewriter.replaceOpWithNewOp( - op, op.getType(), "get_subblockid", ValueRange{}, ArrayAttr{}, - ArrayAttr{}); - - return success(); - } -}; - -// GetSubBlockNumOp Lowering. -struct PTOGetSubBlockNumToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::pto::GetSubBlockNumOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - rewriter.replaceOpWithNewOp( - op, op.getType(), "get_subblockdim", ValueRange{}, ArrayAttr{}, - ArrayAttr{}); - - return success(); - } -}; - - -struct PTOMScatterToMSCATTER : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::MScatterOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto *ctx = rewriter.getContext(); - Value src = peelUnrealized(adaptor.getSrc()); - Value idx = peelUnrealized(adaptor.getIdx()); - Value mem = peelUnrealized(adaptor.getMem()); - - Value memArg = maybeWrapGlobalMemrefAsGlobalTensor( - rewriter, op.getLoc(), mem, op.getMem().getType(), op.getOperation()); - - auto scatterAtomicTok = [&](pto::ScatterAtomicOp atomic) -> StringRef { - switch (atomic) { - case pto::ScatterAtomicOp::None: - return "pto::ScatterAtomicOp::None"; - case pto::ScatterAtomicOp::Add: - return "pto::ScatterAtomicOp::Add"; - case pto::ScatterAtomicOp::Max: - return "pto::ScatterAtomicOp::Max"; - case pto::ScatterAtomicOp::Min: - return "pto::ScatterAtomicOp::Min"; - } - llvm_unreachable("unknown ScatterAtomicOp"); - }; - auto scatterOobTok = [&](pto::ScatterOOB mode) -> StringRef { - switch (mode) { - case pto::ScatterOOB::Undefined: - return "pto::ScatterOOB::Undefined"; - case pto::ScatterOOB::Skip: - return "pto::ScatterOOB::Skip"; - case pto::ScatterOOB::Clamp: - return "pto::ScatterOOB::Clamp"; - case pto::ScatterOOB::Wrap: - return "pto::ScatterOOB::Wrap"; - } - llvm_unreachable("unknown ScatterOOB"); - }; - - SmallVector templateArgVec; - const bool rowCoalesce = - isRowCoalescedMGatherIndexType(op.getSrc().getType(), op.getIdx().getType()); - templateArgVec.push_back(emitc::OpaqueAttr::get( - ctx, rowCoalesce ? "pto::Coalesce::Row" : "pto::Coalesce::Elem")); - if (op.getScatterAtomicOp() != pto::ScatterAtomicOp::None || - op.getScatterOob() != pto::ScatterOOB::Undefined) { - templateArgVec.push_back(emitc::OpaqueAttr::get( - ctx, scatterAtomicTok(op.getScatterAtomicOp()))); - if (op.getScatterOob() != pto::ScatterOOB::Undefined) - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, scatterOobTok(op.getScatterOob()))); - } - ArrayAttr templateArgs = rewriter.getArrayAttr(templateArgVec); - - rewriter.create( - op.getLoc(), TypeRange{}, "MSCATTER", - ArrayAttr{}, templateArgs, - ValueRange{memArg, src, idx}); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTOSetValToSETVAL : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSetValOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value dst = peelUnrealized(adaptor.getDst()); - Value val = peelUnrealized(adaptor.getVal()); - - // ---- offset: SSA index operand ---- - Value offset = peelUnrealized(adaptor.getOffset()); - - // Emit a marker call and let the ptoas post-processing step lower it to - // the corresponding tile setter. - rewriter.create( - op.getLoc(), TypeRange{}, "PTOAS__TILE_SET_VALUE", - ArrayAttr{}, ArrayAttr{}, ValueRange{dst, offset, val}); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTOGetValToGETVAL : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGetValOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src = peelUnrealized(adaptor.getSrc()); - - // ---- offset: SSA index operand ---- - Value offset = peelUnrealized(adaptor.getOffset()); - - // Emit a marker call and let the ptoas post-processing step lower it to - // the corresponding tile getter. - Type dstTy = getTypeConverter()->convertType(op.getDst().getType()); - if (!dstTy) - return failure(); - auto call = rewriter.create( - op.getLoc(), - TypeRange{dstTy}, - "PTOAS__TILE_GET_VALUE", - ArrayAttr{}, ArrayAttr{}, - ValueRange{src, offset}); - - rewriter.replaceOp(op, call.getResults()); - return success(); - } -}; - -struct PTOTAxpyToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAxpyOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value scalar = peelUnrealized(adaptor.getScalar()); - - rewriter.create( - loc, TypeRange{}, "TAXPY", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOHistogramToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::THistogramOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value idx = peelUnrealized(adaptor.getIdx()); - Value dst = peelUnrealized(adaptor.getDst()); - - auto templateArgs = rewriter.getArrayAttr({emitc::OpaqueAttr::get( - ctx, op.getIsMSB() ? "HistByte::BYTE_1" : "HistByte::BYTE_0")}); - rewriter.create( - loc, TypeRange{}, "THISTOGRAM", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, - /*operands=*/ValueRange{dst, src, idx}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOGetScaleAddrToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGetScaleAddrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TGET_SCALE_ADDR", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOSetValidShapeToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::SetValidShapeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto peelAllCasts = [](Value v) { - while (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(0); - if (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(); - return v; - }; - auto isTileLike = [](Value v) -> bool { - auto ot = dyn_cast(v.getType()); - if (!ot) - return false; - StringRef s = ot.getValue(); - return s.contains("Tile<") || s.contains("ConvTile<"); - }; - - Value src = peelAllCasts(peelUnrealized(adaptor.getSource())); - Value row = peelUnrealized(adaptor.getValidRow()); - Value col = peelUnrealized(adaptor.getValidCol()); - - if (!isTileLike(src)) - return rewriter.notifyMatchFailure( - op, "set_validshape source must lower to a tile-like value"); - - rewriter.create( - op.getLoc(), TypeRange{}, "PTOAS__TILE_SET_VALIDSHAPE", ArrayAttr{}, - ArrayAttr{}, ValueRange{src, row, col}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOGetValidShapeToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::GetValidShapeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto peelAllCasts = [](Value v) { - while (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(0); - if (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(); - return v; - }; - auto isTileLike = [](Value v) -> bool { - auto ot = dyn_cast(v.getType()); - if (!ot) - return false; - StringRef s = ot.getValue(); - return s.contains("Tile<") || s.contains("ConvTile<"); - }; - - Value src = peelAllCasts(peelUnrealized(adaptor.getSource())); - if (!isTileLike(src)) - return rewriter.notifyMatchFailure( - op, "get_validshape source must lower to a tile-like value"); - - auto resultTy = getTypeConverter()->convertType(rewriter.getIndexType()); - if (!resultTy) - return failure(); - - Value row = rewriter - .create( - op.getLoc(), resultTy, - "PTOAS__TILE_GET_VALID_ROW", ArrayAttr{}, - ArrayAttr{}, ValueRange{src}) - .getResult(0); - Value col = rewriter - .create( - op.getLoc(), resultTy, - "PTOAS__TILE_GET_VALID_COL", ArrayAttr{}, - ArrayAttr{}, ValueRange{src}) - .getResult(0); - rewriter.replaceOp(op, ValueRange{row, col}); - return success(); - } -}; - -struct PTOTAssignToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAssignOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto peelAllCasts = [](Value v) { - while (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(0); - if (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(); - return v; - }; - auto isTileLike = [](Value v) -> bool { - auto ot = dyn_cast(v.getType()); - if (!ot) - return false; - StringRef s = ot.getValue(); - return s.contains("Tile<") || s.contains("ConvTile<"); - }; - - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value tile = peelAllCasts(peelUnrealized(adaptor.getTile())); - if (!isTileLike(tile)) - return rewriter.notifyMatchFailure( - op, "tassign tile must lower to a tile-like value"); - - Value addr = peelUnrealized(adaptor.getAddr()); - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - if (isa(addr.getType()) || - (isa(addr.getType()) && - cast(addr.getType()).getValue().ends_with("*"))) { - auto rcU64 = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - addr = rewriter - .create(loc, u64Ty, "reinterpret_cast", - ArrayAttr{}, rcU64, - ValueRange{addr}) - .getResult(0); - } else if (addr.getType() != u64Ty) { - addr = rewriter.create(loc, u64Ty, addr).getResult(); - } - - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{tile, addr}); - rewriter.replaceOp(op, tile); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.load_scalar / pto.store_scalar lowering -> ptr[offset] -//===----------------------------------------------------------------------===// - -static Type getPointerLikeElementType(Type type) { - if (auto ptrTy = dyn_cast(type)) - return ptrTy.getElementType(); - if (auto memTy = dyn_cast(type)) - return memTy.getElementType(); - return Type(); -} - -struct PTOPtrToIntToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::PtrToIntOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value ptr = peelUnrealized(adaptor.getPtr()); - Type dstTy = getTypeConverter()->convertType(op.getResult().getType()); - if (!dstTy) - return failure(); - - auto dstOpaque = dyn_cast(dstTy); - if (!dstOpaque) - return failure(); - - auto templateArgs = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), - dstOpaque.getValue())}); - auto cast = rewriter.create( - op.getLoc(), dstTy, "reinterpret_cast", ArrayAttr{}, templateArgs, - ValueRange{ptr}); - rewriter.replaceOp(op, cast.getResult(0)); - return success(); - } -}; - -struct PTOIntToPtrToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::IntToPtrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value addr = peelUnrealized(adaptor.getAddr()); - Type dstTy = getTypeConverter()->convertType(op.getResult().getType()); - if (!dstTy) - return failure(); - - Type dstElemTy = getPointerLikeElementType(op.getResult().getType()); - if (!dstElemTy) - return failure(); - - std::string castType = - std::string("__gm__ ") + getEmitCScalarTypeToken(dstElemTy) + "*"; - auto templateArgs = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), - castType)}); - auto cast = rewriter.create( - op.getLoc(), dstTy, "reinterpret_cast", ArrayAttr{}, templateArgs, - ValueRange{addr}); - rewriter.replaceOp(op, cast.getResult(0)); - return success(); - } -}; - -struct PTOLoadScalarToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::LoadScalarOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value ptr = peelUnrealized(adaptor.getPtr()); - Value offset = peelUnrealized(adaptor.getOffset()); - - Type dstTy = getTypeConverter()->convertType(op.getValue().getType()); - if (!dstTy) - return failure(); - - auto call = rewriter.create( - op.getLoc(), TypeRange{dstTy}, "PTOAS__PTR_LOAD", - ArrayAttr{}, ArrayAttr{}, ValueRange{ptr, offset}); - - rewriter.replaceOp(op, call.getResults()); - return success(); - } -}; - -struct PTOStoreScalarToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::StoreScalarOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value ptr = peelUnrealized(adaptor.getPtr()); - Value offset = peelUnrealized(adaptor.getOffset()); - Value val = peelUnrealized(adaptor.getValue()); - - rewriter.create( - op.getLoc(), TypeRange{}, "PTOAS__PTR_STORE", - ArrayAttr{}, ArrayAttr{}, ValueRange{ptr, offset, val}); - rewriter.create( - op.getLoc(), TypeRange{}, "PTOAS__SCALAR_GM_STORE_FLUSH", - ArrayAttr{}, ArrayAttr{}, ValueRange{ptr}); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.tabs lowering -> TABS(dst, src) -//===----------------------------------------------------------------------===// - -struct PTOTAbsToTABS : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAbsOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - // intrinsic: TABS(dst, src) - rewriter.create( - op.getLoc(), TypeRange{}, "TABS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tadd lowering -> TADD(dst, src0, src1) -//===----------------------------------------------------------------------===// - -struct PTOTAddToTADD : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAddOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TADD", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOInitializeL2G2LPipeToEmitC - : public OpConversionPattern { - PTOInitializeL2G2LPipeToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult matchAndRewrite(mlir::pto::InitializeL2G2LPipeOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto tpipeTok = buildTPipeTokenFromInitOp(op.getOperation(), targetArch); - if (failed(tpipeTok)) - return rewriter.notifyMatchFailure(op, "failed to build TPipe token"); - - auto *ctx = rewriter.getContext(); - auto emitPipeTy = - cast(getTypeConverter()->convertType(op.getPipe().getType())); - - Value gmAddr = peelUnrealized(adaptor.getGmAddr()); - gmAddr = materializeTensorViewDataPointer( - rewriter, op.getLoc(), gmAddr, op.getGmAddr().getType()); - Value localAddr = - op.getLocalAddr() ? peelUnrealized(adaptor.getLocalAddr()) : Value(); - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value zero = makeEmitCIntConstant(rewriter, op.getLoc(), i32Ty, 0); - - Value c2vBuf = zero; - Value v2cBuf = zero; - if (op.getDirMask() == 1) - c2vBuf = localAddr ? localAddr : zero; - else if (op.getDirMask() == 2) - v2cBuf = localAddr ? localAddr : zero; - else if (op.getDirMask() == 3) { - if (localAddr) { - if (!op.getPeerLocalAddr()) - return rewriter.notifyMatchFailure( - op, "bidirectional l2g2l pipe requires peer local buffer"); - c2vBuf = localAddr; - v2cBuf = peelUnrealized(adaptor.getPeerLocalAddr()); - } - } else - return rewriter.notifyMatchFailure(op, "unsupported dir_mask"); - - rewriter.replaceOpWithNewOp( - op, TypeRange{emitPipeTy}, *tpipeTok, ArrayAttr{}, ArrayAttr{}, - ValueRange{gmAddr, c2vBuf, v2cBuf}); - return success(); - } - - PTOArch targetArch; -}; - -struct PTOInitializeL2LPipeToEmitC - : public OpConversionPattern { - PTOInitializeL2LPipeToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult matchAndRewrite(mlir::pto::InitializeL2LPipeOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto tpipeTok = buildTPipeTokenFromInitOp(op.getOperation(), targetArch); - if (failed(tpipeTok)) - return rewriter.notifyMatchFailure(op, "failed to build TPipe token"); - - auto *ctx = rewriter.getContext(); - auto emitPipeTy = - cast(getTypeConverter()->convertType(op.getPipe().getType())); - - auto gmPtrTy = - emitc::PointerType::get(emitc::OpaqueType::get(ctx, "__gm__ void")); - Value nullGm = - makeEmitCOpaqueConstant(rewriter, op.getLoc(), gmPtrTy, "nullptr"); - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value zero = makeEmitCIntConstant(rewriter, op.getLoc(), i32Ty, 0); - Value localAddr = peelUnrealized(adaptor.getLocalAddr()); - - Value c2vBuf = zero; - Value v2cBuf = zero; - if (op.getDirMask() == 1) - c2vBuf = localAddr; - else if (op.getDirMask() == 2) - v2cBuf = localAddr; - else if (op.getDirMask() == 3) { - c2vBuf = localAddr; - v2cBuf = peelUnrealized(adaptor.getPeerLocalAddr()); - } else - return rewriter.notifyMatchFailure(op, "unsupported dir_mask"); - - rewriter.replaceOpWithNewOp( - op, TypeRange{emitPipeTy}, *tpipeTok, ArrayAttr{}, ArrayAttr{}, - ValueRange{nullGm, c2vBuf, v2cBuf}); - return success(); - } - - PTOArch targetArch; -}; - -struct PTOBuildAsyncSessionToEmitC - : public OpConversionPattern { - PTOBuildAsyncSessionToEmitC(TypeConverter &typeConverter, MLIRContext *ctx) - : OpConversionPattern(typeConverter, ctx) {} - - LogicalResult matchAndRewrite(mlir::pto::BuildAsyncSessionOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto *ctx = rewriter.getContext(); - Location loc = op.getLoc(); - - auto sessionTy = - dyn_cast(getTypeConverter()->convertType(op.getSession().getType())); - if (!sessionTy) - return rewriter.notifyMatchFailure(op, "failed to convert async session type"); - - FailureOr scratchTile = - buildAsyncScratchTileValue(rewriter, loc, op.getScratch(), - adaptor.getScratch()); - if (failed(scratchTile)) - return rewriter.notifyMatchFailure(op, "failed to materialize async scratch tile"); - - Value workspace = - castToGMBytePointer(rewriter, loc, peelUnrealized(adaptor.getWorkspace())); - - Value session = rewriter - .create( - loc, sessionTy, emitc::OpaqueAttr::get(ctx, "")) - .getResult(); - - auto u32Ty = emitc::OpaqueType::get(ctx, "uint32_t"); - - auto makeU32Const = [&](uint64_t value) -> Value { - return makeEmitCOpaqueConstant(rewriter, loc, u32Ty, - std::to_string(value) + "u"); - }; - uint64_t syncId = op.getSyncIdAttr() ? op.getSyncIdAttr().getInt() : 0; - uint64_t blockBytes = - op.getBlockBytesAttr() ? op.getBlockBytesAttr().getInt() : 32 * 1024; - uint64_t commBlockOffset = - op.getCommBlockOffsetAttr() ? op.getCommBlockOffsetAttr().getInt() : 0; - uint64_t queueNum = op.getQueueNumAttr() ? op.getQueueNumAttr().getInt() : 1; - uint64_t channelGroupIdx = op.getChannelGroupIdxAttr() - ? op.getChannelGroupIdxAttr().getInt() - : UINT32_MAX; - - Value syncIdVal = makeU32Const(syncId); - Value channelGroupIdxVal = - channelGroupIdx == UINT32_MAX - ? makeEmitCOpaqueConstant(rewriter, loc, u32Ty, "UINT32_MAX") - : makeU32Const(channelGroupIdx); - - auto baseConfigTy = - emitc::OpaqueType::get(ctx, "pto::comm::sdma::SdmaBaseConfig"); - Value baseConfig = - rewriter - .create( - loc, baseConfigTy, - emitc::OpaqueAttr::get( - ctx, "{" + std::to_string(blockBytes) + "ULL, " + - std::to_string(commBlockOffset) + "ULL, " + - std::to_string(queueNum) + "u}")) - .getResult(); - - rewriter.create( - loc, TypeRange{}, "pto::comm::BuildAsyncSession", - ArrayAttr{}, ArrayAttr{}, - ValueRange{*scratchTile, workspace, session, syncIdVal, baseConfig, - channelGroupIdxVal}); - - rewriter.replaceOp(op, session); - return success(); - } -}; - -template -struct PTOAsyncTransferToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - explicit PTOAsyncTransferToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - StringRef callee) - : OpConversionPattern(typeConverter, ctx), callee(callee.str()) {} - - LogicalResult matchAndRewrite(AsyncOp op, typename AsyncOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value dstGT = dst; - Value srcGT = src; - if (!isEmitCGlobalTensorLikeType(dstGT.getType())) { - auto dstMrTy = dyn_cast(op.getDst().getType()); - if (!dstMrTy) - return rewriter.notifyMatchFailure(op, "expected dst to lower to GlobalTensor or memref"); - dstGT = buildGlobalTensorFromMemref(rewriter, op.getLoc(), dst, dstMrTy, - op.getDst().getDefiningOp() - ? op.getDst().getDefiningOp() - : op.getOperation()); - } - if (!isEmitCGlobalTensorLikeType(srcGT.getType())) { - auto srcMrTy = dyn_cast(op.getSrc().getType()); - if (!srcMrTy) - return rewriter.notifyMatchFailure(op, "expected src to lower to GlobalTensor or memref"); - srcGT = buildGlobalTensorFromMemref(rewriter, op.getLoc(), src, srcMrTy, - op.getSrc().getDefiningOp() - ? op.getSrc().getDefiningOp() - : op.getOperation()); - } - if (!dstGT || !srcGT) - return rewriter.notifyMatchFailure(op, "failed to build GlobalTensor operands"); - - Type eventTy = this->getTypeConverter()->convertType(op.getEvent().getType()); - if (!eventTy) - return rewriter.notifyMatchFailure(op, "failed to convert async event type"); - - rewriter.replaceOpWithNewOp( - op, TypeRange{eventTy}, callee, ArrayAttr{}, ArrayAttr{}, - ValueRange{dstGT, srcGT, peelUnrealized(adaptor.getSession())}); - return success(); - } - - std::string callee; -}; - -template -struct PTOAsyncEventToEmitC : public OpConversionPattern { - explicit PTOAsyncEventToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - StringRef callee) - : OpConversionPattern(typeConverter, ctx), - callee(callee.str()) {} - - LogicalResult matchAndRewrite(AsyncEventOp op, - typename AsyncEventOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type resultTy = - this->getTypeConverter()->convertType(op.getCompleted().getType()); - if (!resultTy) - return rewriter.notifyMatchFailure(op, "failed to convert async event result type"); - - rewriter.replaceOpWithNewOp( - op, TypeRange{resultTy}, callee, ArrayAttr{}, ArrayAttr{}, - ValueRange{peelUnrealized(adaptor.getEvent()), - peelUnrealized(adaptor.getSession())}); - return success(); - } - - std::string callee; -}; - -static FailureOr buildCommGlobalTensorValue( - ConversionPatternRewriter &rewriter, Location loc, Value originalValue, - Value emittedValue, Operation *anchor) { - Value value = peelUnrealized(emittedValue); - if (isEmitCGlobalTensorLikeType(value.getType())) - return value; - - auto memTy = dyn_cast(originalValue.getType()); - if (!memTy) - return failure(); - - Value gt = buildGlobalTensorFromMemref(rewriter, loc, value, memTy, anchor); - if (!gt) - return failure(); - return gt; -} - -static FailureOr buildCommTileValue(ConversionPatternRewriter &rewriter, - Location loc, Value originalValue, - Value emittedValue) { - Value value = peelUnrealized(emittedValue); - if (auto opaqueTy = dyn_cast(value.getType())) { - StringRef typeStr = opaqueTy.getValue(); - if (typeStr.contains("Tile<") || typeStr.contains("ConvTile<")) - return value; - } - return buildAsyncScratchTileValue(rewriter, loc, originalValue, emittedValue); -} - -static FailureOr buildCollectiveParallelGroup( - ConversionPatternRewriter &rewriter, Location loc, - ArrayRef groupGTs, int64_t root) { - if (groupGTs.empty()) - return failure(); - - auto firstTy = dyn_cast(groupGTs.front().getType()); - if (!firstTy) - return failure(); - - auto *ctx = rewriter.getContext(); - auto arrayTy = emitc::ArrayType::get({static_cast(groupGTs.size())}, - firstTy); - auto groupArray = cast>( - rewriter - .create(loc, arrayTy, - emitc::OpaqueAttr::get(ctx, "{}")) - .getResult()); - - auto indexTy = emitc::OpaqueType::get(ctx, "int"); - for (auto [idx, groupVal] : llvm::enumerate(groupGTs)) { - Value idxVal = - makeEmitCIntConstant(rewriter, loc, indexTy, static_cast(idx)); - Value slot = - rewriter.create(loc, groupArray, ValueRange{idxVal}) - .getResult(); - rewriter.create(loc, slot, groupVal); - } - - std::string pgTypeStr = - (Twine("pto::comm::ParallelGroup<") + firstTy.getValue() + ">").str(); - auto pgTy = emitc::OpaqueType::get(ctx, pgTypeStr); - Value sizeVal = makeEmitCIntConstant(rewriter, loc, indexTy, - static_cast(groupGTs.size())); - Value rootVal = makeEmitCIntConstant(rewriter, loc, indexTy, root); - return rewriter - .create( - loc, TypeRange{pgTy}, (Twine(pgTypeStr) + "::Create").str(), - ArrayAttr{}, ArrayAttr{}, ValueRange{groupArray, sizeVal, rootVal}) - .getResult(0); -} - -static std::string notifyOpTok(pto::NotifyOp op) { - switch (op) { - case pto::NotifyOp::AtomicAdd: - return "pto::comm::NotifyOp::AtomicAdd"; - case pto::NotifyOp::Set: - return "pto::comm::NotifyOp::Set"; - } - return "pto::comm::NotifyOp::Set"; -} - -static std::string waitCmpTok(pto::WaitCmp cmp) { - switch (cmp) { - case pto::WaitCmp::EQ: - return "pto::comm::WaitCmp::EQ"; - case pto::WaitCmp::NE: - return "pto::comm::WaitCmp::NE"; - case pto::WaitCmp::GT: - return "pto::comm::WaitCmp::GT"; - case pto::WaitCmp::GE: - return "pto::comm::WaitCmp::GE"; - case pto::WaitCmp::LT: - return "pto::comm::WaitCmp::LT"; - case pto::WaitCmp::LE: - return "pto::comm::WaitCmp::LE"; - } - return "pto::comm::WaitCmp::EQ"; -} - -static std::string reduceOpTok(pto::ReduceOp op) { - switch (op) { - case pto::ReduceOp::Sum: - return "pto::comm::ReduceOp::Sum"; - case pto::ReduceOp::Max: - return "pto::comm::ReduceOp::Max"; - case pto::ReduceOp::Min: - return "pto::comm::ReduceOp::Min"; - } - return "pto::comm::ReduceOp::Sum"; -} - -template -static FailureOr> buildCommGroupGlobalTensors( - ConversionPatternRewriter &rewriter, Location loc, OpTy op, - ValueRange originalGroup, ValueRange emittedGroup) { - SmallVector groupGTs; - groupGTs.reserve(originalGroup.size()); - for (auto [orig, emitted] : llvm::zip(originalGroup, emittedGroup)) { - FailureOr gt = - buildCommGlobalTensorValue(rewriter, loc, orig, emitted, op.getOperation()); - if (failed(gt)) - return failure(); - groupGTs.push_back(*gt); - } - return groupGTs; -} - -template -struct PTOCommCollectiveToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - explicit PTOCommCollectiveToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - StringRef apiName) - : OpConversionPattern(typeConverter, ctx), - apiName(apiName.str()) {} - - LogicalResult matchAndRewrite(CollectiveOp op, typename CollectiveOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - - auto buildPong = [&](Value original, Value emitted, StringRef name) -> FailureOr { - if (!original) - return failure(); - return buildCommTileValue(rewriter, loc, original, emitted); - }; - - if constexpr (std::is_same_v) { - FailureOr srcGT = - buildCommGlobalTensorValue(rewriter, loc, op.getSrc(), adaptor.getSrc(), - op.getOperation()); - FailureOr pingTile = - buildCommTileValue(rewriter, loc, op.getPing(), adaptor.getPing()); - auto groupGTs = - buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); - if (failed(srcGT) || failed(pingTile) || failed(groupGTs)) - return rewriter.notifyMatchFailure(op, "failed to materialize broadcast operands"); - FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); - if (failed(pg)) - return rewriter.notifyMatchFailure(op, "failed to materialize broadcast group"); - if (op.getPong()) { - FailureOr pongTile = - buildPong(op.getPong(), adaptor.getPong(), "__pong"); - if (failed(pongTile)) - return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); - rewriter.create( - loc, TypeRange{}, "pto::comm::TBROADCAST", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *srcGT, *pingTile, *pongTile}); - } else { - rewriter.create( - loc, TypeRange{}, "pto::comm::TBROADCAST", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *srcGT, *pingTile}); - } - } else if constexpr (std::is_same_v) { - FailureOr dstGT = - buildCommGlobalTensorValue(rewriter, loc, op.getDst(), adaptor.getDst(), - op.getOperation()); - FailureOr pingTile = - buildCommTileValue(rewriter, loc, op.getPing(), adaptor.getPing()); - auto groupGTs = - buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); - if (failed(dstGT) || failed(pingTile) || failed(groupGTs)) - return rewriter.notifyMatchFailure(op, "failed to materialize gather operands"); - FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); - if (failed(pg)) - return rewriter.notifyMatchFailure(op, "failed to materialize gather group"); - if (op.getPong()) { - FailureOr pongTile = - buildPong(op.getPong(), adaptor.getPong(), "__pong"); - if (failed(pongTile)) - return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); - rewriter.create( - loc, TypeRange{}, "pto::comm::TGATHER", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *dstGT, *pingTile, *pongTile}); - } else { - rewriter.create( - loc, TypeRange{}, "pto::comm::TGATHER", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *dstGT, *pingTile}); - } - } else if constexpr (std::is_same_v) { - FailureOr srcGT = - buildCommGlobalTensorValue(rewriter, loc, op.getSrc(), adaptor.getSrc(), - op.getOperation()); - FailureOr pingTile = - buildCommTileValue(rewriter, loc, op.getPing(), adaptor.getPing()); - auto groupGTs = - buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); - if (failed(srcGT) || failed(pingTile) || failed(groupGTs)) - return rewriter.notifyMatchFailure(op, "failed to materialize scatter operands"); - FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); - if (failed(pg)) - return rewriter.notifyMatchFailure(op, "failed to materialize scatter group"); - if (op.getPong()) { - FailureOr pongTile = - buildPong(op.getPong(), adaptor.getPong(), "__pong"); - if (failed(pongTile)) - return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); - rewriter.create( - loc, TypeRange{}, "pto::comm::TSCATTER", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *srcGT, *pingTile, *pongTile}); - } else { - rewriter.create( - loc, TypeRange{}, "pto::comm::TSCATTER", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *srcGT, *pingTile}); - } - } else { - FailureOr dstGT = - buildCommGlobalTensorValue(rewriter, loc, op.getDst(), adaptor.getDst(), - op.getOperation()); - FailureOr accTile = - buildCommTileValue(rewriter, loc, op.getAcc(), adaptor.getAcc()); - FailureOr recvPing = - buildCommTileValue(rewriter, loc, op.getRecvPing(), adaptor.getRecvPing()); - auto groupGTs = - buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); - if (failed(dstGT) || failed(accTile) || failed(recvPing) || failed(groupGTs)) - return rewriter.notifyMatchFailure(op, "failed to materialize reduce operands"); - FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); - if (failed(pg)) - return rewriter.notifyMatchFailure(op, "failed to materialize reduce group"); - if (op.getRecvPong()) { - FailureOr recvPong = - buildPong(op.getRecvPong(), adaptor.getRecvPong(), "__recv_pong"); - if (failed(recvPong)) - return rewriter.notifyMatchFailure(op, "failed to materialize recv_pong"); - auto reduceTy = - emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::ReduceOp"); - Value reduceOp = makeEmitCOpaqueConstant(rewriter, loc, reduceTy, - reduceOpTok(op.getReduceOp())); - rewriter.create( - loc, TypeRange{}, "pto::comm::TREDUCE", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *dstGT, *accTile, *recvPing, *recvPong, reduceOp}); - } else { - auto reduceTy = - emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::ReduceOp"); - Value reduceOp = makeEmitCOpaqueConstant(rewriter, loc, reduceTy, - reduceOpTok(op.getReduceOp())); - rewriter.create( - loc, TypeRange{}, "pto::comm::TREDUCE", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *dstGT, *accTile, *recvPing, reduceOp}); - } - } - rewriter.eraseOp(op); - return success(); - } - - std::string apiName; -}; - -template -struct PTOP2PCommToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - explicit PTOP2PCommToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - StringRef callee) - : OpConversionPattern(typeConverter, ctx), callee(callee.str()) {} - - LogicalResult matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - FailureOr dstGT = - buildCommGlobalTensorValue(rewriter, op.getLoc(), op.getDst(), adaptor.getDst(), - op.getOperation()); - FailureOr srcGT = - buildCommGlobalTensorValue(rewriter, op.getLoc(), op.getSrc(), adaptor.getSrc(), - op.getOperation()); - FailureOr pingTile = - buildCommTileValue(rewriter, op.getLoc(), op.getPing(), adaptor.getPing()); - if (failed(dstGT) || failed(srcGT) || failed(pingTile)) - return rewriter.notifyMatchFailure(op, "failed to materialize p2p operands"); - - SmallVector operands{*dstGT, *srcGT, *pingTile}; - std::string actualCallee = callee; - if constexpr (std::is_same_v) { - if (op.getAtomicType() == pto::AtomicType::AtomicAdd) - actualCallee = "pto::comm::TPUT"; - } - if (op.getPong()) { - FailureOr pongTile = - buildCommTileValue(rewriter, op.getLoc(), op.getPong(), adaptor.getPong()); - if (failed(pongTile)) - return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); - operands.push_back(*pongTile); - } - - rewriter.create(op.getLoc(), TypeRange{}, actualCallee, - ArrayAttr{}, ArrayAttr{}, operands); - rewriter.eraseOp(op); - return success(); - } - - std::string callee; -}; - -template -struct PTOSignalCommToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - explicit PTOSignalCommToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - StringRef callee) - : OpConversionPattern(typeConverter, ctx), - callee(callee.str()) {} - - LogicalResult matchAndRewrite(SignalOp op, typename SignalOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - FailureOr signalGT = buildCommGlobalTensorValue( - rewriter, op.getLoc(), op.getSignal(), adaptor.getSignal(), op.getOperation()); - if (failed(signalGT)) - return rewriter.notifyMatchFailure(op, "failed to materialize signal operand"); - - if constexpr (std::is_same_v) { - auto notifyTy = - emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::NotifyOp"); - Value notifyOp = makeEmitCOpaqueConstant( - rewriter, op.getLoc(), notifyTy, notifyOpTok(op.getNotifyOp())); - SmallVector operands{*signalGT, peelUnrealized(adaptor.getValue()), - notifyOp}; - rewriter.create(op.getLoc(), TypeRange{}, callee, - ArrayAttr{}, ArrayAttr{}, operands); - rewriter.eraseOp(op); - } else { - auto waitCmpTy = - emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::WaitCmp"); - Value waitCmp = makeEmitCOpaqueConstant( - rewriter, op.getLoc(), waitCmpTy, waitCmpTok(op.getCmp())); - SmallVector operands{*signalGT, peelUnrealized(adaptor.getCmpValue()), - waitCmp}; - if constexpr (std::is_same_v) { - Type resultTy = this->getTypeConverter()->convertType(op.getResult().getType()); - if (!resultTy) - return rewriter.notifyMatchFailure(op, "failed to convert ttest result type"); - rewriter.replaceOpWithNewOp( - op, TypeRange{resultTy}, callee, ArrayAttr{}, ArrayAttr{}, operands); - } else { - rewriter.create(op.getLoc(), TypeRange{}, callee, - ArrayAttr{}, ArrayAttr{}, operands); - rewriter.eraseOp(op); - } - } - return success(); - } - - std::string callee; -}; - -struct PTODeclareTileMemRefToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::DeclareTileMemRefOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::DeclareTileMemRefOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - Type convertedType = getTypeConverter()->convertType(op.getResult().getType()); - if (!convertedType) - return rewriter.notifyMatchFailure( - op, "failed to convert declare_tile_memref result type"); - rewriter.replaceOp(op, makeEmitCOpaqueConstant(rewriter, op.getLoc(), - convertedType, "nullptr")); - return success(); - } -}; - -struct PTODeclareGlobalToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::DeclareGlobalOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::DeclareGlobalOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - Type convertedType = getTypeConverter()->convertType(op.getEntry().getType()); - if (!convertedType) - return rewriter.notifyMatchFailure( - op, "failed to convert declare_global result type"); - if (auto tvTy = dyn_cast(op.getEntry().getType())) { - if (auto stridesAttr = - op->getAttrOfType(kGlobalTensorStridesAttrName)) { - auto strides = stridesAttr.asArrayRef(); - if (strides.size() == static_cast(tvTy.getRank())) { - convertedType = emitc::OpaqueType::get( - rewriter.getContext(), - getGlobalTensorTypeStringFromShapeAndStrides( - tvTy.getElementType(), tvTy.getShape(), strides)); - } - } - } - auto var = rewriter.create( - op.getLoc(), convertedType, - emitc::OpaqueAttr::get(rewriter.getContext(), "")); - rewriter.replaceOp(op, var.getResult()); - return success(); - } -}; - -struct PTODeclareEventIdArrayToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::DeclareEventIdArrayOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::DeclareEventIdArrayOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - Type arrayTy = getTypeConverter()->convertType(op.getArray().getType()); - if (!arrayTy) - return rewriter.notifyMatchFailure(op, - "failed to map declared eventid_array type"); - - auto array = rewriter - .create( - op.getLoc(), arrayTy, - emitc::OpaqueAttr::get(rewriter.getContext(), "")) - .getResult(); - rewriter.replaceOp(op, array); - return success(); - } -}; - -struct PTOEventIdArrayGetToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::EventIdArrayGetOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::EventIdArrayGetOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value array = peelUnrealized(adaptor.getArray()); - Value index = peelUnrealized(adaptor.getIndex()); - - Type resultTy = getTypeConverter()->convertType(op.getResult().getType()); - if (!resultTy) - return rewriter.notifyMatchFailure(op, - "failed to map eventid_array get result type"); - - auto load = - rewriter.create(op.getLoc(), resultTy, array, index); - rewriter.replaceOp(op, load.getResult()); - return success(); - } -}; - -struct PTOEventIdArraySetToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::EventIdArraySetOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::EventIdArraySetOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value array = peelUnrealized(adaptor.getArray()); - Value index = peelUnrealized(adaptor.getIndex()); - Value value = peelUnrealized(adaptor.getValue()); - - rewriter.create( - op.getLoc(), TypeRange{}, "PTOAS__EVENTID_ARRAY_STORE", - ArrayAttr{}, ArrayAttr{}, ValueRange{array, index, value}); - rewriter.eraseOp(op); - return success(); - } -}; - -// pto.declare_local_array -> emitc.variable of !emitc.array<...>. -// Renders as `T a[D1][D2]...;` in the emitted C++. -struct PTODeclareLocalArrayToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::DeclareLocalArrayOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::DeclareLocalArrayOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - Type arrayTy = getTypeConverter()->convertType(op.getArray().getType()); - if (!arrayTy) - return rewriter.notifyMatchFailure(op, - "failed to map !pto.local_array type"); - - auto var = rewriter - .create( - op.getLoc(), arrayTy, - emitc::OpaqueAttr::get(rewriter.getContext(), "")) - .getResult(); - rewriter.replaceOp(op, var); - return success(); - } -}; - -// pto.local_array_get %a[%i0, %i1, ...] -> rvalue. -// Lowers to a single emitc.subscript with the full index pack; the C++ emitter -// prints it as `a[i0][i1]...`. The adaptor already exposes target-typed values -// (the type converter has remapped !pto.local_array -> !emitc.array and -// index/integer indices), so they're forwarded directly to the builder. -struct PTOLocalArrayGetToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::LocalArrayGetOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::LocalArrayGetOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type resultTy = - getTypeConverter()->convertType(op.getResult().getType()); - if (!resultTy) - return rewriter.notifyMatchFailure( - op, "failed to map local_array element type"); - - auto sub = rewriter.create( - op.getLoc(), resultTy, adaptor.getArray(), adaptor.getIndices()); - rewriter.replaceOp(op, sub.getResult()); - return success(); - } -}; - -// pto.local_array_set %a[%i0, %i1, ...], %v -> emitc.assign to subscript slot. -// The C++ emitter prints this as `a[i0][i1]... = v;`. As above, adaptor values -// are already target-typed; pass them through directly. -struct PTOLocalArraySetToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::LocalArraySetOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::LocalArraySetOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value value = adaptor.getValue(); - Type elemTy = value.getType(); - - Value slot = rewriter - .create(op.getLoc(), elemTy, - adaptor.getArray(), - adaptor.getIndices()) - .getResult(); - rewriter.create(op.getLoc(), slot, value); - rewriter.eraseOp(op); - return success(); - } -}; - -static std::optional getStaticIndexLikeValue(Value value) { - if (!value) - return std::nullopt; - if (auto cst = value.getDefiningOp()) - return cst.value(); - if (auto cst = value.getDefiningOp()) - return cst.value(); - if (auto cst = value.getDefiningOp()) { - if (auto intAttr = dyn_cast(cst.getValue())) - return intAttr.getInt(); - } - return std::nullopt; -} - -static FailureOr buildGlobalTensorViewFromPointer( - ConversionPatternRewriter &rewriter, Location loc, Value ptr, Type elemTy, - ArrayRef shape, ArrayRef strides = {}, - StringRef layoutEnum = "pto::Layout::ND") { - if (llvm::any_of(shape, [](int64_t dim) { - return dim == ShapedType::kDynamic; - })) - return failure(); - - auto *ctx = rewriter.getContext(); - SmallVector rowMajorStrides; - ArrayRef effectiveStrides = strides; - if (effectiveStrides.empty()) { - rowMajorStrides = buildRowMajorStrides(shape); - effectiveStrides = rowMajorStrides; - } - SmallVector shape5D; - SmallVector stride5D; - buildGlobalTensorShapeAndStride(shape, effectiveStrides, shape5D, stride5D); - - std::string shapeType = "pto::Shape<" + joinIntTemplateParams(shape5D) + ">"; - std::string strideType = - "pto::Stride<" + joinIntTemplateParams(stride5D) + ">"; - auto shapeVal = rewriter - .create( - loc, emitc::OpaqueType::get(ctx, shapeType), - shapeType, ArrayAttr{}, ArrayAttr{}, ValueRange{}) - .getResult(0); - auto strideVal = rewriter - .create( - loc, emitc::OpaqueType::get(ctx, strideType), - strideType, ArrayAttr{}, ArrayAttr{}, ValueRange{}) - .getResult(0); - - std::string gtTypeStr = - getGlobalTensorTypeStringFromShapeAndStrides(elemTy, shape, - effectiveStrides, - layoutEnum); - auto gtType = emitc::OpaqueType::get(ctx, gtTypeStr); - auto gt = rewriter.create( - loc, gtType, gtTypeStr, ArrayAttr{}, ArrayAttr{}, - ValueRange{ptr, shapeVal, strideVal}); - return gt.getResult(0); -} - -static bool parseIntegerTemplateList(StringRef token, StringRef marker, - SmallVectorImpl &values) { - size_t pos = token.find(marker); - if (pos == StringRef::npos) - return false; - pos += marker.size(); - size_t end = token.find('>', pos); - if (end == StringRef::npos) - return false; - - SmallVector parts; - token.slice(pos, end).split(parts, ','); - values.clear(); - for (StringRef part : parts) { - int64_t value = 0; - if (part.trim().getAsInteger(10, value)) - return false; - values.push_back(value); - } - return true; -} - -static LogicalResult getStaticTensorViewStrides( - Value source, Value convertedSource, pto::TensorViewType sourceType, - SmallVectorImpl &strides) { - int64_t rank = sourceType.getRank(); - strides.clear(); - - if (auto makeView = source.getDefiningOp()) { - if ((int64_t)makeView.getStrides().size() != rank) - return failure(); - for (Value strideValue : makeView.getStrides()) { - auto cst = getStaticIndexLikeValue(strideValue); - if (!cst) - return failure(); - strides.push_back(*cst); - } - return success(); - } - - Value src = peelUnrealized(convertedSource); - if (auto opaqueTy = dyn_cast(src.getType())) { - SmallVector stride5D; - StringRef token = opaqueTy.getValue(); - if ((parseIntegerTemplateList(token, "pto::Stride<", stride5D) || - parseIntegerTemplateList(token, "Stride<", stride5D)) && - (int64_t)stride5D.size() >= rank) { - strides.append(stride5D.end() - rank, stride5D.end()); - return success(); - } - } - - auto fallback = buildRowMajorStrides(sourceType.getShape()); - strides.append(fallback.begin(), fallback.end()); - return success(); -} - -struct PTOPartitionViewToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::PartitionViewOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::PartitionViewOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto srcTy = dyn_cast(op.getSource().getType()); - auto resTy = dyn_cast(op.getResult().getType()); - if (!srcTy || !resTy) - return rewriter.notifyMatchFailure( - op, "expected tensor_view source and partition_tensor_view result"); - - if (op.getOffsets().size() != static_cast(srcTy.getRank()) || - op.getSizes().size() != static_cast(srcTy.getRank())) - return rewriter.notifyMatchFailure(op, "rank mismatch"); - - for (auto [idx, value] : llvm::enumerate(op.getSizes())) { - auto cst = getStaticIndexLikeValue(value); - if (!cst) - return rewriter.notifyMatchFailure( - op, "globaltensor partition_view requires static sizes"); - int64_t resultDim = resTy.getShape()[idx]; - if (resultDim != ShapedType::kDynamic && resultDim != *cst) - return rewriter.notifyMatchFailure( - op, "partition_view static size does not match result type"); - } - - SmallVector srcStrides; - if (failed(getStaticTensorViewStrides(op.getSource(), adaptor.getSource(), - srcTy, srcStrides))) - return rewriter.notifyMatchFailure( - op, "partition_view requires static source strides"); - int64_t staticLinearOffset = 0; - SmallVector> dynamicOffsetTerms; - for (auto [idx, values] : - llvm::enumerate(llvm::zip(op.getOffsets(), adaptor.getOffsets()))) { - Value originalOffset = std::get<0>(values); - Value convertedOffset = std::get<1>(values); - int64_t stride = srcStrides[idx]; - if (stride == ShapedType::kDynamic) - return rewriter.notifyMatchFailure( - op, "dynamic source stride is not supported"); - - if (auto cst = getStaticIndexLikeValue(originalOffset)) { - if (*cst != 0) - staticLinearOffset += (*cst) * stride; - continue; - } - dynamicOffsetTerms.push_back({convertedOffset, stride}); - } - - auto *ctx = rewriter.getContext(); - std::string elemTypeStr = getElemTypeStringForGT(srcTy.getElementType()); - auto ptrTy = emitc::PointerType::get( - emitc::OpaqueType::get(ctx, "__gm__ " + elemTypeStr)); - Value src = peelUnrealized(adaptor.getSource()); - auto data = rewriter - .create( - op.getLoc(), ptrTy, "PTOAS__GLOBAL_TENSOR_DATA", - ArrayAttr{}, ArrayAttr{}, ValueRange{src}) - .getResult(0); - Value ptr = data; - if (!dynamicOffsetTerms.empty()) { - Type u32Ty = emitc::OpaqueType::get(ctx, "unsigned"); - auto makeU32 = [&](int64_t value) { - return makeEmitCIntConstant(rewriter, op.getLoc(), u32Ty, value); - }; - auto asU32 = [&](Value value) -> Value { - if (value.getType() == u32Ty) - return value; - return rewriter.create(op.getLoc(), u32Ty, value) - .getResult(); - }; - - Value totalOffset = makeU32(staticLinearOffset); - for (auto [offsetValue, stride] : dynamicOffsetTerms) { - Value term = asU32(offsetValue); - if (stride != 1) { - Value strideValue = makeU32(stride); - term = rewriter - .create(op.getLoc(), u32Ty, term, - strideValue) - .getResult(); - } - totalOffset = rewriter - .create(op.getLoc(), u32Ty, - totalOffset, term) - .getResult(); - } - ptr = rewriter - .create(op.getLoc(), data.getType(), data, - totalOffset) - .getResult(); - } else { - ptr = applyStaticMemrefOffset(rewriter, op.getLoc(), data, - staticLinearOffset); - } - - auto resultOr = buildGlobalTensorViewFromPointer( - rewriter, op.getLoc(), ptr, resTy.getElementType(), resTy.getShape(), - srcStrides); - if (failed(resultOr)) - return rewriter.notifyMatchFailure( - op, "failed to materialize partition GlobalTensor"); - - rewriter.replaceOp(op, *resultOr); - return success(); - } -}; - -static FailureOr getPipeDataTypeToken(Value value) { - auto opaqueTy = dyn_cast(value.getType()); - if (!opaqueTy) - return failure(); - StringRef token = opaqueTy.getValue(); - if (!token.contains("Tile<") && !token.contains("GlobalTensor<")) - return failure(); - return token.str(); -} - -struct PTOTAllocToEmitC : public OpConversionPattern { - PTOTAllocToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult matchAndRewrite(mlir::pto::TAllocOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); - if (failed(pipeTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); - Value entry = peelUnrealized(adaptor.getEntry()); - auto entryTok = getPipeDataTypeToken(entry); - if (failed(entryTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve entry token"); - auto splitTok = getTileSplitToken(op.getSplit()); - if (failed(splitTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve split token"); - - std::string callee = - "TALLOC<" + *pipeTok + ", " + *entryTok + ", " + *splitTok + ">"; - rewriter.replaceOpWithNewOp( - op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, - ValueRange{peelUnrealized(adaptor.getPipeHandle()), entry}); - return success(); - } - - PTOArch targetArch; -}; - -struct PTOTPushToEmitC : public OpConversionPattern { - PTOTPushToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult matchAndRewrite(mlir::pto::TPushOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); - if (failed(pipeTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); - // Read the tile type token from the already-converted OpaqueType, which - // preserves the exact layout produced by BindTileOp / PointerCastOp EmitC. - Value convertedTile = peelUnrealized(adaptor.getTile()); - auto tileTok = getPipeDataTypeToken(convertedTile); - if (failed(tileTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve tile token"); - auto splitTok = getTileSplitToken(op.getSplit()); - if (failed(splitTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve split token"); - - std::string callee = - "TPUSH<" + *pipeTok + ", " + *tileTok + ", " + *splitTok + ">"; - rewriter.replaceOpWithNewOp( - op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, - ValueRange{peelUnrealized(adaptor.getPipeHandle()), convertedTile}); - return success(); - } - - PTOArch targetArch; -}; - -struct PTOTPopToEmitC : public OpConversionPattern { - PTOTPopToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult matchAndRewrite(mlir::pto::TPopOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); - if (failed(pipeTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); - Value convertedTile = peelUnrealized(adaptor.getTile()); - auto tileTok = getPipeDataTypeToken(convertedTile); - if (failed(tileTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve tile token"); - auto splitTok = getTileSplitToken(op.getSplit()); - if (failed(splitTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve split token"); - - std::string callee = - "TPOP<" + *pipeTok + ", " + *tileTok + ", " + *splitTok + ">"; - rewriter.replaceOpWithNewOp( - op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, - ValueRange{peelUnrealized(adaptor.getPipeHandle()), convertedTile}); - return success(); - } - - PTOArch targetArch; -}; - -struct PTOTFreeToEmitC : public OpConversionPattern { - PTOTFreeToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult matchAndRewrite(mlir::pto::TFreeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); - if (failed(pipeTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); - auto splitTok = getTileSplitToken(op.getSplit()); - if (failed(splitTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve split token"); - - SmallVector operands{peelUnrealized(adaptor.getPipeHandle())}; - std::string callee; - if (op.getEntry()) { - Value entry = peelUnrealized(adaptor.getEntry()); - auto entryTok = getPipeDataTypeToken(entry); - if (failed(entryTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve entry token"); - callee = "TFREE<" + *pipeTok + ", " + *entryTok + ", " + *splitTok + ">"; - operands.push_back(entry); - } else { - callee = "TFREE<" + *pipeTok + ", " + *splitTok + ">"; - } - rewriter.replaceOpWithNewOp( - op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, operands); - return success(); - } - - PTOArch targetArch; -}; - -//===----------------------------------------------------------------------===// -// populate patterns -//===----------------------------------------------------------------------=== -struct ReinterpretCastToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - auto resMrTy = dyn_cast(op.getType()); - if (!resMrTy) - return failure(); - - auto asAttr = dyn_cast_or_null(resMrTy.getMemorySpace()); - const bool isGm = (!asAttr || asAttr.getAddressSpace() == pto::AddressSpace::GM); - - bool emitAddPtrTrace = op->hasAttr("pto.addptr_trace"); - Value source = peelUnrealized(adaptor.getSource()); - auto offsets = adaptor.getOffsets(); - Value offsetVal = offsets.empty() ? Value() : offsets[0]; - - // GM: keep pointer arithmetic. - if (isGm) { - if (!offsetVal) { - rewriter.replaceOp(op, source); - return success(); - } - - Type resultType = getTypeConverter()->convertType(op.getType()); - if (!resultType) - return failure(); - - auto addOp = rewriter.create(loc, resultType, source, offsetVal); - if (emitAddPtrTrace) { - rewriter.setInsertionPointAfter(addOp); - rewriter.create( - loc, TypeRange{}, "PTOAS__ADDPTR_TRACE", - ArrayAttr{}, ArrayAttr{}, - ValueRange{addOp.getResult(), source, offsetVal}); - } - rewriter.replaceOp(op, addOp.getResult()); - return success(); - } - - // UB/L1/L0 tiles: materialize a new Tile view by assigning an adjusted - // underlying pointer (in elements). - pto::AddressSpace as = asAttr.getAddressSpace(); - - // Element type token. - Type elemTy = resMrTy.getElementType(); - std::string elemTok = getEmitCScalarTypeToken(elemTy); - int64_t elemBytes = getEmitCScalarByteWidth(elemTy); - - // Tile role. - const char *roleTok = "TileType::Vec"; - switch (as) { - case pto::AddressSpace::VEC: - roleTok = "TileType::Vec"; - break; - case pto::AddressSpace::MAT: - roleTok = "TileType::Mat"; - break; - case pto::AddressSpace::LEFT: - roleTok = "TileType::Left"; - break; - case pto::AddressSpace::RIGHT: - roleTok = "TileType::Right"; - break; - case pto::AddressSpace::ACC: - roleTok = "TileType::Acc"; - break; - case pto::AddressSpace::BIAS: - roleTok = "TileType::Bias"; - break; - case pto::AddressSpace::GM: - roleTok = "TileType::Vec"; - break; - case pto::AddressSpace::Zero: - roleTok = "TileType::Vec"; - break; - case pto::AddressSpace::SCALING: - roleTok = "TileType::Scaling"; - break; - } - - // Shape (fallback to 32x32). - int64_t rows = 32, cols = 32; - if (resMrTy.getRank() >= 2 && resMrTy.hasStaticShape()) { - rows = resMrTy.getDimSize(0); - cols = resMrTy.getDimSize(1); - } - int64_t templateRows = - renderTileTemplateDim(rows, elemTy, pto::BLayout::RowMajor, 0); - int64_t templateCols = - renderTileTemplateDim(cols, elemTy, pto::BLayout::RowMajor, 1); - - // Keep a conservative default config for now. - std::string tileTypeStr = - std::string("Tile<") + roleTok + ", " + elemTok + ", " + - std::to_string(templateRows) + ", " + std::to_string(templateCols) + - ", BLayout::RowMajor, " + std::to_string(templateRows) + ", " + - std::to_string(templateCols) + - ", SLayout::NoneBox, 512, PadValue::Null, CompactMode::Null>"; - - auto tileType = emitc::OpaqueType::get(ctx, tileTypeStr); - Value tile = rewriter - .create(loc, tileType, - emitc::OpaqueAttr::get(ctx, "")) - .getResult(); - - // Compute an integer address and assign it to the new tile. - // NOTE: pto-isa TASSIGN requires an integral address (not a pointer). - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - auto rcU64 = rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - - // Non-GM reinterpret_cast operands come from UB/L1/L0 tiles. - // We need the underlying address, but `__cce_get_tile_ptr()` is only valid - // inside `__tf__` functions. Use `tile.data()` (via a post-processed marker) - // and compute the adjusted address in bytes. - Value rawPtr = source; - if (auto ot = dyn_cast(source.getType())) { - // Only Tiles have a `.data()` member. For plain address-space pointers - // (e.g. `__ubuf__ float*`), use the pointer value directly. - if (ot.getValue().starts_with("Tile<")) { - rawPtr = materializeTileDataValue(rewriter, loc, source, as, elemTok); - } - } - - Value baseAddr = rawPtr; - if (isSetFFTsPointerLikeType(rawPtr.getType())) { - baseAddr = rewriter - .create(loc, u64Ty, "reinterpret_cast", - /*args=*/ArrayAttr{}, - /*templateArgs=*/rcU64, - /*operands=*/ValueRange{rawPtr}) - .getResult(0); - } else if (rawPtr.getType() != u64Ty) { - baseAddr = rewriter.create(loc, u64Ty, rawPtr).getResult(); - } - - Value addr = baseAddr; - if (offsetVal) { - Value offU64 = offsetVal; - if (offU64.getType() != u64Ty) - offU64 = rewriter.create(loc, u64Ty, offU64).getResult(); - - auto bytesAttr = emitc::OpaqueAttr::get(ctx, std::to_string(elemBytes)); - Value bytesVal = rewriter.create(loc, u64Ty, bytesAttr); - Value byteOff = rewriter.create(loc, u64Ty, offU64, bytesVal); - addr = rewriter.create(loc, u64Ty, baseAddr, byteOff); - } - - rewriter.create(loc, TypeRange{}, "TASSIGN", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{tile, addr}); - - rewriter.replaceOp(op, tile); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.taddc lowering -> TADDC(dst, src0, src1, src2) -//===----------------------------------------------------------------------===// - -struct PTOTAddCToTADDC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAddCOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value src2 = peelUnrealized(adaptor.getSrc2()); - Value dst = peelUnrealized(adaptor.getDst()); - - // pto-isa does not provide NPU implementation for TADDC yet. - // Decompose: dst = src0 + src1 + src2 - rewriter.create( - loc, TypeRange{}, "TADD", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TADD", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, dst, src2}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tadds lowering -> TADDS(dst, src, scalar) -//===----------------------------------------------------------------------===// - -struct PTOAddSToTADDS : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAddSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value scalar = peelUnrealized(adaptor.getScalar()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TADDS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src, scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.taddsc lowering -> TADDSC(dst, src0, scalar, src1) -//===----------------------------------------------------------------------===// - -struct PTOAddSCToTADDSC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAddSCOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - // pto-isa does not provide NPU implementation for TADDSC yet. - // Decompose: dst = src0 + scalar + src1 - rewriter.create( - loc, TypeRange{}, "TADDS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src0, scalar}); - rewriter.create( - loc, TypeRange{}, "TADD", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, dst, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTOTAndToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAndOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getSrc0()); - Value b = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TAND", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, a, b}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOConcatToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TConcatOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TCONCAT", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTOConcatidxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TConcatidxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value src0Idx = peelUnrealized(adaptor.getSrc0Idx()); - Value src1Idx = peelUnrealized(adaptor.getSrc1Idx()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TCONCAT", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src0, src1, src0Idx, src1Idx}); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTOAndSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAndSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TANDS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src, scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; - - -struct PTOTCIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TCIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value S = peelUnrealized(adaptor.getOperands()[0]); - - // The TCI scalar template parameter should follow the original PTO IR - // scalar type, not the converted EmitC value type. - std::string scalarTok = "int32_t"; - if (auto it = dyn_cast(op->getOperand(0).getType())) { - bool isUnsigned = it.isUnsigned(); - if (it.getWidth() == 16) - scalarTok = isUnsigned ? "uint16_t" : "int16_t"; - else - scalarTok = isUnsigned ? "uint32_t" : "int32_t"; - } - - // descending -> "0"/"1" - std::string descTok = op.getDescending() ? "1" : "0"; - - ArrayAttr targs; - if (auto ot = mlir::dyn_cast(dst.getType())) { - std::string tileTok = ot.getValue().str(); // "Tile<...>" - targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, tileTok), - emitc::OpaqueAttr::get(ctx, scalarTok), - emitc::OpaqueAttr::get(ctx, descTok), - }); - } else { - targs = rewriter.getArrayAttr({}); - } - - rewriter.create( - loc, TypeRange{}, "TCI", - /*args=*/ArrayAttr{}, - /*templateArgs=*/targs, - /*operands=*/ValueRange{dst, S}); - - rewriter.eraseOp(op); - return success(); - } -}; -static std::string cmpModeTok(pto::CmpModeAttr a) { - // 生成 "CmpMode::GT" 这种 token - auto m = a.getValue(); // 取 enum - switch (m) { - case pto::CmpMode::EQ: return "CmpMode::EQ"; - case pto::CmpMode::NE: return "CmpMode::NE"; - case pto::CmpMode::LT: return "CmpMode::LT"; - case pto::CmpMode::LE: return "CmpMode::LE"; - case pto::CmpMode::GT: return "CmpMode::GT"; - case pto::CmpMode::GE: return "CmpMode::GE"; - } - return "CmpMode::EQ"; -} -struct PTOColExpandToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPAND", - /*args=*/ArrayAttr(), - /*templateArgs=*/ArrayAttr(), - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandMulToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandMulOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDMUL", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandAddToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandAddOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDADD", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandDivToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandDivOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDDIV", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandExpdifToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandExpdifOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDEXPDIF", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandSubToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandSubOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDSUB", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDMAX", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDMIN", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOTTriToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TTriOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value diagonal = peelUnrealized(adaptor.getDiagonal()); - - ArrayAttr templateArgs; - if (auto dstOT = mlir::dyn_cast(dst.getType())) { - templateArgs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, std::to_string(op.getUpperOrLower())), - }); - } else { - templateArgs = ArrayAttr{}; - } - - SmallVector operands{dst, diagonal}; - rewriter.create( - loc, TypeRange{}, "TTRI", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOCmpToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TCmpOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - - std::string tok = "CmpMode::EQ"; - if (auto a = op.getCmpModeAttr()) - tok = cmpModeTok(a); - - auto modeTy = emitc::OpaqueType::get(ctx, "CmpMode"); - Value modeVal = rewriter.create( - loc, modeTy, emitc::OpaqueAttr::get(ctx, tok)); - - rewriter.create( - loc, - TypeRange{}, - "TCMP", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1, modeVal}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOCmpSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TCmpSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - - // cmpMode -> token - auto cmpAttr = op.getCmpModeAttr(); // PTO_CmpModeAttr - std::string tok = cmpModeTok(cmpAttr); - - auto modeTy = emitc::OpaqueType::get(ctx, "CmpMode"); - Value modeVal = rewriter.create( - loc, modeTy, emitc::OpaqueAttr::get(ctx, tok)); - - rewriter.create( - loc, - TypeRange{}, - "TCMPS", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, scalar, modeVal}); - - rewriter.eraseOp(op); - return success(); - } -}; - - -struct PTOColMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - // intrinsic: TCOLMAX(dst, src) - rewriter.create( - loc, TypeRange{}, "TCOLMAX", - /*args=*/ArrayAttr{}, // default: print all operands - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColArgMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColArgMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLARGMAX", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, tmp}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - // intrinsic: TCOLMIN(dst, src) - rewriter.create( - loc, TypeRange{}, "TCOLMIN", - /*args=*/ArrayAttr{}, // default: print all operands - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColArgMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColArgMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLARGMIN", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, tmp}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColSumToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColSumOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - // Check if tmp exists before accessing it - if (op.getTmp()) { - // Format 2: with tmp and isBinary - Value tmp = peelUnrealized(adaptor.getTmp()); - bool isBinary = false; - if (auto a = op.getIsBinaryAttr()) - isBinary = a.getValue(); - - auto boolTy = emitc::OpaqueType::get(ctx, "bool"); - auto tok = isBinary ? "true" : "false"; - Value isBinaryVal = rewriter.create( - loc, boolTy, emitc::OpaqueAttr::get(ctx, tok)); - - rewriter.create( - loc, TypeRange{}, "TCOLSUM", - /*args=*/ArrayAttr(), - /*templateArgs=*/ArrayAttr(), - /*operands=*/ValueRange{dst, src, tmp, isBinaryVal}); - } else { - // Format 1: without tmp and isBinary - rewriter.create( - loc, TypeRange{}, "TCOLSUM", - /*args=*/ArrayAttr(), - /*templateArgs=*/ArrayAttr(), - /*operands=*/ValueRange{dst, src}); - } - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColProdToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColProdOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLPROD", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; -static std::string roundModeTok(mlir::pto::RoundModeAttr attr) { - using RM = mlir::pto::RoundMode; - switch (attr.getValue()) { - case RM::NONE: return "RoundMode::CAST_NONE"; - case RM::RINT: return "RoundMode::CAST_RINT"; - case RM::ROUND: return "RoundMode::CAST_ROUND"; - case RM::FLOOR: return "RoundMode::CAST_FLOOR"; - case RM::CEIL: return "RoundMode::CAST_CEIL"; - case RM::TRUNC: return "RoundMode::CAST_TRUNC"; - case RM::ODD: return "RoundMode::CAST_ODD"; - case RM::CAST_RINT: return "RoundMode::CAST_RINT"; - } - return "RoundMode::CAST_RINT"; -} -static std::string saturationModeTok(mlir::pto::SaturationModeAttr attr) { - using SM = mlir::pto::SaturationMode; - switch (attr.getValue()) { - case SM::ON: return "SaturationMode::ON"; - case SM::OFF: return "SaturationMode::OFF"; - } - return "SaturationMode::OFF"; -} -struct PTOCvtToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TCvtOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - pto::RoundModeAttr rmAttr = op.getRmodeAttr(); - std::string rmTok = rmAttr ? roundModeTok(rmAttr) - : std::string("RoundMode::CAST_RINT"); - auto rmodeTy = emitc::OpaqueType::get(ctx, "RoundMode"); - Value rmodeVal = rewriter.create( - loc, rmodeTy, emitc::OpaqueAttr::get(ctx, rmTok)); - - auto satModeTy = emitc::OpaqueType::get(ctx, "SaturationMode"); - auto satAttr = op.getSatModeAttr(); - std::string satTok = satAttr ? saturationModeTok(satAttr) - : std::string("SaturationMode::OFF"); - Value satModeVal = rewriter.create( - loc, satModeTy, emitc::OpaqueAttr::get(ctx, satTok)); - - SmallVector operands{dst, src, rmodeVal, satModeVal}; - - rewriter.create( - loc, TypeRange{}, "TCVT", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTORandomToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRandomOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - SmallVector operands{ - dst, - peelUnrealized(adaptor.getKey0()), - peelUnrealized(adaptor.getKey1()), - peelUnrealized(adaptor.getCounter0()), - peelUnrealized(adaptor.getCounter1()), - peelUnrealized(adaptor.getCounter2()), - peelUnrealized(adaptor.getCounter3()), - }; - ArrayAttr templateArgs = rewriter.getArrayAttr( - {emitc::OpaqueAttr::get(ctx, std::to_string(op.getRounds()))}); - - rewriter.create( - loc, TypeRange{}, "PTOAS__TRANDOM", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, operands); - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tdiv lowering -> TDIV(dst, src0, src1) -//===----------------------------------------------------------------------===// - -struct PTODivToTDIV : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TDivOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TDIV", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tdivs lowering -> TDIVS(dst, src, scalar) or TDIVS(dst, scalar, src) -// Order is determined by operand types: if src is tile_buf, order is (tile, scalar) -// Otherwise, order is (scalar, tile) -//===----------------------------------------------------------------------===// - -struct PTODivSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TDivSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value dst = peelUnrealized(adaptor.getDst()); - // Preserve source order from textual parse: - // ins(tile, scalar) -> TDIVS(dst, tile, scalar) - // ins(scalar, tile) -> TDIVS(dst, scalar, tile) - rewriter.create( - loc, TypeRange{}, "TDIVS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src, scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.tdivs (TDivSOp) lowering -> TDIVS(dst, src, scalar) or TDIVS(dst, scalar, src) -// Order is determined by operand types: if src is tile_buf, order is (tile, scalar) -// Otherwise, order is (scalar, tile) -//===----------------------------------------------------------------------===// - -struct PTOTDivSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TDivSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value dst = peelUnrealized(adaptor.getDst()); - rewriter.create( - loc, TypeRange{}, "TDIVS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src, scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.texp lowering -> TEXP(dst, src) -//===----------------------------------------------------------------------===// - -struct PTOExpToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TExpOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TEXP", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.texpands lowering -> TEXPANDS(dst, scalar) -//===----------------------------------------------------------------------===// - -struct PTOExpandsToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TExpandsOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value scalar = peelUnrealized(adaptor.getScalar()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TEXPANDS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.textract lowering -> TEXTRACT(dst, src, indexRow, indexCol) -//===----------------------------------------------------------------------===// - -struct PTOExtractToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TExtractOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value r0 = peelUnrealized(adaptor.getIndexRow()); - Value c0 = peelUnrealized(adaptor.getIndexCol()); - - rewriter.create( - loc, TypeRange{}, "TEXTRACT", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, r0, c0}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.textract_fp lowering -> TEXTRACT_FP(dst, src, fp, indexRow, indexCol) -//===----------------------------------------------------------------------===// - -struct PTOExtractFPToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TExtractFPOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value fp = peelUnrealized(adaptor.getFp()); - Value dst = peelUnrealized(adaptor.getDst()); - Value r0 = peelUnrealized(adaptor.getIndexRow()); - Value c0 = peelUnrealized(adaptor.getIndexCol()); - - rewriter.create( - loc, TypeRange{}, "TEXTRACT_FP", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, fp, r0, c0}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tinsert lowering -> TINSERT(dst, src, indexRow, indexCol) -// Keep lowering arch-agnostic and let PTO-ISA infer proper A5 path. -//===----------------------------------------------------------------------===// - -struct PTOInsertToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TInsertOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value r0 = peelUnrealized(adaptor.getIndexRow()); - Value c0 = peelUnrealized(adaptor.getIndexCol()); - - rewriter.create( - loc, TypeRange{}, "TINSERT", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, r0, c0}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tinsert_fp lowering -> TINSERT_FP(dst, src, fp, indexRow, indexCol) -//===----------------------------------------------------------------------===// - -struct PTOInsertFPToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TInsertFPOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value fp = peelUnrealized(adaptor.getFp()); - Value dst = peelUnrealized(adaptor.getDst()); - Value r0 = peelUnrealized(adaptor.getIndexRow()); - Value c0 = peelUnrealized(adaptor.getIndexCol()); - - rewriter.create( - loc, TypeRange{}, "TINSERT_FP", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, fp, r0, c0}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tfillpad lowering -> TFILLPAD(dst, src) -//===----------------------------------------------------------------------===// - -struct PTOFillPadToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TFillPadOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TFILLPAD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tfillpad_inplace lowering -> TFILLPAD_INPLACE(dst, src) -//===----------------------------------------------------------------------===// - -struct PTOFillPadInplaceToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TFillPadInplaceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TFILLPAD_INPLACE", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tfillpad_expand lowering -> TFILLPAD_EXPAND(dst, src) -//===----------------------------------------------------------------------===// - -struct PTOFillPadExpandToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TFillPadExpandOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TFILLPAD_EXPAND", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tgather lowering -// - Index form : TGATHER(dst, src0, indices, tmp) -// - Compare form: TGATHER(dst, src0, kValue, cdst, tmp) -// - Mask form : TGATHER(dst, src0) -//===----------------------------------------------------------------------===// - -[[maybe_unused]] static std::string maskPatternTok(mlir::pto::MaskPatternAttr a) { - - auto v = a.getValue(); // enum - return (std::string("pto::MaskPattern::") + mlir::pto::stringifyMaskPattern(v).str()); -} - -struct PTOGatherToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGatherOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src0 = peelUnrealized(adaptor.getSrc()); - - auto getOpaqueTok = [&](Value v, StringRef name) -> FailureOr { - if (auto ot = mlir::dyn_cast(v.getType())) - return ot.getValue().str(); - return rewriter.notifyMatchFailure(op, (name + " must be emitc::OpaqueType (tile)").str()); - }; - - // Case 1: index-based TGATHER(dst, src0, indices, tmp) - if (Value idx = adaptor.getIndices()) { - idx = peelUnrealized(idx); - Value tmp = peelUnrealized(adaptor.getTmp()); - - rewriter.create( - loc, TypeRange{}, "TGATHER", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, idx, tmp}); - - rewriter.eraseOp(op); - return success(); - } - - // Case 2: compare-based TGATHER( - // dst, src0, kValue, tmp, cdst, offset) - if (Value cdst = adaptor.getCdst()) { - cdst = peelUnrealized(cdst); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value kValue = peelUnrealized(adaptor.getKValue()); - - auto dstTokOr = getOpaqueTok(dst, "dst"); - auto srcTokOr = getOpaqueTok(src0, "src0"); - auto cdstTokOr = getOpaqueTok(cdst, "cdst"); - auto tmpTokOr = getOpaqueTok(tmp, "tmp"); - if (failed(dstTokOr) || failed(srcTokOr) || failed(cdstTokOr) || failed(tmpTokOr)) - return failure(); - - auto cmpAttr = op.getCmpModeAttr(); - std::string cmpTok = cmpAttr ? cmpModeTok(cmpAttr) : "CmpMode::EQ"; - int64_t offset = 0; - if (auto offsetAttr = op.getOffsetAttr()) - offset = offsetAttr.getInt(); - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value offsetVal = makeEmitCIntConstant(rewriter, loc, i32Ty, offset); - - auto targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, *dstTokOr), - emitc::OpaqueAttr::get(ctx, *srcTokOr), - emitc::OpaqueAttr::get(ctx, *tmpTokOr), - emitc::OpaqueAttr::get(ctx, *cdstTokOr), - emitc::OpaqueAttr::get(ctx, cmpTok), - }); - - rewriter.create( - loc, TypeRange{}, "TGATHER", - /*args=*/ArrayAttr{}, /*templateArgs=*/targs, - /*operands=*/ValueRange{dst, src0, kValue, tmp, cdst, offsetVal}); - - rewriter.eraseOp(op); - return success(); - } - - // Case 3: mask-pattern TGATHER(dst, src0) - auto mp = op.getMaskPatternAttr(); - if (!mp) - return rewriter.notifyMatchFailure(op, "expected maskPattern, indices, or cdst on tgather"); - - auto dstTokOr = getOpaqueTok(dst, "dst"); - auto srcTokOr = getOpaqueTok(src0, "src0"); - if (failed(dstTokOr) || failed(srcTokOr)) - return failure(); - - // mp is an EnumAttr; stringify name is "P0101" etc. - // We emit MaskPattern::P0101 (because generated C++ has `using namespace pto;`) - std::string mpTok = std::string("MaskPattern::") + - mlir::pto::stringifyMaskPattern(mp.getValue()).str(); - - auto targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, *dstTokOr), - emitc::OpaqueAttr::get(ctx, *srcTokOr), - emitc::OpaqueAttr::get(ctx, mpTok), - }); - - rewriter.create( - loc, TypeRange{}, "TGATHER", - /*args=*/ArrayAttr{}, - /*templateArgs=*/targs, - /*operands=*/ValueRange{dst, src0}); - - rewriter.eraseOp(op); - return success(); - } -}; - - -struct PTOGatherbToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGatherBOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value offsets = peelUnrealized(adaptor.getOffsets()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TGATHERB", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, offsets}); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// TLOG lowering to EmitC (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - -struct PTOLogToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TLogOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TLOG", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - - - -//===----------------------------------------------------------------------===// -// TLRELU lowering to EmitC (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - - struct PTOLReluToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TLReluOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value slope = peelUnrealized(adaptor.getSlope()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, slope}; - - rewriter.create( - loc, TypeRange{}, "TLRELU", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// TMAX lowering to EmitC (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - -struct PTOMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TMAX", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// TMAXS lowering to EmitC (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - - struct PTOMaxSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMaxSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, scalar}; - rewriter.create( - loc, TypeRange{}, "TMAXS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - - -//===----------------------------------------------------------------------===// -// TMIN lowering to EmitC (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - -struct PTOMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TMIN", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// TMINS lowering to EmitC (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - -//===----------------------------------------------------------------------===// -// TMINS lowering to EmitC (fix APFloat -> FloatAttr) (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - -struct PTOMinsToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMinSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value scalar = peelUnrealized(adaptor.getScalar()); - - SmallVector operands{dst, src, scalar}; - rewriter.create( - loc, TypeRange{}, "TMINS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering for TMOV op -> EmitC) -//===----------------------------------------------------------------------===// - -struct PTOMovToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMovOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value fp; - if (op.getFp()) - fp = peelUnrealized(adaptor.getFp()); - Value preQuantScalar; - if (op.getPreQuantScalar()) - preQuantScalar = peelUnrealized(adaptor.getPreQuantScalar()); - - auto dstOT = mlir::dyn_cast(dst.getType()); - auto srcOT = mlir::dyn_cast(src.getType()); - if (!dstOT || !srcOT) - return rewriter.notifyMatchFailure( - op, "tmov lowering expects opaque dst/src types"); - - auto modeTok = [&](pto::AccToVecMode mode) -> StringRef { - switch (mode) { - case pto::AccToVecMode::SingleModeVec0: - return "pto::AccToVecMode::SingleModeVec0"; - case pto::AccToVecMode::SingleModeVec1: - return "pto::AccToVecMode::SingleModeVec1"; - case pto::AccToVecMode::DualModeSplitM: - return "pto::AccToVecMode::DualModeSplitM"; - case pto::AccToVecMode::DualModeSplitN: - return "pto::AccToVecMode::DualModeSplitN"; - } - llvm_unreachable("unknown AccToVecMode"); - }; - - auto modeAttr = op.getAccToVecModeAttr(); - auto reluTok = [&](pto::ReluPreMode mode) -> StringRef { - switch (mode) { - case pto::ReluPreMode::NoRelu: - return "ReluPreMode::NoRelu"; - case pto::ReluPreMode::NormalRelu: - return "ReluPreMode::NormalRelu"; - } - llvm_unreachable("unknown ReluPreMode"); - }; - - const bool hasFp = static_cast(fp); - const bool hasPreQuantScalar = static_cast(preQuantScalar); - const bool hasMode = static_cast(modeAttr); - const bool reluNonDefault = op.getReluPreMode() != pto::ReluPreMode::NoRelu; - - SmallVector operands{dst, src}; - SmallVector templateArgVec{ - emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), - }; - StringRef callee = "TMOV"; - - if (hasFp) { - auto fpOT = mlir::dyn_cast(fp.getType()); - if (!fpOT) - return rewriter.notifyMatchFailure( - op, "tmov fp lowering expects opaque fp type"); - operands.push_back(fp); - templateArgVec.push_back(emitc::OpaqueAttr::get(ctx, fpOT.getValue().str())); - if (hasMode) - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, modeTok(modeAttr.getValue()))); - if (hasMode || reluNonDefault) - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); - callee = hasMode ? "TMOV" : "TMOV_FP"; - } else if (hasPreQuantScalar) { - operands.push_back(preQuantScalar); - if (hasMode) - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, modeTok(modeAttr.getValue()))); - if (hasMode || reluNonDefault) - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); - } else if (hasMode) { - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, modeTok(modeAttr.getValue()))); - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); - } else if (reluNonDefault) { - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); - } - - ArrayAttr templateArgs = - templateArgVec.size() == 2 && !hasFp && !hasPreQuantScalar && - !hasMode && !reluNonDefault - ? ArrayAttr{} - : rewriter.getArrayAttr(templateArgVec); - - rewriter.create( - loc, TypeRange{}, callee, - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TMOV_FP DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOMovFPToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMovFPOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value fp = peelUnrealized(adaptor.getFp()); - - // TMOV_FP(dstTileData, cTile, fbTile) - ArrayAttr templateArgs; - auto dstOT = mlir::dyn_cast(dst.getType()); - auto srcOT = mlir::dyn_cast(src.getType()); - auto fpOT = mlir::dyn_cast(fp.getType()); - if (dstOT && srcOT && fpOT) { - templateArgs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, fpOT.getValue().str()), - }); - } else { - templateArgs = ArrayAttr{}; - } - - SmallVector operands{dst, src, fp}; - rewriter.create( - loc, TypeRange{}, "TMOV_FP", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOQuantToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TQuantOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value fp = peelUnrealized(adaptor.getFp()); - - // Optional offset (INT8_ASYM only): passed as pointer (&offset) - Value offsetPtr; - if (op.getOffset()) { - Value offset = peelUnrealized(adaptor.getOffset()); - auto offsetOT = mlir::dyn_cast(offset.getType()); - if (offsetOT) { - offsetPtr = rewriter - .create( - loc, emitc::PointerType::get(offsetOT), "&", offset) - .getResult(); - } - } - - // TQUANT(dst, src, fp[, &offset]) - std::string quantTypeStr = - op.getQuantType() == pto::QuantType::INT8_SYM - ? "pto::QuantType::INT8_SYM" - : "pto::QuantType::INT8_ASYM"; - ArrayAttr templateArgs; - auto dstOT = mlir::dyn_cast(dst.getType()); - auto srcOT = mlir::dyn_cast(src.getType()); - auto fpOT = mlir::dyn_cast(fp.getType()); - if (dstOT && srcOT && fpOT) { - templateArgs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, quantTypeStr), - emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, fpOT.getValue().str()), - }); - } else { - templateArgs = ArrayAttr{}; - } - - SmallVector operands{dst, src, fp}; - if (offsetPtr) - operands.push_back(offsetPtr); - - rewriter.create( - loc, TypeRange{}, "TQUANT", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTODequantToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TDequantOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value scale = peelUnrealized(adaptor.getScale()); - Value offset = peelUnrealized(adaptor.getOffset()); - - // TDEQUANT(dst, src, scale, offset) - ArrayAttr templateArgs; - auto dstOT = mlir::dyn_cast(dst.getType()); - auto srcOT = mlir::dyn_cast(src.getType()); - auto scaleOT = mlir::dyn_cast(scale.getType()); - if (dstOT && srcOT && scaleOT) { - templateArgs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, scaleOT.getValue().str()), - }); - } else { - templateArgs = ArrayAttr{}; - } - - rewriter.create( - loc, TypeRange{}, "TDEQUANT", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, - /*operands=*/SmallVector{dst, src, scale, offset}); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TMRGSORT DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOMrgSortToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMrgSortOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - if (op.isFormat1()) { - Value src = peelUnrealized(adaptor.getSrcs().front()); - Value dst = peelUnrealized(adaptor.getDsts().front()); - Value blockLen = peelUnrealized(adaptor.getBlockLen()); - - SmallVector operands{dst, src, blockLen}; - rewriter.create( - loc, TypeRange{}, "TMRGSORT", - ArrayAttr{}, ArrayAttr{}, operands); - } else if (op.isFormat2()) { - // pto-isa API: - // TMRGSORT( - // dst, executedNumList, tmp, src0, src1[, src2[, src3]]); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDsts()[0]); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value excuted = peelUnrealized(adaptor.getExcuted()); - - SmallVector srcs; - srcs.reserve(adaptor.getSrcs().size()); - for (Value v : adaptor.getSrcs()) - srcs.push_back(peelUnrealized(v)); - - auto dstOT = mlir::dyn_cast(dst.getType()); - auto tmpOT = mlir::dyn_cast(tmp.getType()); - if (!dstOT || !tmpOT || srcs.size() < 2 || srcs.size() > 4) - return op.emitOpError("format2 expects dst/tmp tilebufs and 2 to 4 srcs"); - - SmallVector targs; - targs.reserve(2 + srcs.size() + 1); - targs.push_back(emitc::OpaqueAttr::get(ctx, dstOT.getValue().str())); - targs.push_back(emitc::OpaqueAttr::get(ctx, tmpOT.getValue().str())); - for (Value v : srcs) { - auto ot = mlir::dyn_cast(v.getType()); - if (!ot) - return op.emitOpError("format2 expects tilebuf srcs"); - targs.push_back(emitc::OpaqueAttr::get(ctx, ot.getValue().str())); - } - targs.push_back(emitc::OpaqueAttr::get(ctx, op.getExhausted() ? "true" : "false")); - ArrayAttr templateArgs = rewriter.getArrayAttr(targs); - - SmallVector operands{dst, excuted, tmp}; - operands.append(srcs.begin(), srcs.end()); - - rewriter.create( - loc, TypeRange{}, "TMRGSORT", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, operands); - } else { - return op.emitOpError("unsupported mrgsort_dps format"); - } - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TMUL DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOMulToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMulOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TMUL", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TMULS DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOMulsToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMulSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc0()); - Value dst = peelUnrealized(adaptor.getDst()); - Value scalar = peelUnrealized(adaptor.getScalar()); - - SmallVector operands{dst, src, scalar}; - rewriter.create( - loc, TypeRange{}, "TMULS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TNEG DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTONegToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TNegOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TNEG", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TNOT DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTONotToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TNotOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TNOT", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TOR DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOOrToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TOrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TOR", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TORS DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOOrsToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TOrSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - // NOTE: The conversion type system may materialize integers as emitc.opaque - // (e.g. "int32_t"). For EmitC call emission we can pass the scalar through - // directly without arith casts here. - Value s = adaptor.getScalar(); - - SmallVector operands{dst, src0, s}; - rewriter.create( - loc, TypeRange{}, "TORS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TPARTADD DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOPartAddToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPartAddOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TPARTADD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TPARTMAX DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOPartMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPartMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TPARTMAX", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TPARTMIN DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOPartMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPartMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TPARTMIN", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOPartArgMaxToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPartArgMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value src0Idx = peelUnrealized(adaptor.getSrc0Idx()); - Value src1Idx = peelUnrealized(adaptor.getSrc1Idx()); - Value dst = peelUnrealized(adaptor.getDst()); - Value dstIdx = peelUnrealized(adaptor.getDstIdx()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TPARTARGMAX", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1, dstIdx, src0Idx, src1Idx}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOPartArgMinToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPartArgMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value src0Idx = peelUnrealized(adaptor.getSrc0Idx()); - Value src1Idx = peelUnrealized(adaptor.getSrc1Idx()); - Value dst = peelUnrealized(adaptor.getDst()); - Value dstIdx = peelUnrealized(adaptor.getDstIdx()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TPARTARGMIN", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1, dstIdx, src0Idx, src1Idx}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TPARTMUL DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOPartMulToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPartMulOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TPARTMUL", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TPRELU DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOPreluToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPReluOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - // C++ interface: TPRELU(dst, src0, src1, tmp) — last parameter is tmp. - SmallVector operands{dst, src0, src1, tmp}; - rewriter.create( - loc, TypeRange{}, "TPRELU", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TRECIP DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORecipToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRecipOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TRECIP", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TRELU DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOReluToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TReluOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TRELU", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TREM DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORemToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRemOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - SmallVector operands{dst, src0, src1, tmp}; - rewriter.create( - loc, TypeRange{}, "TREM", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOFModToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TFModOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TFMOD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TREMS DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORemSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRemSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - Value scalar = peelUnrealized(adaptor.getScalar()); - SmallVector operands{dst, src, scalar, tmp}; - rewriter.create( - loc, TypeRange{}, "TREMS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOFModSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TFModSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value scalar = peelUnrealized(adaptor.getScalar()); - - SmallVector operands{dst, src, scalar}; - rewriter.create( - loc, TypeRange{}, "TFMODS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWEXPAND DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORowExpandToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TROWEXPAND", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowExpandAddToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandAddOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TROWEXPANDADD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowExpandExpdifToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandExpdifOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src0, src1, tmp}); - else - operands.assign({dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TROWEXPANDEXPDIF", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWEXPANDDIV DPS/memref op) -//===----------------------------------------------------------------------===// -// Helper: replace or erase based on whether op has results. -static void replaceOrEraseWithOpaqueCall(Operation *op, - StringRef callee, - ArrayRef args, - ConversionPatternRewriter &rewriter) { - TypeRange resultTypes = op->getResultTypes(); - auto call = rewriter.create( - op->getLoc(), resultTypes, callee, ArrayAttr{}, ArrayAttr{}, ValueRange(args)); - if (resultTypes.empty()) - rewriter.eraseOp(op); - else - rewriter.replaceOp(op, call.getResults()); -} - -static void replaceOrEraseWithOpaqueCallAndReturnDst(Operation *op, Value dst, - StringRef callee, - ArrayRef args, - ConversionPatternRewriter &rewriter) { - rewriter.create( - op->getLoc(), TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, ValueRange(args)); - if (op->getNumResults() == 1) - rewriter.replaceOp(op, dst); - else - rewriter.eraseOp(op); -} - -// ---------- TOp ---------- -struct PTOTGemvBiasToTGEMV_BIAS - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGemvBiasOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getA()); - Value b = peelUnrealized(adaptor.getB()); - Value bias = peelUnrealized(adaptor.getBias()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCall(op.getOperation(), "TGEMV_BIAS", - {dst, a, b, bias}, rewriter); - return success(); - } -}; - -struct PTOTGemvMXToTGEMV_MX - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGemvMxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getA()); - Value aScale = peelUnrealized(adaptor.getAScale()); - Value b = peelUnrealized(adaptor.getB()); - Value bScale = peelUnrealized(adaptor.getBScale()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", - {dst, a, aScale, b, bScale}, rewriter); - return success(); - } -}; - -struct PTOTGemvMXAccToTGEMV_MX - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGemvMxAccOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value cIn = peelUnrealized(adaptor.getCIn()); - Value a = peelUnrealized(adaptor.getA()); - Value aScale = peelUnrealized(adaptor.getAScale()); - Value b = peelUnrealized(adaptor.getB()); - Value bScale = peelUnrealized(adaptor.getBScale()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", - {dst, cIn, a, aScale, b, bScale}, rewriter); - return success(); - } -}; - -struct PTOTGemvMXBiasToTGEMV_MX - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGemvMxBiasOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getA()); - Value aScale = peelUnrealized(adaptor.getAScale()); - Value b = peelUnrealized(adaptor.getB()); - Value bScale = peelUnrealized(adaptor.getBScale()); - Value bias = peelUnrealized(adaptor.getBias()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", - {dst, a, aScale, b, bScale, bias}, rewriter); - return success(); - } -}; - -struct PTOTMatmulBiasToTMATMUL_BIAS - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMatmulBiasOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getA()); - Value b = peelUnrealized(adaptor.getB()); - Value bias = peelUnrealized(adaptor.getBias()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_BIAS", - {dst, a, b, bias}, rewriter); - return success(); - } -}; - -struct PTOTMatmulMXToTMATMUL_MX - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMatmulMxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getA()); - Value aScale = peelUnrealized(adaptor.getAScale()); - Value b = peelUnrealized(adaptor.getB()); - Value bScale = peelUnrealized(adaptor.getBScale()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_MX", - {dst, a, aScale, b, bScale}, rewriter); - return success(); - } -}; - -struct PTOTMatmulMXAccToTMATMUL_MX_ACC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMatmulMxAccOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value cIn = peelUnrealized(adaptor.getCIn()); - Value a = peelUnrealized(adaptor.getA()); - Value aScale = peelUnrealized(adaptor.getAScale()); - Value b = peelUnrealized(adaptor.getB()); - Value bScale = peelUnrealized(adaptor.getBScale()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_MX", - {dst, cIn, a, aScale, b, bScale}, rewriter); - return success(); - } -}; - -struct PTOTMatmulMXBiasToTMATMUL_MX_BIAS - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMatmulMxBiasOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getA()); - Value aScale = peelUnrealized(adaptor.getAScale()); - Value b = peelUnrealized(adaptor.getB()); - Value bScale = peelUnrealized(adaptor.getBScale()); - Value bias = peelUnrealized(adaptor.getBias()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_MX", - {dst, a, aScale, b, bScale, bias}, rewriter); - return success(); - } -}; - -struct PTORowExpandDivToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandDivOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src0, src1, tmp}); - else - operands.assign({dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TROWEXPANDDIV", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWEXPANDMUL DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORowExpandMulToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandMulOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src0, src1, tmp}); - else - operands.assign({dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TROWEXPANDMUL", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWEXPANDSUB DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORowExpandSubToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandSubOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src0, src1, tmp}); - else - operands.assign({dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TROWEXPANDSUB", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowExpandMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src0, src1, tmp}); - else - operands.assign({dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TROWEXPANDMAX", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowExpandMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src0, src1, tmp}); - else - operands.assign({dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TROWEXPANDMIN", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWMAX DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORowMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, tmp}; - rewriter.create( - loc, TypeRange{}, "TROWMAX", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowArgMaxToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowArgMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TROWARGMAX", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, tmp}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWMIN DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORowMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, tmp}; - rewriter.create( - loc, TypeRange{}, "TROWMIN", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowArgMinToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowArgMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TROWARGMIN", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, tmp}); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWSUM DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORowSumToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowSumOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, tmp}; - rewriter.create( - loc, TypeRange{}, "TROWSUM", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowProdToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowProdOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, tmp}; - rewriter.create( - loc, TypeRange{}, "TROWPROD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TRSQRT DPS/memref op) -// - no-tmp form : TRSQRT(dst, src) -// - tmp form : TRSQRT(dst, src, tmp) -//===----------------------------------------------------------------------===// - -struct PTORsqrtToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRsqrtOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - SmallVector operands{dst, src}; - if (Value tmp = adaptor.getTmp()) - operands.push_back(peelUnrealized(tmp)); - rewriter.create( - loc, TypeRange{}, "TRSQRT", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSCATTER DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOScatterToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TScatterOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - const bool hasMaskPattern = static_cast(op.getMaskPatternAttr()); - const bool hasIndexes = static_cast(op.getIndexes()); - if (hasMaskPattern == hasIndexes) { - return rewriter.notifyMatchFailure( - op, "expected exactly one of indexes operand or maskPattern attribute"); - } - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - if (auto mp = op.getMaskPatternAttr()) { - auto *ctx = rewriter.getContext(); - auto targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, maskPatternTok(mp)), - }); - rewriter.create( - loc, TypeRange{}, "TSCATTER", - /*args=*/ArrayAttr{}, /*templateArgs=*/targs, - /*operands=*/ValueRange{dst, src}); - } else { - Value idx = peelUnrealized(adaptor.getIndexes()); - rewriter.create( - loc, TypeRange{}, "TSCATTER", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, idx}); - } - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSEL DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSelToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSelOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value mask = peelUnrealized(adaptor.getMask()); - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, mask, src0, src1, tmp}; - rewriter.create( - loc, TypeRange{}, "TSEL", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSELS DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSelSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSelSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value mask = peelUnrealized(adaptor.getMask()); - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, mask, src, tmp, scalar}; - rewriter.create( - loc, TypeRange{}, "TSELS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSHL DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOShlSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TShlOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TSHL", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSHR DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOShrSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TShrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TSHR", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering for TSHLS/TSHRS DPS: shift by scalar) -//===----------------------------------------------------------------------===// - -struct PTOShlSConstToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TShlSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - SmallVector operands{dst, src, scalar}; - rewriter.create( - loc, TypeRange{}, "TSHLS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOShrSConstToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TShrSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - SmallVector operands{dst, src, scalar}; - rewriter.create( - loc, TypeRange{}, "TSHRS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (TSORT32 DPS/memref op: ins(src, idx[, tmp]) outs(dst)) -//===----------------------------------------------------------------------===// - -struct PTOSORT32SToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSort32Op op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value idx = peelUnrealized(adaptor.getIdx()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src, idx, tmp}); - else - operands.assign({dst, src, idx}); - rewriter.create( - loc, TypeRange{}, "TSORT32", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSQRT DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSqrtSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSqrtOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TSQRT", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSTORE_FP DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOStoreFPSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TStoreFPOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value fp = peelUnrealized(adaptor.getFp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, fp}; - rewriter.create( - loc, TypeRange{}, "TSTORE_FP", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSUB DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSubSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSubOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TSUB", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSUBC DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSubCSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSubCOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value src2 = peelUnrealized(adaptor.getSrc2()); - Value dst = peelUnrealized(adaptor.getDst()); - - // pto-isa does not provide NPU implementation for TSUBC yet. - // Decompose: dst = src0 - src1 + src2 - rewriter.create( - loc, TypeRange{}, "TSUB", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TADD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, dst, src2}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSUBS DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSubSSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSubSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, scalar}; - rewriter.create( - loc, TypeRange{}, "TSUBS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSUBSC DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSubSCToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSubSCOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - // pto-isa does not provide NPU implementation for TSUBSC yet. - // Decompose: dst = src0 - scalar + src1 - rewriter.create( - loc, TypeRange{}, "TSUBS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, scalar}); - rewriter.create( - loc, TypeRange{}, "TADD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, dst, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TXOR DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOXORToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TXorOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = peelUnrealized(adaptor.getTmp()); - SmallVector operands{dst, src0, src1, tmp}; - rewriter.create( - loc, TypeRange{}, "TXOR", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTOTTransToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TTransOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, tmp}; - rewriter.create( - loc, TypeRange{}, "TTRANS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TXORS DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOXORSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TXorSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, scalar, tmp}; - rewriter.create( - loc, TypeRange{}, "TXORS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - struct PTOPrintToTPRINT : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPrintOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - - SmallVector operands{src}; - rewriter.create( - loc, TypeRange{}, "TPRINT", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -// pto.print "format", %scalar -> PRINTF("format", scalar) -struct PTOPrintOpToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::PrintOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - std::string fmt = op.getFormat().str(); - if (fmt.empty()) - fmt = "%f"; - std::string quoted = "\""; - for (char c : fmt) { - if (c == '"' || c == '\\') - quoted += '\\'; - else if (c == '\n') - quoted += "\\n"; - else if (c == '\t') - quoted += "\\t"; - else - quoted += c; - } - quoted += "\""; - - Value scalar = peelUnrealized(adaptor.getScalar()); - auto argsAttr = rewriter.getArrayAttr( - {emitc::OpaqueAttr::get(ctx, quoted), - IntegerAttr::get(IndexType::get(ctx), 0)}); - rewriter.create( - loc, TypeRange{}, "cce::printf", - /*args=*/argsAttr, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; - -// pto.trap -> TRAP() -struct PTOTrapOpToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TrapOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - rewriter.create( - loc, TypeRange{}, "trap", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{}); - - rewriter.eraseOp(op); - return success(); - } -}; - -// ============================================================================= -// 2. BindTileOp Lowering (FIX: Trace back to physical address) -// ============================================================================= -struct PTOBindTileToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - struct TileBuildSpec { - std::string tileTypeStr; - bool useConstructor = false; - SmallVector constructorArgs; - }; - - static bool getIndexConst(Value v, int64_t &out) { - if (!v) - return false; - if (auto cst = v.getDefiningOp()) { - if (auto ia = dyn_cast(cst.getValue())) { - out = ia.getValue().getSExtValue(); - return true; - } - } - return false; - } - - static bool getTilePointerStrides(pto::TileBufConfigAttr configAttr, - Type elemTy, int64_t rows, int64_t cols, - int64_t &rowStride, - int64_t &colStride) { - if (rows == ShapedType::kDynamic || cols == ShapedType::kDynamic) - return false; - - int32_t blVal = 0; - if (auto blAttr = dyn_cast(configAttr.getBLayout())) - blVal = static_cast(blAttr.getValue()); - else if (auto intAttr = dyn_cast(configAttr.getBLayout())) - blVal = static_cast(intAttr.getInt()); - - int32_t slVal = 0; - if (auto slAttr = dyn_cast(configAttr.getSLayout())) - slVal = static_cast(slAttr.getValue()); - else if (auto intAttr = dyn_cast(configAttr.getSLayout())) - slVal = static_cast(intAttr.getInt()); - - bool boxed = slVal != 0; - int64_t innerRows = 1; - int64_t innerCols = 1; - if (boxed) { - int32_t fractal = 512; - if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) - fractal = static_cast(frAttr.getInt()); - - unsigned elemBytes = pto::getPTOStorageElemByteSize(elemTy); - if (elemBytes == 0) - return false; - - switch (fractal) { - case 1024: - innerRows = 16; - innerCols = 16; - break; - case 32: - innerRows = 16; - innerCols = 2; - break; - case 512: - if (slVal == 1) { - innerRows = 16; - innerCols = 32 / elemBytes; - } else if (slVal == 2) { - innerRows = 32 / elemBytes; - innerCols = 16; - } else { - return false; - } - break; - default: - return false; - } - if (innerRows <= 0 || innerCols <= 0) - return false; - } - - if (!boxed) { - if (blVal == 1) { - rowStride = 1; - colStride = rows; - } else { - rowStride = cols; - colStride = 1; - } - return true; - } - - if (blVal == 1) { - if (slVal != 1) - return false; - rowStride = innerCols; - colStride = rows; - return true; - } - - rowStride = cols; - colStride = innerRows; - return true; - } - - LogicalResult matchAndRewrite(pto::BindTileOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - auto configAttr = op.getConfigAttr(); - auto viewSemantics = op->getAttrOfType("pto.view_semantics"); - bool isSubView = viewSemantics && viewSemantics.getValue() == "subview"; - - auto peelAllCasts = [](Value v) { - while (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(0); - if (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(); - return v; - }; - auto isTileLike = [](Value v) -> bool { - auto ot = dyn_cast(v.getType()); - if (!ot) - return false; - StringRef s = ot.getValue(); - return s.contains("Tile<") || s.contains("ConvTile<"); - }; - auto buildTileSpec = [&]() -> FailureOr { - auto resMrTy = dyn_cast(op.getType()); - if (!resMrTy) - return failure(); - - const char *roleTok = "TileType::Vec"; - if (auto asAttr = - dyn_cast_or_null(resMrTy.getMemorySpace())) { - switch (asAttr.getAddressSpace()) { - case pto::AddressSpace::VEC: - roleTok = "TileType::Vec"; - break; - case pto::AddressSpace::MAT: - roleTok = "TileType::Mat"; - break; - case pto::AddressSpace::LEFT: - roleTok = "TileType::Left"; - break; - case pto::AddressSpace::RIGHT: - roleTok = "TileType::Right"; - break; - case pto::AddressSpace::ACC: - roleTok = "TileType::Acc"; - break; - case pto::AddressSpace::BIAS: - roleTok = "TileType::Bias"; - break; - case pto::AddressSpace::SCALING: - roleTok = "TileType::Scaling"; - break; - case pto::AddressSpace::GM: - case pto::AddressSpace::Zero: - roleTok = "TileType::Vec"; - break; - } - } - - Type elemTy = resMrTy.getElementType(); - Type emitElemTy = getTypeConverter()->convertType(elemTy); - if (!emitElemTy) - return failure(); - auto emitElemOpaque = dyn_cast(emitElemTy); - if (!emitElemOpaque) - return failure(); - std::string elemTypeStr = emitElemOpaque.getValue().str(); - - if (resMrTy.getRank() < 2) - return failure(); - int64_t rows = resMrTy.getDimSize(0); - int64_t cols = resMrTy.getDimSize(1); - if (rows == ShapedType::kDynamic || cols == ShapedType::kDynamic) - return failure(); - - std::string blTok = "BLayout::RowMajor"; - if (auto blAttr = dyn_cast(configAttr.getBLayout())) { - if (static_cast(blAttr.getValue()) == 1) - blTok = "BLayout::ColMajor"; - } - pto::BLayout blayout = getTileBufBLayoutValue(configAttr); - - if (isSubView) { - auto subMrTy = dyn_cast(op.getSource().getType()); - auto subViewOp = op.getSource().getDefiningOp(); - if (subMrTy && subMrTy.getRank() >= 2 && subViewOp) { - int64_t subRows = subMrTy.getDimSize(0); - int64_t subCols = subMrTy.getDimSize(1); - SmallVector inheritedStrides; - int64_t inheritedOffset = ShapedType::kDynamic; - - if (!pto::isPTOFloat4PackedType(elemTy) && - subRows != ShapedType::kDynamic && - subCols != ShapedType::kDynamic && - succeeded(getStridesAndOffset(subMrTy, inheritedStrides, - inheritedOffset)) && - inheritedStrides.size() >= 2) { - int64_t childRowStride = 0; - int64_t childColStride = 0; - bool sameStrides = getTilePointerStrides( - configAttr, elemTy, subRows, subCols, childRowStride, - childColStride); - sameStrides = sameStrides && - inheritedStrides[0] == childRowStride && - inheritedStrides[1] == childColStride; - if (sameStrides) { - rows = subRows; - cols = subCols; - } - } - } - } - - std::string slTok = "SLayout::NoneBox"; - if (auto slAttr = dyn_cast(configAttr.getSLayout())) { - int32_t slVal = static_cast(slAttr.getValue()); - slTok = (slVal == 1) ? "SLayout::RowMajor" - : (slVal == 2) ? "SLayout::ColMajor" - : "SLayout::NoneBox"; - } - - int32_t fractal = 512; - if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) - fractal = frAttr.getInt(); - - std::string padTok = "PadValue::Null"; - if (auto padAttr = dyn_cast(configAttr.getPad())) { - switch (static_cast(padAttr.getValue())) { - case 1: - padTok = "PadValue::Zero"; - break; - case 2: - padTok = "PadValue::Max"; - break; - case 3: - padTok = "PadValue::Min"; - break; - default: - padTok = "PadValue::Null"; - break; - } - } - - std::string compactTok = "CompactMode::Null"; - if (auto compactAttr = dyn_cast(configAttr.getCompactMode())) { - switch (static_cast(compactAttr.getValue())) { - case 1: - compactTok = "CompactMode::Normal"; - break; - case 2: - compactTok = "CompactMode::RowPlusOne"; - break; - default: - compactTok = "CompactMode::Null"; - break; - } - } - - std::string vrowTok, vcolTok; - bool useConstructor = false; - bool rowIsDynamic = false; - bool colIsDynamic = false; - SmallVector constructorArgs; - - Value vRow = op.getValidRow(); - Value vCol = op.getValidCol(); - Value vRowEmitC = adaptor.getValidRow(); - Value vColEmitC = adaptor.getValidCol(); - bool forceDynamicValid = op->hasAttr(kForceDynamicValidShapeAttrName); - int64_t cRow = 0, cCol = 0; - bool rowIsConst = vRow && getIndexConst(vRow, cRow); - bool colIsConst = vCol && getIndexConst(vCol, cCol); - - auto makeCtorDimValue = [&](Value emitted, int64_t fallback) -> Value { - if (emitted) - return emitted; - return makeEmitCIntConstant( - rewriter, loc, emitc::OpaqueType::get(ctx, "int32_t"), fallback); - }; - auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { - if (!emitted || !pto::isPTOFloat4PackedType(elemTy)) - return emitted; - int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; - if (dimIdx != packedDim) - return emitted; - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); - return rewriter.create(loc, i32Ty, emitted, two).getResult(); - }; - - if (forceDynamicValid) { - vrowTok = "-1"; - vcolTok = "-1"; - useConstructor = true; - constructorArgs.push_back( - makeCtorDimValue(maybeScaleDynamicValid(vRowEmitC, 0), - renderTileTemplateDim(rowIsConst ? cRow : rows, - elemTy, blayout, 0))); - constructorArgs.push_back( - makeCtorDimValue(maybeScaleDynamicValid(vColEmitC, 1), - renderTileTemplateDim(colIsConst ? cCol : cols, - elemTy, blayout, 1))); - } else { - if (rowIsConst) { - vrowTok = std::to_string( - renderTileTemplateDim(cRow, elemTy, blayout, 0)); - } else if (vRow) { - vrowTok = "-1"; - rowIsDynamic = true; - useConstructor = true; - } else { - vrowTok = std::to_string( - renderTileTemplateDim(rows, elemTy, blayout, 0)); - } - - if (colIsConst) { - vcolTok = std::to_string( - renderTileTemplateDim(cCol, elemTy, blayout, 1)); - } else if (vCol) { - vcolTok = "-1"; - colIsDynamic = true; - useConstructor = true; - } else { - vcolTok = std::to_string( - renderTileTemplateDim(cols, elemTy, blayout, 1)); - } - - if (useConstructor) { - if (rowIsDynamic && vRowEmitC) - constructorArgs.push_back(maybeScaleDynamicValid(vRowEmitC, 0)); - if (colIsDynamic && vColEmitC) - constructorArgs.push_back(maybeScaleDynamicValid(vColEmitC, 1)); - } - } - - std::string tileTypeStr = std::string("Tile<") + roleTok + ", " + - elemTypeStr + ", " + - std::to_string(renderTileTemplateDim( - rows, elemTy, blayout, 0)) + - ", " + - std::to_string(renderTileTemplateDim( - cols, elemTy, blayout, 1)) + - ", " + blTok + - ", " + vrowTok + ", " + vcolTok + ", " + slTok + - ", " + std::to_string(fractal) + ", " + padTok + - ", " + compactTok + - ">"; - return TileBuildSpec{tileTypeStr, useConstructor, constructorArgs}; - }; - - auto buildTileValue = [&](const TileBuildSpec &spec, - bool forceDeclaration = false) -> Value { - auto tileType = emitc::OpaqueType::get(ctx, spec.tileTypeStr); - if (spec.useConstructor && !forceDeclaration) { - return rewriter - .create(loc, tileType, spec.tileTypeStr, - ArrayAttr{}, ArrayAttr{}, - ValueRange(spec.constructorArgs)) - .getResult(0); - } - - return rewriter - .create(loc, tileType, emitc::OpaqueAttr::get(ctx, "")) - .getResult(); - }; - - auto emitElemTypeToString = [&](Type elemTy) -> std::string { - return getEmitCScalarTypeToken(elemTy); - }; - - auto buildIntegralAddress = [&](Value sourceValue) -> FailureOr { - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - auto rcU64 = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - - Value rawPtr = sourceValue; - if (auto ot = dyn_cast(sourceValue.getType())) { - StringRef tyStr = ot.getValue(); - if (tyStr.contains("Tile<") || tyStr.contains("ConvTile<")) { - auto srcMrTy = dyn_cast(op.getSource().getType()); - if (!srcMrTy) - return failure(); - std::string elemTok = emitElemTypeToString(srcMrTy.getElementType()); - pto::AddressSpace as = pto::AddressSpace::GM; - if (auto asAttr = - dyn_cast_or_null(srcMrTy.getMemorySpace())) - as = asAttr.getAddressSpace(); - rawPtr = materializeTileDataValue(rewriter, loc, sourceValue, as, - elemTok); - } - } - - if (isSetFFTsPointerLikeType(rawPtr.getType())) { - return rewriter - .create(loc, u64Ty, "reinterpret_cast", - ArrayAttr{}, rcU64, ValueRange{rawPtr}) - .getResult(0); - } - - if (rawPtr.getType() == u64Ty) - return rawPtr; - return rewriter.create(loc, u64Ty, rawPtr).getResult(); - }; - - if (op.getSource().getDefiningOp()) { - FailureOr tileSpec = buildTileSpec(); - if (failed(tileSpec)) - return failure(); - rewriter.replaceOp(op, buildTileValue(*tileSpec)); - return success(); - } - - Value tileCandidate = peelAllCasts(adaptor.getSource()); - if (viewSemantics && viewSemantics.getValue() == "bitcast" && - isTileLike(tileCandidate)) { - FailureOr tileSpec = buildTileSpec(); - if (failed(tileSpec)) - return failure(); - Value dstTile = buildTileValue(*tileSpec); - FailureOr addr = buildIntegralAddress(tileCandidate); - if (failed(addr)) - return failure(); - - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dstTile, *addr}); - rewriter.replaceOp(op, dstTile); - return success(); - } - - if (viewSemantics && viewSemantics.getValue() == "treshape" && - isTileLike(tileCandidate)) { - FailureOr tileSpec = buildTileSpec(); - if (failed(tileSpec)) - return failure(); - Value dstTile = buildTileValue(*tileSpec, /*forceDeclaration=*/true); - - rewriter.create(loc, TypeRange{}, "TRESHAPE", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dstTile, tileCandidate}); - rewriter.replaceOp(op, dstTile); - return success(); - } - - // Subview origins are kept distinct from generic tile rebinding: - // even when source/destination C++ tile types match, subview may carry - // shifted base address semantics and should materialize a fresh handle. - if (isSubView) { - FailureOr tileSpec = buildTileSpec(); - if (failed(tileSpec)) - return failure(); - Value dstTile = buildTileValue(*tileSpec); - FailureOr addr = buildIntegralAddress(tileCandidate); - if (failed(addr)) - return failure(); - - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dstTile, *addr}); - rewriter.replaceOp(op, dstTile); - return success(); - } - - // Generic tile-to-tile rebind path: preserve the same backing storage and - // rebuild a sibling tile with updated metadata/valid dims. - if (isTileLike(tileCandidate)) { - FailureOr tileSpec = buildTileSpec(); - if (failed(tileSpec)) - return failure(); - - if (!tileSpec->useConstructor) { - if (auto srcTy = dyn_cast(tileCandidate.getType())) { - if (srcTy.getValue() == tileSpec->tileTypeStr) { - rewriter.replaceOp(op, tileCandidate); - return success(); - } - } - } - - Value dstTile = buildTileValue(*tileSpec); - FailureOr addr = buildIntegralAddress(tileCandidate); - if (failed(addr)) - return failure(); - - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dstTile, *addr}); - rewriter.replaceOp(op, dstTile); - return success(); - } - - SmallVector physAddrs; - Value source = op.getSource(); - - while (auto castOp = source.getDefiningOp()) - source = castOp.getOperand(0); - - if (auto upstreamCast = source.getDefiningOp()) { - auto upstreamOperands = upstreamCast.getAddrs(); - physAddrs.append(upstreamOperands.begin(), upstreamOperands.end()); - } else { - physAddrs.push_back(adaptor.getSource()); - } - - Value vRow = op.getValidRow(); - Value vCol = op.getValidCol(); - - auto newCast = rewriter.create( - loc, op.getType(), physAddrs, vRow ? vRow : Value(), - vCol ? vCol : Value(), configAttr); - if (viewSemantics) - newCast->setAttr("pto.view_semantics", viewSemantics); - if (op->hasAttr(kForceDynamicValidShapeAttrName)) - newCast->setAttr(kForceDynamicValidShapeAttrName, - op->getAttr(kForceDynamicValidShapeAttrName)); - rewriter.replaceOp(op, newCast.getResult()); - - return success(); - } -}; - -struct PTOAllocTileToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::AllocTileOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - MLIRContext *ctx = rewriter.getContext(); - auto tileTy = cast(op.getResult().getType()); - auto tileTypeString = getEmitCTileTypeString(tileTy); - if (!tileTypeString) - return rewriter.notifyMatchFailure( - op, "only rank-2 alloc_tile handles can be converted to EmitC"); - - Type convertedTy = getTypeConverter()->convertType(tileTy); - if (!convertedTy) - convertedTy = emitc::OpaqueType::get(ctx, *tileTypeString); - - auto validShape = tileTy.getValidShape(); - bool hasDynamicValidDim = - llvm::any_of(validShape, [](int64_t dim) { return dim < 0; }); - bool useConstructor = hasDynamicValidDim; - - SmallVector constructorArgs; - if (useConstructor) { - Type elemTy = tileTy.getElementType(); - pto::BLayout blayout = getTileBufBLayoutValue(tileTy.getConfigAttr()); - auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { - if (!emitted || !pto::isPTOFloat4PackedType(elemTy)) - return emitted; - int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; - if (dimIdx != packedDim) - return emitted; - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); - return rewriter.create(loc, i32Ty, emitted, two) - .getResult(); - }; - - if (validShape.size() > 0 && validShape[0] < 0) { - Value validRow = adaptor.getValidRow(); - if (!validRow) - return rewriter.notifyMatchFailure( - op, "dynamic alloc_tile valid row must have an operand"); - if (validRow) - validRow = peelUnrealized(validRow); - constructorArgs.push_back(maybeScaleDynamicValid(validRow, 0)); - } - if (validShape.size() > 1 && validShape[1] < 0) { - Value validCol = adaptor.getValidCol(); - if (!validCol) - return rewriter.notifyMatchFailure( - op, "dynamic alloc_tile valid col must have an operand"); - if (validCol) - validCol = peelUnrealized(validCol); - constructorArgs.push_back(maybeScaleDynamicValid(validCol, 1)); - } - } - - Value tile; - if (useConstructor) { - tile = rewriter - .create( - loc, convertedTy, *tileTypeString, ArrayAttr{}, - ArrayAttr{}, ValueRange(constructorArgs)) - .getResult(0); - } else { - tile = - rewriter - .create( - loc, convertedTy, emitc::OpaqueAttr::get(ctx, "")) - .getResult(); - } - - Value addr = adaptor.getAddr(); - if (addr) { - addr = peelUnrealized(addr); - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - if (isa(addr.getType()) || - (isa(addr.getType()) && - cast(addr.getType()).getValue().ends_with("*"))) { - auto rcU64 = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - addr = rewriter - .create(loc, u64Ty, "reinterpret_cast", - ArrayAttr{}, rcU64, - ValueRange{addr}) - .getResult(0); - } else if (addr.getType() != u64Ty) { - addr = rewriter.create(loc, u64Ty, addr).getResult(); - } - - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{tile, addr}); - } - - rewriter.replaceOp(op, tile); - return success(); - } -}; - -static FailureOr -createEmitCTileVariable(ConversionPatternRewriter &rewriter, Location loc, - const TypeConverter *typeConverter, - pto::TileBufType tileTy) { - auto tileTypeString = getEmitCTileTypeString(tileTy); - if (!tileTypeString) - return failure(); - - Type convertedTy = typeConverter->convertType(tileTy); - if (!convertedTy) - convertedTy = emitc::OpaqueType::get(rewriter.getContext(), *tileTypeString); - - return rewriter - .create( - loc, convertedTy, emitc::OpaqueAttr::get(rewriter.getContext(), "")) - .getResult(); -} - -struct PTOTReshapeToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TReshapeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto tileTy = dyn_cast(op.getResult().getType()); - if (!tileTy) - return failure(); - - FailureOr dst = - createEmitCTileVariable(rewriter, op.getLoc(), getTypeConverter(), tileTy); - if (failed(dst)) - return failure(); - - Value src = peelUnrealized(adaptor.getSrc()); - if (auto castOp = src.getDefiningOp()) - src = castOp.getOperand(); - - rewriter.create(op.getLoc(), TypeRange{}, "TRESHAPE", - ArrayAttr{}, ArrayAttr{}, - ValueRange{*dst, src}); - rewriter.replaceOp(op, *dst); - return success(); - } -}; - -struct PTOBitcastToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::BitcastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto dstTy = dyn_cast(op.getResult().getType()); - auto srcTy = dyn_cast(op.getSrc().getType()); - if (!dstTy || !srcTy) - return failure(); - - FailureOr dst = - createEmitCTileVariable(rewriter, op.getLoc(), getTypeConverter(), dstTy); - if (failed(dst)) - return failure(); - - Value src = peelUnrealized(adaptor.getSrc()); - if (auto castOp = src.getDefiningOp()) - src = castOp.getOperand(); - - pto::AddressSpace as = pto::AddressSpace::GM; - if (auto asAttr = - dyn_cast_or_null(srcTy.getMemorySpace())) - as = asAttr.getAddressSpace(); - std::string elemTok = getEmitCScalarTypeToken(srcTy.getElementType()); - - Value rawPtr = materializeTileDataValue(rewriter, op.getLoc(), src, as, elemTok); - auto u64Ty = emitc::OpaqueType::get(rewriter.getContext(), "uint64_t"); - Value addr = rawPtr; - if (isSetFFTsPointerLikeType(rawPtr.getType())) { - auto rcU64 = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), - "uint64_t")}); - addr = rewriter - .create(op.getLoc(), u64Ty, - "reinterpret_cast", ArrayAttr{}, - rcU64, ValueRange{rawPtr}) - .getResult(0); - } else if (addr.getType() != u64Ty) { - addr = rewriter.create(op.getLoc(), u64Ty, addr).getResult(); - } - - rewriter.create(op.getLoc(), TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{*dst, addr}); - rewriter.replaceOp(op, *dst); - return success(); - } -}; - -struct PTOMaterializeTileToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - static bool isTileLike(Value v) { - auto ot = dyn_cast(v.getType()); - if (!ot) - return false; - StringRef s = ot.getValue(); - return s.contains("Tile<") || s.contains("ConvTile<"); - } - - LogicalResult matchAndRewrite(pto::MaterializeTileOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - MLIRContext *ctx = rewriter.getContext(); - auto tileTy = cast(op.getResult().getType()); - auto tileTypeString = getEmitCTileTypeString(tileTy); - if (!tileTypeString) - return rewriter.notifyMatchFailure( - op, "only rank-2 tile_buf handles can be materialized to EmitC"); - - Type convertedTy = getTypeConverter()->convertType(tileTy); - if (!convertedTy) - convertedTy = emitc::OpaqueType::get(ctx, *tileTypeString); - - Value source = peelUnrealized(adaptor.getSource()); - if (auto castOp = source.getDefiningOp()) - source = castOp.getOperand(); - - auto viewSemantics = op->getAttrOfType("pto.view_semantics"); - bool forceDynamicValid = op->hasAttr(kForceDynamicValidShapeAttrName); - bool isReshape = viewSemantics && viewSemantics.getValue() == "treshape"; - bool isSubview = viewSemantics && viewSemantics.getValue() == "subview"; - bool sourceIsDeclaredTile = - op.getSource().getDefiningOp(); - - auto createTileValue = [&]() -> Value { - SmallVector constructorArgs; - bool useConstructor = false; - pto::BLayout blayout = getTileBufBLayoutValue(tileTy.getConfigAttr()); - Type elemTy = tileTy.getElementType(); - auto shape = tileTy.getShape(); - auto validShape = tileTy.getValidShape(); - - auto makeCtorDimValue = [&](Value emitted, int64_t fallback) -> Value { - if (emitted) - return emitted; - return makeEmitCIntConstant( - rewriter, loc, emitc::OpaqueType::get(ctx, "int32_t"), fallback); - }; - auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { - if (!emitted || !pto::isPTOFloat4PackedType(elemTy)) - return emitted; - int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; - if (dimIdx != packedDim) - return emitted; - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); - return rewriter.create(loc, i32Ty, emitted, two).getResult(); - }; - auto fallbackDim = [&](int dimIdx) { - return renderTileTemplateDim(shape[dimIdx], elemTy, blayout, dimIdx); - }; - - if (forceDynamicValid) { - useConstructor = true; - constructorArgs.push_back(makeCtorDimValue( - maybeScaleDynamicValid(adaptor.getValidRow(), 0), fallbackDim(0))); - constructorArgs.push_back(makeCtorDimValue( - maybeScaleDynamicValid(adaptor.getValidCol(), 1), fallbackDim(1))); - } else { - if (validShape[0] == ShapedType::kDynamic) { - useConstructor = true; - constructorArgs.push_back(makeCtorDimValue( - maybeScaleDynamicValid(adaptor.getValidRow(), 0), fallbackDim(0))); - } - if (validShape[1] == ShapedType::kDynamic) { - useConstructor = true; - constructorArgs.push_back(makeCtorDimValue( - maybeScaleDynamicValid(adaptor.getValidCol(), 1), fallbackDim(1))); - } - } - - if (useConstructor) { - return rewriter - .create(loc, convertedTy, *tileTypeString, - ArrayAttr{}, ArrayAttr{}, - ValueRange(constructorArgs)) - .getResult(0); - } - - return rewriter - .create(loc, convertedTy, - emitc::OpaqueAttr::get(ctx, "")) - .getResult(); - }; - - if (!isSubview && !forceDynamicValid && isTileLike(source)) { - if (auto srcTy = dyn_cast(source.getType())) { - if (srcTy.getValue() == *tileTypeString) { - rewriter.replaceOp(op, source); - return success(); - } - } - } - - Value tile = createTileValue(); - if (sourceIsDeclaredTile) { - rewriter.replaceOp(op, tile); - return success(); - } - - if (isReshape && isTileLike(source)) { - rewriter.create(loc, TypeRange{}, "TRESHAPE", - ArrayAttr{}, ArrayAttr{}, - ValueRange{tile, source}); - rewriter.replaceOp(op, tile); - return success(); - } - - pto::AddressSpace as = pto::AddressSpace::GM; - if (auto asAttr = - dyn_cast_or_null(tileTy.getMemorySpace())) - as = asAttr.getAddressSpace(); - std::string elemTok = getEmitCScalarTypeToken(tileTy.getElementType()); - - Value rawPtr = source; - if (isTileLike(rawPtr)) - rawPtr = materializeTileDataValue(rewriter, loc, rawPtr, as, elemTok); - - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - Value addr = rawPtr; - if (isSetFFTsPointerLikeType(rawPtr.getType())) { - auto rcU64 = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - addr = rewriter - .create(loc, u64Ty, "reinterpret_cast", - ArrayAttr{}, rcU64, - ValueRange{rawPtr}) - .getResult(0); - } else if (rawPtr.getType() != u64Ty) { - addr = rewriter.create(loc, u64Ty, rawPtr).getResult(); - } - - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{tile, addr}); - rewriter.replaceOp(op, tile); - return success(); - } -}; - -// ============================================================================= -// Arith CmpI -> EmitC Cmp -// ============================================================================= -class ArithCmpIToEmitC : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - // 将 arith.cmpi 转换为 emitc.cmp - // 映射 Predicate: eq -> equal, slt -> less, etc. - emitc::CmpPredicate emitcPred = emitc::CmpPredicate::eq; - const bool isUnsignedPred = - op.getPredicate() == arith::CmpIPredicate::ult || - op.getPredicate() == arith::CmpIPredicate::ule || - op.getPredicate() == arith::CmpIPredicate::ugt || - op.getPredicate() == arith::CmpIPredicate::uge; - switch (op.getPredicate()) { - case arith::CmpIPredicate::eq: emitcPred = emitc::CmpPredicate::eq; break; - case arith::CmpIPredicate::ne: emitcPred = emitc::CmpPredicate::ne; break; - case arith::CmpIPredicate::slt: emitcPred = emitc::CmpPredicate::lt; break; - case arith::CmpIPredicate::sle: emitcPred = emitc::CmpPredicate::le; break; - case arith::CmpIPredicate::sgt: emitcPred = emitc::CmpPredicate::gt; break; - case arith::CmpIPredicate::sge: emitcPred = emitc::CmpPredicate::ge; break; - // ... 处理无符号比较 (ult, ule 等) ... - case arith::CmpIPredicate::ult: emitcPred = emitc::CmpPredicate::lt; break; - case arith::CmpIPredicate::ule: emitcPred = emitc::CmpPredicate::le; break; - case arith::CmpIPredicate::ugt: emitcPred = emitc::CmpPredicate::gt; break; - case arith::CmpIPredicate::uge: emitcPred = emitc::CmpPredicate::ge; break; - } - - Type resTy = getTypeConverter()->convertType(op.getType()); - if (!resTy) - return failure(); - - Value lhs = adaptor.getLhs(); - Value rhs = adaptor.getRhs(); - if (isUnsignedPred) { - Type opTy = op.getLhs().getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure( - op, "expected scalar integer or index operands"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - if (bitWidth != 1) { - lhs = castSignlessIntToUnsignedSameWidth(rewriter, loc, lhs, bitWidth); - rhs = castSignlessIntToUnsignedSameWidth(rewriter, loc, rhs, bitWidth); - } - } - - rewriter.replaceOpWithNewOp( - op, - /*resultType=*/resTy, // i1 -> bool/i1 - emitcPred, - lhs, - rhs - ); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// Section Op Lowering -//===----------------------------------------------------------------------===// -static bool isA5NoSplitPipeOp(Operation *op) { - if (auto talloc = dyn_cast(op)) - return talloc.getSplit() == 0; - if (auto tpush = dyn_cast(op)) - return tpush.getSplit() == 0; - if (auto tpop = dyn_cast(op)) - return tpop.getSplit() == 0; - if (auto tfree = dyn_cast(op)) - return tfree.getSplit() == 0; - if (auto tpush = dyn_cast(op)) - return tpush.getSplit() == 0; - if (auto tpush = dyn_cast(op)) - return tpush.getSplit() == 0; - if (auto talloc = dyn_cast(op)) - return talloc.getSplit() == 0; - if (auto talloc = dyn_cast(op)) - return talloc.getSplit() == 0; - if (auto tpop = dyn_cast(op)) - return tpop.getSplit() == 0; - if (auto tpop = dyn_cast(op)) - return tpop.getSplit() == 0; - if (auto tfree = dyn_cast(op)) - return tfree.getSplit() == 0; - if (auto tfree = dyn_cast(op)) - return tfree.getSplit() == 0; - return false; -} - -static bool hasExplicitSubblockControl(Operation *op) { - bool hasControl = false; - op->walk([&](Operation *nested) { - if (isa(nested)) { - hasControl = true; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - return hasControl; -} - -static bool needsA5NoSplitVectorGuard(Operation *op) { - auto arch = getTargetArch(op); - if (arch != PTOArch::A5) - return false; - bool isVectorScope = isa(op); - if (auto func = dyn_cast(op)) { - if (auto kernelKindAttr = - func->getAttrOfType( - FunctionKernelKindAttr::name)) { - isVectorScope = - kernelKindAttr.getKernelKind() == FunctionKernelKind::Vector; - } - } - if (!isVectorScope) - return false; - if (hasExplicitSubblockControl(op)) - return false; - - bool hasNoSplitPipe = false; - op->walk([&](Operation *nested) { - if (!isA5NoSplitPipeOp(nested)) - return WalkResult::advance(); - hasNoSplitPipe = true; - return WalkResult::interrupt(); - }); - return hasNoSplitPipe; -} - -template -struct SectionToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - std::string getMacroName() const { - if (std::is_same::value) - return "__DAV_CUBE__"; - if (std::is_same::value) - return "__DAV_VEC__"; - return "UNKNOWN_MACRO"; - } - - LogicalResult - matchAndRewrite(SectionOpTy op, typename SectionOpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - bool needsNoSplitGuard = needsA5NoSplitVectorGuard(op.getOperation()); - - std::string startMacro = "\n#if defined(" + getMacroName() + ")"; - rewriter.create(loc, startMacro); - - if constexpr (std::is_same_v) { - // Vector mask is a global HW state and may be modified by previous kernels - // (or earlier sections). Reset it to a well-defined state for deterministic - // execution of VEC ops. - rewriter.create(loc, "set_mask_norm();"); - rewriter.create(loc, "set_vector_mask(-1, -1);"); - } - - if (needsNoSplitGuard) { - rewriter.create( - loc, "if (get_subblockid() == 0) {"); - } - - Block &innerBlock = op.getBody().front(); - if (!innerBlock.empty()) { - rewriter.inlineBlockBefore(&innerBlock, op.getOperation(), ValueRange{}); - } - - if (needsNoSplitGuard) - rewriter.create(loc, "}"); - - std::string endMacro = "#endif // " + getMacroName() + "\n"; - rewriter.create(loc, endMacro); - - rewriter.eraseOp(op); - - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// SCF Control-Flow Pre-Lowering -// -// EmitC translation supports `emitc.for`/`emitc.if` plus CFG-style -// `cf.br`/`cf.cond_br`. Upstream SCFToEmitC patterns only cover `scf.for` and -// `scf.if`, so we pre-lower some SCF ops into those supported forms. -//===----------------------------------------------------------------------===// - -namespace { - -static bool isTriviallyInlineableExecuteRegion(scf::ExecuteRegionOp op) { - Region &r = op.getRegion(); - if (!r.hasOneBlock()) - return false; - Block &b = r.front(); - return isa_and_nonnull(b.getTerminator()); -} - -static bool needsWholeFunctionSCFToCF(func::FuncOp func) { - bool needs = false; - func.walk([&](Operation *op) { - if (!isa(op)) - return WalkResult::advance(); - Operation *parentOp = op->getParentOp(); - - // `scf.execute_region` can legally appear in single-block parents. Only - // require whole-function SCFToCF if we need to lower it into CFG blocks - // (multi-block region / non-trivial terminators). - if (auto exec = dyn_cast(op)) { - if (parentOp && parentOp->hasTrait() && - !isTriviallyInlineableExecuteRegion(exec)) { - needs = true; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - } - - if (parentOp && parentOp->hasTrait()) { - needs = true; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - return needs; -} - -// scf.execute_region is semantically just an inlined region producing results -// via scf.yield. Inline it to the parent block to avoid extra lowering needs. -struct SCFExecuteRegionInline - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(scf::ExecuteRegionOp op, - PatternRewriter &rewriter) const override { - if (op.getRegion().empty()) - return rewriter.notifyMatchFailure(op, "expected non-empty region"); - - Block &innerBlock = op.getRegion().front(); - auto yield = dyn_cast(innerBlock.getTerminator()); - if (!yield) - return rewriter.notifyMatchFailure(op, "expected scf.yield terminator"); - - // Move the body operations before the execute_region op. - rewriter.inlineBlockBefore(&innerBlock, op.getOperation(), ValueRange{}); - - // Replace execute_region results with yielded values, then erase the yield. - rewriter.replaceOp(op, yield.getOperands()); - rewriter.eraseOp(yield); - return success(); - } -}; - -// Lower scf.execute_region into CFG blocks with cf.br/cf.cond_br by inlining the -// region blocks into the parent region and rewriting scf.yield to branch into a -// continuation block carrying results. -// -// Note: This requires the parent region to allow multiple blocks (e.g. the -// function body CFG region). For execute_region nested in single-block regions -// (scf.for/scf.if), run SCFToCF first to eliminate the single-block constraint. -struct SCFExecuteRegionToCF : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(scf::ExecuteRegionOp op, - PatternRewriter &rewriter) const override { - if (isTriviallyInlineableExecuteRegion(op)) - return rewriter.notifyMatchFailure(op, "trivially inlineable"); - - Operation *parentOp = op->getParentOp(); - if (parentOp && parentOp->hasTrait()) { - return rewriter.notifyMatchFailure( - op, "cannot lower scf.execute_region inside a single-block parent region"); - } - - if (op.getRegion().empty()) - return rewriter.notifyMatchFailure(op, "expected non-empty region"); - - Location loc = op.getLoc(); - Block *curBlock = op->getBlock(); - Region *parentRegion = curBlock->getParent(); - - // Split the parent block so we can branch to a continuation block with phi - // arguments for the execute_region results. - auto execIt = Block::iterator(op.getOperation()); - Block *continueBlock = rewriter.splitBlock(curBlock, std::next(execIt)); - - SmallVector contArgs; - contArgs.reserve(op.getNumResults()); - for (Type t : op.getResultTypes()) - contArgs.push_back(continueBlock->addArgument(t, loc)); - - for (auto it : llvm::enumerate(op.getResults())) - it.value().replaceAllUsesWith(contArgs[it.index()]); - - // Capture blocks before moving the region. - SmallVector movedBlocks; - movedBlocks.reserve(op.getRegion().getBlocks().size()); - for (Block &b : op.getRegion()) - movedBlocks.push_back(&b); - Block *entryBlock = &op.getRegion().front(); - - // Inline the execute_region blocks into the parent region right before the - // continuation block. - rewriter.inlineRegionBefore(op.getRegion(), *parentRegion, - continueBlock->getIterator()); - - // Replace all scf.yield terminators with a branch to the continuation. - for (Block *b : movedBlocks) { - auto yield = dyn_cast(b->getTerminator()); - if (!yield) - continue; - rewriter.setInsertionPoint(yield); - rewriter.create(loc, continueBlock, yield.getOperands()); - rewriter.eraseOp(yield); - } - - // Replace execute_region itself with a branch to the inlined entry block. - rewriter.setInsertionPoint(op); - rewriter.create(loc, entryBlock, ValueRange{}); - rewriter.eraseOp(op); - return success(); - } -}; - -// Lower scf.index_switch into CFG blocks with cf.cond_br/cf.br so that we can -// avoid `scf.if` result materialization quirks (and avoid relying on cf.switch, -// which is not supported by EmitC C++ translation). -struct SCFIndexSwitchToCF : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - static LogicalResult cloneYieldingBlockAndBranchTo( - PatternRewriter &rewriter, Location loc, Block &srcBlock, Block *destBlock, - Block *continueBlock) { - rewriter.setInsertionPointToEnd(destBlock); - - IRMapping mapping; - for (Operation &inner : srcBlock.without_terminator()) - rewriter.clone(inner, mapping); - - auto yield = dyn_cast(srcBlock.getTerminator()); - if (!yield) - return failure(); - - SmallVector yieldOperands; - yieldOperands.reserve(yield.getNumOperands()); - for (Value v : yield.getOperands()) - yieldOperands.push_back(mapping.lookupOrDefault(v)); - - rewriter.create(loc, continueBlock, yieldOperands); - return success(); - } - - static Block *splitBlockForContinuation(PatternRewriter &rewriter, - scf::IndexSwitchOp op) { - auto switchIt = Block::iterator(op.getOperation()); - return rewriter.splitBlock(op->getBlock(), std::next(switchIt)); - } - - static void addContinuationArguments(PatternRewriter &rewriter, - scf::IndexSwitchOp op, Location loc, - Block *continueBlock) { - SmallVector contArgs; - contArgs.reserve(op.getNumResults()); - for (Type type : op.getResultTypes()) - contArgs.push_back(continueBlock->addArgument(type, loc)); - for (auto result : llvm::enumerate(op.getResults())) - result.value().replaceAllUsesWith(contArgs[result.index()]); - } - - static void createIndexSwitchBlocks(PatternRewriter &rewriter, - Region *parentRegion, - Region::iterator insertPt, - unsigned numCases, - SmallVectorImpl &checkBlocks, - Block *&defaultBlock, - SmallVectorImpl &caseBlocks) { - checkBlocks.reserve(numCases); - caseBlocks.reserve(numCases); - for (unsigned i = 0; i < numCases; ++i) - checkBlocks.push_back(rewriter.createBlock(parentRegion, insertPt)); - defaultBlock = rewriter.createBlock(parentRegion, insertPt); - for (unsigned i = 0; i < numCases; ++i) - caseBlocks.push_back(rewriter.createBlock(parentRegion, insertPt)); - } - - static void populateIndexSwitchCheckBlocks( - PatternRewriter &rewriter, Location loc, Value selector, - ArrayRef cases, ArrayRef checkBlocks, - ArrayRef caseBlocks, Block *defaultBlock) { - for (unsigned i = 0; i < checkBlocks.size(); ++i) { - rewriter.setInsertionPointToEnd(checkBlocks[i]); - Value caseVal = rewriter.create(loc, cases[i]); - Value cond = rewriter.create( - loc, arith::CmpIPredicate::eq, selector, caseVal); - Block *falseDest = - (i + 1 < checkBlocks.size()) ? checkBlocks[i + 1] : defaultBlock; - rewriter.create(loc, cond, caseBlocks[i], ValueRange{}, - falseDest, ValueRange{}); - } - } - - LogicalResult matchAndRewrite(scf::IndexSwitchOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Operation *parentOp = op->getParentOp(); - if (parentOp && parentOp->hasTrait()) { - return rewriter.notifyMatchFailure( - op, "cannot lower scf.index_switch inside a single-block parent region"); - } - - Block *curBlock = op->getBlock(); - Region *parentRegion = curBlock->getParent(); - Block *continueBlock = splitBlockForContinuation(rewriter, op); - addContinuationArguments(rewriter, op, loc, continueBlock); - - unsigned numCases = op.getCases().size(); - auto insertPt = continueBlock->getIterator(); - - SmallVector checkBlocks; - SmallVector caseBlocks; - Block *defaultBlock = nullptr; - createIndexSwitchBlocks(rewriter, parentRegion, insertPt, numCases, - checkBlocks, defaultBlock, caseBlocks); - - Value selector = op.getArg(); - auto cases = op.getCases(); - populateIndexSwitchCheckBlocks(rewriter, loc, selector, cases, checkBlocks, - caseBlocks, defaultBlock); - - // Fill case blocks and default block with cloned bodies + branch to cont. - for (unsigned i = 0; i < numCases; ++i) { - if (failed(cloneYieldingBlockAndBranchTo( - rewriter, loc, op.getCaseBlock(i), caseBlocks[i], continueBlock))) - return rewriter.notifyMatchFailure(op, "expected scf.yield terminator"); - } - if (failed(cloneYieldingBlockAndBranchTo(rewriter, loc, op.getDefaultBlock(), - defaultBlock, continueBlock))) - return rewriter.notifyMatchFailure(op, "expected scf.yield terminator"); - - // Replace the original switch op with a branch into the check chain. - Block *entryDest = numCases ? checkBlocks[0] : defaultBlock; - rewriter.setInsertionPointAfter(op); - rewriter.create(loc, entryDest, ValueRange{}); - rewriter.eraseOp(op); - return success(); - } -}; - -// Lower scf.while into CFG blocks with cf.br/cf.cond_br. -// -// Note: This requires the parent region to allow multiple blocks. In -// particular, scf.if/scf.for regions are single-block and cannot contain this -// lowering. -struct SCFWhileToCF : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - static LogicalResult validateWhileResultUses(scf::WhileOp op) { - Block *parentBlock = op->getBlock(); - for (Value result : op.getResults()) { - for (OpOperand &use : result.getUses()) { - if (use.getOwner()->getBlock() != parentBlock) - return failure(); - } - } - return success(); - } - - static Block *splitAfterWhileBlock(PatternRewriter &rewriter, - scf::WhileOp op) { - auto whileIt = Block::iterator(op.getOperation()); - return rewriter.splitBlock(op->getBlock(), std::next(whileIt)); - } - - static void addWhileExitArguments(PatternRewriter &rewriter, scf::WhileOp op, - Location loc, Block *afterWhileBlock) { - SmallVector exitArgs; - exitArgs.reserve(op.getNumResults()); - for (Type type : op.getResultTypes()) - exitArgs.push_back(afterWhileBlock->addArgument(type, loc)); - for (auto result : llvm::enumerate(op.getResults())) - result.value().replaceAllUsesWith(exitArgs[result.index()]); - } - - static Block *createWhileHeaderBlock(PatternRewriter &rewriter, - scf::WhileOp op, Location loc, - Block *afterWhileBlock) { - SmallVector headerArgTypes; - for (Value init : op.getInits()) - headerArgTypes.push_back(init.getType()); - SmallVector headerArgLocs(headerArgTypes.size(), loc); - return rewriter.createBlock(afterWhileBlock->getParent(), - afterWhileBlock->getIterator(), headerArgTypes, - headerArgLocs); - } - - static Block *createWhileBodyBlock(PatternRewriter &rewriter, scf::WhileOp op, - Location loc, Block *afterWhileBlock) { - Block &afterRegionBlock = op.getAfter().front(); - SmallVector bodyArgTypes(afterRegionBlock.getArgumentTypes().begin(), - afterRegionBlock.getArgumentTypes().end()); - SmallVector bodyArgLocs(bodyArgTypes.size(), loc); - return rewriter.createBlock(afterWhileBlock->getParent(), - afterWhileBlock->getIterator(), bodyArgTypes, - bodyArgLocs); - } - - static void rewriteWhileTerminators(PatternRewriter &rewriter, Location loc, - Block *headerBlock, Block *bodyBlock, - Block *afterWhileBlock) { - auto condOp = cast(headerBlock->getTerminator()); - rewriter.setInsertionPoint(condOp); - rewriter.create(loc, condOp.getCondition(), - /*trueDest=*/bodyBlock, - /*trueOperands=*/condOp.getArgs(), - /*falseDest=*/afterWhileBlock, - /*falseOperands=*/condOp.getArgs()); - rewriter.eraseOp(condOp); - - auto yieldOp = cast(bodyBlock->getTerminator()); - rewriter.setInsertionPoint(yieldOp); - rewriter.create(loc, headerBlock, yieldOp.getOperands()); - rewriter.eraseOp(yieldOp); - } - - LogicalResult matchAndRewrite(scf::WhileOp op, - PatternRewriter &rewriter) const override { - Operation *parentOp = op->getParentOp(); - if (parentOp && parentOp->hasTrait()) { - return rewriter.notifyMatchFailure( - op, "cannot lower scf.while inside a single-block parent region"); - } - - if (failed(validateWhileResultUses(op))) - return rewriter.notifyMatchFailure( - op, "unsupported: while results used outside the parent block"); - - auto loc = op.getLoc(); - Block *afterWhileBlock = splitAfterWhileBlock(rewriter, op); - addWhileExitArguments(rewriter, op, loc, afterWhileBlock); - Block *headerBlock = createWhileHeaderBlock(rewriter, op, loc, - afterWhileBlock); - Block *bodyBlock = createWhileBodyBlock(rewriter, op, loc, afterWhileBlock); - - // Move the before/after region bodies into the new CFG blocks. - Block &afterRegionBlock = op.getAfter().front(); - rewriter.mergeBlocks(&op.getBefore().front(), headerBlock, - headerBlock->getArguments()); - rewriter.mergeBlocks(&afterRegionBlock, bodyBlock, bodyBlock->getArguments()); - rewriteWhileTerminators(rewriter, loc, headerBlock, bodyBlock, - afterWhileBlock); - - // Replace scf.while itself with a branch to the header. - rewriter.setInsertionPoint(op); - rewriter.create(loc, headerBlock, op.getInits()); - rewriter.eraseOp(op); - return success(); - } -}; - -// Lower cf.switch into chained comparisons and cf.cond_br/cf.br. -// -// EmitC C++ translation currently supports cf.br/cf.cond_br, but not cf.switch. -struct CFSwitchToCondBr : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - static SmallVector> - collectSwitchCaseOperands(cf::SwitchOp op) { - SmallVector> caseOperands; - caseOperands.reserve(op.getCaseDestinations().size()); - for (auto range : op.getCaseOperands()) - caseOperands.emplace_back(range.begin(), range.end()); - return caseOperands; - } - - static SmallVector getSwitchCaseValues(cf::SwitchOp op) { - SmallVector caseValues; - if (auto caseValuesAttr = op.getCaseValues()) { - for (APInt value : caseValuesAttr->getValues()) - caseValues.push_back(value); - } - return caseValues; - } - - static SmallVector createSwitchCheckBlocks(PatternRewriter &rewriter, - Region *parentRegion, - Block *curBlock, - size_t numCases) { - auto insertPt = std::next(curBlock->getIterator()); - SmallVector checkBlocks; - checkBlocks.reserve(numCases); - for (size_t i = 0; i < numCases; ++i) - checkBlocks.push_back(rewriter.createBlock(parentRegion, insertPt)); - return checkBlocks; - } - - static LogicalResult populateSwitchCheckBlocks( - PatternRewriter &rewriter, Location loc, Value flag, IntegerType flagTy, - ArrayRef caseValues, ArrayRef caseDests, - ArrayRef> caseOperands, Block *defaultDest, - ValueRange defaultOperands, ArrayRef checkBlocks, - cf::SwitchOp op) { - for (size_t i = 0; i < caseDests.size(); ++i) { - rewriter.setInsertionPointToEnd(checkBlocks[i]); - APInt caseVal = caseValues[i]; - if (caseVal.getBitWidth() != flagTy.getWidth()) { - return rewriter.notifyMatchFailure( - op, "case value bitwidth doesn't match flag type"); - } - - Value caseConst = rewriter.create( - loc, flagTy, rewriter.getIntegerAttr(flagTy, caseVal)); - Value cond = rewriter.create( - loc, arith::CmpIPredicate::eq, flag, caseConst); - Block *falseDest = - (i + 1 < checkBlocks.size()) ? checkBlocks[i + 1] : defaultDest; - ValueRange falseOperands = - (i + 1 < checkBlocks.size()) ? ValueRange{} : defaultOperands; - rewriter.create(loc, cond, caseDests[i], - caseOperands[i], falseDest, - falseOperands); - } - return success(); - } - - LogicalResult matchAndRewrite(cf::SwitchOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Operation *parentOp = op->getParentOp(); - if (parentOp && parentOp->hasTrait()) { - return rewriter.notifyMatchFailure( - op, "cannot lower cf.switch inside a single-block parent region"); - } - - Block *curBlock = op->getBlock(); - Region *parentRegion = curBlock->getParent(); - - Value flag = op.getFlag(); - auto flagTy = dyn_cast(flag.getType()); - if (!flagTy) - return rewriter.notifyMatchFailure(op, "expected integer switch flag"); - - SmallVector defaultOperands(op.getDefaultOperands().begin(), - op.getDefaultOperands().end()); - Block *defaultDest = op.getDefaultDestination(); - - SmallVector caseDests(op.getCaseDestinations().begin(), - op.getCaseDestinations().end()); - SmallVector> caseOperands = collectSwitchCaseOperands(op); - - if (caseDests.empty()) { - rewriter.replaceOpWithNewOp(op, defaultDest, defaultOperands); - return success(); - } - - if (!op.getCaseValues()) - return rewriter.notifyMatchFailure(op, "missing case_values"); - SmallVector caseValues = getSwitchCaseValues(op); - - if (caseValues.size() != caseDests.size()) - return rewriter.notifyMatchFailure(op, "case_values/destinations mismatch"); - if (caseOperands.size() != caseDests.size()) - return rewriter.notifyMatchFailure(op, "case_operands/destinations mismatch"); - - SmallVector checkBlocks = - createSwitchCheckBlocks(rewriter, parentRegion, curBlock, - caseDests.size()); - if (failed(populateSwitchCheckBlocks(rewriter, loc, flag, flagTy, - caseValues, caseDests, caseOperands, - defaultDest, defaultOperands, - checkBlocks, op))) { - return failure(); - } - - // Replace the switch terminator with a branch into the first check block. - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp(op, checkBlocks.front(), - ValueRange{}); - return success(); - } -}; - -} // namespace - -static void populatePTOToEmitCPatterns(RewritePatternSet &patterns, - TypeConverter &typeConverter, - MLIRContext *ctx, - DataFlowSolver &solver, - PTOArch targetArch) { - (void)solver; - patterns.add(typeConverter, ctx); - populatePTOToEmitCArithPatterns(patterns, typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx, "pto.set_flag_dyn", - "set_flag"); - patterns.add(typeConverter, ctx, "pto.wait_flag_dyn", - "wait_flag"); - // Backward-compatible aliases used in some downstream branches. - patterns.add(typeConverter, ctx, "pto.set_flag_d", - "set_flag"); - patterns.add(typeConverter, ctx, "pto.wait_flag_d", - "wait_flag"); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add( - typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add( - typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add>( - typeConverter, ctx, - "pto::comm::TPUT_ASYNC"); - patterns.add>( - typeConverter, ctx, - "pto::comm::TGET_ASYNC"); - patterns.add>(typeConverter, ctx, - "pto::comm::TPUT"); - patterns.add>(typeConverter, ctx, - "pto::comm::TGET"); - patterns.add>(typeConverter, ctx, - "pto::comm::TNOTIFY"); - patterns.add>(typeConverter, ctx, - "pto::comm::TWAIT"); - patterns.add>(typeConverter, ctx, - "pto::comm::TTEST"); - patterns.add>(typeConverter, ctx, - "TBROADCAST"); - patterns.add>(typeConverter, ctx, - "TGATHER"); - patterns.add>(typeConverter, ctx, - "TSCATTER"); - patterns.add>(typeConverter, ctx, - "TREDUCE"); - patterns.add>( - typeConverter, ctx, "PTOAS__ASYNC_EVENT_WAIT"); - patterns.add>( - typeConverter, ctx, "PTOAS__ASYNC_EVENT_TEST"); - patterns.add(typeConverter, ctx, targetArch); - patterns.add(typeConverter, ctx, targetArch); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx, targetArch); - patterns.add(typeConverter, ctx, targetArch); - patterns.add(typeConverter, ctx, targetArch); - patterns.add(typeConverter, ctx, targetArch); - patterns.add(typeConverter, ctx, targetArch); - patterns.add(typeConverter, ctx, targetArch); - patterns.add>(typeConverter, ctx); - patterns.add>(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add< - PTOTMatmulBiasToTMATMUL_BIAS, - PTOTMatmulMXToTMATMUL_MX, - PTOTMatmulMXAccToTMATMUL_MX_ACC, - PTOTMatmulMXBiasToTMATMUL_MX_BIAS, - PTOTMatmulBiasToTMATMUL_BIAS, - PTOTMatmulMXToTMATMUL_MX, - PTOTMatmulMXAccToTMATMUL_MX_ACC, - PTOTMatmulMXBiasToTMATMUL_MX_BIAS, - PTOTGemvBiasToTGEMV_BIAS, - PTOTGemvMXToTGEMV_MX, - PTOTGemvMXAccToTGEMV_MX, - PTOTGemvMXBiasToTGEMV_MX, - PTOBarrierToEmitC - >(typeConverter, ctx); - - patterns.add(typeConverter, ctx); - - populateSCFToEmitCConversionPatterns(patterns); - // Keep CFG-style branches type-consistent when block argument types are - // converted (e.g. after lowering scf.while to cf.br/cf.cond_br). - populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); -} - -//===----------------------------------------------------------------------===// -// Pass -//===----------------------------------------------------------------------===// - -namespace { -struct EmitPTOManualPass - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(EmitPTOManualPass) - - PTOArch targetArch; - - EmitPTOManualPass() : targetArch(PTOArch::A3) {} - - explicit EmitPTOManualPass(PTOArch arch) : targetArch(arch) {} - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - LLVM_DEBUG(llvm::dbgs() << "DEBUG: Start PTOToEmitC Pass\n"); - MLIRContext *ctx = &getContext(); - ModuleOp mop = getOperation(); - - if (failed(pto::validatePTOEntryFunctions(mop))) - return signalPassFailure(); - pto::annotatePTOEntryFunctions(mop); - - // A3 requires explicit FFTS base setup for inter-core sync ops. - if (targetArch == PTOArch::A3) { - bool hasMissingSetFFTs = false; - for (auto func : mop.getOps()) { - if (!hasInterCoreSyncOp(func)) - continue; - if (hasSetFFTsOp(func)) - continue; - hasMissingSetFFTs = true; - func.emitError() - << "A3 inter-core sync requires explicit `pto.set_ffts` in the " - "same function when using `pto.sync.set`/`pto.sync.wait`"; - } - if (hasMissingSetFFTs) - return signalPassFailure(); + if (hasMissingSetFFTs) + return signalPassFailure(); } bool needsEventIdArrayHelper = false; @@ -11013,67 +2210,8 @@ static AICORE inline void ptoas_auto_sync_tail( } // 1.5 Pre-lower SCF constructs not handled by SCFToEmitC. - { - // scf.while / scf.index_switch are lowered via CFG blocks. This is not - // possible inside ops that require single-block regions (e.g. scf.for / - // scf.if). If we see such nesting, lower the entire function to the - // ControlFlow dialect first. - bool needsAnySCFToCF = false; - for (auto func : mop.getOps()) { - if (needsWholeFunctionSCFToCF(func)) { - needsAnySCFToCF = true; - break; - } - } - if (needsAnySCFToCF) { - RewritePatternSet scfToCfPatterns(ctx); - populateSCFToControlFlowConversionPatterns(scfToCfPatterns); - FrozenRewritePatternSet frozenSCFToCF(std::move(scfToCfPatterns)); - - ConversionTarget scfToCfTarget(*ctx); - // Only eliminate the single-block SCF constructs; we'll pre-lower - // scf.while/index_switch/execute_region ourselves afterwards. - scfToCfTarget.addIllegalOp(); - scfToCfTarget.markUnknownOpDynamicallyLegal( - [](Operation *) { return true; }); - - for (auto func : mop.getOps()) { - if (!needsWholeFunctionSCFToCF(func)) - continue; - if (failed(applyPartialConversion(func, scfToCfTarget, - frozenSCFToCF))) { - func.emitError() - << "failed to lower nested SCF to ControlFlow (SCFToCF)"; - return signalPassFailure(); - } - } - } - - RewritePatternSet scfLoweringPatterns(ctx); - scfLoweringPatterns.add(ctx); - (void)applyPatternsAndFoldGreedily(mop, std::move(scfLoweringPatterns)); - - bool hasUnsupportedSCF = false; - mop.walk([&](Operation *op) { - if (isa(op)) { - hasUnsupportedSCF = true; - op->emitError() << "Unsupported SCF op remained after pre-lowering"; - return WalkResult::interrupt(); - } - if (isa(op)) { - hasUnsupportedSCF = true; - op->emitError() - << "Unsupported CF op remained after pre-lowering: cf.switch"; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - if (hasUnsupportedSCF) - return signalPassFailure(); - } + if (failed(runPTOToEmitCSCFPreLowering(mop, ctx))) + return signalPassFailure(); PTOToEmitCTypeConverter typeConverter(ctx, targetArch); diff --git a/lib/PTO/Transforms/PTOToEmitCComm.cpp b/lib/PTO/Transforms/PTOToEmitCComm.cpp new file mode 100644 index 000000000..93aed176d --- /dev/null +++ b/lib/PTO/Transforms/PTOToEmitCComm.cpp @@ -0,0 +1,889 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- PTOToEmitCComm.cpp --------------------------------------------------===// +//===----------------------------------------------------------------------===// + +#include "PTOToEmitCInternal.h" + +#include "PTO/IR/PTO.h" + +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +#include +#include +#include + +using namespace mlir; +using namespace mlir::pto; + +namespace mlir::pto { +namespace { + +static constexpr llvm::StringLiteral kGlobalTensorStridesAttrName = + "__pto.globaltensor_strides"; + +struct PTOInitializeL2G2LPipeToEmitC + : public OpConversionPattern { + PTOInitializeL2G2LPipeToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + PTOArch targetArch) + : OpConversionPattern(typeConverter, ctx), + targetArch(targetArch) {} + + LogicalResult matchAndRewrite(mlir::pto::InitializeL2G2LPipeOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto tpipeTok = buildTPipeTokenFromInitOp(op.getOperation(), targetArch); + if (failed(tpipeTok)) + return rewriter.notifyMatchFailure(op, "failed to build TPipe token"); + + auto *ctx = rewriter.getContext(); + auto emitPipeTy = + cast(getTypeConverter()->convertType(op.getPipe().getType())); + + Value gmAddr = peelUnrealized(adaptor.getGmAddr()); + gmAddr = materializeTensorViewDataPointer( + rewriter, op.getLoc(), gmAddr, op.getGmAddr().getType()); + Value localAddr = + op.getLocalAddr() ? peelUnrealized(adaptor.getLocalAddr()) : Value(); + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value zero = makeEmitCIntConstant(rewriter, op.getLoc(), i32Ty, 0); + + Value c2vBuf = zero; + Value v2cBuf = zero; + if (op.getDirMask() == 1) + c2vBuf = localAddr ? localAddr : zero; + else if (op.getDirMask() == 2) + v2cBuf = localAddr ? localAddr : zero; + else if (op.getDirMask() == 3) { + if (localAddr) { + if (!op.getPeerLocalAddr()) + return rewriter.notifyMatchFailure( + op, "bidirectional l2g2l pipe requires peer local buffer"); + c2vBuf = localAddr; + v2cBuf = peelUnrealized(adaptor.getPeerLocalAddr()); + } + } else + return rewriter.notifyMatchFailure(op, "unsupported dir_mask"); + + rewriter.replaceOpWithNewOp( + op, TypeRange{emitPipeTy}, *tpipeTok, ArrayAttr{}, ArrayAttr{}, + ValueRange{gmAddr, c2vBuf, v2cBuf}); + return success(); + } + + PTOArch targetArch; +}; + +struct PTOInitializeL2LPipeToEmitC + : public OpConversionPattern { + PTOInitializeL2LPipeToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + PTOArch targetArch) + : OpConversionPattern(typeConverter, ctx), + targetArch(targetArch) {} + + LogicalResult matchAndRewrite(mlir::pto::InitializeL2LPipeOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto tpipeTok = buildTPipeTokenFromInitOp(op.getOperation(), targetArch); + if (failed(tpipeTok)) + return rewriter.notifyMatchFailure(op, "failed to build TPipe token"); + + auto *ctx = rewriter.getContext(); + auto emitPipeTy = + cast(getTypeConverter()->convertType(op.getPipe().getType())); + + auto gmPtrTy = + emitc::PointerType::get(emitc::OpaqueType::get(ctx, "__gm__ void")); + Value nullGm = + makeEmitCOpaqueConstant(rewriter, op.getLoc(), gmPtrTy, "nullptr"); + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value zero = makeEmitCIntConstant(rewriter, op.getLoc(), i32Ty, 0); + Value localAddr = peelUnrealized(adaptor.getLocalAddr()); + + Value c2vBuf = zero; + Value v2cBuf = zero; + if (op.getDirMask() == 1) + c2vBuf = localAddr; + else if (op.getDirMask() == 2) + v2cBuf = localAddr; + else if (op.getDirMask() == 3) { + c2vBuf = localAddr; + v2cBuf = peelUnrealized(adaptor.getPeerLocalAddr()); + } else + return rewriter.notifyMatchFailure(op, "unsupported dir_mask"); + + rewriter.replaceOpWithNewOp( + op, TypeRange{emitPipeTy}, *tpipeTok, ArrayAttr{}, ArrayAttr{}, + ValueRange{nullGm, c2vBuf, v2cBuf}); + return success(); + } + + PTOArch targetArch; +}; + +struct PTOBuildAsyncSessionToEmitC + : public OpConversionPattern { + PTOBuildAsyncSessionToEmitC(TypeConverter &typeConverter, MLIRContext *ctx) + : OpConversionPattern(typeConverter, ctx) {} + + LogicalResult matchAndRewrite(mlir::pto::BuildAsyncSessionOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *ctx = rewriter.getContext(); + Location loc = op.getLoc(); + + auto sessionTy = + dyn_cast(getTypeConverter()->convertType(op.getSession().getType())); + if (!sessionTy) + return rewriter.notifyMatchFailure(op, "failed to convert async session type"); + + FailureOr scratchTile = + buildAsyncScratchTileValue(rewriter, loc, op.getScratch(), + adaptor.getScratch()); + if (failed(scratchTile)) + return rewriter.notifyMatchFailure(op, "failed to materialize async scratch tile"); + + Value workspace = + castToGMBytePointer(rewriter, loc, peelUnrealized(adaptor.getWorkspace())); + + Value session = rewriter + .create( + loc, sessionTy, emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + + auto u32Ty = emitc::OpaqueType::get(ctx, "uint32_t"); + + auto makeU32Const = [&](uint64_t value) -> Value { + return makeEmitCOpaqueConstant(rewriter, loc, u32Ty, + std::to_string(value) + "u"); + }; + uint64_t syncId = op.getSyncIdAttr() ? op.getSyncIdAttr().getInt() : 0; + uint64_t blockBytes = + op.getBlockBytesAttr() ? op.getBlockBytesAttr().getInt() : 32 * 1024; + uint64_t commBlockOffset = + op.getCommBlockOffsetAttr() ? op.getCommBlockOffsetAttr().getInt() : 0; + uint64_t queueNum = op.getQueueNumAttr() ? op.getQueueNumAttr().getInt() : 1; + uint64_t channelGroupIdx = op.getChannelGroupIdxAttr() + ? op.getChannelGroupIdxAttr().getInt() + : UINT32_MAX; + + Value syncIdVal = makeU32Const(syncId); + Value channelGroupIdxVal = + channelGroupIdx == UINT32_MAX + ? makeEmitCOpaqueConstant(rewriter, loc, u32Ty, "UINT32_MAX") + : makeU32Const(channelGroupIdx); + + auto baseConfigTy = + emitc::OpaqueType::get(ctx, "pto::comm::sdma::SdmaBaseConfig"); + Value baseConfig = + rewriter + .create( + loc, baseConfigTy, + emitc::OpaqueAttr::get( + ctx, "{" + std::to_string(blockBytes) + "ULL, " + + std::to_string(commBlockOffset) + "ULL, " + + std::to_string(queueNum) + "u}")) + .getResult(); + + rewriter.create( + loc, TypeRange{}, "pto::comm::BuildAsyncSession", + ArrayAttr{}, ArrayAttr{}, + ValueRange{*scratchTile, workspace, session, syncIdVal, baseConfig, + channelGroupIdxVal}); + + rewriter.replaceOp(op, session); + return success(); + } +}; + +template +struct PTOAsyncTransferToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + explicit PTOAsyncTransferToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + StringRef callee) + : OpConversionPattern(typeConverter, ctx), callee(callee.str()) {} + + LogicalResult matchAndRewrite(AsyncOp op, typename AsyncOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value dst = peelUnrealized(adaptor.getDst()); + Value src = peelUnrealized(adaptor.getSrc()); + Value dstGT = dst; + Value srcGT = src; + if (!isEmitCGlobalTensorLikeType(dstGT.getType())) { + auto dstMrTy = dyn_cast(op.getDst().getType()); + if (!dstMrTy) + return rewriter.notifyMatchFailure(op, "expected dst to lower to GlobalTensor or memref"); + dstGT = buildGlobalTensorFromMemref(rewriter, op.getLoc(), dst, dstMrTy, + op.getDst().getDefiningOp() + ? op.getDst().getDefiningOp() + : op.getOperation()); + } + if (!isEmitCGlobalTensorLikeType(srcGT.getType())) { + auto srcMrTy = dyn_cast(op.getSrc().getType()); + if (!srcMrTy) + return rewriter.notifyMatchFailure(op, "expected src to lower to GlobalTensor or memref"); + srcGT = buildGlobalTensorFromMemref(rewriter, op.getLoc(), src, srcMrTy, + op.getSrc().getDefiningOp() + ? op.getSrc().getDefiningOp() + : op.getOperation()); + } + if (!dstGT || !srcGT) + return rewriter.notifyMatchFailure(op, "failed to build GlobalTensor operands"); + + Type eventTy = this->getTypeConverter()->convertType(op.getEvent().getType()); + if (!eventTy) + return rewriter.notifyMatchFailure(op, "failed to convert async event type"); + + rewriter.replaceOpWithNewOp( + op, TypeRange{eventTy}, callee, ArrayAttr{}, ArrayAttr{}, + ValueRange{dstGT, srcGT, peelUnrealized(adaptor.getSession())}); + return success(); + } + + std::string callee; +}; + +template +struct PTOAsyncEventToEmitC : public OpConversionPattern { + explicit PTOAsyncEventToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + StringRef callee) + : OpConversionPattern(typeConverter, ctx), + callee(callee.str()) {} + + LogicalResult matchAndRewrite(AsyncEventOp op, + typename AsyncEventOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultTy = + this->getTypeConverter()->convertType(op.getCompleted().getType()); + if (!resultTy) + return rewriter.notifyMatchFailure(op, "failed to convert async event result type"); + + rewriter.replaceOpWithNewOp( + op, TypeRange{resultTy}, callee, ArrayAttr{}, ArrayAttr{}, + ValueRange{peelUnrealized(adaptor.getEvent()), + peelUnrealized(adaptor.getSession())}); + return success(); + } + + std::string callee; +}; + +static FailureOr buildCommGlobalTensorValue( + ConversionPatternRewriter &rewriter, Location loc, Value originalValue, + Value emittedValue, Operation *anchor) { + Value value = peelUnrealized(emittedValue); + if (isEmitCGlobalTensorLikeType(value.getType())) + return value; + + auto memTy = dyn_cast(originalValue.getType()); + if (!memTy) + return failure(); + + Value gt = buildGlobalTensorFromMemref(rewriter, loc, value, memTy, anchor); + if (!gt) + return failure(); + return gt; +} + +static FailureOr buildCommTileValue(ConversionPatternRewriter &rewriter, + Location loc, Value originalValue, + Value emittedValue) { + Value value = peelUnrealized(emittedValue); + if (auto opaqueTy = dyn_cast(value.getType())) { + StringRef typeStr = opaqueTy.getValue(); + if (typeStr.contains("Tile<") || typeStr.contains("ConvTile<")) + return value; + } + return buildAsyncScratchTileValue(rewriter, loc, originalValue, emittedValue); +} + +static FailureOr buildCollectiveParallelGroup( + ConversionPatternRewriter &rewriter, Location loc, + ArrayRef groupGTs, int64_t root) { + if (groupGTs.empty()) + return failure(); + + auto firstTy = dyn_cast(groupGTs.front().getType()); + if (!firstTy) + return failure(); + + auto *ctx = rewriter.getContext(); + auto arrayTy = emitc::ArrayType::get({static_cast(groupGTs.size())}, + firstTy); + auto groupArray = cast>( + rewriter + .create(loc, arrayTy, + emitc::OpaqueAttr::get(ctx, "{}")) + .getResult()); + + auto indexTy = emitc::OpaqueType::get(ctx, "int"); + for (auto [idx, groupVal] : llvm::enumerate(groupGTs)) { + Value idxVal = + makeEmitCIntConstant(rewriter, loc, indexTy, static_cast(idx)); + Value slot = + rewriter.create(loc, groupArray, ValueRange{idxVal}) + .getResult(); + rewriter.create(loc, slot, groupVal); + } + + std::string pgTypeStr = + (Twine("pto::comm::ParallelGroup<") + firstTy.getValue() + ">").str(); + auto pgTy = emitc::OpaqueType::get(ctx, pgTypeStr); + Value sizeVal = makeEmitCIntConstant(rewriter, loc, indexTy, + static_cast(groupGTs.size())); + Value rootVal = makeEmitCIntConstant(rewriter, loc, indexTy, root); + return rewriter + .create( + loc, TypeRange{pgTy}, (Twine(pgTypeStr) + "::Create").str(), + ArrayAttr{}, ArrayAttr{}, ValueRange{groupArray, sizeVal, rootVal}) + .getResult(0); +} + +static std::string notifyOpTok(pto::NotifyOp op) { + switch (op) { + case pto::NotifyOp::AtomicAdd: + return "pto::comm::NotifyOp::AtomicAdd"; + case pto::NotifyOp::Set: + return "pto::comm::NotifyOp::Set"; + } + return "pto::comm::NotifyOp::Set"; +} + +static std::string waitCmpTok(pto::WaitCmp cmp) { + switch (cmp) { + case pto::WaitCmp::EQ: + return "pto::comm::WaitCmp::EQ"; + case pto::WaitCmp::NE: + return "pto::comm::WaitCmp::NE"; + case pto::WaitCmp::GT: + return "pto::comm::WaitCmp::GT"; + case pto::WaitCmp::GE: + return "pto::comm::WaitCmp::GE"; + case pto::WaitCmp::LT: + return "pto::comm::WaitCmp::LT"; + case pto::WaitCmp::LE: + return "pto::comm::WaitCmp::LE"; + } + return "pto::comm::WaitCmp::EQ"; +} + +static std::string reduceOpTok(pto::ReduceOp op) { + switch (op) { + case pto::ReduceOp::Sum: + return "pto::comm::ReduceOp::Sum"; + case pto::ReduceOp::Max: + return "pto::comm::ReduceOp::Max"; + case pto::ReduceOp::Min: + return "pto::comm::ReduceOp::Min"; + } + return "pto::comm::ReduceOp::Sum"; +} + +template +static FailureOr> buildCommGroupGlobalTensors( + ConversionPatternRewriter &rewriter, Location loc, OpTy op, + ValueRange originalGroup, ValueRange emittedGroup) { + SmallVector groupGTs; + groupGTs.reserve(originalGroup.size()); + for (auto [orig, emitted] : llvm::zip(originalGroup, emittedGroup)) { + FailureOr gt = + buildCommGlobalTensorValue(rewriter, loc, orig, emitted, op.getOperation()); + if (failed(gt)) + return failure(); + groupGTs.push_back(*gt); + } + return groupGTs; +} + +template +struct PTOCommCollectiveToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + explicit PTOCommCollectiveToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + StringRef apiName) + : OpConversionPattern(typeConverter, ctx), + apiName(apiName.str()) {} + + LogicalResult matchAndRewrite(CollectiveOp op, typename CollectiveOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + auto buildPong = [&](Value original, Value emitted, StringRef name) -> FailureOr { + if (!original) + return failure(); + return buildCommTileValue(rewriter, loc, original, emitted); + }; + + if constexpr (std::is_same_v) { + FailureOr srcGT = + buildCommGlobalTensorValue(rewriter, loc, op.getSrc(), adaptor.getSrc(), + op.getOperation()); + FailureOr pingTile = + buildCommTileValue(rewriter, loc, op.getPing(), adaptor.getPing()); + auto groupGTs = + buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); + if (failed(srcGT) || failed(pingTile) || failed(groupGTs)) + return rewriter.notifyMatchFailure(op, "failed to materialize broadcast operands"); + FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); + if (failed(pg)) + return rewriter.notifyMatchFailure(op, "failed to materialize broadcast group"); + if (op.getPong()) { + FailureOr pongTile = + buildPong(op.getPong(), adaptor.getPong(), "__pong"); + if (failed(pongTile)) + return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); + rewriter.create( + loc, TypeRange{}, "pto::comm::TBROADCAST", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *srcGT, *pingTile, *pongTile}); + } else { + rewriter.create( + loc, TypeRange{}, "pto::comm::TBROADCAST", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *srcGT, *pingTile}); + } + } else if constexpr (std::is_same_v) { + FailureOr dstGT = + buildCommGlobalTensorValue(rewriter, loc, op.getDst(), adaptor.getDst(), + op.getOperation()); + FailureOr pingTile = + buildCommTileValue(rewriter, loc, op.getPing(), adaptor.getPing()); + auto groupGTs = + buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); + if (failed(dstGT) || failed(pingTile) || failed(groupGTs)) + return rewriter.notifyMatchFailure(op, "failed to materialize gather operands"); + FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); + if (failed(pg)) + return rewriter.notifyMatchFailure(op, "failed to materialize gather group"); + if (op.getPong()) { + FailureOr pongTile = + buildPong(op.getPong(), adaptor.getPong(), "__pong"); + if (failed(pongTile)) + return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); + rewriter.create( + loc, TypeRange{}, "pto::comm::TGATHER", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *dstGT, *pingTile, *pongTile}); + } else { + rewriter.create( + loc, TypeRange{}, "pto::comm::TGATHER", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *dstGT, *pingTile}); + } + } else if constexpr (std::is_same_v) { + FailureOr srcGT = + buildCommGlobalTensorValue(rewriter, loc, op.getSrc(), adaptor.getSrc(), + op.getOperation()); + FailureOr pingTile = + buildCommTileValue(rewriter, loc, op.getPing(), adaptor.getPing()); + auto groupGTs = + buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); + if (failed(srcGT) || failed(pingTile) || failed(groupGTs)) + return rewriter.notifyMatchFailure(op, "failed to materialize scatter operands"); + FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); + if (failed(pg)) + return rewriter.notifyMatchFailure(op, "failed to materialize scatter group"); + if (op.getPong()) { + FailureOr pongTile = + buildPong(op.getPong(), adaptor.getPong(), "__pong"); + if (failed(pongTile)) + return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); + rewriter.create( + loc, TypeRange{}, "pto::comm::TSCATTER", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *srcGT, *pingTile, *pongTile}); + } else { + rewriter.create( + loc, TypeRange{}, "pto::comm::TSCATTER", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *srcGT, *pingTile}); + } + } else { + FailureOr dstGT = + buildCommGlobalTensorValue(rewriter, loc, op.getDst(), adaptor.getDst(), + op.getOperation()); + FailureOr accTile = + buildCommTileValue(rewriter, loc, op.getAcc(), adaptor.getAcc()); + FailureOr recvPing = + buildCommTileValue(rewriter, loc, op.getRecvPing(), adaptor.getRecvPing()); + auto groupGTs = + buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); + if (failed(dstGT) || failed(accTile) || failed(recvPing) || failed(groupGTs)) + return rewriter.notifyMatchFailure(op, "failed to materialize reduce operands"); + FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); + if (failed(pg)) + return rewriter.notifyMatchFailure(op, "failed to materialize reduce group"); + if (op.getRecvPong()) { + FailureOr recvPong = + buildPong(op.getRecvPong(), adaptor.getRecvPong(), "__recv_pong"); + if (failed(recvPong)) + return rewriter.notifyMatchFailure(op, "failed to materialize recv_pong"); + auto reduceTy = + emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::ReduceOp"); + Value reduceOp = makeEmitCOpaqueConstant(rewriter, loc, reduceTy, + reduceOpTok(op.getReduceOp())); + rewriter.create( + loc, TypeRange{}, "pto::comm::TREDUCE", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *dstGT, *accTile, *recvPing, *recvPong, reduceOp}); + } else { + auto reduceTy = + emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::ReduceOp"); + Value reduceOp = makeEmitCOpaqueConstant(rewriter, loc, reduceTy, + reduceOpTok(op.getReduceOp())); + rewriter.create( + loc, TypeRange{}, "pto::comm::TREDUCE", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *dstGT, *accTile, *recvPing, reduceOp}); + } + } + rewriter.eraseOp(op); + return success(); + } + + std::string apiName; +}; + +template +struct PTOP2PCommToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + explicit PTOP2PCommToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + StringRef callee) + : OpConversionPattern(typeConverter, ctx), callee(callee.str()) {} + + LogicalResult matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr dstGT = + buildCommGlobalTensorValue(rewriter, op.getLoc(), op.getDst(), adaptor.getDst(), + op.getOperation()); + FailureOr srcGT = + buildCommGlobalTensorValue(rewriter, op.getLoc(), op.getSrc(), adaptor.getSrc(), + op.getOperation()); + FailureOr pingTile = + buildCommTileValue(rewriter, op.getLoc(), op.getPing(), adaptor.getPing()); + if (failed(dstGT) || failed(srcGT) || failed(pingTile)) + return rewriter.notifyMatchFailure(op, "failed to materialize p2p operands"); + + SmallVector operands{*dstGT, *srcGT, *pingTile}; + std::string actualCallee = callee; + if constexpr (std::is_same_v) { + if (op.getAtomicType() == pto::AtomicType::AtomicAdd) + actualCallee = "pto::comm::TPUT"; + } + if (op.getPong()) { + FailureOr pongTile = + buildCommTileValue(rewriter, op.getLoc(), op.getPong(), adaptor.getPong()); + if (failed(pongTile)) + return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); + operands.push_back(*pongTile); + } + + rewriter.create(op.getLoc(), TypeRange{}, actualCallee, + ArrayAttr{}, ArrayAttr{}, operands); + rewriter.eraseOp(op); + return success(); + } + + std::string callee; +}; + +template +struct PTOSignalCommToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + explicit PTOSignalCommToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + StringRef callee) + : OpConversionPattern(typeConverter, ctx), + callee(callee.str()) {} + + LogicalResult matchAndRewrite(SignalOp op, typename SignalOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr signalGT = buildCommGlobalTensorValue( + rewriter, op.getLoc(), op.getSignal(), adaptor.getSignal(), op.getOperation()); + if (failed(signalGT)) + return rewriter.notifyMatchFailure(op, "failed to materialize signal operand"); + + if constexpr (std::is_same_v) { + auto notifyTy = + emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::NotifyOp"); + Value notifyOp = makeEmitCOpaqueConstant( + rewriter, op.getLoc(), notifyTy, notifyOpTok(op.getNotifyOp())); + SmallVector operands{*signalGT, peelUnrealized(adaptor.getValue()), + notifyOp}; + rewriter.create(op.getLoc(), TypeRange{}, callee, + ArrayAttr{}, ArrayAttr{}, operands); + rewriter.eraseOp(op); + } else { + auto waitCmpTy = + emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::WaitCmp"); + Value waitCmp = makeEmitCOpaqueConstant( + rewriter, op.getLoc(), waitCmpTy, waitCmpTok(op.getCmp())); + SmallVector operands{*signalGT, peelUnrealized(adaptor.getCmpValue()), + waitCmp}; + if constexpr (std::is_same_v) { + Type resultTy = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultTy) + return rewriter.notifyMatchFailure(op, "failed to convert ttest result type"); + rewriter.replaceOpWithNewOp( + op, TypeRange{resultTy}, callee, ArrayAttr{}, ArrayAttr{}, operands); + } else { + rewriter.create(op.getLoc(), TypeRange{}, callee, + ArrayAttr{}, ArrayAttr{}, operands); + rewriter.eraseOp(op); + } + } + return success(); + } + + std::string callee; +}; + +struct PTODeclareTileMemRefToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::DeclareTileMemRefOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::DeclareTileMemRefOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + Type convertedType = getTypeConverter()->convertType(op.getResult().getType()); + if (!convertedType) + return rewriter.notifyMatchFailure( + op, "failed to convert declare_tile_memref result type"); + rewriter.replaceOp(op, makeEmitCOpaqueConstant(rewriter, op.getLoc(), + convertedType, "nullptr")); + return success(); + } +}; + +struct PTODeclareGlobalToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::DeclareGlobalOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::DeclareGlobalOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + Type convertedType = getTypeConverter()->convertType(op.getEntry().getType()); + if (!convertedType) + return rewriter.notifyMatchFailure( + op, "failed to convert declare_global result type"); + if (auto tvTy = dyn_cast(op.getEntry().getType())) { + if (auto stridesAttr = + op->getAttrOfType(kGlobalTensorStridesAttrName)) { + auto strides = stridesAttr.asArrayRef(); + if (strides.size() == static_cast(tvTy.getRank())) { + convertedType = emitc::OpaqueType::get( + rewriter.getContext(), + getGlobalTensorTypeStringFromShapeAndStrides( + tvTy.getElementType(), tvTy.getShape(), strides)); + } + } + } + auto var = rewriter.create( + op.getLoc(), convertedType, + emitc::OpaqueAttr::get(rewriter.getContext(), "")); + rewriter.replaceOp(op, var.getResult()); + return success(); + } +}; + +struct PTODeclareEventIdArrayToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::DeclareEventIdArrayOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::DeclareEventIdArrayOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + Type arrayTy = getTypeConverter()->convertType(op.getArray().getType()); + if (!arrayTy) + return rewriter.notifyMatchFailure(op, + "failed to map declared eventid_array type"); + + auto array = rewriter + .create( + op.getLoc(), arrayTy, + emitc::OpaqueAttr::get(rewriter.getContext(), "")) + .getResult(); + rewriter.replaceOp(op, array); + return success(); + } +}; + +struct PTOEventIdArrayGetToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::EventIdArrayGetOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::EventIdArrayGetOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value array = peelUnrealized(adaptor.getArray()); + Value index = peelUnrealized(adaptor.getIndex()); + + Type resultTy = getTypeConverter()->convertType(op.getResult().getType()); + if (!resultTy) + return rewriter.notifyMatchFailure(op, + "failed to map eventid_array get result type"); + + auto load = + rewriter.create(op.getLoc(), resultTy, array, index); + rewriter.replaceOp(op, load.getResult()); + return success(); + } +}; + +struct PTOEventIdArraySetToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::EventIdArraySetOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::EventIdArraySetOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value array = peelUnrealized(adaptor.getArray()); + Value index = peelUnrealized(adaptor.getIndex()); + Value value = peelUnrealized(adaptor.getValue()); + + rewriter.create( + op.getLoc(), TypeRange{}, "PTOAS__EVENTID_ARRAY_STORE", + ArrayAttr{}, ArrayAttr{}, ValueRange{array, index, value}); + rewriter.eraseOp(op); + return success(); + } +}; + +// pto.declare_local_array -> emitc.variable of !emitc.array<...>. +// Renders as `T a[D1][D2]...;` in the emitted C++. +struct PTODeclareLocalArrayToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::DeclareLocalArrayOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::DeclareLocalArrayOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + Type arrayTy = getTypeConverter()->convertType(op.getArray().getType()); + if (!arrayTy) + return rewriter.notifyMatchFailure(op, + "failed to map !pto.local_array type"); + + auto var = rewriter + .create( + op.getLoc(), arrayTy, + emitc::OpaqueAttr::get(rewriter.getContext(), "")) + .getResult(); + rewriter.replaceOp(op, var); + return success(); + } +}; + +// pto.local_array_get %a[%i0, %i1, ...] -> rvalue. +// Lowers to a single emitc.subscript with the full index pack; the C++ emitter +// prints it as `a[i0][i1]...`. The adaptor already exposes target-typed values +// (the type converter has remapped !pto.local_array -> !emitc.array and +// index/integer indices), so they're forwarded directly to the builder. +struct PTOLocalArrayGetToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::LocalArrayGetOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::LocalArrayGetOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultTy = + getTypeConverter()->convertType(op.getResult().getType()); + if (!resultTy) + return rewriter.notifyMatchFailure( + op, "failed to map local_array element type"); + + auto sub = rewriter.create( + op.getLoc(), resultTy, adaptor.getArray(), adaptor.getIndices()); + rewriter.replaceOp(op, sub.getResult()); + return success(); + } +}; + +// pto.local_array_set %a[%i0, %i1, ...], %v -> emitc.assign to subscript slot. +// The C++ emitter prints this as `a[i0][i1]... = v;`. As above, adaptor values +// are already target-typed; pass them through directly. +struct PTOLocalArraySetToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::LocalArraySetOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::LocalArraySetOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value value = adaptor.getValue(); + Type elemTy = value.getType(); + + Value slot = rewriter + .create(op.getLoc(), elemTy, + adaptor.getArray(), + adaptor.getIndices()) + .getResult(); + rewriter.create(op.getLoc(), slot, value); + rewriter.eraseOp(op); + return success(); + } +}; + + +} // namespace + +void populatePTOToEmitCCommPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx, PTOArch targetArch) { + patterns.add(typeConverter, ctx, targetArch); + patterns.add(typeConverter, ctx, targetArch); + patterns.add(typeConverter, ctx); + patterns.add>( + typeConverter, ctx, + "pto::comm::TPUT_ASYNC"); + patterns.add>( + typeConverter, ctx, + "pto::comm::TGET_ASYNC"); + patterns.add>(typeConverter, ctx, + "pto::comm::TPUT"); + patterns.add>(typeConverter, ctx, + "pto::comm::TGET"); + patterns.add>(typeConverter, ctx, + "pto::comm::TNOTIFY"); + patterns.add>(typeConverter, ctx, + "pto::comm::TWAIT"); + patterns.add>(typeConverter, ctx, + "pto::comm::TTEST"); + patterns.add>(typeConverter, ctx, + "TBROADCAST"); + patterns.add>(typeConverter, ctx, + "TGATHER"); + patterns.add>(typeConverter, ctx, + "TSCATTER"); + patterns.add>(typeConverter, ctx, + "TREDUCE"); + patterns.add>( + typeConverter, ctx, "PTOAS__ASYNC_EVENT_WAIT"); + patterns.add>( + typeConverter, ctx, "PTOAS__ASYNC_EVENT_TEST"); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); +} + +} // namespace mlir::pto diff --git a/lib/PTO/Transforms/PTOToEmitCControlFlow.cpp b/lib/PTO/Transforms/PTOToEmitCControlFlow.cpp new file mode 100644 index 000000000..8422fe40d --- /dev/null +++ b/lib/PTO/Transforms/PTOToEmitCControlFlow.cpp @@ -0,0 +1,717 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- PTOToEmitCControlFlow.cpp ------------------------------------------===// +//===----------------------------------------------------------------------===// + +#include "PTOToEmitCInternal.h" + +#include "PTO/IR/PTO.h" + +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" +#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::pto; + +namespace mlir::pto { +namespace { + +//===----------------------------------------------------------------------===// +// Return lowering +//===----------------------------------------------------------------------=== + +static constexpr llvm::StringLiteral kAutoSyncTailPendingModeAttr = + "__pto.auto_sync_tail_mode"; + +struct ReturnToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (auto emitcFunc = op->getParentOfType()) { + if (auto modeAttr = + emitcFunc->getAttrOfType(kAutoSyncTailPendingModeAttr)) { + auto *ctx = rewriter.getContext(); + rewriter.setInsertionPoint(op); + auto args = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, modeAttr.getValue())}); + rewriter.create( + op.getLoc(), TypeRange{}, "ptoas_auto_sync_tail", + args, ArrayAttr{}, ValueRange{}); + } + } + + auto vals = adaptor.getOperands(); + if (vals.empty()) { + rewriter.replaceOpWithNewOp(op, Value{}); + return success(); + } + if (vals.size() == 1) { + rewriter.replaceOpWithNewOp(op, vals[0]); + return success(); + } + return rewriter.notifyMatchFailure(op, "EmitC cannot return multiple values"); + } +}; + +struct CallToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(func::CallOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op.getNumResults() > 1) + return rewriter.notifyMatchFailure( + op, "EmitC cannot lower calls with multiple results"); + + SmallVector resultTypes; + if (failed( + getTypeConverter()->convertTypes(op.getResultTypes(), resultTypes))) + return rewriter.notifyMatchFailure(op, + "failed to convert call result types"); + + rewriter.replaceOpWithNewOp(op, op.getCalleeAttr(), + resultTypes, + adaptor.getOperands()); + return success(); + } +}; + + + +template +struct SectionToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + std::string getMacroName() const { + if (std::is_same::value) + return "__DAV_CUBE__"; + if (std::is_same::value) + return "__DAV_VEC__"; + return "UNKNOWN_MACRO"; + } + + LogicalResult + matchAndRewrite(SectionOpTy op, typename SectionOpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + bool needsNoSplitGuard = needsA5NoSplitVectorGuard(op.getOperation()); + + std::string startMacro = "\n#if defined(" + getMacroName() + ")"; + rewriter.create(loc, startMacro); + + if constexpr (std::is_same_v) { + // Vector mask is a global HW state and may be modified by previous kernels + // (or earlier sections). Reset it to a well-defined state for deterministic + // execution of VEC ops. + rewriter.create(loc, "set_mask_norm();"); + rewriter.create(loc, "set_vector_mask(-1, -1);"); + } + + if (needsNoSplitGuard) { + rewriter.create( + loc, "if (get_subblockid() == 0) {"); + } + + Block &innerBlock = op.getBody().front(); + if (!innerBlock.empty()) { + rewriter.inlineBlockBefore(&innerBlock, op.getOperation(), ValueRange{}); + } + + if (needsNoSplitGuard) + rewriter.create(loc, "}"); + + std::string endMacro = "#endif // " + getMacroName() + "\n"; + rewriter.create(loc, endMacro); + + rewriter.eraseOp(op); + + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// SCF Control-Flow Pre-Lowering +// +// EmitC translation supports `emitc.for`/`emitc.if` plus CFG-style +// `cf.br`/`cf.cond_br`. Upstream SCFToEmitC patterns only cover `scf.for` and +// `scf.if`, so we pre-lower some SCF ops into those supported forms. +//===----------------------------------------------------------------------===// + +namespace { + +static bool isTriviallyInlineableExecuteRegion(scf::ExecuteRegionOp op) { + Region &r = op.getRegion(); + if (!r.hasOneBlock()) + return false; + Block &b = r.front(); + return isa_and_nonnull(b.getTerminator()); +} + +static bool needsWholeFunctionSCFToCF(func::FuncOp func) { + bool needs = false; + func.walk([&](Operation *op) { + if (!isa(op)) + return WalkResult::advance(); + Operation *parentOp = op->getParentOp(); + + // `scf.execute_region` can legally appear in single-block parents. Only + // require whole-function SCFToCF if we need to lower it into CFG blocks + // (multi-block region / non-trivial terminators). + if (auto exec = dyn_cast(op)) { + if (parentOp && parentOp->hasTrait() && + !isTriviallyInlineableExecuteRegion(exec)) { + needs = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + } + + if (parentOp && parentOp->hasTrait()) { + needs = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return needs; +} + +// scf.execute_region is semantically just an inlined region producing results +// via scf.yield. Inline it to the parent block to avoid extra lowering needs. +struct SCFExecuteRegionInline + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ExecuteRegionOp op, + PatternRewriter &rewriter) const override { + if (op.getRegion().empty()) + return rewriter.notifyMatchFailure(op, "expected non-empty region"); + + Block &innerBlock = op.getRegion().front(); + auto yield = dyn_cast(innerBlock.getTerminator()); + if (!yield) + return rewriter.notifyMatchFailure(op, "expected scf.yield terminator"); + + // Move the body operations before the execute_region op. + rewriter.inlineBlockBefore(&innerBlock, op.getOperation(), ValueRange{}); + + // Replace execute_region results with yielded values, then erase the yield. + rewriter.replaceOp(op, yield.getOperands()); + rewriter.eraseOp(yield); + return success(); + } +}; + +// Lower scf.execute_region into CFG blocks with cf.br/cf.cond_br by inlining the +// region blocks into the parent region and rewriting scf.yield to branch into a +// continuation block carrying results. +// +// Note: This requires the parent region to allow multiple blocks (e.g. the +// function body CFG region). For execute_region nested in single-block regions +// (scf.for/scf.if), run SCFToCF first to eliminate the single-block constraint. +struct SCFExecuteRegionToCF : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ExecuteRegionOp op, + PatternRewriter &rewriter) const override { + if (isTriviallyInlineableExecuteRegion(op)) + return rewriter.notifyMatchFailure(op, "trivially inlineable"); + + Operation *parentOp = op->getParentOp(); + if (parentOp && parentOp->hasTrait()) { + return rewriter.notifyMatchFailure( + op, "cannot lower scf.execute_region inside a single-block parent region"); + } + + if (op.getRegion().empty()) + return rewriter.notifyMatchFailure(op, "expected non-empty region"); + + Location loc = op.getLoc(); + Block *curBlock = op->getBlock(); + Region *parentRegion = curBlock->getParent(); + + // Split the parent block so we can branch to a continuation block with phi + // arguments for the execute_region results. + auto execIt = Block::iterator(op.getOperation()); + Block *continueBlock = rewriter.splitBlock(curBlock, std::next(execIt)); + + SmallVector contArgs; + contArgs.reserve(op.getNumResults()); + for (Type t : op.getResultTypes()) + contArgs.push_back(continueBlock->addArgument(t, loc)); + + for (auto it : llvm::enumerate(op.getResults())) + it.value().replaceAllUsesWith(contArgs[it.index()]); + + // Capture blocks before moving the region. + SmallVector movedBlocks; + movedBlocks.reserve(op.getRegion().getBlocks().size()); + for (Block &b : op.getRegion()) + movedBlocks.push_back(&b); + Block *entryBlock = &op.getRegion().front(); + + // Inline the execute_region blocks into the parent region right before the + // continuation block. + rewriter.inlineRegionBefore(op.getRegion(), *parentRegion, + continueBlock->getIterator()); + + // Replace all scf.yield terminators with a branch to the continuation. + for (Block *b : movedBlocks) { + auto yield = dyn_cast(b->getTerminator()); + if (!yield) + continue; + rewriter.setInsertionPoint(yield); + rewriter.create(loc, continueBlock, yield.getOperands()); + rewriter.eraseOp(yield); + } + + // Replace execute_region itself with a branch to the inlined entry block. + rewriter.setInsertionPoint(op); + rewriter.create(loc, entryBlock, ValueRange{}); + rewriter.eraseOp(op); + return success(); + } +}; + +// Lower scf.index_switch into CFG blocks with cf.cond_br/cf.br so that we can +// avoid `scf.if` result materialization quirks (and avoid relying on cf.switch, +// which is not supported by EmitC C++ translation). +struct SCFIndexSwitchToCF : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + static LogicalResult cloneYieldingBlockAndBranchTo( + PatternRewriter &rewriter, Location loc, Block &srcBlock, Block *destBlock, + Block *continueBlock) { + rewriter.setInsertionPointToEnd(destBlock); + + IRMapping mapping; + for (Operation &inner : srcBlock.without_terminator()) + rewriter.clone(inner, mapping); + + auto yield = dyn_cast(srcBlock.getTerminator()); + if (!yield) + return failure(); + + SmallVector yieldOperands; + yieldOperands.reserve(yield.getNumOperands()); + for (Value v : yield.getOperands()) + yieldOperands.push_back(mapping.lookupOrDefault(v)); + + rewriter.create(loc, continueBlock, yieldOperands); + return success(); + } + + static Block *splitBlockForContinuation(PatternRewriter &rewriter, + scf::IndexSwitchOp op) { + auto switchIt = Block::iterator(op.getOperation()); + return rewriter.splitBlock(op->getBlock(), std::next(switchIt)); + } + + static void addContinuationArguments(PatternRewriter &rewriter, + scf::IndexSwitchOp op, Location loc, + Block *continueBlock) { + SmallVector contArgs; + contArgs.reserve(op.getNumResults()); + for (Type type : op.getResultTypes()) + contArgs.push_back(continueBlock->addArgument(type, loc)); + for (auto result : llvm::enumerate(op.getResults())) + result.value().replaceAllUsesWith(contArgs[result.index()]); + } + + static void createIndexSwitchBlocks(PatternRewriter &rewriter, + Region *parentRegion, + Region::iterator insertPt, + unsigned numCases, + SmallVectorImpl &checkBlocks, + Block *&defaultBlock, + SmallVectorImpl &caseBlocks) { + checkBlocks.reserve(numCases); + caseBlocks.reserve(numCases); + for (unsigned i = 0; i < numCases; ++i) + checkBlocks.push_back(rewriter.createBlock(parentRegion, insertPt)); + defaultBlock = rewriter.createBlock(parentRegion, insertPt); + for (unsigned i = 0; i < numCases; ++i) + caseBlocks.push_back(rewriter.createBlock(parentRegion, insertPt)); + } + + static void populateIndexSwitchCheckBlocks( + PatternRewriter &rewriter, Location loc, Value selector, + ArrayRef cases, ArrayRef checkBlocks, + ArrayRef caseBlocks, Block *defaultBlock) { + for (unsigned i = 0; i < checkBlocks.size(); ++i) { + rewriter.setInsertionPointToEnd(checkBlocks[i]); + Value caseVal = rewriter.create(loc, cases[i]); + Value cond = rewriter.create( + loc, arith::CmpIPredicate::eq, selector, caseVal); + Block *falseDest = + (i + 1 < checkBlocks.size()) ? checkBlocks[i + 1] : defaultBlock; + rewriter.create(loc, cond, caseBlocks[i], ValueRange{}, + falseDest, ValueRange{}); + } + } + + LogicalResult matchAndRewrite(scf::IndexSwitchOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Operation *parentOp = op->getParentOp(); + if (parentOp && parentOp->hasTrait()) { + return rewriter.notifyMatchFailure( + op, "cannot lower scf.index_switch inside a single-block parent region"); + } + + Block *curBlock = op->getBlock(); + Region *parentRegion = curBlock->getParent(); + Block *continueBlock = splitBlockForContinuation(rewriter, op); + addContinuationArguments(rewriter, op, loc, continueBlock); + + unsigned numCases = op.getCases().size(); + auto insertPt = continueBlock->getIterator(); + + SmallVector checkBlocks; + SmallVector caseBlocks; + Block *defaultBlock = nullptr; + createIndexSwitchBlocks(rewriter, parentRegion, insertPt, numCases, + checkBlocks, defaultBlock, caseBlocks); + + Value selector = op.getArg(); + auto cases = op.getCases(); + populateIndexSwitchCheckBlocks(rewriter, loc, selector, cases, checkBlocks, + caseBlocks, defaultBlock); + + // Fill case blocks and default block with cloned bodies + branch to cont. + for (unsigned i = 0; i < numCases; ++i) { + if (failed(cloneYieldingBlockAndBranchTo( + rewriter, loc, op.getCaseBlock(i), caseBlocks[i], continueBlock))) + return rewriter.notifyMatchFailure(op, "expected scf.yield terminator"); + } + if (failed(cloneYieldingBlockAndBranchTo(rewriter, loc, op.getDefaultBlock(), + defaultBlock, continueBlock))) + return rewriter.notifyMatchFailure(op, "expected scf.yield terminator"); + + // Replace the original switch op with a branch into the check chain. + Block *entryDest = numCases ? checkBlocks[0] : defaultBlock; + rewriter.setInsertionPointAfter(op); + rewriter.create(loc, entryDest, ValueRange{}); + rewriter.eraseOp(op); + return success(); + } +}; + +// Lower scf.while into CFG blocks with cf.br/cf.cond_br. +// +// Note: This requires the parent region to allow multiple blocks. In +// particular, scf.if/scf.for regions are single-block and cannot contain this +// lowering. +struct SCFWhileToCF : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + static LogicalResult validateWhileResultUses(scf::WhileOp op) { + Block *parentBlock = op->getBlock(); + for (Value result : op.getResults()) { + for (OpOperand &use : result.getUses()) { + if (use.getOwner()->getBlock() != parentBlock) + return failure(); + } + } + return success(); + } + + static Block *splitAfterWhileBlock(PatternRewriter &rewriter, + scf::WhileOp op) { + auto whileIt = Block::iterator(op.getOperation()); + return rewriter.splitBlock(op->getBlock(), std::next(whileIt)); + } + + static void addWhileExitArguments(PatternRewriter &rewriter, scf::WhileOp op, + Location loc, Block *afterWhileBlock) { + SmallVector exitArgs; + exitArgs.reserve(op.getNumResults()); + for (Type type : op.getResultTypes()) + exitArgs.push_back(afterWhileBlock->addArgument(type, loc)); + for (auto result : llvm::enumerate(op.getResults())) + result.value().replaceAllUsesWith(exitArgs[result.index()]); + } + + static Block *createWhileHeaderBlock(PatternRewriter &rewriter, + scf::WhileOp op, Location loc, + Block *afterWhileBlock) { + SmallVector headerArgTypes; + for (Value init : op.getInits()) + headerArgTypes.push_back(init.getType()); + SmallVector headerArgLocs(headerArgTypes.size(), loc); + return rewriter.createBlock(afterWhileBlock->getParent(), + afterWhileBlock->getIterator(), headerArgTypes, + headerArgLocs); + } + + static Block *createWhileBodyBlock(PatternRewriter &rewriter, scf::WhileOp op, + Location loc, Block *afterWhileBlock) { + Block &afterRegionBlock = op.getAfter().front(); + SmallVector bodyArgTypes(afterRegionBlock.getArgumentTypes().begin(), + afterRegionBlock.getArgumentTypes().end()); + SmallVector bodyArgLocs(bodyArgTypes.size(), loc); + return rewriter.createBlock(afterWhileBlock->getParent(), + afterWhileBlock->getIterator(), bodyArgTypes, + bodyArgLocs); + } + + static void rewriteWhileTerminators(PatternRewriter &rewriter, Location loc, + Block *headerBlock, Block *bodyBlock, + Block *afterWhileBlock) { + auto condOp = cast(headerBlock->getTerminator()); + rewriter.setInsertionPoint(condOp); + rewriter.create(loc, condOp.getCondition(), + /*trueDest=*/bodyBlock, + /*trueOperands=*/condOp.getArgs(), + /*falseDest=*/afterWhileBlock, + /*falseOperands=*/condOp.getArgs()); + rewriter.eraseOp(condOp); + + auto yieldOp = cast(bodyBlock->getTerminator()); + rewriter.setInsertionPoint(yieldOp); + rewriter.create(loc, headerBlock, yieldOp.getOperands()); + rewriter.eraseOp(yieldOp); + } + + LogicalResult matchAndRewrite(scf::WhileOp op, + PatternRewriter &rewriter) const override { + Operation *parentOp = op->getParentOp(); + if (parentOp && parentOp->hasTrait()) { + return rewriter.notifyMatchFailure( + op, "cannot lower scf.while inside a single-block parent region"); + } + + if (failed(validateWhileResultUses(op))) + return rewriter.notifyMatchFailure( + op, "unsupported: while results used outside the parent block"); + + auto loc = op.getLoc(); + Block *afterWhileBlock = splitAfterWhileBlock(rewriter, op); + addWhileExitArguments(rewriter, op, loc, afterWhileBlock); + Block *headerBlock = createWhileHeaderBlock(rewriter, op, loc, + afterWhileBlock); + Block *bodyBlock = createWhileBodyBlock(rewriter, op, loc, afterWhileBlock); + + // Move the before/after region bodies into the new CFG blocks. + Block &afterRegionBlock = op.getAfter().front(); + rewriter.mergeBlocks(&op.getBefore().front(), headerBlock, + headerBlock->getArguments()); + rewriter.mergeBlocks(&afterRegionBlock, bodyBlock, bodyBlock->getArguments()); + rewriteWhileTerminators(rewriter, loc, headerBlock, bodyBlock, + afterWhileBlock); + + // Replace scf.while itself with a branch to the header. + rewriter.setInsertionPoint(op); + rewriter.create(loc, headerBlock, op.getInits()); + rewriter.eraseOp(op); + return success(); + } +}; + +// Lower cf.switch into chained comparisons and cf.cond_br/cf.br. +// +// EmitC C++ translation currently supports cf.br/cf.cond_br, but not cf.switch. +struct CFSwitchToCondBr : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + static SmallVector> + collectSwitchCaseOperands(cf::SwitchOp op) { + SmallVector> caseOperands; + caseOperands.reserve(op.getCaseDestinations().size()); + for (auto range : op.getCaseOperands()) + caseOperands.emplace_back(range.begin(), range.end()); + return caseOperands; + } + + static SmallVector getSwitchCaseValues(cf::SwitchOp op) { + SmallVector caseValues; + if (auto caseValuesAttr = op.getCaseValues()) { + for (APInt value : caseValuesAttr->getValues()) + caseValues.push_back(value); + } + return caseValues; + } + + static SmallVector createSwitchCheckBlocks(PatternRewriter &rewriter, + Region *parentRegion, + Block *curBlock, + size_t numCases) { + auto insertPt = std::next(curBlock->getIterator()); + SmallVector checkBlocks; + checkBlocks.reserve(numCases); + for (size_t i = 0; i < numCases; ++i) + checkBlocks.push_back(rewriter.createBlock(parentRegion, insertPt)); + return checkBlocks; + } + + static LogicalResult populateSwitchCheckBlocks( + PatternRewriter &rewriter, Location loc, Value flag, IntegerType flagTy, + ArrayRef caseValues, ArrayRef caseDests, + ArrayRef> caseOperands, Block *defaultDest, + ValueRange defaultOperands, ArrayRef checkBlocks, + cf::SwitchOp op) { + for (size_t i = 0; i < caseDests.size(); ++i) { + rewriter.setInsertionPointToEnd(checkBlocks[i]); + APInt caseVal = caseValues[i]; + if (caseVal.getBitWidth() != flagTy.getWidth()) { + return rewriter.notifyMatchFailure( + op, "case value bitwidth doesn't match flag type"); + } + + Value caseConst = rewriter.create( + loc, flagTy, rewriter.getIntegerAttr(flagTy, caseVal)); + Value cond = rewriter.create( + loc, arith::CmpIPredicate::eq, flag, caseConst); + Block *falseDest = + (i + 1 < checkBlocks.size()) ? checkBlocks[i + 1] : defaultDest; + ValueRange falseOperands = + (i + 1 < checkBlocks.size()) ? ValueRange{} : defaultOperands; + rewriter.create(loc, cond, caseDests[i], + caseOperands[i], falseDest, + falseOperands); + } + return success(); + } + + LogicalResult matchAndRewrite(cf::SwitchOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Operation *parentOp = op->getParentOp(); + if (parentOp && parentOp->hasTrait()) { + return rewriter.notifyMatchFailure( + op, "cannot lower cf.switch inside a single-block parent region"); + } + + Block *curBlock = op->getBlock(); + Region *parentRegion = curBlock->getParent(); + + Value flag = op.getFlag(); + auto flagTy = dyn_cast(flag.getType()); + if (!flagTy) + return rewriter.notifyMatchFailure(op, "expected integer switch flag"); + + SmallVector defaultOperands(op.getDefaultOperands().begin(), + op.getDefaultOperands().end()); + Block *defaultDest = op.getDefaultDestination(); + + SmallVector caseDests(op.getCaseDestinations().begin(), + op.getCaseDestinations().end()); + SmallVector> caseOperands = collectSwitchCaseOperands(op); + + if (caseDests.empty()) { + rewriter.replaceOpWithNewOp(op, defaultDest, defaultOperands); + return success(); + } + + if (!op.getCaseValues()) + return rewriter.notifyMatchFailure(op, "missing case_values"); + SmallVector caseValues = getSwitchCaseValues(op); + + if (caseValues.size() != caseDests.size()) + return rewriter.notifyMatchFailure(op, "case_values/destinations mismatch"); + if (caseOperands.size() != caseDests.size()) + return rewriter.notifyMatchFailure(op, "case_operands/destinations mismatch"); + + SmallVector checkBlocks = + createSwitchCheckBlocks(rewriter, parentRegion, curBlock, + caseDests.size()); + if (failed(populateSwitchCheckBlocks(rewriter, loc, flag, flagTy, + caseValues, caseDests, caseOperands, + defaultDest, defaultOperands, + checkBlocks, op))) { + return failure(); + } + + // Replace the switch terminator with a branch into the first check block. + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp(op, checkBlocks.front(), + ValueRange{}); + return success(); + } +}; + +} // namespace + + +} // namespace + +LogicalResult runPTOToEmitCSCFPreLowering(ModuleOp mop, MLIRContext *ctx) { + bool needsAnySCFToCF = false; + for (auto func : mop.getOps()) { + if (needsWholeFunctionSCFToCF(func)) { + needsAnySCFToCF = true; + break; + } + } + if (needsAnySCFToCF) { + RewritePatternSet scfToCfPatterns(ctx); + populateSCFToControlFlowConversionPatterns(scfToCfPatterns); + FrozenRewritePatternSet frozenSCFToCF(std::move(scfToCfPatterns)); + + ConversionTarget scfToCfTarget(*ctx); + scfToCfTarget.addIllegalOp(); + scfToCfTarget.markUnknownOpDynamicallyLegal( + [](Operation *) { return true; }); + + for (auto func : mop.getOps()) { + if (!needsWholeFunctionSCFToCF(func)) + continue; + if (failed(applyPartialConversion(func, scfToCfTarget, + frozenSCFToCF))) { + func.emitError() + << "failed to lower nested SCF to ControlFlow (SCFToCF)"; + return failure(); + } + } + } + + RewritePatternSet scfLoweringPatterns(ctx); + scfLoweringPatterns.add(ctx); + (void)applyPatternsAndFoldGreedily(mop, std::move(scfLoweringPatterns)); + + bool hasUnsupportedSCF = false; + mop.walk([&](Operation *op) { + if (isa(op)) { + hasUnsupportedSCF = true; + op->emitError() << "Unsupported SCF op remained after pre-lowering"; + return WalkResult::interrupt(); + } + if (isa(op)) { + hasUnsupportedSCF = true; + op->emitError() + << "Unsupported CF op remained after pre-lowering: cf.switch"; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return failure(hasUnsupportedSCF); +} + +void populatePTOToEmitCControlFlowPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx) { + patterns.add>(typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add(typeConverter, ctx); + populateSCFToEmitCConversionPatterns(patterns); + populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); +} + +} // namespace mlir::pto diff --git a/lib/PTO/Transforms/PTOToEmitCInternal.h b/lib/PTO/Transforms/PTOToEmitCInternal.h index 0d43b8a1b..e8be34ed2 100644 --- a/lib/PTO/Transforms/PTOToEmitCInternal.h +++ b/lib/PTO/Transforms/PTOToEmitCInternal.h @@ -9,16 +9,140 @@ #ifndef MLIR_DIALECT_PTO_TRANSFORMS_PTOTOEMITCINTERNAL_H #define MLIR_DIALECT_PTO_TRANSFORMS_PTOTOEMITCINTERNAL_H +#include "PTO/IR/PTO.h" + #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Value.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" +#include +#include + namespace mlir::pto { +Value peelUnrealized(Value v); + +Value makeEmitCOpaqueConstant(ConversionPatternRewriter &rewriter, + Location loc, Type type, + llvm::StringRef literal); + +Value makeEmitCIntConstant(ConversionPatternRewriter &rewriter, Location loc, + Type type, int64_t value); + +Value emitCCast(ConversionPatternRewriter &rewriter, Location loc, Type dstType, + Value src); + +Value castSignlessIntToUnsignedSameWidth(ConversionPatternRewriter &rewriter, + Location loc, Value v, + unsigned bitWidth); + +std::string getEmitCScalarTypeToken(Type elemTy); + +pto::BLayout getTileBufBLayoutValue(pto::TileBufConfigAttr configAttr); + +int64_t renderTileTemplateDim(int64_t rawDim, Type elemTy, + pto::BLayout blayout, int dimIdx); + +std::optional getEmitCTileTypeString(pto::TileBufType type); + +bool isSetFFTsPointerLikeType(Type ty); + +bool isEmitCGlobalTensorLikeType(Type ty); + +std::string getGlobalTensorTypeStringFromShapeAndStrides( + Type elemTy, ArrayRef shape, ArrayRef strides, + llvm::StringRef layoutEnum = "pto::Layout::ND"); + +std::string getElemTypeStringForGT(Type elemTy); + +SmallVector buildRowMajorStrides(ArrayRef shape); + +void buildGlobalTensorShapeAndStride(ArrayRef shape, + ArrayRef strides, + SmallVectorImpl &shape5D, + SmallVectorImpl &stride5D); + +std::string joinIntTemplateParams(ArrayRef values); + +Value buildGlobalTensorFromMemref(ConversionPatternRewriter &rewriter, + Location loc, Value basePtr, + MemRefType mrTy, Operation *anchor); + +FailureOr buildTPipeTokenFromInitOp(Operation *op, + PTOArch targetArch); + +Value castToGMBytePointer(ConversionPatternRewriter &rewriter, Location loc, + Value value); + +Value materializeTensorViewDataPointer(ConversionPatternRewriter &rewriter, + Location loc, Value value, + Type originalType); + +Value materializeAddressAsPointer(ConversionPatternRewriter &rewriter, + Location loc, Value addr, + pto::AddressSpace as, + llvm::StringRef elemTok); + +Value applyStaticMemrefOffset(ConversionPatternRewriter &rewriter, + Location loc, Value basePtr, int64_t offset); + +FailureOr buildAsyncScratchTileValue(ConversionPatternRewriter &rewriter, + Location loc, Value originalScratch, + Value emittedScratch); + +bool needsA5NoSplitVectorGuard(Operation *op); + +Value materializeTileDataValue(ConversionPatternRewriter &rewriter, + Location loc, Value tile, + pto::AddressSpace as, + llvm::StringRef elemTypeToken); + void populatePTOToEmitCArithPatterns(RewritePatternSet &patterns, TypeConverter &typeConverter, MLIRContext *ctx); +void populatePTOToEmitCTilePatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx); + +void populatePTOToEmitCTileExtraPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx); + +void populatePTOToEmitCTileMaterializationPatterns( + RewritePatternSet &patterns, TypeConverter &typeConverter, MLIRContext *ctx); + +void populatePTOToEmitCSyncPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx, PTOArch targetArch); + +void populatePTOToEmitCCommPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx, PTOArch targetArch); + +void populatePTOToEmitCKernelOpPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx); + +LogicalResult runPTOToEmitCSCFPreLowering(ModuleOp mop, MLIRContext *ctx); + +void populatePTOToEmitCControlFlowPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx); + +void populatePTOToEmitCSimpleOpPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx); + +void populatePTOToEmitCRuntimeOpPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx, PTOArch targetArch); + +void populatePTOToEmitCMemoryOpPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx); + } // namespace mlir::pto #endif // MLIR_DIALECT_PTO_TRANSFORMS_PTOTOEMITCINTERNAL_H diff --git a/lib/PTO/Transforms/PTOToEmitCKernelOps.cpp b/lib/PTO/Transforms/PTOToEmitCKernelOps.cpp new file mode 100644 index 000000000..e0a80102d --- /dev/null +++ b/lib/PTO/Transforms/PTOToEmitCKernelOps.cpp @@ -0,0 +1,516 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- PTOToEmitCKernelOps.cpp --------------------------------------------===// +//===----------------------------------------------------------------------===// + +#include "PTOToEmitCInternal.h" + +#include "PTO/IR/PTO.h" + +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +#include + +using namespace mlir; +using namespace mlir::pto; + +namespace mlir::pto { +namespace { + +struct PTOTLoadToTLOAD : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TLoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getDst()) + return rewriter.notifyMatchFailure(op, "expected outs(dst) on pto.tload"); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value srcArg = src; + if (auto srcMrTy = dyn_cast(op.getSrc().getType())) { + bool isGlobal = true; + if (auto asAttr = dyn_cast_or_null(srcMrTy.getMemorySpace())) { + auto as = asAttr.getAddressSpace(); + isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); + } + if (isGlobal) { + if (Value gt = buildGlobalTensorFromMemref(rewriter, op.getLoc(), src, srcMrTy, + op.getOperation())) + srcArg = gt; + } + } + + rewriter.create( + op.getLoc(), TypeRange{}, "TLOAD", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, srcArg}); + + if (op->getNumResults() == 1) { + rewriter.replaceOp(op, dst); + } else { + rewriter.eraseOp(op); + } + return success(); + } +}; + +struct PTOTPrefetchToTPREFETCH : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPrefetchOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getDst()) + return rewriter.notifyMatchFailure(op, "expected outs(dst) on pto.tprefetch"); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value srcArg = src; + if (auto srcMrTy = dyn_cast(op.getSrc().getType())) { + bool isGlobal = true; + if (auto asAttr = dyn_cast_or_null(srcMrTy.getMemorySpace())) { + auto as = asAttr.getAddressSpace(); + isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); + } + if (isGlobal) { + if (Value gt = buildGlobalTensorFromMemref(rewriter, op.getLoc(), src, srcMrTy, + op.getOperation())) + srcArg = gt; + } + } + + rewriter.create( + op.getLoc(), TypeRange{}, "TPREFETCH", + ArrayAttr{}, ArrayAttr{}, ValueRange{dst, srcArg}); + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOTPrefetchAsyncToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPrefetchAsyncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = peelUnrealized(adaptor.getSrc()); + Value srcArg = src; + if (!isEmitCGlobalTensorLikeType(srcArg.getType())) { + auto srcMrTy = dyn_cast(op.getSrc().getType()); + if (!srcMrTy) + return rewriter.notifyMatchFailure( + op, "expected src to lower to GlobalTensor or memref"); + srcArg = buildGlobalTensorFromMemref(rewriter, op.getLoc(), src, srcMrTy, + op.getSrc().getDefiningOp() + ? op.getSrc().getDefiningOp() + : op.getOperation()); + } + if (!srcArg) + return rewriter.notifyMatchFailure(op, + "failed to build GlobalTensor src"); + + Value prefetchCtx = peelUnrealized(adaptor.getCtx()); + + Type eventTy = getTypeConverter()->convertType(op.getEvent().getType()); + if (!eventTy) + return rewriter.notifyMatchFailure( + op, "failed to convert tprefetch_async result type"); + + Value event = rewriter + .create( + op.getLoc(), TypeRange{eventTy}, "TPREFETCH_ASYNC", + ArrayAttr{}, ArrayAttr{}, + ValueRange{srcArg, prefetchCtx}) + .getResult(0); + + rewriter.replaceOp(op, ValueRange{event}); + return success(); + } +}; + +struct PTOMakePrefetchAsyncContextToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::MakePrefetchAsyncContextOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type ctxTy = getTypeConverter()->convertType(op.getCtx().getType()); + if (!ctxTy) + return rewriter.notifyMatchFailure( + op, "failed to convert make_prefetch_async_context result type"); + + Value workspace = peelUnrealized(adaptor.getWorkspace()); + workspace = castToGMBytePointer(rewriter, op.getLoc(), workspace); + + Value ctx = rewriter + .create( + op.getLoc(), TypeRange{ctxTy}, "pto::PrefetchAsyncContext", + ArrayAttr{}, ArrayAttr{}, ValueRange{workspace}) + .getResult(0); + + rewriter.replaceOp(op, ValueRange{ctx}); + return success(); + } +}; + +struct PTOGetPrefetchAsyncSessionToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::GetPrefetchAsyncSessionOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type sessionTy = getTypeConverter()->convertType(op.getSession().getType()); + if (!sessionTy) + return rewriter.notifyMatchFailure( + op, "failed to convert get_prefetch_async_session result type"); + + Value ctx = peelUnrealized(adaptor.getCtx()); + Value session = rewriter + .create( + op.getLoc(), TypeRange{sessionTy}, + "PTOAS__PREFETCH_CTX_SESSION", ArrayAttr{}, + ArrayAttr{}, ValueRange{ctx}) + .getResult(0); + + rewriter.replaceOp(op, ValueRange{session}); + return success(); + } +}; + +struct PTOTStoreToTSTORE : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + static std::string stPhaseTok(pto::STPhase phase) { + switch (phase) { + case pto::STPhase::Unspecified: return "STPhase::Unspecified"; + case pto::STPhase::Partial: return "STPhase::Partial"; + case pto::STPhase::Final: return "STPhase::Final"; + } + return "STPhase::Unspecified"; + } + + static std::string atomicTypeTok(pto::AtomicType atomicType) { + switch (atomicType) { + case pto::AtomicType::AtomicNone: return "AtomicType::AtomicNone"; + case pto::AtomicType::AtomicAdd: return "AtomicType::AtomicAdd"; + } + return "AtomicType::AtomicNone"; + } + + static std::string reluPreModeTok(pto::ReluPreMode reluPreMode) { + switch (reluPreMode) { + case pto::ReluPreMode::NoRelu: return "ReluPreMode::NoRelu"; + case pto::ReluPreMode::NormalRelu: return "ReluPreMode::NormalRelu"; + } + return "ReluPreMode::NoRelu"; + } + + LogicalResult matchAndRewrite(pto::TStoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getDst()) + return rewriter.notifyMatchFailure(op, "expected outs(dst) on pto.tstore"); + + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value preQuantScalar; + if (op.getPreQuantScalar()) + preQuantScalar = peelUnrealized(adaptor.getPreQuantScalar()); + Value dstArg = dst; + if (auto dstMrTy = dyn_cast(op.getDst().getType())) { + bool isGlobal = true; + if (auto asAttr = dyn_cast_or_null(dstMrTy.getMemorySpace())) { + auto as = asAttr.getAddressSpace(); + isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); + } + if (isGlobal) { + if (Value gt = buildGlobalTensorFromMemref(rewriter, op.getLoc(), dst, dstMrTy, + op.getOperation())) + dstArg = gt; + } + } + + const auto phase = op.getStPhase(); + const auto atomicType = op.getAtomicType(); + const auto reluPreMode = op.getReluPreMode(); + const bool hasPreQuantScalar = static_cast(preQuantScalar); + const bool phaseNonDefault = phase != pto::STPhase::Unspecified; + const bool atomicNonDefault = atomicType != pto::AtomicType::AtomicNone; + const bool reluNonDefault = reluPreMode != pto::ReluPreMode::NoRelu; + + auto getOpaqueTok = [&](Value v, StringRef name) -> FailureOr { + if (auto ot = mlir::dyn_cast(v.getType())) + return ot.getValue().str(); + return rewriter.notifyMatchFailure(op, (name + " must be emitc::OpaqueType").str()); + }; + + ArrayAttr targs; + // Map op attributes/operands to the exact TSTORE overload family: + // 1) TSTORE(dst, src) + // 2) TSTORE(dst, src) + // 3) TSTORE(dst, src) + // 4) TSTORE(dst, src) + // 5) TSTORE(dst, src) + // 6) TSTORE(dst, src) + // 7) TSTORE(dst, src, preQuant) + // 8) TSTORE(dst, src, preQuant) + if (!hasPreQuantScalar && !reluNonDefault && !atomicNonDefault) { + if (phaseNonDefault) { + targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, stPhaseTok(phase)), + }); + } else { + targs = ArrayAttr{}; + } + } else { + auto srcTokOr = getOpaqueTok(src, "src"); + auto dstTokOr = getOpaqueTok(dstArg, "dst"); + if (failed(srcTokOr) || failed(dstTokOr)) + return failure(); + + // If there is no preQuant and relu stays default, emit the atomic-only + // overloads (#3/#4) without ReluPreMode template argument. + if (!hasPreQuantScalar && !reluNonDefault) { + if (phaseNonDefault) { + targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, stPhaseTok(phase)), + emitc::OpaqueAttr::get(ctx, *srcTokOr), + emitc::OpaqueAttr::get(ctx, *dstTokOr), + emitc::OpaqueAttr::get(ctx, atomicTypeTok(atomicType)), + }); + } else { + targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, *srcTokOr), + emitc::OpaqueAttr::get(ctx, *dstTokOr), + emitc::OpaqueAttr::get(ctx, atomicTypeTok(atomicType)), + }); + } + } else { + // Relu/preQuant families (#5/#6/#7/#8): keep AtomicType + ReluPreMode. + if (phaseNonDefault) { + targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, stPhaseTok(phase)), + emitc::OpaqueAttr::get(ctx, *srcTokOr), + emitc::OpaqueAttr::get(ctx, *dstTokOr), + emitc::OpaqueAttr::get(ctx, atomicTypeTok(atomicType)), + emitc::OpaqueAttr::get(ctx, reluPreModeTok(reluPreMode)), + }); + } else { + targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, *srcTokOr), + emitc::OpaqueAttr::get(ctx, *dstTokOr), + emitc::OpaqueAttr::get(ctx, atomicTypeTok(atomicType)), + emitc::OpaqueAttr::get(ctx, reluPreModeTok(reluPreMode)), + }); + } + } + } + + SmallVector operands{dstArg, src}; + if (hasPreQuantScalar) + operands.push_back(preQuantScalar); + + rewriter.create( + loc, TypeRange{}, "TSTORE", + /*args=*/ArrayAttr{}, /*templateArgs=*/targs, + /*operands=*/operands); + + if (op->getNumResults() == 1) { + rewriter.replaceOp(op, dst); + } else { + rewriter.eraseOp(op); + } + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// pto.matmul_dps lowering (Simplified: No internal copy/sync) +//===----------------------------------------------------------------------===// +// +// Render `pto.tmatmul` as one of three forms depending on the optional +// `acc_phase` attribute: +// * absent / Unspecified -> `TMATMUL(dst, lhs, rhs)` +// * Partial -> `TMATMUL(dst, lhs, rhs)` +// * Final -> `TMATMUL(dst, lhs, rhs)` +// The Unspecified default keeps backward compatibility with all upstream IR +// that does not yet emit an explicit phase attribute. +static ArrayAttr buildAccPhaseTemplateArgs(ConversionPatternRewriter &rewriter, + pto::AccPhase phase) { + StringRef tmpl; + switch (phase) { + case pto::AccPhase::Unspecified: + return ArrayAttr{}; + case pto::AccPhase::Partial: + tmpl = "AccPhase::Partial"; + break; + case pto::AccPhase::Final: + tmpl = "AccPhase::Final"; + break; + } + if (tmpl.empty()) + return ArrayAttr{}; + return rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(rewriter.getContext(), tmpl)}); +} + +struct PTOTMatmulToTMATMUL : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMatmulOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // 1. 获取操作数 (剥离 Cast) + Value lhs = peelUnrealized(adaptor.getLhs()); // A (Left) + Value rhs = peelUnrealized(adaptor.getRhs()); // B (Right) + Value dst = peelUnrealized(adaptor.getDst()); // C (Acc) + + // 2. 根据 acc_phase 属性决定是否生成 TMATMUL(...) + ArrayAttr templateArgs = + buildAccPhaseTemplateArgs(rewriter, op.getAccPhase()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TMATMUL", + /*args=*/ArrayAttr{}, /*template_args=*/templateArgs, + ValueRange{dst, lhs, rhs}); + + // 3. 处理 Op 替换/删除 + if (op->getNumResults() == 1) { + rewriter.replaceOp(op, dst); + } else { + rewriter.eraseOp(op); + } + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// pto.tgemv lowering +//===----------------------------------------------------------------------===// +struct PTOTGemvToTGEMV : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGemvOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // 1. 获取操作数 (剥离 Cast) + Value lhs = peelUnrealized(adaptor.getLhs()); // A (Matrix) + Value rhs = peelUnrealized(adaptor.getRhs()); // B (Vector) + Value dst = peelUnrealized(adaptor.getDst()); // C (Result) + + // 2. 直接生成函数调用 TGEMV(dst, lhs, rhs) + rewriter.create( + op.getLoc(), TypeRange{}, "TGEMV", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, lhs, rhs}); + + // 3. 处理 Op 替换/删除 + if (op->getNumResults() == 1) { + rewriter.replaceOp(op, dst); + } else { + rewriter.eraseOp(op); + } + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// pto.tgemv.acc lowering +//===----------------------------------------------------------------------===// +struct PTOTGemvAccToTGEMVACC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGemvAccOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getDst()) + return rewriter.notifyMatchFailure(op, "expected outs(dst) for pto.tgemv.acc"); + + // 1. 获取操作数 + Value accIn = peelUnrealized(adaptor.getAccIn()); // AccOld + Value lhs = peelUnrealized(adaptor.getLhs()); // A (Matrix) + Value rhs = peelUnrealized(adaptor.getRhs()); // B (Vector) + Value dst = peelUnrealized(adaptor.getDst()); // AccNew + + // 2. 直接生成函数调用 TGEMV_ACC(dst, accIn, lhs, rhs) + rewriter.create( + op.getLoc(), TypeRange{}, "TGEMV_ACC", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, accIn, lhs, rhs}); + + // 3. 处理 Op 替换/删除 + if (op->getNumResults() == 1) { + rewriter.replaceOp(op, dst); + } else { + rewriter.eraseOp(op); + } + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// pto.matmul_acc_dps lowering (Simplified: No internal copy/sync) +//===----------------------------------------------------------------------===// +struct PTOTMatmulAccToTMATMULACC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMatmulAccOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getDst()) + return rewriter.notifyMatchFailure(op, "expected outs(dst) for pto.tmatmul.acc"); + + // 1. 获取操作数 + Value accIn = peelUnrealized(adaptor.getAccIn()); // AccOld + Value lhs = peelUnrealized(adaptor.getLhs()); // A (Left) + Value rhs = peelUnrealized(adaptor.getRhs()); // B (Right) + Value dst = peelUnrealized(adaptor.getDst()); // AccNew + + // 2. 根据 acc_phase 属性决定是否生成 TMATMUL_ACC(...) + ArrayAttr templateArgs = + buildAccPhaseTemplateArgs(rewriter, op.getAccPhase()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TMATMUL_ACC", + /*args=*/ArrayAttr{}, /*template_args=*/templateArgs, + ValueRange{dst, accIn, lhs, rhs}); + + // 3. 处理 Op 替换/删除 + if (op->getNumResults() == 1) { + rewriter.replaceOp(op, dst); + } else { + rewriter.eraseOp(op); + } + return success(); + } +}; + + +} // namespace + +void populatePTOToEmitCKernelOpPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx) { + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); +} + +} // namespace mlir::pto diff --git a/lib/PTO/Transforms/PTOToEmitCMemoryOps.cpp b/lib/PTO/Transforms/PTOToEmitCMemoryOps.cpp new file mode 100644 index 000000000..dba225b3c --- /dev/null +++ b/lib/PTO/Transforms/PTOToEmitCMemoryOps.cpp @@ -0,0 +1,597 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- PTOToEmitCMemoryOps.cpp --------------------------------------------===// +//===----------------------------------------------------------------------===// + +#include "PTOToEmitCInternal.h" + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +#include +#include + +using namespace mlir; +using namespace mlir::pto; + +namespace mlir::pto { +namespace { + +static constexpr llvm::StringLiteral kForceDynamicValidShapeAttrName = + "__pto.force_dynamic_valid_shape"; + +struct PointerCastConversion : public OpConversionPattern { + static bool getIndexConst(Value v, int64_t &out) { + if (auto cst = v.getDefiningOp()) { + if (auto ia = dyn_cast(cst.getValue())) { + out = ia.getValue().getSExtValue(); + return true; + } + } + return false; + } + + using OpConversionPattern::OpConversionPattern; + + enum class TileRole { Vec, Mat, Left, Right, Acc, Bias, Scaling }; + + static void collectUserOpsThroughCasts(Value v, SmallVectorImpl &out) { + for (Operation *u : v.getUsers()) { + if (auto castOp = dyn_cast(u)) { + for (Value r : castOp.getResults()) + collectUserOpsThroughCasts(r, out); + continue; + } + out.push_back(u); + } + } + + static Value peelUnrealized(Value v) { + while (auto castOp = v.getDefiningOp()) { + v = castOp.getOperand(0); + } + return v; + } + + static TileRole inferRole(pto::PointerCastOp op) { + // 1. 优先检查 AddressSpace + if (auto memRefTy = dyn_cast(op.getType())) { + Attribute memorySpace = memRefTy.getMemorySpace(); + if (auto ptoAttr = dyn_cast_or_null(memorySpace)) { + switch (ptoAttr.getAddressSpace()) { + case pto::AddressSpace::LEFT: return TileRole::Left; + case pto::AddressSpace::RIGHT: return TileRole::Right; + case pto::AddressSpace::ACC: return TileRole::Acc; + case pto::AddressSpace::BIAS: return TileRole::Bias; + case pto::AddressSpace::MAT: return TileRole::Mat; + case pto::AddressSpace::SCALING: return TileRole::Scaling; + default: break; + } + } + } + + // 2. 通过 Usage 推导 (Fallback) + SmallVector users; + collectUserOpsThroughCasts(op.getResult(), users); + + for (Operation *user : users) { + if (auto mm = dyn_cast(user)) { + if (mm.getDst() && peelUnrealized(mm.getDst()) == op.getResult()) return TileRole::Acc; + if (peelUnrealized(mm.getLhs()) == op.getResult()) return TileRole::Left; + if (peelUnrealized(mm.getRhs()) == op.getResult()) return TileRole::Right; + } + if (auto mmacc = dyn_cast(user)) { + if (mmacc.getDst() && peelUnrealized(mmacc.getDst()) == op.getResult()) return TileRole::Acc; + if (peelUnrealized(mmacc.getAccIn()) == op.getResult()) return TileRole::Acc; + if (peelUnrealized(mmacc.getLhs()) == op.getResult()) return TileRole::Left; + if (peelUnrealized(mmacc.getRhs()) == op.getResult()) return TileRole::Right; + } + } + + return TileRole::Vec; + } + + // [新增] 辅助函数:判断 Value 是否源自 arith.constant + static bool isConstant(Value v, int64_t &outVal) { + if (!v) return false; + if (auto cst = v.getDefiningOp()) { + if (auto attr = dyn_cast(cst.getValue())) { + outVal = attr.getInt(); + return true; + } + } + return false; + } + + LogicalResult matchAndRewrite(pto::PointerCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto selfType = mlir::cast(op.getType()); + ArrayRef shape = selfType.getShape(); + Type elemType = selfType.getElementType(); + + // 1. 推导 Tile Role + TileRole role = inferRole(op); + + // 2. 类型字符串生成 (elemTypeStr, dimStr) + std::string elemTypeStr = getEmitCScalarTypeToken(elemType); + + std::string dimStr; + pto::BLayout blayout = pto::BLayout::RowMajor; + auto dimToString = [&](int64_t dim, const char *symbol, + int dimIdx) -> std::string { + if (dim == ShapedType::kDynamic) + return std::string(symbol); + return std::to_string(renderTileTemplateDim(dim, elemType, blayout, + dimIdx)); + }; + + // 3. Role Token + const char *roleTok = "TileType::Vec"; + switch (role) { + case TileRole::Left: roleTok = "TileType::Left"; break; + case TileRole::Right: roleTok = "TileType::Right"; break; + case TileRole::Acc: roleTok = "TileType::Acc"; break; + case TileRole::Bias: roleTok = "TileType::Bias"; break; + case TileRole::Mat: roleTok = "TileType::Mat"; break; + case TileRole::Vec: roleTok = "TileType::Vec"; break; + case TileRole::Scaling: roleTok = "TileType::Scaling"; break; + } + + // 4. Config & Layout (support BLayoutAttr/SLayoutAttr/PadValueAttr after namespace change) + std::string layoutParams = "BLayout::RowMajor"; + std::string extraParams = ""; + if (auto configOpt = op.getConfig()) { + auto config = *configOpt; + int32_t blVal = 0; + if (auto attr = dyn_cast(config.getBLayout())) + blVal = static_cast(attr.getValue()); + + if (blVal == 1) layoutParams = "BLayout::ColMajor"; + blayout = blVal == 1 ? pto::BLayout::ColMajor : pto::BLayout::RowMajor; + + int32_t slVal = 0; + if (auto attr = dyn_cast(config.getSLayout())) + slVal = static_cast(attr.getValue()); + + std::string slStr = (slVal == 1) ? "SLayout::RowMajor" : (slVal == 2) ? "SLayout::ColMajor" : "SLayout::NoneBox"; + + int32_t frVal = 0; + if (auto attr = dyn_cast(config.getSFractalSize())) frVal = attr.getInt(); + + int32_t padVal = 0; + if (auto attr = dyn_cast(config.getPad())) + padVal = static_cast(attr.getValue()); + + std::string padStr = "PadValue::Null"; + switch (padVal) { + case 1: padStr = "PadValue::Zero"; break; + case 2: padStr = "PadValue::Max"; break; + case 3: padStr = "PadValue::Min"; break; + } + + int32_t compactVal = 0; + if (auto attr = dyn_cast(config.getCompactMode())) + compactVal = static_cast(attr.getValue()); + + std::string compactStr = "CompactMode::Null"; + switch (compactVal) { + case 1: compactStr = "CompactMode::Normal"; break; + case 2: compactStr = "CompactMode::RowPlusOne"; break; + } + + if (!slStr.empty()) { + extraParams += ", " + slStr + ", " + std::to_string(frVal) + ", " + + padStr + ", " + compactStr; + } + } else { + extraParams = ", SLayout::NoneBox, 512, PadValue::Null, CompactMode::Null"; + } + + if (role == TileRole::Left) + dimStr = dimToString(shape[0], "M", 0) + ", " + + dimToString(shape[1], "K", 1); + else if (role == TileRole::Right) + dimStr = dimToString(shape[0], "K", 0) + ", " + + dimToString(shape[1], "N", 1); + else if (role == TileRole::Bias) + dimStr = "1, " + dimToString(shape[1], "N", 1); + else + dimStr = dimToString(shape[0], "M", 0) + ", " + + dimToString(shape[1], "N", 1); + + // [核心修改] Valid Dims 处理逻辑 (支持混合静态/动态) + std::string vrowTok, vcolTok; + bool useConstructor = false; + + bool rowIsDynamic = false; + bool colIsDynamic = false; + + SmallVector constructorArgs; + + Value vRow = op.getValidRow(); + Value vCol = op.getValidCol(); + Value vRowEmitC = adaptor.getValidRow(); + Value vColEmitC = adaptor.getValidCol(); + bool forceDynamicValid = op->hasAttr(kForceDynamicValidShapeAttrName); + + int64_t cRow = 0, cCol = 0; + bool rowIsConst = vRow && isConstant(vRow, cRow); + bool colIsConst = vCol && isConstant(vCol, cCol); + + auto makeCtorDimValue = [&](Value emitted, int64_t fallback) -> Value { + if (emitted) + return emitted; + return makeEmitCIntConstant( + rewriter, loc, emitc::OpaqueType::get(ctx, "int32_t"), fallback); + }; + auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { + if (!emitted || !pto::isPTOFloat4PackedType(elemType)) + return emitted; + int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; + if (dimIdx != packedDim) + return emitted; + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); + return rewriter.create(loc, i32Ty, emitted, two).getResult(); + }; + + if (forceDynamicValid) { + vrowTok = "-1"; + vcolTok = "-1"; + useConstructor = true; + constructorArgs.push_back( + makeCtorDimValue(maybeScaleDynamicValid(vRowEmitC, 0), + renderTileTemplateDim(rowIsConst ? cRow : shape[0], + elemType, blayout, 0))); + constructorArgs.push_back( + makeCtorDimValue(maybeScaleDynamicValid(vColEmitC, 1), + renderTileTemplateDim(colIsConst ? cCol : shape[1], + elemType, blayout, 1))); + } else { + if (rowIsConst) { + vrowTok = std::to_string( + renderTileTemplateDim(cRow, elemType, blayout, 0)); + } else if (vRow) { + vrowTok = "-1"; + rowIsDynamic = true; + useConstructor = true; + } else { + vrowTok = std::to_string( + renderTileTemplateDim(shape[0], elemType, blayout, 0)); + } + + if (colIsConst) { + vcolTok = std::to_string( + renderTileTemplateDim(cCol, elemType, blayout, 1)); + } else if (vCol) { + vcolTok = "-1"; + colIsDynamic = true; + useConstructor = true; + } else { + vcolTok = std::to_string( + renderTileTemplateDim(shape[1], elemType, blayout, 1)); + } + + if (useConstructor) { + if (rowIsDynamic && vRowEmitC) + constructorArgs.push_back(maybeScaleDynamicValid(vRowEmitC, 0)); + if (colIsDynamic && vColEmitC) + constructorArgs.push_back(maybeScaleDynamicValid(vColEmitC, 1)); + } + } + + // 5. 生成 Tile 类型字符串 + std::string tileTypeStr = + std::string("Tile<") + roleTok + ", " + elemTypeStr + ", " + dimStr + ", " + + layoutParams + ", " + vrowTok + ", " + vcolTok + extraParams + ">"; + + auto tileType = emitc::OpaqueType::get(ctx, tileTypeStr); + Value resultValue; + + if (useConstructor) { + // 使用 CallOpaqueOp 生成构造函数调用 (Tile v = Tile(...)) + auto ctorOp = rewriter.create( + loc, + tileType, // Result Type + tileTypeStr, // Callee Name (类名) + ArrayAttr{}, // args + ArrayAttr{}, // template_args + ValueRange(constructorArgs) // operands + ); + resultValue = ctorOp.getResult(0); + } else { + // 静态情况 (Tile v;) + auto varOp = rewriter.create( + loc, + tileType, + emitc::OpaqueAttr::get(ctx, "") + ); + resultValue = varOp.getResult(); + } + + // TASSIGN: pto-isa expects an integral address. + Value addr = adaptor.getAddrs()[0]; + if (isa(addr.getType()) || + (isa(addr.getType()) && + cast(addr.getType()).getValue().ends_with("*"))) { + auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); + auto rcU64 = rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + addr = rewriter.create( + loc, u64Ty, "reinterpret_cast", + /*args=*/ArrayAttr{}, /*templateArgs=*/rcU64, + /*operands=*/ValueRange{addr}) + .getResult(0); + } + + rewriter.create( + loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{resultValue, addr}); + + rewriter.replaceOp(op, resultValue); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// pto.load_dps / pto.store_dps lowering (FIX: keep optional result) +//===----------------------------------------------------------------------=== + +// GetBlockIdxOp Lowering (pto.get_block_idx -> get_block_idx()) + + +static std::optional getStaticIndexLikeValue(Value value) { + if (!value) + return std::nullopt; + if (auto cst = value.getDefiningOp()) + return cst.value(); + if (auto cst = value.getDefiningOp()) + return cst.value(); + if (auto cst = value.getDefiningOp()) { + if (auto intAttr = dyn_cast(cst.getValue())) + return intAttr.getInt(); + } + return std::nullopt; +} + +static FailureOr buildGlobalTensorViewFromPointer( + ConversionPatternRewriter &rewriter, Location loc, Value ptr, Type elemTy, + ArrayRef shape, ArrayRef strides = {}, + StringRef layoutEnum = "pto::Layout::ND") { + if (llvm::any_of(shape, [](int64_t dim) { + return dim == ShapedType::kDynamic; + })) + return failure(); + + auto *ctx = rewriter.getContext(); + SmallVector rowMajorStrides; + ArrayRef effectiveStrides = strides; + if (effectiveStrides.empty()) { + rowMajorStrides = buildRowMajorStrides(shape); + effectiveStrides = rowMajorStrides; + } + SmallVector shape5D; + SmallVector stride5D; + buildGlobalTensorShapeAndStride(shape, effectiveStrides, shape5D, stride5D); + + std::string shapeType = "pto::Shape<" + joinIntTemplateParams(shape5D) + ">"; + std::string strideType = + "pto::Stride<" + joinIntTemplateParams(stride5D) + ">"; + auto shapeVal = rewriter + .create( + loc, emitc::OpaqueType::get(ctx, shapeType), + shapeType, ArrayAttr{}, ArrayAttr{}, ValueRange{}) + .getResult(0); + auto strideVal = rewriter + .create( + loc, emitc::OpaqueType::get(ctx, strideType), + strideType, ArrayAttr{}, ArrayAttr{}, ValueRange{}) + .getResult(0); + + std::string gtTypeStr = + getGlobalTensorTypeStringFromShapeAndStrides(elemTy, shape, + effectiveStrides, + layoutEnum); + auto gtType = emitc::OpaqueType::get(ctx, gtTypeStr); + auto gt = rewriter.create( + loc, gtType, gtTypeStr, ArrayAttr{}, ArrayAttr{}, + ValueRange{ptr, shapeVal, strideVal}); + return gt.getResult(0); +} + +static bool parseIntegerTemplateList(StringRef token, StringRef marker, + SmallVectorImpl &values) { + size_t pos = token.find(marker); + if (pos == StringRef::npos) + return false; + pos += marker.size(); + size_t end = token.find('>', pos); + if (end == StringRef::npos) + return false; + + SmallVector parts; + token.slice(pos, end).split(parts, ','); + values.clear(); + for (StringRef part : parts) { + int64_t value = 0; + if (part.trim().getAsInteger(10, value)) + return false; + values.push_back(value); + } + return true; +} + +static LogicalResult getStaticTensorViewStrides( + Value source, Value convertedSource, pto::TensorViewType sourceType, + SmallVectorImpl &strides) { + int64_t rank = sourceType.getRank(); + strides.clear(); + + if (auto makeView = source.getDefiningOp()) { + if ((int64_t)makeView.getStrides().size() != rank) + return failure(); + for (Value strideValue : makeView.getStrides()) { + auto cst = getStaticIndexLikeValue(strideValue); + if (!cst) + return failure(); + strides.push_back(*cst); + } + return success(); + } + + Value src = peelUnrealized(convertedSource); + if (auto opaqueTy = dyn_cast(src.getType())) { + SmallVector stride5D; + StringRef token = opaqueTy.getValue(); + if ((parseIntegerTemplateList(token, "pto::Stride<", stride5D) || + parseIntegerTemplateList(token, "Stride<", stride5D)) && + (int64_t)stride5D.size() >= rank) { + strides.append(stride5D.end() - rank, stride5D.end()); + return success(); + } + } + + auto fallback = buildRowMajorStrides(sourceType.getShape()); + strides.append(fallback.begin(), fallback.end()); + return success(); +} + +struct PTOPartitionViewToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::PartitionViewOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::PartitionViewOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcTy = dyn_cast(op.getSource().getType()); + auto resTy = dyn_cast(op.getResult().getType()); + if (!srcTy || !resTy) + return rewriter.notifyMatchFailure( + op, "expected tensor_view source and partition_tensor_view result"); + + if (op.getOffsets().size() != static_cast(srcTy.getRank()) || + op.getSizes().size() != static_cast(srcTy.getRank())) + return rewriter.notifyMatchFailure(op, "rank mismatch"); + + for (auto [idx, value] : llvm::enumerate(op.getSizes())) { + auto cst = getStaticIndexLikeValue(value); + if (!cst) + return rewriter.notifyMatchFailure( + op, "globaltensor partition_view requires static sizes"); + int64_t resultDim = resTy.getShape()[idx]; + if (resultDim != ShapedType::kDynamic && resultDim != *cst) + return rewriter.notifyMatchFailure( + op, "partition_view static size does not match result type"); + } + + SmallVector srcStrides; + if (failed(getStaticTensorViewStrides(op.getSource(), adaptor.getSource(), + srcTy, srcStrides))) + return rewriter.notifyMatchFailure( + op, "partition_view requires static source strides"); + int64_t staticLinearOffset = 0; + SmallVector> dynamicOffsetTerms; + for (auto [idx, values] : + llvm::enumerate(llvm::zip(op.getOffsets(), adaptor.getOffsets()))) { + Value originalOffset = std::get<0>(values); + Value convertedOffset = std::get<1>(values); + int64_t stride = srcStrides[idx]; + if (stride == ShapedType::kDynamic) + return rewriter.notifyMatchFailure( + op, "dynamic source stride is not supported"); + + if (auto cst = getStaticIndexLikeValue(originalOffset)) { + if (*cst != 0) + staticLinearOffset += (*cst) * stride; + continue; + } + dynamicOffsetTerms.push_back({convertedOffset, stride}); + } + + auto *ctx = rewriter.getContext(); + std::string elemTypeStr = getElemTypeStringForGT(srcTy.getElementType()); + auto ptrTy = emitc::PointerType::get( + emitc::OpaqueType::get(ctx, "__gm__ " + elemTypeStr)); + Value src = peelUnrealized(adaptor.getSource()); + auto data = rewriter + .create( + op.getLoc(), ptrTy, "PTOAS__GLOBAL_TENSOR_DATA", + ArrayAttr{}, ArrayAttr{}, ValueRange{src}) + .getResult(0); + Value ptr = data; + if (!dynamicOffsetTerms.empty()) { + Type u32Ty = emitc::OpaqueType::get(ctx, "unsigned"); + auto makeU32 = [&](int64_t value) { + return makeEmitCIntConstant(rewriter, op.getLoc(), u32Ty, value); + }; + auto asU32 = [&](Value value) -> Value { + if (value.getType() == u32Ty) + return value; + return rewriter.create(op.getLoc(), u32Ty, value) + .getResult(); + }; + + Value totalOffset = makeU32(staticLinearOffset); + for (auto [offsetValue, stride] : dynamicOffsetTerms) { + Value term = asU32(offsetValue); + if (stride != 1) { + Value strideValue = makeU32(stride); + term = rewriter + .create(op.getLoc(), u32Ty, term, + strideValue) + .getResult(); + } + totalOffset = rewriter + .create(op.getLoc(), u32Ty, + totalOffset, term) + .getResult(); + } + ptr = rewriter + .create(op.getLoc(), data.getType(), data, + totalOffset) + .getResult(); + } else { + ptr = applyStaticMemrefOffset(rewriter, op.getLoc(), data, + staticLinearOffset); + } + + auto resultOr = buildGlobalTensorViewFromPointer( + rewriter, op.getLoc(), ptr, resTy.getElementType(), resTy.getShape(), + srcStrides); + if (failed(resultOr)) + return rewriter.notifyMatchFailure( + op, "failed to materialize partition GlobalTensor"); + + rewriter.replaceOp(op, *resultOr); + return success(); + } +}; + + +} // namespace + +void populatePTOToEmitCMemoryOpPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx) { + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); +} + +} // namespace mlir::pto diff --git a/lib/PTO/Transforms/PTOToEmitCRuntimeOps.cpp b/lib/PTO/Transforms/PTOToEmitCRuntimeOps.cpp new file mode 100644 index 000000000..a80b79fa0 --- /dev/null +++ b/lib/PTO/Transforms/PTOToEmitCRuntimeOps.cpp @@ -0,0 +1,736 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- PTOToEmitCRuntimeOps.cpp -------------------------------------------===// +//===----------------------------------------------------------------------===// + +#include "PTOToEmitCInternal.h" + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +#include + +using namespace mlir; +using namespace mlir::pto; + +namespace mlir::pto { +namespace { + +static constexpr unsigned kPTOIndexBitWidth = 32; + +static int64_t getEmitCScalarByteWidth(Type elemTy) { + if (pto::getPTOStorageElemByteSize(elemTy) == 1) + return 1; + if (elemTy.isF16() || elemTy.isBF16() || elemTy.isInteger(16)) + return 2; + if (elemTy.isF32() || elemTy.isInteger(32)) + return 4; + if (elemTy.isF64() || elemTy.isInteger(64)) + return 8; + return 4; +} + +static FailureOr getTileSplitToken(int64_t split) { + switch (split) { + case 0: + return std::string("TileSplitAxis::TILE_NO_SPLIT"); + case 1: + return std::string("TileSplitAxis::TILE_UP_DOWN"); + case 2: + return std::string("TileSplitAxis::TILE_LEFT_RIGHT"); + default: + return failure(); + } +} + +static FailureOr +getTPipeDirectionToken(bool isL2G2L, int8_t dirMask, PTOArch targetArch) { + if (dirMask == 1) { + if (isL2G2L && targetArch == PTOArch::A5) + return std::string("Direction::DIR_C2V_GM"); + return std::string("Direction::DIR_C2V"); + } + if (dirMask == 2) { + if (isL2G2L && targetArch == PTOArch::A5) + return std::string("Direction::DIR_V2C_GM"); + return std::string("Direction::DIR_V2C"); + } + if (dirMask == 3) + return std::string("Direction::DIR_BOTH"); + return failure(); +} + +static std::string buildTPipeToken(int32_t flagBase, llvm::StringRef dirTok, + int32_t slotSize, int32_t slotNum, + int32_t localSlotNum, bool nosplit) { + std::string token = "TPipe<" + std::to_string(flagBase) + ", " + dirTok.str() + + ", " + std::to_string(slotSize) + ", " + + std::to_string(slotNum); + token += ", " + std::to_string(localSlotNum); + token += nosplit ? ", true" : ", false"; + token += ">"; + return token; +} + +} // namespace + +FailureOr buildTPipeTokenFromInitOp(Operation *op, + PTOArch targetArch) { + if (auto initOp = dyn_cast(op)) { + if (!initOp.getFlagBaseAttr()) + return failure(); + auto dirTok = + getTPipeDirectionToken(/*isL2G2L=*/true, initOp.getDirMask(), targetArch); + if (failed(dirTok)) + return failure(); + int32_t localSlotNum = initOp.getLocalSlotNumAttr() + ? initOp.getLocalSlotNumAttr().getInt() + : initOp.getSlotNum(); + return buildTPipeToken(initOp.getFlagBaseAttr().getInt(), *dirTok, + initOp.getSlotSize(), initOp.getSlotNum(), + localSlotNum, + initOp.getNosplitAttr() && + initOp.getNosplitAttr().getValue()); + } + + if (auto initOp = dyn_cast(op)) { + if (!initOp.getFlagBaseAttr()) + return failure(); + auto dirTok = + getTPipeDirectionToken(/*isL2G2L=*/false, initOp.getDirMask(), targetArch); + if (failed(dirTok)) + return failure(); + return buildTPipeToken(initOp.getFlagBaseAttr().getInt(), *dirTok, + initOp.getSlotSize(), initOp.getSlotNum(), 2, + initOp.getNosplitAttr() && + initOp.getNosplitAttr().getValue()); + } + + return failure(); +} + + +namespace { + +static FailureOr getTPipeTokenFromValue(Value pipeHandle, + PTOArch targetArch) { + pipeHandle = peelUnrealized(pipeHandle); + Operation *def = pipeHandle.getDefiningOp(); + if (!def) + return failure(); + return buildTPipeTokenFromInitOp(def, targetArch); +} + + + +static FailureOr getPipeDataTypeToken(Value value) { + auto opaqueTy = dyn_cast(value.getType()); + if (!opaqueTy) + return failure(); + StringRef token = opaqueTy.getValue(); + if (!token.contains("Tile<") && !token.contains("GlobalTensor<")) + return failure(); + return token.str(); +} + +struct PTOTAllocToEmitC : public OpConversionPattern { + PTOTAllocToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + PTOArch targetArch) + : OpConversionPattern(typeConverter, ctx), + targetArch(targetArch) {} + + LogicalResult matchAndRewrite(mlir::pto::TAllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); + if (failed(pipeTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); + Value entry = peelUnrealized(adaptor.getEntry()); + auto entryTok = getPipeDataTypeToken(entry); + if (failed(entryTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve entry token"); + auto splitTok = getTileSplitToken(op.getSplit()); + if (failed(splitTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve split token"); + + std::string callee = + "TALLOC<" + *pipeTok + ", " + *entryTok + ", " + *splitTok + ">"; + rewriter.replaceOpWithNewOp( + op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, + ValueRange{peelUnrealized(adaptor.getPipeHandle()), entry}); + return success(); + } + + PTOArch targetArch; +}; + +struct PTOTPushToEmitC : public OpConversionPattern { + PTOTPushToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + PTOArch targetArch) + : OpConversionPattern(typeConverter, ctx), + targetArch(targetArch) {} + + LogicalResult matchAndRewrite(mlir::pto::TPushOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); + if (failed(pipeTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); + // Read the tile type token from the already-converted OpaqueType, which + // preserves the exact layout produced by BindTileOp / PointerCastOp EmitC. + Value convertedTile = peelUnrealized(adaptor.getTile()); + auto tileTok = getPipeDataTypeToken(convertedTile); + if (failed(tileTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve tile token"); + auto splitTok = getTileSplitToken(op.getSplit()); + if (failed(splitTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve split token"); + + std::string callee = + "TPUSH<" + *pipeTok + ", " + *tileTok + ", " + *splitTok + ">"; + rewriter.replaceOpWithNewOp( + op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, + ValueRange{peelUnrealized(adaptor.getPipeHandle()), convertedTile}); + return success(); + } + + PTOArch targetArch; +}; + +struct PTOTPopToEmitC : public OpConversionPattern { + PTOTPopToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + PTOArch targetArch) + : OpConversionPattern(typeConverter, ctx), + targetArch(targetArch) {} + + LogicalResult matchAndRewrite(mlir::pto::TPopOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); + if (failed(pipeTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); + Value convertedTile = peelUnrealized(adaptor.getTile()); + auto tileTok = getPipeDataTypeToken(convertedTile); + if (failed(tileTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve tile token"); + auto splitTok = getTileSplitToken(op.getSplit()); + if (failed(splitTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve split token"); + + std::string callee = + "TPOP<" + *pipeTok + ", " + *tileTok + ", " + *splitTok + ">"; + rewriter.replaceOpWithNewOp( + op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, + ValueRange{peelUnrealized(adaptor.getPipeHandle()), convertedTile}); + return success(); + } + + PTOArch targetArch; +}; + +struct PTOTFreeToEmitC : public OpConversionPattern { + PTOTFreeToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + PTOArch targetArch) + : OpConversionPattern(typeConverter, ctx), + targetArch(targetArch) {} + + LogicalResult matchAndRewrite(mlir::pto::TFreeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); + if (failed(pipeTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); + auto splitTok = getTileSplitToken(op.getSplit()); + if (failed(splitTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve split token"); + + SmallVector operands{peelUnrealized(adaptor.getPipeHandle())}; + std::string callee; + if (op.getEntry()) { + Value entry = peelUnrealized(adaptor.getEntry()); + auto entryTok = getPipeDataTypeToken(entry); + if (failed(entryTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve entry token"); + callee = "TFREE<" + *pipeTok + ", " + *entryTok + ", " + *splitTok + ">"; + operands.push_back(entry); + } else { + callee = "TFREE<" + *pipeTok + ", " + *splitTok + ">"; + } + rewriter.replaceOpWithNewOp( + op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, operands); + return success(); + } + + PTOArch targetArch; +}; + +//===----------------------------------------------------------------------===// +// populate patterns +//===----------------------------------------------------------------------=== +struct ReinterpretCastToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + auto resMrTy = dyn_cast(op.getType()); + if (!resMrTy) + return failure(); + + auto asAttr = dyn_cast_or_null(resMrTy.getMemorySpace()); + const bool isGm = (!asAttr || asAttr.getAddressSpace() == pto::AddressSpace::GM); + + bool emitAddPtrTrace = op->hasAttr("pto.addptr_trace"); + Value source = peelUnrealized(adaptor.getSource()); + auto offsets = adaptor.getOffsets(); + Value offsetVal = offsets.empty() ? Value() : offsets[0]; + + // GM: keep pointer arithmetic. + if (isGm) { + if (!offsetVal) { + rewriter.replaceOp(op, source); + return success(); + } + + Type resultType = getTypeConverter()->convertType(op.getType()); + if (!resultType) + return failure(); + + auto addOp = rewriter.create(loc, resultType, source, offsetVal); + if (emitAddPtrTrace) { + rewriter.setInsertionPointAfter(addOp); + rewriter.create( + loc, TypeRange{}, "PTOAS__ADDPTR_TRACE", + ArrayAttr{}, ArrayAttr{}, + ValueRange{addOp.getResult(), source, offsetVal}); + } + rewriter.replaceOp(op, addOp.getResult()); + return success(); + } + + // UB/L1/L0 tiles: materialize a new Tile view by assigning an adjusted + // underlying pointer (in elements). + pto::AddressSpace as = asAttr.getAddressSpace(); + + // Element type token. + Type elemTy = resMrTy.getElementType(); + std::string elemTok = getEmitCScalarTypeToken(elemTy); + int64_t elemBytes = getEmitCScalarByteWidth(elemTy); + + // Tile role. + const char *roleTok = "TileType::Vec"; + switch (as) { + case pto::AddressSpace::VEC: + roleTok = "TileType::Vec"; + break; + case pto::AddressSpace::MAT: + roleTok = "TileType::Mat"; + break; + case pto::AddressSpace::LEFT: + roleTok = "TileType::Left"; + break; + case pto::AddressSpace::RIGHT: + roleTok = "TileType::Right"; + break; + case pto::AddressSpace::ACC: + roleTok = "TileType::Acc"; + break; + case pto::AddressSpace::BIAS: + roleTok = "TileType::Bias"; + break; + case pto::AddressSpace::GM: + roleTok = "TileType::Vec"; + break; + case pto::AddressSpace::Zero: + roleTok = "TileType::Vec"; + break; + case pto::AddressSpace::SCALING: + roleTok = "TileType::Scaling"; + break; + } + + // Shape (fallback to 32x32). + int64_t rows = 32, cols = 32; + if (resMrTy.getRank() >= 2 && resMrTy.hasStaticShape()) { + rows = resMrTy.getDimSize(0); + cols = resMrTy.getDimSize(1); + } + int64_t templateRows = + renderTileTemplateDim(rows, elemTy, pto::BLayout::RowMajor, 0); + int64_t templateCols = + renderTileTemplateDim(cols, elemTy, pto::BLayout::RowMajor, 1); + + // Keep a conservative default config for now. + std::string tileTypeStr = + std::string("Tile<") + roleTok + ", " + elemTok + ", " + + std::to_string(templateRows) + ", " + std::to_string(templateCols) + + ", BLayout::RowMajor, " + std::to_string(templateRows) + ", " + + std::to_string(templateCols) + + ", SLayout::NoneBox, 512, PadValue::Null, CompactMode::Null>"; + + auto tileType = emitc::OpaqueType::get(ctx, tileTypeStr); + Value tile = rewriter + .create(loc, tileType, + emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + + // Compute an integer address and assign it to the new tile. + // NOTE: pto-isa TASSIGN requires an integral address (not a pointer). + auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); + auto rcU64 = rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + + // Non-GM reinterpret_cast operands come from UB/L1/L0 tiles. + // We need the underlying address, but `__cce_get_tile_ptr()` is only valid + // inside `__tf__` functions. Use `tile.data()` (via a post-processed marker) + // and compute the adjusted address in bytes. + Value rawPtr = source; + if (auto ot = dyn_cast(source.getType())) { + // Only Tiles have a `.data()` member. For plain address-space pointers + // (e.g. `__ubuf__ float*`), use the pointer value directly. + if (ot.getValue().starts_with("Tile<")) { + rawPtr = materializeTileDataValue(rewriter, loc, source, as, elemTok); + } + } + + Value baseAddr = rawPtr; + if (isSetFFTsPointerLikeType(rawPtr.getType())) { + baseAddr = rewriter + .create(loc, u64Ty, "reinterpret_cast", + /*args=*/ArrayAttr{}, + /*templateArgs=*/rcU64, + /*operands=*/ValueRange{rawPtr}) + .getResult(0); + } else if (rawPtr.getType() != u64Ty) { + baseAddr = rewriter.create(loc, u64Ty, rawPtr).getResult(); + } + + Value addr = baseAddr; + if (offsetVal) { + Value offU64 = offsetVal; + if (offU64.getType() != u64Ty) + offU64 = rewriter.create(loc, u64Ty, offU64).getResult(); + + auto bytesAttr = emitc::OpaqueAttr::get(ctx, std::to_string(elemBytes)); + Value bytesVal = rewriter.create(loc, u64Ty, bytesAttr); + Value byteOff = rewriter.create(loc, u64Ty, offU64, bytesVal); + addr = rewriter.create(loc, u64Ty, baseAddr, byteOff); + } + + rewriter.create(loc, TypeRange{}, "TASSIGN", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{tile, addr}); + + rewriter.replaceOp(op, tile); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.taddc lowering -> TADDC(dst, src0, src1, src2) +//===----------------------------------------------------------------------===// + +struct PTOTAddCToTADDC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAddCOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value src2 = peelUnrealized(adaptor.getSrc2()); + Value dst = peelUnrealized(adaptor.getDst()); + + // pto-isa does not provide NPU implementation for TADDC yet. + // Decompose: dst = src0 + src1 + src2 + rewriter.create( + loc, TypeRange{}, "TADD", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TADD", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, dst, src2}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tadds lowering -> TADDS(dst, src, scalar) +//===----------------------------------------------------------------------===// + +struct PTOAddSToTADDS : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAddSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value scalar = peelUnrealized(adaptor.getScalar()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TADDS", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src, scalar}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.taddsc lowering -> TADDSC(dst, src0, scalar, src1) +//===----------------------------------------------------------------------===// + +struct PTOAddSCToTADDSC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAddSCOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + // pto-isa does not provide NPU implementation for TADDSC yet. + // Decompose: dst = src0 + scalar + src1 + rewriter.create( + loc, TypeRange{}, "TADDS", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src0, scalar}); + rewriter.create( + loc, TypeRange{}, "TADD", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, dst, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; +// Tile/vector PTO op conversion patterns live in PTOToEmitCTilePatterns.cpp. + +struct PTOPrintOpToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::PrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + std::string fmt = op.getFormat().str(); + if (fmt.empty()) + fmt = "%f"; + std::string quoted = "\""; + for (char c : fmt) { + if (c == '"' || c == '\\') + quoted += '\\'; + else if (c == '\n') + quoted += "\\n"; + else if (c == '\t') + quoted += "\\t"; + else + quoted += c; + } + quoted += "\""; + + Value scalar = peelUnrealized(adaptor.getScalar()); + auto argsAttr = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, quoted), + IntegerAttr::get(IndexType::get(ctx), 0)}); + rewriter.create( + loc, TypeRange{}, "cce::printf", + /*args=*/argsAttr, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{scalar}); + + rewriter.eraseOp(op); + return success(); + } +}; + +// pto.trap -> TRAP() +struct PTOTrapOpToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TrapOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + rewriter.create( + loc, TypeRange{}, "trap", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{}); + + rewriter.eraseOp(op); + return success(); + } +}; + +// ============================================================================= +// Arith CmpI -> EmitC Cmp +// ============================================================================= +class ArithCmpIToEmitC : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + // 将 arith.cmpi 转换为 emitc.cmp + // 映射 Predicate: eq -> equal, slt -> less, etc. + emitc::CmpPredicate emitcPred = emitc::CmpPredicate::eq; + const bool isUnsignedPred = + op.getPredicate() == arith::CmpIPredicate::ult || + op.getPredicate() == arith::CmpIPredicate::ule || + op.getPredicate() == arith::CmpIPredicate::ugt || + op.getPredicate() == arith::CmpIPredicate::uge; + switch (op.getPredicate()) { + case arith::CmpIPredicate::eq: emitcPred = emitc::CmpPredicate::eq; break; + case arith::CmpIPredicate::ne: emitcPred = emitc::CmpPredicate::ne; break; + case arith::CmpIPredicate::slt: emitcPred = emitc::CmpPredicate::lt; break; + case arith::CmpIPredicate::sle: emitcPred = emitc::CmpPredicate::le; break; + case arith::CmpIPredicate::sgt: emitcPred = emitc::CmpPredicate::gt; break; + case arith::CmpIPredicate::sge: emitcPred = emitc::CmpPredicate::ge; break; + // ... 处理无符号比较 (ult, ule 等) ... + case arith::CmpIPredicate::ult: emitcPred = emitc::CmpPredicate::lt; break; + case arith::CmpIPredicate::ule: emitcPred = emitc::CmpPredicate::le; break; + case arith::CmpIPredicate::ugt: emitcPred = emitc::CmpPredicate::gt; break; + case arith::CmpIPredicate::uge: emitcPred = emitc::CmpPredicate::ge; break; + } + + Type resTy = getTypeConverter()->convertType(op.getType()); + if (!resTy) + return failure(); + + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + if (isUnsignedPred) { + Type opTy = op.getLhs().getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure( + op, "expected scalar integer or index operands"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + if (bitWidth != 1) { + lhs = castSignlessIntToUnsignedSameWidth(rewriter, loc, lhs, bitWidth); + rhs = castSignlessIntToUnsignedSameWidth(rewriter, loc, rhs, bitWidth); + } + } + + rewriter.replaceOpWithNewOp( + op, + /*resultType=*/resTy, // i1 -> bool/i1 + emitcPred, + lhs, + rhs + ); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Section Op Lowering +//===----------------------------------------------------------------------===// +static bool isA5NoSplitPipeOp(Operation *op) { + if (auto talloc = dyn_cast(op)) + return talloc.getSplit() == 0; + if (auto tpush = dyn_cast(op)) + return tpush.getSplit() == 0; + if (auto tpop = dyn_cast(op)) + return tpop.getSplit() == 0; + if (auto tfree = dyn_cast(op)) + return tfree.getSplit() == 0; + if (auto tpush = dyn_cast(op)) + return tpush.getSplit() == 0; + if (auto tpush = dyn_cast(op)) + return tpush.getSplit() == 0; + if (auto talloc = dyn_cast(op)) + return talloc.getSplit() == 0; + if (auto talloc = dyn_cast(op)) + return talloc.getSplit() == 0; + if (auto tpop = dyn_cast(op)) + return tpop.getSplit() == 0; + if (auto tpop = dyn_cast(op)) + return tpop.getSplit() == 0; + if (auto tfree = dyn_cast(op)) + return tfree.getSplit() == 0; + if (auto tfree = dyn_cast(op)) + return tfree.getSplit() == 0; + return false; +} + +static bool hasExplicitSubblockControl(Operation *op) { + bool hasControl = false; + op->walk([&](Operation *nested) { + if (isa(nested)) { + hasControl = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return hasControl; +} + +} // namespace + +bool needsA5NoSplitVectorGuard(Operation *op) { + auto arch = getTargetArch(op); + if (arch != PTOArch::A5) + return false; + bool isVectorScope = isa(op); + if (auto func = dyn_cast(op)) { + if (auto kernelKindAttr = + func->getAttrOfType( + FunctionKernelKindAttr::name)) { + isVectorScope = + kernelKindAttr.getKernelKind() == FunctionKernelKind::Vector; + } + } + if (!isVectorScope) + return false; + if (hasExplicitSubblockControl(op)) + return false; + + bool hasNoSplitPipe = false; + op->walk([&](Operation *nested) { + if (!isA5NoSplitPipeOp(nested)) + return WalkResult::advance(); + hasNoSplitPipe = true; + return WalkResult::interrupt(); + }); + return hasNoSplitPipe; +} + + +void populatePTOToEmitCRuntimeOpPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx, PTOArch targetArch) { + patterns.add(typeConverter, ctx, targetArch); + patterns.add(typeConverter, ctx, targetArch); + patterns.add(typeConverter, ctx, targetArch); + patterns.add(typeConverter, ctx, targetArch); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); +} + +} // namespace mlir::pto diff --git a/lib/PTO/Transforms/PTOToEmitCSimpleOps.cpp b/lib/PTO/Transforms/PTOToEmitCSimpleOps.cpp new file mode 100644 index 000000000..24dc00c2e --- /dev/null +++ b/lib/PTO/Transforms/PTOToEmitCSimpleOps.cpp @@ -0,0 +1,563 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- PTOToEmitCSimpleOps.cpp --------------------------------------------===// +//===----------------------------------------------------------------------===// + +#include "PTOToEmitCInternal.h" + +#include "PTO/IR/PTO.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" + +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; +using namespace mlir::pto; + +namespace mlir::pto { +namespace { + +struct PTOGetBlockIdxToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mlir::pto::GetBlockIdxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + rewriter.replaceOpWithNewOp( + op, op.getType(), "get_block_idx", ValueRange{}, ArrayAttr{}, + ArrayAttr{}); + + return success(); + } +}; + +// GetBlockNumOp Lowering (pto.get_block_num -> get_block_num()) +struct PTOGetBlockNumToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mlir::pto::GetBlockNumOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + rewriter.replaceOpWithNewOp( + op, op.getType(), "get_block_num", ValueRange{}, ArrayAttr{}, + ArrayAttr{}); + + return success(); + } +}; + +// GetSubBlockIdxOp Lowering (pto.get_block_idx -> get_subblockid()) +struct PTOGetSubBlockIdxToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mlir::pto::GetSubBlockIdxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + rewriter.replaceOpWithNewOp( + op, op.getType(), "get_subblockid", ValueRange{}, ArrayAttr{}, + ArrayAttr{}); + + return success(); + } +}; + +// GetSubBlockNumOp Lowering. +struct PTOGetSubBlockNumToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mlir::pto::GetSubBlockNumOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + rewriter.replaceOpWithNewOp( + op, op.getType(), "get_subblockdim", ValueRange{}, ArrayAttr{}, + ArrayAttr{}); + + return success(); + } +}; + + + + +struct PTOSetValToSETVAL : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSetValOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value dst = peelUnrealized(adaptor.getDst()); + Value val = peelUnrealized(adaptor.getVal()); + + // ---- offset: SSA index operand ---- + Value offset = peelUnrealized(adaptor.getOffset()); + + // Emit a marker call and let the ptoas post-processing step lower it to + // the corresponding tile setter. + rewriter.create( + op.getLoc(), TypeRange{}, "PTOAS__TILE_SET_VALUE", + ArrayAttr{}, ArrayAttr{}, ValueRange{dst, offset, val}); + + rewriter.eraseOp(op); + return success(); + } +}; +struct PTOGetValToGETVAL : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGetValOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = peelUnrealized(adaptor.getSrc()); + + // ---- offset: SSA index operand ---- + Value offset = peelUnrealized(adaptor.getOffset()); + + // Emit a marker call and let the ptoas post-processing step lower it to + // the corresponding tile getter. + Type dstTy = getTypeConverter()->convertType(op.getDst().getType()); + if (!dstTy) + return failure(); + auto call = rewriter.create( + op.getLoc(), + TypeRange{dstTy}, + "PTOAS__TILE_GET_VALUE", + ArrayAttr{}, ArrayAttr{}, + ValueRange{src, offset}); + + rewriter.replaceOp(op, call.getResults()); + return success(); + } +}; + +struct PTOTAxpyToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAxpyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value scalar = peelUnrealized(adaptor.getScalar()); + + rewriter.create( + loc, TypeRange{}, "TAXPY", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, scalar}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOHistogramToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::THistogramOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value idx = peelUnrealized(adaptor.getIdx()); + Value dst = peelUnrealized(adaptor.getDst()); + + auto templateArgs = rewriter.getArrayAttr({emitc::OpaqueAttr::get( + ctx, op.getIsMSB() ? "HistByte::BYTE_1" : "HistByte::BYTE_0")}); + rewriter.create( + loc, TypeRange{}, "THISTOGRAM", + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, + /*operands=*/ValueRange{dst, src, idx}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOGetScaleAddrToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGetScaleAddrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TGET_SCALE_ADDR", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOSetValidShapeToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::SetValidShapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto peelAllCasts = [](Value v) { + while (auto castOp = v.getDefiningOp()) + v = castOp.getOperand(0); + if (auto castOp = v.getDefiningOp()) + v = castOp.getOperand(); + return v; + }; + auto isTileLike = [](Value v) -> bool { + auto ot = dyn_cast(v.getType()); + if (!ot) + return false; + StringRef s = ot.getValue(); + return s.contains("Tile<") || s.contains("ConvTile<"); + }; + + Value src = peelAllCasts(peelUnrealized(adaptor.getSource())); + Value row = peelUnrealized(adaptor.getValidRow()); + Value col = peelUnrealized(adaptor.getValidCol()); + + if (!isTileLike(src)) + return rewriter.notifyMatchFailure( + op, "set_validshape source must lower to a tile-like value"); + + rewriter.create( + op.getLoc(), TypeRange{}, "PTOAS__TILE_SET_VALIDSHAPE", ArrayAttr{}, + ArrayAttr{}, ValueRange{src, row, col}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOGetValidShapeToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::GetValidShapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto peelAllCasts = [](Value v) { + while (auto castOp = v.getDefiningOp()) + v = castOp.getOperand(0); + if (auto castOp = v.getDefiningOp()) + v = castOp.getOperand(); + return v; + }; + auto isTileLike = [](Value v) -> bool { + auto ot = dyn_cast(v.getType()); + if (!ot) + return false; + StringRef s = ot.getValue(); + return s.contains("Tile<") || s.contains("ConvTile<"); + }; + + Value src = peelAllCasts(peelUnrealized(adaptor.getSource())); + if (!isTileLike(src)) + return rewriter.notifyMatchFailure( + op, "get_validshape source must lower to a tile-like value"); + + auto resultTy = getTypeConverter()->convertType(rewriter.getIndexType()); + if (!resultTy) + return failure(); + + Value row = rewriter + .create( + op.getLoc(), resultTy, + "PTOAS__TILE_GET_VALID_ROW", ArrayAttr{}, + ArrayAttr{}, ValueRange{src}) + .getResult(0); + Value col = rewriter + .create( + op.getLoc(), resultTy, + "PTOAS__TILE_GET_VALID_COL", ArrayAttr{}, + ArrayAttr{}, ValueRange{src}) + .getResult(0); + rewriter.replaceOp(op, ValueRange{row, col}); + return success(); + } +}; + +struct PTOTAssignToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAssignOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto peelAllCasts = [](Value v) { + while (auto castOp = v.getDefiningOp()) + v = castOp.getOperand(0); + if (auto castOp = v.getDefiningOp()) + v = castOp.getOperand(); + return v; + }; + auto isTileLike = [](Value v) -> bool { + auto ot = dyn_cast(v.getType()); + if (!ot) + return false; + StringRef s = ot.getValue(); + return s.contains("Tile<") || s.contains("ConvTile<"); + }; + + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value tile = peelAllCasts(peelUnrealized(adaptor.getTile())); + if (!isTileLike(tile)) + return rewriter.notifyMatchFailure( + op, "tassign tile must lower to a tile-like value"); + + Value addr = peelUnrealized(adaptor.getAddr()); + auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); + if (isa(addr.getType()) || + (isa(addr.getType()) && + cast(addr.getType()).getValue().ends_with("*"))) { + auto rcU64 = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + addr = rewriter + .create(loc, u64Ty, "reinterpret_cast", + ArrayAttr{}, rcU64, + ValueRange{addr}) + .getResult(0); + } else if (addr.getType() != u64Ty) { + addr = rewriter.create(loc, u64Ty, addr).getResult(); + } + + rewriter.create(loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{tile, addr}); + rewriter.replaceOp(op, tile); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// pto.load_scalar / pto.store_scalar lowering -> ptr[offset] +//===----------------------------------------------------------------------===// + +static Type getPointerLikeElementType(Type type) { + if (auto ptrTy = dyn_cast(type)) + return ptrTy.getElementType(); + if (auto memTy = dyn_cast(type)) + return memTy.getElementType(); + return Type(); +} + +struct PTOPtrToIntToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::PtrToIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value ptr = peelUnrealized(adaptor.getPtr()); + Type dstTy = getTypeConverter()->convertType(op.getResult().getType()); + if (!dstTy) + return failure(); + + auto dstOpaque = dyn_cast(dstTy); + if (!dstOpaque) + return failure(); + + auto templateArgs = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), + dstOpaque.getValue())}); + auto cast = rewriter.create( + op.getLoc(), dstTy, "reinterpret_cast", ArrayAttr{}, templateArgs, + ValueRange{ptr}); + rewriter.replaceOp(op, cast.getResult(0)); + return success(); + } +}; + +struct PTOIntToPtrToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::IntToPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value addr = peelUnrealized(adaptor.getAddr()); + Type dstTy = getTypeConverter()->convertType(op.getResult().getType()); + if (!dstTy) + return failure(); + + Type dstElemTy = getPointerLikeElementType(op.getResult().getType()); + if (!dstElemTy) + return failure(); + + std::string castType = + std::string("__gm__ ") + getEmitCScalarTypeToken(dstElemTy) + "*"; + auto templateArgs = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), + castType)}); + auto cast = rewriter.create( + op.getLoc(), dstTy, "reinterpret_cast", ArrayAttr{}, templateArgs, + ValueRange{addr}); + rewriter.replaceOp(op, cast.getResult(0)); + return success(); + } +}; + +struct PTOLoadScalarToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::LoadScalarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value ptr = peelUnrealized(adaptor.getPtr()); + Value offset = peelUnrealized(adaptor.getOffset()); + + Type dstTy = getTypeConverter()->convertType(op.getValue().getType()); + if (!dstTy) + return failure(); + + auto call = rewriter.create( + op.getLoc(), TypeRange{dstTy}, "PTOAS__PTR_LOAD", + ArrayAttr{}, ArrayAttr{}, ValueRange{ptr, offset}); + + rewriter.replaceOp(op, call.getResults()); + return success(); + } +}; + +struct PTOStoreScalarToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::StoreScalarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value ptr = peelUnrealized(adaptor.getPtr()); + Value offset = peelUnrealized(adaptor.getOffset()); + Value val = peelUnrealized(adaptor.getValue()); + + rewriter.create( + op.getLoc(), TypeRange{}, "PTOAS__PTR_STORE", + ArrayAttr{}, ArrayAttr{}, ValueRange{ptr, offset, val}); + rewriter.create( + op.getLoc(), TypeRange{}, "PTOAS__SCALAR_GM_STORE_FLUSH", + ArrayAttr{}, ArrayAttr{}, ValueRange{ptr}); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// pto.tabs lowering -> TABS(dst, src) +//===----------------------------------------------------------------------===// + + + +struct PTOTAbsToTABS : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAbsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + // intrinsic: TABS(dst, src) + rewriter.create( + op.getLoc(), TypeRange{}, "TABS", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tadd lowering -> TADD(dst, src0, src1) +//===----------------------------------------------------------------------===// + +struct PTOTAddToTADD : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAddOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TADD", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + + +struct AffineApplyMulConstToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(affine::AffineApplyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto map = op.getAffineMap(); + + if (map.getNumDims() != 0 || map.getNumSymbols() != 1) + return failure(); + + auto expr = map.getResult(0); + auto bin = dyn_cast(expr); + if (!bin || bin.getKind() != AffineExprKind::Mul) + return failure(); + + auto lhs = bin.getLHS(); + auto rhs = bin.getRHS(); + + auto symExpr = dyn_cast(lhs); + auto constExpr = dyn_cast(rhs); + if (!symExpr || !constExpr) + return failure(); + + Value inputVal = adaptor.getMapOperands()[0]; + + std::string valStr = std::to_string(constExpr.getValue()); + auto cstAttr = emitc::OpaqueAttr::get(rewriter.getContext(), valStr); + auto cstOp = rewriter.create( + op.getLoc(), inputVal.getType(), cstAttr); + + rewriter.replaceOpWithNewOp( + op, inputVal.getType(), inputVal, cstOp); + + return success(); + } +}; + + +} // namespace + +void populatePTOToEmitCSimpleOpPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx) { + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add( + typeConverter, ctx); + patterns.add( + typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); +} + +} // namespace mlir::pto diff --git a/lib/PTO/Transforms/PTOToEmitCSync.cpp b/lib/PTO/Transforms/PTOToEmitCSync.cpp new file mode 100644 index 000000000..efa812e5e --- /dev/null +++ b/lib/PTO/Transforms/PTOToEmitCSync.cpp @@ -0,0 +1,1046 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- PTOToEmitCSync.cpp --------------------------------------------------===// +//===----------------------------------------------------------------------===// + +#include "PTOToEmitCInternal.h" + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOSyncUtils.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +#include +#include +#include + +using namespace mlir; +using namespace mlir::pto; + +namespace mlir::pto { +namespace { + +static constexpr llvm::StringLiteral kAutoSyncTailPendingModeAttr = + "__pto.auto_sync_tail_mode"; +static inline std::string pipeTokFromPipeAttr(mlir::pto::PipeAttr a); + +struct InterCoreSyncCallDesc { + const char *callee = nullptr; + ArrayAttr args; + SmallVector operands; +}; + +static Value castInterCoreEventIdToI32(ConversionPatternRewriter &rewriter, + Location loc, Value eventId) { + auto i32Ty = emitc::OpaqueType::get(rewriter.getContext(), "int32_t"); + if (eventId.getType() == i32Ty) + return eventId; + return emitCCast(rewriter, loc, i32Ty, eventId); +} + +static Attribute getFFTSModeCodegenArg(ConversionPatternRewriter &rewriter, + int64_t fftsMode) { + auto *ctx = rewriter.getContext(); + if (fftsMode == 2) + return emitc::OpaqueAttr::get(ctx, "FFTS_MODE_VAL"); + return emitc::OpaqueAttr::get(ctx, std::to_string(fftsMode)); +} + +static Value createFFTSMsg(ConversionPatternRewriter &rewriter, Location loc, + Value eventI32, int64_t fftsMode) { + auto *ctx = rewriter.getContext(); + auto msgTy = emitc::OpaqueType::get(ctx, "uint16_t"); + auto msgArgs = rewriter.getArrayAttr({ + getFFTSModeCodegenArg(rewriter, fftsMode), + IntegerAttr::get(IndexType::get(ctx), 0), + }); + return rewriter + .create(loc, msgTy, "getFFTSMsg", + /*args=*/msgArgs, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{eventI32}) + .getResult(0); +} + +static InterCoreSyncCallDesc buildInterCoreSyncSetCall( + ConversionPatternRewriter &rewriter, Location loc, PTOArch targetArch, + pto::PipeAttr pipeAttr, IntegerAttr eventIdAttr, int64_t fftsMode) { + auto *ctx = rewriter.getContext(); + std::string pipeTok = pipeTokFromPipeAttr(pipeAttr); + + if (targetArch == PTOArch::A3) { + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value eventVal = + makeEmitCIntConstant(rewriter, loc, i32Ty, eventIdAttr.getInt()); + Value msgVal = createFFTSMsg(rewriter, loc, eventVal, fftsMode); + + InterCoreSyncCallDesc desc; + desc.callee = "ffts_cross_core_sync"; + desc.args = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeTok), + IntegerAttr::get(IndexType::get(ctx), 0), + }); + desc.operands.push_back(msgVal); + return desc; + } + + InterCoreSyncCallDesc desc; + desc.callee = "set_intra_block"; + desc.args = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, pipeTok), eventIdAttr}); + return desc; +} + +static InterCoreSyncCallDesc buildInterCoreSyncSetCallDyn( + ConversionPatternRewriter &rewriter, Location loc, PTOArch targetArch, + pto::PipeAttr pipeAttr, Value eventIdVal, int64_t fftsMode) { + auto *ctx = rewriter.getContext(); + std::string pipeTok = pipeTokFromPipeAttr(pipeAttr); + Value eventI32 = castInterCoreEventIdToI32(rewriter, loc, eventIdVal); + + if (targetArch == PTOArch::A3) { + Value msgVal = createFFTSMsg(rewriter, loc, eventI32, fftsMode); + + InterCoreSyncCallDesc desc; + desc.callee = "ffts_cross_core_sync"; + desc.args = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeTok), + IntegerAttr::get(IndexType::get(ctx), 0), + }); + desc.operands.push_back(msgVal); + return desc; + } + + InterCoreSyncCallDesc desc; + desc.callee = "set_intra_block"; + desc.args = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeTok), + IntegerAttr::get(IndexType::get(ctx), 0), + }); + desc.operands.push_back(eventI32); + return desc; +} + +static InterCoreSyncCallDesc buildInterCoreSyncWaitCall( + ConversionPatternRewriter &rewriter, PTOArch targetArch, + pto::PipeAttr pipeAttr, IntegerAttr eventIdAttr) { + auto *ctx = rewriter.getContext(); + std::string pipeTok = pipeTokFromPipeAttr(pipeAttr); + + InterCoreSyncCallDesc desc; + if (targetArch == PTOArch::A3) { + desc.callee = "wait_flag_dev"; + desc.args = rewriter.getArrayAttr({eventIdAttr}); + return desc; + } + + desc.callee = "wait_intra_block"; + desc.args = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, pipeTok), eventIdAttr}); + return desc; +} + +static InterCoreSyncCallDesc buildInterCoreSyncWaitCallDyn( + ConversionPatternRewriter &rewriter, Location loc, PTOArch targetArch, + pto::PipeAttr pipeAttr, Value eventIdVal) { + auto *ctx = rewriter.getContext(); + std::string pipeTok = pipeTokFromPipeAttr(pipeAttr); + Value eventI32 = castInterCoreEventIdToI32(rewriter, loc, eventIdVal); + + InterCoreSyncCallDesc desc; + if (targetArch == PTOArch::A3) { + desc.callee = "wait_flag_dev"; + desc.args = rewriter.getArrayAttr({IntegerAttr::get(IndexType::get(ctx), 0)}); + desc.operands.push_back(eventI32); + return desc; + } + + desc.callee = "wait_intra_block"; + desc.args = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeTok), + IntegerAttr::get(IndexType::get(ctx), 0), + }); + desc.operands.push_back(eventI32); + return desc; +} + + + +static FailureOr buildSyncAllWorkspaceTileValue( + ConversionPatternRewriter &rewriter, Location loc, Value originalWorkspace, + Value emittedWorkspace) { + Value workspace = peelUnrealized(emittedWorkspace); + if (auto opaqueTy = dyn_cast(workspace.getType())) { + StringRef typeStr = opaqueTy.getValue(); + if (typeStr.contains("Tile<") || typeStr.contains("ConvTile<")) + return workspace; + } + + auto memTy = dyn_cast(originalWorkspace.getType()); + if (!memTy) + return failure(); + if (!memTy.hasStaticShape()) + return failure(); + + ArrayRef rawShape = memTy.getShape(); + if (rawShape.empty() || rawShape.size() > 2) + return failure(); + + int64_t rows = rawShape.size() == 1 ? 1 : rawShape[0]; + int64_t cols = rawShape.size() == 1 ? rawShape[0] : rawShape[1]; + SmallVector shape{rows, cols}; + SmallVector validShape{rows, cols}; + + auto *ctx = rewriter.getContext(); + pto::TileBufConfigAttr configAttr = pto::TileBufConfigAttr::getDefault(ctx); + if (auto bind = originalWorkspace.getDefiningOp()) { + configAttr = bind.getConfig(); + } else if (auto cast = originalWorkspace.getDefiningOp()) { + if (auto config = cast.getConfig()) + configAttr = *config; + } + + Attribute memorySpace = memTy.getMemorySpace(); + if (!memorySpace) + return failure(); + + auto tileTy = pto::TileBufType::get(ctx, shape, memTy.getElementType(), + memorySpace, validShape, configAttr); + auto tileTypeString = getEmitCTileTypeString(tileTy); + if (!tileTypeString) + return failure(); + + auto tileEmitTy = emitc::OpaqueType::get(ctx, *tileTypeString); + Value tile = rewriter + .create(loc, tileEmitTy, + emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + + Value rawPtr = workspace; + auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); + if (isSetFFTsPointerLikeType(rawPtr.getType())) { + auto rcU64 = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + rawPtr = rewriter + .create(loc, u64Ty, "reinterpret_cast", + ArrayAttr{}, rcU64, + ValueRange{rawPtr}) + .getResult(0); + } else if (rawPtr.getType() != u64Ty) { + rawPtr = rewriter.create(loc, u64Ty, rawPtr).getResult(); + } + + rewriter.create(loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{tile, rawPtr}); + return tile; +} + + + +//===----------------------------------------------------------------------===// +// Sync lowering +//===----------------------------------------------------------------------=== + +static constexpr llvm::StringLiteral kAutoSyncTailBarrierAttr = + "pto.auto_sync_tail_barrier"; +static constexpr llvm::StringLiteral kAutoSyncTailHintAttr = + "pto.auto_sync_tail_hint"; +static constexpr llvm::StringLiteral kAutoSyncTailPolicyBarrierAll = + "barrier_all"; +static constexpr llvm::StringLiteral kAutoSyncTailPolicyMte3ToSEvent0 = + "setwait_mte3_to_s_event0"; +static constexpr llvm::StringLiteral kAutoSyncTailModeBarrierAllToken = + "PTOAutoSyncTailMode::kBarrierAll"; +static constexpr llvm::StringLiteral kAutoSyncTailModeMte3ToSEvent0Token = + "PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0"; + +static std::string getAutoSyncTailModeToken(Operation *op) { + if (op) { + if (auto hintAttr = op->getAttrOfType(kAutoSyncTailHintAttr)) { + if (hintAttr.getValue() == kAutoSyncTailPolicyBarrierAll) + return kAutoSyncTailModeBarrierAllToken.str(); + if (hintAttr.getValue() == kAutoSyncTailPolicyMte3ToSEvent0) + return kAutoSyncTailModeMte3ToSEvent0Token.str(); + } + } + + auto func = op ? op->getParentOfType() : func::FuncOp(); + if (!func) + return kAutoSyncTailModeBarrierAllToken.str(); + + auto hintAttr = func->getAttrOfType(kAutoSyncTailHintAttr); + if (!hintAttr) + return kAutoSyncTailModeBarrierAllToken.str(); + + if (hintAttr.getValue() == kAutoSyncTailPolicyBarrierAll) + return kAutoSyncTailModeBarrierAllToken.str(); + if (hintAttr.getValue() == kAutoSyncTailPolicyMte3ToSEvent0) + return kAutoSyncTailModeMte3ToSEvent0Token.str(); + + // Fallback to the conservative behavior when seeing unknown policies. + return kAutoSyncTailModeBarrierAllToken.str(); +} + +[[maybe_unused]] static std::string getPipeName(pto::PIPE pipe) { + switch (pipe) { + case pto::PIPE::PIPE_S: return "PIPE_S"; + case pto::PIPE::PIPE_V: return "PIPE_V"; + case pto::PIPE::PIPE_M: return "PIPE_M"; + case pto::PIPE::PIPE_MTE1: return "PIPE_MTE1"; + case pto::PIPE::PIPE_MTE2: return "PIPE_MTE2"; + case pto::PIPE::PIPE_MTE3: return "PIPE_MTE3"; + case pto::PIPE::PIPE_ALL: return "PIPE_ALL"; + case pto::PIPE::PIPE_MTE4: return "PIPE_MTE4"; + case pto::PIPE::PIPE_MTE5: return "PIPE_MTE5"; + case pto::PIPE::PIPE_V2: return "PIPE_V2"; + case pto::PIPE::PIPE_FIX: return "PIPE_FIX"; + case pto::PIPE::VIRTUAL_PIPE_MTE2_L1A: return "VIRTUAL_PIPE_MTE2_L1A"; + case pto::PIPE::VIRTUAL_PIPE_MTE2_L1B: return "VIRTUAL_PIPE_MTE2_L1B"; + // 默认回退 + default: return "PIPE_ALL"; + } +} + +//===----------------------------------------------------------------------===// +// pto.barrier lowering -> pipe_barrier(...) +//===----------------------------------------------------------------------===// +struct PTOBarrierToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::BarrierOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op->hasAttr(kAutoSyncTailBarrierAttr)) { + auto modeAttr = rewriter.getStringAttr(getAutoSyncTailModeToken(op)); + if (auto emitcFunc = op->getParentOfType()) { + emitcFunc->setAttr(kAutoSyncTailPendingModeAttr, modeAttr); + } else if (auto funcOp = op->getParentOfType()) { + funcOp->setAttr(kAutoSyncTailPendingModeAttr, modeAttr); + } + rewriter.eraseOp(op); + return success(); + } + + // [FIX] op.getPipe() returns PipeAttr. + // We must call .getPipe() on the attribute to get the actual Enum value. + pto::PIPE pipeEnum = op.getPipe().getPipe(); + + // Convert Enum to String (e.g., PIPE_ALL -> "PIPE_ALL") + std::string pipeStr = pto::stringifyPIPE(pipeEnum).str(); + auto *ctx = rewriter.getContext(); + + auto args = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeStr) + }); + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, // void return + "pipe_barrier", // function name + args, // arguments + ArrayAttr{}, // template args + ValueRange{} // operands + ); + + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Sync lowering (robust for bracket form pto.set_flag[...] / pto.wait_flag[...]) +// Replace your PTOSyncToRuntimeCall with the code below. +//===----------------------------------------------------------------------===// + +static bool tryConvertPipeAttrToToken(Attribute attr, std::string &token) { + if (!attr) + return false; + if (auto pipe = dyn_cast(attr)) { + token = mlir::pto::stringifyPIPE(pipe.getPipe()).str(); + return true; + } + if (auto stringAttr = dyn_cast(attr)) { + token = stringAttr.getValue().str(); + return true; + } + return false; +} + +static bool tryConvertEventAttrToToken(Attribute attr, std::string &token) { + if (!attr) + return false; + if (auto event = dyn_cast(attr)) { + token = mlir::pto::stringifyEVENT(event.getEvent()).str(); + return true; + } + if (auto stringAttr = dyn_cast(attr)) { + token = stringAttr.getValue().str(); + return true; + } + return false; +} + +static bool tryAssignSyncTokens(Attribute srcAttr, Attribute dstAttr, + Attribute evtAttr, std::string &srcTok, + std::string &dstTok, std::string &evtTok) { + std::string localSrc; + std::string localDst; + std::string localEvt; + if (!tryConvertPipeAttrToToken(srcAttr, localSrc) || + !tryConvertPipeAttrToToken(dstAttr, localDst) || + !tryConvertEventAttrToToken(evtAttr, localEvt)) { + return false; + } + srcTok = std::move(localSrc); + dstTok = std::move(localDst); + evtTok = std::move(localEvt); + return true; +} + +static bool tryExtractSyncTokensFromNamedAttrs(Operation *op, + StringRef srcName, + StringRef dstName, + StringRef evtName, + std::string &srcTok, + std::string &dstTok, + std::string &evtTok) { + return tryAssignSyncTokens(op->getAttr(srcName), op->getAttr(dstName), + op->getAttr(evtName), srcTok, dstTok, evtTok); +} + +static bool tryExtractSyncTokensFromArrayAttr(Operation *op, StringRef attrName, + std::string &srcTok, + std::string &dstTok, + std::string &evtTok) { + auto arrayAttr = op->getAttrOfType(attrName); + if (!arrayAttr || arrayAttr.size() < 3) + return false; + return tryAssignSyncTokens(arrayAttr[0], arrayAttr[1], arrayAttr[2], srcTok, + dstTok, evtTok); +} + +static bool tryExtractFallbackSyncTokens(Operation *op, std::string &srcTok, + std::string &dstTok, + std::string &evtTok) { + SmallVector pipes; + std::string event; + for (NamedAttribute namedAttr : op->getAttrs()) { + std::string token; + if (tryConvertPipeAttrToToken(namedAttr.getValue(), token)) { + pipes.push_back(std::move(token)); + continue; + } + if (event.empty() && + tryConvertEventAttrToToken(namedAttr.getValue(), token)) { + event = std::move(token); + } + } + if (pipes.size() < 2 || event.empty()) + return false; + srcTok = pipes[0]; + dstTok = pipes[1]; + evtTok = event; + return true; +} + +static LogicalResult extractSyncTripletTokens(Operation *op, + std::string &srcTok, + std::string &dstTok, + std::string &evtTok, + ConversionPatternRewriter &rewriter) { + if (tryExtractSyncTokensFromNamedAttrs(op, "src_pipe", "dst_pipe", "event_id", + srcTok, dstTok, evtTok) || + tryExtractSyncTokensFromNamedAttrs(op, "srcPipe", "dstPipe", "eventId", + srcTok, dstTok, evtTok) || + tryExtractSyncTokensFromNamedAttrs(op, "src", "dst", "event", srcTok, + dstTok, evtTok)) { + return success(); + } + + for (StringRef attrName : {"args", "pipes", "sync", "triplet", "attrs"}) { + if (tryExtractSyncTokensFromArrayAttr(op, attrName, srcTok, dstTok, + evtTok)) { + return success(); + } + } + + if (tryExtractFallbackSyncTokens(op, srcTok, dstTok, evtTok)) + return success(); + return rewriter.notifyMatchFailure( + op, "cannot extract PIPE/PIPE/EVENT tokens from pto.{set,wait}_flag"); +} +static inline std::string pipeTokFromPipeEnum(mlir::pto::PIPE p) { + return mlir::pto::stringifyPIPE(p).str(); +} +[[maybe_unused]] static inline std::string evtTokFromEventEnum(mlir::pto::EVENT e) { + return mlir::pto::stringifyEVENT(e).str(); +} +static inline std::string pipeTokFromPipeAttr(mlir::pto::PipeAttr a) { + return mlir::pto::stringifyPIPE(a.getPipe()).str(); +} +static inline std::string evtTokFromEventAttr(mlir::pto::EventAttr a) { + return mlir::pto::stringifyEVENT(a.getEvent()).str(); +} + +template +struct HasGetSrcPipe : std::false_type {}; +template +struct HasGetSrcPipe().getSrcPipe())>> : std::true_type {}; + +template +struct HasGetDstPipe : std::false_type {}; +template +struct HasGetDstPipe().getDstPipe())>> : std::true_type {}; + +template +struct HasGetEventId : std::false_type {}; +template +struct HasGetEventId().getEventId())>> : std::true_type {}; + +template +struct HasGetSrcPipeAttr : std::false_type {}; +template +struct HasGetSrcPipeAttr().getSrcPipeAttr())>> : std::true_type {}; + +template +struct HasGetDstPipeAttr : std::false_type {}; +template +struct HasGetDstPipeAttr().getDstPipeAttr())>> : std::true_type {}; + +template +struct HasGetEventIdAttr : std::false_type {}; +template +struct HasGetEventIdAttr().getEventIdAttr())>> : std::true_type {}; + +template +static LogicalResult extractSyncTokens(SyncOpT op, + std::string &srcTok, + std::string &dstTok, + std::string &evtTok, + ConversionPatternRewriter &rewriter) { + if constexpr (HasGetSrcPipe::value && + HasGetDstPipe::value && + HasGetEventId::value) { + auto s = op.getSrcPipe(); + auto d = op.getDstPipe(); + auto e = op.getEventId(); + + if constexpr (std::is_same::value) srcTok = pipeTokFromPipeEnum(s); + else srcTok = pipeTokFromPipeAttr(s); + + if constexpr (std::is_same::value) dstTok = pipeTokFromPipeEnum(d); + else dstTok = pipeTokFromPipeAttr(d); + + if constexpr (std::is_same::value) evtTok = evtTokFromEventEnum(e); + else evtTok = evtTokFromEventAttr(e); + + return success(); + } + + if constexpr (HasGetSrcPipeAttr::value && + HasGetDstPipeAttr::value && + HasGetEventIdAttr::value) { + auto s = op.getSrcPipeAttr(); + auto d = op.getDstPipeAttr(); + auto e = op.getEventIdAttr(); + srcTok = pipeTokFromPipeAttr(s); + dstTok = pipeTokFromPipeAttr(d); + evtTok = evtTokFromEventAttr(e); + return success(); + } + + return extractSyncTripletTokens(op.getOperation(), srcTok, dstTok, evtTok, rewriter); +} +struct PTOSetFlagToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::SetFlagOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + auto *ctx = rewriter.getContext(); + + std::string srcTok, dstTok, evtTok; + if (failed(extractSyncTokens(op, srcTok, dstTok, evtTok, rewriter))) + return failure(); + + auto argsAttr = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, srcTok), + emitc::OpaqueAttr::get(ctx, dstTok), + emitc::OpaqueAttr::get(ctx, evtTok), + }); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, "set_flag", + /*args=*/argsAttr, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{}); + return success(); + } +}; + +struct PTOWaitFlagToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::WaitFlagOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + auto *ctx = rewriter.getContext(); + + std::string srcTok, dstTok, evtTok; + if (failed(extractSyncTokens(op, srcTok, dstTok, evtTok, rewriter))) + return failure(); + + auto argsAttr = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, srcTok), + emitc::OpaqueAttr::get(ctx, dstTok), + emitc::OpaqueAttr::get(ctx, evtTok), + }); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, "wait_flag", + /*args=*/argsAttr, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{}); + return success(); + } +}; + +struct PTOSyncToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::TSyncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector operands; + operands.reserve(adaptor.getEvents().size()); + for (Value event : adaptor.getEvents()) + operands.push_back(peelUnrealized(event)); + + rewriter.create( + op.getLoc(), TypeRange{}, "TSYNC", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange(operands)); + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOSyncAllToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + static StringRef coreTypeTok(pto::SyncCoreType coreType) { + switch (coreType) { + case pto::SyncCoreType::AIVOnly: + return "SyncCoreType::AIVOnly"; + case pto::SyncCoreType::AICOnly: + return "SyncCoreType::AICOnly"; + case pto::SyncCoreType::Mix: + return "SyncCoreType::Mix"; + } + llvm_unreachable("unhandled SyncCoreType"); + } + + LogicalResult matchAndRewrite(mlir::pto::SyncAllOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto mode = op.getMode().getValue(); + auto coreType = op.getCoreType().getValue(); + + auto buildGmWorkspace = [&]() -> FailureOr { + Value gm = peelUnrealized(adaptor.getGmWorkspace()); + if (isEmitCGlobalTensorLikeType(gm.getType())) + return gm; + + auto memTy = dyn_cast(op.getGmWorkspace().getType()); + if (!memTy) + return failure(); + + Value gt = buildGlobalTensorFromMemref(rewriter, op.getLoc(), gm, memTy, + op.getGmWorkspace().getDefiningOp() + ? op.getGmWorkspace().getDefiningOp() + : op.getOperation()); + if (!gt) + return failure(); + return gt; + }; + + if (mode == pto::SyncAllMode::Hard) { + std::string callee = "SYNCALL<" + coreTypeTok(coreType).str() + ">"; + rewriter.create(op.getLoc(), TypeRange{}, callee, + ArrayAttr{}, ArrayAttr{}, + ValueRange{}); + rewriter.eraseOp(op); + return success(); + } + + FailureOr gmWorkspace = buildGmWorkspace(); + if (failed(gmWorkspace)) + return rewriter.notifyMatchFailure(op, + "failed to build gm_workspace GlobalTensor"); + + auto i32Ty = emitc::OpaqueType::get(rewriter.getContext(), "int32_t"); + Value usedCores = adaptor.getUsedCores() + ? peelUnrealized(adaptor.getUsedCores()) + : makeEmitCIntConstant(rewriter, op.getLoc(), i32Ty, 0); + if (usedCores.getType() != i32Ty) + usedCores = rewriter.create(op.getLoc(), i32Ty, usedCores) + .getResult(); + + std::string callee = + "SYNCALL"; + + SmallVector operands{*gmWorkspace}; + switch (coreType) { + case pto::SyncCoreType::AIVOnly: { + FailureOr ubWorkspace = + buildSyncAllWorkspaceTileValue(rewriter, op.getLoc(), + op.getUbWorkspace(), + adaptor.getUbWorkspace()); + if (failed(ubWorkspace)) + return rewriter.notifyMatchFailure( + op, "failed to materialize ub_workspace tile"); + operands.push_back(*ubWorkspace); + break; + } + case pto::SyncCoreType::AICOnly: { + FailureOr l1Workspace = + buildSyncAllWorkspaceTileValue(rewriter, op.getLoc(), + op.getL1Workspace(), + adaptor.getL1Workspace()); + if (failed(l1Workspace)) + return rewriter.notifyMatchFailure( + op, "failed to materialize l1_workspace tile"); + operands.push_back(*l1Workspace); + break; + } + case pto::SyncCoreType::Mix: { + FailureOr ubWorkspace = + buildSyncAllWorkspaceTileValue(rewriter, op.getLoc(), + op.getUbWorkspace(), + adaptor.getUbWorkspace()); + FailureOr l1Workspace = + buildSyncAllWorkspaceTileValue(rewriter, op.getLoc(), + op.getL1Workspace(), + adaptor.getL1Workspace()); + if (failed(ubWorkspace) || failed(l1Workspace)) + return rewriter.notifyMatchFailure( + op, "failed to materialize mixed syncall workspace tiles"); + operands.push_back(*ubWorkspace); + operands.push_back(*l1Workspace); + break; + } + } + + operands.push_back(usedCores); + rewriter.create(op.getLoc(), TypeRange{}, callee, + ArrayAttr{}, ArrayAttr{}, + ValueRange(operands)); + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOSyncFlagDynToEmitC : public ConversionPattern { + PTOSyncFlagDynToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + StringRef opName, StringRef callee) + : ConversionPattern(typeConverter, opName, /*benefit=*/1, ctx), + callee(callee.str()) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (operands.size() != 1) + return rewriter.notifyMatchFailure(op, "expected exactly one dynamic event-id operand"); + + auto srcAttr = op->getAttrOfType("src_pipe"); + auto dstAttr = op->getAttrOfType("dst_pipe"); + if (!srcAttr || !dstAttr) + return rewriter.notifyMatchFailure(op, "missing PipeAttr src_pipe/dst_pipe attrs"); + + auto *ctx = rewriter.getContext(); + std::string srcTok = pipeTokFromPipeAttr(srcAttr); + std::string dstTok = pipeTokFromPipeAttr(dstAttr); + + Value eventVal = operands.front(); + eventVal = + emitCCast(rewriter, op->getLoc(), emitc::OpaqueType::get(ctx, "event_t"), eventVal); + + auto argsAttr = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, srcTok), + emitc::OpaqueAttr::get(ctx, dstTok), + IntegerAttr::get(IndexType::get(ctx), 0), + }); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, callee, + /*args=*/argsAttr, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{eventVal}); + return success(); + } + +private: + std::string callee; +}; + +struct PTOGetBufToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::GetBufOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + auto *ctx = rewriter.getContext(); + + auto opTypeOr = parseSyncOpTypeLikeAttr(op.getOpTypeAttr()); + if (failed(opTypeOr)) + return rewriter.notifyMatchFailure(op, "get_buf expects pipe_event_type/sync_op_type attr"); + auto pipe = mapSyncOpTypeToPipe(*opTypeOr); + if (!isConcreteSyncPipe(pipe)) + return rewriter.notifyMatchFailure(op, "get_buf op_type cannot map to a concrete pipe"); + std::string pipeTok = pipeTokFromPipeEnum(pipe); + auto argsAttr = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeTok), + op.getBufIdAttr(), + op.getModeAttr(), + }); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, "get_buf", + /*args=*/argsAttr, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{}); + return success(); + } +}; + +struct PTORlsBufToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::RlsBufOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + auto *ctx = rewriter.getContext(); + + auto opTypeOr = parseSyncOpTypeLikeAttr(op.getOpTypeAttr()); + if (failed(opTypeOr)) + return rewriter.notifyMatchFailure(op, "rls_buf expects pipe_event_type/sync_op_type attr"); + auto pipe = mapSyncOpTypeToPipe(*opTypeOr); + if (!isConcreteSyncPipe(pipe)) + return rewriter.notifyMatchFailure(op, "rls_buf op_type cannot map to a concrete pipe"); + std::string pipeTok = pipeTokFromPipeEnum(pipe); + auto argsAttr = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeTok), + op.getBufIdAttr(), + op.getModeAttr(), + }); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, "rls_buf", + /*args=*/argsAttr, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{}); + return success(); + } +}; + +struct PTOSetFFTsToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::SetFFTsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *ctx = rewriter.getContext(); + auto loc = op.getLoc(); + + Value fftsAddr = peelUnrealized(adaptor.getFfts()); + auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); + + if (isSetFFTsPointerLikeType(fftsAddr.getType())) { + auto castTyAttr = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + fftsAddr = + rewriter + .create(loc, u64Ty, "reinterpret_cast", + /*args=*/ArrayAttr{}, + /*templateArgs=*/castTyAttr, + /*operands=*/ValueRange{fftsAddr}) + .getResult(0); + } else if (fftsAddr.getType() != u64Ty) { + fftsAddr = + rewriter.create(loc, u64Ty, fftsAddr).getResult(); + } + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, "set_ffts_base_addr", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{fftsAddr}); + return success(); + } +}; + +struct PTOSyncSetToEmitC : public OpConversionPattern { + PTOSyncSetToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + PTOArch targetArch) + : OpConversionPattern(typeConverter, ctx), + targetArch(targetArch) {} + + LogicalResult + matchAndRewrite(mlir::pto::SyncSetOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto *ctx = rewriter.getContext(); + IntegerAttr eventIdAttr = op.getEventIdAttr(); + Value eventIdDyn = adaptor.getEventIdDyn(); + int64_t fftsMode = 2; + if (IntegerAttr fftsModeAttr = op.getFftsModeAttr()) + fftsMode = fftsModeAttr.getInt(); + + if ((eventIdAttr != nullptr) == static_cast(eventIdDyn)) + return rewriter.notifyMatchFailure( + op, "expects exactly one of static event_id attr or dynamic event_id operand"); + + // A5 inter-core sync mirrors +16 only for cube-side producer (PIPE_FIX). + // Vec-side producer (PIPE_MTE3) emits a single set; hardware handles the + // subblock mapping in PTO-ISA custom flow. + if (targetArch == PTOArch::A5) { + pto::PIPE pipe = op.getPipe().getPipe(); + bool needsMirrorPlus16 = (pipe == pto::PIPE::PIPE_FIX); + std::string pipeTok = pipeTokFromPipeAttr(op.getPipe()); + auto emitSet = [&](Value eventOperand, IntegerAttr eventLiteral, + bool isDynamic) { + if (isDynamic) { + auto args = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeTok), + IntegerAttr::get(IndexType::get(ctx), 0), + }); + rewriter.create(loc, TypeRange{}, "set_intra_block", + /*args=*/args, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{eventOperand}); + return; + } + auto args = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeTok), + eventLiteral, + }); + rewriter.create(loc, TypeRange{}, "set_intra_block", + /*args=*/args, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{}); + }; + + if (eventIdAttr) { + emitSet(Value{}, eventIdAttr, /*isDynamic=*/false); + if (needsMirrorPlus16) { + auto plus16 = IntegerAttr::get(eventIdAttr.getType(), + eventIdAttr.getInt() + 16); + emitSet(Value{}, plus16, /*isDynamic=*/false); + } + } else { + Value eventI32 = castInterCoreEventIdToI32(rewriter, loc, eventIdDyn); + emitSet(eventI32, IntegerAttr{}, /*isDynamic=*/true); + if (needsMirrorPlus16) { + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value c16 = makeEmitCIntConstant(rewriter, loc, i32Ty, 16); + Value eventI32Plus16 = + rewriter.create(loc, i32Ty, eventI32, c16).getResult(); + emitSet(eventI32Plus16, IntegerAttr{}, /*isDynamic=*/true); + } + } + + rewriter.eraseOp(op); + return success(); + } + + InterCoreSyncCallDesc desc; + if (eventIdAttr) { + desc = buildInterCoreSyncSetCall(rewriter, loc, targetArch, op.getPipe(), + eventIdAttr, fftsMode); + } else { + desc = buildInterCoreSyncSetCallDyn(rewriter, loc, targetArch, op.getPipe(), + eventIdDyn, fftsMode); + } + rewriter.create(loc, TypeRange{}, desc.callee, + /*args=*/desc.args, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/desc.operands); + + rewriter.eraseOp(op); + return success(); + } + + PTOArch targetArch; +}; + +struct PTOSyncWaitToEmitC : public OpConversionPattern { + PTOSyncWaitToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + PTOArch targetArch) + : OpConversionPattern(typeConverter, ctx), + targetArch(targetArch) {} + + LogicalResult + matchAndRewrite(mlir::pto::SyncWaitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + IntegerAttr eventIdAttr = op.getEventIdAttr(); + Value eventIdDyn = adaptor.getEventIdDyn(); + + if ((eventIdAttr != nullptr) == static_cast(eventIdDyn)) + return rewriter.notifyMatchFailure( + op, "expects exactly one of static event_id attr or dynamic event_id operand"); + + InterCoreSyncCallDesc desc; + if (eventIdAttr) { + desc = buildInterCoreSyncWaitCall(rewriter, targetArch, op.getPipe(), + eventIdAttr); + } else { + desc = buildInterCoreSyncWaitCallDyn(rewriter, loc, targetArch, op.getPipe(), + eventIdDyn); + } + rewriter.create(loc, TypeRange{}, desc.callee, + desc.args, ArrayAttr{}, desc.operands); + + rewriter.eraseOp(op); + return success(); + } + + PTOArch targetArch; +}; + + +} // namespace + +void populatePTOToEmitCSyncPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx, PTOArch targetArch) { + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx, "pto.set_flag_dyn", + "set_flag"); + patterns.add(typeConverter, ctx, "pto.wait_flag_dyn", + "wait_flag"); + patterns.add(typeConverter, ctx, "pto.set_flag_d", + "set_flag"); + patterns.add(typeConverter, ctx, "pto.wait_flag_d", + "wait_flag"); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx, targetArch); + patterns.add(typeConverter, ctx, targetArch); +} + +} // namespace mlir::pto diff --git a/lib/PTO/Transforms/PTOToEmitCTileMaterialization.cpp b/lib/PTO/Transforms/PTOToEmitCTileMaterialization.cpp new file mode 100644 index 000000000..5fedd725c --- /dev/null +++ b/lib/PTO/Transforms/PTOToEmitCTileMaterialization.cpp @@ -0,0 +1,923 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- PTOToEmitCTileMaterialization.cpp ----------------------------------===// +//===----------------------------------------------------------------------===// + +#include "PTOToEmitCInternal.h" + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +#include + +using namespace mlir; +using namespace mlir::pto; + +namespace mlir::pto { +namespace { + +static constexpr llvm::StringLiteral kForceDynamicValidShapeAttrName = + "__pto.force_dynamic_valid_shape"; + +// ============================================================================= +// 2. BindTileOp Lowering (FIX: Trace back to physical address) +// ============================================================================= +struct PTOBindTileToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + struct TileBuildSpec { + std::string tileTypeStr; + bool useConstructor = false; + SmallVector constructorArgs; + }; + + static bool getIndexConst(Value v, int64_t &out) { + if (!v) + return false; + if (auto cst = v.getDefiningOp()) { + if (auto ia = dyn_cast(cst.getValue())) { + out = ia.getValue().getSExtValue(); + return true; + } + } + return false; + } + + static bool getTilePointerStrides(pto::TileBufConfigAttr configAttr, + Type elemTy, int64_t rows, int64_t cols, + int64_t &rowStride, + int64_t &colStride) { + if (rows == ShapedType::kDynamic || cols == ShapedType::kDynamic) + return false; + + int32_t blVal = 0; + if (auto blAttr = dyn_cast(configAttr.getBLayout())) + blVal = static_cast(blAttr.getValue()); + else if (auto intAttr = dyn_cast(configAttr.getBLayout())) + blVal = static_cast(intAttr.getInt()); + + int32_t slVal = 0; + if (auto slAttr = dyn_cast(configAttr.getSLayout())) + slVal = static_cast(slAttr.getValue()); + else if (auto intAttr = dyn_cast(configAttr.getSLayout())) + slVal = static_cast(intAttr.getInt()); + + bool boxed = slVal != 0; + int64_t innerRows = 1; + int64_t innerCols = 1; + if (boxed) { + int32_t fractal = 512; + if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) + fractal = static_cast(frAttr.getInt()); + + unsigned elemBytes = pto::getPTOStorageElemByteSize(elemTy); + if (elemBytes == 0) + return false; + + switch (fractal) { + case 1024: + innerRows = 16; + innerCols = 16; + break; + case 32: + innerRows = 16; + innerCols = 2; + break; + case 512: + if (slVal == 1) { + innerRows = 16; + innerCols = 32 / elemBytes; + } else if (slVal == 2) { + innerRows = 32 / elemBytes; + innerCols = 16; + } else { + return false; + } + break; + default: + return false; + } + if (innerRows <= 0 || innerCols <= 0) + return false; + } + + if (!boxed) { + if (blVal == 1) { + rowStride = 1; + colStride = rows; + } else { + rowStride = cols; + colStride = 1; + } + return true; + } + + if (blVal == 1) { + if (slVal != 1) + return false; + rowStride = innerCols; + colStride = rows; + return true; + } + + rowStride = cols; + colStride = innerRows; + return true; + } + + LogicalResult matchAndRewrite(pto::BindTileOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto configAttr = op.getConfigAttr(); + auto viewSemantics = op->getAttrOfType("pto.view_semantics"); + bool isSubView = viewSemantics && viewSemantics.getValue() == "subview"; + + auto peelAllCasts = [](Value v) { + while (auto castOp = v.getDefiningOp()) + v = castOp.getOperand(0); + if (auto castOp = v.getDefiningOp()) + v = castOp.getOperand(); + return v; + }; + auto isTileLike = [](Value v) -> bool { + auto ot = dyn_cast(v.getType()); + if (!ot) + return false; + StringRef s = ot.getValue(); + return s.contains("Tile<") || s.contains("ConvTile<"); + }; + auto buildTileSpec = [&]() -> FailureOr { + auto resMrTy = dyn_cast(op.getType()); + if (!resMrTy) + return failure(); + + const char *roleTok = "TileType::Vec"; + if (auto asAttr = + dyn_cast_or_null(resMrTy.getMemorySpace())) { + switch (asAttr.getAddressSpace()) { + case pto::AddressSpace::VEC: + roleTok = "TileType::Vec"; + break; + case pto::AddressSpace::MAT: + roleTok = "TileType::Mat"; + break; + case pto::AddressSpace::LEFT: + roleTok = "TileType::Left"; + break; + case pto::AddressSpace::RIGHT: + roleTok = "TileType::Right"; + break; + case pto::AddressSpace::ACC: + roleTok = "TileType::Acc"; + break; + case pto::AddressSpace::BIAS: + roleTok = "TileType::Bias"; + break; + case pto::AddressSpace::SCALING: + roleTok = "TileType::Scaling"; + break; + case pto::AddressSpace::GM: + case pto::AddressSpace::Zero: + roleTok = "TileType::Vec"; + break; + } + } + + Type elemTy = resMrTy.getElementType(); + Type emitElemTy = getTypeConverter()->convertType(elemTy); + if (!emitElemTy) + return failure(); + auto emitElemOpaque = dyn_cast(emitElemTy); + if (!emitElemOpaque) + return failure(); + std::string elemTypeStr = emitElemOpaque.getValue().str(); + + if (resMrTy.getRank() < 2) + return failure(); + int64_t rows = resMrTy.getDimSize(0); + int64_t cols = resMrTy.getDimSize(1); + if (rows == ShapedType::kDynamic || cols == ShapedType::kDynamic) + return failure(); + + std::string blTok = "BLayout::RowMajor"; + if (auto blAttr = dyn_cast(configAttr.getBLayout())) { + if (static_cast(blAttr.getValue()) == 1) + blTok = "BLayout::ColMajor"; + } + pto::BLayout blayout = getTileBufBLayoutValue(configAttr); + + if (isSubView) { + auto subMrTy = dyn_cast(op.getSource().getType()); + auto subViewOp = op.getSource().getDefiningOp(); + if (subMrTy && subMrTy.getRank() >= 2 && subViewOp) { + int64_t subRows = subMrTy.getDimSize(0); + int64_t subCols = subMrTy.getDimSize(1); + SmallVector inheritedStrides; + int64_t inheritedOffset = ShapedType::kDynamic; + + if (!pto::isPTOFloat4PackedType(elemTy) && + subRows != ShapedType::kDynamic && + subCols != ShapedType::kDynamic && + succeeded(getStridesAndOffset(subMrTy, inheritedStrides, + inheritedOffset)) && + inheritedStrides.size() >= 2) { + int64_t childRowStride = 0; + int64_t childColStride = 0; + bool sameStrides = getTilePointerStrides( + configAttr, elemTy, subRows, subCols, childRowStride, + childColStride); + sameStrides = sameStrides && + inheritedStrides[0] == childRowStride && + inheritedStrides[1] == childColStride; + if (sameStrides) { + rows = subRows; + cols = subCols; + } + } + } + } + + std::string slTok = "SLayout::NoneBox"; + if (auto slAttr = dyn_cast(configAttr.getSLayout())) { + int32_t slVal = static_cast(slAttr.getValue()); + slTok = (slVal == 1) ? "SLayout::RowMajor" + : (slVal == 2) ? "SLayout::ColMajor" + : "SLayout::NoneBox"; + } + + int32_t fractal = 512; + if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) + fractal = frAttr.getInt(); + + std::string padTok = "PadValue::Null"; + if (auto padAttr = dyn_cast(configAttr.getPad())) { + switch (static_cast(padAttr.getValue())) { + case 1: + padTok = "PadValue::Zero"; + break; + case 2: + padTok = "PadValue::Max"; + break; + case 3: + padTok = "PadValue::Min"; + break; + default: + padTok = "PadValue::Null"; + break; + } + } + + std::string compactTok = "CompactMode::Null"; + if (auto compactAttr = dyn_cast(configAttr.getCompactMode())) { + switch (static_cast(compactAttr.getValue())) { + case 1: + compactTok = "CompactMode::Normal"; + break; + case 2: + compactTok = "CompactMode::RowPlusOne"; + break; + default: + compactTok = "CompactMode::Null"; + break; + } + } + + std::string vrowTok, vcolTok; + bool useConstructor = false; + bool rowIsDynamic = false; + bool colIsDynamic = false; + SmallVector constructorArgs; + + Value vRow = op.getValidRow(); + Value vCol = op.getValidCol(); + Value vRowEmitC = adaptor.getValidRow(); + Value vColEmitC = adaptor.getValidCol(); + bool forceDynamicValid = op->hasAttr(kForceDynamicValidShapeAttrName); + int64_t cRow = 0, cCol = 0; + bool rowIsConst = vRow && getIndexConst(vRow, cRow); + bool colIsConst = vCol && getIndexConst(vCol, cCol); + + auto makeCtorDimValue = [&](Value emitted, int64_t fallback) -> Value { + if (emitted) + return emitted; + return makeEmitCIntConstant( + rewriter, loc, emitc::OpaqueType::get(ctx, "int32_t"), fallback); + }; + auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { + if (!emitted || !pto::isPTOFloat4PackedType(elemTy)) + return emitted; + int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; + if (dimIdx != packedDim) + return emitted; + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); + return rewriter.create(loc, i32Ty, emitted, two).getResult(); + }; + + if (forceDynamicValid) { + vrowTok = "-1"; + vcolTok = "-1"; + useConstructor = true; + constructorArgs.push_back( + makeCtorDimValue(maybeScaleDynamicValid(vRowEmitC, 0), + renderTileTemplateDim(rowIsConst ? cRow : rows, + elemTy, blayout, 0))); + constructorArgs.push_back( + makeCtorDimValue(maybeScaleDynamicValid(vColEmitC, 1), + renderTileTemplateDim(colIsConst ? cCol : cols, + elemTy, blayout, 1))); + } else { + if (rowIsConst) { + vrowTok = std::to_string( + renderTileTemplateDim(cRow, elemTy, blayout, 0)); + } else if (vRow) { + vrowTok = "-1"; + rowIsDynamic = true; + useConstructor = true; + } else { + vrowTok = std::to_string( + renderTileTemplateDim(rows, elemTy, blayout, 0)); + } + + if (colIsConst) { + vcolTok = std::to_string( + renderTileTemplateDim(cCol, elemTy, blayout, 1)); + } else if (vCol) { + vcolTok = "-1"; + colIsDynamic = true; + useConstructor = true; + } else { + vcolTok = std::to_string( + renderTileTemplateDim(cols, elemTy, blayout, 1)); + } + + if (useConstructor) { + if (rowIsDynamic && vRowEmitC) + constructorArgs.push_back(maybeScaleDynamicValid(vRowEmitC, 0)); + if (colIsDynamic && vColEmitC) + constructorArgs.push_back(maybeScaleDynamicValid(vColEmitC, 1)); + } + } + + std::string tileTypeStr = std::string("Tile<") + roleTok + ", " + + elemTypeStr + ", " + + std::to_string(renderTileTemplateDim( + rows, elemTy, blayout, 0)) + + ", " + + std::to_string(renderTileTemplateDim( + cols, elemTy, blayout, 1)) + + ", " + blTok + + ", " + vrowTok + ", " + vcolTok + ", " + slTok + + ", " + std::to_string(fractal) + ", " + padTok + + ", " + compactTok + + ">"; + return TileBuildSpec{tileTypeStr, useConstructor, constructorArgs}; + }; + + auto buildTileValue = [&](const TileBuildSpec &spec, + bool forceDeclaration = false) -> Value { + auto tileType = emitc::OpaqueType::get(ctx, spec.tileTypeStr); + if (spec.useConstructor && !forceDeclaration) { + return rewriter + .create(loc, tileType, spec.tileTypeStr, + ArrayAttr{}, ArrayAttr{}, + ValueRange(spec.constructorArgs)) + .getResult(0); + } + + return rewriter + .create(loc, tileType, emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + }; + + auto emitElemTypeToString = [&](Type elemTy) -> std::string { + return getEmitCScalarTypeToken(elemTy); + }; + + auto buildIntegralAddress = [&](Value sourceValue) -> FailureOr { + auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); + auto rcU64 = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + + Value rawPtr = sourceValue; + if (auto ot = dyn_cast(sourceValue.getType())) { + StringRef tyStr = ot.getValue(); + if (tyStr.contains("Tile<") || tyStr.contains("ConvTile<")) { + auto srcMrTy = dyn_cast(op.getSource().getType()); + if (!srcMrTy) + return failure(); + std::string elemTok = emitElemTypeToString(srcMrTy.getElementType()); + pto::AddressSpace as = pto::AddressSpace::GM; + if (auto asAttr = + dyn_cast_or_null(srcMrTy.getMemorySpace())) + as = asAttr.getAddressSpace(); + rawPtr = materializeTileDataValue(rewriter, loc, sourceValue, as, + elemTok); + } + } + + if (isSetFFTsPointerLikeType(rawPtr.getType())) { + return rewriter + .create(loc, u64Ty, "reinterpret_cast", + ArrayAttr{}, rcU64, ValueRange{rawPtr}) + .getResult(0); + } + + if (rawPtr.getType() == u64Ty) + return rawPtr; + return rewriter.create(loc, u64Ty, rawPtr).getResult(); + }; + + if (op.getSource().getDefiningOp()) { + FailureOr tileSpec = buildTileSpec(); + if (failed(tileSpec)) + return failure(); + rewriter.replaceOp(op, buildTileValue(*tileSpec)); + return success(); + } + + Value tileCandidate = peelAllCasts(adaptor.getSource()); + if (viewSemantics && viewSemantics.getValue() == "bitcast" && + isTileLike(tileCandidate)) { + FailureOr tileSpec = buildTileSpec(); + if (failed(tileSpec)) + return failure(); + Value dstTile = buildTileValue(*tileSpec); + FailureOr addr = buildIntegralAddress(tileCandidate); + if (failed(addr)) + return failure(); + + rewriter.create(loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dstTile, *addr}); + rewriter.replaceOp(op, dstTile); + return success(); + } + + if (viewSemantics && viewSemantics.getValue() == "treshape" && + isTileLike(tileCandidate)) { + FailureOr tileSpec = buildTileSpec(); + if (failed(tileSpec)) + return failure(); + Value dstTile = buildTileValue(*tileSpec, /*forceDeclaration=*/true); + + rewriter.create(loc, TypeRange{}, "TRESHAPE", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dstTile, tileCandidate}); + rewriter.replaceOp(op, dstTile); + return success(); + } + + // Subview origins are kept distinct from generic tile rebinding: + // even when source/destination C++ tile types match, subview may carry + // shifted base address semantics and should materialize a fresh handle. + if (isSubView) { + FailureOr tileSpec = buildTileSpec(); + if (failed(tileSpec)) + return failure(); + Value dstTile = buildTileValue(*tileSpec); + FailureOr addr = buildIntegralAddress(tileCandidate); + if (failed(addr)) + return failure(); + + rewriter.create(loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dstTile, *addr}); + rewriter.replaceOp(op, dstTile); + return success(); + } + + // Generic tile-to-tile rebind path: preserve the same backing storage and + // rebuild a sibling tile with updated metadata/valid dims. + if (isTileLike(tileCandidate)) { + FailureOr tileSpec = buildTileSpec(); + if (failed(tileSpec)) + return failure(); + + if (!tileSpec->useConstructor) { + if (auto srcTy = dyn_cast(tileCandidate.getType())) { + if (srcTy.getValue() == tileSpec->tileTypeStr) { + rewriter.replaceOp(op, tileCandidate); + return success(); + } + } + } + + Value dstTile = buildTileValue(*tileSpec); + FailureOr addr = buildIntegralAddress(tileCandidate); + if (failed(addr)) + return failure(); + + rewriter.create(loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dstTile, *addr}); + rewriter.replaceOp(op, dstTile); + return success(); + } + + SmallVector physAddrs; + Value source = op.getSource(); + + while (auto castOp = source.getDefiningOp()) + source = castOp.getOperand(0); + + if (auto upstreamCast = source.getDefiningOp()) { + auto upstreamOperands = upstreamCast.getAddrs(); + physAddrs.append(upstreamOperands.begin(), upstreamOperands.end()); + } else { + physAddrs.push_back(adaptor.getSource()); + } + + Value vRow = op.getValidRow(); + Value vCol = op.getValidCol(); + + auto newCast = rewriter.create( + loc, op.getType(), physAddrs, vRow ? vRow : Value(), + vCol ? vCol : Value(), configAttr); + if (viewSemantics) + newCast->setAttr("pto.view_semantics", viewSemantics); + if (op->hasAttr(kForceDynamicValidShapeAttrName)) + newCast->setAttr(kForceDynamicValidShapeAttrName, + op->getAttr(kForceDynamicValidShapeAttrName)); + rewriter.replaceOp(op, newCast.getResult()); + + return success(); + } +}; + +struct PTOAllocTileToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::AllocTileOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + auto tileTy = cast(op.getResult().getType()); + auto tileTypeString = getEmitCTileTypeString(tileTy); + if (!tileTypeString) + return rewriter.notifyMatchFailure( + op, "only rank-2 alloc_tile handles can be converted to EmitC"); + + Type convertedTy = getTypeConverter()->convertType(tileTy); + if (!convertedTy) + convertedTy = emitc::OpaqueType::get(ctx, *tileTypeString); + + auto validShape = tileTy.getValidShape(); + bool hasDynamicValidDim = + llvm::any_of(validShape, [](int64_t dim) { return dim < 0; }); + bool useConstructor = hasDynamicValidDim; + + SmallVector constructorArgs; + if (useConstructor) { + Type elemTy = tileTy.getElementType(); + pto::BLayout blayout = getTileBufBLayoutValue(tileTy.getConfigAttr()); + auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { + if (!emitted || !pto::isPTOFloat4PackedType(elemTy)) + return emitted; + int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; + if (dimIdx != packedDim) + return emitted; + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); + return rewriter.create(loc, i32Ty, emitted, two) + .getResult(); + }; + + if (validShape.size() > 0 && validShape[0] < 0) { + Value validRow = adaptor.getValidRow(); + if (!validRow) + return rewriter.notifyMatchFailure( + op, "dynamic alloc_tile valid row must have an operand"); + if (validRow) + validRow = peelUnrealized(validRow); + constructorArgs.push_back(maybeScaleDynamicValid(validRow, 0)); + } + if (validShape.size() > 1 && validShape[1] < 0) { + Value validCol = adaptor.getValidCol(); + if (!validCol) + return rewriter.notifyMatchFailure( + op, "dynamic alloc_tile valid col must have an operand"); + if (validCol) + validCol = peelUnrealized(validCol); + constructorArgs.push_back(maybeScaleDynamicValid(validCol, 1)); + } + } + + Value tile; + if (useConstructor) { + tile = rewriter + .create( + loc, convertedTy, *tileTypeString, ArrayAttr{}, + ArrayAttr{}, ValueRange(constructorArgs)) + .getResult(0); + } else { + tile = + rewriter + .create( + loc, convertedTy, emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + } + + Value addr = adaptor.getAddr(); + if (addr) { + addr = peelUnrealized(addr); + auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); + if (isa(addr.getType()) || + (isa(addr.getType()) && + cast(addr.getType()).getValue().ends_with("*"))) { + auto rcU64 = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + addr = rewriter + .create(loc, u64Ty, "reinterpret_cast", + ArrayAttr{}, rcU64, + ValueRange{addr}) + .getResult(0); + } else if (addr.getType() != u64Ty) { + addr = rewriter.create(loc, u64Ty, addr).getResult(); + } + + rewriter.create(loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{tile, addr}); + } + + rewriter.replaceOp(op, tile); + return success(); + } +}; + +static FailureOr +createEmitCTileVariable(ConversionPatternRewriter &rewriter, Location loc, + const TypeConverter *typeConverter, + pto::TileBufType tileTy) { + auto tileTypeString = getEmitCTileTypeString(tileTy); + if (!tileTypeString) + return failure(); + + Type convertedTy = typeConverter->convertType(tileTy); + if (!convertedTy) + convertedTy = emitc::OpaqueType::get(rewriter.getContext(), *tileTypeString); + + return rewriter + .create( + loc, convertedTy, emitc::OpaqueAttr::get(rewriter.getContext(), "")) + .getResult(); +} + +struct PTOTReshapeToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TReshapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto tileTy = dyn_cast(op.getResult().getType()); + if (!tileTy) + return failure(); + + FailureOr dst = + createEmitCTileVariable(rewriter, op.getLoc(), getTypeConverter(), tileTy); + if (failed(dst)) + return failure(); + + Value src = peelUnrealized(adaptor.getSrc()); + if (auto castOp = src.getDefiningOp()) + src = castOp.getOperand(); + + rewriter.create(op.getLoc(), TypeRange{}, "TRESHAPE", + ArrayAttr{}, ArrayAttr{}, + ValueRange{*dst, src}); + rewriter.replaceOp(op, *dst); + return success(); + } +}; + +struct PTOBitcastToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::BitcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto dstTy = dyn_cast(op.getResult().getType()); + auto srcTy = dyn_cast(op.getSrc().getType()); + if (!dstTy || !srcTy) + return failure(); + + FailureOr dst = + createEmitCTileVariable(rewriter, op.getLoc(), getTypeConverter(), dstTy); + if (failed(dst)) + return failure(); + + Value src = peelUnrealized(adaptor.getSrc()); + if (auto castOp = src.getDefiningOp()) + src = castOp.getOperand(); + + pto::AddressSpace as = pto::AddressSpace::GM; + if (auto asAttr = + dyn_cast_or_null(srcTy.getMemorySpace())) + as = asAttr.getAddressSpace(); + std::string elemTok = getEmitCScalarTypeToken(srcTy.getElementType()); + + Value rawPtr = materializeTileDataValue(rewriter, op.getLoc(), src, as, elemTok); + auto u64Ty = emitc::OpaqueType::get(rewriter.getContext(), "uint64_t"); + Value addr = rawPtr; + if (isSetFFTsPointerLikeType(rawPtr.getType())) { + auto rcU64 = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), + "uint64_t")}); + addr = rewriter + .create(op.getLoc(), u64Ty, + "reinterpret_cast", ArrayAttr{}, + rcU64, ValueRange{rawPtr}) + .getResult(0); + } else if (addr.getType() != u64Ty) { + addr = rewriter.create(op.getLoc(), u64Ty, addr).getResult(); + } + + rewriter.create(op.getLoc(), TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{*dst, addr}); + rewriter.replaceOp(op, *dst); + return success(); + } +}; + +struct PTOMaterializeTileToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + static bool isTileLike(Value v) { + auto ot = dyn_cast(v.getType()); + if (!ot) + return false; + StringRef s = ot.getValue(); + return s.contains("Tile<") || s.contains("ConvTile<"); + } + + LogicalResult matchAndRewrite(pto::MaterializeTileOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + auto tileTy = cast(op.getResult().getType()); + auto tileTypeString = getEmitCTileTypeString(tileTy); + if (!tileTypeString) + return rewriter.notifyMatchFailure( + op, "only rank-2 tile_buf handles can be materialized to EmitC"); + + Type convertedTy = getTypeConverter()->convertType(tileTy); + if (!convertedTy) + convertedTy = emitc::OpaqueType::get(ctx, *tileTypeString); + + Value source = peelUnrealized(adaptor.getSource()); + if (auto castOp = source.getDefiningOp()) + source = castOp.getOperand(); + + auto viewSemantics = op->getAttrOfType("pto.view_semantics"); + bool forceDynamicValid = op->hasAttr(kForceDynamicValidShapeAttrName); + bool isReshape = viewSemantics && viewSemantics.getValue() == "treshape"; + bool isSubview = viewSemantics && viewSemantics.getValue() == "subview"; + bool sourceIsDeclaredTile = + op.getSource().getDefiningOp(); + + auto createTileValue = [&]() -> Value { + SmallVector constructorArgs; + bool useConstructor = false; + pto::BLayout blayout = getTileBufBLayoutValue(tileTy.getConfigAttr()); + Type elemTy = tileTy.getElementType(); + auto shape = tileTy.getShape(); + auto validShape = tileTy.getValidShape(); + + auto makeCtorDimValue = [&](Value emitted, int64_t fallback) -> Value { + if (emitted) + return emitted; + return makeEmitCIntConstant( + rewriter, loc, emitc::OpaqueType::get(ctx, "int32_t"), fallback); + }; + auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { + if (!emitted || !pto::isPTOFloat4PackedType(elemTy)) + return emitted; + int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; + if (dimIdx != packedDim) + return emitted; + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); + return rewriter.create(loc, i32Ty, emitted, two).getResult(); + }; + auto fallbackDim = [&](int dimIdx) { + return renderTileTemplateDim(shape[dimIdx], elemTy, blayout, dimIdx); + }; + + if (forceDynamicValid) { + useConstructor = true; + constructorArgs.push_back(makeCtorDimValue( + maybeScaleDynamicValid(adaptor.getValidRow(), 0), fallbackDim(0))); + constructorArgs.push_back(makeCtorDimValue( + maybeScaleDynamicValid(adaptor.getValidCol(), 1), fallbackDim(1))); + } else { + if (validShape[0] == ShapedType::kDynamic) { + useConstructor = true; + constructorArgs.push_back(makeCtorDimValue( + maybeScaleDynamicValid(adaptor.getValidRow(), 0), fallbackDim(0))); + } + if (validShape[1] == ShapedType::kDynamic) { + useConstructor = true; + constructorArgs.push_back(makeCtorDimValue( + maybeScaleDynamicValid(adaptor.getValidCol(), 1), fallbackDim(1))); + } + } + + if (useConstructor) { + return rewriter + .create(loc, convertedTy, *tileTypeString, + ArrayAttr{}, ArrayAttr{}, + ValueRange(constructorArgs)) + .getResult(0); + } + + return rewriter + .create(loc, convertedTy, + emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + }; + + if (!isSubview && !forceDynamicValid && isTileLike(source)) { + if (auto srcTy = dyn_cast(source.getType())) { + if (srcTy.getValue() == *tileTypeString) { + rewriter.replaceOp(op, source); + return success(); + } + } + } + + Value tile = createTileValue(); + if (sourceIsDeclaredTile) { + rewriter.replaceOp(op, tile); + return success(); + } + + if (isReshape && isTileLike(source)) { + rewriter.create(loc, TypeRange{}, "TRESHAPE", + ArrayAttr{}, ArrayAttr{}, + ValueRange{tile, source}); + rewriter.replaceOp(op, tile); + return success(); + } + + pto::AddressSpace as = pto::AddressSpace::GM; + if (auto asAttr = + dyn_cast_or_null(tileTy.getMemorySpace())) + as = asAttr.getAddressSpace(); + std::string elemTok = getEmitCScalarTypeToken(tileTy.getElementType()); + + Value rawPtr = source; + if (isTileLike(rawPtr)) + rawPtr = materializeTileDataValue(rewriter, loc, rawPtr, as, elemTok); + + auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); + Value addr = rawPtr; + if (isSetFFTsPointerLikeType(rawPtr.getType())) { + auto rcU64 = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + addr = rewriter + .create(loc, u64Ty, "reinterpret_cast", + ArrayAttr{}, rcU64, + ValueRange{rawPtr}) + .getResult(0); + } else if (rawPtr.getType() != u64Ty) { + addr = rewriter.create(loc, u64Ty, rawPtr).getResult(); + } + + rewriter.create(loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{tile, addr}); + rewriter.replaceOp(op, tile); + return success(); + } +}; + + +} // namespace + +void populatePTOToEmitCTileMaterializationPatterns( + RewritePatternSet &patterns, TypeConverter &typeConverter, MLIRContext *ctx) { + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); +} + +} // namespace mlir::pto diff --git a/lib/PTO/Transforms/PTOToEmitCTilePatterns.cpp b/lib/PTO/Transforms/PTOToEmitCTilePatterns.cpp new file mode 100644 index 000000000..0854efa46 --- /dev/null +++ b/lib/PTO/Transforms/PTOToEmitCTilePatterns.cpp @@ -0,0 +1,1438 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- PTOToEmitCTilePatterns.cpp ----------------------------------------===// +//===----------------------------------------------------------------------===// + +#include "PTOToEmitCInternal.h" + +#include "PTO/IR/PTO.h" + +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +#include + +using namespace mlir; +using namespace mlir::pto; + +namespace mlir::pto { +namespace { + +struct PTOTAndToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAndOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value a = peelUnrealized(adaptor.getSrc0()); + Value b = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TAND", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, a, b}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOConcatToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TConcatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TCONCAT", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; +struct PTOConcatidxToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TConcatidxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value src0Idx = peelUnrealized(adaptor.getSrc0Idx()); + Value src1Idx = peelUnrealized(adaptor.getSrc1Idx()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TCONCAT", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src0, src1, src0Idx, src1Idx}); + + rewriter.eraseOp(op); + return success(); + } +}; +struct PTOAndSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAndSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TANDS", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src, scalar}); + + rewriter.eraseOp(op); + return success(); + } +}; + + +struct PTOTCIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TCIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value S = peelUnrealized(adaptor.getOperands()[0]); + + // The TCI scalar template parameter should follow the original PTO IR + // scalar type, not the converted EmitC value type. + std::string scalarTok = "int32_t"; + if (auto it = dyn_cast(op->getOperand(0).getType())) { + bool isUnsigned = it.isUnsigned(); + if (it.getWidth() == 16) + scalarTok = isUnsigned ? "uint16_t" : "int16_t"; + else + scalarTok = isUnsigned ? "uint32_t" : "int32_t"; + } + + // descending -> "0"/"1" + std::string descTok = op.getDescending() ? "1" : "0"; + + ArrayAttr targs; + if (auto ot = mlir::dyn_cast(dst.getType())) { + std::string tileTok = ot.getValue().str(); // "Tile<...>" + targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, tileTok), + emitc::OpaqueAttr::get(ctx, scalarTok), + emitc::OpaqueAttr::get(ctx, descTok), + }); + } else { + targs = rewriter.getArrayAttr({}); + } + + rewriter.create( + loc, TypeRange{}, "TCI", + /*args=*/ArrayAttr{}, + /*templateArgs=*/targs, + /*operands=*/ValueRange{dst, S}); + + rewriter.eraseOp(op); + return success(); + } +}; +static std::string cmpModeTok(pto::CmpModeAttr a) { + // 生成 "CmpMode::GT" 这种 token + auto m = a.getValue(); // 取 enum + switch (m) { + case pto::CmpMode::EQ: return "CmpMode::EQ"; + case pto::CmpMode::NE: return "CmpMode::NE"; + case pto::CmpMode::LT: return "CmpMode::LT"; + case pto::CmpMode::LE: return "CmpMode::LE"; + case pto::CmpMode::GT: return "CmpMode::GT"; + case pto::CmpMode::GE: return "CmpMode::GE"; + } + return "CmpMode::EQ"; +} +struct PTOColExpandToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value src = peelUnrealized(adaptor.getSrc()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPAND", + /*args=*/ArrayAttr(), + /*templateArgs=*/ArrayAttr(), + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColExpandMulToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandMulOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPANDMUL", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColExpandAddToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandAddOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPANDADD", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColExpandDivToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandDivOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPANDDIV", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColExpandExpdifToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandExpdifOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPANDEXPDIF", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColExpandSubToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandSubOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPANDSUB", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColExpandMaxToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPANDMAX", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColExpandMinToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPANDMIN", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOTTriToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TTriOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value diagonal = peelUnrealized(adaptor.getDiagonal()); + + ArrayAttr templateArgs; + if (auto dstOT = mlir::dyn_cast(dst.getType())) { + templateArgs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), + emitc::OpaqueAttr::get(ctx, std::to_string(op.getUpperOrLower())), + }); + } else { + templateArgs = ArrayAttr{}; + } + + SmallVector operands{dst, diagonal}; + rewriter.create( + loc, TypeRange{}, "TTRI", + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOCmpToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TCmpOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + + std::string tok = "CmpMode::EQ"; + if (auto a = op.getCmpModeAttr()) + tok = cmpModeTok(a); + + auto modeTy = emitc::OpaqueType::get(ctx, "CmpMode"); + Value modeVal = rewriter.create( + loc, modeTy, emitc::OpaqueAttr::get(ctx, tok)); + + rewriter.create( + loc, + TypeRange{}, + "TCMP", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1, modeVal}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOCmpSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TCmpSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + + // cmpMode -> token + auto cmpAttr = op.getCmpModeAttr(); // PTO_CmpModeAttr + std::string tok = cmpModeTok(cmpAttr); + + auto modeTy = emitc::OpaqueType::get(ctx, "CmpMode"); + Value modeVal = rewriter.create( + loc, modeTy, emitc::OpaqueAttr::get(ctx, tok)); + + rewriter.create( + loc, + TypeRange{}, + "TCMPS", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, scalar, modeVal}); + + rewriter.eraseOp(op); + return success(); + } +}; + + +struct PTOColMaxToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + // intrinsic: TCOLMAX(dst, src) + rewriter.create( + loc, TypeRange{}, "TCOLMAX", + /*args=*/ArrayAttr{}, // default: print all operands + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColArgMaxToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColArgMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLARGMAX", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, tmp}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColMinToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + // intrinsic: TCOLMIN(dst, src) + rewriter.create( + loc, TypeRange{}, "TCOLMIN", + /*args=*/ArrayAttr{}, // default: print all operands + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColArgMinToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColArgMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLARGMIN", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, tmp}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColSumToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColSumOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + // Check if tmp exists before accessing it + if (op.getTmp()) { + // Format 2: with tmp and isBinary + Value tmp = peelUnrealized(adaptor.getTmp()); + bool isBinary = false; + if (auto a = op.getIsBinaryAttr()) + isBinary = a.getValue(); + + auto boolTy = emitc::OpaqueType::get(ctx, "bool"); + auto tok = isBinary ? "true" : "false"; + Value isBinaryVal = rewriter.create( + loc, boolTy, emitc::OpaqueAttr::get(ctx, tok)); + + rewriter.create( + loc, TypeRange{}, "TCOLSUM", + /*args=*/ArrayAttr(), + /*templateArgs=*/ArrayAttr(), + /*operands=*/ValueRange{dst, src, tmp, isBinaryVal}); + } else { + // Format 1: without tmp and isBinary + rewriter.create( + loc, TypeRange{}, "TCOLSUM", + /*args=*/ArrayAttr(), + /*templateArgs=*/ArrayAttr(), + /*operands=*/ValueRange{dst, src}); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColProdToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColProdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLPROD", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; +static std::string roundModeTok(mlir::pto::RoundModeAttr attr) { + using RM = mlir::pto::RoundMode; + switch (attr.getValue()) { + case RM::NONE: return "RoundMode::CAST_NONE"; + case RM::RINT: return "RoundMode::CAST_RINT"; + case RM::ROUND: return "RoundMode::CAST_ROUND"; + case RM::FLOOR: return "RoundMode::CAST_FLOOR"; + case RM::CEIL: return "RoundMode::CAST_CEIL"; + case RM::TRUNC: return "RoundMode::CAST_TRUNC"; + case RM::ODD: return "RoundMode::CAST_ODD"; + case RM::CAST_RINT: return "RoundMode::CAST_RINT"; + } + return "RoundMode::CAST_RINT"; +} +static std::string saturationModeTok(mlir::pto::SaturationModeAttr attr) { + using SM = mlir::pto::SaturationMode; + switch (attr.getValue()) { + case SM::ON: return "SaturationMode::ON"; + case SM::OFF: return "SaturationMode::OFF"; + } + return "SaturationMode::OFF"; +} +struct PTOCvtToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TCvtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + pto::RoundModeAttr rmAttr = op.getRmodeAttr(); + std::string rmTok = rmAttr ? roundModeTok(rmAttr) + : std::string("RoundMode::CAST_RINT"); + auto rmodeTy = emitc::OpaqueType::get(ctx, "RoundMode"); + Value rmodeVal = rewriter.create( + loc, rmodeTy, emitc::OpaqueAttr::get(ctx, rmTok)); + + auto satModeTy = emitc::OpaqueType::get(ctx, "SaturationMode"); + auto satAttr = op.getSatModeAttr(); + std::string satTok = satAttr ? saturationModeTok(satAttr) + : std::string("SaturationMode::OFF"); + Value satModeVal = rewriter.create( + loc, satModeTy, emitc::OpaqueAttr::get(ctx, satTok)); + + SmallVector operands{dst, src, rmodeVal, satModeVal}; + + rewriter.create( + loc, TypeRange{}, "TCVT", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +struct PTORandomToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRandomOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + SmallVector operands{ + dst, + peelUnrealized(adaptor.getKey0()), + peelUnrealized(adaptor.getKey1()), + peelUnrealized(adaptor.getCounter0()), + peelUnrealized(adaptor.getCounter1()), + peelUnrealized(adaptor.getCounter2()), + peelUnrealized(adaptor.getCounter3()), + }; + ArrayAttr templateArgs = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, std::to_string(op.getRounds()))}); + + rewriter.create( + loc, TypeRange{}, "PTOAS__TRANDOM", + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, operands); + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tdiv lowering -> TDIV(dst, src0, src1) +//===----------------------------------------------------------------------===// + +struct PTODivToTDIV : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TDivOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TDIV", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tdivs lowering -> TDIVS(dst, src, scalar) or TDIVS(dst, scalar, src) +// Order is determined by operand types: if src is tile_buf, order is (tile, scalar) +// Otherwise, order is (scalar, tile) +//===----------------------------------------------------------------------===// + +struct PTODivSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TDivSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value dst = peelUnrealized(adaptor.getDst()); + // Preserve source order from textual parse: + // ins(tile, scalar) -> TDIVS(dst, tile, scalar) + // ins(scalar, tile) -> TDIVS(dst, scalar, tile) + rewriter.create( + loc, TypeRange{}, "TDIVS", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src, scalar}); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// pto.tdivs (TDivSOp) lowering -> TDIVS(dst, src, scalar) or TDIVS(dst, scalar, src) +// Order is determined by operand types: if src is tile_buf, order is (tile, scalar) +// Otherwise, order is (scalar, tile) +//===----------------------------------------------------------------------===// + +struct PTOTDivSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TDivSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value dst = peelUnrealized(adaptor.getDst()); + rewriter.create( + loc, TypeRange{}, "TDIVS", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src, scalar}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.texp lowering -> TEXP(dst, src) +//===----------------------------------------------------------------------===// + +struct PTOExpToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TExpOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TEXP", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.texpands lowering -> TEXPANDS(dst, scalar) +//===----------------------------------------------------------------------===// + +struct PTOExpandsToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TExpandsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value scalar = peelUnrealized(adaptor.getScalar()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TEXPANDS", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, scalar}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.textract lowering -> TEXTRACT(dst, src, indexRow, indexCol) +//===----------------------------------------------------------------------===// + +struct PTOExtractToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TExtractOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value r0 = peelUnrealized(adaptor.getIndexRow()); + Value c0 = peelUnrealized(adaptor.getIndexCol()); + + rewriter.create( + loc, TypeRange{}, "TEXTRACT", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, r0, c0}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.textract_fp lowering -> TEXTRACT_FP(dst, src, fp, indexRow, indexCol) +//===----------------------------------------------------------------------===// + +struct PTOExtractFPToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TExtractFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value fp = peelUnrealized(adaptor.getFp()); + Value dst = peelUnrealized(adaptor.getDst()); + Value r0 = peelUnrealized(adaptor.getIndexRow()); + Value c0 = peelUnrealized(adaptor.getIndexCol()); + + rewriter.create( + loc, TypeRange{}, "TEXTRACT_FP", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, fp, r0, c0}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tinsert lowering -> TINSERT(dst, src, indexRow, indexCol) +// Keep lowering arch-agnostic and let PTO-ISA infer proper A5 path. +//===----------------------------------------------------------------------===// + +struct PTOInsertToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TInsertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value r0 = peelUnrealized(adaptor.getIndexRow()); + Value c0 = peelUnrealized(adaptor.getIndexCol()); + + rewriter.create( + loc, TypeRange{}, "TINSERT", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, r0, c0}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tinsert_fp lowering -> TINSERT_FP(dst, src, fp, indexRow, indexCol) +//===----------------------------------------------------------------------===// + +struct PTOInsertFPToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TInsertFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value fp = peelUnrealized(adaptor.getFp()); + Value dst = peelUnrealized(adaptor.getDst()); + Value r0 = peelUnrealized(adaptor.getIndexRow()); + Value c0 = peelUnrealized(adaptor.getIndexCol()); + + rewriter.create( + loc, TypeRange{}, "TINSERT_FP", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, fp, r0, c0}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tfillpad lowering -> TFILLPAD(dst, src) +//===----------------------------------------------------------------------===// + +struct PTOFillPadToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TFillPadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TFILLPAD", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tfillpad_inplace lowering -> TFILLPAD_INPLACE(dst, src) +//===----------------------------------------------------------------------===// + +struct PTOFillPadInplaceToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TFillPadInplaceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TFILLPAD_INPLACE", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tfillpad_expand lowering -> TFILLPAD_EXPAND(dst, src) +//===----------------------------------------------------------------------===// + +struct PTOFillPadExpandToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TFillPadExpandOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TFILLPAD_EXPAND", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tgather lowering +// - Index form : TGATHER(dst, src0, indices, tmp) +// - Compare form: TGATHER(dst, src0, kValue, cdst, tmp) +// - Mask form : TGATHER(dst, src0) +//===----------------------------------------------------------------------===// + +[[maybe_unused]] static std::string maskPatternTok(mlir::pto::MaskPatternAttr a) { + + auto v = a.getValue(); // enum + return (std::string("pto::MaskPattern::") + mlir::pto::stringifyMaskPattern(v).str()); +} + +struct PTOGatherToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value src0 = peelUnrealized(adaptor.getSrc()); + + auto getOpaqueTok = [&](Value v, StringRef name) -> FailureOr { + if (auto ot = mlir::dyn_cast(v.getType())) + return ot.getValue().str(); + return rewriter.notifyMatchFailure(op, (name + " must be emitc::OpaqueType (tile)").str()); + }; + + // Case 1: index-based TGATHER(dst, src0, indices, tmp) + if (Value idx = adaptor.getIndices()) { + idx = peelUnrealized(idx); + Value tmp = peelUnrealized(adaptor.getTmp()); + + rewriter.create( + loc, TypeRange{}, "TGATHER", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, idx, tmp}); + + rewriter.eraseOp(op); + return success(); + } + + // Case 2: compare-based TGATHER( + // dst, src0, kValue, tmp, cdst, offset) + if (Value cdst = adaptor.getCdst()) { + cdst = peelUnrealized(cdst); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value kValue = peelUnrealized(adaptor.getKValue()); + + auto dstTokOr = getOpaqueTok(dst, "dst"); + auto srcTokOr = getOpaqueTok(src0, "src0"); + auto cdstTokOr = getOpaqueTok(cdst, "cdst"); + auto tmpTokOr = getOpaqueTok(tmp, "tmp"); + if (failed(dstTokOr) || failed(srcTokOr) || failed(cdstTokOr) || failed(tmpTokOr)) + return failure(); + + auto cmpAttr = op.getCmpModeAttr(); + std::string cmpTok = cmpAttr ? cmpModeTok(cmpAttr) : "CmpMode::EQ"; + int64_t offset = 0; + if (auto offsetAttr = op.getOffsetAttr()) + offset = offsetAttr.getInt(); + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value offsetVal = makeEmitCIntConstant(rewriter, loc, i32Ty, offset); + + auto targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, *dstTokOr), + emitc::OpaqueAttr::get(ctx, *srcTokOr), + emitc::OpaqueAttr::get(ctx, *tmpTokOr), + emitc::OpaqueAttr::get(ctx, *cdstTokOr), + emitc::OpaqueAttr::get(ctx, cmpTok), + }); + + rewriter.create( + loc, TypeRange{}, "TGATHER", + /*args=*/ArrayAttr{}, /*templateArgs=*/targs, + /*operands=*/ValueRange{dst, src0, kValue, tmp, cdst, offsetVal}); + + rewriter.eraseOp(op); + return success(); + } + + // Case 3: mask-pattern TGATHER(dst, src0) + auto mp = op.getMaskPatternAttr(); + if (!mp) + return rewriter.notifyMatchFailure(op, "expected maskPattern, indices, or cdst on tgather"); + + auto dstTokOr = getOpaqueTok(dst, "dst"); + auto srcTokOr = getOpaqueTok(src0, "src0"); + if (failed(dstTokOr) || failed(srcTokOr)) + return failure(); + + // mp is an EnumAttr; stringify name is "P0101" etc. + // We emit MaskPattern::P0101 (because generated C++ has `using namespace pto;`) + std::string mpTok = std::string("MaskPattern::") + + mlir::pto::stringifyMaskPattern(mp.getValue()).str(); + + auto targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, *dstTokOr), + emitc::OpaqueAttr::get(ctx, *srcTokOr), + emitc::OpaqueAttr::get(ctx, mpTok), + }); + + rewriter.create( + loc, TypeRange{}, "TGATHER", + /*args=*/ArrayAttr{}, + /*templateArgs=*/targs, + /*operands=*/ValueRange{dst, src0}); + + rewriter.eraseOp(op); + return success(); + } +}; + + +struct PTOGatherbToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGatherBOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value offsets = peelUnrealized(adaptor.getOffsets()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TGATHERB", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, offsets}); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// TLOG lowering to EmitC (PTOConvert.cpp) +//===----------------------------------------------------------------------===// + +struct PTOLogToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TLogOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src}; + rewriter.create( + loc, TypeRange{}, "TLOG", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + + + +//===----------------------------------------------------------------------===// +// TLRELU lowering to EmitC (PTOConvert.cpp) +//===----------------------------------------------------------------------===// + + struct PTOLReluToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TLReluOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value slope = peelUnrealized(adaptor.getSlope()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, slope}; + + rewriter.create( + loc, TypeRange{}, "TLRELU", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// TMAX lowering to EmitC (PTOConvert.cpp) +//===----------------------------------------------------------------------===// + +struct PTOMaxToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TMAX", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// TMAXS lowering to EmitC (PTOConvert.cpp) +//===----------------------------------------------------------------------===// + + struct PTOMaxSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMaxSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, scalar}; + rewriter.create( + loc, TypeRange{}, "TMAXS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + + +//===----------------------------------------------------------------------===// +// TMIN lowering to EmitC (PTOConvert.cpp) +//===----------------------------------------------------------------------===// + +struct PTOMinToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TMIN", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// TMINS lowering to EmitC (PTOConvert.cpp) +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// TMINS lowering to EmitC (fix APFloat -> FloatAttr) (PTOConvert.cpp) +//===----------------------------------------------------------------------===// + +struct PTOMinsToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMinSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value scalar = peelUnrealized(adaptor.getScalar()); + + SmallVector operands{dst, src, scalar}; + rewriter.create( + loc, TypeRange{}, "TMINS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering for TMOV op -> EmitC) +//===----------------------------------------------------------------------===// + +struct PTOMovToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMovOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value fp; + if (op.getFp()) + fp = peelUnrealized(adaptor.getFp()); + Value preQuantScalar; + if (op.getPreQuantScalar()) + preQuantScalar = peelUnrealized(adaptor.getPreQuantScalar()); + + auto dstOT = mlir::dyn_cast(dst.getType()); + auto srcOT = mlir::dyn_cast(src.getType()); + if (!dstOT || !srcOT) + return rewriter.notifyMatchFailure( + op, "tmov lowering expects opaque dst/src types"); + + auto modeTok = [&](pto::AccToVecMode mode) -> StringRef { + switch (mode) { + case pto::AccToVecMode::SingleModeVec0: + return "pto::AccToVecMode::SingleModeVec0"; + case pto::AccToVecMode::SingleModeVec1: + return "pto::AccToVecMode::SingleModeVec1"; + case pto::AccToVecMode::DualModeSplitM: + return "pto::AccToVecMode::DualModeSplitM"; + case pto::AccToVecMode::DualModeSplitN: + return "pto::AccToVecMode::DualModeSplitN"; + } + llvm_unreachable("unknown AccToVecMode"); + }; + + auto modeAttr = op.getAccToVecModeAttr(); + auto reluTok = [&](pto::ReluPreMode mode) -> StringRef { + switch (mode) { + case pto::ReluPreMode::NoRelu: + return "ReluPreMode::NoRelu"; + case pto::ReluPreMode::NormalRelu: + return "ReluPreMode::NormalRelu"; + } + llvm_unreachable("unknown ReluPreMode"); + }; + + const bool hasFp = static_cast(fp); + const bool hasPreQuantScalar = static_cast(preQuantScalar); + const bool hasMode = static_cast(modeAttr); + const bool reluNonDefault = op.getReluPreMode() != pto::ReluPreMode::NoRelu; + + SmallVector operands{dst, src}; + SmallVector templateArgVec{ + emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), + emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), + }; + StringRef callee = "TMOV"; + + if (hasFp) { + auto fpOT = mlir::dyn_cast(fp.getType()); + if (!fpOT) + return rewriter.notifyMatchFailure( + op, "tmov fp lowering expects opaque fp type"); + operands.push_back(fp); + templateArgVec.push_back(emitc::OpaqueAttr::get(ctx, fpOT.getValue().str())); + if (hasMode) + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, modeTok(modeAttr.getValue()))); + if (hasMode || reluNonDefault) + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); + callee = hasMode ? "TMOV" : "TMOV_FP"; + } else if (hasPreQuantScalar) { + operands.push_back(preQuantScalar); + if (hasMode) + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, modeTok(modeAttr.getValue()))); + if (hasMode || reluNonDefault) + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); + } else if (hasMode) { + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, modeTok(modeAttr.getValue()))); + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); + } else if (reluNonDefault) { + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); + } + + ArrayAttr templateArgs = + templateArgVec.size() == 2 && !hasFp && !hasPreQuantScalar && + !hasMode && !reluNonDefault + ? ArrayAttr{} + : rewriter.getArrayAttr(templateArgVec); + + rewriter.create( + loc, TypeRange{}, callee, + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TMOV_FP DPS/memref op) +//===----------------------------------------------------------------------===// + +void populatePTOToEmitCTilePatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx) { + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add( + typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + populatePTOToEmitCTileExtraPatterns(patterns, typeConverter, ctx); +} + +} // namespace mlir::pto diff --git a/lib/PTO/Transforms/PTOToEmitCTilePatternsExtra.cpp b/lib/PTO/Transforms/PTOToEmitCTilePatternsExtra.cpp new file mode 100644 index 000000000..e7c5b93cc --- /dev/null +++ b/lib/PTO/Transforms/PTOToEmitCTilePatternsExtra.cpp @@ -0,0 +1,1819 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- PTOToEmitCTilePatternsExtra.cpp -----------------------------------===// +//===----------------------------------------------------------------------===// + +#include "PTOToEmitCInternal.h" + +#include "PTO/IR/PTO.h" + +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +#include + +using namespace mlir; +using namespace mlir::pto; + +namespace mlir::pto { +namespace { + +[[maybe_unused]] static std::string maskPatternTok(mlir::pto::MaskPatternAttr a) { + auto value = a.getValue(); + return (std::string("pto::MaskPattern::") + + mlir::pto::stringifyMaskPattern(value).str()); +} + +struct PTOMovFPToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMovFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value src = peelUnrealized(adaptor.getSrc()); + Value fp = peelUnrealized(adaptor.getFp()); + + // TMOV_FP(dstTileData, cTile, fbTile) + ArrayAttr templateArgs; + auto dstOT = mlir::dyn_cast(dst.getType()); + auto srcOT = mlir::dyn_cast(src.getType()); + auto fpOT = mlir::dyn_cast(fp.getType()); + if (dstOT && srcOT && fpOT) { + templateArgs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), + emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), + emitc::OpaqueAttr::get(ctx, fpOT.getValue().str()), + }); + } else { + templateArgs = ArrayAttr{}; + } + + SmallVector operands{dst, src, fp}; + rewriter.create( + loc, TypeRange{}, "TMOV_FP", + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOQuantToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TQuantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value src = peelUnrealized(adaptor.getSrc()); + Value fp = peelUnrealized(adaptor.getFp()); + + // Optional offset (INT8_ASYM only): passed as pointer (&offset) + Value offsetPtr; + if (op.getOffset()) { + Value offset = peelUnrealized(adaptor.getOffset()); + auto offsetOT = mlir::dyn_cast(offset.getType()); + if (offsetOT) { + offsetPtr = rewriter + .create( + loc, emitc::PointerType::get(offsetOT), "&", offset) + .getResult(); + } + } + + // TQUANT(dst, src, fp[, &offset]) + std::string quantTypeStr = + op.getQuantType() == pto::QuantType::INT8_SYM + ? "pto::QuantType::INT8_SYM" + : "pto::QuantType::INT8_ASYM"; + ArrayAttr templateArgs; + auto dstOT = mlir::dyn_cast(dst.getType()); + auto srcOT = mlir::dyn_cast(src.getType()); + auto fpOT = mlir::dyn_cast(fp.getType()); + if (dstOT && srcOT && fpOT) { + templateArgs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, quantTypeStr), + emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), + emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), + emitc::OpaqueAttr::get(ctx, fpOT.getValue().str()), + }); + } else { + templateArgs = ArrayAttr{}; + } + + SmallVector operands{dst, src, fp}; + if (offsetPtr) + operands.push_back(offsetPtr); + + rewriter.create( + loc, TypeRange{}, "TQUANT", + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +struct PTODequantToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TDequantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value src = peelUnrealized(adaptor.getSrc()); + Value scale = peelUnrealized(adaptor.getScale()); + Value offset = peelUnrealized(adaptor.getOffset()); + + // TDEQUANT(dst, src, scale, offset) + ArrayAttr templateArgs; + auto dstOT = mlir::dyn_cast(dst.getType()); + auto srcOT = mlir::dyn_cast(src.getType()); + auto scaleOT = mlir::dyn_cast(scale.getType()); + if (dstOT && srcOT && scaleOT) { + templateArgs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), + emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), + emitc::OpaqueAttr::get(ctx, scaleOT.getValue().str()), + }); + } else { + templateArgs = ArrayAttr{}; + } + + rewriter.create( + loc, TypeRange{}, "TDEQUANT", + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, + /*operands=*/SmallVector{dst, src, scale, offset}); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TMRGSORT DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOMrgSortToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMrgSortOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + if (op.isFormat1()) { + Value src = peelUnrealized(adaptor.getSrcs().front()); + Value dst = peelUnrealized(adaptor.getDsts().front()); + Value blockLen = peelUnrealized(adaptor.getBlockLen()); + + SmallVector operands{dst, src, blockLen}; + rewriter.create( + loc, TypeRange{}, "TMRGSORT", + ArrayAttr{}, ArrayAttr{}, operands); + } else if (op.isFormat2()) { + // pto-isa API: + // TMRGSORT( + // dst, executedNumList, tmp, src0, src1[, src2[, src3]]); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDsts()[0]); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value excuted = peelUnrealized(adaptor.getExcuted()); + + SmallVector srcs; + srcs.reserve(adaptor.getSrcs().size()); + for (Value v : adaptor.getSrcs()) + srcs.push_back(peelUnrealized(v)); + + auto dstOT = mlir::dyn_cast(dst.getType()); + auto tmpOT = mlir::dyn_cast(tmp.getType()); + if (!dstOT || !tmpOT || srcs.size() < 2 || srcs.size() > 4) + return op.emitOpError("format2 expects dst/tmp tilebufs and 2 to 4 srcs"); + + SmallVector targs; + targs.reserve(2 + srcs.size() + 1); + targs.push_back(emitc::OpaqueAttr::get(ctx, dstOT.getValue().str())); + targs.push_back(emitc::OpaqueAttr::get(ctx, tmpOT.getValue().str())); + for (Value v : srcs) { + auto ot = mlir::dyn_cast(v.getType()); + if (!ot) + return op.emitOpError("format2 expects tilebuf srcs"); + targs.push_back(emitc::OpaqueAttr::get(ctx, ot.getValue().str())); + } + targs.push_back(emitc::OpaqueAttr::get(ctx, op.getExhausted() ? "true" : "false")); + ArrayAttr templateArgs = rewriter.getArrayAttr(targs); + + SmallVector operands{dst, excuted, tmp}; + operands.append(srcs.begin(), srcs.end()); + + rewriter.create( + loc, TypeRange{}, "TMRGSORT", + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, operands); + } else { + return op.emitOpError("unsupported mrgsort_dps format"); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TMUL DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOMulToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMulOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TMUL", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TMULS DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOMulsToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMulSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc0()); + Value dst = peelUnrealized(adaptor.getDst()); + Value scalar = peelUnrealized(adaptor.getScalar()); + + SmallVector operands{dst, src, scalar}; + rewriter.create( + loc, TypeRange{}, "TMULS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TNEG DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTONegToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TNegOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src}; + rewriter.create( + loc, TypeRange{}, "TNEG", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TNOT DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTONotToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TNotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src}; + rewriter.create( + loc, TypeRange{}, "TNOT", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TOR DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOOrToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TOrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TOR", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TORS DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOOrsToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TOrSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + // NOTE: The conversion type system may materialize integers as emitc.opaque + // (e.g. "int32_t"). For EmitC call emission we can pass the scalar through + // directly without arith casts here. + Value s = adaptor.getScalar(); + + SmallVector operands{dst, src0, s}; + rewriter.create( + loc, TypeRange{}, "TORS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TPARTADD DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOPartAddToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPartAddOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TPARTADD", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TPARTMAX DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOPartMaxToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPartMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TPARTMAX", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TPARTMIN DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOPartMinToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPartMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TPARTMIN", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOPartArgMaxToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPartArgMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value src0Idx = peelUnrealized(adaptor.getSrc0Idx()); + Value src1Idx = peelUnrealized(adaptor.getSrc1Idx()); + Value dst = peelUnrealized(adaptor.getDst()); + Value dstIdx = peelUnrealized(adaptor.getDstIdx()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TPARTARGMAX", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1, dstIdx, src0Idx, src1Idx}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOPartArgMinToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPartArgMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value src0Idx = peelUnrealized(adaptor.getSrc0Idx()); + Value src1Idx = peelUnrealized(adaptor.getSrc1Idx()); + Value dst = peelUnrealized(adaptor.getDst()); + Value dstIdx = peelUnrealized(adaptor.getDstIdx()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TPARTARGMIN", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1, dstIdx, src0Idx, src1Idx}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TPARTMUL DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOPartMulToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPartMulOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TPARTMUL", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TPRELU DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOPreluToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPReluOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + // C++ interface: TPRELU(dst, src0, src1, tmp) — last parameter is tmp. + SmallVector operands{dst, src0, src1, tmp}; + rewriter.create( + loc, TypeRange{}, "TPRELU", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TRECIP DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORecipToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRecipOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src}; + rewriter.create( + loc, TypeRange{}, "TRECIP", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TRELU DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOReluToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TReluOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src}; + rewriter.create( + loc, TypeRange{}, "TRELU", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TREM DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORemToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRemOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + SmallVector operands{dst, src0, src1, tmp}; + rewriter.create( + loc, TypeRange{}, "TREM", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOFModToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TFModOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TFMOD", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TREMS DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORemSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRemSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + Value scalar = peelUnrealized(adaptor.getScalar()); + SmallVector operands{dst, src, scalar, tmp}; + rewriter.create( + loc, TypeRange{}, "TREMS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOFModSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TFModSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value scalar = peelUnrealized(adaptor.getScalar()); + + SmallVector operands{dst, src, scalar}; + rewriter.create( + loc, TypeRange{}, "TFMODS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TROWEXPAND DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORowExpandToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src}; + rewriter.create( + loc, TypeRange{}, "TROWEXPAND", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTORowExpandAddToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandAddOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TROWEXPANDADD", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTORowExpandExpdifToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandExpdifOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); + + SmallVector operands; + if (tmp) + operands.assign({dst, src0, src1, tmp}); + else + operands.assign({dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TROWEXPANDEXPDIF", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TROWEXPANDDIV DPS/memref op) +//===----------------------------------------------------------------------===// +// Helper: replace or erase based on whether op has results. +static void replaceOrEraseWithOpaqueCall(Operation *op, + StringRef callee, + ArrayRef args, + ConversionPatternRewriter &rewriter) { + TypeRange resultTypes = op->getResultTypes(); + auto call = rewriter.create( + op->getLoc(), resultTypes, callee, ArrayAttr{}, ArrayAttr{}, ValueRange(args)); + if (resultTypes.empty()) + rewriter.eraseOp(op); + else + rewriter.replaceOp(op, call.getResults()); +} + +static void replaceOrEraseWithOpaqueCallAndReturnDst(Operation *op, Value dst, + StringRef callee, + ArrayRef args, + ConversionPatternRewriter &rewriter) { + rewriter.create( + op->getLoc(), TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, ValueRange(args)); + if (op->getNumResults() == 1) + rewriter.replaceOp(op, dst); + else + rewriter.eraseOp(op); +} + +// ---------- TOp ---------- +struct PTOTGemvBiasToTGEMV_BIAS + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGemvBiasOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value a = peelUnrealized(adaptor.getA()); + Value b = peelUnrealized(adaptor.getB()); + Value bias = peelUnrealized(adaptor.getBias()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCall(op.getOperation(), "TGEMV_BIAS", + {dst, a, b, bias}, rewriter); + return success(); + } +}; + +struct PTOTGemvMXToTGEMV_MX + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGemvMxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value a = peelUnrealized(adaptor.getA()); + Value aScale = peelUnrealized(adaptor.getAScale()); + Value b = peelUnrealized(adaptor.getB()); + Value bScale = peelUnrealized(adaptor.getBScale()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", + {dst, a, aScale, b, bScale}, rewriter); + return success(); + } +}; + +struct PTOTGemvMXAccToTGEMV_MX + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGemvMxAccOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value cIn = peelUnrealized(adaptor.getCIn()); + Value a = peelUnrealized(adaptor.getA()); + Value aScale = peelUnrealized(adaptor.getAScale()); + Value b = peelUnrealized(adaptor.getB()); + Value bScale = peelUnrealized(adaptor.getBScale()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", + {dst, cIn, a, aScale, b, bScale}, rewriter); + return success(); + } +}; + +struct PTOTGemvMXBiasToTGEMV_MX + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGemvMxBiasOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value a = peelUnrealized(adaptor.getA()); + Value aScale = peelUnrealized(adaptor.getAScale()); + Value b = peelUnrealized(adaptor.getB()); + Value bScale = peelUnrealized(adaptor.getBScale()); + Value bias = peelUnrealized(adaptor.getBias()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", + {dst, a, aScale, b, bScale, bias}, rewriter); + return success(); + } +}; + +struct PTOTMatmulBiasToTMATMUL_BIAS + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMatmulBiasOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value a = peelUnrealized(adaptor.getA()); + Value b = peelUnrealized(adaptor.getB()); + Value bias = peelUnrealized(adaptor.getBias()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_BIAS", + {dst, a, b, bias}, rewriter); + return success(); + } +}; + +struct PTOTMatmulMXToTMATMUL_MX + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMatmulMxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value a = peelUnrealized(adaptor.getA()); + Value aScale = peelUnrealized(adaptor.getAScale()); + Value b = peelUnrealized(adaptor.getB()); + Value bScale = peelUnrealized(adaptor.getBScale()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_MX", + {dst, a, aScale, b, bScale}, rewriter); + return success(); + } +}; + +struct PTOTMatmulMXAccToTMATMUL_MX_ACC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMatmulMxAccOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value cIn = peelUnrealized(adaptor.getCIn()); + Value a = peelUnrealized(adaptor.getA()); + Value aScale = peelUnrealized(adaptor.getAScale()); + Value b = peelUnrealized(adaptor.getB()); + Value bScale = peelUnrealized(adaptor.getBScale()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_MX", + {dst, cIn, a, aScale, b, bScale}, rewriter); + return success(); + } +}; + +struct PTOTMatmulMXBiasToTMATMUL_MX_BIAS + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMatmulMxBiasOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value a = peelUnrealized(adaptor.getA()); + Value aScale = peelUnrealized(adaptor.getAScale()); + Value b = peelUnrealized(adaptor.getB()); + Value bScale = peelUnrealized(adaptor.getBScale()); + Value bias = peelUnrealized(adaptor.getBias()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_MX", + {dst, a, aScale, b, bScale, bias}, rewriter); + return success(); + } +}; + +struct PTORowExpandDivToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandDivOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); + + SmallVector operands; + if (tmp) + operands.assign({dst, src0, src1, tmp}); + else + operands.assign({dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TROWEXPANDDIV", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TROWEXPANDMUL DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORowExpandMulToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandMulOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); + + SmallVector operands; + if (tmp) + operands.assign({dst, src0, src1, tmp}); + else + operands.assign({dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TROWEXPANDMUL", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TROWEXPANDSUB DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORowExpandSubToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandSubOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); + + SmallVector operands; + if (tmp) + operands.assign({dst, src0, src1, tmp}); + else + operands.assign({dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TROWEXPANDSUB", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTORowExpandMaxToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); + + SmallVector operands; + if (tmp) + operands.assign({dst, src0, src1, tmp}); + else + operands.assign({dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TROWEXPANDMAX", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTORowExpandMinToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); + + SmallVector operands; + if (tmp) + operands.assign({dst, src0, src1, tmp}); + else + operands.assign({dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TROWEXPANDMIN", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TROWMAX DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORowMaxToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, tmp}; + rewriter.create( + loc, TypeRange{}, "TROWMAX", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTORowArgMaxToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowArgMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TROWARGMAX", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, tmp}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TROWMIN DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORowMinToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, tmp}; + rewriter.create( + loc, TypeRange{}, "TROWMIN", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTORowArgMinToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowArgMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TROWARGMIN", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, tmp}); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TROWSUM DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORowSumToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowSumOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, tmp}; + rewriter.create( + loc, TypeRange{}, "TROWSUM", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTORowProdToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowProdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, tmp}; + rewriter.create( + loc, TypeRange{}, "TROWPROD", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TRSQRT DPS/memref op) +// - no-tmp form : TRSQRT(dst, src) +// - tmp form : TRSQRT(dst, src, tmp) +//===----------------------------------------------------------------------===// + +struct PTORsqrtToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRsqrtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + SmallVector operands{dst, src}; + if (Value tmp = adaptor.getTmp()) + operands.push_back(peelUnrealized(tmp)); + rewriter.create( + loc, TypeRange{}, "TRSQRT", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSCATTER DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOScatterToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TScatterOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + const bool hasMaskPattern = static_cast(op.getMaskPatternAttr()); + const bool hasIndexes = static_cast(op.getIndexes()); + if (hasMaskPattern == hasIndexes) { + return rewriter.notifyMatchFailure( + op, "expected exactly one of indexes operand or maskPattern attribute"); + } + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + if (auto mp = op.getMaskPatternAttr()) { + auto *ctx = rewriter.getContext(); + auto targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, maskPatternTok(mp)), + }); + rewriter.create( + loc, TypeRange{}, "TSCATTER", + /*args=*/ArrayAttr{}, /*templateArgs=*/targs, + /*operands=*/ValueRange{dst, src}); + } else { + Value idx = peelUnrealized(adaptor.getIndexes()); + rewriter.create( + loc, TypeRange{}, "TSCATTER", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, idx}); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSEL DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOSelToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSelOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value mask = peelUnrealized(adaptor.getMask()); + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, mask, src0, src1, tmp}; + rewriter.create( + loc, TypeRange{}, "TSEL", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSELS DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOSelSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSelSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value mask = peelUnrealized(adaptor.getMask()); + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, mask, src, tmp, scalar}; + rewriter.create( + loc, TypeRange{}, "TSELS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSHL DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOShlSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TShlOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TSHL", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSHR DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOShrSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TShrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TSHR", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering for TSHLS/TSHRS DPS: shift by scalar) +//===----------------------------------------------------------------------===// + +struct PTOShlSConstToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TShlSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value dst = peelUnrealized(adaptor.getDst()); + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + SmallVector operands{dst, src, scalar}; + rewriter.create( + loc, TypeRange{}, "TSHLS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOShrSConstToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TShrSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value dst = peelUnrealized(adaptor.getDst()); + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + SmallVector operands{dst, src, scalar}; + rewriter.create( + loc, TypeRange{}, "TSHRS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (TSORT32 DPS/memref op: ins(src, idx[, tmp]) outs(dst)) +//===----------------------------------------------------------------------===// + +struct PTOSORT32SToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSort32Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value idx = peelUnrealized(adaptor.getIdx()); + Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); + + SmallVector operands; + if (tmp) + operands.assign({dst, src, idx, tmp}); + else + operands.assign({dst, src, idx}); + rewriter.create( + loc, TypeRange{}, "TSORT32", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSQRT DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOSqrtSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSqrtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src}; + rewriter.create( + loc, TypeRange{}, "TSQRT", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSTORE_FP DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOStoreFPSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TStoreFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value fp = peelUnrealized(adaptor.getFp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, fp}; + rewriter.create( + loc, TypeRange{}, "TSTORE_FP", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSUB DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOSubSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSubOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TSUB", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSUBC DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOSubCSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSubCOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value src2 = peelUnrealized(adaptor.getSrc2()); + Value dst = peelUnrealized(adaptor.getDst()); + + // pto-isa does not provide NPU implementation for TSUBC yet. + // Decompose: dst = src0 - src1 + src2 + rewriter.create( + loc, TypeRange{}, "TSUB", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TADD", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, dst, src2}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSUBS DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOSubSSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSubSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, scalar}; + rewriter.create( + loc, TypeRange{}, "TSUBS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSUBSC DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOSubSCToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSubSCOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + // pto-isa does not provide NPU implementation for TSUBSC yet. + // Decompose: dst = src0 - scalar + src1 + rewriter.create( + loc, TypeRange{}, "TSUBS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, scalar}); + rewriter.create( + loc, TypeRange{}, "TADD", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, dst, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TXOR DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOXORToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TXorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + Value tmp = peelUnrealized(adaptor.getTmp()); + SmallVector operands{dst, src0, src1, tmp}; + rewriter.create( + loc, TypeRange{}, "TXOR", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +struct PTOTTransToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TTransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, tmp}; + rewriter.create( + loc, TypeRange{}, "TTRANS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TXORS DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOXORSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TXorSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, scalar, tmp}; + rewriter.create( + loc, TypeRange{}, "TXORS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +struct PTOPrintToTPRINT : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + + SmallVector operands{src}; + rewriter.create( + loc, TypeRange{}, "TPRINT", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +// pto.print "format", %scalar -> PRINTF("format", scalar) + +} // namespace + +void populatePTOToEmitCTileExtraPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx) { + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add< + PTOTMatmulBiasToTMATMUL_BIAS, + PTOTMatmulMXToTMATMUL_MX, + PTOTMatmulMXAccToTMATMUL_MX_ACC, + PTOTMatmulMXBiasToTMATMUL_MX_BIAS, + PTOTGemvBiasToTGEMV_BIAS, + PTOTGemvMXToTGEMV_MX, + PTOTGemvMXAccToTGEMV_MX, + PTOTGemvMXBiasToTGEMV_MX>(typeConverter, ctx); +} + +} // namespace mlir::pto From 858e6d9165b09e48061a8b9417d72551d6798d18 Mon Sep 17 00:00:00 2001 From: HecreReed <821896444@qq.com> Date: Tue, 26 May 2026 16:35:53 +0800 Subject: [PATCH 6/8] fix: restore mins lowering and gcc emitc build --- lib/PTO/Transforms/PTOToEmitCInternal.h | 3 +++ lib/PTO/Transforms/PTOToEmitCTilePatterns.cpp | 1 + 2 files changed, 4 insertions(+) diff --git a/lib/PTO/Transforms/PTOToEmitCInternal.h b/lib/PTO/Transforms/PTOToEmitCInternal.h index e8be34ed2..e6c039c91 100644 --- a/lib/PTO/Transforms/PTOToEmitCInternal.h +++ b/lib/PTO/Transforms/PTOToEmitCInternal.h @@ -9,6 +9,9 @@ #ifndef MLIR_DIALECT_PTO_TRANSFORMS_PTOTOEMITCINTERNAL_H #define MLIR_DIALECT_PTO_TRANSFORMS_PTOTOEMITCINTERNAL_H +#pragma GCC diagnostic ignored "-Woverloaded-virtual" +// GCC warns on MLIR OpConversionPattern helper overloads hiding RewritePattern::rewrite. + #include "PTO/IR/PTO.h" #include "mlir/IR/MLIRContext.h" diff --git a/lib/PTO/Transforms/PTOToEmitCTilePatterns.cpp b/lib/PTO/Transforms/PTOToEmitCTilePatterns.cpp index 0854efa46..e723c2c9c 100644 --- a/lib/PTO/Transforms/PTOToEmitCTilePatterns.cpp +++ b/lib/PTO/Transforms/PTOToEmitCTilePatterns.cpp @@ -1432,6 +1432,7 @@ void populatePTOToEmitCTilePatterns(RewritePatternSet &patterns, patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); populatePTOToEmitCTileExtraPatterns(patterns, typeConverter, ctx); } From 3ce0de37072675921c19d74e6bca241afd34c0a3 Mon Sep 17 00:00:00 2001 From: HecreReed <821896444@qq.com> Date: Tue, 26 May 2026 19:17:14 +0800 Subject: [PATCH 7/8] fix: restore split emitc sources after main merge --- lib/PTO/Transforms/PTOToEmitC.cpp | 6487 -------------------- lib/PTO/Transforms/PTOToEmitCSimpleOps.cpp | 34 +- 2 files changed, 32 insertions(+), 6489 deletions(-) diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index 47ace03ac..c8e15b51e 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -2020,6493 +2020,6 @@ struct PTOMScatterToMSCATTER : public OpConversionPattern { return success(); } }; -struct PTOTAxpyToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAxpyOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value scalar = peelUnrealized(adaptor.getScalar()); - - rewriter.create( - loc, TypeRange{}, "TAXPY", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOHistogramToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::THistogramOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value idx = peelUnrealized(adaptor.getIdx()); - Value dst = peelUnrealized(adaptor.getDst()); - - StringRef histByte = "HistByte::BYTE_1"; - int64_t byte = 1; - auto byteAttr = op.getByteAttr(); - if (byteAttr) - byte = byteAttr.getInt(); - if (auto legacyIsMSB = op->getAttrOfType("isMSB")) { - int64_t legacyByte = legacyIsMSB.getValue() ? 1 : 0; - if (byteAttr && byte != legacyByte) - return rewriter.notifyMatchFailure( - op, "conflicting 'byte' and legacy 'isMSB' attributes"); - byte = legacyByte; - } - switch (byte) { - case 0: - histByte = "HistByte::BYTE_0"; - break; - case 1: - histByte = "HistByte::BYTE_1"; - break; - case 2: - histByte = "HistByte::BYTE_2"; - break; - case 3: - histByte = "HistByte::BYTE_3"; - break; - default: - return rewriter.notifyMatchFailure(op, "expected byte to be in range [0, 3]"); - } - - auto templateArgs = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, histByte)}); - rewriter.create( - loc, TypeRange{}, "THISTOGRAM", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, - /*operands=*/ValueRange{dst, src, idx}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOGetScaleAddrToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGetScaleAddrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TGET_SCALE_ADDR", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOSetValidShapeToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::SetValidShapeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto peelAllCasts = [](Value v) { - while (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(0); - if (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(); - return v; - }; - auto isTileLike = [](Value v) -> bool { - auto ot = dyn_cast(v.getType()); - if (!ot) - return false; - StringRef s = ot.getValue(); - return s.contains("Tile<") || s.contains("ConvTile<"); - }; - - Value src = peelAllCasts(peelUnrealized(adaptor.getSource())); - Value row = peelUnrealized(adaptor.getValidRow()); - Value col = peelUnrealized(adaptor.getValidCol()); - - if (!isTileLike(src)) - return rewriter.notifyMatchFailure( - op, "set_validshape source must lower to a tile-like value"); - - rewriter.create( - op.getLoc(), TypeRange{}, "PTOAS__TILE_SET_VALIDSHAPE", ArrayAttr{}, - ArrayAttr{}, ValueRange{src, row, col}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOGetValidShapeToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::GetValidShapeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto peelAllCasts = [](Value v) { - while (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(0); - if (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(); - return v; - }; - auto isTileLike = [](Value v) -> bool { - auto ot = dyn_cast(v.getType()); - if (!ot) - return false; - StringRef s = ot.getValue(); - return s.contains("Tile<") || s.contains("ConvTile<"); - }; - - Value src = peelAllCasts(peelUnrealized(adaptor.getSource())); - if (!isTileLike(src)) - return rewriter.notifyMatchFailure( - op, "get_validshape source must lower to a tile-like value"); - - auto resultTy = getTypeConverter()->convertType(rewriter.getIndexType()); - if (!resultTy) - return failure(); - - Value row = rewriter - .create( - op.getLoc(), resultTy, - "PTOAS__TILE_GET_VALID_ROW", ArrayAttr{}, - ArrayAttr{}, ValueRange{src}) - .getResult(0); - Value col = rewriter - .create( - op.getLoc(), resultTy, - "PTOAS__TILE_GET_VALID_COL", ArrayAttr{}, - ArrayAttr{}, ValueRange{src}) - .getResult(0); - rewriter.replaceOp(op, ValueRange{row, col}); - return success(); - } -}; - -struct PTOTAssignToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAssignOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto peelAllCasts = [](Value v) { - while (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(0); - if (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(); - return v; - }; - auto isTileLike = [](Value v) -> bool { - auto ot = dyn_cast(v.getType()); - if (!ot) - return false; - StringRef s = ot.getValue(); - return s.contains("Tile<") || s.contains("ConvTile<"); - }; - - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value tile = peelAllCasts(peelUnrealized(adaptor.getTile())); - if (!isTileLike(tile)) - return rewriter.notifyMatchFailure( - op, "tassign tile must lower to a tile-like value"); - - Value addr = peelUnrealized(adaptor.getAddr()); - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - if (isa(addr.getType()) || - (isa(addr.getType()) && - cast(addr.getType()).getValue().ends_with("*"))) { - auto rcU64 = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - addr = rewriter - .create(loc, u64Ty, "reinterpret_cast", - ArrayAttr{}, rcU64, - ValueRange{addr}) - .getResult(0); - } else if (addr.getType() != u64Ty) { - addr = rewriter.create(loc, u64Ty, addr).getResult(); - } - - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{tile, addr}); - rewriter.replaceOp(op, tile); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.load_scalar / pto.store_scalar lowering -> ptr[offset] -//===----------------------------------------------------------------------===// - -static Type getPointerLikeElementType(Type type) { - if (auto ptrTy = dyn_cast(type)) - return ptrTy.getElementType(); - if (auto memTy = dyn_cast(type)) - return memTy.getElementType(); - return Type(); -} - -struct PTOPtrToIntToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::PtrToIntOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value ptr = peelUnrealized(adaptor.getPtr()); - Type dstTy = getTypeConverter()->convertType(op.getResult().getType()); - if (!dstTy) - return failure(); - - auto dstOpaque = dyn_cast(dstTy); - if (!dstOpaque) - return failure(); - - auto templateArgs = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), - dstOpaque.getValue())}); - auto cast = rewriter.create( - op.getLoc(), dstTy, "reinterpret_cast", ArrayAttr{}, templateArgs, - ValueRange{ptr}); - rewriter.replaceOp(op, cast.getResult(0)); - return success(); - } -}; - -struct PTOIntToPtrToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::IntToPtrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value addr = peelUnrealized(adaptor.getAddr()); - Type dstTy = getTypeConverter()->convertType(op.getResult().getType()); - if (!dstTy) - return failure(); - - Type dstElemTy = getPointerLikeElementType(op.getResult().getType()); - if (!dstElemTy) - return failure(); - - std::string castType = - std::string("__gm__ ") + getEmitCScalarTypeToken(dstElemTy) + "*"; - auto templateArgs = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), - castType)}); - auto cast = rewriter.create( - op.getLoc(), dstTy, "reinterpret_cast", ArrayAttr{}, templateArgs, - ValueRange{addr}); - rewriter.replaceOp(op, cast.getResult(0)); - return success(); - } -}; - -struct PTOLoadScalarToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::LoadScalarOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value ptr = peelUnrealized(adaptor.getPtr()); - Value offset = peelUnrealized(adaptor.getOffset()); - - Type dstTy = getTypeConverter()->convertType(op.getValue().getType()); - if (!dstTy) - return failure(); - - auto call = rewriter.create( - op.getLoc(), TypeRange{dstTy}, "PTOAS__PTR_LOAD", - ArrayAttr{}, ArrayAttr{}, ValueRange{ptr, offset}); - - rewriter.replaceOp(op, call.getResults()); - return success(); - } -}; - -struct PTOStoreScalarToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::StoreScalarOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value ptr = peelUnrealized(adaptor.getPtr()); - Value offset = peelUnrealized(adaptor.getOffset()); - Value val = peelUnrealized(adaptor.getValue()); - - rewriter.create( - op.getLoc(), TypeRange{}, "PTOAS__PTR_STORE", - ArrayAttr{}, ArrayAttr{}, ValueRange{ptr, offset, val}); - rewriter.create( - op.getLoc(), TypeRange{}, "PTOAS__SCALAR_GM_STORE_FLUSH", - ArrayAttr{}, ArrayAttr{}, ValueRange{ptr}); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.tabs lowering -> TABS(dst, src) -//===----------------------------------------------------------------------===// - -struct PTOTAbsToTABS : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAbsOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - // intrinsic: TABS(dst, src) - rewriter.create( - op.getLoc(), TypeRange{}, "TABS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tadd lowering -> TADD(dst, src0, src1) -//===----------------------------------------------------------------------===// - -struct PTOTAddToTADD : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAddOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TADD", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOInitializeL2G2LPipeToEmitC - : public OpConversionPattern { - PTOInitializeL2G2LPipeToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult matchAndRewrite(mlir::pto::InitializeL2G2LPipeOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto tpipeTok = buildTPipeTokenFromInitOp(op.getOperation(), targetArch); - if (failed(tpipeTok)) - return rewriter.notifyMatchFailure(op, "failed to build TPipe token"); - - auto *ctx = rewriter.getContext(); - auto emitPipeTy = - cast(getTypeConverter()->convertType(op.getPipe().getType())); - - Value gmAddr = peelUnrealized(adaptor.getGmAddr()); - gmAddr = materializeTensorViewDataPointer( - rewriter, op.getLoc(), gmAddr, op.getGmAddr().getType()); - Value localAddr = - op.getLocalAddr() ? peelUnrealized(adaptor.getLocalAddr()) : Value(); - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value zero = makeEmitCIntConstant(rewriter, op.getLoc(), i32Ty, 0); - - Value c2vBuf = zero; - Value v2cBuf = zero; - if (op.getDirMask() == 1) - c2vBuf = localAddr ? localAddr : zero; - else if (op.getDirMask() == 2) - v2cBuf = localAddr ? localAddr : zero; - else if (op.getDirMask() == 3) { - if (localAddr) { - if (!op.getPeerLocalAddr()) - return rewriter.notifyMatchFailure( - op, "bidirectional l2g2l pipe requires peer local buffer"); - c2vBuf = localAddr; - v2cBuf = peelUnrealized(adaptor.getPeerLocalAddr()); - } - } else - return rewriter.notifyMatchFailure(op, "unsupported dir_mask"); - - rewriter.replaceOpWithNewOp( - op, TypeRange{emitPipeTy}, *tpipeTok, ArrayAttr{}, ArrayAttr{}, - ValueRange{gmAddr, c2vBuf, v2cBuf}); - return success(); - } - - PTOArch targetArch; -}; - -struct PTOInitializeL2LPipeToEmitC - : public OpConversionPattern { - PTOInitializeL2LPipeToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult matchAndRewrite(mlir::pto::InitializeL2LPipeOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto tpipeTok = buildTPipeTokenFromInitOp(op.getOperation(), targetArch); - if (failed(tpipeTok)) - return rewriter.notifyMatchFailure(op, "failed to build TPipe token"); - - auto *ctx = rewriter.getContext(); - auto emitPipeTy = - cast(getTypeConverter()->convertType(op.getPipe().getType())); - - auto gmPtrTy = - emitc::PointerType::get(emitc::OpaqueType::get(ctx, "__gm__ void")); - Value nullGm = - makeEmitCOpaqueConstant(rewriter, op.getLoc(), gmPtrTy, "nullptr"); - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value zero = makeEmitCIntConstant(rewriter, op.getLoc(), i32Ty, 0); - Value localAddr = peelUnrealized(adaptor.getLocalAddr()); - - Value c2vBuf = zero; - Value v2cBuf = zero; - if (op.getDirMask() == 1) - c2vBuf = localAddr; - else if (op.getDirMask() == 2) - v2cBuf = localAddr; - else if (op.getDirMask() == 3) { - c2vBuf = localAddr; - v2cBuf = peelUnrealized(adaptor.getPeerLocalAddr()); - } else - return rewriter.notifyMatchFailure(op, "unsupported dir_mask"); - - rewriter.replaceOpWithNewOp( - op, TypeRange{emitPipeTy}, *tpipeTok, ArrayAttr{}, ArrayAttr{}, - ValueRange{nullGm, c2vBuf, v2cBuf}); - return success(); - } - - PTOArch targetArch; -}; - -struct PTOBuildAsyncSessionToEmitC - : public OpConversionPattern { - PTOBuildAsyncSessionToEmitC(TypeConverter &typeConverter, MLIRContext *ctx) - : OpConversionPattern(typeConverter, ctx) {} - - LogicalResult matchAndRewrite(mlir::pto::BuildAsyncSessionOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto *ctx = rewriter.getContext(); - Location loc = op.getLoc(); - - auto sessionTy = - dyn_cast(getTypeConverter()->convertType(op.getSession().getType())); - if (!sessionTy) - return rewriter.notifyMatchFailure(op, "failed to convert async session type"); - - FailureOr scratchTile = - buildAsyncScratchTileValue(rewriter, loc, op.getScratch(), - adaptor.getScratch()); - if (failed(scratchTile)) - return rewriter.notifyMatchFailure(op, "failed to materialize async scratch tile"); - - Value workspace = - castToGMBytePointer(rewriter, loc, peelUnrealized(adaptor.getWorkspace())); - - Value session = rewriter - .create( - loc, sessionTy, emitc::OpaqueAttr::get(ctx, "")) - .getResult(); - - auto u32Ty = emitc::OpaqueType::get(ctx, "uint32_t"); - - auto makeU32Const = [&](uint64_t value) -> Value { - return makeEmitCOpaqueConstant(rewriter, loc, u32Ty, - std::to_string(value) + "u"); - }; - uint64_t syncId = op.getSyncIdAttr() ? op.getSyncIdAttr().getInt() : 0; - uint64_t blockBytes = - op.getBlockBytesAttr() ? op.getBlockBytesAttr().getInt() : 32 * 1024; - uint64_t commBlockOffset = - op.getCommBlockOffsetAttr() ? op.getCommBlockOffsetAttr().getInt() : 0; - uint64_t queueNum = op.getQueueNumAttr() ? op.getQueueNumAttr().getInt() : 1; - uint64_t channelGroupIdx = op.getChannelGroupIdxAttr() - ? op.getChannelGroupIdxAttr().getInt() - : UINT32_MAX; - - Value syncIdVal = makeU32Const(syncId); - Value channelGroupIdxVal = - channelGroupIdx == UINT32_MAX - ? makeEmitCOpaqueConstant(rewriter, loc, u32Ty, "UINT32_MAX") - : makeU32Const(channelGroupIdx); - - auto baseConfigTy = - emitc::OpaqueType::get(ctx, "pto::comm::sdma::SdmaBaseConfig"); - Value baseConfig = - rewriter - .create( - loc, baseConfigTy, - emitc::OpaqueAttr::get( - ctx, "{" + std::to_string(blockBytes) + "ULL, " + - std::to_string(commBlockOffset) + "ULL, " + - std::to_string(queueNum) + "u}")) - .getResult(); - - rewriter.create( - loc, TypeRange{}, "pto::comm::BuildAsyncSession", - ArrayAttr{}, ArrayAttr{}, - ValueRange{*scratchTile, workspace, session, syncIdVal, baseConfig, - channelGroupIdxVal}); - - rewriter.replaceOp(op, session); - return success(); - } -}; - -template -struct PTOAsyncTransferToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - explicit PTOAsyncTransferToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - StringRef callee) - : OpConversionPattern(typeConverter, ctx), callee(callee.str()) {} - - LogicalResult matchAndRewrite(AsyncOp op, typename AsyncOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value dstGT = dst; - Value srcGT = src; - if (!isEmitCGlobalTensorLikeType(dstGT.getType())) { - auto dstMrTy = dyn_cast(op.getDst().getType()); - if (!dstMrTy) - return rewriter.notifyMatchFailure(op, "expected dst to lower to GlobalTensor or memref"); - dstGT = buildGlobalTensorFromMemref(rewriter, op.getLoc(), dst, dstMrTy, - op.getDst().getDefiningOp() - ? op.getDst().getDefiningOp() - : op.getOperation()); - } - if (!isEmitCGlobalTensorLikeType(srcGT.getType())) { - auto srcMrTy = dyn_cast(op.getSrc().getType()); - if (!srcMrTy) - return rewriter.notifyMatchFailure(op, "expected src to lower to GlobalTensor or memref"); - srcGT = buildGlobalTensorFromMemref(rewriter, op.getLoc(), src, srcMrTy, - op.getSrc().getDefiningOp() - ? op.getSrc().getDefiningOp() - : op.getOperation()); - } - if (!dstGT || !srcGT) - return rewriter.notifyMatchFailure(op, "failed to build GlobalTensor operands"); - - Type eventTy = this->getTypeConverter()->convertType(op.getEvent().getType()); - if (!eventTy) - return rewriter.notifyMatchFailure(op, "failed to convert async event type"); - - rewriter.replaceOpWithNewOp( - op, TypeRange{eventTy}, callee, ArrayAttr{}, ArrayAttr{}, - ValueRange{dstGT, srcGT, peelUnrealized(adaptor.getSession())}); - return success(); - } - - std::string callee; -}; - -template -struct PTOAsyncEventToEmitC : public OpConversionPattern { - explicit PTOAsyncEventToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - StringRef callee) - : OpConversionPattern(typeConverter, ctx), - callee(callee.str()) {} - - LogicalResult matchAndRewrite(AsyncEventOp op, - typename AsyncEventOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type resultTy = - this->getTypeConverter()->convertType(op.getCompleted().getType()); - if (!resultTy) - return rewriter.notifyMatchFailure(op, "failed to convert async event result type"); - - rewriter.replaceOpWithNewOp( - op, TypeRange{resultTy}, callee, ArrayAttr{}, ArrayAttr{}, - ValueRange{peelUnrealized(adaptor.getEvent()), - peelUnrealized(adaptor.getSession())}); - return success(); - } - - std::string callee; -}; - -static FailureOr buildCommGlobalTensorValue( - ConversionPatternRewriter &rewriter, Location loc, Value originalValue, - Value emittedValue, Operation *anchor) { - Value value = peelUnrealized(emittedValue); - if (isEmitCGlobalTensorLikeType(value.getType())) - return value; - - auto memTy = dyn_cast(originalValue.getType()); - if (!memTy) - return failure(); - - Value gt = buildGlobalTensorFromMemref(rewriter, loc, value, memTy, anchor); - if (!gt) - return failure(); - return gt; -} - -static FailureOr buildCommTileValue(ConversionPatternRewriter &rewriter, - Location loc, Value originalValue, - Value emittedValue) { - Value value = peelUnrealized(emittedValue); - if (auto opaqueTy = dyn_cast(value.getType())) { - StringRef typeStr = opaqueTy.getValue(); - if (typeStr.contains("Tile<") || typeStr.contains("ConvTile<")) - return value; - } - return buildAsyncScratchTileValue(rewriter, loc, originalValue, emittedValue); -} - -static FailureOr buildCollectiveParallelGroup( - ConversionPatternRewriter &rewriter, Location loc, - ArrayRef groupGTs, int64_t root) { - if (groupGTs.empty()) - return failure(); - - auto firstTy = dyn_cast(groupGTs.front().getType()); - if (!firstTy) - return failure(); - - auto *ctx = rewriter.getContext(); - auto arrayTy = emitc::ArrayType::get({static_cast(groupGTs.size())}, - firstTy); - auto groupArray = cast>( - rewriter - .create(loc, arrayTy, - emitc::OpaqueAttr::get(ctx, "{}")) - .getResult()); - - auto indexTy = emitc::OpaqueType::get(ctx, "int"); - for (auto [idx, groupVal] : llvm::enumerate(groupGTs)) { - Value idxVal = - makeEmitCIntConstant(rewriter, loc, indexTy, static_cast(idx)); - Value slot = - rewriter.create(loc, groupArray, ValueRange{idxVal}) - .getResult(); - rewriter.create(loc, slot, groupVal); - } - - std::string pgTypeStr = - (Twine("pto::comm::ParallelGroup<") + firstTy.getValue() + ">").str(); - auto pgTy = emitc::OpaqueType::get(ctx, pgTypeStr); - Value sizeVal = makeEmitCIntConstant(rewriter, loc, indexTy, - static_cast(groupGTs.size())); - Value rootVal = makeEmitCIntConstant(rewriter, loc, indexTy, root); - return rewriter - .create( - loc, TypeRange{pgTy}, (Twine(pgTypeStr) + "::Create").str(), - ArrayAttr{}, ArrayAttr{}, ValueRange{groupArray, sizeVal, rootVal}) - .getResult(0); -} - -static std::string notifyOpTok(pto::NotifyOp op) { - switch (op) { - case pto::NotifyOp::AtomicAdd: - return "pto::comm::NotifyOp::AtomicAdd"; - case pto::NotifyOp::Set: - return "pto::comm::NotifyOp::Set"; - } - return "pto::comm::NotifyOp::Set"; -} - -static std::string waitCmpTok(pto::WaitCmp cmp) { - switch (cmp) { - case pto::WaitCmp::EQ: - return "pto::comm::WaitCmp::EQ"; - case pto::WaitCmp::NE: - return "pto::comm::WaitCmp::NE"; - case pto::WaitCmp::GT: - return "pto::comm::WaitCmp::GT"; - case pto::WaitCmp::GE: - return "pto::comm::WaitCmp::GE"; - case pto::WaitCmp::LT: - return "pto::comm::WaitCmp::LT"; - case pto::WaitCmp::LE: - return "pto::comm::WaitCmp::LE"; - } - return "pto::comm::WaitCmp::EQ"; -} - -static std::string reduceOpTok(pto::ReduceOp op) { - switch (op) { - case pto::ReduceOp::Sum: - return "pto::comm::ReduceOp::Sum"; - case pto::ReduceOp::Max: - return "pto::comm::ReduceOp::Max"; - case pto::ReduceOp::Min: - return "pto::comm::ReduceOp::Min"; - } - return "pto::comm::ReduceOp::Sum"; -} - -template -static FailureOr> buildCommGroupGlobalTensors( - ConversionPatternRewriter &rewriter, Location loc, OpTy op, - ValueRange originalGroup, ValueRange emittedGroup) { - SmallVector groupGTs; - groupGTs.reserve(originalGroup.size()); - for (auto [orig, emitted] : llvm::zip(originalGroup, emittedGroup)) { - FailureOr gt = - buildCommGlobalTensorValue(rewriter, loc, orig, emitted, op.getOperation()); - if (failed(gt)) - return failure(); - groupGTs.push_back(*gt); - } - return groupGTs; -} - -template -struct PTOCommCollectiveToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - explicit PTOCommCollectiveToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - StringRef apiName) - : OpConversionPattern(typeConverter, ctx), - apiName(apiName.str()) {} - - LogicalResult matchAndRewrite(CollectiveOp op, typename CollectiveOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - - auto buildPong = [&](Value original, Value emitted, StringRef name) -> FailureOr { - if (!original) - return failure(); - return buildCommTileValue(rewriter, loc, original, emitted); - }; - - if constexpr (std::is_same_v) { - FailureOr srcGT = - buildCommGlobalTensorValue(rewriter, loc, op.getSrc(), adaptor.getSrc(), - op.getOperation()); - FailureOr pingTile = - buildCommTileValue(rewriter, loc, op.getPing(), adaptor.getPing()); - auto groupGTs = - buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); - if (failed(srcGT) || failed(pingTile) || failed(groupGTs)) - return rewriter.notifyMatchFailure(op, "failed to materialize broadcast operands"); - FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); - if (failed(pg)) - return rewriter.notifyMatchFailure(op, "failed to materialize broadcast group"); - if (op.getPong()) { - FailureOr pongTile = - buildPong(op.getPong(), adaptor.getPong(), "__pong"); - if (failed(pongTile)) - return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); - rewriter.create( - loc, TypeRange{}, "pto::comm::TBROADCAST", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *srcGT, *pingTile, *pongTile}); - } else { - rewriter.create( - loc, TypeRange{}, "pto::comm::TBROADCAST", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *srcGT, *pingTile}); - } - } else if constexpr (std::is_same_v) { - FailureOr dstGT = - buildCommGlobalTensorValue(rewriter, loc, op.getDst(), adaptor.getDst(), - op.getOperation()); - FailureOr pingTile = - buildCommTileValue(rewriter, loc, op.getPing(), adaptor.getPing()); - auto groupGTs = - buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); - if (failed(dstGT) || failed(pingTile) || failed(groupGTs)) - return rewriter.notifyMatchFailure(op, "failed to materialize gather operands"); - FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); - if (failed(pg)) - return rewriter.notifyMatchFailure(op, "failed to materialize gather group"); - if (op.getPong()) { - FailureOr pongTile = - buildPong(op.getPong(), adaptor.getPong(), "__pong"); - if (failed(pongTile)) - return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); - rewriter.create( - loc, TypeRange{}, "pto::comm::TGATHER", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *dstGT, *pingTile, *pongTile}); - } else { - rewriter.create( - loc, TypeRange{}, "pto::comm::TGATHER", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *dstGT, *pingTile}); - } - } else if constexpr (std::is_same_v) { - FailureOr srcGT = - buildCommGlobalTensorValue(rewriter, loc, op.getSrc(), adaptor.getSrc(), - op.getOperation()); - FailureOr pingTile = - buildCommTileValue(rewriter, loc, op.getPing(), adaptor.getPing()); - auto groupGTs = - buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); - if (failed(srcGT) || failed(pingTile) || failed(groupGTs)) - return rewriter.notifyMatchFailure(op, "failed to materialize scatter operands"); - FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); - if (failed(pg)) - return rewriter.notifyMatchFailure(op, "failed to materialize scatter group"); - if (op.getPong()) { - FailureOr pongTile = - buildPong(op.getPong(), adaptor.getPong(), "__pong"); - if (failed(pongTile)) - return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); - rewriter.create( - loc, TypeRange{}, "pto::comm::TSCATTER", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *srcGT, *pingTile, *pongTile}); - } else { - rewriter.create( - loc, TypeRange{}, "pto::comm::TSCATTER", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *srcGT, *pingTile}); - } - } else { - FailureOr dstGT = - buildCommGlobalTensorValue(rewriter, loc, op.getDst(), adaptor.getDst(), - op.getOperation()); - FailureOr accTile = - buildCommTileValue(rewriter, loc, op.getAcc(), adaptor.getAcc()); - FailureOr recvPing = - buildCommTileValue(rewriter, loc, op.getRecvPing(), adaptor.getRecvPing()); - auto groupGTs = - buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); - if (failed(dstGT) || failed(accTile) || failed(recvPing) || failed(groupGTs)) - return rewriter.notifyMatchFailure(op, "failed to materialize reduce operands"); - FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); - if (failed(pg)) - return rewriter.notifyMatchFailure(op, "failed to materialize reduce group"); - if (op.getRecvPong()) { - FailureOr recvPong = - buildPong(op.getRecvPong(), adaptor.getRecvPong(), "__recv_pong"); - if (failed(recvPong)) - return rewriter.notifyMatchFailure(op, "failed to materialize recv_pong"); - auto reduceTy = - emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::ReduceOp"); - Value reduceOp = makeEmitCOpaqueConstant(rewriter, loc, reduceTy, - reduceOpTok(op.getReduceOp())); - rewriter.create( - loc, TypeRange{}, "pto::comm::TREDUCE", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *dstGT, *accTile, *recvPing, *recvPong, reduceOp}); - } else { - auto reduceTy = - emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::ReduceOp"); - Value reduceOp = makeEmitCOpaqueConstant(rewriter, loc, reduceTy, - reduceOpTok(op.getReduceOp())); - rewriter.create( - loc, TypeRange{}, "pto::comm::TREDUCE", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *dstGT, *accTile, *recvPing, reduceOp}); - } - } - rewriter.eraseOp(op); - return success(); - } - - std::string apiName; -}; - -template -struct PTOP2PCommToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - explicit PTOP2PCommToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - StringRef callee) - : OpConversionPattern(typeConverter, ctx), callee(callee.str()) {} - - LogicalResult matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - FailureOr dstGT = - buildCommGlobalTensorValue(rewriter, op.getLoc(), op.getDst(), adaptor.getDst(), - op.getOperation()); - FailureOr srcGT = - buildCommGlobalTensorValue(rewriter, op.getLoc(), op.getSrc(), adaptor.getSrc(), - op.getOperation()); - FailureOr pingTile = - buildCommTileValue(rewriter, op.getLoc(), op.getPing(), adaptor.getPing()); - if (failed(dstGT) || failed(srcGT) || failed(pingTile)) - return rewriter.notifyMatchFailure(op, "failed to materialize p2p operands"); - - SmallVector operands{*dstGT, *srcGT, *pingTile}; - std::string actualCallee = callee; - if constexpr (std::is_same_v) { - if (op.getAtomicType() == pto::AtomicType::AtomicAdd) - actualCallee = "pto::comm::TPUT"; - } - if (op.getPong()) { - FailureOr pongTile = - buildCommTileValue(rewriter, op.getLoc(), op.getPong(), adaptor.getPong()); - if (failed(pongTile)) - return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); - operands.push_back(*pongTile); - } - - rewriter.create(op.getLoc(), TypeRange{}, actualCallee, - ArrayAttr{}, ArrayAttr{}, operands); - rewriter.eraseOp(op); - return success(); - } - - std::string callee; -}; - -template -struct PTOSignalCommToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - explicit PTOSignalCommToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - StringRef callee) - : OpConversionPattern(typeConverter, ctx), - callee(callee.str()) {} - - LogicalResult matchAndRewrite(SignalOp op, typename SignalOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - FailureOr signalGT = buildCommGlobalTensorValue( - rewriter, op.getLoc(), op.getSignal(), adaptor.getSignal(), op.getOperation()); - if (failed(signalGT)) - return rewriter.notifyMatchFailure(op, "failed to materialize signal operand"); - - if constexpr (std::is_same_v) { - auto notifyTy = - emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::NotifyOp"); - Value notifyOp = makeEmitCOpaqueConstant( - rewriter, op.getLoc(), notifyTy, notifyOpTok(op.getNotifyOp())); - SmallVector operands{*signalGT, peelUnrealized(adaptor.getValue()), - notifyOp}; - rewriter.create(op.getLoc(), TypeRange{}, callee, - ArrayAttr{}, ArrayAttr{}, operands); - rewriter.eraseOp(op); - } else { - auto waitCmpTy = - emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::WaitCmp"); - Value waitCmp = makeEmitCOpaqueConstant( - rewriter, op.getLoc(), waitCmpTy, waitCmpTok(op.getCmp())); - SmallVector operands{*signalGT, peelUnrealized(adaptor.getCmpValue()), - waitCmp}; - if constexpr (std::is_same_v) { - Type resultTy = this->getTypeConverter()->convertType(op.getResult().getType()); - if (!resultTy) - return rewriter.notifyMatchFailure(op, "failed to convert ttest result type"); - rewriter.replaceOpWithNewOp( - op, TypeRange{resultTy}, callee, ArrayAttr{}, ArrayAttr{}, operands); - } else { - rewriter.create(op.getLoc(), TypeRange{}, callee, - ArrayAttr{}, ArrayAttr{}, operands); - rewriter.eraseOp(op); - } - } - return success(); - } - - std::string callee; -}; - -struct PTODeclareTileMemRefToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::DeclareTileMemRefOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::DeclareTileMemRefOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - Type convertedType = getTypeConverter()->convertType(op.getResult().getType()); - if (!convertedType) - return rewriter.notifyMatchFailure( - op, "failed to convert declare_tile_memref result type"); - rewriter.replaceOp(op, makeEmitCOpaqueConstant(rewriter, op.getLoc(), - convertedType, "nullptr")); - return success(); - } -}; - -struct PTODeclareGlobalToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::DeclareGlobalOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::DeclareGlobalOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - Type convertedType = getTypeConverter()->convertType(op.getEntry().getType()); - if (!convertedType) - return rewriter.notifyMatchFailure( - op, "failed to convert declare_global result type"); - if (auto tvTy = dyn_cast(op.getEntry().getType())) { - if (auto stridesAttr = - op->getAttrOfType(kGlobalTensorStridesAttrName)) { - auto strides = stridesAttr.asArrayRef(); - if (strides.size() == static_cast(tvTy.getRank())) { - convertedType = emitc::OpaqueType::get( - rewriter.getContext(), - getGlobalTensorTypeStringFromShapeAndStrides( - tvTy.getElementType(), tvTy.getShape(), strides)); - } - } - } - auto var = rewriter.create( - op.getLoc(), convertedType, - emitc::OpaqueAttr::get(rewriter.getContext(), "")); - rewriter.replaceOp(op, var.getResult()); - return success(); - } -}; - -struct PTODeclareEventIdArrayToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::DeclareEventIdArrayOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::DeclareEventIdArrayOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - Type arrayTy = getTypeConverter()->convertType(op.getArray().getType()); - if (!arrayTy) - return rewriter.notifyMatchFailure(op, - "failed to map declared eventid_array type"); - - auto array = rewriter - .create( - op.getLoc(), arrayTy, - emitc::OpaqueAttr::get(rewriter.getContext(), "")) - .getResult(); - rewriter.replaceOp(op, array); - return success(); - } -}; - -struct PTOEventIdArrayGetToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::EventIdArrayGetOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::EventIdArrayGetOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value array = peelUnrealized(adaptor.getArray()); - Value index = peelUnrealized(adaptor.getIndex()); - - Type resultTy = getTypeConverter()->convertType(op.getResult().getType()); - if (!resultTy) - return rewriter.notifyMatchFailure(op, - "failed to map eventid_array get result type"); - - auto load = - rewriter.create(op.getLoc(), resultTy, array, index); - rewriter.replaceOp(op, load.getResult()); - return success(); - } -}; - -struct PTOEventIdArraySetToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::EventIdArraySetOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::EventIdArraySetOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value array = peelUnrealized(adaptor.getArray()); - Value index = peelUnrealized(adaptor.getIndex()); - Value value = peelUnrealized(adaptor.getValue()); - - rewriter.create( - op.getLoc(), TypeRange{}, "PTOAS__EVENTID_ARRAY_STORE", - ArrayAttr{}, ArrayAttr{}, ValueRange{array, index, value}); - rewriter.eraseOp(op); - return success(); - } -}; - -// pto.declare_local_array -> emitc.variable of !emitc.array<...>. -// Renders as `T a[D1][D2]...;` in the emitted C++. -struct PTODeclareLocalArrayToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::DeclareLocalArrayOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::DeclareLocalArrayOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - Type arrayTy = getTypeConverter()->convertType(op.getArray().getType()); - if (!arrayTy) - return rewriter.notifyMatchFailure(op, - "failed to map !pto.local_array type"); - - auto var = rewriter - .create( - op.getLoc(), arrayTy, - emitc::OpaqueAttr::get(rewriter.getContext(), "")) - .getResult(); - rewriter.replaceOp(op, var); - return success(); - } -}; - -// pto.local_array_get %a[%i0, %i1, ...] -> rvalue. -// Lowers to a single emitc.subscript with the full index pack; the C++ emitter -// prints it as `a[i0][i1]...`. The adaptor already exposes target-typed values -// (the type converter has remapped !pto.local_array -> !emitc.array and -// index/integer indices), so they're forwarded directly to the builder. -struct PTOLocalArrayGetToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::LocalArrayGetOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::LocalArrayGetOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type resultTy = - getTypeConverter()->convertType(op.getResult().getType()); - if (!resultTy) - return rewriter.notifyMatchFailure( - op, "failed to map local_array element type"); - - auto sub = rewriter.create( - op.getLoc(), resultTy, adaptor.getArray(), adaptor.getIndices()); - rewriter.replaceOp(op, sub.getResult()); - return success(); - } -}; - -// pto.local_array_set %a[%i0, %i1, ...], %v -> emitc.assign to subscript slot. -// The C++ emitter prints this as `a[i0][i1]... = v;`. As above, adaptor values -// are already target-typed; pass them through directly. -struct PTOLocalArraySetToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::LocalArraySetOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::LocalArraySetOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value value = adaptor.getValue(); - Type elemTy = value.getType(); - - Value slot = rewriter - .create(op.getLoc(), elemTy, - adaptor.getArray(), - adaptor.getIndices()) - .getResult(); - rewriter.create(op.getLoc(), slot, value); - rewriter.eraseOp(op); - return success(); - } -}; - -static std::optional getStaticIndexLikeValue(Value value) { - if (!value) - return std::nullopt; - if (auto cst = value.getDefiningOp()) - return cst.value(); - if (auto cst = value.getDefiningOp()) - return cst.value(); - if (auto cst = value.getDefiningOp()) { - if (auto intAttr = dyn_cast(cst.getValue())) - return intAttr.getInt(); - } - return std::nullopt; -} - -static FailureOr buildGlobalTensorViewFromPointer( - ConversionPatternRewriter &rewriter, Location loc, Value ptr, Type elemTy, - ArrayRef shape, ArrayRef strides = {}, - StringRef layoutEnum = "pto::Layout::ND") { - if (llvm::any_of(shape, [](int64_t dim) { - return dim == ShapedType::kDynamic; - })) - return failure(); - - auto *ctx = rewriter.getContext(); - SmallVector rowMajorStrides; - ArrayRef effectiveStrides = strides; - if (effectiveStrides.empty()) { - rowMajorStrides = buildRowMajorStrides(shape); - effectiveStrides = rowMajorStrides; - } - SmallVector shape5D; - SmallVector stride5D; - buildGlobalTensorShapeAndStride(shape, effectiveStrides, shape5D, stride5D); - - std::string shapeType = "pto::Shape<" + joinIntTemplateParams(shape5D) + ">"; - std::string strideType = - "pto::Stride<" + joinIntTemplateParams(stride5D) + ">"; - auto shapeVal = rewriter - .create( - loc, emitc::OpaqueType::get(ctx, shapeType), - shapeType, ArrayAttr{}, ArrayAttr{}, ValueRange{}) - .getResult(0); - auto strideVal = rewriter - .create( - loc, emitc::OpaqueType::get(ctx, strideType), - strideType, ArrayAttr{}, ArrayAttr{}, ValueRange{}) - .getResult(0); - - std::string gtTypeStr = - getGlobalTensorTypeStringFromShapeAndStrides(elemTy, shape, - effectiveStrides, - layoutEnum); - auto gtType = emitc::OpaqueType::get(ctx, gtTypeStr); - auto gt = rewriter.create( - loc, gtType, gtTypeStr, ArrayAttr{}, ArrayAttr{}, - ValueRange{ptr, shapeVal, strideVal}); - return gt.getResult(0); -} - -static bool parseIntegerTemplateList(StringRef token, StringRef marker, - SmallVectorImpl &values) { - size_t pos = token.find(marker); - if (pos == StringRef::npos) - return false; - pos += marker.size(); - size_t end = token.find('>', pos); - if (end == StringRef::npos) - return false; - - SmallVector parts; - token.slice(pos, end).split(parts, ','); - values.clear(); - for (StringRef part : parts) { - int64_t value = 0; - if (part.trim().getAsInteger(10, value)) - return false; - values.push_back(value); - } - return true; -} - -static LogicalResult getStaticTensorViewStrides( - Value source, Value convertedSource, pto::TensorViewType sourceType, - SmallVectorImpl &strides) { - int64_t rank = sourceType.getRank(); - strides.clear(); - - if (auto makeView = source.getDefiningOp()) { - if ((int64_t)makeView.getStrides().size() != rank) - return failure(); - for (Value strideValue : makeView.getStrides()) { - auto cst = getStaticIndexLikeValue(strideValue); - if (!cst) - return failure(); - strides.push_back(*cst); - } - return success(); - } - - Value src = peelUnrealized(convertedSource); - if (auto opaqueTy = dyn_cast(src.getType())) { - SmallVector stride5D; - StringRef token = opaqueTy.getValue(); - if ((parseIntegerTemplateList(token, "pto::Stride<", stride5D) || - parseIntegerTemplateList(token, "Stride<", stride5D)) && - (int64_t)stride5D.size() >= rank) { - strides.append(stride5D.end() - rank, stride5D.end()); - return success(); - } - } - - auto fallback = buildRowMajorStrides(sourceType.getShape()); - strides.append(fallback.begin(), fallback.end()); - return success(); -} - -struct PTOPartitionViewToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::PartitionViewOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::PartitionViewOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto srcTy = dyn_cast(op.getSource().getType()); - auto resTy = dyn_cast(op.getResult().getType()); - if (!srcTy || !resTy) - return rewriter.notifyMatchFailure( - op, "expected tensor_view source and partition_tensor_view result"); - - if (op.getOffsets().size() != static_cast(srcTy.getRank()) || - op.getSizes().size() != static_cast(srcTy.getRank())) - return rewriter.notifyMatchFailure(op, "rank mismatch"); - - for (auto [idx, value] : llvm::enumerate(op.getSizes())) { - auto cst = getStaticIndexLikeValue(value); - if (!cst) - return rewriter.notifyMatchFailure( - op, "globaltensor partition_view requires static sizes"); - int64_t resultDim = resTy.getShape()[idx]; - if (resultDim != ShapedType::kDynamic && resultDim != *cst) - return rewriter.notifyMatchFailure( - op, "partition_view static size does not match result type"); - } - - SmallVector srcStrides; - if (failed(getStaticTensorViewStrides(op.getSource(), adaptor.getSource(), - srcTy, srcStrides))) - return rewriter.notifyMatchFailure( - op, "partition_view requires static source strides"); - int64_t staticLinearOffset = 0; - SmallVector> dynamicOffsetTerms; - for (auto [idx, values] : - llvm::enumerate(llvm::zip(op.getOffsets(), adaptor.getOffsets()))) { - Value originalOffset = std::get<0>(values); - Value convertedOffset = std::get<1>(values); - int64_t stride = srcStrides[idx]; - if (stride == ShapedType::kDynamic) - return rewriter.notifyMatchFailure( - op, "dynamic source stride is not supported"); - - if (auto cst = getStaticIndexLikeValue(originalOffset)) { - if (*cst != 0) - staticLinearOffset += (*cst) * stride; - continue; - } - dynamicOffsetTerms.push_back({convertedOffset, stride}); - } - - auto *ctx = rewriter.getContext(); - std::string elemTypeStr = getElemTypeStringForGT(srcTy.getElementType()); - auto ptrTy = emitc::PointerType::get( - emitc::OpaqueType::get(ctx, "__gm__ " + elemTypeStr)); - Value src = peelUnrealized(adaptor.getSource()); - auto data = rewriter - .create( - op.getLoc(), ptrTy, "PTOAS__GLOBAL_TENSOR_DATA", - ArrayAttr{}, ArrayAttr{}, ValueRange{src}) - .getResult(0); - Value ptr = data; - if (!dynamicOffsetTerms.empty()) { - Type u32Ty = emitc::OpaqueType::get(ctx, "unsigned"); - auto makeU32 = [&](int64_t value) { - return makeEmitCIntConstant(rewriter, op.getLoc(), u32Ty, value); - }; - auto asU32 = [&](Value value) -> Value { - if (value.getType() == u32Ty) - return value; - return rewriter.create(op.getLoc(), u32Ty, value) - .getResult(); - }; - - Value totalOffset = makeU32(staticLinearOffset); - for (auto [offsetValue, stride] : dynamicOffsetTerms) { - Value term = asU32(offsetValue); - if (stride != 1) { - Value strideValue = makeU32(stride); - term = rewriter - .create(op.getLoc(), u32Ty, term, - strideValue) - .getResult(); - } - totalOffset = rewriter - .create(op.getLoc(), u32Ty, - totalOffset, term) - .getResult(); - } - ptr = rewriter - .create(op.getLoc(), data.getType(), data, - totalOffset) - .getResult(); - } else { - ptr = applyStaticMemrefOffset(rewriter, op.getLoc(), data, - staticLinearOffset); - } - - auto resultOr = buildGlobalTensorViewFromPointer( - rewriter, op.getLoc(), ptr, resTy.getElementType(), resTy.getShape(), - srcStrides); - if (failed(resultOr)) - return rewriter.notifyMatchFailure( - op, "failed to materialize partition GlobalTensor"); - - rewriter.replaceOp(op, *resultOr); - return success(); - } -}; - -static FailureOr getPipeDataTypeToken(Value value) { - auto opaqueTy = dyn_cast(value.getType()); - if (!opaqueTy) - return failure(); - StringRef token = opaqueTy.getValue(); - if (!token.contains("Tile<") && !token.contains("GlobalTensor<")) - return failure(); - return token.str(); -} - -struct PTOTAllocToEmitC : public OpConversionPattern { - PTOTAllocToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult matchAndRewrite(mlir::pto::TAllocOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); - if (failed(pipeTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); - Value entry = peelUnrealized(adaptor.getEntry()); - auto entryTok = getPipeDataTypeToken(entry); - if (failed(entryTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve entry token"); - auto splitTok = getTileSplitToken(op.getSplit()); - if (failed(splitTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve split token"); - - std::string callee = - "TALLOC<" + *pipeTok + ", " + *entryTok + ", " + *splitTok + ">"; - rewriter.replaceOpWithNewOp( - op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, - ValueRange{peelUnrealized(adaptor.getPipeHandle()), entry}); - return success(); - } - - PTOArch targetArch; -}; - -struct PTOTPushToEmitC : public OpConversionPattern { - PTOTPushToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult matchAndRewrite(mlir::pto::TPushOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); - if (failed(pipeTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); - // Read the tile type token from the already-converted OpaqueType, which - // preserves the exact layout produced by BindTileOp / PointerCastOp EmitC. - Value convertedTile = peelUnrealized(adaptor.getTile()); - auto tileTok = getPipeDataTypeToken(convertedTile); - if (failed(tileTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve tile token"); - auto splitTok = getTileSplitToken(op.getSplit()); - if (failed(splitTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve split token"); - - std::string callee = - "TPUSH<" + *pipeTok + ", " + *tileTok + ", " + *splitTok + ">"; - rewriter.replaceOpWithNewOp( - op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, - ValueRange{peelUnrealized(adaptor.getPipeHandle()), convertedTile}); - return success(); - } - - PTOArch targetArch; -}; - -struct PTOTPopToEmitC : public OpConversionPattern { - PTOTPopToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult matchAndRewrite(mlir::pto::TPopOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); - if (failed(pipeTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); - Value convertedTile = peelUnrealized(adaptor.getTile()); - auto tileTok = getPipeDataTypeToken(convertedTile); - if (failed(tileTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve tile token"); - auto splitTok = getTileSplitToken(op.getSplit()); - if (failed(splitTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve split token"); - - std::string callee = - "TPOP<" + *pipeTok + ", " + *tileTok + ", " + *splitTok + ">"; - rewriter.replaceOpWithNewOp( - op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, - ValueRange{peelUnrealized(adaptor.getPipeHandle()), convertedTile}); - return success(); - } - - PTOArch targetArch; -}; - -struct PTOTFreeToEmitC : public OpConversionPattern { - PTOTFreeToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult matchAndRewrite(mlir::pto::TFreeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); - if (failed(pipeTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); - auto splitTok = getTileSplitToken(op.getSplit()); - if (failed(splitTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve split token"); - - SmallVector operands{peelUnrealized(adaptor.getPipeHandle())}; - std::string callee; - if (op.getEntry()) { - Value entry = peelUnrealized(adaptor.getEntry()); - auto entryTok = getPipeDataTypeToken(entry); - if (failed(entryTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve entry token"); - callee = "TFREE<" + *pipeTok + ", " + *entryTok + ", " + *splitTok + ">"; - operands.push_back(entry); - } else { - callee = "TFREE<" + *pipeTok + ", " + *splitTok + ">"; - } - rewriter.replaceOpWithNewOp( - op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, operands); - return success(); - } - - PTOArch targetArch; -}; - -//===----------------------------------------------------------------------===// -// populate patterns -//===----------------------------------------------------------------------=== -struct ReinterpretCastToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - auto resMrTy = dyn_cast(op.getType()); - if (!resMrTy) - return failure(); - - auto asAttr = dyn_cast_or_null(resMrTy.getMemorySpace()); - const bool isGm = (!asAttr || asAttr.getAddressSpace() == pto::AddressSpace::GM); - - bool emitAddPtrTrace = op->hasAttr("pto.addptr_trace"); - Value source = peelUnrealized(adaptor.getSource()); - auto offsets = adaptor.getOffsets(); - Value offsetVal = offsets.empty() ? Value() : offsets[0]; - - // GM: keep pointer arithmetic. - if (isGm) { - if (!offsetVal) { - rewriter.replaceOp(op, source); - return success(); - } - - Type resultType = getTypeConverter()->convertType(op.getType()); - if (!resultType) - return failure(); - - auto addOp = rewriter.create(loc, resultType, source, offsetVal); - if (emitAddPtrTrace) { - rewriter.setInsertionPointAfter(addOp); - rewriter.create( - loc, TypeRange{}, "PTOAS__ADDPTR_TRACE", - ArrayAttr{}, ArrayAttr{}, - ValueRange{addOp.getResult(), source, offsetVal}); - } - rewriter.replaceOp(op, addOp.getResult()); - return success(); - } - - // UB/L1/L0 tiles: materialize a new Tile view by assigning an adjusted - // underlying pointer (in elements). - pto::AddressSpace as = asAttr.getAddressSpace(); - - // Element type token. - Type elemTy = resMrTy.getElementType(); - std::string elemTok = getEmitCScalarTypeToken(elemTy); - int64_t elemBytes = getEmitCScalarByteWidth(elemTy); - - // Tile role. - const char *roleTok = "TileType::Vec"; - switch (as) { - case pto::AddressSpace::VEC: - roleTok = "TileType::Vec"; - break; - case pto::AddressSpace::MAT: - roleTok = "TileType::Mat"; - break; - case pto::AddressSpace::LEFT: - roleTok = "TileType::Left"; - break; - case pto::AddressSpace::RIGHT: - roleTok = "TileType::Right"; - break; - case pto::AddressSpace::ACC: - roleTok = "TileType::Acc"; - break; - case pto::AddressSpace::BIAS: - roleTok = "TileType::Bias"; - break; - case pto::AddressSpace::GM: - roleTok = "TileType::Vec"; - break; - case pto::AddressSpace::Zero: - roleTok = "TileType::Vec"; - break; - case pto::AddressSpace::SCALING: - roleTok = "TileType::Scaling"; - break; - } - - // Shape (fallback to 32x32). - int64_t rows = 32, cols = 32; - if (resMrTy.getRank() >= 2 && resMrTy.hasStaticShape()) { - rows = resMrTy.getDimSize(0); - cols = resMrTy.getDimSize(1); - } - int64_t templateRows = - renderTileTemplateDim(rows, elemTy, pto::BLayout::RowMajor, 0); - int64_t templateCols = - renderTileTemplateDim(cols, elemTy, pto::BLayout::RowMajor, 1); - - // Keep a conservative default config for now. - std::string tileTypeStr = - std::string("Tile<") + roleTok + ", " + elemTok + ", " + - std::to_string(templateRows) + ", " + std::to_string(templateCols) + - ", BLayout::RowMajor, " + std::to_string(templateRows) + ", " + - std::to_string(templateCols) + - ", SLayout::NoneBox, 512, PadValue::Null, CompactMode::Null>"; - - auto tileType = emitc::OpaqueType::get(ctx, tileTypeStr); - Value tile = rewriter - .create(loc, tileType, - emitc::OpaqueAttr::get(ctx, "")) - .getResult(); - - // Compute an integer address and assign it to the new tile. - // NOTE: pto-isa TASSIGN requires an integral address (not a pointer). - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - auto rcU64 = rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - - // Non-GM reinterpret_cast operands come from UB/L1/L0 tiles. - // We need the underlying address, but `__cce_get_tile_ptr()` is only valid - // inside `__tf__` functions. Use `tile.data()` (via a post-processed marker) - // and compute the adjusted address in bytes. - Value rawPtr = source; - if (auto ot = dyn_cast(source.getType())) { - // Only Tiles have a `.data()` member. For plain address-space pointers - // (e.g. `__ubuf__ float*`), use the pointer value directly. - if (ot.getValue().starts_with("Tile<")) { - rawPtr = materializeTileDataValue(rewriter, loc, source, as, elemTok); - } - } - - Value baseAddr = rawPtr; - if (isSetFFTsPointerLikeType(rawPtr.getType())) { - baseAddr = rewriter - .create(loc, u64Ty, "reinterpret_cast", - /*args=*/ArrayAttr{}, - /*templateArgs=*/rcU64, - /*operands=*/ValueRange{rawPtr}) - .getResult(0); - } else if (rawPtr.getType() != u64Ty) { - baseAddr = rewriter.create(loc, u64Ty, rawPtr).getResult(); - } - - Value addr = baseAddr; - if (offsetVal) { - Value offU64 = offsetVal; - if (offU64.getType() != u64Ty) - offU64 = rewriter.create(loc, u64Ty, offU64).getResult(); - - auto bytesAttr = emitc::OpaqueAttr::get(ctx, std::to_string(elemBytes)); - Value bytesVal = rewriter.create(loc, u64Ty, bytesAttr); - Value byteOff = rewriter.create(loc, u64Ty, offU64, bytesVal); - addr = rewriter.create(loc, u64Ty, baseAddr, byteOff); - } - - rewriter.create(loc, TypeRange{}, "TASSIGN", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{tile, addr}); - - rewriter.replaceOp(op, tile); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.taddc lowering -> TADDC(dst, src0, src1, src2) -//===----------------------------------------------------------------------===// - -struct PTOTAddCToTADDC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAddCOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value src2 = peelUnrealized(adaptor.getSrc2()); - Value dst = peelUnrealized(adaptor.getDst()); - - // pto-isa does not provide NPU implementation for TADDC yet. - // Decompose: dst = src0 + src1 + src2 - rewriter.create( - loc, TypeRange{}, "TADD", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TADD", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, dst, src2}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tadds lowering -> TADDS(dst, src, scalar) -//===----------------------------------------------------------------------===// - -struct PTOAddSToTADDS : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAddSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value scalar = peelUnrealized(adaptor.getScalar()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TADDS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src, scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.taddsc lowering -> TADDSC(dst, src0, scalar, src1) -//===----------------------------------------------------------------------===// - -struct PTOAddSCToTADDSC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAddSCOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - // pto-isa does not provide NPU implementation for TADDSC yet. - // Decompose: dst = src0 + scalar + src1 - rewriter.create( - loc, TypeRange{}, "TADDS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src0, scalar}); - rewriter.create( - loc, TypeRange{}, "TADD", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, dst, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTOTAndToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAndOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getSrc0()); - Value b = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TAND", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, a, b}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOConcatToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TConcatOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TCONCAT", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTOConcatidxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TConcatidxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value src0Idx = peelUnrealized(adaptor.getSrc0Idx()); - Value src1Idx = peelUnrealized(adaptor.getSrc1Idx()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TCONCAT", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src0, src1, src0Idx, src1Idx}); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTOAndSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAndSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TANDS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src, scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; - - -struct PTOTCIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TCIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value S = peelUnrealized(adaptor.getOperands()[0]); - - // The TCI scalar template parameter should follow the original PTO IR - // scalar type, not the converted EmitC value type. - std::string scalarTok = "int32_t"; - if (auto it = dyn_cast(op->getOperand(0).getType())) { - bool isUnsigned = it.isUnsigned(); - if (it.getWidth() == 16) - scalarTok = isUnsigned ? "uint16_t" : "int16_t"; - else - scalarTok = isUnsigned ? "uint32_t" : "int32_t"; - } - - // descending -> "0"/"1" - std::string descTok = op.getDescending() ? "1" : "0"; - - ArrayAttr targs; - if (auto ot = mlir::dyn_cast(dst.getType())) { - std::string tileTok = ot.getValue().str(); // "Tile<...>" - targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, tileTok), - emitc::OpaqueAttr::get(ctx, scalarTok), - emitc::OpaqueAttr::get(ctx, descTok), - }); - } else { - targs = rewriter.getArrayAttr({}); - } - - rewriter.create( - loc, TypeRange{}, "TCI", - /*args=*/ArrayAttr{}, - /*templateArgs=*/targs, - /*operands=*/ValueRange{dst, S}); - - rewriter.eraseOp(op); - return success(); - } -}; -static std::string cmpModeTok(pto::CmpModeAttr a) { - // 生成 "CmpMode::GT" 这种 token - auto m = a.getValue(); // 取 enum - switch (m) { - case pto::CmpMode::EQ: return "CmpMode::EQ"; - case pto::CmpMode::NE: return "CmpMode::NE"; - case pto::CmpMode::LT: return "CmpMode::LT"; - case pto::CmpMode::LE: return "CmpMode::LE"; - case pto::CmpMode::GT: return "CmpMode::GT"; - case pto::CmpMode::GE: return "CmpMode::GE"; - } - return "CmpMode::EQ"; -} -struct PTOColExpandToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPAND", - /*args=*/ArrayAttr(), - /*templateArgs=*/ArrayAttr(), - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandMulToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandMulOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDMUL", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandAddToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandAddOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDADD", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandDivToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandDivOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDDIV", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandExpdifToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandExpdifOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDEXPDIF", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandSubToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandSubOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDSUB", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDMAX", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDMIN", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOTTriToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TTriOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value diagonal = peelUnrealized(adaptor.getDiagonal()); - - ArrayAttr templateArgs; - if (auto dstOT = mlir::dyn_cast(dst.getType())) { - templateArgs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, std::to_string(op.getUpperOrLower())), - }); - } else { - templateArgs = ArrayAttr{}; - } - - SmallVector operands{dst, diagonal}; - rewriter.create( - loc, TypeRange{}, "TTRI", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOCmpToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TCmpOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - - std::string tok = "CmpMode::EQ"; - if (auto a = op.getCmpModeAttr()) - tok = cmpModeTok(a); - - auto modeTy = emitc::OpaqueType::get(ctx, "CmpMode"); - Value modeVal = rewriter.create( - loc, modeTy, emitc::OpaqueAttr::get(ctx, tok)); - - rewriter.create( - loc, - TypeRange{}, - "TCMP", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1, modeVal}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOCmpSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TCmpSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - - // cmpMode -> token - auto cmpAttr = op.getCmpModeAttr(); // PTO_CmpModeAttr - std::string tok = cmpModeTok(cmpAttr); - - auto modeTy = emitc::OpaqueType::get(ctx, "CmpMode"); - Value modeVal = rewriter.create( - loc, modeTy, emitc::OpaqueAttr::get(ctx, tok)); - - rewriter.create( - loc, - TypeRange{}, - "TCMPS", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, scalar, modeVal}); - - rewriter.eraseOp(op); - return success(); - } -}; - - -struct PTOColMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - // intrinsic: TCOLMAX(dst, src) - rewriter.create( - loc, TypeRange{}, "TCOLMAX", - /*args=*/ArrayAttr{}, // default: print all operands - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColArgMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColArgMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLARGMAX", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, tmp}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - // intrinsic: TCOLMIN(dst, src) - rewriter.create( - loc, TypeRange{}, "TCOLMIN", - /*args=*/ArrayAttr{}, // default: print all operands - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColArgMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColArgMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLARGMIN", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, tmp}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColSumToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColSumOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - // Check if tmp exists before accessing it - if (op.getTmp()) { - // Format 2: with tmp and isBinary - Value tmp = peelUnrealized(adaptor.getTmp()); - bool isBinary = false; - if (auto a = op.getIsBinaryAttr()) - isBinary = a.getValue(); - - auto boolTy = emitc::OpaqueType::get(ctx, "bool"); - auto tok = isBinary ? "true" : "false"; - Value isBinaryVal = rewriter.create( - loc, boolTy, emitc::OpaqueAttr::get(ctx, tok)); - - rewriter.create( - loc, TypeRange{}, "TCOLSUM", - /*args=*/ArrayAttr(), - /*templateArgs=*/ArrayAttr(), - /*operands=*/ValueRange{dst, src, tmp, isBinaryVal}); - } else { - // Format 1: without tmp and isBinary - rewriter.create( - loc, TypeRange{}, "TCOLSUM", - /*args=*/ArrayAttr(), - /*templateArgs=*/ArrayAttr(), - /*operands=*/ValueRange{dst, src}); - } - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColProdToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColProdOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLPROD", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; -static std::string roundModeTok(mlir::pto::RoundModeAttr attr) { - using RM = mlir::pto::RoundMode; - switch (attr.getValue()) { - case RM::NONE: return "RoundMode::CAST_NONE"; - case RM::RINT: return "RoundMode::CAST_RINT"; - case RM::ROUND: return "RoundMode::CAST_ROUND"; - case RM::FLOOR: return "RoundMode::CAST_FLOOR"; - case RM::CEIL: return "RoundMode::CAST_CEIL"; - case RM::TRUNC: return "RoundMode::CAST_TRUNC"; - case RM::ODD: return "RoundMode::CAST_ODD"; - case RM::CAST_RINT: return "RoundMode::CAST_RINT"; - } - return "RoundMode::CAST_RINT"; -} -static std::string saturationModeTok(mlir::pto::SaturationModeAttr attr) { - using SM = mlir::pto::SaturationMode; - switch (attr.getValue()) { - case SM::ON: return "SaturationMode::ON"; - case SM::OFF: return "SaturationMode::OFF"; - } - return "SaturationMode::OFF"; -} -struct PTOCvtToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TCvtOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - pto::RoundModeAttr rmAttr = op.getRmodeAttr(); - std::string rmTok = rmAttr ? roundModeTok(rmAttr) - : std::string("RoundMode::CAST_RINT"); - auto rmodeTy = emitc::OpaqueType::get(ctx, "RoundMode"); - Value rmodeVal = rewriter.create( - loc, rmodeTy, emitc::OpaqueAttr::get(ctx, rmTok)); - - auto satModeTy = emitc::OpaqueType::get(ctx, "SaturationMode"); - auto satAttr = op.getSatModeAttr(); - std::string satTok = satAttr ? saturationModeTok(satAttr) - : std::string("SaturationMode::OFF"); - Value satModeVal = rewriter.create( - loc, satModeTy, emitc::OpaqueAttr::get(ctx, satTok)); - - SmallVector operands{dst, src, rmodeVal, satModeVal}; - - rewriter.create( - loc, TypeRange{}, "TCVT", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTORandomToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRandomOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - SmallVector operands{ - dst, - peelUnrealized(adaptor.getKey0()), - peelUnrealized(adaptor.getKey1()), - peelUnrealized(adaptor.getCounter0()), - peelUnrealized(adaptor.getCounter1()), - peelUnrealized(adaptor.getCounter2()), - peelUnrealized(adaptor.getCounter3()), - }; - ArrayAttr templateArgs = rewriter.getArrayAttr( - {emitc::OpaqueAttr::get(ctx, std::to_string(op.getRounds()))}); - - rewriter.create( - loc, TypeRange{}, "PTOAS__TRANDOM", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, operands); - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tdiv lowering -> TDIV(dst, src0, src1) -//===----------------------------------------------------------------------===// - -struct PTODivToTDIV : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TDivOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TDIV", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tdivs lowering -> TDIVS(dst, src, scalar) or TDIVS(dst, scalar, src) -// Order is determined by operand types: if src is tile_buf, order is (tile, scalar) -// Otherwise, order is (scalar, tile) -//===----------------------------------------------------------------------===// - -struct PTODivSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TDivSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value dst = peelUnrealized(adaptor.getDst()); - // Preserve source order from textual parse: - // ins(tile, scalar) -> TDIVS(dst, tile, scalar) - // ins(scalar, tile) -> TDIVS(dst, scalar, tile) - rewriter.create( - loc, TypeRange{}, "TDIVS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src, scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.tdivs (TDivSOp) lowering -> TDIVS(dst, src, scalar) or TDIVS(dst, scalar, src) -// Order is determined by operand types: if src is tile_buf, order is (tile, scalar) -// Otherwise, order is (scalar, tile) -//===----------------------------------------------------------------------===// - -struct PTOTDivSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TDivSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value dst = peelUnrealized(adaptor.getDst()); - rewriter.create( - loc, TypeRange{}, "TDIVS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src, scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.texp lowering -> TEXP(dst, src) -//===----------------------------------------------------------------------===// - -struct PTOExpToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TExpOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TEXP", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.texpands lowering -> TEXPANDS(dst, scalar) -//===----------------------------------------------------------------------===// - -struct PTOExpandsToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TExpandsOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value scalar = peelUnrealized(adaptor.getScalar()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TEXPANDS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.textract lowering -> TEXTRACT(dst, src, indexRow, indexCol) -//===----------------------------------------------------------------------===// - -struct PTOExtractToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TExtractOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value r0 = peelUnrealized(adaptor.getIndexRow()); - Value c0 = peelUnrealized(adaptor.getIndexCol()); - - rewriter.create( - loc, TypeRange{}, "TEXTRACT", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, r0, c0}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.textract_fp lowering -> TEXTRACT_FP(dst, src, fp, indexRow, indexCol) -//===----------------------------------------------------------------------===// - -struct PTOExtractFPToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TExtractFPOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value fp = peelUnrealized(adaptor.getFp()); - Value dst = peelUnrealized(adaptor.getDst()); - Value r0 = peelUnrealized(adaptor.getIndexRow()); - Value c0 = peelUnrealized(adaptor.getIndexCol()); - - rewriter.create( - loc, TypeRange{}, "TEXTRACT_FP", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, fp, r0, c0}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tinsert lowering -> TINSERT(dst, src, indexRow, indexCol) -// Keep lowering arch-agnostic and let PTO-ISA infer proper A5 path. -//===----------------------------------------------------------------------===// - -struct PTOInsertToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TInsertOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value r0 = peelUnrealized(adaptor.getIndexRow()); - Value c0 = peelUnrealized(adaptor.getIndexCol()); - - rewriter.create( - loc, TypeRange{}, "TINSERT", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, r0, c0}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tinsert_fp lowering -> TINSERT_FP(dst, src, fp, indexRow, indexCol) -//===----------------------------------------------------------------------===// - -struct PTOInsertFPToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TInsertFPOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value fp = peelUnrealized(adaptor.getFp()); - Value dst = peelUnrealized(adaptor.getDst()); - Value r0 = peelUnrealized(adaptor.getIndexRow()); - Value c0 = peelUnrealized(adaptor.getIndexCol()); - - rewriter.create( - loc, TypeRange{}, "TINSERT_FP", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, fp, r0, c0}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tfillpad lowering -> TFILLPAD(dst, src) -//===----------------------------------------------------------------------===// - -struct PTOFillPadToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TFillPadOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TFILLPAD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tfillpad_inplace lowering -> TFILLPAD_INPLACE(dst, src) -//===----------------------------------------------------------------------===// - -struct PTOFillPadInplaceToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TFillPadInplaceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TFILLPAD_INPLACE", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tfillpad_expand lowering -> TFILLPAD_EXPAND(dst, src) -//===----------------------------------------------------------------------===// - -struct PTOFillPadExpandToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TFillPadExpandOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TFILLPAD_EXPAND", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tgather lowering -// - Index form : TGATHER(dst, src0, indices, tmp) -// - Compare form: TGATHER(dst, src0, kValue, cdst, tmp) -// - Mask form : TGATHER(dst, src0) -//===----------------------------------------------------------------------===// - -[[maybe_unused]] static std::string maskPatternTok(mlir::pto::MaskPatternAttr a) { - - auto v = a.getValue(); // enum - return (std::string("pto::MaskPattern::") + mlir::pto::stringifyMaskPattern(v).str()); -} - -struct PTOGatherToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGatherOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src0 = peelUnrealized(adaptor.getSrc()); - - auto getOpaqueTok = [&](Value v, StringRef name) -> FailureOr { - if (auto ot = mlir::dyn_cast(v.getType())) - return ot.getValue().str(); - return rewriter.notifyMatchFailure(op, (name + " must be emitc::OpaqueType (tile)").str()); - }; - - // Case 1: index-based TGATHER(dst, src0, indices, tmp) - if (Value idx = adaptor.getIndices()) { - idx = peelUnrealized(idx); - Value tmp = peelUnrealized(adaptor.getTmp()); - - rewriter.create( - loc, TypeRange{}, "TGATHER", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, idx, tmp}); - - rewriter.eraseOp(op); - return success(); - } - - // Case 2: compare-based TGATHER( - // dst, src0, kValue, tmp, cdst, offset) - if (Value cdst = adaptor.getCdst()) { - cdst = peelUnrealized(cdst); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value kValue = peelUnrealized(adaptor.getKValue()); - - auto dstTokOr = getOpaqueTok(dst, "dst"); - auto srcTokOr = getOpaqueTok(src0, "src0"); - auto cdstTokOr = getOpaqueTok(cdst, "cdst"); - auto tmpTokOr = getOpaqueTok(tmp, "tmp"); - if (failed(dstTokOr) || failed(srcTokOr) || failed(cdstTokOr) || failed(tmpTokOr)) - return failure(); - - auto cmpAttr = op.getCmpModeAttr(); - std::string cmpTok = cmpAttr ? cmpModeTok(cmpAttr) : "CmpMode::EQ"; - int64_t offset = 0; - if (auto offsetAttr = op.getOffsetAttr()) - offset = offsetAttr.getInt(); - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value offsetVal = makeEmitCIntConstant(rewriter, loc, i32Ty, offset); - - auto targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, *dstTokOr), - emitc::OpaqueAttr::get(ctx, *srcTokOr), - emitc::OpaqueAttr::get(ctx, *tmpTokOr), - emitc::OpaqueAttr::get(ctx, *cdstTokOr), - emitc::OpaqueAttr::get(ctx, cmpTok), - }); - - rewriter.create( - loc, TypeRange{}, "TGATHER", - /*args=*/ArrayAttr{}, /*templateArgs=*/targs, - /*operands=*/ValueRange{dst, src0, kValue, tmp, cdst, offsetVal}); - - rewriter.eraseOp(op); - return success(); - } - - // Case 3: mask-pattern TGATHER(dst, src0) - auto mp = op.getMaskPatternAttr(); - if (!mp) - return rewriter.notifyMatchFailure(op, "expected maskPattern, indices, or cdst on tgather"); - - auto dstTokOr = getOpaqueTok(dst, "dst"); - auto srcTokOr = getOpaqueTok(src0, "src0"); - if (failed(dstTokOr) || failed(srcTokOr)) - return failure(); - - // mp is an EnumAttr; stringify name is "P0101" etc. - // We emit MaskPattern::P0101 (because generated C++ has `using namespace pto;`) - std::string mpTok = std::string("MaskPattern::") + - mlir::pto::stringifyMaskPattern(mp.getValue()).str(); - - auto targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, *dstTokOr), - emitc::OpaqueAttr::get(ctx, *srcTokOr), - emitc::OpaqueAttr::get(ctx, mpTok), - }); - - rewriter.create( - loc, TypeRange{}, "TGATHER", - /*args=*/ArrayAttr{}, - /*templateArgs=*/targs, - /*operands=*/ValueRange{dst, src0}); - - rewriter.eraseOp(op); - return success(); - } -}; - - -struct PTOGatherbToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGatherBOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value offsets = peelUnrealized(adaptor.getOffsets()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TGATHERB", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, offsets}); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// TLOG lowering to EmitC (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - -struct PTOLogToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TLogOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TLOG", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - - - -//===----------------------------------------------------------------------===// -// TLRELU lowering to EmitC (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - - struct PTOLReluToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TLReluOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value slope = peelUnrealized(adaptor.getSlope()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, slope}; - - rewriter.create( - loc, TypeRange{}, "TLRELU", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// TMAX lowering to EmitC (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - -struct PTOMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TMAX", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// TMAXS lowering to EmitC (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - - struct PTOMaxSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMaxSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, scalar}; - rewriter.create( - loc, TypeRange{}, "TMAXS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - - -//===----------------------------------------------------------------------===// -// TMIN lowering to EmitC (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - -struct PTOMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TMIN", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// TMINS lowering to EmitC (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - -//===----------------------------------------------------------------------===// -// TMINS lowering to EmitC (fix APFloat -> FloatAttr) (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - -struct PTOMinsToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMinSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value scalar = peelUnrealized(adaptor.getScalar()); - - SmallVector operands{dst, src, scalar}; - rewriter.create( - loc, TypeRange{}, "TMINS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering for TMOV op -> EmitC) -//===----------------------------------------------------------------------===// - -struct PTOMovToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMovOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value fp; - if (op.getFp()) - fp = peelUnrealized(adaptor.getFp()); - Value preQuantScalar; - if (op.getPreQuantScalar()) - preQuantScalar = peelUnrealized(adaptor.getPreQuantScalar()); - - auto dstOT = mlir::dyn_cast(dst.getType()); - auto srcOT = mlir::dyn_cast(src.getType()); - if (!dstOT || !srcOT) - return rewriter.notifyMatchFailure( - op, "tmov lowering expects opaque dst/src types"); - - auto modeTok = [&](pto::AccToVecMode mode) -> StringRef { - switch (mode) { - case pto::AccToVecMode::SingleModeVec0: - return "pto::AccToVecMode::SingleModeVec0"; - case pto::AccToVecMode::SingleModeVec1: - return "pto::AccToVecMode::SingleModeVec1"; - case pto::AccToVecMode::DualModeSplitM: - return "pto::AccToVecMode::DualModeSplitM"; - case pto::AccToVecMode::DualModeSplitN: - return "pto::AccToVecMode::DualModeSplitN"; - } - llvm_unreachable("unknown AccToVecMode"); - }; - - auto modeAttr = op.getAccToVecModeAttr(); - auto reluTok = [&](pto::ReluPreMode mode) -> StringRef { - switch (mode) { - case pto::ReluPreMode::NoRelu: - return "ReluPreMode::NoRelu"; - case pto::ReluPreMode::NormalRelu: - return "ReluPreMode::NormalRelu"; - } - llvm_unreachable("unknown ReluPreMode"); - }; - - const bool hasFp = static_cast(fp); - const bool hasPreQuantScalar = static_cast(preQuantScalar); - const bool hasMode = static_cast(modeAttr); - const bool reluNonDefault = op.getReluPreMode() != pto::ReluPreMode::NoRelu; - - SmallVector operands{dst, src}; - SmallVector templateArgVec{ - emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), - }; - StringRef callee = "TMOV"; - - if (hasFp) { - auto fpOT = mlir::dyn_cast(fp.getType()); - if (!fpOT) - return rewriter.notifyMatchFailure( - op, "tmov fp lowering expects opaque fp type"); - operands.push_back(fp); - templateArgVec.push_back(emitc::OpaqueAttr::get(ctx, fpOT.getValue().str())); - if (hasMode) - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, modeTok(modeAttr.getValue()))); - if (hasMode || reluNonDefault) - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); - callee = hasMode ? "TMOV" : "TMOV_FP"; - } else if (hasPreQuantScalar) { - operands.push_back(preQuantScalar); - if (hasMode) - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, modeTok(modeAttr.getValue()))); - if (hasMode || reluNonDefault) - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); - } else if (hasMode) { - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, modeTok(modeAttr.getValue()))); - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); - } else if (reluNonDefault) { - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); - } - - ArrayAttr templateArgs = - templateArgVec.size() == 2 && !hasFp && !hasPreQuantScalar && - !hasMode && !reluNonDefault - ? ArrayAttr{} - : rewriter.getArrayAttr(templateArgVec); - - rewriter.create( - loc, TypeRange{}, callee, - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TMOV_FP DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOMovFPToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMovFPOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value fp = peelUnrealized(adaptor.getFp()); - - // TMOV_FP(dstTileData, cTile, fbTile) - ArrayAttr templateArgs; - auto dstOT = mlir::dyn_cast(dst.getType()); - auto srcOT = mlir::dyn_cast(src.getType()); - auto fpOT = mlir::dyn_cast(fp.getType()); - if (dstOT && srcOT && fpOT) { - templateArgs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, fpOT.getValue().str()), - }); - } else { - templateArgs = ArrayAttr{}; - } - - SmallVector operands{dst, src, fp}; - rewriter.create( - loc, TypeRange{}, "TMOV_FP", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOQuantToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TQuantOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value fp = peelUnrealized(adaptor.getFp()); - - // Optional offset (INT8_ASYM only): passed as pointer (&offset) - Value offsetPtr; - if (op.getOffset()) { - Value offset = peelUnrealized(adaptor.getOffset()); - auto offsetOT = mlir::dyn_cast(offset.getType()); - if (offsetOT) { - offsetPtr = rewriter - .create( - loc, emitc::PointerType::get(offsetOT), "&", offset) - .getResult(); - } - } - - // TQUANT(dst, src, fp[, &offset]) - std::string quantTypeStr = - op.getQuantType() == pto::QuantType::INT8_SYM - ? "pto::QuantType::INT8_SYM" - : "pto::QuantType::INT8_ASYM"; - ArrayAttr templateArgs; - auto dstOT = mlir::dyn_cast(dst.getType()); - auto srcOT = mlir::dyn_cast(src.getType()); - auto fpOT = mlir::dyn_cast(fp.getType()); - if (dstOT && srcOT && fpOT) { - templateArgs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, quantTypeStr), - emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, fpOT.getValue().str()), - }); - } else { - templateArgs = ArrayAttr{}; - } - - SmallVector operands{dst, src, fp}; - if (offsetPtr) - operands.push_back(offsetPtr); - - rewriter.create( - loc, TypeRange{}, "TQUANT", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTODequantToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TDequantOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value scale = peelUnrealized(adaptor.getScale()); - Value offset = peelUnrealized(adaptor.getOffset()); - - // TDEQUANT(dst, src, scale, offset) - ArrayAttr templateArgs; - auto dstOT = mlir::dyn_cast(dst.getType()); - auto srcOT = mlir::dyn_cast(src.getType()); - auto scaleOT = mlir::dyn_cast(scale.getType()); - if (dstOT && srcOT && scaleOT) { - templateArgs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, scaleOT.getValue().str()), - }); - } else { - templateArgs = ArrayAttr{}; - } - - rewriter.create( - loc, TypeRange{}, "TDEQUANT", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, - /*operands=*/SmallVector{dst, src, scale, offset}); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TMRGSORT DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOMrgSortToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMrgSortOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - if (op.isFormat1()) { - Value src = peelUnrealized(adaptor.getSrcs().front()); - Value dst = peelUnrealized(adaptor.getDsts().front()); - Value blockLen = peelUnrealized(adaptor.getBlockLen()); - - SmallVector operands{dst, src, blockLen}; - rewriter.create( - loc, TypeRange{}, "TMRGSORT", - ArrayAttr{}, ArrayAttr{}, operands); - } else if (op.isFormat2()) { - // pto-isa API: - // TMRGSORT( - // dst, executedNumList, tmp, src0, src1[, src2[, src3]]); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDsts()[0]); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value excuted = peelUnrealized(adaptor.getExcuted()); - - SmallVector srcs; - srcs.reserve(adaptor.getSrcs().size()); - for (Value v : adaptor.getSrcs()) - srcs.push_back(peelUnrealized(v)); - - auto dstOT = mlir::dyn_cast(dst.getType()); - auto tmpOT = mlir::dyn_cast(tmp.getType()); - if (!dstOT || !tmpOT || srcs.size() < 2 || srcs.size() > 4) - return op.emitOpError("format2 expects dst/tmp tilebufs and 2 to 4 srcs"); - - SmallVector targs; - targs.reserve(2 + srcs.size() + 1); - targs.push_back(emitc::OpaqueAttr::get(ctx, dstOT.getValue().str())); - targs.push_back(emitc::OpaqueAttr::get(ctx, tmpOT.getValue().str())); - for (Value v : srcs) { - auto ot = mlir::dyn_cast(v.getType()); - if (!ot) - return op.emitOpError("format2 expects tilebuf srcs"); - targs.push_back(emitc::OpaqueAttr::get(ctx, ot.getValue().str())); - } - targs.push_back(emitc::OpaqueAttr::get(ctx, op.getExhausted() ? "true" : "false")); - ArrayAttr templateArgs = rewriter.getArrayAttr(targs); - - SmallVector operands{dst, excuted, tmp}; - operands.append(srcs.begin(), srcs.end()); - - rewriter.create( - loc, TypeRange{}, "TMRGSORT", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, operands); - } else { - return op.emitOpError("unsupported mrgsort_dps format"); - } - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TMUL DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOMulToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMulOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TMUL", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TMULS DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOMulsToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMulSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc0()); - Value dst = peelUnrealized(adaptor.getDst()); - Value scalar = peelUnrealized(adaptor.getScalar()); - - SmallVector operands{dst, src, scalar}; - rewriter.create( - loc, TypeRange{}, "TMULS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TNEG DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTONegToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TNegOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TNEG", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TNOT DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTONotToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TNotOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TNOT", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TOR DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOOrToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TOrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TOR", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TORS DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOOrsToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TOrSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - // NOTE: The conversion type system may materialize integers as emitc.opaque - // (e.g. "int32_t"). For EmitC call emission we can pass the scalar through - // directly without arith casts here. - Value s = adaptor.getScalar(); - - SmallVector operands{dst, src0, s}; - rewriter.create( - loc, TypeRange{}, "TORS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TPARTADD DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOPartAddToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPartAddOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TPARTADD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TPARTMAX DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOPartMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPartMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TPARTMAX", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TPARTMIN DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOPartMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPartMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TPARTMIN", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOPartArgMaxToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPartArgMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value src0Idx = peelUnrealized(adaptor.getSrc0Idx()); - Value src1Idx = peelUnrealized(adaptor.getSrc1Idx()); - Value dst = peelUnrealized(adaptor.getDst()); - Value dstIdx = peelUnrealized(adaptor.getDstIdx()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TPARTARGMAX", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1, dstIdx, src0Idx, src1Idx}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOPartArgMinToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPartArgMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value src0Idx = peelUnrealized(adaptor.getSrc0Idx()); - Value src1Idx = peelUnrealized(adaptor.getSrc1Idx()); - Value dst = peelUnrealized(adaptor.getDst()); - Value dstIdx = peelUnrealized(adaptor.getDstIdx()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TPARTARGMIN", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1, dstIdx, src0Idx, src1Idx}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TPARTMUL DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOPartMulToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPartMulOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TPARTMUL", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TPRELU DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOPreluToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPReluOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - // C++ interface: TPRELU(dst, src0, src1, tmp) — last parameter is tmp. - SmallVector operands{dst, src0, src1, tmp}; - rewriter.create( - loc, TypeRange{}, "TPRELU", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TRECIP DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORecipToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRecipOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TRECIP", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TRELU DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOReluToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TReluOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TRELU", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TREM DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORemToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRemOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - SmallVector operands{dst, src0, src1, tmp}; - rewriter.create( - loc, TypeRange{}, "TREM", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOFModToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TFModOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TFMOD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TREMS DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORemSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRemSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - Value scalar = peelUnrealized(adaptor.getScalar()); - SmallVector operands{dst, src, scalar, tmp}; - rewriter.create( - loc, TypeRange{}, "TREMS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOFModSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TFModSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value scalar = peelUnrealized(adaptor.getScalar()); - - SmallVector operands{dst, src, scalar}; - rewriter.create( - loc, TypeRange{}, "TFMODS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWEXPAND DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORowExpandToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TROWEXPAND", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowExpandAddToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandAddOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TROWEXPANDADD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowExpandExpdifToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandExpdifOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src0, src1, tmp}); - else - operands.assign({dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TROWEXPANDEXPDIF", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWEXPANDDIV DPS/memref op) -//===----------------------------------------------------------------------===// -// Helper: replace or erase based on whether op has results. -static void replaceOrEraseWithOpaqueCall(Operation *op, - StringRef callee, - ArrayRef args, - ConversionPatternRewriter &rewriter) { - TypeRange resultTypes = op->getResultTypes(); - auto call = rewriter.create( - op->getLoc(), resultTypes, callee, ArrayAttr{}, ArrayAttr{}, ValueRange(args)); - if (resultTypes.empty()) - rewriter.eraseOp(op); - else - rewriter.replaceOp(op, call.getResults()); -} - -static void replaceOrEraseWithOpaqueCallAndReturnDst(Operation *op, Value dst, - StringRef callee, - ArrayRef args, - ConversionPatternRewriter &rewriter) { - rewriter.create( - op->getLoc(), TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, ValueRange(args)); - if (op->getNumResults() == 1) - rewriter.replaceOp(op, dst); - else - rewriter.eraseOp(op); -} - -// ---------- TOp ---------- -struct PTOTGemvBiasToTGEMV_BIAS - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGemvBiasOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getA()); - Value b = peelUnrealized(adaptor.getB()); - Value bias = peelUnrealized(adaptor.getBias()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCall(op.getOperation(), "TGEMV_BIAS", - {dst, a, b, bias}, rewriter); - return success(); - } -}; - -struct PTOTGemvMXToTGEMV_MX - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGemvMxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getA()); - Value aScale = peelUnrealized(adaptor.getAScale()); - Value b = peelUnrealized(adaptor.getB()); - Value bScale = peelUnrealized(adaptor.getBScale()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", - {dst, a, aScale, b, bScale}, rewriter); - return success(); - } -}; - -struct PTOTGemvMXAccToTGEMV_MX - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGemvMxAccOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value cIn = peelUnrealized(adaptor.getCIn()); - Value a = peelUnrealized(adaptor.getA()); - Value aScale = peelUnrealized(adaptor.getAScale()); - Value b = peelUnrealized(adaptor.getB()); - Value bScale = peelUnrealized(adaptor.getBScale()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", - {dst, cIn, a, aScale, b, bScale}, rewriter); - return success(); - } -}; - -struct PTOTGemvMXBiasToTGEMV_MX - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGemvMxBiasOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getA()); - Value aScale = peelUnrealized(adaptor.getAScale()); - Value b = peelUnrealized(adaptor.getB()); - Value bScale = peelUnrealized(adaptor.getBScale()); - Value bias = peelUnrealized(adaptor.getBias()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", - {dst, a, aScale, b, bScale, bias}, rewriter); - return success(); - } -}; - -struct PTOTMatmulBiasToTMATMUL_BIAS - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMatmulBiasOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getA()); - Value b = peelUnrealized(adaptor.getB()); - Value bias = peelUnrealized(adaptor.getBias()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_BIAS", - {dst, a, b, bias}, rewriter); - return success(); - } -}; - -struct PTOTMatmulMXToTMATMUL_MX - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMatmulMxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getA()); - Value aScale = peelUnrealized(adaptor.getAScale()); - Value b = peelUnrealized(adaptor.getB()); - Value bScale = peelUnrealized(adaptor.getBScale()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_MX", - {dst, a, aScale, b, bScale}, rewriter); - return success(); - } -}; - -struct PTOTMatmulMXAccToTMATMUL_MX_ACC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMatmulMxAccOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value cIn = peelUnrealized(adaptor.getCIn()); - Value a = peelUnrealized(adaptor.getA()); - Value aScale = peelUnrealized(adaptor.getAScale()); - Value b = peelUnrealized(adaptor.getB()); - Value bScale = peelUnrealized(adaptor.getBScale()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_MX", - {dst, cIn, a, aScale, b, bScale}, rewriter); - return success(); - } -}; - -struct PTOTMatmulMXBiasToTMATMUL_MX_BIAS - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMatmulMxBiasOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getA()); - Value aScale = peelUnrealized(adaptor.getAScale()); - Value b = peelUnrealized(adaptor.getB()); - Value bScale = peelUnrealized(adaptor.getBScale()); - Value bias = peelUnrealized(adaptor.getBias()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_MX", - {dst, a, aScale, b, bScale, bias}, rewriter); - return success(); - } -}; - -struct PTORowExpandDivToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandDivOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src0, src1, tmp}); - else - operands.assign({dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TROWEXPANDDIV", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWEXPANDMUL DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORowExpandMulToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandMulOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src0, src1, tmp}); - else - operands.assign({dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TROWEXPANDMUL", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWEXPANDSUB DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORowExpandSubToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandSubOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src0, src1, tmp}); - else - operands.assign({dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TROWEXPANDSUB", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowExpandMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src0, src1, tmp}); - else - operands.assign({dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TROWEXPANDMAX", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowExpandMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src0, src1, tmp}); - else - operands.assign({dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TROWEXPANDMIN", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWMAX DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORowMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, tmp}; - rewriter.create( - loc, TypeRange{}, "TROWMAX", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowArgMaxToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowArgMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TROWARGMAX", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, tmp}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWMIN DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORowMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, tmp}; - rewriter.create( - loc, TypeRange{}, "TROWMIN", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowArgMinToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowArgMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TROWARGMIN", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, tmp}); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWSUM DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORowSumToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowSumOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, tmp}; - rewriter.create( - loc, TypeRange{}, "TROWSUM", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowProdToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowProdOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, tmp}; - rewriter.create( - loc, TypeRange{}, "TROWPROD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TRSQRT DPS/memref op) -// - no-tmp form : TRSQRT(dst, src) -// - tmp form : TRSQRT(dst, src, tmp) -//===----------------------------------------------------------------------===// - -struct PTORsqrtToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRsqrtOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - SmallVector operands{dst, src}; - if (Value tmp = adaptor.getTmp()) - operands.push_back(peelUnrealized(tmp)); - rewriter.create( - loc, TypeRange{}, "TRSQRT", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSCATTER DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOScatterToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TScatterOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - const bool hasMaskPattern = static_cast(op.getMaskPatternAttr()); - const bool hasIndexes = static_cast(op.getIndexes()); - if (hasMaskPattern == hasIndexes) { - return rewriter.notifyMatchFailure( - op, "expected exactly one of indexes operand or maskPattern attribute"); - } - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - if (auto mp = op.getMaskPatternAttr()) { - auto *ctx = rewriter.getContext(); - auto targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, maskPatternTok(mp)), - }); - rewriter.create( - loc, TypeRange{}, "TSCATTER", - /*args=*/ArrayAttr{}, /*templateArgs=*/targs, - /*operands=*/ValueRange{dst, src}); - } else { - Value idx = peelUnrealized(adaptor.getIndexes()); - rewriter.create( - loc, TypeRange{}, "TSCATTER", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, idx}); - } - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSEL DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSelToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSelOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value mask = peelUnrealized(adaptor.getMask()); - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, mask, src0, src1, tmp}; - rewriter.create( - loc, TypeRange{}, "TSEL", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSELS DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSelSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSelSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value mask = peelUnrealized(adaptor.getMask()); - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, mask, src, tmp, scalar}; - rewriter.create( - loc, TypeRange{}, "TSELS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSHL DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOShlSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TShlOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TSHL", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSHR DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOShrSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TShrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TSHR", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering for TSHLS/TSHRS DPS: shift by scalar) -//===----------------------------------------------------------------------===// - -struct PTOShlSConstToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TShlSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - SmallVector operands{dst, src, scalar}; - rewriter.create( - loc, TypeRange{}, "TSHLS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOShrSConstToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TShrSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - SmallVector operands{dst, src, scalar}; - rewriter.create( - loc, TypeRange{}, "TSHRS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (TSORT32 DPS/memref op: ins(src, idx[, tmp]) outs(dst)) -//===----------------------------------------------------------------------===// - -struct PTOSORT32SToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSort32Op op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value idx = peelUnrealized(adaptor.getIdx()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src, idx, tmp}); - else - operands.assign({dst, src, idx}); - rewriter.create( - loc, TypeRange{}, "TSORT32", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSQRT DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSqrtSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSqrtOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TSQRT", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSTORE_FP DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOStoreFPSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TStoreFPOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value fp = peelUnrealized(adaptor.getFp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, fp}; - rewriter.create( - loc, TypeRange{}, "TSTORE_FP", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSUB DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSubSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSubOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TSUB", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSUBC DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSubCSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSubCOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value src2 = peelUnrealized(adaptor.getSrc2()); - Value dst = peelUnrealized(adaptor.getDst()); - - // pto-isa does not provide NPU implementation for TSUBC yet. - // Decompose: dst = src0 - src1 + src2 - rewriter.create( - loc, TypeRange{}, "TSUB", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TADD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, dst, src2}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSUBS DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSubSSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSubSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, scalar}; - rewriter.create( - loc, TypeRange{}, "TSUBS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSUBSC DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSubSCToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSubSCOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - // pto-isa does not provide NPU implementation for TSUBSC yet. - // Decompose: dst = src0 - scalar + src1 - rewriter.create( - loc, TypeRange{}, "TSUBS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, scalar}); - rewriter.create( - loc, TypeRange{}, "TADD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, dst, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TXOR DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOXORToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TXorOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = peelUnrealized(adaptor.getTmp()); - SmallVector operands{dst, src0, src1, tmp}; - rewriter.create( - loc, TypeRange{}, "TXOR", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTOTTransToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TTransOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, tmp}; - rewriter.create( - loc, TypeRange{}, "TTRANS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TXORS DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOXORSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TXorSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, scalar, tmp}; - rewriter.create( - loc, TypeRange{}, "TXORS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - struct PTOPrintToTPRINT : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPrintOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - - SmallVector operands{src}; - rewriter.create( - loc, TypeRange{}, "TPRINT", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -// pto.print "format", %scalar -> PRINTF("format", scalar) -struct PTOPrintOpToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::PrintOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - std::string fmt = op.getFormat().str(); - if (fmt.empty()) - fmt = "%f"; - std::string quoted = "\""; - for (char c : fmt) { - if (c == '"' || c == '\\') - quoted += '\\'; - else if (c == '\n') - quoted += "\\n"; - else if (c == '\t') - quoted += "\\t"; - else - quoted += c; - } - quoted += "\""; - - Value scalar = peelUnrealized(adaptor.getScalar()); - auto argsAttr = rewriter.getArrayAttr( - {emitc::OpaqueAttr::get(ctx, quoted), - IntegerAttr::get(IndexType::get(ctx), 0)}); - rewriter.create( - loc, TypeRange{}, "cce::printf", - /*args=*/argsAttr, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; - -// pto.trap -> TRAP() -struct PTOTrapOpToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TrapOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - rewriter.create( - loc, TypeRange{}, "trap", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{}); - - rewriter.eraseOp(op); - return success(); - } -}; - -// ============================================================================= -// 2. BindTileOp Lowering (FIX: Trace back to physical address) -// ============================================================================= -struct PTOBindTileToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - struct TileBuildSpec { - std::string tileTypeStr; - bool useConstructor = false; - SmallVector constructorArgs; - }; - - static bool getIndexConst(Value v, int64_t &out) { - if (!v) - return false; - if (auto cst = v.getDefiningOp()) { - if (auto ia = dyn_cast(cst.getValue())) { - out = ia.getValue().getSExtValue(); - return true; - } - } - return false; - } - - static bool getTilePointerStrides(pto::TileBufConfigAttr configAttr, - Type elemTy, int64_t rows, int64_t cols, - int64_t &rowStride, - int64_t &colStride) { - if (rows == ShapedType::kDynamic || cols == ShapedType::kDynamic) - return false; - - int32_t blVal = 0; - if (auto blAttr = dyn_cast(configAttr.getBLayout())) - blVal = static_cast(blAttr.getValue()); - else if (auto intAttr = dyn_cast(configAttr.getBLayout())) - blVal = static_cast(intAttr.getInt()); - - int32_t slVal = 0; - if (auto slAttr = dyn_cast(configAttr.getSLayout())) - slVal = static_cast(slAttr.getValue()); - else if (auto intAttr = dyn_cast(configAttr.getSLayout())) - slVal = static_cast(intAttr.getInt()); - - bool boxed = slVal != 0; - int64_t innerRows = 1; - int64_t innerCols = 1; - if (boxed) { - int32_t fractal = 512; - if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) - fractal = static_cast(frAttr.getInt()); - - unsigned elemBytes = pto::getPTOStorageElemByteSize(elemTy); - if (elemBytes == 0) - return false; - - switch (fractal) { - case 1024: - innerRows = 16; - innerCols = 16; - break; - case 32: - innerRows = 16; - innerCols = 2; - break; - case 512: - if (slVal == 1) { - innerRows = 16; - innerCols = 32 / elemBytes; - } else if (slVal == 2) { - innerRows = 32 / elemBytes; - innerCols = 16; - } else { - return false; - } - break; - default: - return false; - } - if (innerRows <= 0 || innerCols <= 0) - return false; - } - - if (!boxed) { - if (blVal == 1) { - rowStride = 1; - colStride = rows; - } else { - rowStride = cols; - colStride = 1; - } - return true; - } - - if (blVal == 1) { - if (slVal != 1) - return false; - rowStride = innerCols; - colStride = rows; - return true; - } - - rowStride = cols; - colStride = innerRows; - return true; - } - - LogicalResult matchAndRewrite(pto::BindTileOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - auto configAttr = op.getConfigAttr(); - auto viewSemantics = op->getAttrOfType("pto.view_semantics"); - bool isSubView = viewSemantics && viewSemantics.getValue() == "subview"; - - auto peelAllCasts = [](Value v) { - while (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(0); - if (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(); - return v; - }; - auto isTileLike = [](Value v) -> bool { - auto ot = dyn_cast(v.getType()); - if (!ot) - return false; - StringRef s = ot.getValue(); - return s.contains("Tile<") || s.contains("ConvTile<"); - }; - auto buildTileSpec = [&]() -> FailureOr { - auto resMrTy = dyn_cast(op.getType()); - if (!resMrTy) - return failure(); - - const char *roleTok = "TileType::Vec"; - if (auto asAttr = - dyn_cast_or_null(resMrTy.getMemorySpace())) { - switch (asAttr.getAddressSpace()) { - case pto::AddressSpace::VEC: - roleTok = "TileType::Vec"; - break; - case pto::AddressSpace::MAT: - roleTok = "TileType::Mat"; - break; - case pto::AddressSpace::LEFT: - roleTok = "TileType::Left"; - break; - case pto::AddressSpace::RIGHT: - roleTok = "TileType::Right"; - break; - case pto::AddressSpace::ACC: - roleTok = "TileType::Acc"; - break; - case pto::AddressSpace::BIAS: - roleTok = "TileType::Bias"; - break; - case pto::AddressSpace::SCALING: - roleTok = "TileType::Scaling"; - break; - case pto::AddressSpace::GM: - case pto::AddressSpace::Zero: - roleTok = "TileType::Vec"; - break; - } - } - - Type elemTy = resMrTy.getElementType(); - Type emitElemTy = getTypeConverter()->convertType(elemTy); - if (!emitElemTy) - return failure(); - auto emitElemOpaque = dyn_cast(emitElemTy); - if (!emitElemOpaque) - return failure(); - std::string elemTypeStr = emitElemOpaque.getValue().str(); - - if (resMrTy.getRank() < 2) - return failure(); - int64_t rows = resMrTy.getDimSize(0); - int64_t cols = resMrTy.getDimSize(1); - if (rows == ShapedType::kDynamic || cols == ShapedType::kDynamic) - return failure(); - - std::string blTok = "BLayout::RowMajor"; - if (auto blAttr = dyn_cast(configAttr.getBLayout())) { - if (static_cast(blAttr.getValue()) == 1) - blTok = "BLayout::ColMajor"; - } - pto::BLayout blayout = getTileBufBLayoutValue(configAttr); - - if (isSubView) { - auto subMrTy = dyn_cast(op.getSource().getType()); - auto subViewOp = op.getSource().getDefiningOp(); - if (subMrTy && subMrTy.getRank() >= 2 && subViewOp) { - int64_t subRows = subMrTy.getDimSize(0); - int64_t subCols = subMrTy.getDimSize(1); - SmallVector inheritedStrides; - int64_t inheritedOffset = ShapedType::kDynamic; - - if (!pto::isPTOFloat4PackedType(elemTy) && - subRows != ShapedType::kDynamic && - subCols != ShapedType::kDynamic && - succeeded(getStridesAndOffset(subMrTy, inheritedStrides, - inheritedOffset)) && - inheritedStrides.size() >= 2) { - int64_t childRowStride = 0; - int64_t childColStride = 0; - bool sameStrides = getTilePointerStrides( - configAttr, elemTy, subRows, subCols, childRowStride, - childColStride); - sameStrides = sameStrides && - inheritedStrides[0] == childRowStride && - inheritedStrides[1] == childColStride; - if (sameStrides) { - rows = subRows; - cols = subCols; - } - } - } - } - - std::string slTok = "SLayout::NoneBox"; - if (auto slAttr = dyn_cast(configAttr.getSLayout())) { - int32_t slVal = static_cast(slAttr.getValue()); - slTok = (slVal == 1) ? "SLayout::RowMajor" - : (slVal == 2) ? "SLayout::ColMajor" - : "SLayout::NoneBox"; - } - - int32_t fractal = 512; - if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) - fractal = frAttr.getInt(); - - std::string padTok = "PadValue::Null"; - if (auto padAttr = dyn_cast(configAttr.getPad())) { - switch (static_cast(padAttr.getValue())) { - case 1: - padTok = "PadValue::Zero"; - break; - case 2: - padTok = "PadValue::Max"; - break; - case 3: - padTok = "PadValue::Min"; - break; - default: - padTok = "PadValue::Null"; - break; - } - } - - std::string compactTok = "CompactMode::Null"; - if (auto compactAttr = dyn_cast(configAttr.getCompactMode())) { - switch (static_cast(compactAttr.getValue())) { - case 1: - compactTok = "CompactMode::Normal"; - break; - case 2: - compactTok = "CompactMode::RowPlusOne"; - break; - default: - compactTok = "CompactMode::Null"; - break; - } - } - - std::string vrowTok, vcolTok; - bool useConstructor = false; - bool rowIsDynamic = false; - bool colIsDynamic = false; - SmallVector constructorArgs; - - Value vRow = op.getValidRow(); - Value vCol = op.getValidCol(); - Value vRowEmitC = adaptor.getValidRow(); - Value vColEmitC = adaptor.getValidCol(); - bool forceDynamicValid = op->hasAttr(kForceDynamicValidShapeAttrName); - int64_t cRow = 0, cCol = 0; - bool rowIsConst = vRow && getIndexConst(vRow, cRow); - bool colIsConst = vCol && getIndexConst(vCol, cCol); - - auto makeCtorDimValue = [&](Value emitted, int64_t fallback) -> Value { - if (emitted) - return emitted; - return makeEmitCIntConstant( - rewriter, loc, emitc::OpaqueType::get(ctx, "int32_t"), fallback); - }; - auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { - if (!emitted || !pto::isPTOFloat4PackedType(elemTy)) - return emitted; - int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; - if (dimIdx != packedDim) - return emitted; - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); - return rewriter.create(loc, i32Ty, emitted, two).getResult(); - }; - - if (forceDynamicValid) { - vrowTok = "-1"; - vcolTok = "-1"; - useConstructor = true; - constructorArgs.push_back( - makeCtorDimValue(maybeScaleDynamicValid(vRowEmitC, 0), - renderTileTemplateDim(rowIsConst ? cRow : rows, - elemTy, blayout, 0))); - constructorArgs.push_back( - makeCtorDimValue(maybeScaleDynamicValid(vColEmitC, 1), - renderTileTemplateDim(colIsConst ? cCol : cols, - elemTy, blayout, 1))); - } else { - if (rowIsConst) { - vrowTok = std::to_string( - renderTileTemplateDim(cRow, elemTy, blayout, 0)); - } else if (vRow) { - vrowTok = "-1"; - rowIsDynamic = true; - useConstructor = true; - } else { - vrowTok = std::to_string( - renderTileTemplateDim(rows, elemTy, blayout, 0)); - } - - if (colIsConst) { - vcolTok = std::to_string( - renderTileTemplateDim(cCol, elemTy, blayout, 1)); - } else if (vCol) { - vcolTok = "-1"; - colIsDynamic = true; - useConstructor = true; - } else { - vcolTok = std::to_string( - renderTileTemplateDim(cols, elemTy, blayout, 1)); - } - - if (useConstructor) { - if (rowIsDynamic && vRowEmitC) - constructorArgs.push_back(maybeScaleDynamicValid(vRowEmitC, 0)); - if (colIsDynamic && vColEmitC) - constructorArgs.push_back(maybeScaleDynamicValid(vColEmitC, 1)); - } - } - - std::string tileTypeStr = std::string("Tile<") + roleTok + ", " + - elemTypeStr + ", " + - std::to_string(renderTileTemplateDim( - rows, elemTy, blayout, 0)) + - ", " + - std::to_string(renderTileTemplateDim( - cols, elemTy, blayout, 1)) + - ", " + blTok + - ", " + vrowTok + ", " + vcolTok + ", " + slTok + - ", " + std::to_string(fractal) + ", " + padTok + - ", " + compactTok + - ">"; - return TileBuildSpec{tileTypeStr, useConstructor, constructorArgs}; - }; - - auto buildTileValue = [&](const TileBuildSpec &spec, - bool forceDeclaration = false) -> Value { - auto tileType = emitc::OpaqueType::get(ctx, spec.tileTypeStr); - if (spec.useConstructor && !forceDeclaration) { - return rewriter - .create(loc, tileType, spec.tileTypeStr, - ArrayAttr{}, ArrayAttr{}, - ValueRange(spec.constructorArgs)) - .getResult(0); - } - - return rewriter - .create(loc, tileType, emitc::OpaqueAttr::get(ctx, "")) - .getResult(); - }; - - auto emitElemTypeToString = [&](Type elemTy) -> std::string { - return getEmitCScalarTypeToken(elemTy); - }; - - auto buildIntegralAddress = [&](Value sourceValue) -> FailureOr { - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - auto rcU64 = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - - Value rawPtr = sourceValue; - if (auto ot = dyn_cast(sourceValue.getType())) { - StringRef tyStr = ot.getValue(); - if (tyStr.contains("Tile<") || tyStr.contains("ConvTile<")) { - auto srcMrTy = dyn_cast(op.getSource().getType()); - if (!srcMrTy) - return failure(); - std::string elemTok = emitElemTypeToString(srcMrTy.getElementType()); - pto::AddressSpace as = pto::AddressSpace::GM; - if (auto asAttr = - dyn_cast_or_null(srcMrTy.getMemorySpace())) - as = asAttr.getAddressSpace(); - rawPtr = materializeTileDataValue(rewriter, loc, sourceValue, as, - elemTok); - } - } - - if (isSetFFTsPointerLikeType(rawPtr.getType())) { - return rewriter - .create(loc, u64Ty, "reinterpret_cast", - ArrayAttr{}, rcU64, ValueRange{rawPtr}) - .getResult(0); - } - - if (rawPtr.getType() == u64Ty) - return rawPtr; - return rewriter.create(loc, u64Ty, rawPtr).getResult(); - }; - - if (op.getSource().getDefiningOp()) { - FailureOr tileSpec = buildTileSpec(); - if (failed(tileSpec)) - return failure(); - rewriter.replaceOp(op, buildTileValue(*tileSpec)); - return success(); - } - - Value tileCandidate = peelAllCasts(adaptor.getSource()); - if (viewSemantics && viewSemantics.getValue() == "bitcast" && - isTileLike(tileCandidate)) { - FailureOr tileSpec = buildTileSpec(); - if (failed(tileSpec)) - return failure(); - Value dstTile = buildTileValue(*tileSpec); - FailureOr addr = buildIntegralAddress(tileCandidate); - if (failed(addr)) - return failure(); - - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dstTile, *addr}); - rewriter.replaceOp(op, dstTile); - return success(); - } - - if (viewSemantics && viewSemantics.getValue() == "treshape" && - isTileLike(tileCandidate)) { - FailureOr tileSpec = buildTileSpec(); - if (failed(tileSpec)) - return failure(); - Value dstTile = buildTileValue(*tileSpec, /*forceDeclaration=*/true); - - rewriter.create(loc, TypeRange{}, "TRESHAPE", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dstTile, tileCandidate}); - rewriter.replaceOp(op, dstTile); - return success(); - } - - // Subview origins are kept distinct from generic tile rebinding: - // even when source/destination C++ tile types match, subview may carry - // shifted base address semantics and should materialize a fresh handle. - if (isSubView) { - FailureOr tileSpec = buildTileSpec(); - if (failed(tileSpec)) - return failure(); - Value dstTile = buildTileValue(*tileSpec); - FailureOr addr = buildIntegralAddress(tileCandidate); - if (failed(addr)) - return failure(); - - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dstTile, *addr}); - rewriter.replaceOp(op, dstTile); - return success(); - } - - // Generic tile-to-tile rebind path: preserve the same backing storage and - // rebuild a sibling tile with updated metadata/valid dims. - if (isTileLike(tileCandidate)) { - FailureOr tileSpec = buildTileSpec(); - if (failed(tileSpec)) - return failure(); - - if (!tileSpec->useConstructor) { - if (auto srcTy = dyn_cast(tileCandidate.getType())) { - if (srcTy.getValue() == tileSpec->tileTypeStr) { - rewriter.replaceOp(op, tileCandidate); - return success(); - } - } - } - - Value dstTile = buildTileValue(*tileSpec); - FailureOr addr = buildIntegralAddress(tileCandidate); - if (failed(addr)) - return failure(); - - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dstTile, *addr}); - rewriter.replaceOp(op, dstTile); - return success(); - } - - SmallVector physAddrs; - Value source = op.getSource(); - - while (auto castOp = source.getDefiningOp()) - source = castOp.getOperand(0); - - if (auto upstreamCast = source.getDefiningOp()) { - auto upstreamOperands = upstreamCast.getAddrs(); - physAddrs.append(upstreamOperands.begin(), upstreamOperands.end()); - } else { - physAddrs.push_back(adaptor.getSource()); - } - - Value vRow = op.getValidRow(); - Value vCol = op.getValidCol(); - - auto newCast = rewriter.create( - loc, op.getType(), physAddrs, vRow ? vRow : Value(), - vCol ? vCol : Value(), configAttr); - if (viewSemantics) - newCast->setAttr("pto.view_semantics", viewSemantics); - if (op->hasAttr(kForceDynamicValidShapeAttrName)) - newCast->setAttr(kForceDynamicValidShapeAttrName, - op->getAttr(kForceDynamicValidShapeAttrName)); - rewriter.replaceOp(op, newCast.getResult()); - - return success(); - } -}; - -struct PTOAllocTileToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::AllocTileOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - MLIRContext *ctx = rewriter.getContext(); - auto tileTy = cast(op.getResult().getType()); - auto tileTypeString = getEmitCTileTypeString(tileTy); - if (!tileTypeString) - return rewriter.notifyMatchFailure( - op, "only rank-2 alloc_tile handles can be converted to EmitC"); - - Type convertedTy = getTypeConverter()->convertType(tileTy); - if (!convertedTy) - convertedTy = emitc::OpaqueType::get(ctx, *tileTypeString); - - auto validShape = tileTy.getValidShape(); - bool hasDynamicValidDim = - llvm::any_of(validShape, [](int64_t dim) { return dim < 0; }); - bool useConstructor = hasDynamicValidDim; - - SmallVector constructorArgs; - if (useConstructor) { - Type elemTy = tileTy.getElementType(); - pto::BLayout blayout = getTileBufBLayoutValue(tileTy.getConfigAttr()); - auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { - if (!emitted || !pto::isPTOFloat4PackedType(elemTy)) - return emitted; - int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; - if (dimIdx != packedDim) - return emitted; - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); - return rewriter.create(loc, i32Ty, emitted, two) - .getResult(); - }; - - if (validShape.size() > 0 && validShape[0] < 0) { - Value validRow = adaptor.getValidRow(); - if (!validRow) - return rewriter.notifyMatchFailure( - op, "dynamic alloc_tile valid row must have an operand"); - if (validRow) - validRow = peelUnrealized(validRow); - constructorArgs.push_back(maybeScaleDynamicValid(validRow, 0)); - } - if (validShape.size() > 1 && validShape[1] < 0) { - Value validCol = adaptor.getValidCol(); - if (!validCol) - return rewriter.notifyMatchFailure( - op, "dynamic alloc_tile valid col must have an operand"); - if (validCol) - validCol = peelUnrealized(validCol); - constructorArgs.push_back(maybeScaleDynamicValid(validCol, 1)); - } - } - - Value tile; - if (useConstructor) { - tile = rewriter - .create( - loc, convertedTy, *tileTypeString, ArrayAttr{}, - ArrayAttr{}, ValueRange(constructorArgs)) - .getResult(0); - } else { - tile = - rewriter - .create( - loc, convertedTy, emitc::OpaqueAttr::get(ctx, "")) - .getResult(); - } - - Value addr = adaptor.getAddr(); - if (addr) { - addr = peelUnrealized(addr); - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - if (isa(addr.getType()) || - (isa(addr.getType()) && - cast(addr.getType()).getValue().ends_with("*"))) { - auto rcU64 = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - addr = rewriter - .create(loc, u64Ty, "reinterpret_cast", - ArrayAttr{}, rcU64, - ValueRange{addr}) - .getResult(0); - } else if (addr.getType() != u64Ty) { - addr = rewriter.create(loc, u64Ty, addr).getResult(); - } - - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{tile, addr}); - } - - rewriter.replaceOp(op, tile); - return success(); - } -}; - -static FailureOr -createEmitCTileVariable(ConversionPatternRewriter &rewriter, Location loc, - const TypeConverter *typeConverter, - pto::TileBufType tileTy) { - auto tileTypeString = getEmitCTileTypeString(tileTy); - if (!tileTypeString) - return failure(); - - Type convertedTy = typeConverter->convertType(tileTy); - if (!convertedTy) - convertedTy = emitc::OpaqueType::get(rewriter.getContext(), *tileTypeString); - - return rewriter - .create( - loc, convertedTy, emitc::OpaqueAttr::get(rewriter.getContext(), "")) - .getResult(); -} - -struct PTOTReshapeToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TReshapeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto tileTy = dyn_cast(op.getResult().getType()); - if (!tileTy) - return failure(); - - FailureOr dst = - createEmitCTileVariable(rewriter, op.getLoc(), getTypeConverter(), tileTy); - if (failed(dst)) - return failure(); - - Value src = peelUnrealized(adaptor.getSrc()); - if (auto castOp = src.getDefiningOp()) - src = castOp.getOperand(); - - rewriter.create(op.getLoc(), TypeRange{}, "TRESHAPE", - ArrayAttr{}, ArrayAttr{}, - ValueRange{*dst, src}); - rewriter.replaceOp(op, *dst); - return success(); - } -}; - -struct PTOBitcastToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::BitcastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto dstTy = dyn_cast(op.getResult().getType()); - auto srcTy = dyn_cast(op.getSrc().getType()); - if (!dstTy || !srcTy) - return failure(); - - FailureOr dst = - createEmitCTileVariable(rewriter, op.getLoc(), getTypeConverter(), dstTy); - if (failed(dst)) - return failure(); - - Value src = peelUnrealized(adaptor.getSrc()); - if (auto castOp = src.getDefiningOp()) - src = castOp.getOperand(); - - pto::AddressSpace as = pto::AddressSpace::GM; - if (auto asAttr = - dyn_cast_or_null(srcTy.getMemorySpace())) - as = asAttr.getAddressSpace(); - std::string elemTok = getEmitCScalarTypeToken(srcTy.getElementType()); - - Value rawPtr = materializeTileDataValue(rewriter, op.getLoc(), src, as, elemTok); - auto u64Ty = emitc::OpaqueType::get(rewriter.getContext(), "uint64_t"); - Value addr = rawPtr; - if (isSetFFTsPointerLikeType(rawPtr.getType())) { - auto rcU64 = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), - "uint64_t")}); - addr = rewriter - .create(op.getLoc(), u64Ty, - "reinterpret_cast", ArrayAttr{}, - rcU64, ValueRange{rawPtr}) - .getResult(0); - } else if (addr.getType() != u64Ty) { - addr = rewriter.create(op.getLoc(), u64Ty, addr).getResult(); - } - - rewriter.create(op.getLoc(), TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{*dst, addr}); - rewriter.replaceOp(op, *dst); - return success(); - } -}; - -struct PTOMaterializeTileToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - static bool isTileLike(Value v) { - auto ot = dyn_cast(v.getType()); - if (!ot) - return false; - StringRef s = ot.getValue(); - return s.contains("Tile<") || s.contains("ConvTile<"); - } - - LogicalResult matchAndRewrite(pto::MaterializeTileOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - MLIRContext *ctx = rewriter.getContext(); - auto tileTy = cast(op.getResult().getType()); - auto tileTypeString = getEmitCTileTypeString(tileTy); - if (!tileTypeString) - return rewriter.notifyMatchFailure( - op, "only rank-2 tile_buf handles can be materialized to EmitC"); - - Type convertedTy = getTypeConverter()->convertType(tileTy); - if (!convertedTy) - convertedTy = emitc::OpaqueType::get(ctx, *tileTypeString); - - Value source = peelUnrealized(adaptor.getSource()); - if (auto castOp = source.getDefiningOp()) - source = castOp.getOperand(); - - auto viewSemantics = op->getAttrOfType("pto.view_semantics"); - bool forceDynamicValid = op->hasAttr(kForceDynamicValidShapeAttrName); - bool isReshape = viewSemantics && viewSemantics.getValue() == "treshape"; - bool isSubview = viewSemantics && viewSemantics.getValue() == "subview"; - bool sourceIsDeclaredTile = - op.getSource().getDefiningOp(); - - auto createTileValue = [&]() -> Value { - SmallVector constructorArgs; - bool useConstructor = false; - pto::BLayout blayout = getTileBufBLayoutValue(tileTy.getConfigAttr()); - Type elemTy = tileTy.getElementType(); - auto shape = tileTy.getShape(); - auto validShape = tileTy.getValidShape(); - - auto makeCtorDimValue = [&](Value emitted, int64_t fallback) -> Value { - if (emitted) - return emitted; - return makeEmitCIntConstant( - rewriter, loc, emitc::OpaqueType::get(ctx, "int32_t"), fallback); - }; - auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { - if (!emitted || !pto::isPTOFloat4PackedType(elemTy)) - return emitted; - int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; - if (dimIdx != packedDim) - return emitted; - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); - return rewriter.create(loc, i32Ty, emitted, two).getResult(); - }; - auto fallbackDim = [&](int dimIdx) { - return renderTileTemplateDim(shape[dimIdx], elemTy, blayout, dimIdx); - }; - - if (forceDynamicValid) { - useConstructor = true; - constructorArgs.push_back(makeCtorDimValue( - maybeScaleDynamicValid(adaptor.getValidRow(), 0), fallbackDim(0))); - constructorArgs.push_back(makeCtorDimValue( - maybeScaleDynamicValid(adaptor.getValidCol(), 1), fallbackDim(1))); - } else { - if (validShape[0] == ShapedType::kDynamic) { - useConstructor = true; - constructorArgs.push_back(makeCtorDimValue( - maybeScaleDynamicValid(adaptor.getValidRow(), 0), fallbackDim(0))); - } - if (validShape[1] == ShapedType::kDynamic) { - useConstructor = true; - constructorArgs.push_back(makeCtorDimValue( - maybeScaleDynamicValid(adaptor.getValidCol(), 1), fallbackDim(1))); - } - } - - if (useConstructor) { - return rewriter - .create(loc, convertedTy, *tileTypeString, - ArrayAttr{}, ArrayAttr{}, - ValueRange(constructorArgs)) - .getResult(0); - } - - return rewriter - .create(loc, convertedTy, - emitc::OpaqueAttr::get(ctx, "")) - .getResult(); - }; - - if (!isSubview && !forceDynamicValid && isTileLike(source)) { - if (auto srcTy = dyn_cast(source.getType())) { - if (srcTy.getValue() == *tileTypeString) { - rewriter.replaceOp(op, source); - return success(); - } - } - } - - Value tile = createTileValue(); - if (sourceIsDeclaredTile) { - rewriter.replaceOp(op, tile); - return success(); - } - - if (isReshape && isTileLike(source)) { - rewriter.create(loc, TypeRange{}, "TRESHAPE", - ArrayAttr{}, ArrayAttr{}, - ValueRange{tile, source}); - rewriter.replaceOp(op, tile); - return success(); - } - - pto::AddressSpace as = pto::AddressSpace::GM; - if (auto asAttr = - dyn_cast_or_null(tileTy.getMemorySpace())) - as = asAttr.getAddressSpace(); - std::string elemTok = getEmitCScalarTypeToken(tileTy.getElementType()); - - Value rawPtr = source; - if (isTileLike(rawPtr)) - rawPtr = materializeTileDataValue(rewriter, loc, rawPtr, as, elemTok); - - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - Value addr = rawPtr; - if (isSetFFTsPointerLikeType(rawPtr.getType())) { - auto rcU64 = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - addr = rewriter - .create(loc, u64Ty, "reinterpret_cast", - ArrayAttr{}, rcU64, - ValueRange{rawPtr}) - .getResult(0); - } else if (rawPtr.getType() != u64Ty) { - addr = rewriter.create(loc, u64Ty, rawPtr).getResult(); - } - - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{tile, addr}); - rewriter.replaceOp(op, tile); - return success(); - } -}; - -// ============================================================================= -// Arith CmpI -> EmitC Cmp -// ============================================================================= -class ArithCmpIToEmitC : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - // 将 arith.cmpi 转换为 emitc.cmp - // 映射 Predicate: eq -> equal, slt -> less, etc. - emitc::CmpPredicate emitcPred = emitc::CmpPredicate::eq; - const bool isUnsignedPred = - op.getPredicate() == arith::CmpIPredicate::ult || - op.getPredicate() == arith::CmpIPredicate::ule || - op.getPredicate() == arith::CmpIPredicate::ugt || - op.getPredicate() == arith::CmpIPredicate::uge; - switch (op.getPredicate()) { - case arith::CmpIPredicate::eq: emitcPred = emitc::CmpPredicate::eq; break; - case arith::CmpIPredicate::ne: emitcPred = emitc::CmpPredicate::ne; break; - case arith::CmpIPredicate::slt: emitcPred = emitc::CmpPredicate::lt; break; - case arith::CmpIPredicate::sle: emitcPred = emitc::CmpPredicate::le; break; - case arith::CmpIPredicate::sgt: emitcPred = emitc::CmpPredicate::gt; break; - case arith::CmpIPredicate::sge: emitcPred = emitc::CmpPredicate::ge; break; - // ... 处理无符号比较 (ult, ule 等) ... - case arith::CmpIPredicate::ult: emitcPred = emitc::CmpPredicate::lt; break; - case arith::CmpIPredicate::ule: emitcPred = emitc::CmpPredicate::le; break; - case arith::CmpIPredicate::ugt: emitcPred = emitc::CmpPredicate::gt; break; - case arith::CmpIPredicate::uge: emitcPred = emitc::CmpPredicate::ge; break; - } - - Type resTy = getTypeConverter()->convertType(op.getType()); - if (!resTy) - return failure(); - - Value lhs = adaptor.getLhs(); - Value rhs = adaptor.getRhs(); - if (isUnsignedPred) { - Type opTy = op.getLhs().getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure( - op, "expected scalar integer or index operands"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - if (bitWidth != 1) { - lhs = castSignlessIntToUnsignedSameWidth(rewriter, loc, lhs, bitWidth); - rhs = castSignlessIntToUnsignedSameWidth(rewriter, loc, rhs, bitWidth); - } - } - - rewriter.replaceOpWithNewOp( - op, - /*resultType=*/resTy, // i1 -> bool/i1 - emitcPred, - lhs, - rhs - ); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// Section Op Lowering -//===----------------------------------------------------------------------===// -static bool isA5NoSplitPipeOp(Operation *op) { - if (auto talloc = dyn_cast(op)) - return talloc.getSplit() == 0; - if (auto tpush = dyn_cast(op)) - return tpush.getSplit() == 0; - if (auto tpop = dyn_cast(op)) - return tpop.getSplit() == 0; - if (auto tfree = dyn_cast(op)) - return tfree.getSplit() == 0; - if (auto tpush = dyn_cast(op)) - return tpush.getSplit() == 0; - if (auto tpush = dyn_cast(op)) - return tpush.getSplit() == 0; - if (auto talloc = dyn_cast(op)) - return talloc.getSplit() == 0; - if (auto talloc = dyn_cast(op)) - return talloc.getSplit() == 0; - if (auto tpop = dyn_cast(op)) - return tpop.getSplit() == 0; - if (auto tpop = dyn_cast(op)) - return tpop.getSplit() == 0; - if (auto tfree = dyn_cast(op)) - return tfree.getSplit() == 0; - if (auto tfree = dyn_cast(op)) - return tfree.getSplit() == 0; - return false; -} - -static bool hasExplicitSubblockControl(Operation *op) { - bool hasControl = false; - op->walk([&](Operation *nested) { - if (isa(nested)) { - hasControl = true; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - return hasControl; -} - -static bool needsA5NoSplitVectorGuard(Operation *op) { - auto arch = getTargetArch(op); - if (arch != PTOArch::A5) - return false; - bool isVectorScope = isa(op); - if (auto func = dyn_cast(op)) { - if (auto kernelKindAttr = - func->getAttrOfType( - FunctionKernelKindAttr::name)) { - isVectorScope = - kernelKindAttr.getKernelKind() == FunctionKernelKind::Vector; - } - } - if (!isVectorScope) - return false; - if (hasExplicitSubblockControl(op)) - return false; - - bool hasNoSplitPipe = false; - op->walk([&](Operation *nested) { - if (!isA5NoSplitPipeOp(nested)) - return WalkResult::advance(); - hasNoSplitPipe = true; - return WalkResult::interrupt(); - }); - return hasNoSplitPipe; -} - -template -struct SectionToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - std::string getMacroName() const { - if (std::is_same::value) - return "__DAV_CUBE__"; - if (std::is_same::value) - return "__DAV_VEC__"; - return "UNKNOWN_MACRO"; - } - - LogicalResult - matchAndRewrite(SectionOpTy op, typename SectionOpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - bool needsNoSplitGuard = needsA5NoSplitVectorGuard(op.getOperation()); - - std::string startMacro = "\n#if defined(" + getMacroName() + ")"; - rewriter.create(loc, startMacro); - - if constexpr (std::is_same_v) { - // Vector mask is a global HW state and may be modified by previous kernels - // (or earlier sections). Reset it to a well-defined state for deterministic - // execution of VEC ops. - rewriter.create(loc, "set_mask_norm();"); - rewriter.create(loc, "set_vector_mask(-1, -1);"); - } - - if (needsNoSplitGuard) { - rewriter.create( - loc, "if (get_subblockid() == 0) {"); - } - - Block &innerBlock = op.getBody().front(); - if (!innerBlock.empty()) { - rewriter.inlineBlockBefore(&innerBlock, op.getOperation(), ValueRange{}); - } - - if (needsNoSplitGuard) - rewriter.create(loc, "}"); - - std::string endMacro = "#endif // " + getMacroName() + "\n"; - rewriter.create(loc, endMacro); - - rewriter.eraseOp(op); - - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// SCF Control-Flow Pre-Lowering -// -// EmitC translation supports `emitc.for`/`emitc.if` plus CFG-style -// `cf.br`/`cf.cond_br`. Upstream SCFToEmitC patterns only cover `scf.for` and -// `scf.if`, so we pre-lower some SCF ops into those supported forms. -//===----------------------------------------------------------------------===// - -namespace { - -static bool isTriviallyInlineableExecuteRegion(scf::ExecuteRegionOp op) { - Region &r = op.getRegion(); - if (!r.hasOneBlock()) - return false; - Block &b = r.front(); - return isa_and_nonnull(b.getTerminator()); -} - -static bool needsWholeFunctionSCFToCF(func::FuncOp func) { - bool needs = false; - func.walk([&](Operation *op) { - if (!isa(op)) - return WalkResult::advance(); - Operation *parentOp = op->getParentOp(); - - // `scf.execute_region` can legally appear in single-block parents. Only - // require whole-function SCFToCF if we need to lower it into CFG blocks - // (multi-block region / non-trivial terminators). - if (auto exec = dyn_cast(op)) { - if (parentOp && parentOp->hasTrait() && - !isTriviallyInlineableExecuteRegion(exec)) { - needs = true; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - } - - if (parentOp && parentOp->hasTrait()) { - needs = true; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - return needs; -} - -// scf.execute_region is semantically just an inlined region producing results -// via scf.yield. Inline it to the parent block to avoid extra lowering needs. -struct SCFExecuteRegionInline - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(scf::ExecuteRegionOp op, - PatternRewriter &rewriter) const override { - if (op.getRegion().empty()) - return rewriter.notifyMatchFailure(op, "expected non-empty region"); - - Block &innerBlock = op.getRegion().front(); - auto yield = dyn_cast(innerBlock.getTerminator()); - if (!yield) - return rewriter.notifyMatchFailure(op, "expected scf.yield terminator"); - - // Move the body operations before the execute_region op. - rewriter.inlineBlockBefore(&innerBlock, op.getOperation(), ValueRange{}); - - // Replace execute_region results with yielded values, then erase the yield. - rewriter.replaceOp(op, yield.getOperands()); - rewriter.eraseOp(yield); - return success(); - } -}; - -// Lower scf.execute_region into CFG blocks with cf.br/cf.cond_br by inlining the -// region blocks into the parent region and rewriting scf.yield to branch into a -// continuation block carrying results. -// -// Note: This requires the parent region to allow multiple blocks (e.g. the -// function body CFG region). For execute_region nested in single-block regions -// (scf.for/scf.if), run SCFToCF first to eliminate the single-block constraint. -struct SCFExecuteRegionToCF : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(scf::ExecuteRegionOp op, - PatternRewriter &rewriter) const override { - if (isTriviallyInlineableExecuteRegion(op)) - return rewriter.notifyMatchFailure(op, "trivially inlineable"); - - Operation *parentOp = op->getParentOp(); - if (parentOp && parentOp->hasTrait()) { - return rewriter.notifyMatchFailure( - op, "cannot lower scf.execute_region inside a single-block parent region"); - } - - if (op.getRegion().empty()) - return rewriter.notifyMatchFailure(op, "expected non-empty region"); - - Location loc = op.getLoc(); - Block *curBlock = op->getBlock(); - Region *parentRegion = curBlock->getParent(); - - // Split the parent block so we can branch to a continuation block with phi - // arguments for the execute_region results. - auto execIt = Block::iterator(op.getOperation()); - Block *continueBlock = rewriter.splitBlock(curBlock, std::next(execIt)); - - SmallVector contArgs; - contArgs.reserve(op.getNumResults()); - for (Type t : op.getResultTypes()) - contArgs.push_back(continueBlock->addArgument(t, loc)); - - for (auto it : llvm::enumerate(op.getResults())) - it.value().replaceAllUsesWith(contArgs[it.index()]); - - // Capture blocks before moving the region. - SmallVector movedBlocks; - movedBlocks.reserve(op.getRegion().getBlocks().size()); - for (Block &b : op.getRegion()) - movedBlocks.push_back(&b); - Block *entryBlock = &op.getRegion().front(); - - // Inline the execute_region blocks into the parent region right before the - // continuation block. - rewriter.inlineRegionBefore(op.getRegion(), *parentRegion, - continueBlock->getIterator()); - - // Replace all scf.yield terminators with a branch to the continuation. - for (Block *b : movedBlocks) { - auto yield = dyn_cast(b->getTerminator()); - if (!yield) - continue; - rewriter.setInsertionPoint(yield); - rewriter.create(loc, continueBlock, yield.getOperands()); - rewriter.eraseOp(yield); - } - - // Replace execute_region itself with a branch to the inlined entry block. - rewriter.setInsertionPoint(op); - rewriter.create(loc, entryBlock, ValueRange{}); - rewriter.eraseOp(op); - return success(); - } -}; - -// Lower scf.index_switch into CFG blocks with cf.cond_br/cf.br so that we can -// avoid `scf.if` result materialization quirks (and avoid relying on cf.switch, -// which is not supported by EmitC C++ translation). -struct SCFIndexSwitchToCF : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - static LogicalResult cloneYieldingBlockAndBranchTo( - PatternRewriter &rewriter, Location loc, Block &srcBlock, Block *destBlock, - Block *continueBlock) { - rewriter.setInsertionPointToEnd(destBlock); - - IRMapping mapping; - for (Operation &inner : srcBlock.without_terminator()) - rewriter.clone(inner, mapping); - - auto yield = dyn_cast(srcBlock.getTerminator()); - if (!yield) - return failure(); - - SmallVector yieldOperands; - yieldOperands.reserve(yield.getNumOperands()); - for (Value v : yield.getOperands()) - yieldOperands.push_back(mapping.lookupOrDefault(v)); - - rewriter.create(loc, continueBlock, yieldOperands); - return success(); - } - - static Block *splitBlockForContinuation(PatternRewriter &rewriter, - scf::IndexSwitchOp op) { - auto switchIt = Block::iterator(op.getOperation()); - return rewriter.splitBlock(op->getBlock(), std::next(switchIt)); - } - - static void addContinuationArguments(PatternRewriter &rewriter, - scf::IndexSwitchOp op, Location loc, - Block *continueBlock) { - SmallVector contArgs; - contArgs.reserve(op.getNumResults()); - for (Type type : op.getResultTypes()) - contArgs.push_back(continueBlock->addArgument(type, loc)); - for (auto result : llvm::enumerate(op.getResults())) - result.value().replaceAllUsesWith(contArgs[result.index()]); - } - - static void createIndexSwitchBlocks(PatternRewriter &rewriter, - Region *parentRegion, - Region::iterator insertPt, - unsigned numCases, - SmallVectorImpl &checkBlocks, - Block *&defaultBlock, - SmallVectorImpl &caseBlocks) { - checkBlocks.reserve(numCases); - caseBlocks.reserve(numCases); - for (unsigned i = 0; i < numCases; ++i) - checkBlocks.push_back(rewriter.createBlock(parentRegion, insertPt)); - defaultBlock = rewriter.createBlock(parentRegion, insertPt); - for (unsigned i = 0; i < numCases; ++i) - caseBlocks.push_back(rewriter.createBlock(parentRegion, insertPt)); - } - - static void populateIndexSwitchCheckBlocks( - PatternRewriter &rewriter, Location loc, Value selector, - ArrayRef cases, ArrayRef checkBlocks, - ArrayRef caseBlocks, Block *defaultBlock) { - for (unsigned i = 0; i < checkBlocks.size(); ++i) { - rewriter.setInsertionPointToEnd(checkBlocks[i]); - Value caseVal = rewriter.create(loc, cases[i]); - Value cond = rewriter.create( - loc, arith::CmpIPredicate::eq, selector, caseVal); - Block *falseDest = - (i + 1 < checkBlocks.size()) ? checkBlocks[i + 1] : defaultBlock; - rewriter.create(loc, cond, caseBlocks[i], ValueRange{}, - falseDest, ValueRange{}); - } - } - - LogicalResult matchAndRewrite(scf::IndexSwitchOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Operation *parentOp = op->getParentOp(); - if (parentOp && parentOp->hasTrait()) { - return rewriter.notifyMatchFailure( - op, "cannot lower scf.index_switch inside a single-block parent region"); - } - - Block *curBlock = op->getBlock(); - Region *parentRegion = curBlock->getParent(); - Block *continueBlock = splitBlockForContinuation(rewriter, op); - addContinuationArguments(rewriter, op, loc, continueBlock); - - unsigned numCases = op.getCases().size(); - auto insertPt = continueBlock->getIterator(); - - SmallVector checkBlocks; - SmallVector caseBlocks; - Block *defaultBlock = nullptr; - createIndexSwitchBlocks(rewriter, parentRegion, insertPt, numCases, - checkBlocks, defaultBlock, caseBlocks); - - Value selector = op.getArg(); - auto cases = op.getCases(); - populateIndexSwitchCheckBlocks(rewriter, loc, selector, cases, checkBlocks, - caseBlocks, defaultBlock); - - // Fill case blocks and default block with cloned bodies + branch to cont. - for (unsigned i = 0; i < numCases; ++i) { - if (failed(cloneYieldingBlockAndBranchTo( - rewriter, loc, op.getCaseBlock(i), caseBlocks[i], continueBlock))) - return rewriter.notifyMatchFailure(op, "expected scf.yield terminator"); - } - if (failed(cloneYieldingBlockAndBranchTo(rewriter, loc, op.getDefaultBlock(), - defaultBlock, continueBlock))) - return rewriter.notifyMatchFailure(op, "expected scf.yield terminator"); - - // Replace the original switch op with a branch into the check chain. - Block *entryDest = numCases ? checkBlocks[0] : defaultBlock; - rewriter.setInsertionPointAfter(op); - rewriter.create(loc, entryDest, ValueRange{}); - rewriter.eraseOp(op); - return success(); - } -}; - -// Lower scf.while into CFG blocks with cf.br/cf.cond_br. -// -// Note: This requires the parent region to allow multiple blocks. In -// particular, scf.if/scf.for regions are single-block and cannot contain this -// lowering. -struct SCFWhileToCF : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - static LogicalResult validateWhileResultUses(scf::WhileOp op) { - Block *parentBlock = op->getBlock(); - for (Value result : op.getResults()) { - for (OpOperand &use : result.getUses()) { - if (use.getOwner()->getBlock() != parentBlock) - return failure(); - } - } - return success(); - } - - static Block *splitAfterWhileBlock(PatternRewriter &rewriter, - scf::WhileOp op) { - auto whileIt = Block::iterator(op.getOperation()); - return rewriter.splitBlock(op->getBlock(), std::next(whileIt)); - } - - static void addWhileExitArguments(PatternRewriter &rewriter, scf::WhileOp op, - Location loc, Block *afterWhileBlock) { - SmallVector exitArgs; - exitArgs.reserve(op.getNumResults()); - for (Type type : op.getResultTypes()) - exitArgs.push_back(afterWhileBlock->addArgument(type, loc)); - for (auto result : llvm::enumerate(op.getResults())) - result.value().replaceAllUsesWith(exitArgs[result.index()]); - } - - static Block *createWhileHeaderBlock(PatternRewriter &rewriter, - scf::WhileOp op, Location loc, - Block *afterWhileBlock) { - SmallVector headerArgTypes; - for (Value init : op.getInits()) - headerArgTypes.push_back(init.getType()); - SmallVector headerArgLocs(headerArgTypes.size(), loc); - return rewriter.createBlock(afterWhileBlock->getParent(), - afterWhileBlock->getIterator(), headerArgTypes, - headerArgLocs); - } - - static Block *createWhileBodyBlock(PatternRewriter &rewriter, scf::WhileOp op, - Location loc, Block *afterWhileBlock) { - Block &afterRegionBlock = op.getAfter().front(); - SmallVector bodyArgTypes(afterRegionBlock.getArgumentTypes().begin(), - afterRegionBlock.getArgumentTypes().end()); - SmallVector bodyArgLocs(bodyArgTypes.size(), loc); - return rewriter.createBlock(afterWhileBlock->getParent(), - afterWhileBlock->getIterator(), bodyArgTypes, - bodyArgLocs); - } - - static void rewriteWhileTerminators(PatternRewriter &rewriter, Location loc, - Block *headerBlock, Block *bodyBlock, - Block *afterWhileBlock) { - auto condOp = cast(headerBlock->getTerminator()); - rewriter.setInsertionPoint(condOp); - rewriter.create(loc, condOp.getCondition(), - /*trueDest=*/bodyBlock, - /*trueOperands=*/condOp.getArgs(), - /*falseDest=*/afterWhileBlock, - /*falseOperands=*/condOp.getArgs()); - rewriter.eraseOp(condOp); - - auto yieldOp = cast(bodyBlock->getTerminator()); - rewriter.setInsertionPoint(yieldOp); - rewriter.create(loc, headerBlock, yieldOp.getOperands()); - rewriter.eraseOp(yieldOp); - } - - LogicalResult matchAndRewrite(scf::WhileOp op, - PatternRewriter &rewriter) const override { - Operation *parentOp = op->getParentOp(); - if (parentOp && parentOp->hasTrait()) { - return rewriter.notifyMatchFailure( - op, "cannot lower scf.while inside a single-block parent region"); - } - - if (failed(validateWhileResultUses(op))) - return rewriter.notifyMatchFailure( - op, "unsupported: while results used outside the parent block"); - - auto loc = op.getLoc(); - Block *afterWhileBlock = splitAfterWhileBlock(rewriter, op); - addWhileExitArguments(rewriter, op, loc, afterWhileBlock); - Block *headerBlock = createWhileHeaderBlock(rewriter, op, loc, - afterWhileBlock); - Block *bodyBlock = createWhileBodyBlock(rewriter, op, loc, afterWhileBlock); - - // Move the before/after region bodies into the new CFG blocks. - Block &afterRegionBlock = op.getAfter().front(); - rewriter.mergeBlocks(&op.getBefore().front(), headerBlock, - headerBlock->getArguments()); - rewriter.mergeBlocks(&afterRegionBlock, bodyBlock, bodyBlock->getArguments()); - rewriteWhileTerminators(rewriter, loc, headerBlock, bodyBlock, - afterWhileBlock); - - // Replace scf.while itself with a branch to the header. - rewriter.setInsertionPoint(op); - rewriter.create(loc, headerBlock, op.getInits()); - rewriter.eraseOp(op); - return success(); - } -}; - -// Lower cf.switch into chained comparisons and cf.cond_br/cf.br. -// -// EmitC C++ translation currently supports cf.br/cf.cond_br, but not cf.switch. -struct CFSwitchToCondBr : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - static SmallVector> - collectSwitchCaseOperands(cf::SwitchOp op) { - SmallVector> caseOperands; - caseOperands.reserve(op.getCaseDestinations().size()); - for (auto range : op.getCaseOperands()) - caseOperands.emplace_back(range.begin(), range.end()); - return caseOperands; - } - - static SmallVector getSwitchCaseValues(cf::SwitchOp op) { - SmallVector caseValues; - if (auto caseValuesAttr = op.getCaseValues()) { - for (APInt value : caseValuesAttr->getValues()) - caseValues.push_back(value); - } - return caseValues; - } - - static SmallVector createSwitchCheckBlocks(PatternRewriter &rewriter, - Region *parentRegion, - Block *curBlock, - size_t numCases) { - auto insertPt = std::next(curBlock->getIterator()); - SmallVector checkBlocks; - checkBlocks.reserve(numCases); - for (size_t i = 0; i < numCases; ++i) - checkBlocks.push_back(rewriter.createBlock(parentRegion, insertPt)); - return checkBlocks; - } - - static LogicalResult populateSwitchCheckBlocks( - PatternRewriter &rewriter, Location loc, Value flag, IntegerType flagTy, - ArrayRef caseValues, ArrayRef caseDests, - ArrayRef> caseOperands, Block *defaultDest, - ValueRange defaultOperands, ArrayRef checkBlocks, - cf::SwitchOp op) { - for (size_t i = 0; i < caseDests.size(); ++i) { - rewriter.setInsertionPointToEnd(checkBlocks[i]); - APInt caseVal = caseValues[i]; - if (caseVal.getBitWidth() != flagTy.getWidth()) { - return rewriter.notifyMatchFailure( - op, "case value bitwidth doesn't match flag type"); - } - - Value caseConst = rewriter.create( - loc, flagTy, rewriter.getIntegerAttr(flagTy, caseVal)); - Value cond = rewriter.create( - loc, arith::CmpIPredicate::eq, flag, caseConst); - Block *falseDest = - (i + 1 < checkBlocks.size()) ? checkBlocks[i + 1] : defaultDest; - ValueRange falseOperands = - (i + 1 < checkBlocks.size()) ? ValueRange{} : defaultOperands; - rewriter.create(loc, cond, caseDests[i], - caseOperands[i], falseDest, - falseOperands); - } - return success(); - } - - LogicalResult matchAndRewrite(cf::SwitchOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Operation *parentOp = op->getParentOp(); - if (parentOp && parentOp->hasTrait()) { - return rewriter.notifyMatchFailure( - op, "cannot lower cf.switch inside a single-block parent region"); - } - - Block *curBlock = op->getBlock(); - Region *parentRegion = curBlock->getParent(); - - Value flag = op.getFlag(); - auto flagTy = dyn_cast(flag.getType()); - if (!flagTy) - return rewriter.notifyMatchFailure(op, "expected integer switch flag"); - - SmallVector defaultOperands(op.getDefaultOperands().begin(), - op.getDefaultOperands().end()); - Block *defaultDest = op.getDefaultDestination(); - - SmallVector caseDests(op.getCaseDestinations().begin(), - op.getCaseDestinations().end()); - SmallVector> caseOperands = collectSwitchCaseOperands(op); - - if (caseDests.empty()) { - rewriter.replaceOpWithNewOp(op, defaultDest, defaultOperands); - return success(); - } - - if (!op.getCaseValues()) - return rewriter.notifyMatchFailure(op, "missing case_values"); - SmallVector caseValues = getSwitchCaseValues(op); - - if (caseValues.size() != caseDests.size()) - return rewriter.notifyMatchFailure(op, "case_values/destinations mismatch"); - if (caseOperands.size() != caseDests.size()) - return rewriter.notifyMatchFailure(op, "case_operands/destinations mismatch"); - - SmallVector checkBlocks = - createSwitchCheckBlocks(rewriter, parentRegion, curBlock, - caseDests.size()); - if (failed(populateSwitchCheckBlocks(rewriter, loc, flag, flagTy, - caseValues, caseDests, caseOperands, - defaultDest, defaultOperands, - checkBlocks, op))) { - return failure(); - } - - // Replace the switch terminator with a branch into the first check block. - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp(op, checkBlocks.front(), - ValueRange{}); - return success(); - } -}; -} // namespace static void populatePTOToEmitCPatterns(RewritePatternSet &patterns, TypeConverter &typeConverter, MLIRContext *ctx, diff --git a/lib/PTO/Transforms/PTOToEmitCSimpleOps.cpp b/lib/PTO/Transforms/PTOToEmitCSimpleOps.cpp index 24dc00c2e..a2127a34e 100644 --- a/lib/PTO/Transforms/PTOToEmitCSimpleOps.cpp +++ b/lib/PTO/Transforms/PTOToEmitCSimpleOps.cpp @@ -178,8 +178,38 @@ struct PTOHistogramToEmitC : public OpConversionPattern { Value idx = peelUnrealized(adaptor.getIdx()); Value dst = peelUnrealized(adaptor.getDst()); - auto templateArgs = rewriter.getArrayAttr({emitc::OpaqueAttr::get( - ctx, op.getIsMSB() ? "HistByte::BYTE_1" : "HistByte::BYTE_0")}); + StringRef histByte = "HistByte::BYTE_1"; + int64_t byte = 1; + auto byteAttr = op.getByteAttr(); + if (byteAttr) + byte = byteAttr.getInt(); + if (auto legacyIsMSB = op->getAttrOfType("isMSB")) { + int64_t legacyByte = legacyIsMSB.getValue() ? 1 : 0; + if (byteAttr && byte != legacyByte) + return rewriter.notifyMatchFailure( + op, "conflicting 'byte' and legacy 'isMSB' attributes"); + byte = legacyByte; + } + switch (byte) { + case 0: + histByte = "HistByte::BYTE_0"; + break; + case 1: + histByte = "HistByte::BYTE_1"; + break; + case 2: + histByte = "HistByte::BYTE_2"; + break; + case 3: + histByte = "HistByte::BYTE_3"; + break; + default: + return rewriter.notifyMatchFailure(op, + "expected byte to be in range [0, 3]"); + } + + auto templateArgs = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, histByte)}); rewriter.create( loc, TypeRange{}, "THISTOGRAM", /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, From b535fdc4813eb2faf41635a74c22b94dbd2b32bf Mon Sep 17 00:00:00 2001 From: HecreReed <821896444@qq.com> Date: Tue, 26 May 2026 19:53:20 +0800 Subject: [PATCH 8/8] fix: add utf-8 headers for python buildcheck --- python/pto/dialects/pto.py | 1 + tools/ptobc/tests/opcode_coverage_check.py | 1 + 2 files changed, 2 insertions(+) diff --git a/python/pto/dialects/pto.py b/python/pto/dialects/pto.py index a0b5037c7..eca06e81b 100644 --- a/python/pto/dialects/pto.py +++ b/python/pto/dialects/pto.py @@ -1,4 +1,5 @@ # Copyright (c) 2026 Huawei Technologies Co., Ltd. +# -*- coding: utf-8 -*- # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). # Please refer to the License for details. You may not use this file except in compliance with the License. diff --git a/tools/ptobc/tests/opcode_coverage_check.py b/tools/ptobc/tests/opcode_coverage_check.py index 757c7ea40..c4d99ff5b 100755 --- a/tools/ptobc/tests/opcode_coverage_check.py +++ b/tools/ptobc/tests/opcode_coverage_check.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# -*- coding: utf-8 -*- # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License").