diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index b8194d0a4..35ea68387 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -16,8 +16,20 @@ add_mlir_dialect_library(PTOTransforms PTOInjectBarrierAllSync.cpp InsertSync/InsertSyncDebug.cpp PTOViewToMemref.cpp + PTOViewToMemrefCompute.cpp PTOValidateIntToPtrUses.cpp PTOToEmitC.cpp + PTOToEmitCArith.cpp + PTOToEmitCTilePatterns.cpp + PTOToEmitCTilePatternsExtra.cpp + PTOToEmitCTileMaterialization.cpp + PTOToEmitCSync.cpp + PTOToEmitCComm.cpp + PTOToEmitCKernelOps.cpp + PTOToEmitCControlFlow.cpp + PTOToEmitCSimpleOps.cpp + PTOToEmitCRuntimeOps.cpp + PTOToEmitCMemoryOps.cpp Utils.cpp OptMemPlanForPipeline.cpp AllocToPointerCast.cpp @@ -50,6 +62,7 @@ add_mlir_dialect_library(PTOTransforms GraphSyncSolver/GraphSolver.cpp GraphSyncSolver/EventIdSolver.cpp GraphSyncSolver/SyncSolver.cpp + GraphSyncSolver/SyncSolverMerge.cpp GraphSyncSolver/SyncSolverCodeGen.cpp LoweringSyncToPipe.cpp PTOVerifyTFreePass.cpp diff --git a/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp b/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp index 23a4032a6..c920fd69e 100644 --- a/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp +++ b/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp @@ -1911,666 +1911,3 @@ void Solver::handleConflict(Occurrence *occ1, Occurrence *occ2, } } -void Solver::calcAllEventIds() { - for (auto &[pipes, eventIdSolver] : eventIdSolver) { - assert(eventIdSolver != nullptr); - - [[maybe_unused]] auto result = - eventIdSolver->shrinkEventIdMaxToEventIdNum(); - assert(llvm::succeeded(result)); - assert(eventIdSolver->isColorable()); - } -} - -void Solver::collectBackwardSyncEventIds() { - LLVM_DEBUG(llvm::dbgs() << "collectBackwardSyncEventIds\n";); - for (auto &conflictPair : chosenConflictedPairs) { - if (!conflictPair->isUseless && conflictPair->isInnerBackward && - conflictPair->eventIdNode != nullptr) { - LLVM_DEBUG(llvm::dbgs() << " " << conflictPair->str() << "\n";); - for (auto eventId : conflictPair->eventIdNode->getEventIds()) { - auto &e = backwardSyncEvents[conflictPair->backwardSyncLoopOp] - [{conflictPair->setCorePipeInfo, - conflictPair->waitCorePipeInfo}][eventId]; - e = std::max(e, conflictPair->eventIdInfo.eventIdRepeatNum); - } - } - } -} - -void Solver::resetAndBuildSetWaitOpIndex(const SyncMap &syncMapBefore, - const SyncMap &syncMapAfter) { - globalSetWaitIndex = 0; - setWaitStartIndex.clear(); - setWaitEndIndex.clear(); - setWaitStartIndexInclusive.clear(); - setWaitEndIndexInclusive.clear(); - setWaitFlagOpsIndex.clear(); - collectSetWaitOpsIndexes(funcIr.get(), syncMapBefore, syncMapAfter); -} - -std::set> & -Solver::getSetWaitOpsIndexRef(pto::PIPE pipeSrc, pto::PIPE pipeDst, - int64_t eventId) { - auto key = std::make_tuple(pipeSrc, pipeDst, eventId); - return setWaitFlagOpsIndex[key]; -} - -// Collect indices for all Set/Wait ops to facilitate merging decisions. -void Solver::collectSetWaitOpsIndexes(OperationBase *op, - const SyncMap &syncMapBefore, - const SyncMap &syncMapAfter) { - assert(op != nullptr); - setWaitStartIndexInclusive[op] = globalSetWaitIndex++; - if (syncMapBefore.count(op)) { - auto *it = syncMapBefore.find(op); - assert(it != syncMapBefore.end()); - for (auto &syncOp : it->second) { - if (auto *setWaitOp = llvm::dyn_cast(syncOp.get())) { - for (auto eventId : setWaitOp->eventIds) { - auto &index = getSetWaitOpsIndexRef(setWaitOp->pipeSrc, - setWaitOp->pipeDst, eventId); - index.insert({globalSetWaitIndex++, setWaitOp}); - } - } - } - } - setWaitStartIndex[op] = globalSetWaitIndex++; - if (auto *scopeOp = llvm::dyn_cast(op)) { - for (auto &childOp : scopeOp->body) { - collectSetWaitOpsIndexes(childOp.get(), syncMapBefore, syncMapAfter); - } - } - setWaitEndIndex[op] = globalSetWaitIndex++; - if (syncMapAfter.count(op)) { - auto *it = syncMapAfter.find(op); - assert(it != syncMapAfter.end()); - for (auto &syncOp : it->second) { - if (auto *setWaitOp = llvm::dyn_cast(syncOp.get())) { - for (auto eventId : setWaitOp->eventIds) { - auto &index = getSetWaitOpsIndexRef(setWaitOp->pipeSrc, - setWaitOp->pipeDst, eventId); - index.insert({globalSetWaitIndex++, setWaitOp}); - } - } - } - } - setWaitEndIndexInclusive[op] = globalSetWaitIndex++; -} - -bool Solver::checkBackwardSyncEventsContains(OperationBase *op, - CorePipeInfo corePipeSrc, - CorePipeInfo corePipeDst, - int64_t eventId) { - auto *it1 = backwardSyncEvents.find(op); - if (it1 == backwardSyncEvents.end()) { - return false; - } - auto it2 = it1->second.find({corePipeSrc, corePipeDst}); - if (it2 == it1->second.end()) { - return false; - } - return it2->second.contains(eventId); -} - -bool Solver::checkBackwardSyncEventsContainsAfterMerge( - OperationBase *op, CorePipeInfo corePipeSrc, CorePipeInfo corePipeDst) { - auto *it1 = backwardSyncEventsAfterMerge.find(op); - if (it1 == backwardSyncEventsAfterMerge.end()) { - return false; - } - return it1->second.contains({corePipeSrc, corePipeDst}); -} - -// Check whether a backward-sync event id can be merged at scope level. -bool Solver::checkMergeable(Scope *scopeOp, CorePipeInfo corePipeSrc, - CorePipeInfo corePipeDst, int64_t eventId, - bool shouldBeUsedAtleastOnce) { - auto &index = - getSetWaitOpsIndexRef(corePipeSrc.pipe, corePipeDst.pipe, eventId); - if (shouldBeUsedAtleastOnce) { - auto it = index.lower_bound({setWaitStartIndexInclusive[scopeOp], nullptr}); - bool usedAtleastOnce = - it != index.end() && it->first < setWaitEndIndexInclusive[scopeOp]; - if (!usedAtleastOnce) { - return false; - } - } - { - auto it1 = - index.lower_bound({setWaitStartIndexInclusive[scopeOp], nullptr}); - auto it2 = index.lower_bound({setWaitEndIndex[scopeOp], nullptr}); - bool usedBefore = - it1 != index.end() && it1->first < setWaitStartIndex[scopeOp]; - bool usedAfter = - it2 != index.end() && it2->first < setWaitEndIndexInclusive[scopeOp]; - if (usedBefore || usedAfter) { - return false; - } - } - if (auto *conditionOp = llvm::dyn_cast(scopeOp)) { - if (!conditionOp->hasFalseScope()) { - return false; - } - return checkMergeable(conditionOp->getTrueScope(), corePipeSrc, corePipeDst, - eventId, true) && - checkMergeable(conditionOp->getFalseScope(), corePipeSrc, - corePipeDst, eventId, true); - } - if (auto *loopOp = llvm::dyn_cast(scopeOp)) { - for (auto &childOp : loopOp->body) { - if (auto *childScopeOp = llvm::dyn_cast(childOp.get())) { - if (!checkMergeable(childScopeOp, corePipeSrc, corePipeDst, eventId, - false)) { - return false; - } - } - } - for (auto &childOp : loopOp->body) { - if (auto *childScopeOp = llvm::dyn_cast(childOp.get())) { - if (checkMergeable(childScopeOp, corePipeSrc, corePipeDst, eventId, - true)) { - return true; - } - } - } - return false; - } - for (auto &childOp : scopeOp->body) { - auto it1 = - index.lower_bound({setWaitStartIndexInclusive[childOp.get()], nullptr}); - auto it2 = index.lower_bound({setWaitEndIndex[childOp.get()], nullptr}); - bool usedAtleastOnce = it1 != index.end() && - it1->first < setWaitEndIndexInclusive[childOp.get()]; - if (!usedAtleastOnce) { - continue; - } - bool before = - it1 != index.end() && it1->first < setWaitStartIndex[childOp.get()]; - bool after = it2 != index.end() && - it2->first < setWaitEndIndexInclusive[childOp.get()]; - if (before || after) { - return false; - } - if (!checkBackwardSyncEventsContains(childOp.get(), corePipeSrc, - corePipeDst, eventId)) { - return false; - } - if (checkBackwardSyncEventsContainsAfterMerge(childOp.get(), corePipeSrc, - corePipeDst)) { - return false; - } - } - return true; -} - -// Attempt to merge backward sync events across children and prune duplicates. -void Solver::mergeBackwardSyncEventIds(OperationBase *op) { - auto *scopeOp = llvm::dyn_cast_if_present(op); - if (scopeOp == nullptr) { - return; - } - for (auto &op : scopeOp->body) { - mergeBackwardSyncEventIds(op.get()); - } - - if (llvm::isa_and_present(op)) { - return; - } - if (llvm::isa_and_present(op->parentOp)) { - return; - } - - auto *conditionOp = llvm::dyn_cast(op); - if (conditionOp != nullptr) { - if (!conditionOp->hasFalseScope()) { - return; - } - } - - llvm::DenseSet> toBeErased; - - llvm::SmallVector coreTypes; - if (options.isCrossCoreMode()) { - coreTypes = {pto::TCoreType::VECTOR, pto::TCoreType::CUBE}; - } else { - coreTypes = {pto::TCoreType::CUBE_OR_VECTOR}; - } - size_t pipeNumMax = static_cast(pto::PIPE::PIPE_NUM); - const int64_t eventIdMax = getHWAvailableEventIdNum(options.syncMode); - - for (int64_t eventId = 0; eventId < eventIdMax; ++eventId) { - for (auto coreSrc : coreTypes) { - for (auto coreDst : coreTypes) { - for (size_t pipeSrcInt = 0; pipeSrcInt < pipeNumMax; pipeSrcInt++) { - for (size_t pipeDstInt = 0; pipeDstInt < pipeNumMax; pipeDstInt++) { - auto pipeSrc = static_cast(pipeSrcInt); - auto pipeDst = static_cast(pipeDstInt); - auto corePipeSrc = CorePipeInfo(coreSrc, pipeSrc); - auto corePipeDst = CorePipeInfo(coreDst, pipeDst); - if (checkBackwardSyncEventsContains(scopeOp, corePipeSrc, - corePipeDst, eventId)) { - continue; - } - if (checkMergeable(scopeOp, corePipeSrc, corePipeDst, eventId)) { - toBeErased.insert({corePipeSrc, corePipeDst, eventId}); - backwardSyncEvents[scopeOp][{corePipeSrc, corePipeDst}].insert( - {eventId, 1}); - } - } - } - } - } - } - - if (isa(scopeOp)) { - for (auto &op : scopeOp->body) { - if (auto *block = llvm::dyn_cast(op.get())) { - for (auto &childOp : block->body) { - if (auto *childScopeOp = llvm::dyn_cast(childOp.get())) { - for (auto [corePipeSrc, corePipeDst, eventId] : toBeErased) { - if (checkBackwardSyncEventsContains(childScopeOp, corePipeSrc, - corePipeDst, eventId)) { - auto key = std::make_tuple(corePipeSrc, corePipeDst); - backwardSyncEvents[childScopeOp][key].erase(eventId); - if (backwardSyncEvents[childScopeOp][key].empty()) { - backwardSyncEvents[childScopeOp].erase(key); - } - } - } - } - } - } - } - } else { - for (auto &childOp : scopeOp->body) { - if (auto *childScopeOp = llvm::dyn_cast(childOp.get())) { - for (auto [corePipeSrc, corePipeDst, eventId] : toBeErased) { - if (checkBackwardSyncEventsContains(childScopeOp, corePipeSrc, - corePipeDst, eventId)) { - auto key = std::make_tuple(corePipeSrc, corePipeDst); - backwardSyncEvents[childScopeOp][key].erase(eventId); - if (backwardSyncEvents[childScopeOp][key].empty()) { - backwardSyncEvents[childScopeOp].erase(key); - } - } - } - } - } - } -} - -void Solver::mergeBackwardSyncPairs(SyncMap &syncMapBefore, - SyncMap &syncMapAfter) { - if (!options.moveOutAndMergeBackwardSyncPairs) { - return; - } - if (options.isIntraCoreMode()) { - resetAndBuildSetWaitOpIndex(syncMapBefore, syncMapAfter); - auto *scopeOp = llvm::dyn_cast(funcIr.get()); - assert(scopeOp != nullptr && scopeOp->body.front() != nullptr); - mergeBackwardSyncEventIds(scopeOp->body.front().get()); - } -} - -SyncBeforeAfterMap Solver::getBeforeAfterSyncMaps() { - calcAllEventIds(); - SyncMap syncMapBefore, syncMapAfter; - std::vector conflictPairs; - for (auto &conflictPair : chosenConflictedPairs) { - conflictPairs.push_back(conflictPair.get()); - } - for (auto &conflictPair : persistentChosenConflictedPairs) { - conflictPairs.push_back(conflictPair.get()); - } - - for (auto *conflictPair : conflictPairs) { - if (conflictPair->isUseless) { - continue; - } - if (conflictPair->replacedWithUnitFlag) { - continue; - } - assert(conflictPair->setOp != nullptr && conflictPair->waitOp != nullptr); - if (conflictPair->isBarrier()) { - auto barrierOp = std::make_unique( - conflictPair->waitOp->op, conflictPair->waitOp->parentOp, - conflictPair->waitCorePipeInfo.pipe); - LLVM_DEBUG(barrierOp->debugId = conflictPair->id); - syncMapBefore[conflictPair->waitOp].push_back(std::move(barrierOp)); - } else { - assert(conflictPair->eventIdNode != nullptr); - auto setOp = std::make_unique( - conflictPair->setOp->op, conflictPair->setOp->parentOp, - conflictPair->eventIdNode->getEventIds(), - conflictPair->setCorePipeInfo.pipe, - conflictPair->waitCorePipeInfo.pipe); - auto waitOp = std::make_unique( - conflictPair->waitOp->op, conflictPair->waitOp->parentOp, - conflictPair->eventIdNode->getEventIds(), - conflictPair->setCorePipeInfo.pipe, - conflictPair->waitCorePipeInfo.pipe); - if (options.isCrossCoreMode()) { - setOp->coreType = conflictPair->setCorePipeInfo.coreType; - waitOp->coreType = conflictPair->waitCorePipeInfo.coreType; - } - setOp->eventIdInfo = conflictPair->eventIdInfo; - waitOp->eventIdInfo = conflictPair->eventIdInfo; - setOp->checkLastIter = conflictPair->setOnLastIterOnly; - waitOp->checkFirstIter = conflictPair->waitOnFirstIterOnly; - LLVM_DEBUG({ - setOp->debugId = conflictPair->id; - waitOp->debugId = conflictPair->id; - }); - assert(setOp != nullptr && waitOp != nullptr); - syncMapAfter[conflictPair->setOp].push_back(std::move(setOp)); - syncMapBefore[conflictPair->waitOp].push_front(std::move(waitOp)); - } - } - - collectBackwardSyncEventIds(); - mergeBackwardSyncPairs(syncMapBefore, syncMapAfter); - - for (auto &[op, mp] : backwardSyncEvents) { - if (mp.empty()) { - continue; - } - auto *scopeOp = llvm::dyn_cast(op); - assert(scopeOp != nullptr); - for (auto [setWaitCorePipes, eventIdsMp] : mp) { - if (eventIdsMp.empty()) { - continue; - } - llvm::SmallVector eventIds; - for (auto [eventId, repeatNum] : eventIdsMp) { - llvm::SmallVector curEventIds(repeatNum, eventId); - llvm::append_range(eventIds, curEventIds); - } - llvm::sort(eventIds); - auto [corePipeSrc, corePipeDst] = setWaitCorePipes; - auto setOp = - std::make_unique(scopeOp->op, scopeOp->parentOp, eventIds, - corePipeSrc.pipe, corePipeDst.pipe); - auto waitOp = - std::make_unique(scopeOp->op, scopeOp->parentOp, eventIds, - corePipeSrc.pipe, corePipeDst.pipe); - setOp->allAtOnce = true; - waitOp->allAtOnce = true; - if (options.isCrossCoreMode()) { - setOp->coreType = corePipeSrc.coreType; - waitOp->coreType = corePipeDst.coreType; - } - assert(setOp != nullptr && waitOp != nullptr); - syncMapBefore[scopeOp].push_back(std::move(setOp)); - syncMapAfter[scopeOp].push_front(std::move(waitOp)); - } - } - return std::make_pair(std::move(syncMapBefore), std::move(syncMapAfter)); -} - -void Solver::processConflict(Occurrence *occ1, Occurrence *occ2, - RWOperation *rwOp1, RWOperation *rwOp2, - bool isUseless) { - for (auto [corePipeSrc, corePipeDst] : checkMemoryConflicts(rwOp1, rwOp2)) { - if (options.alwaysUsePipeSAsWaitingPipe) { - corePipeDst.pipe = pto::PIPE::PIPE_S; - } - auto eventIdInfo = - getEventIdInfo(occ1, occ2, rwOp1, rwOp2, corePipeSrc, corePipeDst); - handleConflict(occ1, occ2, rwOp1, rwOp2, corePipeSrc, corePipeDst, - eventIdInfo, isUseless); - } -} - -// Main processing loop that iterates processingOrders and attempts to -// discover and record conflicts. -void Solver::processOrders() { - for (auto &[occ1, occ2, rwOp1, rwOp2, isUseless] : processingOrders) { - assert(occ1 != occ2); - assert(occ1->syncIrIndex < occ2->syncIrIndex); - if (checkVisited(occ1, occ2)) { - assert(false && "expected to not check a pair more than once."); - continue; - } - if (checkImpossibleOccPair(occ1, occ2) || checkAlreadySynced(occ1, occ2) || - skipMMad1DecomposedLoopOpt(occ1, occ2) || - checkSkipParallelLoop(occ1, occ2) || - checkSkipCrossCorePair(occ1, occ2)) { - continue; - } - DEBUG_WITH_TYPE("gss-sync-solver-checking", { - llvm::dbgs() << "checking: " << (isUseless ? "is-useless\n" : "\n"); - llvm::dbgs() << occ1->syncIrIndex << ' ' << occ1->startIndex << ' ' - << occ1->endIndex << ' ' << occ1->op->str(0, false) << '\n'; - llvm::dbgs() << occ2->syncIrIndex << ' ' << occ2->startIndex << ' ' - << occ2->endIndex << ' ' << occ2->op->str(0, false) << '\n'; - }); - if (checkAlreadySyncedWithUnitFlag(occ1, occ2)) { - continue; - } - processConflict(occ1, occ2, rwOp1, rwOp2, isUseless); - } -} - -void Solver::insertMergedBackwardSyncPairs() { - for (auto &[scopeOp, st] : backwardSyncEventsAfterMerge) { - for (auto &corePipeInfoPair : st) { - auto [corePipeSrc, corePipeDst] = corePipeInfoPair; - for (auto *scopeOcc : opAllOccurrences[scopeOp]) { - auto *parentScopeOcc = scopeOcc->parentOcc; - assert(parentScopeOcc != nullptr); - Occurrence *setOcc = nullptr; - Occurrence *waitOcc = nullptr; - auto startIndex = scopeOcc->startIndex; - auto endIndex = scopeOcc->endIndex; - if (isa(scopeOp)) { - setOcc = getBeforePlaceHolderOcc(scopeOcc); - waitOcc = getAfterPlaceHolderOcc(scopeOcc); - startIndex = setOcc->endIndex; - endIndex = waitOcc->startIndex; - } - auto conflictPair = std::make_unique( - nullptr, nullptr, nullptr, nullptr, setOcc, waitOcc, corePipeSrc, - corePipeDst, startIndex, endIndex); - assert(conflictPair->startIndex <= conflictPair->endIndex); - conflictPair->isUseless = true; - conflictPair->dontReuse = true; - conflictPair->dontCheckForConflict = true; - conflictPair->couldNotRun = false; // notice this - LLVM_DEBUG({ - llvm::dbgs() << "consider-merged-backward-pair: " - << scopeOp->str(0, false) << ' ' << conflictPair->str() - << "\n"; - }); - scopeOccChosenConflicts[parentScopeOcc].insert(conflictPair.get()); - chosenConflictedPairs.push_back(std::move(conflictPair)); - } - } - } -} - -llvm::LogicalResult Solver::considerOuterBackwardSyncPairs() { - if (!options.considerOuterBackwardSyncPairs) { - return llvm::failure(); - } - bool backwardPairsPositionChanged = false; - for (auto &[scopeOp, st] : backwardSyncEventsAfterMerge) { - SmallVector> toBeErased; - for (auto &corePipeInfoPair : st) { - if (!backwardSyncEvents.contains(scopeOp) || - !backwardSyncEvents[scopeOp].contains(corePipeInfoPair)) { - toBeErased.push_back(corePipeInfoPair); - } - } - if (!toBeErased.empty()) { - backwardPairsPositionChanged = true; - for (auto &corePipeInfoPair : toBeErased) { - st.erase(corePipeInfoPair); - } - } - } - int chosenOpsDepth = -1; - SmallVector chosenOps; - for (auto &[scopeOp, mp] : backwardSyncEvents) { - if (backwardSyncEventsAfterMerge.contains(scopeOp)) { - continue; - } - int scopeOpDepth = scopeOp->getDepth(); - if (chosenOpsDepth == scopeOpDepth) { - chosenOps.push_back(scopeOp); - } else if (chosenOpsDepth == -1 || chosenOpsDepth < scopeOpDepth) { - chosenOps.clear(); - chosenOps.push_back(scopeOp); - chosenOpsDepth = scopeOpDepth; - } - } - if (chosenOps.empty()) { - return llvm::failure(); - } - bool newPairIsInserted = false; - for (auto *chosenOp : chosenOps) { - for (auto &[corePipeInfoPair, eventIdsMp] : backwardSyncEvents[chosenOp]) { - assert(!eventIdsMp.empty()); - if (!eventIdsMp.empty()) { - auto [it, isInserted] = - backwardSyncEventsAfterMerge[chosenOp].insert(corePipeInfoPair); - newPairIsInserted |= isInserted; - } - } - } - return llvm::success(backwardPairsPositionChanged || newPairIsInserted); -} - -llvm::LogicalResult Solver::reuseSyncPairToSaveEventIds() { - if (!options.reuseSyncPairToSaveEventIds || barrierAllPairs.empty()) { - return llvm::failure(); - } - bool limitReached = true; - for (auto [corePipeSrc, corePipeDst] : barrierAllPairs) { - if (reusePairs[{corePipeSrc, corePipeDst}] < maxReuseNum) { - if (reusePairs[{corePipeSrc, corePipeDst}] <= - reusedPairs[{corePipeSrc, corePipeDst}]) { - reusePairs[{corePipeSrc, corePipeDst}] += 1; - limitReached = false; - } - } - } - DEBUG_WITH_TYPE("gss-sync-solver-reuse", { - llvm::dbgs() << "reusePairs: \n"; - for (auto [pipeCorePairs, cnt] : reusePairs) { - llvm::dbgs() << get<0>(pipeCorePairs).pipe << ' ' - << get<1>(pipeCorePairs).pipe << ' ' << cnt << '\n'; - } - }); - return llvm::success(!limitReached); -} - -llvm::LogicalResult Solver::disableMultiEventIdForBarrierAllPairs() { - if (!options.disableMultiEventIdForBarrierAllPairs || - barrierAllPairs.empty()) { - return llvm::failure(); - } - bool newPairIsInserted = false; - for (auto corePipeInfoPair : barrierAllPairs) { - auto [it, isInserted] = disabledMultiEventIdPairs.insert(corePipeInfoPair); - newPairIsInserted |= isInserted; - } - LLVM_DEBUG({ - if (newPairIsInserted) { - llvm::dbgs() << "disabled-multi-event-id-pairs: \n"; - for (auto &[corePipeSrc, corePipeDst] : disabledMultiEventIdPairs) { - llvm::dbgs() << corePipeSrc.coreType << ' ' << corePipeSrc.pipe << ' ' - << corePipeDst.coreType << ' ' << corePipeDst.pipe << '\n'; - } - } - }); - return llvm::success(newPairIsInserted); -} - -llvm::LogicalResult Solver::tryMovingOutBackwardSyncPairsToOuterLoops() { - if (!options.moveOutAndMergeBackwardSyncPairs || !options.isCrossCoreMode() || - dontMoveBackwardSyncPairsToOutmostLoop) { - return llvm::failure(); - } - if (!moveBackwardSyncPairsToOutmostLoop) { - moveBackwardSyncPairsToOutmostLoop = true; - return llvm::success(); - } - if (!barrierAllPairs.empty()) { - moveBackwardSyncPairsToOutmostLoop = false; - dontMoveBackwardSyncPairsToOutmostLoop = true; - return llvm::success(); - } - return llvm::failure(); -} - -// High-level solve orchestration with multiple passes and optional merging -// iterations. -llvm::LogicalResult Solver::runSolver(bool enableOpts1, bool enableOpts2) { - reset(/*resetEventIdRanOutOpts=*/true); - - int64_t runNum = 0; - while (runNum++ < maxRunNum) { - LLVM_DEBUG(llvm::dbgs() << "runNum: " << runNum << '\n'); - - reset(); - insertMergedBackwardSyncPairs(); - processOrders(); - - if (llvm::succeeded(tryMovingOutBackwardSyncPairsToOuterLoops())) { - continue; - } - - if (enableOpts1) { - if (options.considerOuterBackwardSyncPairs) { - getBeforeAfterSyncMaps(); - if (llvm::succeeded(considerOuterBackwardSyncPairs())) { - continue; - } - if (!barrierAllPairs.empty()) { - backwardSyncEventsAfterMerge.clear(); - } - } - } - - if (enableOpts2) { - if (!barrierAllPairs.empty()) { - if (llvm::succeeded(reuseSyncPairToSaveEventIds())) { - continue; - } - if (llvm::succeeded(disableMultiEventIdForBarrierAllPairs())) { - continue; - } - } - } - - if (!barrierAllPairs.empty()) { - pickAndInsertABarrierAll(); - reset(/*resetEventIdRanOutOpts=*/true); - continue; - } - break; - } - - reset(); - insertMergedBackwardSyncPairs(); - processOrders(); - - return llvm::success(runNum < maxRunNum); -} - -void Solver::solve() { - if (llvm::succeeded(runSolver())) { - return; - } - if (!options.isTestMode()) { - if (llvm::succeeded(runSolver(/*enableOpts1=*/false))) { - return; - } - if (llvm::succeeded( - runSolver(/*enableOpts1=*/false, /*enableOpts2=*/false))) { - return; - } - } - llvm_unreachable("GSS: runSolver() failed."); -} diff --git a/lib/PTO/Transforms/GraphSyncSolver/SyncSolverMerge.cpp b/lib/PTO/Transforms/GraphSyncSolver/SyncSolverMerge.cpp new file mode 100644 index 000000000..b35a37b79 --- /dev/null +++ b/lib/PTO/Transforms/GraphSyncSolver/SyncSolverMerge.cpp @@ -0,0 +1,705 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===--------- SyncSolver.cpp ------- Graph Sync Solver -------------------===// +//===----------------------------------------------------------------------===// + +#include "PTO/Transforms/GraphSyncSolver/SyncSolver.h" +#include "PTO/Transforms/GraphSyncSolver/GraphSolver.h" +#include "PTO/Transforms/GraphSyncSolver/MemInfo.h" +#include "PTO/Transforms/GraphSyncSolver/SyncSolverIR.h" +#include "PTO/Transforms/GraphSyncSolver/Utility.h" + +#include "PTO/IR/PTO.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/LogicalResult.h" +#include +#include +#include +#include +#include +#include +#include + +#define DEBUG_TYPE "PTO-gss-solver" + +using namespace mlir; +using namespace pto::syncsolver; + +void Solver::calcAllEventIds() { + for (auto &[pipes, eventIdSolver] : eventIdSolver) { + assert(eventIdSolver != nullptr); + + [[maybe_unused]] auto result = + eventIdSolver->shrinkEventIdMaxToEventIdNum(); + assert(llvm::succeeded(result)); + assert(eventIdSolver->isColorable()); + } +} + +void Solver::collectBackwardSyncEventIds() { + LLVM_DEBUG(llvm::dbgs() << "collectBackwardSyncEventIds\n";); + for (auto &conflictPair : chosenConflictedPairs) { + if (!conflictPair->isUseless && conflictPair->isInnerBackward && + conflictPair->eventIdNode != nullptr) { + LLVM_DEBUG(llvm::dbgs() << " " << conflictPair->str() << "\n";); + for (auto eventId : conflictPair->eventIdNode->getEventIds()) { + auto &e = backwardSyncEvents[conflictPair->backwardSyncLoopOp] + [{conflictPair->setCorePipeInfo, + conflictPair->waitCorePipeInfo}][eventId]; + e = std::max(e, conflictPair->eventIdInfo.eventIdRepeatNum); + } + } + } +} + +void Solver::resetAndBuildSetWaitOpIndex(const SyncMap &syncMapBefore, + const SyncMap &syncMapAfter) { + globalSetWaitIndex = 0; + setWaitStartIndex.clear(); + setWaitEndIndex.clear(); + setWaitStartIndexInclusive.clear(); + setWaitEndIndexInclusive.clear(); + setWaitFlagOpsIndex.clear(); + collectSetWaitOpsIndexes(funcIr.get(), syncMapBefore, syncMapAfter); +} + +std::set> & +Solver::getSetWaitOpsIndexRef(pto::PIPE pipeSrc, pto::PIPE pipeDst, + int64_t eventId) { + auto key = std::make_tuple(pipeSrc, pipeDst, eventId); + return setWaitFlagOpsIndex[key]; +} + +// Collect indices for all Set/Wait ops to facilitate merging decisions. +void Solver::collectSetWaitOpsIndexes(OperationBase *op, + const SyncMap &syncMapBefore, + const SyncMap &syncMapAfter) { + assert(op != nullptr); + setWaitStartIndexInclusive[op] = globalSetWaitIndex++; + if (syncMapBefore.count(op)) { + auto *it = syncMapBefore.find(op); + assert(it != syncMapBefore.end()); + for (auto &syncOp : it->second) { + if (auto *setWaitOp = llvm::dyn_cast(syncOp.get())) { + for (auto eventId : setWaitOp->eventIds) { + auto &index = getSetWaitOpsIndexRef(setWaitOp->pipeSrc, + setWaitOp->pipeDst, eventId); + index.insert({globalSetWaitIndex++, setWaitOp}); + } + } + } + } + setWaitStartIndex[op] = globalSetWaitIndex++; + if (auto *scopeOp = llvm::dyn_cast(op)) { + for (auto &childOp : scopeOp->body) { + collectSetWaitOpsIndexes(childOp.get(), syncMapBefore, syncMapAfter); + } + } + setWaitEndIndex[op] = globalSetWaitIndex++; + if (syncMapAfter.count(op)) { + auto *it = syncMapAfter.find(op); + assert(it != syncMapAfter.end()); + for (auto &syncOp : it->second) { + if (auto *setWaitOp = llvm::dyn_cast(syncOp.get())) { + for (auto eventId : setWaitOp->eventIds) { + auto &index = getSetWaitOpsIndexRef(setWaitOp->pipeSrc, + setWaitOp->pipeDst, eventId); + index.insert({globalSetWaitIndex++, setWaitOp}); + } + } + } + } + setWaitEndIndexInclusive[op] = globalSetWaitIndex++; +} + +bool Solver::checkBackwardSyncEventsContains(OperationBase *op, + CorePipeInfo corePipeSrc, + CorePipeInfo corePipeDst, + int64_t eventId) { + auto *it1 = backwardSyncEvents.find(op); + if (it1 == backwardSyncEvents.end()) { + return false; + } + auto it2 = it1->second.find({corePipeSrc, corePipeDst}); + if (it2 == it1->second.end()) { + return false; + } + return it2->second.contains(eventId); +} + +bool Solver::checkBackwardSyncEventsContainsAfterMerge( + OperationBase *op, CorePipeInfo corePipeSrc, CorePipeInfo corePipeDst) { + auto *it1 = backwardSyncEventsAfterMerge.find(op); + if (it1 == backwardSyncEventsAfterMerge.end()) { + return false; + } + return it1->second.contains({corePipeSrc, corePipeDst}); +} + +// Check whether a backward-sync event id can be merged at scope level. +bool Solver::checkMergeable(Scope *scopeOp, CorePipeInfo corePipeSrc, + CorePipeInfo corePipeDst, int64_t eventId, + bool shouldBeUsedAtleastOnce) { + auto &index = + getSetWaitOpsIndexRef(corePipeSrc.pipe, corePipeDst.pipe, eventId); + if (shouldBeUsedAtleastOnce) { + auto it = index.lower_bound({setWaitStartIndexInclusive[scopeOp], nullptr}); + bool usedAtleastOnce = + it != index.end() && it->first < setWaitEndIndexInclusive[scopeOp]; + if (!usedAtleastOnce) { + return false; + } + } + { + auto it1 = + index.lower_bound({setWaitStartIndexInclusive[scopeOp], nullptr}); + auto it2 = index.lower_bound({setWaitEndIndex[scopeOp], nullptr}); + bool usedBefore = + it1 != index.end() && it1->first < setWaitStartIndex[scopeOp]; + bool usedAfter = + it2 != index.end() && it2->first < setWaitEndIndexInclusive[scopeOp]; + if (usedBefore || usedAfter) { + return false; + } + } + if (auto *conditionOp = llvm::dyn_cast(scopeOp)) { + if (!conditionOp->hasFalseScope()) { + return false; + } + return checkMergeable(conditionOp->getTrueScope(), corePipeSrc, corePipeDst, + eventId, true) && + checkMergeable(conditionOp->getFalseScope(), corePipeSrc, + corePipeDst, eventId, true); + } + if (auto *loopOp = llvm::dyn_cast(scopeOp)) { + for (auto &childOp : loopOp->body) { + if (auto *childScopeOp = llvm::dyn_cast(childOp.get())) { + if (!checkMergeable(childScopeOp, corePipeSrc, corePipeDst, eventId, + false)) { + return false; + } + } + } + for (auto &childOp : loopOp->body) { + if (auto *childScopeOp = llvm::dyn_cast(childOp.get())) { + if (checkMergeable(childScopeOp, corePipeSrc, corePipeDst, eventId, + true)) { + return true; + } + } + } + return false; + } + for (auto &childOp : scopeOp->body) { + auto it1 = + index.lower_bound({setWaitStartIndexInclusive[childOp.get()], nullptr}); + auto it2 = index.lower_bound({setWaitEndIndex[childOp.get()], nullptr}); + bool usedAtleastOnce = it1 != index.end() && + it1->first < setWaitEndIndexInclusive[childOp.get()]; + if (!usedAtleastOnce) { + continue; + } + bool before = + it1 != index.end() && it1->first < setWaitStartIndex[childOp.get()]; + bool after = it2 != index.end() && + it2->first < setWaitEndIndexInclusive[childOp.get()]; + if (before || after) { + return false; + } + if (!checkBackwardSyncEventsContains(childOp.get(), corePipeSrc, + corePipeDst, eventId)) { + return false; + } + if (checkBackwardSyncEventsContainsAfterMerge(childOp.get(), corePipeSrc, + corePipeDst)) { + return false; + } + } + return true; +} + +// Attempt to merge backward sync events across children and prune duplicates. +void Solver::mergeBackwardSyncEventIds(OperationBase *op) { + auto *scopeOp = llvm::dyn_cast_if_present(op); + if (scopeOp == nullptr) { + return; + } + for (auto &op : scopeOp->body) { + mergeBackwardSyncEventIds(op.get()); + } + + if (llvm::isa_and_present(op)) { + return; + } + if (llvm::isa_and_present(op->parentOp)) { + return; + } + + auto *conditionOp = llvm::dyn_cast(op); + if (conditionOp != nullptr) { + if (!conditionOp->hasFalseScope()) { + return; + } + } + + llvm::DenseSet> toBeErased; + + llvm::SmallVector coreTypes; + if (options.isCrossCoreMode()) { + coreTypes = {pto::TCoreType::VECTOR, pto::TCoreType::CUBE}; + } else { + coreTypes = {pto::TCoreType::CUBE_OR_VECTOR}; + } + size_t pipeNumMax = static_cast(pto::PIPE::PIPE_NUM); + const int64_t eventIdMax = getHWAvailableEventIdNum(options.syncMode); + + for (int64_t eventId = 0; eventId < eventIdMax; ++eventId) { + for (auto coreSrc : coreTypes) { + for (auto coreDst : coreTypes) { + for (size_t pipeSrcInt = 0; pipeSrcInt < pipeNumMax; pipeSrcInt++) { + for (size_t pipeDstInt = 0; pipeDstInt < pipeNumMax; pipeDstInt++) { + auto pipeSrc = static_cast(pipeSrcInt); + auto pipeDst = static_cast(pipeDstInt); + auto corePipeSrc = CorePipeInfo(coreSrc, pipeSrc); + auto corePipeDst = CorePipeInfo(coreDst, pipeDst); + if (checkBackwardSyncEventsContains(scopeOp, corePipeSrc, + corePipeDst, eventId)) { + continue; + } + if (checkMergeable(scopeOp, corePipeSrc, corePipeDst, eventId)) { + toBeErased.insert({corePipeSrc, corePipeDst, eventId}); + backwardSyncEvents[scopeOp][{corePipeSrc, corePipeDst}].insert( + {eventId, 1}); + } + } + } + } + } + } + + if (isa(scopeOp)) { + for (auto &op : scopeOp->body) { + if (auto *block = llvm::dyn_cast(op.get())) { + for (auto &childOp : block->body) { + if (auto *childScopeOp = llvm::dyn_cast(childOp.get())) { + for (auto [corePipeSrc, corePipeDst, eventId] : toBeErased) { + if (checkBackwardSyncEventsContains(childScopeOp, corePipeSrc, + corePipeDst, eventId)) { + auto key = std::make_tuple(corePipeSrc, corePipeDst); + backwardSyncEvents[childScopeOp][key].erase(eventId); + if (backwardSyncEvents[childScopeOp][key].empty()) { + backwardSyncEvents[childScopeOp].erase(key); + } + } + } + } + } + } + } + } else { + for (auto &childOp : scopeOp->body) { + if (auto *childScopeOp = llvm::dyn_cast(childOp.get())) { + for (auto [corePipeSrc, corePipeDst, eventId] : toBeErased) { + if (checkBackwardSyncEventsContains(childScopeOp, corePipeSrc, + corePipeDst, eventId)) { + auto key = std::make_tuple(corePipeSrc, corePipeDst); + backwardSyncEvents[childScopeOp][key].erase(eventId); + if (backwardSyncEvents[childScopeOp][key].empty()) { + backwardSyncEvents[childScopeOp].erase(key); + } + } + } + } + } + } +} + +void Solver::mergeBackwardSyncPairs(SyncMap &syncMapBefore, + SyncMap &syncMapAfter) { + if (!options.moveOutAndMergeBackwardSyncPairs) { + return; + } + if (options.isIntraCoreMode()) { + resetAndBuildSetWaitOpIndex(syncMapBefore, syncMapAfter); + auto *scopeOp = llvm::dyn_cast(funcIr.get()); + assert(scopeOp != nullptr && scopeOp->body.front() != nullptr); + mergeBackwardSyncEventIds(scopeOp->body.front().get()); + } +} + +SyncBeforeAfterMap Solver::getBeforeAfterSyncMaps() { + calcAllEventIds(); + SyncMap syncMapBefore, syncMapAfter; + std::vector conflictPairs; + for (auto &conflictPair : chosenConflictedPairs) { + conflictPairs.push_back(conflictPair.get()); + } + for (auto &conflictPair : persistentChosenConflictedPairs) { + conflictPairs.push_back(conflictPair.get()); + } + + for (auto *conflictPair : conflictPairs) { + if (conflictPair->isUseless) { + continue; + } + if (conflictPair->replacedWithUnitFlag) { + continue; + } + assert(conflictPair->setOp != nullptr && conflictPair->waitOp != nullptr); + if (conflictPair->isBarrier()) { + auto barrierOp = std::make_unique( + conflictPair->waitOp->op, conflictPair->waitOp->parentOp, + conflictPair->waitCorePipeInfo.pipe); + LLVM_DEBUG(barrierOp->debugId = conflictPair->id); + syncMapBefore[conflictPair->waitOp].push_back(std::move(barrierOp)); + } else { + assert(conflictPair->eventIdNode != nullptr); + auto setOp = std::make_unique( + conflictPair->setOp->op, conflictPair->setOp->parentOp, + conflictPair->eventIdNode->getEventIds(), + conflictPair->setCorePipeInfo.pipe, + conflictPair->waitCorePipeInfo.pipe); + auto waitOp = std::make_unique( + conflictPair->waitOp->op, conflictPair->waitOp->parentOp, + conflictPair->eventIdNode->getEventIds(), + conflictPair->setCorePipeInfo.pipe, + conflictPair->waitCorePipeInfo.pipe); + if (options.isCrossCoreMode()) { + setOp->coreType = conflictPair->setCorePipeInfo.coreType; + waitOp->coreType = conflictPair->waitCorePipeInfo.coreType; + } + setOp->eventIdInfo = conflictPair->eventIdInfo; + waitOp->eventIdInfo = conflictPair->eventIdInfo; + setOp->checkLastIter = conflictPair->setOnLastIterOnly; + waitOp->checkFirstIter = conflictPair->waitOnFirstIterOnly; + LLVM_DEBUG({ + setOp->debugId = conflictPair->id; + waitOp->debugId = conflictPair->id; + }); + assert(setOp != nullptr && waitOp != nullptr); + syncMapAfter[conflictPair->setOp].push_back(std::move(setOp)); + syncMapBefore[conflictPair->waitOp].push_front(std::move(waitOp)); + } + } + + collectBackwardSyncEventIds(); + mergeBackwardSyncPairs(syncMapBefore, syncMapAfter); + + for (auto &[op, mp] : backwardSyncEvents) { + if (mp.empty()) { + continue; + } + auto *scopeOp = llvm::dyn_cast(op); + assert(scopeOp != nullptr); + for (auto [setWaitCorePipes, eventIdsMp] : mp) { + if (eventIdsMp.empty()) { + continue; + } + llvm::SmallVector eventIds; + for (auto [eventId, repeatNum] : eventIdsMp) { + llvm::SmallVector curEventIds(repeatNum, eventId); + llvm::append_range(eventIds, curEventIds); + } + llvm::sort(eventIds); + auto [corePipeSrc, corePipeDst] = setWaitCorePipes; + auto setOp = + std::make_unique(scopeOp->op, scopeOp->parentOp, eventIds, + corePipeSrc.pipe, corePipeDst.pipe); + auto waitOp = + std::make_unique(scopeOp->op, scopeOp->parentOp, eventIds, + corePipeSrc.pipe, corePipeDst.pipe); + setOp->allAtOnce = true; + waitOp->allAtOnce = true; + if (options.isCrossCoreMode()) { + setOp->coreType = corePipeSrc.coreType; + waitOp->coreType = corePipeDst.coreType; + } + assert(setOp != nullptr && waitOp != nullptr); + syncMapBefore[scopeOp].push_back(std::move(setOp)); + syncMapAfter[scopeOp].push_front(std::move(waitOp)); + } + } + return std::make_pair(std::move(syncMapBefore), std::move(syncMapAfter)); +} + +void Solver::processConflict(Occurrence *occ1, Occurrence *occ2, + RWOperation *rwOp1, RWOperation *rwOp2, + bool isUseless) { + for (auto [corePipeSrc, corePipeDst] : checkMemoryConflicts(rwOp1, rwOp2)) { + if (options.alwaysUsePipeSAsWaitingPipe) { + corePipeDst.pipe = pto::PIPE::PIPE_S; + } + auto eventIdInfo = + getEventIdInfo(occ1, occ2, rwOp1, rwOp2, corePipeSrc, corePipeDst); + handleConflict(occ1, occ2, rwOp1, rwOp2, corePipeSrc, corePipeDst, + eventIdInfo, isUseless); + } +} + +// Main processing loop that iterates processingOrders and attempts to +// discover and record conflicts. +void Solver::processOrders() { + for (auto &[occ1, occ2, rwOp1, rwOp2, isUseless] : processingOrders) { + assert(occ1 != occ2); + assert(occ1->syncIrIndex < occ2->syncIrIndex); + if (checkVisited(occ1, occ2)) { + assert(false && "expected to not check a pair more than once."); + continue; + } + if (checkImpossibleOccPair(occ1, occ2) || checkAlreadySynced(occ1, occ2) || + skipMMad1DecomposedLoopOpt(occ1, occ2) || + checkSkipParallelLoop(occ1, occ2) || + checkSkipCrossCorePair(occ1, occ2)) { + continue; + } + DEBUG_WITH_TYPE("gss-sync-solver-checking", { + llvm::dbgs() << "checking: " << (isUseless ? "is-useless\n" : "\n"); + llvm::dbgs() << occ1->syncIrIndex << ' ' << occ1->startIndex << ' ' + << occ1->endIndex << ' ' << occ1->op->str(0, false) << '\n'; + llvm::dbgs() << occ2->syncIrIndex << ' ' << occ2->startIndex << ' ' + << occ2->endIndex << ' ' << occ2->op->str(0, false) << '\n'; + }); + if (checkAlreadySyncedWithUnitFlag(occ1, occ2)) { + continue; + } + processConflict(occ1, occ2, rwOp1, rwOp2, isUseless); + } +} + +void Solver::insertMergedBackwardSyncPairs() { + for (auto &[scopeOp, st] : backwardSyncEventsAfterMerge) { + for (auto &corePipeInfoPair : st) { + auto [corePipeSrc, corePipeDst] = corePipeInfoPair; + for (auto *scopeOcc : opAllOccurrences[scopeOp]) { + auto *parentScopeOcc = scopeOcc->parentOcc; + assert(parentScopeOcc != nullptr); + Occurrence *setOcc = nullptr; + Occurrence *waitOcc = nullptr; + auto startIndex = scopeOcc->startIndex; + auto endIndex = scopeOcc->endIndex; + if (isa(scopeOp)) { + setOcc = getBeforePlaceHolderOcc(scopeOcc); + waitOcc = getAfterPlaceHolderOcc(scopeOcc); + startIndex = setOcc->endIndex; + endIndex = waitOcc->startIndex; + } + auto conflictPair = std::make_unique( + nullptr, nullptr, nullptr, nullptr, setOcc, waitOcc, corePipeSrc, + corePipeDst, startIndex, endIndex); + assert(conflictPair->startIndex <= conflictPair->endIndex); + conflictPair->isUseless = true; + conflictPair->dontReuse = true; + conflictPair->dontCheckForConflict = true; + conflictPair->couldNotRun = false; // notice this + LLVM_DEBUG({ + llvm::dbgs() << "consider-merged-backward-pair: " + << scopeOp->str(0, false) << ' ' << conflictPair->str() + << "\n"; + }); + scopeOccChosenConflicts[parentScopeOcc].insert(conflictPair.get()); + chosenConflictedPairs.push_back(std::move(conflictPair)); + } + } + } +} + +llvm::LogicalResult Solver::considerOuterBackwardSyncPairs() { + if (!options.considerOuterBackwardSyncPairs) { + return llvm::failure(); + } + bool backwardPairsPositionChanged = false; + for (auto &[scopeOp, st] : backwardSyncEventsAfterMerge) { + SmallVector> toBeErased; + for (auto &corePipeInfoPair : st) { + if (!backwardSyncEvents.contains(scopeOp) || + !backwardSyncEvents[scopeOp].contains(corePipeInfoPair)) { + toBeErased.push_back(corePipeInfoPair); + } + } + if (!toBeErased.empty()) { + backwardPairsPositionChanged = true; + for (auto &corePipeInfoPair : toBeErased) { + st.erase(corePipeInfoPair); + } + } + } + int chosenOpsDepth = -1; + SmallVector chosenOps; + for (auto &[scopeOp, mp] : backwardSyncEvents) { + if (backwardSyncEventsAfterMerge.contains(scopeOp)) { + continue; + } + int scopeOpDepth = scopeOp->getDepth(); + if (chosenOpsDepth == scopeOpDepth) { + chosenOps.push_back(scopeOp); + } else if (chosenOpsDepth == -1 || chosenOpsDepth < scopeOpDepth) { + chosenOps.clear(); + chosenOps.push_back(scopeOp); + chosenOpsDepth = scopeOpDepth; + } + } + if (chosenOps.empty()) { + return llvm::failure(); + } + bool newPairIsInserted = false; + for (auto *chosenOp : chosenOps) { + for (auto &[corePipeInfoPair, eventIdsMp] : backwardSyncEvents[chosenOp]) { + assert(!eventIdsMp.empty()); + if (!eventIdsMp.empty()) { + auto [it, isInserted] = + backwardSyncEventsAfterMerge[chosenOp].insert(corePipeInfoPair); + newPairIsInserted |= isInserted; + } + } + } + return llvm::success(backwardPairsPositionChanged || newPairIsInserted); +} + +llvm::LogicalResult Solver::reuseSyncPairToSaveEventIds() { + if (!options.reuseSyncPairToSaveEventIds || barrierAllPairs.empty()) { + return llvm::failure(); + } + bool limitReached = true; + for (auto [corePipeSrc, corePipeDst] : barrierAllPairs) { + if (reusePairs[{corePipeSrc, corePipeDst}] < maxReuseNum) { + if (reusePairs[{corePipeSrc, corePipeDst}] <= + reusedPairs[{corePipeSrc, corePipeDst}]) { + reusePairs[{corePipeSrc, corePipeDst}] += 1; + limitReached = false; + } + } + } + DEBUG_WITH_TYPE("gss-sync-solver-reuse", { + llvm::dbgs() << "reusePairs: \n"; + for (auto [pipeCorePairs, cnt] : reusePairs) { + llvm::dbgs() << get<0>(pipeCorePairs).pipe << ' ' + << get<1>(pipeCorePairs).pipe << ' ' << cnt << '\n'; + } + }); + return llvm::success(!limitReached); +} + +llvm::LogicalResult Solver::disableMultiEventIdForBarrierAllPairs() { + if (!options.disableMultiEventIdForBarrierAllPairs || + barrierAllPairs.empty()) { + return llvm::failure(); + } + bool newPairIsInserted = false; + for (auto corePipeInfoPair : barrierAllPairs) { + auto [it, isInserted] = disabledMultiEventIdPairs.insert(corePipeInfoPair); + newPairIsInserted |= isInserted; + } + LLVM_DEBUG({ + if (newPairIsInserted) { + llvm::dbgs() << "disabled-multi-event-id-pairs: \n"; + for (auto &[corePipeSrc, corePipeDst] : disabledMultiEventIdPairs) { + llvm::dbgs() << corePipeSrc.coreType << ' ' << corePipeSrc.pipe << ' ' + << corePipeDst.coreType << ' ' << corePipeDst.pipe << '\n'; + } + } + }); + return llvm::success(newPairIsInserted); +} + +llvm::LogicalResult Solver::tryMovingOutBackwardSyncPairsToOuterLoops() { + if (!options.moveOutAndMergeBackwardSyncPairs || !options.isCrossCoreMode() || + dontMoveBackwardSyncPairsToOutmostLoop) { + return llvm::failure(); + } + if (!moveBackwardSyncPairsToOutmostLoop) { + moveBackwardSyncPairsToOutmostLoop = true; + return llvm::success(); + } + if (!barrierAllPairs.empty()) { + moveBackwardSyncPairsToOutmostLoop = false; + dontMoveBackwardSyncPairsToOutmostLoop = true; + return llvm::success(); + } + return llvm::failure(); +} + +// High-level solve orchestration with multiple passes and optional merging +// iterations. +llvm::LogicalResult Solver::runSolver(bool enableOpts1, bool enableOpts2) { + reset(/*resetEventIdRanOutOpts=*/true); + + int64_t runNum = 0; + while (runNum++ < maxRunNum) { + LLVM_DEBUG(llvm::dbgs() << "runNum: " << runNum << '\n'); + + reset(); + insertMergedBackwardSyncPairs(); + processOrders(); + + if (llvm::succeeded(tryMovingOutBackwardSyncPairsToOuterLoops())) { + continue; + } + + if (enableOpts1) { + if (options.considerOuterBackwardSyncPairs) { + getBeforeAfterSyncMaps(); + if (llvm::succeeded(considerOuterBackwardSyncPairs())) { + continue; + } + if (!barrierAllPairs.empty()) { + backwardSyncEventsAfterMerge.clear(); + } + } + } + + if (enableOpts2) { + if (!barrierAllPairs.empty()) { + if (llvm::succeeded(reuseSyncPairToSaveEventIds())) { + continue; + } + if (llvm::succeeded(disableMultiEventIdForBarrierAllPairs())) { + continue; + } + } + } + + if (!barrierAllPairs.empty()) { + pickAndInsertABarrierAll(); + reset(/*resetEventIdRanOutOpts=*/true); + continue; + } + break; + } + + reset(); + insertMergedBackwardSyncPairs(); + processOrders(); + + return llvm::success(runNum < maxRunNum); +} + +void Solver::solve() { + if (llvm::succeeded(runSolver())) { + return; + } + if (!options.isTestMode()) { + if (llvm::succeeded(runSolver(/*enableOpts1=*/false))) { + return; + } + if (llvm::succeeded( + runSolver(/*enableOpts1=*/false, /*enableOpts2=*/false))) { + return; + } + } + llvm_unreachable("GSS: runSolver() failed."); +} diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index b3f0c6bbd..c8e15b51e 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -19,6 +19,7 @@ #include "PTO/IR/PTOTypeUtils.h" #include "PTO/IR/PTOSyncUtils.h" #include "PTO/Transforms/Passes.h" +#include "PTOToEmitCInternal.h" #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" @@ -69,24 +70,14 @@ namespace mlir { using namespace mlir; using namespace mlir::pto; -static std::string getElemTypeStringForGT(Type elemTy); static bool getStaticMemrefLayout(MemRefType mrTy, SmallVectorImpl &strides, int64_t &offset); static int64_t multiplyOrDynamic(int64_t lhs, int64_t rhs); -static void buildGlobalTensorShapeAndStride(ArrayRef shape, - ArrayRef strides, - SmallVectorImpl &shape5D, - SmallVectorImpl &stride5D); -static std::string joinIntTemplateParams(ArrayRef values); -static SmallVector buildRowMajorStrides(ArrayRef shape); static std::string getGlobalTensorTypeStringFromShape(Type elemTy, ArrayRef shape, StringRef layoutEnum = "pto::Layout::ND"); -static std::string getGlobalTensorTypeStringFromShapeAndStrides( - Type elemTy, ArrayRef shape, ArrayRef strides, - StringRef layoutEnum = "pto::Layout::ND"); static emitc::OpaqueType getGlobalTensorOpaqueTypeFromShape( MLIRContext *ctx, Type elemTy, ArrayRef shape, StringRef layoutEnum = "pto::Layout::ND"); @@ -121,20 +112,17 @@ static const char *addrSpaceQualifier(pto::AddressSpace as) { "__pto.lowered_set_validshape"; [[maybe_unused]] static constexpr llvm::StringLiteral kLoweredSetValidShapeConfigAttrName = "__pto.lowered_set_validshape_config"; -static constexpr llvm::StringLiteral kForceDynamicValidShapeAttrName = +[[maybe_unused]] static constexpr llvm::StringLiteral kForceDynamicValidShapeAttrName = "__pto.force_dynamic_valid_shape"; -static constexpr llvm::StringLiteral kGlobalTensorStridesAttrName = +[[maybe_unused]] static constexpr llvm::StringLiteral kGlobalTensorStridesAttrName = "__pto.globaltensor_strides"; -static Value peelUnrealized(Value v) { +Value mlir::pto::peelUnrealized(Value v) { if (auto castOp = v.getDefiningOp()) return castOp.getOperand(0); return v; } -static Value buildGlobalTensorFromMemref(ConversionPatternRewriter &rewriter, - Location loc, Value basePtr, - MemRefType mrTy, Operation *anchor); static Value maybeWrapGlobalMemrefAsGlobalTensor( ConversionPatternRewriter &rewriter, Location loc, Value loweredValue, @@ -258,12 +246,12 @@ static std::string layoutToEmitCString(mlir::pto::Layout layout) { return "pto::Layout::ND"; } -static bool isEmitCGlobalTensorLikeType(Type ty) { +bool mlir::pto::isEmitCGlobalTensorLikeType(Type ty) { auto opaqueTy = dyn_cast(ty); return opaqueTy && opaqueTy.getValue().contains("GlobalTensor<"); } -static std::string getEmitCScalarTypeToken(Type elemTy) { +std::string mlir::pto::getEmitCScalarTypeToken(Type elemTy) { if (pto::isPTOFloat8Type(elemTy) && (elemTy.isFloat8E4M3() || elemTy.isFloat8E4M3FN() || elemTy.isFloat8E4M3FNUZ() || elemTy.isFloat8E4M3B11FNUZ())) @@ -320,7 +308,7 @@ static bool isEmitCPointerLikeType(Type ty) { return false; } -static int64_t getEmitCScalarByteWidth(Type elemTy) { +[[maybe_unused]] static int64_t getEmitCScalarByteWidth(Type elemTy) { if (pto::getPTOStorageElemByteSize(elemTy) == 1) return 1; if (elemTy.isF16() || elemTy.isBF16() || elemTy.isInteger(16)) @@ -335,8 +323,8 @@ static int64_t getEmitCScalarByteWidth(Type elemTy) { static std::string tileBufBLayoutToken(pto::TileBufConfigAttr configAttr); static std::string tileBufSLayoutToken(pto::TileBufConfigAttr configAttr); static std::string tileBufPadToken(pto::TileBufConfigAttr configAttr); -static pto::BLayout getTileBufBLayoutValue(pto::TileBufConfigAttr configAttr); -static int64_t renderTileTemplateDim(int64_t rawDim, Type elemTy, +pto::BLayout mlir::pto::getTileBufBLayoutValue(pto::TileBufConfigAttr configAttr); +int64_t mlir::pto::renderTileTemplateDim(int64_t rawDim, Type elemTy, pto::BLayout blayout, int dimIdx); static const char *tileRoleToken(Attribute memorySpace) { @@ -382,7 +370,7 @@ static std::string tileBufCompactToken(pto::TileBufConfigAttr configAttr) { return compactTok; } -static std::optional getEmitCTileTypeString(pto::TileBufType type) { +std::optional mlir::pto::getEmitCTileTypeString(pto::TileBufType type) { if (type.getRank() != 2) return std::nullopt; auto validShape = type.getValidShape(); @@ -642,11 +630,10 @@ class PTOToEmitCTypeConverter : public TypeConverter { } }; -static constexpr unsigned kPTOIndexBitWidth = +[[maybe_unused]] static constexpr unsigned kPTOIndexBitWidth = 32; // keep consistent with IndexType conversion // Forward declarations (definitions below). -static inline std::string pipeTokFromPipeAttr(mlir::pto::PipeAttr a); static emitc::OpaqueType getSignedIntOpaqueType(MLIRContext *ctx, unsigned bitWidth); static emitc::OpaqueType getUnsignedIntOpaqueType(MLIRContext *ctx, @@ -655,107 +642,10 @@ static emitc::OpaqueType getWiderSignedIntOpaqueType(MLIRContext *ctx, unsigned bitWidth); static emitc::OpaqueType getWiderUnsignedIntOpaqueType(MLIRContext *ctx, unsigned bitWidth); -static Value makeEmitCOpaqueConstant(ConversionPatternRewriter &rewriter, - Location loc, Type type, - llvm::StringRef literal); -static Value makeEmitCIntConstant(ConversionPatternRewriter &rewriter, - Location loc, Type type, int64_t value); -static Value emitCCast(ConversionPatternRewriter &rewriter, Location loc, - Type dstType, Value src); static FailureOr buildEmitCOpaqueConstantLiteral(Type targetType, Attribute valueAttr); -static Value castSignlessIntToUnsignedSameWidth(ConversionPatternRewriter &rewriter, - Location loc, Value v, - unsigned bitWidth); -static bool needsA5NoSplitVectorGuard(Operation *op); - -static FailureOr getTileSplitToken(int64_t split) { - switch (split) { - case 0: - return std::string("TileSplitAxis::TILE_NO_SPLIT"); - case 1: - return std::string("TileSplitAxis::TILE_UP_DOWN"); - case 2: - return std::string("TileSplitAxis::TILE_LEFT_RIGHT"); - default: - return failure(); - } -} - -static FailureOr -getTPipeDirectionToken(bool isL2G2L, int8_t dirMask, PTOArch targetArch) { - if (dirMask == 1) { - if (isL2G2L && targetArch == PTOArch::A5) - return std::string("Direction::DIR_C2V_GM"); - return std::string("Direction::DIR_C2V"); - } - if (dirMask == 2) { - if (isL2G2L && targetArch == PTOArch::A5) - return std::string("Direction::DIR_V2C_GM"); - return std::string("Direction::DIR_V2C"); - } - if (dirMask == 3) - return std::string("Direction::DIR_BOTH"); - return failure(); -} - -static std::string buildTPipeToken(int32_t flagBase, llvm::StringRef dirTok, - int32_t slotSize, int32_t slotNum, - int32_t localSlotNum, bool nosplit) { - std::string token = "TPipe<" + std::to_string(flagBase) + ", " + dirTok.str() + - ", " + std::to_string(slotSize) + ", " + - std::to_string(slotNum); - token += ", " + std::to_string(localSlotNum); - token += nosplit ? ", true" : ", false"; - token += ">"; - return token; -} - -static FailureOr buildTPipeTokenFromInitOp(Operation *op, - PTOArch targetArch) { - if (auto initOp = dyn_cast(op)) { - if (!initOp.getFlagBaseAttr()) - return failure(); - auto dirTok = - getTPipeDirectionToken(/*isL2G2L=*/true, initOp.getDirMask(), targetArch); - if (failed(dirTok)) - return failure(); - int32_t localSlotNum = initOp.getLocalSlotNumAttr() - ? initOp.getLocalSlotNumAttr().getInt() - : initOp.getSlotNum(); - return buildTPipeToken(initOp.getFlagBaseAttr().getInt(), *dirTok, - initOp.getSlotSize(), initOp.getSlotNum(), - localSlotNum, - initOp.getNosplitAttr() && - initOp.getNosplitAttr().getValue()); - } - - if (auto initOp = dyn_cast(op)) { - if (!initOp.getFlagBaseAttr()) - return failure(); - auto dirTok = - getTPipeDirectionToken(/*isL2G2L=*/false, initOp.getDirMask(), targetArch); - if (failed(dirTok)) - return failure(); - return buildTPipeToken(initOp.getFlagBaseAttr().getInt(), *dirTok, - initOp.getSlotSize(), initOp.getSlotNum(), 2, - initOp.getNosplitAttr() && - initOp.getNosplitAttr().getValue()); - } - - return failure(); -} -static FailureOr getTPipeTokenFromValue(Value pipeHandle, - PTOArch targetArch) { - pipeHandle = peelUnrealized(pipeHandle); - Operation *def = pipeHandle.getDefiningOp(); - if (!def) - return failure(); - return buildTPipeTokenFromInitOp(def, targetArch); -} - -static bool isSetFFTsPointerLikeType(Type ty) { +bool mlir::pto::isSetFFTsPointerLikeType(Type ty) { return isEmitCPointerLikeType(ty); } @@ -770,7 +660,7 @@ static Type getTileDataResultType(MLIRContext *ctx, pto::AddressSpace as, return getEmitCPointerType(ctx, addrSpaceQualifier(as), elemTok); } -static Value materializeTileDataValue(ConversionPatternRewriter &rewriter, +Value mlir::pto::materializeTileDataValue(ConversionPatternRewriter &rewriter, Location loc, Value tile, pto::AddressSpace as, StringRef elemTok) { @@ -782,7 +672,7 @@ static Value materializeTileDataValue(ConversionPatternRewriter &rewriter, .getResult(0); } -static Value materializeAddressAsPointer(ConversionPatternRewriter &rewriter, +Value mlir::pto::materializeAddressAsPointer(ConversionPatternRewriter &rewriter, Location loc, Value addr, pto::AddressSpace as, StringRef elemTok) { @@ -804,146 +694,6 @@ static Value materializeAddressAsPointer(ConversionPatternRewriter &rewriter, .getResult(0); } -struct InterCoreSyncCallDesc { - const char *callee = nullptr; - ArrayAttr args; - SmallVector operands; -}; - -static Value castInterCoreEventIdToI32(ConversionPatternRewriter &rewriter, - Location loc, Value eventId) { - auto i32Ty = emitc::OpaqueType::get(rewriter.getContext(), "int32_t"); - if (eventId.getType() == i32Ty) - return eventId; - return emitCCast(rewriter, loc, i32Ty, eventId); -} - -static Attribute getFFTSModeCodegenArg(ConversionPatternRewriter &rewriter, - int64_t fftsMode) { - auto *ctx = rewriter.getContext(); - if (fftsMode == 2) - return emitc::OpaqueAttr::get(ctx, "FFTS_MODE_VAL"); - return emitc::OpaqueAttr::get(ctx, std::to_string(fftsMode)); -} - -static Value createFFTSMsg(ConversionPatternRewriter &rewriter, Location loc, - Value eventI32, int64_t fftsMode) { - auto *ctx = rewriter.getContext(); - auto msgTy = emitc::OpaqueType::get(ctx, "uint16_t"); - auto msgArgs = rewriter.getArrayAttr({ - getFFTSModeCodegenArg(rewriter, fftsMode), - IntegerAttr::get(IndexType::get(ctx), 0), - }); - return rewriter - .create(loc, msgTy, "getFFTSMsg", - /*args=*/msgArgs, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{eventI32}) - .getResult(0); -} - -static InterCoreSyncCallDesc buildInterCoreSyncSetCall( - ConversionPatternRewriter &rewriter, Location loc, PTOArch targetArch, - pto::PipeAttr pipeAttr, IntegerAttr eventIdAttr, int64_t fftsMode) { - auto *ctx = rewriter.getContext(); - std::string pipeTok = pipeTokFromPipeAttr(pipeAttr); - - if (targetArch == PTOArch::A3) { - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value eventVal = - makeEmitCIntConstant(rewriter, loc, i32Ty, eventIdAttr.getInt()); - Value msgVal = createFFTSMsg(rewriter, loc, eventVal, fftsMode); - - InterCoreSyncCallDesc desc; - desc.callee = "ffts_cross_core_sync"; - desc.args = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeTok), - IntegerAttr::get(IndexType::get(ctx), 0), - }); - desc.operands.push_back(msgVal); - return desc; - } - - InterCoreSyncCallDesc desc; - desc.callee = "set_intra_block"; - desc.args = rewriter.getArrayAttr( - {emitc::OpaqueAttr::get(ctx, pipeTok), eventIdAttr}); - return desc; -} - -static InterCoreSyncCallDesc buildInterCoreSyncSetCallDyn( - ConversionPatternRewriter &rewriter, Location loc, PTOArch targetArch, - pto::PipeAttr pipeAttr, Value eventIdVal, int64_t fftsMode) { - auto *ctx = rewriter.getContext(); - std::string pipeTok = pipeTokFromPipeAttr(pipeAttr); - Value eventI32 = castInterCoreEventIdToI32(rewriter, loc, eventIdVal); - - if (targetArch == PTOArch::A3) { - Value msgVal = createFFTSMsg(rewriter, loc, eventI32, fftsMode); - - InterCoreSyncCallDesc desc; - desc.callee = "ffts_cross_core_sync"; - desc.args = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeTok), - IntegerAttr::get(IndexType::get(ctx), 0), - }); - desc.operands.push_back(msgVal); - return desc; - } - - InterCoreSyncCallDesc desc; - desc.callee = "set_intra_block"; - desc.args = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeTok), - IntegerAttr::get(IndexType::get(ctx), 0), - }); - desc.operands.push_back(eventI32); - return desc; -} - -static InterCoreSyncCallDesc buildInterCoreSyncWaitCall( - ConversionPatternRewriter &rewriter, PTOArch targetArch, - pto::PipeAttr pipeAttr, IntegerAttr eventIdAttr) { - auto *ctx = rewriter.getContext(); - std::string pipeTok = pipeTokFromPipeAttr(pipeAttr); - - InterCoreSyncCallDesc desc; - if (targetArch == PTOArch::A3) { - desc.callee = "wait_flag_dev"; - desc.args = rewriter.getArrayAttr({eventIdAttr}); - return desc; - } - - desc.callee = "wait_intra_block"; - desc.args = rewriter.getArrayAttr( - {emitc::OpaqueAttr::get(ctx, pipeTok), eventIdAttr}); - return desc; -} - -static InterCoreSyncCallDesc buildInterCoreSyncWaitCallDyn( - ConversionPatternRewriter &rewriter, Location loc, PTOArch targetArch, - pto::PipeAttr pipeAttr, Value eventIdVal) { - auto *ctx = rewriter.getContext(); - std::string pipeTok = pipeTokFromPipeAttr(pipeAttr); - Value eventI32 = castInterCoreEventIdToI32(rewriter, loc, eventIdVal); - - InterCoreSyncCallDesc desc; - if (targetArch == PTOArch::A3) { - desc.callee = "wait_flag_dev"; - desc.args = rewriter.getArrayAttr({IntegerAttr::get(IndexType::get(ctx), 0)}); - desc.operands.push_back(eventI32); - return desc; - } - - desc.callee = "wait_intra_block"; - desc.args = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeTok), - IntegerAttr::get(IndexType::get(ctx), 0), - }); - desc.operands.push_back(eventI32); - return desc; -} - static bool hasInterCoreSyncOp(func::FuncOp func) { bool found = false; func.walk([&](Operation *op) { @@ -968,11485 +718,1328 @@ static bool hasSetFFTsOp(func::FuncOp func) { return found; } +// Arith/Affine conversion patterns live in PTOToEmitCArith.cpp. + //===----------------------------------------------------------------------===// -// Arith -> EmitC (full dialect coverage for scalar ops) +// Arith -> EmitC helpers //===----------------------------------------------------------------------===// -template -struct ArithSimpleBinaryToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type dstTy = this->getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getOperands()); - return success(); +static emitc::OpaqueType getSignedIntOpaqueType(MLIRContext *ctx, + unsigned bitWidth) { + switch (bitWidth) { + case 1: + return emitc::OpaqueType::get(ctx, "int8_t"); + case 8: + return emitc::OpaqueType::get(ctx, "int8_t"); + case 16: + return emitc::OpaqueType::get(ctx, "int16_t"); + case 32: + return emitc::OpaqueType::get(ctx, "int32_t"); + case 64: + return emitc::OpaqueType::get(ctx, "int64_t"); + case 128: + return emitc::OpaqueType::get(ctx, "__int128"); + default: + llvm::errs() << "[Debug] Unsupported signed integer bitwidth: " << bitWidth + << "\n"; + return emitc::OpaqueType::get(ctx, "int64_t"); } -}; - -// Integer bitwise ops (andi/ori/xori) on signless integers: perform in unsigned -// to avoid signedness pitfalls, then cast back. -template -struct ArithUnsignedBitwiseBinaryToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = this->getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - if (bitWidth == 1) { - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getLhs(), - adaptor.getRhs()); - return success(); - } +} - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value resU = rewriter.create(loc, uTy, lhsU, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, resU); - rewriter.replaceOp(op, result); - return success(); +static emitc::OpaqueType getUnsignedIntOpaqueType(MLIRContext *ctx, + unsigned bitWidth) { + switch (bitWidth) { + case 1: + return emitc::OpaqueType::get(ctx, "uint8_t"); + case 8: + return emitc::OpaqueType::get(ctx, "uint8_t"); + case 16: + return emitc::OpaqueType::get(ctx, "uint16_t"); + case 32: + return emitc::OpaqueType::get(ctx, "uint32_t"); + case 64: + return emitc::OpaqueType::get(ctx, "uint64_t"); + case 128: + return emitc::OpaqueType::get(ctx, "unsigned __int128"); + default: + llvm::errs() << "[Debug] Unsupported unsigned integer bitwidth: " + << bitWidth << "\n"; + return emitc::OpaqueType::get(ctx, "uint64_t"); } -}; - -struct ArithDivUIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::DivUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); +} - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value divU = rewriter.create(loc, uTy, lhsU, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, divU); - rewriter.replaceOp(op, result); - return success(); +[[maybe_unused]] static emitc::OpaqueType getWiderSignedIntOpaqueType(MLIRContext *ctx, + unsigned bitWidth) { + switch (bitWidth) { + case 1: + case 8: + return getSignedIntOpaqueType(ctx, 16); + case 16: + return getSignedIntOpaqueType(ctx, 32); + case 32: + return getSignedIntOpaqueType(ctx, 64); + case 64: + return getSignedIntOpaqueType(ctx, 128); + default: + return getSignedIntOpaqueType(ctx, 128); } -}; - -struct ArithRemUIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::RemUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); +} - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value remU = rewriter.create(loc, uTy, lhsU, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, remU); - rewriter.replaceOp(op, result); - return success(); +[[maybe_unused]] static emitc::OpaqueType getWiderUnsignedIntOpaqueType(MLIRContext *ctx, + unsigned bitWidth) { + switch (bitWidth) { + case 1: + case 8: + return getUnsignedIntOpaqueType(ctx, 16); + case 16: + return getUnsignedIntOpaqueType(ctx, 32); + case 32: + return getUnsignedIntOpaqueType(ctx, 64); + case 64: + return getUnsignedIntOpaqueType(ctx, 128); + default: + return getUnsignedIntOpaqueType(ctx, 128); } -}; - -struct ArithCeilDivUIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::CeilDivUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); +} - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); +Value mlir::pto::makeEmitCOpaqueConstant(ConversionPatternRewriter &rewriter, + Location loc, Type type, + llvm::StringRef literal) { + auto attr = emitc::OpaqueAttr::get(rewriter.getContext(), literal); + return rewriter.create(loc, type, attr); +} - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); +Value mlir::pto::makeEmitCIntConstant(ConversionPatternRewriter &rewriter, + Location loc, Type type, int64_t value) { + return makeEmitCOpaqueConstant(rewriter, loc, type, std::to_string(value)); +} - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value one = makeEmitCIntConstant(rewriter, loc, uTy, 1); - Value rhsMinusOne = rewriter.create(loc, uTy, rhsU, one); - Value num = rewriter.create(loc, uTy, lhsU, rhsMinusOne); - Value divU = rewriter.create(loc, uTy, num, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, divU); - rewriter.replaceOp(op, result); - return success(); - } -}; +[[maybe_unused]] static FailureOr buildEmitCOpaqueConstantLiteral(Type targetType, + Attribute valueAttr) { + auto opaqueTy = dyn_cast(targetType); + if (!opaqueTy) + return failure(); -struct ArithCeilDivSIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::CeilDivSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) + if (opaqueTy.getValue() == "pto::MrgSortExecutedNumList") { + auto dense = dyn_cast_or_null(valueAttr); + if (!dense) return failure(); - Value zero = makeEmitCIntConstant(rewriter, loc, dstTy, 0); - Value one = makeEmitCIntConstant(rewriter, loc, dstTy, 1); - - Value q0 = rewriter.create(loc, dstTy, adaptor.getLhs(), - adaptor.getRhs()); - Value r = rewriter.create(loc, dstTy, adaptor.getLhs(), - adaptor.getRhs()); - - Value rNeZero = rewriter.create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::ne, r, - zero); - Value lhsLt0 = - rewriter.create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, adaptor.getLhs(), - zero); - Value rhsLt0 = - rewriter.create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, adaptor.getRhs(), - zero); - Value signsSame = - rewriter.create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::eq, lhsLt0, rhsLt0); - Value adjust = - rewriter.create(loc, rewriter.getI1Type(), - rNeZero, signsSame); - - Value qPlusOne = rewriter.create(loc, dstTy, q0, one); - Value result = rewriter.create(loc, dstTy, adjust, - qPlusOne, q0); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithFloorDivSIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::FloorDivSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) + auto vecTy = dyn_cast(dense.getType()); + if (!vecTy || vecTy.getRank() != 1 || vecTy.getNumElements() != 4 || + !vecTy.getElementType().isInteger(16)) return failure(); - Value zero = makeEmitCIntConstant(rewriter, loc, dstTy, 0); - Value one = makeEmitCIntConstant(rewriter, loc, dstTy, 1); - - Value q0 = rewriter.create(loc, dstTy, adaptor.getLhs(), - adaptor.getRhs()); - Value r = rewriter.create(loc, dstTy, adaptor.getLhs(), - adaptor.getRhs()); - - Value rNeZero = rewriter.create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::ne, r, - zero); - Value lhsLt0 = - rewriter.create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, adaptor.getLhs(), - zero); - Value rhsLt0 = - rewriter.create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, adaptor.getRhs(), - zero); - Value signsDifferent = - rewriter.create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::ne, lhsLt0, rhsLt0); - Value adjust = - rewriter.create(loc, rewriter.getI1Type(), - rNeZero, signsDifferent); - - Value qMinusOne = rewriter.create(loc, dstTy, q0, one); - Value result = rewriter.create(loc, dstTy, adjust, - qMinusOne, q0); - rewriter.replaceOp(op, result); - return success(); + std::string literal; + llvm::raw_string_ostream os(literal); + os << "pto::MrgSortExecutedNumList{"; + bool first = true; + for (APInt elem : dense.getValues()) { + if (!first) + os << ", "; + first = false; + os << elem.getZExtValue(); + } + os << "}"; + os.flush(); + return literal; } -}; -struct ArithShiftLeftToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::ShLIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + return failure(); +} - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); +Value mlir::pto::emitCCast(ConversionPatternRewriter &rewriter, Location loc, + Type dstType, Value src) { + if (src.getType() == dstType) + return src; + return rewriter.createOrFold(loc, dstType, src); +} - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); +// For signless iN integers lowered to signed C++ types, this creates a value +// representing the same N-bit pattern in an unsigned C++ type of the same +// width. This avoids incorrect sign-extension when later widening to a larger +// unsigned type. +Value mlir::pto::castSignlessIntToUnsignedSameWidth(ConversionPatternRewriter &rewriter, + Location loc, Value v, + unsigned bitWidth) { + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + return emitCCast(rewriter, loc, uTy, v); +} - if (bitWidth == 1) { - // Compute on u8 and truncate to i1. - auto u8Ty = getUnsignedIntOpaqueType(rewriter.getContext(), 8); - Value lhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getLhs()); - Value rhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getRhs()); - Value sh = rewriter.create(loc, u8Ty, lhsU8, - rhsU8); - Value masked = - rewriter.create(loc, u8Ty, sh, - makeEmitCIntConstant(rewriter, loc, - u8Ty, 1)); - rewriter.replaceOp(op, emitCCast(rewriter, loc, dstTy, masked)); - return success(); - } +//===----------------------------------------------------------------------===// +// pto.mgather lowering -> MGATHER(dst, src, indexes) (pto-isa) +//===----------------------------------------------------------------------===// - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value shU = - rewriter.create(loc, uTy, lhsU, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, shU); - rewriter.replaceOp(op, result); - return success(); - } -}; +struct PTOMGatherToMGATHER : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; -struct ArithShiftRightUIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::ShRUIOp op, OpAdaptor adaptor, + LogicalResult matchAndRewrite(pto::MGatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + auto *ctx = rewriter.getContext(); + Value mem = peelUnrealized(adaptor.getMem()); + Value idx = peelUnrealized(adaptor.getIdx()); + Value dst = peelUnrealized(adaptor.getDst()); - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); + Value memArg = maybeWrapGlobalMemrefAsGlobalTensor( + rewriter, op.getLoc(), mem, op.getMem().getType(), op.getOperation()); - if (bitWidth == 1) { - // (x >> y) on i1 is either x (y==0) or 0 (y!=0); approximate in u8. - auto u8Ty = getUnsignedIntOpaqueType(rewriter.getContext(), 8); - Value lhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getLhs()); - Value rhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getRhs()); - Value sh = rewriter.create(loc, u8Ty, lhsU8, - rhsU8); - Value masked = - rewriter.create(loc, u8Ty, sh, - makeEmitCIntConstant(rewriter, loc, - u8Ty, 1)); - rewriter.replaceOp(op, emitCCast(rewriter, loc, dstTy, masked)); - return success(); + auto gatherOobTok = [&](pto::GatherOOB mode) -> StringRef { + switch (mode) { + case pto::GatherOOB::Undefined: + return "pto::GatherOOB::Undefined"; + case pto::GatherOOB::Clamp: + return "pto::GatherOOB::Clamp"; + case pto::GatherOOB::Wrap: + return "pto::GatherOOB::Wrap"; + case pto::GatherOOB::Zero: + return "pto::GatherOOB::Zero"; + } + llvm_unreachable("unknown GatherOOB"); + }; + + SmallVector templateArgVec; + const bool rowCoalesce = + isRowCoalescedMGatherIndexType(op.getDst().getType(), op.getIdx().getType()); + templateArgVec.push_back(emitc::OpaqueAttr::get( + ctx, rowCoalesce ? "pto::Coalesce::Row" : "pto::Coalesce::Elem")); + if (op.getGatherOob() != pto::GatherOOB::Undefined) { + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, gatherOobTok(op.getGatherOob()))); } + ArrayAttr templateArgs = rewriter.getArrayAttr(templateArgVec); + + rewriter.create( + op.getLoc(), TypeRange{}, "MGATHER", + ArrayAttr{}, templateArgs, + ValueRange{dst, memArg, idx}); - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value shU = - rewriter.create(loc, uTy, lhsU, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, shU); - rewriter.replaceOp(op, result); + if (op->getNumResults() == 0) { + rewriter.eraseOp(op); + } else { + rewriter.replaceOp(op, dst); + } return success(); } }; -struct ArithShiftRightSIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::ShRSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); +static std::optional getKernelKindMacro(func::FuncOp funcOp) { + auto kernelKindAttr = + funcOp->getAttrOfType(FunctionKernelKindAttr::name); + if (!kernelKindAttr) + return std::nullopt; - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); + switch (kernelKindAttr.getKernelKind()) { + case FunctionKernelKind::Cube: + return StringRef("__DAV_CUBE__"); + case FunctionKernelKind::Vector: + return StringRef("__DAV_VEC__"); + } - if (bitWidth == 1) { - // (x >> y) on i1 is either x (y==0) or 0 (y!=0); approximate in u8. - auto u8Ty = getUnsignedIntOpaqueType(rewriter.getContext(), 8); - Value lhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getLhs()); - Value rhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getRhs()); - Value sh = rewriter.create(loc, u8Ty, lhsU8, - rhsU8); - Value masked = - rewriter.create(loc, u8Ty, sh, - makeEmitCIntConstant(rewriter, loc, - u8Ty, 1)); - rewriter.replaceOp(op, emitCCast(rewriter, loc, dstTy, masked)); - return success(); - } + llvm_unreachable("unexpected kernel kind"); +} - // Signed arithmetic shift; cast RHS to unsigned to interpret shift amount. - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value sh = - rewriter.create(loc, dstTy, adaptor.getLhs(), - rhsU); - rewriter.replaceOp(op, sh); - return success(); - } -}; +struct FuncToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; -struct ArithNegFToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::NegFOp op, OpAdaptor adaptor, + LogicalResult matchAndRewrite(func::FuncOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getOperand()); - return success(); - } -}; + // Convert the function signature with the type converter. + Type convertedTy = getTypeConverter()->convertType(op.getFunctionType()); + auto funcType = dyn_cast_or_null(convertedTy); + if (!funcType) + return rewriter.notifyMatchFailure(op, "failed to convert function type"); + if (funcType.getNumResults() > 1) + return rewriter.notifyMatchFailure( + op, "EmitC cannot return multiple values"); -struct ArithRemFToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::RemFOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); + // Create the EmitC function with the converted signature. + auto emitcFunc = + rewriter.create(op.getLoc(), op.getName(), funcType); - // Use builtin `fmod` when possible. For f16, compute in float and cast back. - Type callTy = dstTy; - Value lhs = adaptor.getLhs(); - Value rhs = adaptor.getRhs(); - - if (auto opFloatTy = dyn_cast(op.getType())) { - if (opFloatTy.isF16()) { - auto f32Ty = emitc::OpaqueType::get(rewriter.getContext(), "float"); - lhs = emitCCast(rewriter, loc, f32Ty, lhs); - rhs = emitCCast(rewriter, loc, f32Ty, rhs); - callTy = f32Ty; - } + for (const auto &namedAttr : op->getAttrs()) { + StringRef name = namedAttr.getName().strref(); + if (name == op.getFunctionTypeAttrName() || + name == SymbolTable::getSymbolAttrName() || + name == pto::kPTOEntryAttrName || + name == pto::kLegacyHACCEntryAttrName || + name == "pto.internal.entry") + continue; + emitcFunc->setAttr(namedAttr.getName(), namedAttr.getValue()); } - // Prefer `__builtin_fmod*` to avoid relying on extra headers. - llvm::StringRef callee = "__builtin_fmod"; - if (auto opFloatTy = dyn_cast(op.getType())) { - if (opFloatTy.isF32() || opFloatTy.isF16()) - callee = "__builtin_fmodf"; - else if (opFloatTy.isF64()) - callee = "__builtin_fmod"; + if (op.isDeclaration()) { + emitcFunc.setSpecifiersAttr(rewriter.getStrArrayAttr({"extern"})); + rewriter.eraseOp(op); + return success(); } - auto call = rewriter.create( - loc, TypeRange{callTy}, callee, ValueRange{lhs, rhs}, - /*args=*/ArrayAttr{}, /*template_args=*/ArrayAttr{}); - Value result = call.getResult(0); - if (callTy != dstTy) - result = emitCCast(rewriter, loc, dstTy, result); - - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithSelectToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!op.getCondition().getType().isInteger(1)) - return rewriter.notifyMatchFailure( - op, "only scalar i1 conditions supported for arith.select"); + if (pto::isPTOEntryFunction(op)) { + emitcFunc.setSpecifiersAttr( + rewriter.getStrArrayAttr({"__global__ AICORE"})); + } else if (op.isPrivate()) { + emitcFunc.setSpecifiersAttr( + rewriter.getStrArrayAttr({"static", "AICORE"})); + } else { + emitcFunc.setSpecifiersAttr(rewriter.getStrArrayAttr({"AICORE"})); + } - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); + std::optional kernelKindMacro = getKernelKindMacro(op); + bool needsNoSplitGuard = needsA5NoSplitVectorGuard(op.getOperation()); - auto cond = - rewriter.create(op.getLoc(), dstTy, - adaptor.getCondition(), - adaptor.getTrueValue(), - adaptor.getFalseValue()); - rewriter.replaceOp(op, cond.getResult()); - return success(); - } -}; + // Inline the original body, then convert region/block argument types to + // match the converted signature (also covers CFG blocks introduced by + // pre-lowering, e.g. scf.while -> cf.br/cf.cond_br). + rewriter.inlineRegionBefore(op.getBody(), emitcFunc.getBody(), + emitcFunc.end()); -struct ArithExtUIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto dstIntTy = dyn_cast(op.getType()); - auto srcIntTy = dyn_cast(op.getIn().getType()); - if (!dstIntTy || !srcIntTy) - return rewriter.notifyMatchFailure(op, "expected scalar integer types"); + TypeConverter::SignatureConversion entryConv(op.getNumArguments()); + for (unsigned i = 0; i < op.getNumArguments(); ++i) + entryConv.addInputs(i, funcType.getInput(i)); - Type dstTy = getTypeConverter()->convertType(dstIntTy); - if (!dstTy) + if (failed(rewriter.convertRegionTypes(&emitcFunc.getBody(), + *getTypeConverter(), &entryConv))) return failure(); - // i1 -> iN: bool to integer already behaves as 0/1. - if (srcIntTy.getWidth() == 1) { - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); - return success(); + // Preserve the existing function prologue shape. `kernel_kind` functions are + // emitted with the same macro guard/reset sequence that used to come from + // early pto.section wrapping, but only after SCF pre-lowering has finished. + { + Block &entryBlock = emitcFunc.getBody().front(); + rewriter.setInsertionPointToStart(&entryBlock); + rewriter.create(op.getLoc(), "using T = float;"); + if (kernelKindMacro) { + std::string startMacro = "\n#if defined(" + kernelKindMacro->str() + ")"; + rewriter.create(op.getLoc(), startMacro); + if (*kernelKindMacro == "__DAV_VEC__") { + rewriter.create(op.getLoc(), "set_mask_norm();"); + rewriter.create(op.getLoc(), + "set_vector_mask(-1, -1);"); + if (needsNoSplitGuard) + rewriter.create( + op.getLoc(), "if (get_subblockid() == 0) {"); + } + } + } + + if (kernelKindMacro) { + Block &lastBlock = emitcFunc.getBody().back(); + rewriter.setInsertionPoint(lastBlock.getTerminator()); + if (*kernelKindMacro == "__DAV_VEC__" && needsNoSplitGuard) + rewriter.create(op.getLoc(), "}"); + std::string endMacro = "#endif // " + kernelKindMacro->str() + "\n"; + rewriter.create(op.getLoc(), endMacro); } - auto uDstTy = - getUnsignedIntOpaqueType(rewriter.getContext(), dstIntTy.getWidth()); - Value srcU = - castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getIn(), - srcIntTy.getWidth()); - Value extU = emitCCast(rewriter, loc, uDstTy, srcU); - Value result = emitCCast(rewriter, loc, dstTy, extU); - rewriter.replaceOp(op, result); + rewriter.eraseOp(op); return success(); } }; -struct ArithExtSIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto dstIntTy = dyn_cast(op.getType()); - auto srcIntTy = dyn_cast(op.getIn().getType()); - if (!dstIntTy || !srcIntTy) - return rewriter.notifyMatchFailure(op, "expected scalar integer types"); +//===----------------------------------------------------------------------===// +// SubView lowering to GlobalTensor (keep your existing code) +//===----------------------------------------------------------------------=== - Type dstTy = getTypeConverter()->convertType(dstIntTy); - if (!dstTy) - return failure(); +enum class Role { A, B, C, Unknown }; - // i1 sign-extension: 0 -> 0, 1 -> -1. - if (srcIntTy.getWidth() == 1) { - Value zero = makeEmitCIntConstant(rewriter, loc, dstTy, 0); - Value asInt = emitCCast(rewriter, loc, dstTy, adaptor.getIn()); - Value neg = rewriter.create(loc, dstTy, zero, asInt).getResult(); - rewriter.replaceOp(op, neg); - return success(); - } +template +static std::optional inferMatmulLikeSubviewRole(MatmulLikeOp op, + Value buffer) { + if (op.getLhs() == buffer) + return Role::A; + if (op.getRhs() == buffer) + return Role::B; + return std::nullopt; +} - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); - return success(); +static std::optional inferSubviewRoleFromLoadUser(mlir::pto::TLoadOp load) { + Value buffer = load.getDst(); + if (!buffer) + return std::nullopt; + for (Operation *user : buffer.getUsers()) { + if (auto matmul = dyn_cast(user)) { + if (auto role = inferMatmulLikeSubviewRole(matmul, buffer)) + return role; + continue; + } + if (auto matmulAcc = dyn_cast(user)) { + if (auto role = inferMatmulLikeSubviewRole(matmulAcc, buffer)) + return role; + } } -}; + return std::nullopt; +} -template -struct ArithCastToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type dstTy = this->getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); - return success(); +static std::optional inferSubviewRoleFromUser(Operation *user, Value result) { + if (auto load = dyn_cast(user)) + return inferSubviewRoleFromLoadUser(load); + if (auto store = dyn_cast(user)) { + if (store.getDst() == result) + return Role::C; } -}; - -struct ArithIndexCastUIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::IndexCastUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - - // MemRef casts are handled elsewhere; for safety, fall back to emitc.cast. - if (isa(op.getIn().getType()) || isa(op.getType())) { - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); - return success(); - } + return std::nullopt; +} - auto getBW = [](Type t) -> std::optional { - if (auto i = dyn_cast(t)) - return i.getWidth(); - if (isa(t)) - return kPTOIndexBitWidth; - return std::nullopt; - }; +[[maybe_unused]] static Role inferSubviewRole(memref::SubViewOp sv) { + Value result = sv.getResult(); + for (Operation *user : result.getUsers()) { + if (auto role = inferSubviewRoleFromUser(user, result)) + return *role; + } + return Role::Unknown; +} - auto srcBW = getBW(op.getIn().getType()); - auto dstBW = getBW(op.getType()); - if (!srcBW || !dstBW) - return rewriter.notifyMatchFailure(op, "unsupported index_castui types"); +// ============================================================================= +// 4. MemRef SubView -> Explicit Shape/Stride Construction (Full Implementation) +// ============================================================================= +struct SubviewToEmitCPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - if (*dstBW <= *srcBW) { - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); - return success(); + // 辅助函数:尝试从 OpFoldResult 中提取静态整数值 + std::optional extractStaticInt(OpFoldResult ofr) const { + if (auto attr = ofr.dyn_cast()) { + if (auto intAttr = dyn_cast(attr)) + return intAttr.getInt(); + } else { + Value v = ofr.get(); + if (auto cOp = v.getDefiningOp()) { + if (auto iAttr = dyn_cast(cOp.getValue())) + return iAttr.getInt(); + } else if (auto idxOp = v.getDefiningOp()) { + return idxOp.value(); + } } - - auto uSrcTy = getUnsignedIntOpaqueType(rewriter.getContext(), *srcBW); - auto uDstTy = getUnsignedIntOpaqueType(rewriter.getContext(), *dstBW); - Value srcU = emitCCast(rewriter, loc, uSrcTy, adaptor.getIn()); - Value extU = emitCCast(rewriter, loc, uDstTy, srcU); - Value result = emitCCast(rewriter, loc, dstTy, extU); - rewriter.replaceOp(op, result); - return success(); + return std::nullopt; } -}; -struct ArithUIToFPToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, + LogicalResult matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - auto srcIntTy = dyn_cast(op.getIn().getType()); - if (!srcIntTy) - return rewriter.notifyMatchFailure(op, "expected scalar integer input"); + auto *ctx = rewriter.getContext(); + + // 获取源 MemRef 类型信息 + auto srcType = mlir::cast(op.getSource().getType()); + int64_t rank = srcType.getRank(); - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - - // Convert via an unsigned integer type of the same width. - if (srcIntTy.getWidth() == 1) { - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); - return success(); - } - Value srcU = - castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getIn(), - srcIntTy.getWidth()); - Value fp = rewriter.create(loc, dstTy, srcU).getResult(); - rewriter.replaceOp(op, fp); - return success(); - } -}; - -struct ArithFPToUIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::FPToUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto dstIntTy = dyn_cast(op.getType()); - if (!dstIntTy) - return rewriter.notifyMatchFailure(op, "expected scalar integer result"); - - Type dstTy = getTypeConverter()->convertType(dstIntTy); - if (!dstTy) - return failure(); - - auto uDstTy = - getUnsignedIntOpaqueType(rewriter.getContext(), dstIntTy.getWidth()); - Value asU = rewriter.create(loc, uDstTy, adaptor.getIn()).getResult(); - Value result = emitCCast(rewriter, loc, dstTy, asU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithBitcastToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - - // For pointer-like types, a regular cast is fine. - if (isa(dstTy)) { - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); - return success(); - } - - // Only support scalar int/float/index bitcasts here. - auto srcTy = op.getIn().getType(); - auto dstOrigTy = op.getType(); - - auto getBitWidth = [](Type t) -> std::optional { - if (auto it = dyn_cast(t)) - return it.getWidth(); - if (auto ft = dyn_cast(t)) - return ft.getWidth(); - if (isa(t)) - return kPTOIndexBitWidth; - return std::nullopt; - }; - auto srcBW = getBitWidth(srcTy); - auto dstBW = getBitWidth(dstOrigTy); - if (!srcBW || !dstBW || *srcBW != *dstBW) - return rewriter.notifyMatchFailure(op, "bitcast requires equal bitwidth"); - - // Determine the template argument from the destination type string. - auto dstOpaque = dyn_cast(dstTy); - if (!dstOpaque) - return rewriter.notifyMatchFailure(op, "expected emitc opaque dest type"); - - auto templateArgs = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), - dstOpaque.getValue())}); - auto call = rewriter.create( - loc, TypeRange{dstTy}, "ptoas_bitcast", /*operands=*/ValueRange{adaptor.getIn()}, - /*args=*/ArrayAttr{}, /*template_args=*/templateArgs); - rewriter.replaceOp(op, call.getResult(0)); - return success(); - } -}; - -// arith.cmpf lowering with ordered/unordered semantics. -struct ArithCmpFToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - struct CmpFConfig { - bool unordered = false; - emitc::CmpPredicate predicate = emitc::CmpPredicate::eq; - }; - - static Value isNaN(ConversionPatternRewriter &rewriter, Location loc, - Value v) { - return rewriter - .create(loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, - v, v) - .getResult(); - } - - static Value isNotNaN(ConversionPatternRewriter &rewriter, Location loc, - Value v) { - return rewriter - .create(loc, rewriter.getI1Type(), emitc::CmpPredicate::eq, - v, v) - .getResult(); - } - - static std::optional buildSpecialCmpFResult( - arith::CmpFPredicate predicate, ConversionPatternRewriter &rewriter, - Location loc, Type i1Ty, Value lhs, Value rhs) { - switch (predicate) { - case arith::CmpFPredicate::AlwaysFalse: - return makeEmitCOpaqueConstant(rewriter, loc, i1Ty, "false"); - case arith::CmpFPredicate::AlwaysTrue: - return makeEmitCOpaqueConstant(rewriter, loc, i1Ty, "true"); - case arith::CmpFPredicate::ORD: - return rewriter.create( - loc, i1Ty, isNotNaN(rewriter, loc, lhs), - isNotNaN(rewriter, loc, rhs)) - .getResult(); - case arith::CmpFPredicate::UNO: - return rewriter.create( - loc, i1Ty, isNaN(rewriter, loc, lhs), - isNaN(rewriter, loc, rhs)) - .getResult(); - default: - return std::nullopt; - } - } - - static std::optional - getCmpFConfig(arith::CmpFPredicate predicate) { - switch (predicate) { - case arith::CmpFPredicate::OEQ: - return CmpFConfig{false, emitc::CmpPredicate::eq}; - case arith::CmpFPredicate::OGT: - return CmpFConfig{false, emitc::CmpPredicate::gt}; - case arith::CmpFPredicate::OGE: - return CmpFConfig{false, emitc::CmpPredicate::ge}; - case arith::CmpFPredicate::OLT: - return CmpFConfig{false, emitc::CmpPredicate::lt}; - case arith::CmpFPredicate::OLE: - return CmpFConfig{false, emitc::CmpPredicate::le}; - case arith::CmpFPredicate::ONE: - return CmpFConfig{false, emitc::CmpPredicate::ne}; - case arith::CmpFPredicate::UEQ: - return CmpFConfig{true, emitc::CmpPredicate::eq}; - case arith::CmpFPredicate::UGT: - return CmpFConfig{true, emitc::CmpPredicate::gt}; - case arith::CmpFPredicate::UGE: - return CmpFConfig{true, emitc::CmpPredicate::ge}; - case arith::CmpFPredicate::ULT: - return CmpFConfig{true, emitc::CmpPredicate::lt}; - case arith::CmpFPredicate::ULE: - return CmpFConfig{true, emitc::CmpPredicate::le}; - case arith::CmpFPredicate::UNE: - return CmpFConfig{true, emitc::CmpPredicate::ne}; - default: - return std::nullopt; - } - } - - static Value buildCmpFResult(const CmpFConfig &config, - ConversionPatternRewriter &rewriter, - Location loc, Type i1Ty, Value lhs, Value rhs) { - Value cmp = rewriter - .create(loc, i1Ty, config.predicate, lhs, rhs) - .getResult(); - Value unord = rewriter.create( - loc, i1Ty, isNaN(rewriter, loc, lhs), isNaN(rewriter, loc, rhs)); - if (config.unordered) - return rewriter - .create(loc, i1Ty, unord, cmp) - .getResult(); - Value ord = rewriter.create( - loc, i1Ty, isNotNaN(rewriter, loc, lhs), isNotNaN(rewriter, loc, rhs)); - return rewriter - .create(loc, i1Ty, ord, cmp) - .getResult(); - } - - LogicalResult matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!isa(op.getLhs().getType())) - return rewriter.notifyMatchFailure(op, "cmpf only supported on scalar floats"); - - auto loc = op.getLoc(); - auto i1Ty = rewriter.getI1Type(); - if (auto special = buildSpecialCmpFResult(op.getPredicate(), rewriter, loc, - i1Ty, adaptor.getLhs(), - adaptor.getRhs())) { - rewriter.replaceOp(op, *special); - return success(); - } - - auto config = getCmpFConfig(op.getPredicate()); - if (!config) - return rewriter.notifyMatchFailure(op, "unsupported cmpf predicate"); - rewriter.replaceOp(op, buildCmpFResult(*config, rewriter, loc, i1Ty, - adaptor.getLhs(), adaptor.getRhs())); - return success(); - } -}; - -struct ArithAddUIExtendedToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getSum().getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, - "expected scalar integer or index operands"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - SmallVector newResultTypes; - if (failed(getTypeConverter()->convertTypes(op->getResultTypes(), - newResultTypes))) - return failure(); - if (newResultTypes.size() != 2) - return failure(); - - Type sumDstTy = newResultTypes[0]; - Type overflowDstTy = newResultTypes[1]; - - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - auto wideTy = getWiderUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value lhsWide = emitCCast(rewriter, loc, wideTy, lhsU); - Value rhsWide = emitCCast(rewriter, loc, wideTy, rhsU); - Value sumWide = - rewriter.create(loc, wideTy, lhsWide, rhsWide).getResult(); - - Value sumN = emitCCast(rewriter, loc, uTy, sumWide); - Value sum = emitCCast(rewriter, loc, sumDstTy, sumN); - - Value shiftAmt = makeEmitCIntConstant(rewriter, loc, wideTy, bitWidth); - Value high = rewriter - .create(loc, wideTy, sumWide, - shiftAmt) - .getResult(); - Value zeroWide = makeEmitCIntConstant(rewriter, loc, wideTy, 0); - Value overflow = - rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::ne, high, zeroWide) - .getResult(); - overflow = emitCCast(rewriter, loc, overflowDstTy, overflow); - - rewriter.replaceOp(op, {sum, overflow}); - return success(); - } -}; - -template -struct ArithMulExtendedToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getResult(0).getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, - "expected scalar integer or index operands"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - SmallVector newResultTypes; - if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), - newResultTypes))) - return failure(); - if (newResultTypes.size() != 2) - return failure(); - - Type lowDstTy = newResultTypes[0]; - Type highDstTy = newResultTypes[1]; - - Type wideTy = isUnsigned ? (Type)getWiderUnsignedIntOpaqueType(rewriter.getContext(), - bitWidth) - : (Type)getWiderSignedIntOpaqueType(rewriter.getContext(), - bitWidth); - - Value lhsWide; - Value rhsWide; - if constexpr (isUnsigned) { - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - lhsWide = emitCCast(rewriter, loc, wideTy, lhsU); - rhsWide = emitCCast(rewriter, loc, wideTy, rhsU); - } else { - lhsWide = emitCCast(rewriter, loc, wideTy, adaptor.getLhs()); - rhsWide = emitCCast(rewriter, loc, wideTy, adaptor.getRhs()); - } - - Value prodWide = - rewriter.create(loc, wideTy, lhsWide, rhsWide).getResult(); - Value low = emitCCast(rewriter, loc, lowDstTy, prodWide); - - Value shiftAmt = makeEmitCIntConstant(rewriter, loc, wideTy, bitWidth); - Value highWide = rewriter - .create(loc, wideTy, prodWide, - shiftAmt) - .getResult(); - Value high = emitCCast(rewriter, loc, highDstTy, highWide); - - rewriter.replaceOp(op, {low, high}); - return success(); - } -}; - -using ArithMulSIExtendedToEmitC = - ArithMulExtendedToEmitC; -using ArithMulUIExtendedToEmitC = - ArithMulExtendedToEmitC; - -struct ArithMinMaxIToEmitCBase { - static Value makeSelect(ConversionPatternRewriter &rewriter, Location loc, - Type dstTy, Value cond, Value trueV, Value falseV) { - return rewriter - .create(loc, dstTy, cond, trueV, falseV) - .getResult(); - } -}; - -struct ArithMaxSIToEmitC : public OpConversionPattern, - ArithMinMaxIToEmitCBase { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::MaxSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - Value cond = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, - adaptor.getLhs(), adaptor.getRhs()) - .getResult(); - Value res = makeSelect(rewriter, loc, dstTy, cond, adaptor.getRhs(), - adaptor.getLhs()); - rewriter.replaceOp(op, res); - return success(); - } -}; - -struct ArithMinSIToEmitC : public OpConversionPattern, - ArithMinMaxIToEmitCBase { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::MinSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - Value cond = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, - adaptor.getLhs(), adaptor.getRhs()) - .getResult(); - Value res = makeSelect(rewriter, loc, dstTy, cond, adaptor.getLhs(), - adaptor.getRhs()); - rewriter.replaceOp(op, res); - return success(); - } -}; - -struct ArithMaxUIToEmitC : public OpConversionPattern, - ArithMinMaxIToEmitCBase { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::MaxUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - Value lhsU = - castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = - castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value cond = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, lhsU, rhsU) - .getResult(); - Value res = makeSelect(rewriter, loc, dstTy, cond, adaptor.getRhs(), - adaptor.getLhs()); - rewriter.replaceOp(op, res); - return success(); - } -}; - -struct ArithMinUIToEmitC : public OpConversionPattern, - ArithMinMaxIToEmitCBase { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::MinUIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - Value lhsU = - castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = - castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value cond = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, lhsU, rhsU) - .getResult(); - Value res = makeSelect(rewriter, loc, dstTy, cond, adaptor.getLhs(), - adaptor.getRhs()); - rewriter.replaceOp(op, res); - return success(); - } -}; - -// Floating-point max/min variants. -struct ArithFloatMinMaxToEmitCBase { - static Value isNaN(ConversionPatternRewriter &rewriter, Location loc, - Value v) { - return rewriter - .create(loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, - v, v) - .getResult(); - } - - static Value makeFZero(ConversionPatternRewriter &rewriter, Location loc, - Type ty) { - return makeEmitCOpaqueConstant(rewriter, loc, ty, "0.0f"); - } -}; - -struct ArithMaxNumFToEmitC : public OpConversionPattern, - ArithFloatMinMaxToEmitCBase { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::MaxNumFOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - - Value lhsNaN = isNaN(rewriter, loc, adaptor.getLhs()); - Value rhsNaN = isNaN(rewriter, loc, adaptor.getRhs()); - - Value cmpLt = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, - adaptor.getLhs(), adaptor.getRhs()) - .getResult(); - Value maxNoNaN = - rewriter - .create(loc, dstTy, cmpLt, adaptor.getRhs(), - adaptor.getLhs()) - .getResult(); - - Value rhsOrMax = - rewriter - .create(loc, dstTy, rhsNaN, adaptor.getLhs(), - maxNoNaN) - .getResult(); - Value res = - rewriter - .create(loc, dstTy, lhsNaN, adaptor.getRhs(), - rhsOrMax) - .getResult(); - rewriter.replaceOp(op, res); - return success(); - } -}; - -struct ArithMinNumFToEmitC : public OpConversionPattern, - ArithFloatMinMaxToEmitCBase { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::MinNumFOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - - Value lhsNaN = isNaN(rewriter, loc, adaptor.getLhs()); - Value rhsNaN = isNaN(rewriter, loc, adaptor.getRhs()); - - Value cmpLt = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, - adaptor.getLhs(), adaptor.getRhs()) - .getResult(); - Value minNoNaN = - rewriter - .create(loc, dstTy, cmpLt, adaptor.getLhs(), - adaptor.getRhs()) - .getResult(); - - Value rhsOrMin = - rewriter - .create(loc, dstTy, rhsNaN, adaptor.getLhs(), - minNoNaN) - .getResult(); - Value res = - rewriter - .create(loc, dstTy, lhsNaN, adaptor.getRhs(), - rhsOrMin) - .getResult(); - rewriter.replaceOp(op, res); - return success(); - } -}; - -template -struct ArithMinMaxFPropagateNaNToEmitC : public OpConversionPattern, - ArithFloatMinMaxToEmitCBase { - using OpConversionPattern::OpConversionPattern; - - static Value buildPrimaryCandidate(ConversionPatternRewriter &rewriter, - Location loc, Type dstTy, Value lhs, - Value rhs) { - Value cmpLt = - rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::lt, lhs, rhs) - .getResult(); - return rewriter - .create( - loc, dstTy, cmpLt, isMaximum ? rhs : lhs, isMaximum ? lhs : rhs) - .getResult(); - } - - static Value buildSignBitValue(ConversionPatternRewriter &rewriter, - Location loc, Value lhs, FloatType floatTy) { - auto bitsTy = - getUnsignedIntOpaqueType(rewriter.getContext(), floatTy.getWidth()); - auto templateArgs = rewriter.getArrayAttr({emitc::OpaqueAttr::get( - rewriter.getContext(), cast(bitsTy).getValue())}); - Value lhsBits = - rewriter - .create(loc, TypeRange{bitsTy}, "ptoas_bitcast", - ValueRange{lhs}, ArrayAttr{}, - templateArgs) - .getResult(0); - Value oneBits = makeEmitCIntConstant(rewriter, loc, bitsTy, 1); - Value shiftAmount = - makeEmitCIntConstant(rewriter, loc, bitsTy, floatTy.getWidth() - 1); - Value signMask = rewriter - .create(loc, bitsTy, oneBits, - shiftAmount) - .getResult(); - return rewriter - .create(loc, bitsTy, lhsBits, signMask) - .getResult(); - } - - static Value buildSignedZeroCandidate(ConversionPatternRewriter &rewriter, - Location loc, Type dstTy, Value lhs, - Value rhs, FloatType floatTy) { - Value zero = makeFZero(rewriter, loc, dstTy); - Value equal = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::eq, lhs, rhs) - .getResult(); - Value lhsZero = rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::eq, lhs, - zero) - .getResult(); - Value bothZero = rewriter - .create(loc, rewriter.getI1Type(), - equal, lhsZero) - .getResult(); - auto bitsTy = - getUnsignedIntOpaqueType(rewriter.getContext(), floatTy.getWidth()); - Value zeroBits = makeEmitCIntConstant(rewriter, loc, bitsTy, 0); - Value lhsIsNegZero = - rewriter - .create(loc, rewriter.getI1Type(), - emitc::CmpPredicate::ne, - buildSignBitValue(rewriter, loc, lhs, floatTy), - zeroBits) - .getResult(); - Value tie = rewriter - .create( - loc, dstTy, lhsIsNegZero, isMaximum ? rhs : lhs, - isMaximum ? lhs : rhs) - .getResult(); - return rewriter - .create(loc, dstTy, bothZero, tie, - buildPrimaryCandidate(rewriter, loc, dstTy, - lhs, rhs)) - .getResult(); - } - - static Value buildNaNPropagatingResult(ConversionPatternRewriter &rewriter, - Location loc, Type dstTy, Value lhs, - Value rhs, FloatType floatTy) { - Value lhsNaN = isNaN(rewriter, loc, lhs); - Value rhsNaN = isNaN(rewriter, loc, rhs); - Value noNaN = - buildSignedZeroCandidate(rewriter, loc, dstTy, lhs, rhs, floatTy); - Value rhsOrNoNaN = rewriter - .create(loc, dstTy, rhsNaN, rhs, - noNaN) - .getResult(); - return rewriter - .create(loc, dstTy, lhsNaN, lhs, rhsOrNoNaN) - .getResult(); - } - - LogicalResult - matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!isa(op.getType())) - return rewriter.notifyMatchFailure(op, "expected scalar float type"); - - auto loc = op.getLoc(); - Type dstTy = this->getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - - auto floatTy = cast(op.getType()); - rewriter.replaceOp(op, buildNaNPropagatingResult( - rewriter, loc, dstTy, adaptor.getLhs(), - adaptor.getRhs(), floatTy)); - return success(); - } -}; - -using ArithMaximumFToEmitC = - ArithMinMaxFPropagateNaNToEmitC; -using ArithMinimumFToEmitC = - ArithMinMaxFPropagateNaNToEmitC; - -//===----------------------------------------------------------------------===// -// Arith -> EmitC helpers -//===----------------------------------------------------------------------===// - -static emitc::OpaqueType getSignedIntOpaqueType(MLIRContext *ctx, - unsigned bitWidth) { - switch (bitWidth) { - case 1: - return emitc::OpaqueType::get(ctx, "int8_t"); - case 8: - return emitc::OpaqueType::get(ctx, "int8_t"); - case 16: - return emitc::OpaqueType::get(ctx, "int16_t"); - case 32: - return emitc::OpaqueType::get(ctx, "int32_t"); - case 64: - return emitc::OpaqueType::get(ctx, "int64_t"); - case 128: - return emitc::OpaqueType::get(ctx, "__int128"); - default: - llvm::errs() << "[Debug] Unsupported signed integer bitwidth: " << bitWidth - << "\n"; - return emitc::OpaqueType::get(ctx, "int64_t"); - } -} - -static emitc::OpaqueType getUnsignedIntOpaqueType(MLIRContext *ctx, - unsigned bitWidth) { - switch (bitWidth) { - case 1: - return emitc::OpaqueType::get(ctx, "uint8_t"); - case 8: - return emitc::OpaqueType::get(ctx, "uint8_t"); - case 16: - return emitc::OpaqueType::get(ctx, "uint16_t"); - case 32: - return emitc::OpaqueType::get(ctx, "uint32_t"); - case 64: - return emitc::OpaqueType::get(ctx, "uint64_t"); - case 128: - return emitc::OpaqueType::get(ctx, "unsigned __int128"); - default: - llvm::errs() << "[Debug] Unsupported unsigned integer bitwidth: " - << bitWidth << "\n"; - return emitc::OpaqueType::get(ctx, "uint64_t"); - } -} - -static emitc::OpaqueType getWiderSignedIntOpaqueType(MLIRContext *ctx, - unsigned bitWidth) { - switch (bitWidth) { - case 1: - case 8: - return getSignedIntOpaqueType(ctx, 16); - case 16: - return getSignedIntOpaqueType(ctx, 32); - case 32: - return getSignedIntOpaqueType(ctx, 64); - case 64: - return getSignedIntOpaqueType(ctx, 128); - default: - return getSignedIntOpaqueType(ctx, 128); - } -} - -static emitc::OpaqueType getWiderUnsignedIntOpaqueType(MLIRContext *ctx, - unsigned bitWidth) { - switch (bitWidth) { - case 1: - case 8: - return getUnsignedIntOpaqueType(ctx, 16); - case 16: - return getUnsignedIntOpaqueType(ctx, 32); - case 32: - return getUnsignedIntOpaqueType(ctx, 64); - case 64: - return getUnsignedIntOpaqueType(ctx, 128); - default: - return getUnsignedIntOpaqueType(ctx, 128); - } -} - -static Value makeEmitCOpaqueConstant(ConversionPatternRewriter &rewriter, - Location loc, Type type, - llvm::StringRef literal) { - auto attr = emitc::OpaqueAttr::get(rewriter.getContext(), literal); - return rewriter.create(loc, type, attr); -} - -static Value makeEmitCIntConstant(ConversionPatternRewriter &rewriter, - Location loc, Type type, int64_t value) { - return makeEmitCOpaqueConstant(rewriter, loc, type, std::to_string(value)); -} - -static FailureOr buildEmitCOpaqueConstantLiteral(Type targetType, - Attribute valueAttr) { - auto opaqueTy = dyn_cast(targetType); - if (!opaqueTy) - return failure(); - - if (opaqueTy.getValue() == "pto::MrgSortExecutedNumList") { - auto dense = dyn_cast_or_null(valueAttr); - if (!dense) - return failure(); - - auto vecTy = dyn_cast(dense.getType()); - if (!vecTy || vecTy.getRank() != 1 || vecTy.getNumElements() != 4 || - !vecTy.getElementType().isInteger(16)) - return failure(); - - std::string literal; - llvm::raw_string_ostream os(literal); - os << "pto::MrgSortExecutedNumList{"; - bool first = true; - for (APInt elem : dense.getValues()) { - if (!first) - os << ", "; - first = false; - os << elem.getZExtValue(); - } - os << "}"; - os.flush(); - return literal; - } - - return failure(); -} - -static Value emitCCast(ConversionPatternRewriter &rewriter, Location loc, - Type dstType, Value src) { - if (src.getType() == dstType) - return src; - return rewriter.createOrFold(loc, dstType, src); -} - -// For signless iN integers lowered to signed C++ types, this creates a value -// representing the same N-bit pattern in an unsigned C++ type of the same -// width. This avoids incorrect sign-extension when later widening to a larger -// unsigned type. -static Value castSignlessIntToUnsignedSameWidth(ConversionPatternRewriter &rewriter, - Location loc, Value v, - unsigned bitWidth) { - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - return emitCCast(rewriter, loc, uTy, v); -} - -struct ArithMulIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(arith::MulIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - // i1 mul is equivalent to bitwise AND (mod 2 arithmetic). - if (bitWidth == 1) { - rewriter.replaceOpWithNewOp(op, opTy, adaptor.getLhs(), - adaptor.getRhs()); - return success(); - } - - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value mulU = rewriter.create(loc, uTy, lhsU, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, mulU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithAddIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - // i1 add is equivalent to XOR (mod 2 arithmetic). - if (bitWidth == 1) { - rewriter.replaceOpWithNewOp(op, opTy, adaptor.getLhs(), - adaptor.getRhs()); - return success(); - } - - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value addU = rewriter.create(loc, uTy, lhsU, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, addU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithCastOPToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type newTy = getTypeConverter()->convertType(op.getType()); - if (!newTy) - return failure(); - if (adaptor.getIn().getType() == newTy) { - rewriter.replaceOp(op, adaptor.getIn()); - return success(); - } - rewriter.replaceOpWithNewOp(op, newTy, adaptor.getIn()); - return success(); - } -}; - -struct ArithSubIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(arith::SubIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Type opTy = op.getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - - Type dstTy = getTypeConverter()->convertType(opTy); - if (!dstTy) - return failure(); - - // i1 sub is equivalent to XOR (mod 2 arithmetic). - if (bitWidth == 1) { - rewriter.replaceOpWithNewOp(op, opTy, adaptor.getLhs(), - adaptor.getRhs()); - return success(); - } - - auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); - Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), - bitWidth); - Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), - bitWidth); - Value subU = rewriter.create(loc, uTy, lhsU, rhsU); - Value result = emitCCast(rewriter, loc, dstTy, subU); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ArithDivSIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(arith::DivSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type newTy = getTypeConverter()->convertType(op.getType()); - if (!newTy) - return failure(); - rewriter.replaceOpWithNewOp(op, newTy, adaptor.getLhs(), - adaptor.getRhs()); - return success(); - } -}; - -struct ArithRemSIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type newTy = getTypeConverter()->convertType(op.getType()); - if (!newTy) - return failure(); - rewriter.replaceOpWithNewOp(op, newTy, adaptor.getLhs(), - adaptor.getRhs()); - return success(); - } -}; - -struct ArithTruncIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - auto dstIntTy = dyn_cast(op.getType()); - auto srcIntTy = dyn_cast(op.getIn().getType()); - if (!dstIntTy || !srcIntTy) - return rewriter.notifyMatchFailure(op, "expected scalar integer types"); - - Type dstTy = getTypeConverter()->convertType(dstIntTy); - if (!dstTy) - return failure(); - - // to-i1 conversions: Arith wants truncation to the low bit, while C/C++ - // casts to bool are equivalent to `v != 0`. Implement as `(bool)(v & 1)`. - if (dstIntTy.getWidth() == 1) { - if (srcIntTy.getWidth() == 1) { - rewriter.replaceOp(op, adaptor.getIn()); - return success(); - } - - auto uSrcTy = - getUnsignedIntOpaqueType(rewriter.getContext(), srcIntTy.getWidth()); - Value inU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getIn(), - srcIntTy.getWidth()); - Value one = makeEmitCIntConstant(rewriter, loc, uSrcTy, 1); - Value masked = - rewriter.create(loc, uSrcTy, inU, one); - Value asBool = emitCCast(rewriter, loc, dstTy, masked); - rewriter.replaceOp(op, asBool); - return success(); - } - - rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); - return success(); - } -}; - -struct ArithConstantToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type newType = getTypeConverter()->convertType(op.getType()); - if (!newType) - return failure(); - - // `adaptor.getValue()` may be null if attribute conversion isn't defined. - // Use the original attribute as fallback and always cast null-safely. - Attribute valueAttr = adaptor.getValue(); - if (!valueAttr) - valueAttr = op.getValue(); - - if (auto opaqueLiteral = buildEmitCOpaqueConstantLiteral(newType, valueAttr); - succeeded(opaqueLiteral)) { - auto constAttr = emitc::OpaqueAttr::get(rewriter.getContext(), *opaqueLiteral); - rewriter.replaceOpWithNewOp(op, newType, constAttr); - return success(); - } - - if (auto floatAttr = dyn_cast_or_null(valueAttr)) { - SmallString<32> valStr; - floatAttr.getValue().toString(valStr); - llvm::StringRef s(valStr); - // Ensure the literal parses as a floating-point constant in C/C++. - // `APFloat::toString` may emit "1" for integral values; make it "1.0". - const bool hasFloatMarker = - s.contains('.') || s.contains('e') || s.contains('E') || - s.contains('p') || s.contains('P') || s.starts_with("0x") || - s.starts_with("0X") || s.starts_with("nan") || - s.starts_with("-nan") || s.starts_with("inf") || - s.starts_with("-inf"); - if (!hasFloatMarker) - valStr.append(".0"); - // Suffix: keep `f` for f16/f32; omit for f64. - if (!floatAttr.getType().isF64()) - valStr.append("f"); - auto constAttr = emitc::OpaqueAttr::get(rewriter.getContext(), valStr); - rewriter.replaceOpWithNewOp(op, newType, constAttr); - return success(); - } - - if (auto intAttr = dyn_cast_or_null(valueAttr)) { - std::string valStr = std::to_string(intAttr.getValue().getSExtValue()); - auto constAttr = emitc::OpaqueAttr::get(rewriter.getContext(), valStr); - rewriter.replaceOpWithNewOp(op, newType, constAttr); - return success(); - } - - return failure(); - } -}; -//===----------------------------------------------------------------------===// -// pto.mgather lowering -> MGATHER(dst, src, indexes) (pto-isa) -//===----------------------------------------------------------------------===// - -struct PTOMGatherToMGATHER : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::MGatherOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto *ctx = rewriter.getContext(); - Value mem = peelUnrealized(adaptor.getMem()); - Value idx = peelUnrealized(adaptor.getIdx()); - Value dst = peelUnrealized(adaptor.getDst()); - - Value memArg = maybeWrapGlobalMemrefAsGlobalTensor( - rewriter, op.getLoc(), mem, op.getMem().getType(), op.getOperation()); - - auto gatherOobTok = [&](pto::GatherOOB mode) -> StringRef { - switch (mode) { - case pto::GatherOOB::Undefined: - return "pto::GatherOOB::Undefined"; - case pto::GatherOOB::Clamp: - return "pto::GatherOOB::Clamp"; - case pto::GatherOOB::Wrap: - return "pto::GatherOOB::Wrap"; - case pto::GatherOOB::Zero: - return "pto::GatherOOB::Zero"; - } - llvm_unreachable("unknown GatherOOB"); - }; - - SmallVector templateArgVec; - const bool rowCoalesce = - isRowCoalescedMGatherIndexType(op.getDst().getType(), op.getIdx().getType()); - templateArgVec.push_back(emitc::OpaqueAttr::get( - ctx, rowCoalesce ? "pto::Coalesce::Row" : "pto::Coalesce::Elem")); - if (op.getGatherOob() != pto::GatherOOB::Undefined) { - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, gatherOobTok(op.getGatherOob()))); - } - ArrayAttr templateArgs = rewriter.getArrayAttr(templateArgVec); - - rewriter.create( - op.getLoc(), TypeRange{}, "MGATHER", - ArrayAttr{}, templateArgs, - ValueRange{dst, memArg, idx}); - - if (op->getNumResults() == 0) { - rewriter.eraseOp(op); - } else { - rewriter.replaceOp(op, dst); - } - return success(); - } -}; - -struct AffineApplyMulConstToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(affine::AffineApplyOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto map = op.getAffineMap(); - - if (map.getNumDims() != 0 || map.getNumSymbols() != 1) - return failure(); - - auto expr = map.getResult(0); - auto bin = dyn_cast(expr); - if (!bin || bin.getKind() != AffineExprKind::Mul) - return failure(); - - auto lhs = bin.getLHS(); - auto rhs = bin.getRHS(); - - auto symExpr = dyn_cast(lhs); - auto constExpr = dyn_cast(rhs); - if (!symExpr || !constExpr) - return failure(); - - Value inputVal = adaptor.getMapOperands()[0]; - - std::string valStr = std::to_string(constExpr.getValue()); - auto cstAttr = emitc::OpaqueAttr::get(rewriter.getContext(), valStr); - auto cstOp = rewriter.create( - op.getLoc(), inputVal.getType(), cstAttr); - - rewriter.replaceOpWithNewOp( - op, inputVal.getType(), inputVal, cstOp); - - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// Kernel inference helpers -//===----------------------------------------------------------------------===// - -enum class KernelKind { VecAdd, Matmul, Unknown }; - -[[maybe_unused]] static KernelKind inferKernelKind(func::FuncOp f) { - bool hasAdd = false; - bool hasMM = false; - f.walk([&](Operation *op) { - if (isa(op)) hasAdd = true; - if (isa(op)) hasMM = true; - if (isa(op)) hasMM = true; - }); - if (hasMM) return KernelKind::Matmul; - if (hasAdd) return KernelKind::VecAdd; - return KernelKind::Unknown; -} - -[[maybe_unused]] static void inferTileMNK(func::FuncOp f, int &M, int &N, int &K) { - M = 32; N = 32; K = 32; - SmallVector subs; - f.walk([&](memref::SubViewOp sv) { subs.push_back(sv); }); - - auto readShape2D = [&](memref::SubViewOp sv, int &d0, int &d1) { - auto resTy = mlir::cast(sv.getResult().getType()); - if (resTy.getRank() == 2 && resTy.hasStaticShape()) { - d0 = (int)resTy.getDimSize(0); - d1 = (int)resTy.getDimSize(1); - } - }; - - if (subs.empty()) return; - - int a0=32, a1=32; - readShape2D(subs[0], a0, a1); - M = a0; N = a1; - - if (subs.size() >= 2) { - int b0=32, b1=32; - readShape2D(subs[0], a0, a1); - readShape2D(subs[1], b0, b1); - M = a0; K = a1; N = b1; - } -} - -static std::optional getKernelKindMacro(func::FuncOp funcOp) { - auto kernelKindAttr = - funcOp->getAttrOfType(FunctionKernelKindAttr::name); - if (!kernelKindAttr) - return std::nullopt; - - switch (kernelKindAttr.getKernelKind()) { - case FunctionKernelKind::Cube: - return StringRef("__DAV_CUBE__"); - case FunctionKernelKind::Vector: - return StringRef("__DAV_VEC__"); - } - - llvm_unreachable("unexpected kernel kind"); -} - -struct FuncToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(func::FuncOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Convert the function signature with the type converter. - Type convertedTy = getTypeConverter()->convertType(op.getFunctionType()); - auto funcType = dyn_cast_or_null(convertedTy); - if (!funcType) - return rewriter.notifyMatchFailure(op, "failed to convert function type"); - if (funcType.getNumResults() > 1) - return rewriter.notifyMatchFailure( - op, "EmitC cannot return multiple values"); - - // Create the EmitC function with the converted signature. - auto emitcFunc = - rewriter.create(op.getLoc(), op.getName(), funcType); - - for (const auto &namedAttr : op->getAttrs()) { - StringRef name = namedAttr.getName().strref(); - if (name == op.getFunctionTypeAttrName() || - name == SymbolTable::getSymbolAttrName() || - name == pto::kPTOEntryAttrName || - name == pto::kLegacyHACCEntryAttrName || - name == "pto.internal.entry") - continue; - emitcFunc->setAttr(namedAttr.getName(), namedAttr.getValue()); - } - - if (op.isDeclaration()) { - emitcFunc.setSpecifiersAttr(rewriter.getStrArrayAttr({"extern"})); - rewriter.eraseOp(op); - return success(); - } - - if (pto::isPTOEntryFunction(op)) { - emitcFunc.setSpecifiersAttr( - rewriter.getStrArrayAttr({"__global__ AICORE"})); - } else if (op.isPrivate()) { - emitcFunc.setSpecifiersAttr( - rewriter.getStrArrayAttr({"static", "AICORE"})); - } else { - emitcFunc.setSpecifiersAttr(rewriter.getStrArrayAttr({"AICORE"})); - } - - std::optional kernelKindMacro = getKernelKindMacro(op); - bool needsNoSplitGuard = needsA5NoSplitVectorGuard(op.getOperation()); - - // Inline the original body, then convert region/block argument types to - // match the converted signature (also covers CFG blocks introduced by - // pre-lowering, e.g. scf.while -> cf.br/cf.cond_br). - rewriter.inlineRegionBefore(op.getBody(), emitcFunc.getBody(), - emitcFunc.end()); - - TypeConverter::SignatureConversion entryConv(op.getNumArguments()); - for (unsigned i = 0; i < op.getNumArguments(); ++i) - entryConv.addInputs(i, funcType.getInput(i)); - - if (failed(rewriter.convertRegionTypes(&emitcFunc.getBody(), - *getTypeConverter(), &entryConv))) - return failure(); - - // Preserve the existing function prologue shape. `kernel_kind` functions are - // emitted with the same macro guard/reset sequence that used to come from - // early pto.section wrapping, but only after SCF pre-lowering has finished. - { - Block &entryBlock = emitcFunc.getBody().front(); - rewriter.setInsertionPointToStart(&entryBlock); - rewriter.create(op.getLoc(), "using T = float;"); - if (kernelKindMacro) { - std::string startMacro = "\n#if defined(" + kernelKindMacro->str() + ")"; - rewriter.create(op.getLoc(), startMacro); - if (*kernelKindMacro == "__DAV_VEC__") { - rewriter.create(op.getLoc(), "set_mask_norm();"); - rewriter.create(op.getLoc(), - "set_vector_mask(-1, -1);"); - if (needsNoSplitGuard) - rewriter.create( - op.getLoc(), "if (get_subblockid() == 0) {"); - } - } - } - - if (kernelKindMacro) { - Block &lastBlock = emitcFunc.getBody().back(); - rewriter.setInsertionPoint(lastBlock.getTerminator()); - if (*kernelKindMacro == "__DAV_VEC__" && needsNoSplitGuard) - rewriter.create(op.getLoc(), "}"); - std::string endMacro = "#endif // " + kernelKindMacro->str() + "\n"; - rewriter.create(op.getLoc(), endMacro); - } - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// SubView lowering to GlobalTensor (keep your existing code) -//===----------------------------------------------------------------------=== - -enum class Role { A, B, C, Unknown }; - -template -static std::optional inferMatmulLikeSubviewRole(MatmulLikeOp op, - Value buffer) { - if (op.getLhs() == buffer) - return Role::A; - if (op.getRhs() == buffer) - return Role::B; - return std::nullopt; -} - -static std::optional inferSubviewRoleFromLoadUser(mlir::pto::TLoadOp load) { - Value buffer = load.getDst(); - if (!buffer) - return std::nullopt; - for (Operation *user : buffer.getUsers()) { - if (auto matmul = dyn_cast(user)) { - if (auto role = inferMatmulLikeSubviewRole(matmul, buffer)) - return role; - continue; - } - if (auto matmulAcc = dyn_cast(user)) { - if (auto role = inferMatmulLikeSubviewRole(matmulAcc, buffer)) - return role; - } - } - return std::nullopt; -} - -static std::optional inferSubviewRoleFromUser(Operation *user, Value result) { - if (auto load = dyn_cast(user)) - return inferSubviewRoleFromLoadUser(load); - if (auto store = dyn_cast(user)) { - if (store.getDst() == result) - return Role::C; - } - return std::nullopt; -} - -[[maybe_unused]] static Role inferSubviewRole(memref::SubViewOp sv) { - Value result = sv.getResult(); - for (Operation *user : result.getUsers()) { - if (auto role = inferSubviewRoleFromUser(user, result)) - return *role; - } - return Role::Unknown; -} - -// ============================================================================= -// 4. MemRef SubView -> Explicit Shape/Stride Construction (Full Implementation) -// ============================================================================= -struct SubviewToEmitCPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - // 辅助函数:尝试从 OpFoldResult 中提取静态整数值 - std::optional extractStaticInt(OpFoldResult ofr) const { - if (auto attr = ofr.dyn_cast()) { - if (auto intAttr = dyn_cast(attr)) - return intAttr.getInt(); - } else { - Value v = ofr.get(); - if (auto cOp = v.getDefiningOp()) { - if (auto iAttr = dyn_cast(cOp.getValue())) - return iAttr.getInt(); - } else if (auto idxOp = v.getDefiningOp()) { - return idxOp.value(); - } - } - return std::nullopt; - } - - LogicalResult matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - // 获取源 MemRef 类型信息 - auto srcType = mlir::cast(op.getSource().getType()); - int64_t rank = srcType.getRank(); - - auto elemTypeToString = [&](Type elemTy) -> std::string { - if (elemTy.isF16()) - return "half"; - if (elemTy.isBF16()) - return "bfloat16_t"; - if (elemTy.isF32()) - return "float"; - if (elemTy.isF64()) - return "double"; - if (elemTy.isInteger(8)) { - if (elemTy.isSignlessInteger(8) || elemTy.isSignedInteger(8)) - return "int8_t"; - return "uint8_t"; - } - if (elemTy.isInteger(16)) { - if (elemTy.isSignlessInteger(16) || elemTy.isSignedInteger(16)) - return "int16_t"; - return "uint16_t"; - } - if (elemTy.isInteger(32)) { - if (elemTy.isSignlessInteger(32) || elemTy.isSignedInteger(32)) - return "int32_t"; - return "uint32_t"; - } - if (elemTy.isInteger(64)) { - return cast(elemTy).isUnsigned() ? "uint64_t" : "int64_t"; - } - return "float"; - }; - - // ------------------------------------------------------------------------- - // Part 1: 指针偏移计算 (Runtime Pointer Arithmetic) - // ------------------------------------------------------------------------- - - // 准备类型: unsigned - Type u32Ty = emitc::OpaqueType::get(ctx, "unsigned"); - - // Helper: 创建 unsigned 常量 - auto mkU32 = [&](int64_t v) -> Value { - return rewriter.create( - loc, u32Ty, emitc::OpaqueAttr::get(ctx, std::to_string(v))); - }; - - // Helper: 将 OpFoldResult 转为 EmitC Value (用于计算) - auto ofrToEmitCValue = [&](OpFoldResult ofr) -> Value { - if (auto v = ofr.dyn_cast()) { - Value rv = rewriter.getRemappedValue(v); - // 如果类型不匹配,插入 Cast - if (rv.getType() != u32Ty) - return rewriter.create(loc, u32Ty, rv).getResult(); - return rv; - } - if (auto attr = ofr.dyn_cast()) { - if (auto ia = dyn_cast(attr)) - return mkU32(ia.getValue().getSExtValue()); - } - return mkU32(0); - }; - - // 1. 获取 Source 的 Strides (支持动态 Stride 收集) - SmallVector sourceStrides; - - if (auto rc = op.getSource().getDefiningOp()) { - sourceStrides = rc.getMixedStrides(); - } else { - SmallVector strideInts; - int64_t offset = ShapedType::kDynamic; - bool useTypeStrides = succeeded(getStridesAndOffset(srcType, strideInts, offset)); - (void)offset; - if (useTypeStrides) { - for (int64_t s : strideInts) { - if (s == ShapedType::kDynamic) - useTypeStrides = false; - } - } - if (useTypeStrides) { - for (int64_t s : strideInts) { - sourceStrides.push_back(rewriter.getIndexAttr(s)); - } - } else { - // Fallback: Compact Layout - auto shape = srcType.getShape(); - int64_t current = 1; - sourceStrides.resize(rank); - for (int i = rank - 1; i >= 0; --i) { - sourceStrides[i] = rewriter.getIndexAttr(current); - if (shape[i] != ShapedType::kDynamic) current *= shape[i]; - } - } - } - - // 2. 计算运行时 Offset - auto staticOffsets = op.getStaticOffsets(); - auto dynamicOffsets = adaptor.getOffsets(); - int dynOffIdx = 0; - Value totalOffset = mkU32(0); - - for (int i = 0; i < rank; ++i) { - // A. 获取 Offset - Value offVal; - if (staticOffsets[i] == ShapedType::kDynamic) { - Value rawDyn = dynamicOffsets[dynOffIdx++]; - offVal = rewriter.create(loc, u32Ty, rawDyn); - } else { - offVal = mkU32(staticOffsets[i]); - } - - // B. 获取 Stride (用于指针计算) - Value strideVal = mkU32(1); - if (i < (int)sourceStrides.size()) { - strideVal = ofrToEmitCValue(sourceStrides[i]); - } - - // C. 累加 - Value term = rewriter.create(loc, u32Ty, offVal, strideVal); - totalOffset = rewriter.create(loc, u32Ty, totalOffset, term); - } - - // 3. 生成新指针 - // - // NOTE: Some toolchains may materialize kernel pointer params as `void*` even - // when the underlying element type is i16. Pointer arithmetic on `void*` - // is ill-formed in C++, so we explicitly cast to a typed pointer for i16. - Value sourcePtr = adaptor.getSource(); - Value tileCandidate = sourcePtr; - if (auto castOp = sourcePtr.getDefiningOp()) { - tileCandidate = castOp.getOperand(); - } else if (auto uc = - sourcePtr.getDefiningOp()) { - tileCandidate = uc.getOperand(0); - } - if (auto ot = dyn_cast(tileCandidate.getType())) { - auto tyStr = ot.getValue(); - if (tyStr.find("Tile<") != std::string::npos || - tyStr.find("ConvTile<") != std::string::npos) { - std::string elemTok = elemTypeToString(srcType.getElementType()); - pto::AddressSpace as = pto::AddressSpace::GM; - if (auto asAttr = - dyn_cast_or_null(srcType.getMemorySpace())) - as = asAttr.getAddressSpace(); - sourcePtr = - materializeTileDataValue(rewriter, loc, tileCandidate, as, elemTok); - if (tileDataReturnsIntegralAddress(as)) - sourcePtr = - materializeAddressAsPointer(rewriter, loc, sourcePtr, as, elemTok); - } - } - Value newPtr; - { - auto resTy = mlir::cast(op.getResult().getType()); - Type elemTy = resTy.getElementType(); - if (elemTy.isInteger(16)) { - std::string castElemTypeStr = "int16_t"; - if (cast(elemTy).isUnsigned()) - castElemTypeStr = "uint16_t"; - - std::string qualifier = "__gm__"; - if (Attribute ms = srcType.getMemorySpace()) { - if (auto ptoAttr = dyn_cast(ms)) { - qualifier = addrSpaceQualifier(ptoAttr.getAddressSpace()); - } - } - - auto typedPtrTy = emitc::PointerType::get( - emitc::OpaqueType::get(ctx, qualifier + " " + castElemTypeStr)); - Value typedSourcePtr = rewriter.create(loc, typedPtrTy, sourcePtr); - newPtr = rewriter.create(loc, typedPtrTy, typedSourcePtr, totalOffset); - } else { - newPtr = rewriter.create(loc, sourcePtr.getType(), sourcePtr, totalOffset); - } - } - - - // ------------------------------------------------------------------------- - // Part 2: For non-GM memrefs, keep pointer (no GlobalTensor). - // ------------------------------------------------------------------------- - bool isGlobal = true; - if (auto asAttr = dyn_cast_or_null(srcType.getMemorySpace())) { - auto as = asAttr.getAddressSpace(); - isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); - } - if (!isGlobal) { - Type dstTy = getTypeConverter()->convertType(op.getType()); - if (!dstTy) - return failure(); - if (newPtr.getType() != dstTy) - newPtr = rewriter.create(loc, dstTy, newPtr); - rewriter.replaceOp(op, newPtr); - return success(); - } - - // ------------------------------------------------------------------------- - // Part 3: 生成 GlobalTensor 类型 (Shape/Stride Template Generation) - // ------------------------------------------------------------------------- - - // When emitting C++ with `declareVariablesAtTop`, value declarations are - // hoisted before body statements. Avoid introducing local `using` aliases - // for templated types (Shape/Stride/GlobalTensor) because those aliases - // would appear after the hoisted declarations and break compilation - // (`unknown type name`). - // - // Instead, use the fully spelled template types as EmitC opaque types. - - auto resTy = mlir::cast(op.getResult().getType()); - - // 1. 解析具体元素类型 - std::string elemTypeStr = getElemTypeStringForGT(resTy.getElementType()); - - // 2. 生成 Shape 模板参数,之后会右对齐有效维度并补齐到 5 维(高维填 1) - SmallVector shapeParamsVec; - SmallVector sizeValues; // 每个维度对应的运行时 size(统一为 unsigned) - auto resShape = resTy.getShape(); - auto mixedSizes = op.getMixedSizes(); - sizeValues.reserve(rank); - for (int i = 0; i < resTy.getRank(); ++i) { - if (resShape[i] == ShapedType::kDynamic) { - shapeParamsVec.push_back(-1); - } else { - shapeParamsVec.push_back(resShape[i]); - } - // size 值:优先从 op.getMixedSizes() 取(可动态/静态),否则退化为类型里的静态 shape。 - if (i < (int)mixedSizes.size()) - sizeValues.push_back(ofrToEmitCValue(mixedSizes[i])); - else - sizeValues.push_back( - mkU32(resShape[i] == ShapedType::kDynamic ? 1 : resShape[i])); - } - - // 3. 生成 Stride 模板参数 + 运行时 stride 值(考虑 subview step) - SmallVector strideTemplateVec; - SmallVector strideValues; // 每个维度对应的运行时 stride(统一为 unsigned) - strideTemplateVec.reserve(rank); - strideValues.reserve(rank); - auto subViewSteps = op.getMixedStrides(); - for (int i = 0; i < rank; ++i) { - OpFoldResult srcStrideOfr = - (i < (int)sourceStrides.size()) ? sourceStrides[i] - : rewriter.getIndexAttr(1); - OpFoldResult stepOfr = (i < (int)subViewSteps.size()) - ? subViewSteps[i] - : rewriter.getIndexAttr(1); - - auto srcStatic = extractStaticInt(srcStrideOfr); - auto stepStatic = extractStaticInt(stepOfr); - if (srcStatic && stepStatic) { - int64_t finalStride = (*srcStatic) * (*stepStatic); - strideTemplateVec.push_back(finalStride); - strideValues.push_back(mkU32(finalStride)); - continue; - } - - strideTemplateVec.push_back(-1); - Value srcV = ofrToEmitCValue(srcStrideOfr); - Value stepV = ofrToEmitCValue(stepOfr); - // 尽量避免乘以 1 生成冗余指令 - if (stepStatic && *stepStatic == 1) - strideValues.push_back(srcV); - else if (srcStatic && *srcStatic == 1) - strideValues.push_back(stepV); - else - strideValues.push_back( - rewriter.create(loc, u32Ty, srcV, stepV)); - } - - // 3.1 右对齐到 5 维:shape 补 1;已有维度继承原 stride; - // 被补出来的高维按“紧密升维”规则连续推导:stride[i] = shape[i+1] * stride[i+1] - SmallVector finalShape; - SmallVector finalStride; - buildGlobalTensorShapeAndStride(shapeParamsVec, strideTemplateVec, - finalShape, finalStride); - Value oneU32 = mkU32(1); - SmallVector finalShapeValues(5, oneU32); - SmallVector finalStrideValues(5, oneU32); - int shift = 5 - rank; - - // 先放入原始 shape/stride(保持用户提供的值) - for (int i = 0; i < rank && i < 5; ++i) { - finalShapeValues[shift + i] = sizeValues[i]; - finalStrideValues[shift + i] = strideValues[i]; - } - - // 从低维到高维倒推补齐 stride(仅对补出来的前置维度生效) - for (int i = 3; i >= 0; --i) { - // 如果该维已由原始 rank 覆盖,则保持原值 - if (i >= shift) - continue; - if (finalStride[i] != -1) { - finalStrideValues[i] = mkU32(finalStride[i]); - continue; - } - // 动态推导:stride[i] = shape[i+1] * stride[i+1] - if (finalShape[i + 1] == 1) { - finalStrideValues[i] = finalStrideValues[i + 1]; - } else { - finalStrideValues[i] = rewriter.create( - loc, u32Ty, finalShapeValues[i + 1], finalStrideValues[i + 1]); - } - } - - std::string shapeParams = joinIntTemplateParams(finalShape); - std::string strideParams = joinIntTemplateParams(finalStride); - - // Spelled-out C++ types. - std::string shapeCppType = "pto::Shape<" + shapeParams + ">"; - std::string strideCppType = "pto::Stride<" + strideParams + ">"; - - // 3.0 Layout: prefer the attribute from InferPTOLayout; only fall back to - // local inference when the pass is disabled. - std::string layoutEnum = "pto::Layout::ND"; - if (auto layout = resolveLayoutForGlobalTensor(op, op.getSource())) { - layoutEnum = layoutToEmitCString(*layout); - } else { - bool allStatic = - llvm::all_of(finalShape, [](int64_t value) { return value != -1; }) && - llvm::all_of(finalStride, [](int64_t value) { return value != -1; }); - - int layoutTag = 0; // ND - auto elemBytes = 4; // default float - if (elemTypeStr.find("half") != std::string::npos || - elemTypeStr.find("f16") != std::string::npos || - elemTypeStr.find("bf16") != std::string::npos) - elemBytes = 2; - else if (elemTypeStr.find("double") != std::string::npos || - elemTypeStr.find("f64") != std::string::npos) - elemBytes = 8; - - if (allStatic) { - if (finalShape[2] == 16 && - finalShape[2] * finalShape[3] * elemBytes == 512 && - finalStride[4] == 1 && finalStride[3] == finalShape[4]) { - layoutTag = 2; // NZ - } else { - bool isRow = finalStride[4] == 1; - for (int i = 3; i >= 0; --i) - isRow &= (finalStride[i] == - multiplyOrDynamic(finalStride[i + 1], finalShape[i + 1])); - bool isCol = finalStride[0] == 1; - for (int i = 0; i < 4; ++i) - isCol &= (finalStride[i + 1] == - multiplyOrDynamic(finalStride[i], finalShape[i])); - if (isCol) - layoutTag = 1; // DN - else - layoutTag = isRow ? 0 : 0; // fallback ND - } - } - - if (layoutTag == 1) - layoutEnum = "pto::Layout::DN"; - else if (layoutTag == 2) - layoutEnum = "pto::Layout::NZ"; - } - // GlobalTensor takes a Layout non-type template parameter; directly use the - // enum constant. - - - // ------------------------------------------------------------------------- - // Part 3: 显式对象实例化 (Explicit Object Instantiation) - // ------------------------------------------------------------------------- - - // A. Instantiate Shape object. - auto shapeTypeOpaque = emitc::OpaqueType::get(ctx, shapeCppType); - SmallVector shapeArgs; - // 从 adaptor.getSizes() 获取 subview 的所有 dynamic sizes - for (Value dynSize : adaptor.getSizes()) { - shapeArgs.push_back(dynSize); - } - - auto shapeInstOp = rewriter.create( - loc, - shapeTypeOpaque, // 返回类型 - shapeCppType, // 调用的“函数名”即类名构造函数 - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange(shapeArgs) - ); - - // B. Instantiate Stride object. - auto strideTypeOpaque = emitc::OpaqueType::get(ctx, strideCppType); - // 仅传入动态 stride 维度对应的值,匹配 pto::Stride 的 N-parameter ctor(并满足其 static_assert)。 - SmallVector strideCtorArgs; - strideCtorArgs.reserve(5); - for (int i = 0; i < 5; ++i) { - if (finalStride[i] == -1) - strideCtorArgs.push_back(finalStrideValues[i]); - } - auto strideInstOp = rewriter.create( - loc, strideTypeOpaque, strideCppType, - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange(strideCtorArgs)); - - // C. Instantiate GlobalTensor object (ptr + shape + stride). - std::string gtCppType = "GlobalTensor<" + elemTypeStr + ", " + shapeCppType + - ", " + strideCppType + ", " + layoutEnum + ">"; - auto gtType = emitc::OpaqueType::get(ctx, gtCppType); - - // 准备构造参数: [ptr, shape_instance, stride_instance] - SmallVector gtConstructorArgs; - gtConstructorArgs.push_back(newPtr); - gtConstructorArgs.push_back(shapeInstOp.getResult(0)); // 拿到 shape_inst 的 SSA Value - gtConstructorArgs.push_back(strideInstOp.getResult(0)); // 拿到 stride_inst 的 SSA Value - - rewriter.replaceOpWithNewOp( - op, - gtType, - gtCppType, - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange(gtConstructorArgs) - ); - - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// Helper: build GlobalTensor from a static MemRef (for TLOAD/TSTORE) -//===----------------------------------------------------------------------===// - -static std::string getElemTypeStringForGT(Type elemTy) { - return getEmitCScalarTypeToken(elemTy); -} - -static bool hasStaticShape(MemRefType mrTy) { - return llvm::none_of(mrTy.getShape(), [](int64_t dim) { - return dim == ShapedType::kDynamic; - }); -} - -static bool getStaticMemrefLayout(MemRefType mrTy, SmallVectorImpl &strides, - int64_t &offset) { - if (failed(getStridesAndOffset(mrTy, strides, offset))) { - strides.clear(); - int64_t stride = 1; - ArrayRef shape = mrTy.getShape(); - for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { - strides.push_back(stride); - stride *= shape[i]; - } - std::reverse(strides.begin(), strides.end()); - offset = 0; - } - return offset != ShapedType::kDynamic && - llvm::none_of(strides, [](int64_t strideValue) { - return strideValue == ShapedType::kDynamic; - }); -} - -static Value applyStaticMemrefOffset(ConversionPatternRewriter &rewriter, - Location loc, Value basePtr, - int64_t offset) { - if (offset == 0) - return basePtr; - auto *ctx = rewriter.getContext(); - Type u32Ty = emitc::OpaqueType::get(ctx, "unsigned"); - auto offVal = rewriter.create( - loc, u32Ty, emitc::OpaqueAttr::get(ctx, std::to_string(offset))); - return rewriter.create(loc, basePtr.getType(), basePtr, offVal); -} - -static int getGlobalTensorElementBytes(Type elemTy) { - return static_cast(getPTOStorageElemByteSize(elemTy)); -} - -static int64_t multiplyOrDynamic(int64_t lhs, int64_t rhs) { - if (lhs < 0 || rhs < 0) - return -1; - return lhs * rhs; -} - -static void buildGlobalTensorShapeAndStride(ArrayRef shape, - ArrayRef strides, - SmallVectorImpl &shape5D, - SmallVectorImpl &stride5D) { - shape5D.assign(5, 1); - stride5D.assign(5, 1); - int rank = static_cast(shape.size()); - int shift = 5 - rank; - for (int i = 0; i < rank && i < 5; ++i) { - shape5D[shift + i] = shape[i]; - stride5D[shift + i] = strides[i]; - } - for (int i = 3; i >= 0; --i) { - if (i >= shift) - continue; - stride5D[i] = multiplyOrDynamic(shape5D[i + 1], stride5D[i + 1]); - } -} - -static std::string joinIntTemplateParams(ArrayRef values) { - std::string result; - for (size_t i = 0; i < values.size(); ++i) { - if (i != 0) - result += ", "; - result += std::to_string(values[i]); - } - return result; -} - -static SmallVector buildRowMajorStrides(ArrayRef shape) { - SmallVector strides(shape.size(), 1); - int64_t running = 1; - for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { - strides[i] = running; - running = multiplyOrDynamic(running, shape[i]); - } - return strides; -} - -static std::string getGlobalTensorTypeStringFromShape(Type elemTy, - ArrayRef shape, - StringRef layoutEnum) { - SmallVector strides = buildRowMajorStrides(shape); - return getGlobalTensorTypeStringFromShapeAndStrides(elemTy, shape, strides, - layoutEnum); -} - -static std::string getGlobalTensorTypeStringFromShapeAndStrides( - Type elemTy, ArrayRef shape, ArrayRef strides, - StringRef layoutEnum) { - SmallVector shape5D; - SmallVector stride5D; - buildGlobalTensorShapeAndStride(shape, strides, shape5D, stride5D); - - std::string elemTypeStr = getElemTypeStringForGT(elemTy); - std::string shapeType = "pto::Shape<" + joinIntTemplateParams(shape5D) + ">"; - std::string strideType = - "pto::Stride<" + joinIntTemplateParams(stride5D) + ">"; - return "GlobalTensor<" + elemTypeStr + ", " + shapeType + ", " + - strideType + ", " + layoutEnum.str() + ">"; -} - -static emitc::OpaqueType getGlobalTensorOpaqueTypeFromShape( - MLIRContext *ctx, Type elemTy, ArrayRef shape, - StringRef layoutEnum) { - return emitc::OpaqueType::get( - ctx, getGlobalTensorTypeStringFromShape(elemTy, shape, layoutEnum)); -} - -static std::string inferFallbackGlobalTensorLayout(ArrayRef shape5D, - ArrayRef stride5D, - Type elemTy) { - int elemBytes = getGlobalTensorElementBytes(elemTy); - if (elemBytes == 0) - return "pto::Layout::ND"; - if (shape5D[2] == 16 && multiplyOrDynamic(shape5D[2], shape5D[3]) * elemBytes == 512 && - stride5D[4] == 1 && stride5D[3] == shape5D[4]) { - return "pto::Layout::NZ"; - } - - bool isRowMajor = stride5D[4] == 1; - for (int i = 3; i >= 0 && isRowMajor; --i) - isRowMajor = stride5D[i] == multiplyOrDynamic(stride5D[i + 1], shape5D[i + 1]); - - bool isColMajor = stride5D[0] == 1; - for (int i = 0; i < 4 && isColMajor; ++i) - isColMajor = stride5D[i + 1] == multiplyOrDynamic(stride5D[i], shape5D[i]); - - if (isColMajor) - return "pto::Layout::DN"; - return isRowMajor ? "pto::Layout::ND" : "pto::Layout::ND"; -} - -static std::string resolveGlobalTensorLayout(Operation *anchor, Value basePtr, - ArrayRef shape5D, - ArrayRef stride5D, - Type elemTy) { - if (auto layout = resolveLayoutForGlobalTensor(anchor, basePtr)) - return layoutToEmitCString(*layout); - return inferFallbackGlobalTensorLayout(shape5D, stride5D, elemTy); -} - -struct GlobalTensorTypeNames { - std::string shapeTypeName; - std::string strideTypeName; - std::string tensorTypeName; - std::string layoutConstName; -}; - -static GlobalTensorTypeNames getGlobalTensorTypeNames(Operation *anchor) { - std::string suffix = "_" + std::to_string(reinterpret_cast(anchor)); - return { - "GTShape" + suffix, - "GTStride" + suffix, - "GT" + suffix, - "GT" + suffix + "_layout", - }; -} -static Value buildGlobalTensorFromMemref(ConversionPatternRewriter &rewriter, - Location loc, Value basePtr, - MemRefType mrTy, - Operation *anchor) { - auto *ctx = rewriter.getContext(); - - ArrayRef shape = mrTy.getShape(); - if (!hasStaticShape(mrTy)) - return Value(); - - SmallVector strides; - int64_t offset = 0; - if (!getStaticMemrefLayout(mrTy, strides, offset)) - return Value(); - - Value ptr = applyStaticMemrefOffset(rewriter, loc, basePtr, offset); - GlobalTensorTypeNames names = getGlobalTensorTypeNames(anchor); - std::string elemTypeStr = getElemTypeStringForGT(mrTy.getElementType()); - SmallVector shape5D; - SmallVector stride5D; - buildGlobalTensorShapeAndStride(shape, strides, shape5D, stride5D); - - rewriter.create( - loc, "using " + names.shapeTypeName + " = pto::Shape<" + - joinIntTemplateParams(shape5D) + ">;"); - rewriter.create( - loc, "using " + names.strideTypeName + " = pto::Stride<" + - joinIntTemplateParams(stride5D) + ">;"); - - std::string layoutEnum = resolveGlobalTensorLayout( - anchor, basePtr, shape5D, stride5D, mrTy.getElementType()); - rewriter.create(loc, "constexpr pto::Layout " + - names.layoutConstName + " = " + - layoutEnum + ";"); - - auto shapeTypeOpaque = emitc::OpaqueType::get(ctx, names.shapeTypeName); - auto strideTypeOpaque = emitc::OpaqueType::get(ctx, names.strideTypeName); - auto shapeInstOp = rewriter.create( - loc, shapeTypeOpaque, names.shapeTypeName, ArrayAttr{}, ArrayAttr{}, - ValueRange{}); - auto strideInstOp = rewriter.create( - loc, strideTypeOpaque, names.strideTypeName, ArrayAttr{}, ArrayAttr{}, - ValueRange{}); - - rewriter.create( - loc, "using " + names.tensorTypeName + " = GlobalTensor<" + elemTypeStr + - ", " + names.shapeTypeName + ", " + names.strideTypeName + - ", " + names.layoutConstName + ">;"); - auto gtType = emitc::OpaqueType::get(ctx, names.tensorTypeName); - - SmallVector gtArgs; - gtArgs.push_back(ptr); - gtArgs.push_back(shapeInstOp.getResult(0)); - gtArgs.push_back(strideInstOp.getResult(0)); - - auto gtInst = rewriter.create( - loc, gtType, names.tensorTypeName, ArrayAttr{}, ArrayAttr{}, - ValueRange(gtArgs)); - - return gtInst.getResult(0); -} - -static Value maybeWrapGlobalMemrefAsGlobalTensor( - ConversionPatternRewriter &rewriter, Location loc, Value loweredValue, - Type originalType, Operation *anchor) { - auto mrTy = dyn_cast(originalType); - if (!mrTy) - return loweredValue; - - bool isGlobal = true; - if (auto asAttr = - dyn_cast_or_null(mrTy.getMemorySpace())) { - auto as = asAttr.getAddressSpace(); - isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); - } - if (!isGlobal) - return loweredValue; - - if (Value gt = - buildGlobalTensorFromMemref(rewriter, loc, loweredValue, mrTy, anchor)) - return gt; - return loweredValue; -} - -static Value castToGMBytePointer(ConversionPatternRewriter &rewriter, - Location loc, Value value) { - auto *ctx = rewriter.getContext(); - auto targetTy = - emitc::PointerType::get(emitc::OpaqueType::get(ctx, "__gm__ uint8_t")); - if (value.getType() == targetTy) - return value; - - auto castTyAttr = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "__gm__ uint8_t*")}); - if (isSetFFTsPointerLikeType(value.getType())) { - return rewriter - .create(loc, targetTy, "reinterpret_cast", - ArrayAttr{}, castTyAttr, - ValueRange{value}) - .getResult(0); - } - return rewriter.create(loc, targetTy, value).getResult(); -} - -static Value materializeTensorViewDataPointer( - ConversionPatternRewriter &rewriter, Location loc, Value value, - Type sourceType) { - auto tvTy = dyn_cast(sourceType); - if (!tvTy) - return value; - - auto *ctx = rewriter.getContext(); - std::string elemTypeStr = getElemTypeStringForGT(tvTy.getElementType()); - auto ptrTy = emitc::PointerType::get( - emitc::OpaqueType::get(ctx, "__gm__ " + elemTypeStr)); - return rewriter - .create(loc, ptrTy, "PTOAS__GLOBAL_TENSOR_DATA", - ArrayAttr{}, ArrayAttr{}, ValueRange{value}) - .getResult(0); -} - -static std::string tileBufBLayoutToken(pto::TileBufConfigAttr configAttr) { - std::string blTok = "BLayout::RowMajor"; - if (auto blAttr = dyn_cast(configAttr.getBLayout())) { - if (static_cast(blAttr.getValue()) == 1) - blTok = "BLayout::ColMajor"; - } - return blTok; -} - -static std::string tileBufSLayoutToken(pto::TileBufConfigAttr configAttr) { - std::string slTok = "SLayout::NoneBox"; - if (auto slAttr = dyn_cast(configAttr.getSLayout())) { - int32_t slVal = static_cast(slAttr.getValue()); - slTok = (slVal == 1) ? "SLayout::RowMajor" - : (slVal == 2) ? "SLayout::ColMajor" - : "SLayout::NoneBox"; - } - return slTok; -} - -static std::string tileBufPadToken(pto::TileBufConfigAttr configAttr) { - std::string padTok = "PadValue::Null"; - if (auto padAttr = dyn_cast(configAttr.getPad())) { - switch (static_cast(padAttr.getValue())) { - case 1: - padTok = "PadValue::Zero"; - break; - case 2: - padTok = "PadValue::Max"; - break; - case 3: - padTok = "PadValue::Min"; - break; - default: - padTok = "PadValue::Null"; - break; - } - } - return padTok; -} - -static pto::BLayout getTileBufBLayoutValue(pto::TileBufConfigAttr configAttr) { - if (auto blAttr = dyn_cast(configAttr.getBLayout())) - return blAttr.getValue(); - return pto::BLayout::RowMajor; -} - -static int64_t renderTileTemplateDim(int64_t rawDim, Type elemTy, - pto::BLayout blayout, int dimIdx) { - assert(dimIdx >= 0 && dimIdx < 2 && - "renderTileTemplateDim expects a rank-2 rows/cols dimension index"); - if (rawDim == ShapedType::kDynamic) - return rawDim; - if (!pto::isPTOFloat4PackedType(elemTy)) - return rawDim; - int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; - return dimIdx == packedDim ? rawDim * 2 : rawDim; -} - -static FailureOr buildAsyncScratchTileValue( - ConversionPatternRewriter &rewriter, Location loc, Value originalScratch, - Value emittedScratch) { - Value scratch = peelUnrealized(emittedScratch); - if (auto opaqueTy = dyn_cast(scratch.getType())) { - StringRef typeStr = opaqueTy.getValue(); - if (typeStr.contains("Tile<") || typeStr.contains("ConvTile<")) - return scratch; - } - - auto memTy = dyn_cast(originalScratch.getType()); - if (!memTy) - return failure(); - - ArrayRef shape = memTy.getShape(); - if (!memTy.hasStaticShape() || shape.empty() || shape.size() > 2) - return failure(); - - int64_t rows = shape.size() == 1 ? 1 : shape[0]; - int64_t cols = shape.size() == 1 ? shape[0] : shape[1]; - - auto *ctx = rewriter.getContext(); - pto::TileBufConfigAttr configAttr = pto::TileBufConfigAttr::getDefault(ctx); - if (auto bind = originalScratch.getDefiningOp()) { - configAttr = bind.getConfig(); - } else if (auto cast = originalScratch.getDefiningOp()) { - if (auto config = cast.getConfig()) - configAttr = *config; - } - - int32_t fractal = 512; - if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) - fractal = frAttr.getInt(); - - Type elemTy = memTy.getElementType(); - pto::BLayout blayout = getTileBufBLayoutValue(configAttr); - int64_t templateRows = renderTileTemplateDim(rows, elemTy, blayout, 0); - int64_t templateCols = renderTileTemplateDim(cols, elemTy, blayout, 1); - std::string elemTypeStr = getEmitCScalarTypeToken(elemTy); - std::string tileTypeStr = - "Tile"; - - Value tile = rewriter - .create( - loc, emitc::OpaqueType::get(ctx, tileTypeStr), - emitc::OpaqueAttr::get(ctx, "")) - .getResult(); - auto addr = rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - Value scratchAddr = - rewriter - .create(loc, emitc::OpaqueType::get(ctx, "uint64_t"), - "reinterpret_cast", ArrayAttr{}, addr, - ValueRange{scratch}) - .getResult(0); - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{tile, scratchAddr}); - return tile; -} - -static FailureOr buildSyncAllWorkspaceTileValue( - ConversionPatternRewriter &rewriter, Location loc, Value originalWorkspace, - Value emittedWorkspace) { - Value workspace = peelUnrealized(emittedWorkspace); - if (auto opaqueTy = dyn_cast(workspace.getType())) { - StringRef typeStr = opaqueTy.getValue(); - if (typeStr.contains("Tile<") || typeStr.contains("ConvTile<")) - return workspace; - } - - auto memTy = dyn_cast(originalWorkspace.getType()); - if (!memTy) - return failure(); - if (!memTy.hasStaticShape()) - return failure(); - - ArrayRef rawShape = memTy.getShape(); - if (rawShape.empty() || rawShape.size() > 2) - return failure(); - - int64_t rows = rawShape.size() == 1 ? 1 : rawShape[0]; - int64_t cols = rawShape.size() == 1 ? rawShape[0] : rawShape[1]; - SmallVector shape{rows, cols}; - SmallVector validShape{rows, cols}; - - auto *ctx = rewriter.getContext(); - pto::TileBufConfigAttr configAttr = pto::TileBufConfigAttr::getDefault(ctx); - if (auto bind = originalWorkspace.getDefiningOp()) { - configAttr = bind.getConfig(); - } else if (auto cast = originalWorkspace.getDefiningOp()) { - if (auto config = cast.getConfig()) - configAttr = *config; - } - - Attribute memorySpace = memTy.getMemorySpace(); - if (!memorySpace) - return failure(); - - auto tileTy = pto::TileBufType::get(ctx, shape, memTy.getElementType(), - memorySpace, validShape, configAttr); - auto tileTypeString = getEmitCTileTypeString(tileTy); - if (!tileTypeString) - return failure(); - - auto tileEmitTy = emitc::OpaqueType::get(ctx, *tileTypeString); - Value tile = rewriter - .create(loc, tileEmitTy, - emitc::OpaqueAttr::get(ctx, "")) - .getResult(); - - Value rawPtr = workspace; - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - if (isSetFFTsPointerLikeType(rawPtr.getType())) { - auto rcU64 = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - rawPtr = rewriter - .create(loc, u64Ty, "reinterpret_cast", - ArrayAttr{}, rcU64, - ValueRange{rawPtr}) - .getResult(0); - } else if (rawPtr.getType() != u64Ty) { - rawPtr = rewriter.create(loc, u64Ty, rawPtr).getResult(); - } - - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{tile, rawPtr}); - return tile; -} - -//===----------------------------------------------------------------------===// -// pto.pointer_cast lowering -//===----------------------------------------------------------------------=== -struct PointerCastConversion : public OpConversionPattern { - static bool getIndexConst(Value v, int64_t &out) { - if (auto cst = v.getDefiningOp()) { - if (auto ia = dyn_cast(cst.getValue())) { - out = ia.getValue().getSExtValue(); - return true; - } - } - return false; - } - - using OpConversionPattern::OpConversionPattern; - - enum class TileRole { Vec, Mat, Left, Right, Acc, Bias, Scaling }; - - static void collectUserOpsThroughCasts(Value v, SmallVectorImpl &out) { - for (Operation *u : v.getUsers()) { - if (auto castOp = dyn_cast(u)) { - for (Value r : castOp.getResults()) - collectUserOpsThroughCasts(r, out); - continue; - } - out.push_back(u); - } - } - - static Value peelUnrealized(Value v) { - while (auto castOp = v.getDefiningOp()) { - v = castOp.getOperand(0); - } - return v; - } - - static TileRole inferRole(pto::PointerCastOp op) { - // 1. 优先检查 AddressSpace - if (auto memRefTy = dyn_cast(op.getType())) { - Attribute memorySpace = memRefTy.getMemorySpace(); - if (auto ptoAttr = dyn_cast_or_null(memorySpace)) { - switch (ptoAttr.getAddressSpace()) { - case pto::AddressSpace::LEFT: return TileRole::Left; - case pto::AddressSpace::RIGHT: return TileRole::Right; - case pto::AddressSpace::ACC: return TileRole::Acc; - case pto::AddressSpace::BIAS: return TileRole::Bias; - case pto::AddressSpace::MAT: return TileRole::Mat; - case pto::AddressSpace::SCALING: return TileRole::Scaling; - default: break; - } - } - } - - // 2. 通过 Usage 推导 (Fallback) - SmallVector users; - collectUserOpsThroughCasts(op.getResult(), users); - - for (Operation *user : users) { - if (auto mm = dyn_cast(user)) { - if (mm.getDst() && peelUnrealized(mm.getDst()) == op.getResult()) return TileRole::Acc; - if (peelUnrealized(mm.getLhs()) == op.getResult()) return TileRole::Left; - if (peelUnrealized(mm.getRhs()) == op.getResult()) return TileRole::Right; - } - if (auto mmacc = dyn_cast(user)) { - if (mmacc.getDst() && peelUnrealized(mmacc.getDst()) == op.getResult()) return TileRole::Acc; - if (peelUnrealized(mmacc.getAccIn()) == op.getResult()) return TileRole::Acc; - if (peelUnrealized(mmacc.getLhs()) == op.getResult()) return TileRole::Left; - if (peelUnrealized(mmacc.getRhs()) == op.getResult()) return TileRole::Right; - } - } - - return TileRole::Vec; - } - - // [新增] 辅助函数:判断 Value 是否源自 arith.constant - static bool isConstant(Value v, int64_t &outVal) { - if (!v) return false; - if (auto cst = v.getDefiningOp()) { - if (auto attr = dyn_cast(cst.getValue())) { - outVal = attr.getInt(); - return true; - } - } - return false; - } - - LogicalResult matchAndRewrite(pto::PointerCastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - auto selfType = mlir::cast(op.getType()); - ArrayRef shape = selfType.getShape(); - Type elemType = selfType.getElementType(); - - // 1. 推导 Tile Role - TileRole role = inferRole(op); - - // 2. 类型字符串生成 (elemTypeStr, dimStr) - std::string elemTypeStr = getEmitCScalarTypeToken(elemType); - - std::string dimStr; - pto::BLayout blayout = pto::BLayout::RowMajor; - auto dimToString = [&](int64_t dim, const char *symbol, - int dimIdx) -> std::string { - if (dim == ShapedType::kDynamic) - return std::string(symbol); - return std::to_string(renderTileTemplateDim(dim, elemType, blayout, - dimIdx)); - }; - - // 3. Role Token - const char *roleTok = "TileType::Vec"; - switch (role) { - case TileRole::Left: roleTok = "TileType::Left"; break; - case TileRole::Right: roleTok = "TileType::Right"; break; - case TileRole::Acc: roleTok = "TileType::Acc"; break; - case TileRole::Bias: roleTok = "TileType::Bias"; break; - case TileRole::Mat: roleTok = "TileType::Mat"; break; - case TileRole::Vec: roleTok = "TileType::Vec"; break; - case TileRole::Scaling: roleTok = "TileType::Scaling"; break; - } - - // 4. Config & Layout (support BLayoutAttr/SLayoutAttr/PadValueAttr after namespace change) - std::string layoutParams = "BLayout::RowMajor"; - std::string extraParams = ""; - if (auto configOpt = op.getConfig()) { - auto config = *configOpt; - int32_t blVal = 0; - if (auto attr = dyn_cast(config.getBLayout())) - blVal = static_cast(attr.getValue()); - - if (blVal == 1) layoutParams = "BLayout::ColMajor"; - blayout = blVal == 1 ? pto::BLayout::ColMajor : pto::BLayout::RowMajor; - - int32_t slVal = 0; - if (auto attr = dyn_cast(config.getSLayout())) - slVal = static_cast(attr.getValue()); - - std::string slStr = (slVal == 1) ? "SLayout::RowMajor" : (slVal == 2) ? "SLayout::ColMajor" : "SLayout::NoneBox"; - - int32_t frVal = 0; - if (auto attr = dyn_cast(config.getSFractalSize())) frVal = attr.getInt(); - - int32_t padVal = 0; - if (auto attr = dyn_cast(config.getPad())) - padVal = static_cast(attr.getValue()); - - std::string padStr = "PadValue::Null"; - switch (padVal) { - case 1: padStr = "PadValue::Zero"; break; - case 2: padStr = "PadValue::Max"; break; - case 3: padStr = "PadValue::Min"; break; - } - - int32_t compactVal = 0; - if (auto attr = dyn_cast(config.getCompactMode())) - compactVal = static_cast(attr.getValue()); - - std::string compactStr = "CompactMode::Null"; - switch (compactVal) { - case 1: compactStr = "CompactMode::Normal"; break; - case 2: compactStr = "CompactMode::RowPlusOne"; break; - } - - if (!slStr.empty()) { - extraParams += ", " + slStr + ", " + std::to_string(frVal) + ", " + - padStr + ", " + compactStr; - } - } else { - extraParams = ", SLayout::NoneBox, 512, PadValue::Null, CompactMode::Null"; - } - - if (role == TileRole::Left) - dimStr = dimToString(shape[0], "M", 0) + ", " + - dimToString(shape[1], "K", 1); - else if (role == TileRole::Right) - dimStr = dimToString(shape[0], "K", 0) + ", " + - dimToString(shape[1], "N", 1); - else if (role == TileRole::Bias) - dimStr = "1, " + dimToString(shape[1], "N", 1); - else - dimStr = dimToString(shape[0], "M", 0) + ", " + - dimToString(shape[1], "N", 1); - - // [核心修改] Valid Dims 处理逻辑 (支持混合静态/动态) - std::string vrowTok, vcolTok; - bool useConstructor = false; - - bool rowIsDynamic = false; - bool colIsDynamic = false; - - SmallVector constructorArgs; - - Value vRow = op.getValidRow(); - Value vCol = op.getValidCol(); - Value vRowEmitC = adaptor.getValidRow(); - Value vColEmitC = adaptor.getValidCol(); - bool forceDynamicValid = op->hasAttr(kForceDynamicValidShapeAttrName); - - int64_t cRow = 0, cCol = 0; - bool rowIsConst = vRow && isConstant(vRow, cRow); - bool colIsConst = vCol && isConstant(vCol, cCol); - - auto makeCtorDimValue = [&](Value emitted, int64_t fallback) -> Value { - if (emitted) - return emitted; - return makeEmitCIntConstant( - rewriter, loc, emitc::OpaqueType::get(ctx, "int32_t"), fallback); - }; - auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { - if (!emitted || !pto::isPTOFloat4PackedType(elemType)) - return emitted; - int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; - if (dimIdx != packedDim) - return emitted; - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); - return rewriter.create(loc, i32Ty, emitted, two).getResult(); - }; - - if (forceDynamicValid) { - vrowTok = "-1"; - vcolTok = "-1"; - useConstructor = true; - constructorArgs.push_back( - makeCtorDimValue(maybeScaleDynamicValid(vRowEmitC, 0), - renderTileTemplateDim(rowIsConst ? cRow : shape[0], - elemType, blayout, 0))); - constructorArgs.push_back( - makeCtorDimValue(maybeScaleDynamicValid(vColEmitC, 1), - renderTileTemplateDim(colIsConst ? cCol : shape[1], - elemType, blayout, 1))); - } else { - if (rowIsConst) { - vrowTok = std::to_string( - renderTileTemplateDim(cRow, elemType, blayout, 0)); - } else if (vRow) { - vrowTok = "-1"; - rowIsDynamic = true; - useConstructor = true; - } else { - vrowTok = std::to_string( - renderTileTemplateDim(shape[0], elemType, blayout, 0)); - } - - if (colIsConst) { - vcolTok = std::to_string( - renderTileTemplateDim(cCol, elemType, blayout, 1)); - } else if (vCol) { - vcolTok = "-1"; - colIsDynamic = true; - useConstructor = true; - } else { - vcolTok = std::to_string( - renderTileTemplateDim(shape[1], elemType, blayout, 1)); - } - - if (useConstructor) { - if (rowIsDynamic && vRowEmitC) - constructorArgs.push_back(maybeScaleDynamicValid(vRowEmitC, 0)); - if (colIsDynamic && vColEmitC) - constructorArgs.push_back(maybeScaleDynamicValid(vColEmitC, 1)); - } - } - - // 5. 生成 Tile 类型字符串 - std::string tileTypeStr = - std::string("Tile<") + roleTok + ", " + elemTypeStr + ", " + dimStr + ", " + - layoutParams + ", " + vrowTok + ", " + vcolTok + extraParams + ">"; - - auto tileType = emitc::OpaqueType::get(ctx, tileTypeStr); - Value resultValue; - - if (useConstructor) { - // 使用 CallOpaqueOp 生成构造函数调用 (Tile v = Tile(...)) - auto ctorOp = rewriter.create( - loc, - tileType, // Result Type - tileTypeStr, // Callee Name (类名) - ArrayAttr{}, // args - ArrayAttr{}, // template_args - ValueRange(constructorArgs) // operands - ); - resultValue = ctorOp.getResult(0); - } else { - // 静态情况 (Tile v;) - auto varOp = rewriter.create( - loc, - tileType, - emitc::OpaqueAttr::get(ctx, "") - ); - resultValue = varOp.getResult(); - } - - // TASSIGN: pto-isa expects an integral address. - Value addr = adaptor.getAddrs()[0]; - if (isa(addr.getType()) || - (isa(addr.getType()) && - cast(addr.getType()).getValue().ends_with("*"))) { - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - auto rcU64 = rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - addr = rewriter.create( - loc, u64Ty, "reinterpret_cast", - /*args=*/ArrayAttr{}, /*templateArgs=*/rcU64, - /*operands=*/ValueRange{addr}) - .getResult(0); - } - - rewriter.create( - loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{resultValue, addr}); - - rewriter.replaceOp(op, resultValue); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.load_dps / pto.store_dps lowering (FIX: keep optional result) -//===----------------------------------------------------------------------=== - -struct PTOTLoadToTLOAD : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TLoadOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!op.getDst()) - return rewriter.notifyMatchFailure(op, "expected outs(dst) on pto.tload"); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value srcArg = src; - if (auto srcMrTy = dyn_cast(op.getSrc().getType())) { - bool isGlobal = true; - if (auto asAttr = dyn_cast_or_null(srcMrTy.getMemorySpace())) { - auto as = asAttr.getAddressSpace(); - isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); - } - if (isGlobal) { - if (Value gt = buildGlobalTensorFromMemref(rewriter, op.getLoc(), src, srcMrTy, - op.getOperation())) - srcArg = gt; - } - } - - rewriter.create( - op.getLoc(), TypeRange{}, "TLOAD", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, srcArg}); - - if (op->getNumResults() == 1) { - rewriter.replaceOp(op, dst); - } else { - rewriter.eraseOp(op); - } - return success(); - } -}; - -struct PTOTPrefetchToTPREFETCH : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPrefetchOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!op.getDst()) - return rewriter.notifyMatchFailure(op, "expected outs(dst) on pto.tprefetch"); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value srcArg = src; - if (auto srcMrTy = dyn_cast(op.getSrc().getType())) { - bool isGlobal = true; - if (auto asAttr = dyn_cast_or_null(srcMrTy.getMemorySpace())) { - auto as = asAttr.getAddressSpace(); - isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); - } - if (isGlobal) { - if (Value gt = buildGlobalTensorFromMemref(rewriter, op.getLoc(), src, srcMrTy, - op.getOperation())) - srcArg = gt; - } - } - - rewriter.create( - op.getLoc(), TypeRange{}, "TPREFETCH", - ArrayAttr{}, ArrayAttr{}, ValueRange{dst, srcArg}); - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOTPrefetchAsyncToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPrefetchAsyncOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src = peelUnrealized(adaptor.getSrc()); - Value srcArg = src; - if (!isEmitCGlobalTensorLikeType(srcArg.getType())) { - auto srcMrTy = dyn_cast(op.getSrc().getType()); - if (!srcMrTy) - return rewriter.notifyMatchFailure( - op, "expected src to lower to GlobalTensor or memref"); - srcArg = buildGlobalTensorFromMemref(rewriter, op.getLoc(), src, srcMrTy, - op.getSrc().getDefiningOp() - ? op.getSrc().getDefiningOp() - : op.getOperation()); - } - if (!srcArg) - return rewriter.notifyMatchFailure(op, - "failed to build GlobalTensor src"); - - Value prefetchCtx = peelUnrealized(adaptor.getCtx()); - - Type eventTy = getTypeConverter()->convertType(op.getEvent().getType()); - if (!eventTy) - return rewriter.notifyMatchFailure( - op, "failed to convert tprefetch_async result type"); - - Value event = rewriter - .create( - op.getLoc(), TypeRange{eventTy}, "TPREFETCH_ASYNC", - ArrayAttr{}, ArrayAttr{}, - ValueRange{srcArg, prefetchCtx}) - .getResult(0); - - rewriter.replaceOp(op, ValueRange{event}); - return success(); - } -}; - -struct PTOMakePrefetchAsyncContextToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::MakePrefetchAsyncContextOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type ctxTy = getTypeConverter()->convertType(op.getCtx().getType()); - if (!ctxTy) - return rewriter.notifyMatchFailure( - op, "failed to convert make_prefetch_async_context result type"); - - Value workspace = peelUnrealized(adaptor.getWorkspace()); - workspace = castToGMBytePointer(rewriter, op.getLoc(), workspace); - - Value ctx = rewriter - .create( - op.getLoc(), TypeRange{ctxTy}, "pto::PrefetchAsyncContext", - ArrayAttr{}, ArrayAttr{}, ValueRange{workspace}) - .getResult(0); - - rewriter.replaceOp(op, ValueRange{ctx}); - return success(); - } -}; - -struct PTOGetPrefetchAsyncSessionToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::GetPrefetchAsyncSessionOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type sessionTy = getTypeConverter()->convertType(op.getSession().getType()); - if (!sessionTy) - return rewriter.notifyMatchFailure( - op, "failed to convert get_prefetch_async_session result type"); - - Value ctx = peelUnrealized(adaptor.getCtx()); - Value session = rewriter - .create( - op.getLoc(), TypeRange{sessionTy}, - "PTOAS__PREFETCH_CTX_SESSION", ArrayAttr{}, - ArrayAttr{}, ValueRange{ctx}) - .getResult(0); - - rewriter.replaceOp(op, ValueRange{session}); - return success(); - } -}; - -struct PTOTStoreToTSTORE : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - static std::string stPhaseTok(pto::STPhase phase) { - switch (phase) { - case pto::STPhase::Unspecified: return "STPhase::Unspecified"; - case pto::STPhase::Partial: return "STPhase::Partial"; - case pto::STPhase::Final: return "STPhase::Final"; - } - return "STPhase::Unspecified"; - } - - static std::string atomicTypeTok(pto::AtomicType atomicType) { - switch (atomicType) { - case pto::AtomicType::AtomicNone: return "AtomicType::AtomicNone"; - case pto::AtomicType::AtomicAdd: return "AtomicType::AtomicAdd"; - } - return "AtomicType::AtomicNone"; - } - - static std::string reluPreModeTok(pto::ReluPreMode reluPreMode) { - switch (reluPreMode) { - case pto::ReluPreMode::NoRelu: return "ReluPreMode::NoRelu"; - case pto::ReluPreMode::NormalRelu: return "ReluPreMode::NormalRelu"; - } - return "ReluPreMode::NoRelu"; - } - - LogicalResult matchAndRewrite(pto::TStoreOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!op.getDst()) - return rewriter.notifyMatchFailure(op, "expected outs(dst) on pto.tstore"); - - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value preQuantScalar; - if (op.getPreQuantScalar()) - preQuantScalar = peelUnrealized(adaptor.getPreQuantScalar()); - Value dstArg = dst; - if (auto dstMrTy = dyn_cast(op.getDst().getType())) { - bool isGlobal = true; - if (auto asAttr = dyn_cast_or_null(dstMrTy.getMemorySpace())) { - auto as = asAttr.getAddressSpace(); - isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); - } - if (isGlobal) { - if (Value gt = buildGlobalTensorFromMemref(rewriter, op.getLoc(), dst, dstMrTy, - op.getOperation())) - dstArg = gt; - } - } - - const auto phase = op.getStPhase(); - const auto atomicType = op.getAtomicType(); - const auto reluPreMode = op.getReluPreMode(); - const bool hasPreQuantScalar = static_cast(preQuantScalar); - const bool phaseNonDefault = phase != pto::STPhase::Unspecified; - const bool atomicNonDefault = atomicType != pto::AtomicType::AtomicNone; - const bool reluNonDefault = reluPreMode != pto::ReluPreMode::NoRelu; - - auto getOpaqueTok = [&](Value v, StringRef name) -> FailureOr { - if (auto ot = mlir::dyn_cast(v.getType())) - return ot.getValue().str(); - return rewriter.notifyMatchFailure(op, (name + " must be emitc::OpaqueType").str()); - }; - - ArrayAttr targs; - // Map op attributes/operands to the exact TSTORE overload family: - // 1) TSTORE(dst, src) - // 2) TSTORE(dst, src) - // 3) TSTORE(dst, src) - // 4) TSTORE(dst, src) - // 5) TSTORE(dst, src) - // 6) TSTORE(dst, src) - // 7) TSTORE(dst, src, preQuant) - // 8) TSTORE(dst, src, preQuant) - if (!hasPreQuantScalar && !reluNonDefault && !atomicNonDefault) { - if (phaseNonDefault) { - targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, stPhaseTok(phase)), - }); - } else { - targs = ArrayAttr{}; - } - } else { - auto srcTokOr = getOpaqueTok(src, "src"); - auto dstTokOr = getOpaqueTok(dstArg, "dst"); - if (failed(srcTokOr) || failed(dstTokOr)) - return failure(); - - // If there is no preQuant and relu stays default, emit the atomic-only - // overloads (#3/#4) without ReluPreMode template argument. - if (!hasPreQuantScalar && !reluNonDefault) { - if (phaseNonDefault) { - targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, stPhaseTok(phase)), - emitc::OpaqueAttr::get(ctx, *srcTokOr), - emitc::OpaqueAttr::get(ctx, *dstTokOr), - emitc::OpaqueAttr::get(ctx, atomicTypeTok(atomicType)), - }); - } else { - targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, *srcTokOr), - emitc::OpaqueAttr::get(ctx, *dstTokOr), - emitc::OpaqueAttr::get(ctx, atomicTypeTok(atomicType)), - }); - } - } else { - // Relu/preQuant families (#5/#6/#7/#8): keep AtomicType + ReluPreMode. - if (phaseNonDefault) { - targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, stPhaseTok(phase)), - emitc::OpaqueAttr::get(ctx, *srcTokOr), - emitc::OpaqueAttr::get(ctx, *dstTokOr), - emitc::OpaqueAttr::get(ctx, atomicTypeTok(atomicType)), - emitc::OpaqueAttr::get(ctx, reluPreModeTok(reluPreMode)), - }); - } else { - targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, *srcTokOr), - emitc::OpaqueAttr::get(ctx, *dstTokOr), - emitc::OpaqueAttr::get(ctx, atomicTypeTok(atomicType)), - emitc::OpaqueAttr::get(ctx, reluPreModeTok(reluPreMode)), - }); - } - } - } - - SmallVector operands{dstArg, src}; - if (hasPreQuantScalar) - operands.push_back(preQuantScalar); - - rewriter.create( - loc, TypeRange{}, "TSTORE", - /*args=*/ArrayAttr{}, /*templateArgs=*/targs, - /*operands=*/operands); - - if (op->getNumResults() == 1) { - rewriter.replaceOp(op, dst); - } else { - rewriter.eraseOp(op); - } - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.matmul_dps lowering (Simplified: No internal copy/sync) -//===----------------------------------------------------------------------===// -// -// Render `pto.tmatmul` as one of three forms depending on the optional -// `acc_phase` attribute: -// * absent / Unspecified -> `TMATMUL(dst, lhs, rhs)` -// * Partial -> `TMATMUL(dst, lhs, rhs)` -// * Final -> `TMATMUL(dst, lhs, rhs)` -// The Unspecified default keeps backward compatibility with all upstream IR -// that does not yet emit an explicit phase attribute. -static ArrayAttr buildAccPhaseTemplateArgs(ConversionPatternRewriter &rewriter, - pto::AccPhase phase) { - StringRef tmpl; - switch (phase) { - case pto::AccPhase::Unspecified: - return ArrayAttr{}; - case pto::AccPhase::Partial: - tmpl = "AccPhase::Partial"; - break; - case pto::AccPhase::Final: - tmpl = "AccPhase::Final"; - break; - } - if (tmpl.empty()) - return ArrayAttr{}; - return rewriter.getArrayAttr( - {emitc::OpaqueAttr::get(rewriter.getContext(), tmpl)}); -} - -struct PTOTMatmulToTMATMUL : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMatmulOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // 1. 获取操作数 (剥离 Cast) - Value lhs = peelUnrealized(adaptor.getLhs()); // A (Left) - Value rhs = peelUnrealized(adaptor.getRhs()); // B (Right) - Value dst = peelUnrealized(adaptor.getDst()); // C (Acc) - - // 2. 根据 acc_phase 属性决定是否生成 TMATMUL(...) - ArrayAttr templateArgs = - buildAccPhaseTemplateArgs(rewriter, op.getAccPhase()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TMATMUL", - /*args=*/ArrayAttr{}, /*template_args=*/templateArgs, - ValueRange{dst, lhs, rhs}); - - // 3. 处理 Op 替换/删除 - if (op->getNumResults() == 1) { - rewriter.replaceOp(op, dst); - } else { - rewriter.eraseOp(op); - } - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.tgemv lowering -//===----------------------------------------------------------------------===// -struct PTOTGemvToTGEMV : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGemvOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // 1. 获取操作数 (剥离 Cast) - Value lhs = peelUnrealized(adaptor.getLhs()); // A (Matrix) - Value rhs = peelUnrealized(adaptor.getRhs()); // B (Vector) - Value dst = peelUnrealized(adaptor.getDst()); // C (Result) - - // 2. 直接生成函数调用 TGEMV(dst, lhs, rhs) - rewriter.create( - op.getLoc(), TypeRange{}, "TGEMV", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, lhs, rhs}); - - // 3. 处理 Op 替换/删除 - if (op->getNumResults() == 1) { - rewriter.replaceOp(op, dst); - } else { - rewriter.eraseOp(op); - } - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.tgemv.acc lowering -//===----------------------------------------------------------------------===// -struct PTOTGemvAccToTGEMVACC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGemvAccOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!op.getDst()) - return rewriter.notifyMatchFailure(op, "expected outs(dst) for pto.tgemv.acc"); - - // 1. 获取操作数 - Value accIn = peelUnrealized(adaptor.getAccIn()); // AccOld - Value lhs = peelUnrealized(adaptor.getLhs()); // A (Matrix) - Value rhs = peelUnrealized(adaptor.getRhs()); // B (Vector) - Value dst = peelUnrealized(adaptor.getDst()); // AccNew - - // 2. 直接生成函数调用 TGEMV_ACC(dst, accIn, lhs, rhs) - rewriter.create( - op.getLoc(), TypeRange{}, "TGEMV_ACC", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, accIn, lhs, rhs}); - - // 3. 处理 Op 替换/删除 - if (op->getNumResults() == 1) { - rewriter.replaceOp(op, dst); - } else { - rewriter.eraseOp(op); - } - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.matmul_acc_dps lowering (Simplified: No internal copy/sync) -//===----------------------------------------------------------------------===// -struct PTOTMatmulAccToTMATMULACC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMatmulAccOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!op.getDst()) - return rewriter.notifyMatchFailure(op, "expected outs(dst) for pto.tmatmul.acc"); - - // 1. 获取操作数 - Value accIn = peelUnrealized(adaptor.getAccIn()); // AccOld - Value lhs = peelUnrealized(adaptor.getLhs()); // A (Left) - Value rhs = peelUnrealized(adaptor.getRhs()); // B (Right) - Value dst = peelUnrealized(adaptor.getDst()); // AccNew - - // 2. 根据 acc_phase 属性决定是否生成 TMATMUL_ACC(...) - ArrayAttr templateArgs = - buildAccPhaseTemplateArgs(rewriter, op.getAccPhase()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TMATMUL_ACC", - /*args=*/ArrayAttr{}, /*template_args=*/templateArgs, - ValueRange{dst, accIn, lhs, rhs}); - - // 3. 处理 Op 替换/删除 - if (op->getNumResults() == 1) { - rewriter.replaceOp(op, dst); - } else { - rewriter.eraseOp(op); - } - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// Return lowering -//===----------------------------------------------------------------------=== - -static constexpr llvm::StringLiteral kAutoSyncTailPendingModeAttr = - "__pto.auto_sync_tail_mode"; - -struct ReturnToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (auto emitcFunc = op->getParentOfType()) { - if (auto modeAttr = - emitcFunc->getAttrOfType(kAutoSyncTailPendingModeAttr)) { - auto *ctx = rewriter.getContext(); - rewriter.setInsertionPoint(op); - auto args = rewriter.getArrayAttr( - {emitc::OpaqueAttr::get(ctx, modeAttr.getValue())}); - rewriter.create( - op.getLoc(), TypeRange{}, "ptoas_auto_sync_tail", - args, ArrayAttr{}, ValueRange{}); - } - } - - auto vals = adaptor.getOperands(); - if (vals.empty()) { - rewriter.replaceOpWithNewOp(op, Value{}); - return success(); - } - if (vals.size() == 1) { - rewriter.replaceOpWithNewOp(op, vals[0]); - return success(); - } - return rewriter.notifyMatchFailure(op, "EmitC cannot return multiple values"); - } -}; - -struct CallToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(func::CallOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (op.getNumResults() > 1) - return rewriter.notifyMatchFailure( - op, "EmitC cannot lower calls with multiple results"); - - SmallVector resultTypes; - if (failed( - getTypeConverter()->convertTypes(op.getResultTypes(), resultTypes))) - return rewriter.notifyMatchFailure(op, - "failed to convert call result types"); - - rewriter.replaceOpWithNewOp(op, op.getCalleeAttr(), - resultTypes, - adaptor.getOperands()); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// Sync lowering -//===----------------------------------------------------------------------=== - -static constexpr llvm::StringLiteral kAutoSyncTailBarrierAttr = - "pto.auto_sync_tail_barrier"; -static constexpr llvm::StringLiteral kAutoSyncTailHintAttr = - "pto.auto_sync_tail_hint"; -static constexpr llvm::StringLiteral kAutoSyncTailPolicyBarrierAll = - "barrier_all"; -static constexpr llvm::StringLiteral kAutoSyncTailPolicyMte3ToSEvent0 = - "setwait_mte3_to_s_event0"; -static constexpr llvm::StringLiteral kAutoSyncTailModeBarrierAllToken = - "PTOAutoSyncTailMode::kBarrierAll"; -static constexpr llvm::StringLiteral kAutoSyncTailModeMte3ToSEvent0Token = - "PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0"; - -static std::string getAutoSyncTailModeToken(Operation *op) { - if (op) { - if (auto hintAttr = op->getAttrOfType(kAutoSyncTailHintAttr)) { - if (hintAttr.getValue() == kAutoSyncTailPolicyBarrierAll) - return kAutoSyncTailModeBarrierAllToken.str(); - if (hintAttr.getValue() == kAutoSyncTailPolicyMte3ToSEvent0) - return kAutoSyncTailModeMte3ToSEvent0Token.str(); - } - } - - auto func = op ? op->getParentOfType() : func::FuncOp(); - if (!func) - return kAutoSyncTailModeBarrierAllToken.str(); - - auto hintAttr = func->getAttrOfType(kAutoSyncTailHintAttr); - if (!hintAttr) - return kAutoSyncTailModeBarrierAllToken.str(); - - if (hintAttr.getValue() == kAutoSyncTailPolicyBarrierAll) - return kAutoSyncTailModeBarrierAllToken.str(); - if (hintAttr.getValue() == kAutoSyncTailPolicyMte3ToSEvent0) - return kAutoSyncTailModeMte3ToSEvent0Token.str(); - - // Fallback to the conservative behavior when seeing unknown policies. - return kAutoSyncTailModeBarrierAllToken.str(); -} - -[[maybe_unused]] static std::string getPipeName(pto::PIPE pipe) { - switch (pipe) { - case pto::PIPE::PIPE_S: return "PIPE_S"; - case pto::PIPE::PIPE_V: return "PIPE_V"; - case pto::PIPE::PIPE_M: return "PIPE_M"; - case pto::PIPE::PIPE_MTE1: return "PIPE_MTE1"; - case pto::PIPE::PIPE_MTE2: return "PIPE_MTE2"; - case pto::PIPE::PIPE_MTE3: return "PIPE_MTE3"; - case pto::PIPE::PIPE_ALL: return "PIPE_ALL"; - case pto::PIPE::PIPE_MTE4: return "PIPE_MTE4"; - case pto::PIPE::PIPE_MTE5: return "PIPE_MTE5"; - case pto::PIPE::PIPE_V2: return "PIPE_V2"; - case pto::PIPE::PIPE_FIX: return "PIPE_FIX"; - case pto::PIPE::VIRTUAL_PIPE_MTE2_L1A: return "VIRTUAL_PIPE_MTE2_L1A"; - case pto::PIPE::VIRTUAL_PIPE_MTE2_L1B: return "VIRTUAL_PIPE_MTE2_L1B"; - // 默认回退 - default: return "PIPE_ALL"; - } -} - -//===----------------------------------------------------------------------===// -// pto.barrier lowering -> pipe_barrier(...) -//===----------------------------------------------------------------------===// -struct PTOBarrierToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::BarrierOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (op->hasAttr(kAutoSyncTailBarrierAttr)) { - auto modeAttr = rewriter.getStringAttr(getAutoSyncTailModeToken(op)); - if (auto emitcFunc = op->getParentOfType()) { - emitcFunc->setAttr(kAutoSyncTailPendingModeAttr, modeAttr); - } else if (auto funcOp = op->getParentOfType()) { - funcOp->setAttr(kAutoSyncTailPendingModeAttr, modeAttr); - } - rewriter.eraseOp(op); - return success(); - } - - // [FIX] op.getPipe() returns PipeAttr. - // We must call .getPipe() on the attribute to get the actual Enum value. - pto::PIPE pipeEnum = op.getPipe().getPipe(); - - // Convert Enum to String (e.g., PIPE_ALL -> "PIPE_ALL") - std::string pipeStr = pto::stringifyPIPE(pipeEnum).str(); - auto *ctx = rewriter.getContext(); - - auto args = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeStr) - }); - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, // void return - "pipe_barrier", // function name - args, // arguments - ArrayAttr{}, // template args - ValueRange{} // operands - ); - - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// Sync lowering (robust for bracket form pto.set_flag[...] / pto.wait_flag[...]) -// Replace your PTOSyncToRuntimeCall with the code below. -//===----------------------------------------------------------------------===// - -static bool tryConvertPipeAttrToToken(Attribute attr, std::string &token) { - if (!attr) - return false; - if (auto pipe = dyn_cast(attr)) { - token = mlir::pto::stringifyPIPE(pipe.getPipe()).str(); - return true; - } - if (auto stringAttr = dyn_cast(attr)) { - token = stringAttr.getValue().str(); - return true; - } - return false; -} - -static bool tryConvertEventAttrToToken(Attribute attr, std::string &token) { - if (!attr) - return false; - if (auto event = dyn_cast(attr)) { - token = mlir::pto::stringifyEVENT(event.getEvent()).str(); - return true; - } - if (auto stringAttr = dyn_cast(attr)) { - token = stringAttr.getValue().str(); - return true; - } - return false; -} - -static bool tryAssignSyncTokens(Attribute srcAttr, Attribute dstAttr, - Attribute evtAttr, std::string &srcTok, - std::string &dstTok, std::string &evtTok) { - std::string localSrc; - std::string localDst; - std::string localEvt; - if (!tryConvertPipeAttrToToken(srcAttr, localSrc) || - !tryConvertPipeAttrToToken(dstAttr, localDst) || - !tryConvertEventAttrToToken(evtAttr, localEvt)) { - return false; - } - srcTok = std::move(localSrc); - dstTok = std::move(localDst); - evtTok = std::move(localEvt); - return true; -} - -static bool tryExtractSyncTokensFromNamedAttrs(Operation *op, - StringRef srcName, - StringRef dstName, - StringRef evtName, - std::string &srcTok, - std::string &dstTok, - std::string &evtTok) { - return tryAssignSyncTokens(op->getAttr(srcName), op->getAttr(dstName), - op->getAttr(evtName), srcTok, dstTok, evtTok); -} - -static bool tryExtractSyncTokensFromArrayAttr(Operation *op, StringRef attrName, - std::string &srcTok, - std::string &dstTok, - std::string &evtTok) { - auto arrayAttr = op->getAttrOfType(attrName); - if (!arrayAttr || arrayAttr.size() < 3) - return false; - return tryAssignSyncTokens(arrayAttr[0], arrayAttr[1], arrayAttr[2], srcTok, - dstTok, evtTok); -} - -static bool tryExtractFallbackSyncTokens(Operation *op, std::string &srcTok, - std::string &dstTok, - std::string &evtTok) { - SmallVector pipes; - std::string event; - for (NamedAttribute namedAttr : op->getAttrs()) { - std::string token; - if (tryConvertPipeAttrToToken(namedAttr.getValue(), token)) { - pipes.push_back(std::move(token)); - continue; - } - if (event.empty() && - tryConvertEventAttrToToken(namedAttr.getValue(), token)) { - event = std::move(token); - } - } - if (pipes.size() < 2 || event.empty()) - return false; - srcTok = pipes[0]; - dstTok = pipes[1]; - evtTok = event; - return true; -} - -static LogicalResult extractSyncTripletTokens(Operation *op, - std::string &srcTok, - std::string &dstTok, - std::string &evtTok, - ConversionPatternRewriter &rewriter) { - if (tryExtractSyncTokensFromNamedAttrs(op, "src_pipe", "dst_pipe", "event_id", - srcTok, dstTok, evtTok) || - tryExtractSyncTokensFromNamedAttrs(op, "srcPipe", "dstPipe", "eventId", - srcTok, dstTok, evtTok) || - tryExtractSyncTokensFromNamedAttrs(op, "src", "dst", "event", srcTok, - dstTok, evtTok)) { - return success(); - } - - for (StringRef attrName : {"args", "pipes", "sync", "triplet", "attrs"}) { - if (tryExtractSyncTokensFromArrayAttr(op, attrName, srcTok, dstTok, - evtTok)) { - return success(); - } - } - - if (tryExtractFallbackSyncTokens(op, srcTok, dstTok, evtTok)) - return success(); - return rewriter.notifyMatchFailure( - op, "cannot extract PIPE/PIPE/EVENT tokens from pto.{set,wait}_flag"); -} -static inline std::string pipeTokFromPipeEnum(mlir::pto::PIPE p) { - return mlir::pto::stringifyPIPE(p).str(); -} -[[maybe_unused]] static inline std::string evtTokFromEventEnum(mlir::pto::EVENT e) { - return mlir::pto::stringifyEVENT(e).str(); -} -static inline std::string pipeTokFromPipeAttr(mlir::pto::PipeAttr a) { - return mlir::pto::stringifyPIPE(a.getPipe()).str(); -} -static inline std::string evtTokFromEventAttr(mlir::pto::EventAttr a) { - return mlir::pto::stringifyEVENT(a.getEvent()).str(); -} - -template -struct HasGetSrcPipe : std::false_type {}; -template -struct HasGetSrcPipe().getSrcPipe())>> : std::true_type {}; - -template -struct HasGetDstPipe : std::false_type {}; -template -struct HasGetDstPipe().getDstPipe())>> : std::true_type {}; - -template -struct HasGetEventId : std::false_type {}; -template -struct HasGetEventId().getEventId())>> : std::true_type {}; - -template -struct HasGetSrcPipeAttr : std::false_type {}; -template -struct HasGetSrcPipeAttr().getSrcPipeAttr())>> : std::true_type {}; - -template -struct HasGetDstPipeAttr : std::false_type {}; -template -struct HasGetDstPipeAttr().getDstPipeAttr())>> : std::true_type {}; - -template -struct HasGetEventIdAttr : std::false_type {}; -template -struct HasGetEventIdAttr().getEventIdAttr())>> : std::true_type {}; - -template -static LogicalResult extractSyncTokens(SyncOpT op, - std::string &srcTok, - std::string &dstTok, - std::string &evtTok, - ConversionPatternRewriter &rewriter) { - if constexpr (HasGetSrcPipe::value && - HasGetDstPipe::value && - HasGetEventId::value) { - auto s = op.getSrcPipe(); - auto d = op.getDstPipe(); - auto e = op.getEventId(); - - if constexpr (std::is_same::value) srcTok = pipeTokFromPipeEnum(s); - else srcTok = pipeTokFromPipeAttr(s); - - if constexpr (std::is_same::value) dstTok = pipeTokFromPipeEnum(d); - else dstTok = pipeTokFromPipeAttr(d); - - if constexpr (std::is_same::value) evtTok = evtTokFromEventEnum(e); - else evtTok = evtTokFromEventAttr(e); - - return success(); - } - - if constexpr (HasGetSrcPipeAttr::value && - HasGetDstPipeAttr::value && - HasGetEventIdAttr::value) { - auto s = op.getSrcPipeAttr(); - auto d = op.getDstPipeAttr(); - auto e = op.getEventIdAttr(); - srcTok = pipeTokFromPipeAttr(s); - dstTok = pipeTokFromPipeAttr(d); - evtTok = evtTokFromEventAttr(e); - return success(); - } - - return extractSyncTripletTokens(op.getOperation(), srcTok, dstTok, evtTok, rewriter); -} -struct PTOSetFlagToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::SetFlagOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - auto *ctx = rewriter.getContext(); - - std::string srcTok, dstTok, evtTok; - if (failed(extractSyncTokens(op, srcTok, dstTok, evtTok, rewriter))) - return failure(); - - auto argsAttr = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, srcTok), - emitc::OpaqueAttr::get(ctx, dstTok), - emitc::OpaqueAttr::get(ctx, evtTok), - }); - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, "set_flag", - /*args=*/argsAttr, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{}); - return success(); - } -}; - -struct PTOWaitFlagToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::WaitFlagOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - auto *ctx = rewriter.getContext(); - - std::string srcTok, dstTok, evtTok; - if (failed(extractSyncTokens(op, srcTok, dstTok, evtTok, rewriter))) - return failure(); - - auto argsAttr = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, srcTok), - emitc::OpaqueAttr::get(ctx, dstTok), - emitc::OpaqueAttr::get(ctx, evtTok), - }); - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, "wait_flag", - /*args=*/argsAttr, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{}); - return success(); - } -}; - -struct PTOSyncToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::TSyncOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - SmallVector operands; - operands.reserve(adaptor.getEvents().size()); - for (Value event : adaptor.getEvents()) - operands.push_back(peelUnrealized(event)); - - rewriter.create( - op.getLoc(), TypeRange{}, "TSYNC", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange(operands)); - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOSyncAllToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - static StringRef coreTypeTok(pto::SyncCoreType coreType) { - switch (coreType) { - case pto::SyncCoreType::AIVOnly: - return "SyncCoreType::AIVOnly"; - case pto::SyncCoreType::AICOnly: - return "SyncCoreType::AICOnly"; - case pto::SyncCoreType::Mix: - return "SyncCoreType::Mix"; - } - llvm_unreachable("unhandled SyncCoreType"); - } - - LogicalResult matchAndRewrite(mlir::pto::SyncAllOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto mode = op.getMode().getValue(); - auto coreType = op.getCoreType().getValue(); - - auto buildGmWorkspace = [&]() -> FailureOr { - Value gm = peelUnrealized(adaptor.getGmWorkspace()); - if (isEmitCGlobalTensorLikeType(gm.getType())) - return gm; - - auto memTy = dyn_cast(op.getGmWorkspace().getType()); - if (!memTy) - return failure(); - - Value gt = buildGlobalTensorFromMemref(rewriter, op.getLoc(), gm, memTy, - op.getGmWorkspace().getDefiningOp() - ? op.getGmWorkspace().getDefiningOp() - : op.getOperation()); - if (!gt) - return failure(); - return gt; - }; - - if (mode == pto::SyncAllMode::Hard) { - std::string callee = "SYNCALL<" + coreTypeTok(coreType).str() + ">"; - rewriter.create(op.getLoc(), TypeRange{}, callee, - ArrayAttr{}, ArrayAttr{}, - ValueRange{}); - rewriter.eraseOp(op); - return success(); - } - - FailureOr gmWorkspace = buildGmWorkspace(); - if (failed(gmWorkspace)) - return rewriter.notifyMatchFailure(op, - "failed to build gm_workspace GlobalTensor"); - - auto i32Ty = emitc::OpaqueType::get(rewriter.getContext(), "int32_t"); - Value usedCores = adaptor.getUsedCores() - ? peelUnrealized(adaptor.getUsedCores()) - : makeEmitCIntConstant(rewriter, op.getLoc(), i32Ty, 0); - if (usedCores.getType() != i32Ty) - usedCores = rewriter.create(op.getLoc(), i32Ty, usedCores) - .getResult(); - - std::string callee = - "SYNCALL"; - - SmallVector operands{*gmWorkspace}; - switch (coreType) { - case pto::SyncCoreType::AIVOnly: { - FailureOr ubWorkspace = - buildSyncAllWorkspaceTileValue(rewriter, op.getLoc(), - op.getUbWorkspace(), - adaptor.getUbWorkspace()); - if (failed(ubWorkspace)) - return rewriter.notifyMatchFailure( - op, "failed to materialize ub_workspace tile"); - operands.push_back(*ubWorkspace); - break; - } - case pto::SyncCoreType::AICOnly: { - FailureOr l1Workspace = - buildSyncAllWorkspaceTileValue(rewriter, op.getLoc(), - op.getL1Workspace(), - adaptor.getL1Workspace()); - if (failed(l1Workspace)) - return rewriter.notifyMatchFailure( - op, "failed to materialize l1_workspace tile"); - operands.push_back(*l1Workspace); - break; - } - case pto::SyncCoreType::Mix: { - FailureOr ubWorkspace = - buildSyncAllWorkspaceTileValue(rewriter, op.getLoc(), - op.getUbWorkspace(), - adaptor.getUbWorkspace()); - FailureOr l1Workspace = - buildSyncAllWorkspaceTileValue(rewriter, op.getLoc(), - op.getL1Workspace(), - adaptor.getL1Workspace()); - if (failed(ubWorkspace) || failed(l1Workspace)) - return rewriter.notifyMatchFailure( - op, "failed to materialize mixed syncall workspace tiles"); - operands.push_back(*ubWorkspace); - operands.push_back(*l1Workspace); - break; - } - } - - operands.push_back(usedCores); - rewriter.create(op.getLoc(), TypeRange{}, callee, - ArrayAttr{}, ArrayAttr{}, - ValueRange(operands)); - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOSyncFlagDynToEmitC : public ConversionPattern { - PTOSyncFlagDynToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - StringRef opName, StringRef callee) - : ConversionPattern(typeConverter, opName, /*benefit=*/1, ctx), - callee(callee.str()) {} - - LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - if (operands.size() != 1) - return rewriter.notifyMatchFailure(op, "expected exactly one dynamic event-id operand"); - - auto srcAttr = op->getAttrOfType("src_pipe"); - auto dstAttr = op->getAttrOfType("dst_pipe"); - if (!srcAttr || !dstAttr) - return rewriter.notifyMatchFailure(op, "missing PipeAttr src_pipe/dst_pipe attrs"); - - auto *ctx = rewriter.getContext(); - std::string srcTok = pipeTokFromPipeAttr(srcAttr); - std::string dstTok = pipeTokFromPipeAttr(dstAttr); - - Value eventVal = operands.front(); - eventVal = - emitCCast(rewriter, op->getLoc(), emitc::OpaqueType::get(ctx, "event_t"), eventVal); - - auto argsAttr = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, srcTok), - emitc::OpaqueAttr::get(ctx, dstTok), - IntegerAttr::get(IndexType::get(ctx), 0), - }); - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, callee, - /*args=*/argsAttr, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{eventVal}); - return success(); - } - -private: - std::string callee; -}; - -struct PTOGetBufToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::GetBufOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - auto *ctx = rewriter.getContext(); - - auto opTypeOr = parseSyncOpTypeLikeAttr(op.getOpTypeAttr()); - if (failed(opTypeOr)) - return rewriter.notifyMatchFailure(op, "get_buf expects pipe_event_type/sync_op_type attr"); - auto pipe = mapSyncOpTypeToPipe(*opTypeOr); - if (!isConcreteSyncPipe(pipe)) - return rewriter.notifyMatchFailure(op, "get_buf op_type cannot map to a concrete pipe"); - std::string pipeTok = pipeTokFromPipeEnum(pipe); - auto argsAttr = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeTok), - op.getBufIdAttr(), - op.getModeAttr(), - }); - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, "get_buf", - /*args=*/argsAttr, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{}); - return success(); - } -}; - -struct PTORlsBufToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::RlsBufOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - auto *ctx = rewriter.getContext(); - - auto opTypeOr = parseSyncOpTypeLikeAttr(op.getOpTypeAttr()); - if (failed(opTypeOr)) - return rewriter.notifyMatchFailure(op, "rls_buf expects pipe_event_type/sync_op_type attr"); - auto pipe = mapSyncOpTypeToPipe(*opTypeOr); - if (!isConcreteSyncPipe(pipe)) - return rewriter.notifyMatchFailure(op, "rls_buf op_type cannot map to a concrete pipe"); - std::string pipeTok = pipeTokFromPipeEnum(pipe); - auto argsAttr = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeTok), - op.getBufIdAttr(), - op.getModeAttr(), - }); - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, "rls_buf", - /*args=*/argsAttr, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{}); - return success(); - } -}; - -struct PTOSetFFTsToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::SetFFTsOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto *ctx = rewriter.getContext(); - auto loc = op.getLoc(); - - Value fftsAddr = peelUnrealized(adaptor.getFfts()); - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - - if (isSetFFTsPointerLikeType(fftsAddr.getType())) { - auto castTyAttr = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - fftsAddr = - rewriter - .create(loc, u64Ty, "reinterpret_cast", - /*args=*/ArrayAttr{}, - /*templateArgs=*/castTyAttr, - /*operands=*/ValueRange{fftsAddr}) - .getResult(0); - } else if (fftsAddr.getType() != u64Ty) { - fftsAddr = - rewriter.create(loc, u64Ty, fftsAddr).getResult(); - } - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, "set_ffts_base_addr", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{fftsAddr}); - return success(); - } -}; - -struct PTOSyncSetToEmitC : public OpConversionPattern { - PTOSyncSetToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult - matchAndRewrite(mlir::pto::SyncSetOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - auto *ctx = rewriter.getContext(); - IntegerAttr eventIdAttr = op.getEventIdAttr(); - Value eventIdDyn = adaptor.getEventIdDyn(); - int64_t fftsMode = 2; - if (IntegerAttr fftsModeAttr = op.getFftsModeAttr()) - fftsMode = fftsModeAttr.getInt(); - - if ((eventIdAttr != nullptr) == static_cast(eventIdDyn)) - return rewriter.notifyMatchFailure( - op, "expects exactly one of static event_id attr or dynamic event_id operand"); - - // A5 inter-core sync mirrors +16 only for cube-side producer (PIPE_FIX). - // Vec-side producer (PIPE_MTE3) emits a single set; hardware handles the - // subblock mapping in PTO-ISA custom flow. - if (targetArch == PTOArch::A5) { - pto::PIPE pipe = op.getPipe().getPipe(); - bool needsMirrorPlus16 = (pipe == pto::PIPE::PIPE_FIX); - std::string pipeTok = pipeTokFromPipeAttr(op.getPipe()); - auto emitSet = [&](Value eventOperand, IntegerAttr eventLiteral, - bool isDynamic) { - if (isDynamic) { - auto args = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeTok), - IntegerAttr::get(IndexType::get(ctx), 0), - }); - rewriter.create(loc, TypeRange{}, "set_intra_block", - /*args=*/args, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{eventOperand}); - return; - } - auto args = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, pipeTok), - eventLiteral, - }); - rewriter.create(loc, TypeRange{}, "set_intra_block", - /*args=*/args, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{}); - }; - - if (eventIdAttr) { - emitSet(Value{}, eventIdAttr, /*isDynamic=*/false); - if (needsMirrorPlus16) { - auto plus16 = IntegerAttr::get(eventIdAttr.getType(), - eventIdAttr.getInt() + 16); - emitSet(Value{}, plus16, /*isDynamic=*/false); - } - } else { - Value eventI32 = castInterCoreEventIdToI32(rewriter, loc, eventIdDyn); - emitSet(eventI32, IntegerAttr{}, /*isDynamic=*/true); - if (needsMirrorPlus16) { - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value c16 = makeEmitCIntConstant(rewriter, loc, i32Ty, 16); - Value eventI32Plus16 = - rewriter.create(loc, i32Ty, eventI32, c16).getResult(); - emitSet(eventI32Plus16, IntegerAttr{}, /*isDynamic=*/true); - } - } - - rewriter.eraseOp(op); - return success(); - } - - InterCoreSyncCallDesc desc; - if (eventIdAttr) { - desc = buildInterCoreSyncSetCall(rewriter, loc, targetArch, op.getPipe(), - eventIdAttr, fftsMode); - } else { - desc = buildInterCoreSyncSetCallDyn(rewriter, loc, targetArch, op.getPipe(), - eventIdDyn, fftsMode); - } - rewriter.create(loc, TypeRange{}, desc.callee, - /*args=*/desc.args, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/desc.operands); - - rewriter.eraseOp(op); - return success(); - } - - PTOArch targetArch; -}; - -struct PTOSyncWaitToEmitC : public OpConversionPattern { - PTOSyncWaitToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult - matchAndRewrite(mlir::pto::SyncWaitOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - IntegerAttr eventIdAttr = op.getEventIdAttr(); - Value eventIdDyn = adaptor.getEventIdDyn(); - - if ((eventIdAttr != nullptr) == static_cast(eventIdDyn)) - return rewriter.notifyMatchFailure( - op, "expects exactly one of static event_id attr or dynamic event_id operand"); - - InterCoreSyncCallDesc desc; - if (eventIdAttr) { - desc = buildInterCoreSyncWaitCall(rewriter, targetArch, op.getPipe(), - eventIdAttr); - } else { - desc = buildInterCoreSyncWaitCallDyn(rewriter, loc, targetArch, op.getPipe(), - eventIdDyn); - } - rewriter.create(loc, TypeRange{}, desc.callee, - desc.args, ArrayAttr{}, desc.operands); - - rewriter.eraseOp(op); - return success(); - } - - PTOArch targetArch; -}; - -// GetBlockIdxOp Lowering (pto.get_block_idx -> get_block_idx()) -struct PTOGetBlockIdxToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::pto::GetBlockIdxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - rewriter.replaceOpWithNewOp( - op, op.getType(), "get_block_idx", ValueRange{}, ArrayAttr{}, - ArrayAttr{}); - - return success(); - } -}; - -// GetBlockNumOp Lowering (pto.get_block_num -> get_block_num()) -struct PTOGetBlockNumToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::pto::GetBlockNumOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - rewriter.replaceOpWithNewOp( - op, op.getType(), "get_block_num", ValueRange{}, ArrayAttr{}, - ArrayAttr{}); - - return success(); - } -}; - -// GetSubBlockIdxOp Lowering (pto.get_block_idx -> get_subblockid()) -struct PTOGetSubBlockIdxToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::pto::GetSubBlockIdxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - rewriter.replaceOpWithNewOp( - op, op.getType(), "get_subblockid", ValueRange{}, ArrayAttr{}, - ArrayAttr{}); - - return success(); - } -}; - -// GetSubBlockNumOp Lowering. -struct PTOGetSubBlockNumToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(mlir::pto::GetSubBlockNumOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - rewriter.replaceOpWithNewOp( - op, op.getType(), "get_subblockdim", ValueRange{}, ArrayAttr{}, - ArrayAttr{}); - - return success(); - } -}; - - -struct PTOMScatterToMSCATTER : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::MScatterOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto *ctx = rewriter.getContext(); - Value src = peelUnrealized(adaptor.getSrc()); - Value idx = peelUnrealized(adaptor.getIdx()); - Value mem = peelUnrealized(adaptor.getMem()); - - Value memArg = maybeWrapGlobalMemrefAsGlobalTensor( - rewriter, op.getLoc(), mem, op.getMem().getType(), op.getOperation()); - - auto scatterAtomicTok = [&](pto::ScatterAtomicOp atomic) -> StringRef { - switch (atomic) { - case pto::ScatterAtomicOp::None: - return "pto::ScatterAtomicOp::None"; - case pto::ScatterAtomicOp::Add: - return "pto::ScatterAtomicOp::Add"; - case pto::ScatterAtomicOp::Max: - return "pto::ScatterAtomicOp::Max"; - case pto::ScatterAtomicOp::Min: - return "pto::ScatterAtomicOp::Min"; - } - llvm_unreachable("unknown ScatterAtomicOp"); - }; - auto scatterOobTok = [&](pto::ScatterOOB mode) -> StringRef { - switch (mode) { - case pto::ScatterOOB::Undefined: - return "pto::ScatterOOB::Undefined"; - case pto::ScatterOOB::Skip: - return "pto::ScatterOOB::Skip"; - case pto::ScatterOOB::Clamp: - return "pto::ScatterOOB::Clamp"; - case pto::ScatterOOB::Wrap: - return "pto::ScatterOOB::Wrap"; - } - llvm_unreachable("unknown ScatterOOB"); - }; - - SmallVector templateArgVec; - const bool rowCoalesce = - isRowCoalescedMGatherIndexType(op.getSrc().getType(), op.getIdx().getType()); - templateArgVec.push_back(emitc::OpaqueAttr::get( - ctx, rowCoalesce ? "pto::Coalesce::Row" : "pto::Coalesce::Elem")); - if (op.getScatterAtomicOp() != pto::ScatterAtomicOp::None || - op.getScatterOob() != pto::ScatterOOB::Undefined) { - templateArgVec.push_back(emitc::OpaqueAttr::get( - ctx, scatterAtomicTok(op.getScatterAtomicOp()))); - if (op.getScatterOob() != pto::ScatterOOB::Undefined) - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, scatterOobTok(op.getScatterOob()))); - } - ArrayAttr templateArgs = rewriter.getArrayAttr(templateArgVec); - - rewriter.create( - op.getLoc(), TypeRange{}, "MSCATTER", - ArrayAttr{}, templateArgs, - ValueRange{memArg, src, idx}); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTOSetValToSETVAL : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSetValOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value dst = peelUnrealized(adaptor.getDst()); - Value val = peelUnrealized(adaptor.getVal()); - - // ---- offset: SSA index operand ---- - Value offset = peelUnrealized(adaptor.getOffset()); - - // Emit a marker call and let the ptoas post-processing step lower it to - // the corresponding tile setter. - rewriter.create( - op.getLoc(), TypeRange{}, "PTOAS__TILE_SET_VALUE", - ArrayAttr{}, ArrayAttr{}, ValueRange{dst, offset, val}); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTOGetValToGETVAL : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGetValOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src = peelUnrealized(adaptor.getSrc()); - - // ---- offset: SSA index operand ---- - Value offset = peelUnrealized(adaptor.getOffset()); - - // Emit a marker call and let the ptoas post-processing step lower it to - // the corresponding tile getter. - Type dstTy = getTypeConverter()->convertType(op.getDst().getType()); - if (!dstTy) - return failure(); - auto call = rewriter.create( - op.getLoc(), - TypeRange{dstTy}, - "PTOAS__TILE_GET_VALUE", - ArrayAttr{}, ArrayAttr{}, - ValueRange{src, offset}); - - rewriter.replaceOp(op, call.getResults()); - return success(); - } -}; - -struct PTOTAxpyToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAxpyOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value scalar = peelUnrealized(adaptor.getScalar()); - - rewriter.create( - loc, TypeRange{}, "TAXPY", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOHistogramToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::THistogramOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value idx = peelUnrealized(adaptor.getIdx()); - Value dst = peelUnrealized(adaptor.getDst()); - - StringRef histByte = "HistByte::BYTE_1"; - int64_t byte = 1; - auto byteAttr = op.getByteAttr(); - if (byteAttr) - byte = byteAttr.getInt(); - if (auto legacyIsMSB = op->getAttrOfType("isMSB")) { - int64_t legacyByte = legacyIsMSB.getValue() ? 1 : 0; - if (byteAttr && byte != legacyByte) - return rewriter.notifyMatchFailure( - op, "conflicting 'byte' and legacy 'isMSB' attributes"); - byte = legacyByte; - } - switch (byte) { - case 0: - histByte = "HistByte::BYTE_0"; - break; - case 1: - histByte = "HistByte::BYTE_1"; - break; - case 2: - histByte = "HistByte::BYTE_2"; - break; - case 3: - histByte = "HistByte::BYTE_3"; - break; - default: - return rewriter.notifyMatchFailure(op, "expected byte to be in range [0, 3]"); - } - - auto templateArgs = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, histByte)}); - rewriter.create( - loc, TypeRange{}, "THISTOGRAM", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, - /*operands=*/ValueRange{dst, src, idx}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOGetScaleAddrToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGetScaleAddrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TGET_SCALE_ADDR", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOSetValidShapeToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::SetValidShapeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto peelAllCasts = [](Value v) { - while (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(0); - if (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(); - return v; - }; - auto isTileLike = [](Value v) -> bool { - auto ot = dyn_cast(v.getType()); - if (!ot) - return false; - StringRef s = ot.getValue(); - return s.contains("Tile<") || s.contains("ConvTile<"); - }; - - Value src = peelAllCasts(peelUnrealized(adaptor.getSource())); - Value row = peelUnrealized(adaptor.getValidRow()); - Value col = peelUnrealized(adaptor.getValidCol()); - - if (!isTileLike(src)) - return rewriter.notifyMatchFailure( - op, "set_validshape source must lower to a tile-like value"); - - rewriter.create( - op.getLoc(), TypeRange{}, "PTOAS__TILE_SET_VALIDSHAPE", ArrayAttr{}, - ArrayAttr{}, ValueRange{src, row, col}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOGetValidShapeToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::GetValidShapeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto peelAllCasts = [](Value v) { - while (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(0); - if (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(); - return v; - }; - auto isTileLike = [](Value v) -> bool { - auto ot = dyn_cast(v.getType()); - if (!ot) - return false; - StringRef s = ot.getValue(); - return s.contains("Tile<") || s.contains("ConvTile<"); - }; - - Value src = peelAllCasts(peelUnrealized(adaptor.getSource())); - if (!isTileLike(src)) - return rewriter.notifyMatchFailure( - op, "get_validshape source must lower to a tile-like value"); - - auto resultTy = getTypeConverter()->convertType(rewriter.getIndexType()); - if (!resultTy) - return failure(); - - Value row = rewriter - .create( - op.getLoc(), resultTy, - "PTOAS__TILE_GET_VALID_ROW", ArrayAttr{}, - ArrayAttr{}, ValueRange{src}) - .getResult(0); - Value col = rewriter - .create( - op.getLoc(), resultTy, - "PTOAS__TILE_GET_VALID_COL", ArrayAttr{}, - ArrayAttr{}, ValueRange{src}) - .getResult(0); - rewriter.replaceOp(op, ValueRange{row, col}); - return success(); - } -}; - -struct PTOTAssignToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAssignOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto peelAllCasts = [](Value v) { - while (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(0); - if (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(); - return v; - }; - auto isTileLike = [](Value v) -> bool { - auto ot = dyn_cast(v.getType()); - if (!ot) - return false; - StringRef s = ot.getValue(); - return s.contains("Tile<") || s.contains("ConvTile<"); - }; - - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value tile = peelAllCasts(peelUnrealized(adaptor.getTile())); - if (!isTileLike(tile)) - return rewriter.notifyMatchFailure( - op, "tassign tile must lower to a tile-like value"); - - Value addr = peelUnrealized(adaptor.getAddr()); - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - if (isa(addr.getType()) || - (isa(addr.getType()) && - cast(addr.getType()).getValue().ends_with("*"))) { - auto rcU64 = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - addr = rewriter - .create(loc, u64Ty, "reinterpret_cast", - ArrayAttr{}, rcU64, - ValueRange{addr}) - .getResult(0); - } else if (addr.getType() != u64Ty) { - addr = rewriter.create(loc, u64Ty, addr).getResult(); - } - - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{tile, addr}); - rewriter.replaceOp(op, tile); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.load_scalar / pto.store_scalar lowering -> ptr[offset] -//===----------------------------------------------------------------------===// - -static Type getPointerLikeElementType(Type type) { - if (auto ptrTy = dyn_cast(type)) - return ptrTy.getElementType(); - if (auto memTy = dyn_cast(type)) - return memTy.getElementType(); - return Type(); -} - -struct PTOPtrToIntToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::PtrToIntOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value ptr = peelUnrealized(adaptor.getPtr()); - Type dstTy = getTypeConverter()->convertType(op.getResult().getType()); - if (!dstTy) - return failure(); - - auto dstOpaque = dyn_cast(dstTy); - if (!dstOpaque) - return failure(); - - auto templateArgs = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), - dstOpaque.getValue())}); - auto cast = rewriter.create( - op.getLoc(), dstTy, "reinterpret_cast", ArrayAttr{}, templateArgs, - ValueRange{ptr}); - rewriter.replaceOp(op, cast.getResult(0)); - return success(); - } -}; - -struct PTOIntToPtrToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::IntToPtrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value addr = peelUnrealized(adaptor.getAddr()); - Type dstTy = getTypeConverter()->convertType(op.getResult().getType()); - if (!dstTy) - return failure(); - - Type dstElemTy = getPointerLikeElementType(op.getResult().getType()); - if (!dstElemTy) - return failure(); - - std::string castType = - std::string("__gm__ ") + getEmitCScalarTypeToken(dstElemTy) + "*"; - auto templateArgs = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), - castType)}); - auto cast = rewriter.create( - op.getLoc(), dstTy, "reinterpret_cast", ArrayAttr{}, templateArgs, - ValueRange{addr}); - rewriter.replaceOp(op, cast.getResult(0)); - return success(); - } -}; - -struct PTOLoadScalarToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::LoadScalarOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value ptr = peelUnrealized(adaptor.getPtr()); - Value offset = peelUnrealized(adaptor.getOffset()); - - Type dstTy = getTypeConverter()->convertType(op.getValue().getType()); - if (!dstTy) - return failure(); - - auto call = rewriter.create( - op.getLoc(), TypeRange{dstTy}, "PTOAS__PTR_LOAD", - ArrayAttr{}, ArrayAttr{}, ValueRange{ptr, offset}); - - rewriter.replaceOp(op, call.getResults()); - return success(); - } -}; - -struct PTOStoreScalarToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::StoreScalarOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value ptr = peelUnrealized(adaptor.getPtr()); - Value offset = peelUnrealized(adaptor.getOffset()); - Value val = peelUnrealized(adaptor.getValue()); - - rewriter.create( - op.getLoc(), TypeRange{}, "PTOAS__PTR_STORE", - ArrayAttr{}, ArrayAttr{}, ValueRange{ptr, offset, val}); - rewriter.create( - op.getLoc(), TypeRange{}, "PTOAS__SCALAR_GM_STORE_FLUSH", - ArrayAttr{}, ArrayAttr{}, ValueRange{ptr}); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.tabs lowering -> TABS(dst, src) -//===----------------------------------------------------------------------===// - -struct PTOTAbsToTABS : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAbsOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - // intrinsic: TABS(dst, src) - rewriter.create( - op.getLoc(), TypeRange{}, "TABS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tadd lowering -> TADD(dst, src0, src1) -//===----------------------------------------------------------------------===// - -struct PTOTAddToTADD : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAddOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TADD", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOInitializeL2G2LPipeToEmitC - : public OpConversionPattern { - PTOInitializeL2G2LPipeToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult matchAndRewrite(mlir::pto::InitializeL2G2LPipeOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto tpipeTok = buildTPipeTokenFromInitOp(op.getOperation(), targetArch); - if (failed(tpipeTok)) - return rewriter.notifyMatchFailure(op, "failed to build TPipe token"); - - auto *ctx = rewriter.getContext(); - auto emitPipeTy = - cast(getTypeConverter()->convertType(op.getPipe().getType())); - - Value gmAddr = peelUnrealized(adaptor.getGmAddr()); - gmAddr = materializeTensorViewDataPointer( - rewriter, op.getLoc(), gmAddr, op.getGmAddr().getType()); - Value localAddr = - op.getLocalAddr() ? peelUnrealized(adaptor.getLocalAddr()) : Value(); - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value zero = makeEmitCIntConstant(rewriter, op.getLoc(), i32Ty, 0); - - Value c2vBuf = zero; - Value v2cBuf = zero; - if (op.getDirMask() == 1) - c2vBuf = localAddr ? localAddr : zero; - else if (op.getDirMask() == 2) - v2cBuf = localAddr ? localAddr : zero; - else if (op.getDirMask() == 3) { - if (localAddr) { - if (!op.getPeerLocalAddr()) - return rewriter.notifyMatchFailure( - op, "bidirectional l2g2l pipe requires peer local buffer"); - c2vBuf = localAddr; - v2cBuf = peelUnrealized(adaptor.getPeerLocalAddr()); - } - } else - return rewriter.notifyMatchFailure(op, "unsupported dir_mask"); - - rewriter.replaceOpWithNewOp( - op, TypeRange{emitPipeTy}, *tpipeTok, ArrayAttr{}, ArrayAttr{}, - ValueRange{gmAddr, c2vBuf, v2cBuf}); - return success(); - } - - PTOArch targetArch; -}; - -struct PTOInitializeL2LPipeToEmitC - : public OpConversionPattern { - PTOInitializeL2LPipeToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult matchAndRewrite(mlir::pto::InitializeL2LPipeOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto tpipeTok = buildTPipeTokenFromInitOp(op.getOperation(), targetArch); - if (failed(tpipeTok)) - return rewriter.notifyMatchFailure(op, "failed to build TPipe token"); - - auto *ctx = rewriter.getContext(); - auto emitPipeTy = - cast(getTypeConverter()->convertType(op.getPipe().getType())); - - auto gmPtrTy = - emitc::PointerType::get(emitc::OpaqueType::get(ctx, "__gm__ void")); - Value nullGm = - makeEmitCOpaqueConstant(rewriter, op.getLoc(), gmPtrTy, "nullptr"); - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value zero = makeEmitCIntConstant(rewriter, op.getLoc(), i32Ty, 0); - Value localAddr = peelUnrealized(adaptor.getLocalAddr()); - - Value c2vBuf = zero; - Value v2cBuf = zero; - if (op.getDirMask() == 1) - c2vBuf = localAddr; - else if (op.getDirMask() == 2) - v2cBuf = localAddr; - else if (op.getDirMask() == 3) { - c2vBuf = localAddr; - v2cBuf = peelUnrealized(adaptor.getPeerLocalAddr()); - } else - return rewriter.notifyMatchFailure(op, "unsupported dir_mask"); - - rewriter.replaceOpWithNewOp( - op, TypeRange{emitPipeTy}, *tpipeTok, ArrayAttr{}, ArrayAttr{}, - ValueRange{nullGm, c2vBuf, v2cBuf}); - return success(); - } - - PTOArch targetArch; -}; - -struct PTOBuildAsyncSessionToEmitC - : public OpConversionPattern { - PTOBuildAsyncSessionToEmitC(TypeConverter &typeConverter, MLIRContext *ctx) - : OpConversionPattern(typeConverter, ctx) {} - - LogicalResult matchAndRewrite(mlir::pto::BuildAsyncSessionOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto *ctx = rewriter.getContext(); - Location loc = op.getLoc(); - - auto sessionTy = - dyn_cast(getTypeConverter()->convertType(op.getSession().getType())); - if (!sessionTy) - return rewriter.notifyMatchFailure(op, "failed to convert async session type"); - - FailureOr scratchTile = - buildAsyncScratchTileValue(rewriter, loc, op.getScratch(), - adaptor.getScratch()); - if (failed(scratchTile)) - return rewriter.notifyMatchFailure(op, "failed to materialize async scratch tile"); - - Value workspace = - castToGMBytePointer(rewriter, loc, peelUnrealized(adaptor.getWorkspace())); - - Value session = rewriter - .create( - loc, sessionTy, emitc::OpaqueAttr::get(ctx, "")) - .getResult(); - - auto u32Ty = emitc::OpaqueType::get(ctx, "uint32_t"); - - auto makeU32Const = [&](uint64_t value) -> Value { - return makeEmitCOpaqueConstant(rewriter, loc, u32Ty, - std::to_string(value) + "u"); - }; - uint64_t syncId = op.getSyncIdAttr() ? op.getSyncIdAttr().getInt() : 0; - uint64_t blockBytes = - op.getBlockBytesAttr() ? op.getBlockBytesAttr().getInt() : 32 * 1024; - uint64_t commBlockOffset = - op.getCommBlockOffsetAttr() ? op.getCommBlockOffsetAttr().getInt() : 0; - uint64_t queueNum = op.getQueueNumAttr() ? op.getQueueNumAttr().getInt() : 1; - uint64_t channelGroupIdx = op.getChannelGroupIdxAttr() - ? op.getChannelGroupIdxAttr().getInt() - : UINT32_MAX; - - Value syncIdVal = makeU32Const(syncId); - Value channelGroupIdxVal = - channelGroupIdx == UINT32_MAX - ? makeEmitCOpaqueConstant(rewriter, loc, u32Ty, "UINT32_MAX") - : makeU32Const(channelGroupIdx); - - auto baseConfigTy = - emitc::OpaqueType::get(ctx, "pto::comm::sdma::SdmaBaseConfig"); - Value baseConfig = - rewriter - .create( - loc, baseConfigTy, - emitc::OpaqueAttr::get( - ctx, "{" + std::to_string(blockBytes) + "ULL, " + - std::to_string(commBlockOffset) + "ULL, " + - std::to_string(queueNum) + "u}")) - .getResult(); - - rewriter.create( - loc, TypeRange{}, "pto::comm::BuildAsyncSession", - ArrayAttr{}, ArrayAttr{}, - ValueRange{*scratchTile, workspace, session, syncIdVal, baseConfig, - channelGroupIdxVal}); - - rewriter.replaceOp(op, session); - return success(); - } -}; - -template -struct PTOAsyncTransferToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - explicit PTOAsyncTransferToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - StringRef callee) - : OpConversionPattern(typeConverter, ctx), callee(callee.str()) {} - - LogicalResult matchAndRewrite(AsyncOp op, typename AsyncOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value dstGT = dst; - Value srcGT = src; - if (!isEmitCGlobalTensorLikeType(dstGT.getType())) { - auto dstMrTy = dyn_cast(op.getDst().getType()); - if (!dstMrTy) - return rewriter.notifyMatchFailure(op, "expected dst to lower to GlobalTensor or memref"); - dstGT = buildGlobalTensorFromMemref(rewriter, op.getLoc(), dst, dstMrTy, - op.getDst().getDefiningOp() - ? op.getDst().getDefiningOp() - : op.getOperation()); - } - if (!isEmitCGlobalTensorLikeType(srcGT.getType())) { - auto srcMrTy = dyn_cast(op.getSrc().getType()); - if (!srcMrTy) - return rewriter.notifyMatchFailure(op, "expected src to lower to GlobalTensor or memref"); - srcGT = buildGlobalTensorFromMemref(rewriter, op.getLoc(), src, srcMrTy, - op.getSrc().getDefiningOp() - ? op.getSrc().getDefiningOp() - : op.getOperation()); - } - if (!dstGT || !srcGT) - return rewriter.notifyMatchFailure(op, "failed to build GlobalTensor operands"); - - Type eventTy = this->getTypeConverter()->convertType(op.getEvent().getType()); - if (!eventTy) - return rewriter.notifyMatchFailure(op, "failed to convert async event type"); - - rewriter.replaceOpWithNewOp( - op, TypeRange{eventTy}, callee, ArrayAttr{}, ArrayAttr{}, - ValueRange{dstGT, srcGT, peelUnrealized(adaptor.getSession())}); - return success(); - } - - std::string callee; -}; - -template -struct PTOAsyncEventToEmitC : public OpConversionPattern { - explicit PTOAsyncEventToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - StringRef callee) - : OpConversionPattern(typeConverter, ctx), - callee(callee.str()) {} - - LogicalResult matchAndRewrite(AsyncEventOp op, - typename AsyncEventOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type resultTy = - this->getTypeConverter()->convertType(op.getCompleted().getType()); - if (!resultTy) - return rewriter.notifyMatchFailure(op, "failed to convert async event result type"); - - rewriter.replaceOpWithNewOp( - op, TypeRange{resultTy}, callee, ArrayAttr{}, ArrayAttr{}, - ValueRange{peelUnrealized(adaptor.getEvent()), - peelUnrealized(adaptor.getSession())}); - return success(); - } - - std::string callee; -}; - -static FailureOr buildCommGlobalTensorValue( - ConversionPatternRewriter &rewriter, Location loc, Value originalValue, - Value emittedValue, Operation *anchor) { - Value value = peelUnrealized(emittedValue); - if (isEmitCGlobalTensorLikeType(value.getType())) - return value; - - auto memTy = dyn_cast(originalValue.getType()); - if (!memTy) - return failure(); - - Value gt = buildGlobalTensorFromMemref(rewriter, loc, value, memTy, anchor); - if (!gt) - return failure(); - return gt; -} - -static FailureOr buildCommTileValue(ConversionPatternRewriter &rewriter, - Location loc, Value originalValue, - Value emittedValue) { - Value value = peelUnrealized(emittedValue); - if (auto opaqueTy = dyn_cast(value.getType())) { - StringRef typeStr = opaqueTy.getValue(); - if (typeStr.contains("Tile<") || typeStr.contains("ConvTile<")) - return value; - } - return buildAsyncScratchTileValue(rewriter, loc, originalValue, emittedValue); -} - -static FailureOr buildCollectiveParallelGroup( - ConversionPatternRewriter &rewriter, Location loc, - ArrayRef groupGTs, int64_t root) { - if (groupGTs.empty()) - return failure(); - - auto firstTy = dyn_cast(groupGTs.front().getType()); - if (!firstTy) - return failure(); - - auto *ctx = rewriter.getContext(); - auto arrayTy = emitc::ArrayType::get({static_cast(groupGTs.size())}, - firstTy); - auto groupArray = cast>( - rewriter - .create(loc, arrayTy, - emitc::OpaqueAttr::get(ctx, "{}")) - .getResult()); - - auto indexTy = emitc::OpaqueType::get(ctx, "int"); - for (auto [idx, groupVal] : llvm::enumerate(groupGTs)) { - Value idxVal = - makeEmitCIntConstant(rewriter, loc, indexTy, static_cast(idx)); - Value slot = - rewriter.create(loc, groupArray, ValueRange{idxVal}) - .getResult(); - rewriter.create(loc, slot, groupVal); - } - - std::string pgTypeStr = - (Twine("pto::comm::ParallelGroup<") + firstTy.getValue() + ">").str(); - auto pgTy = emitc::OpaqueType::get(ctx, pgTypeStr); - Value sizeVal = makeEmitCIntConstant(rewriter, loc, indexTy, - static_cast(groupGTs.size())); - Value rootVal = makeEmitCIntConstant(rewriter, loc, indexTy, root); - return rewriter - .create( - loc, TypeRange{pgTy}, (Twine(pgTypeStr) + "::Create").str(), - ArrayAttr{}, ArrayAttr{}, ValueRange{groupArray, sizeVal, rootVal}) - .getResult(0); -} - -static std::string notifyOpTok(pto::NotifyOp op) { - switch (op) { - case pto::NotifyOp::AtomicAdd: - return "pto::comm::NotifyOp::AtomicAdd"; - case pto::NotifyOp::Set: - return "pto::comm::NotifyOp::Set"; - } - return "pto::comm::NotifyOp::Set"; -} - -static std::string waitCmpTok(pto::WaitCmp cmp) { - switch (cmp) { - case pto::WaitCmp::EQ: - return "pto::comm::WaitCmp::EQ"; - case pto::WaitCmp::NE: - return "pto::comm::WaitCmp::NE"; - case pto::WaitCmp::GT: - return "pto::comm::WaitCmp::GT"; - case pto::WaitCmp::GE: - return "pto::comm::WaitCmp::GE"; - case pto::WaitCmp::LT: - return "pto::comm::WaitCmp::LT"; - case pto::WaitCmp::LE: - return "pto::comm::WaitCmp::LE"; - } - return "pto::comm::WaitCmp::EQ"; -} - -static std::string reduceOpTok(pto::ReduceOp op) { - switch (op) { - case pto::ReduceOp::Sum: - return "pto::comm::ReduceOp::Sum"; - case pto::ReduceOp::Max: - return "pto::comm::ReduceOp::Max"; - case pto::ReduceOp::Min: - return "pto::comm::ReduceOp::Min"; - } - return "pto::comm::ReduceOp::Sum"; -} - -template -static FailureOr> buildCommGroupGlobalTensors( - ConversionPatternRewriter &rewriter, Location loc, OpTy op, - ValueRange originalGroup, ValueRange emittedGroup) { - SmallVector groupGTs; - groupGTs.reserve(originalGroup.size()); - for (auto [orig, emitted] : llvm::zip(originalGroup, emittedGroup)) { - FailureOr gt = - buildCommGlobalTensorValue(rewriter, loc, orig, emitted, op.getOperation()); - if (failed(gt)) - return failure(); - groupGTs.push_back(*gt); - } - return groupGTs; -} - -template -struct PTOCommCollectiveToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - explicit PTOCommCollectiveToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - StringRef apiName) - : OpConversionPattern(typeConverter, ctx), - apiName(apiName.str()) {} - - LogicalResult matchAndRewrite(CollectiveOp op, typename CollectiveOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - - auto buildPong = [&](Value original, Value emitted, StringRef name) -> FailureOr { - if (!original) - return failure(); - return buildCommTileValue(rewriter, loc, original, emitted); - }; - - if constexpr (std::is_same_v) { - FailureOr srcGT = - buildCommGlobalTensorValue(rewriter, loc, op.getSrc(), adaptor.getSrc(), - op.getOperation()); - FailureOr pingTile = - buildCommTileValue(rewriter, loc, op.getPing(), adaptor.getPing()); - auto groupGTs = - buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); - if (failed(srcGT) || failed(pingTile) || failed(groupGTs)) - return rewriter.notifyMatchFailure(op, "failed to materialize broadcast operands"); - FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); - if (failed(pg)) - return rewriter.notifyMatchFailure(op, "failed to materialize broadcast group"); - if (op.getPong()) { - FailureOr pongTile = - buildPong(op.getPong(), adaptor.getPong(), "__pong"); - if (failed(pongTile)) - return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); - rewriter.create( - loc, TypeRange{}, "pto::comm::TBROADCAST", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *srcGT, *pingTile, *pongTile}); - } else { - rewriter.create( - loc, TypeRange{}, "pto::comm::TBROADCAST", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *srcGT, *pingTile}); - } - } else if constexpr (std::is_same_v) { - FailureOr dstGT = - buildCommGlobalTensorValue(rewriter, loc, op.getDst(), adaptor.getDst(), - op.getOperation()); - FailureOr pingTile = - buildCommTileValue(rewriter, loc, op.getPing(), adaptor.getPing()); - auto groupGTs = - buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); - if (failed(dstGT) || failed(pingTile) || failed(groupGTs)) - return rewriter.notifyMatchFailure(op, "failed to materialize gather operands"); - FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); - if (failed(pg)) - return rewriter.notifyMatchFailure(op, "failed to materialize gather group"); - if (op.getPong()) { - FailureOr pongTile = - buildPong(op.getPong(), adaptor.getPong(), "__pong"); - if (failed(pongTile)) - return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); - rewriter.create( - loc, TypeRange{}, "pto::comm::TGATHER", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *dstGT, *pingTile, *pongTile}); - } else { - rewriter.create( - loc, TypeRange{}, "pto::comm::TGATHER", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *dstGT, *pingTile}); - } - } else if constexpr (std::is_same_v) { - FailureOr srcGT = - buildCommGlobalTensorValue(rewriter, loc, op.getSrc(), adaptor.getSrc(), - op.getOperation()); - FailureOr pingTile = - buildCommTileValue(rewriter, loc, op.getPing(), adaptor.getPing()); - auto groupGTs = - buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); - if (failed(srcGT) || failed(pingTile) || failed(groupGTs)) - return rewriter.notifyMatchFailure(op, "failed to materialize scatter operands"); - FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); - if (failed(pg)) - return rewriter.notifyMatchFailure(op, "failed to materialize scatter group"); - if (op.getPong()) { - FailureOr pongTile = - buildPong(op.getPong(), adaptor.getPong(), "__pong"); - if (failed(pongTile)) - return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); - rewriter.create( - loc, TypeRange{}, "pto::comm::TSCATTER", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *srcGT, *pingTile, *pongTile}); - } else { - rewriter.create( - loc, TypeRange{}, "pto::comm::TSCATTER", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *srcGT, *pingTile}); - } - } else { - FailureOr dstGT = - buildCommGlobalTensorValue(rewriter, loc, op.getDst(), adaptor.getDst(), - op.getOperation()); - FailureOr accTile = - buildCommTileValue(rewriter, loc, op.getAcc(), adaptor.getAcc()); - FailureOr recvPing = - buildCommTileValue(rewriter, loc, op.getRecvPing(), adaptor.getRecvPing()); - auto groupGTs = - buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); - if (failed(dstGT) || failed(accTile) || failed(recvPing) || failed(groupGTs)) - return rewriter.notifyMatchFailure(op, "failed to materialize reduce operands"); - FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); - if (failed(pg)) - return rewriter.notifyMatchFailure(op, "failed to materialize reduce group"); - if (op.getRecvPong()) { - FailureOr recvPong = - buildPong(op.getRecvPong(), adaptor.getRecvPong(), "__recv_pong"); - if (failed(recvPong)) - return rewriter.notifyMatchFailure(op, "failed to materialize recv_pong"); - auto reduceTy = - emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::ReduceOp"); - Value reduceOp = makeEmitCOpaqueConstant(rewriter, loc, reduceTy, - reduceOpTok(op.getReduceOp())); - rewriter.create( - loc, TypeRange{}, "pto::comm::TREDUCE", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *dstGT, *accTile, *recvPing, *recvPong, reduceOp}); - } else { - auto reduceTy = - emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::ReduceOp"); - Value reduceOp = makeEmitCOpaqueConstant(rewriter, loc, reduceTy, - reduceOpTok(op.getReduceOp())); - rewriter.create( - loc, TypeRange{}, "pto::comm::TREDUCE", ArrayAttr{}, ArrayAttr{}, - ValueRange{*pg, *dstGT, *accTile, *recvPing, reduceOp}); - } - } - rewriter.eraseOp(op); - return success(); - } - - std::string apiName; -}; - -template -struct PTOP2PCommToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - explicit PTOP2PCommToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - StringRef callee) - : OpConversionPattern(typeConverter, ctx), callee(callee.str()) {} - - LogicalResult matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - FailureOr dstGT = - buildCommGlobalTensorValue(rewriter, op.getLoc(), op.getDst(), adaptor.getDst(), - op.getOperation()); - FailureOr srcGT = - buildCommGlobalTensorValue(rewriter, op.getLoc(), op.getSrc(), adaptor.getSrc(), - op.getOperation()); - FailureOr pingTile = - buildCommTileValue(rewriter, op.getLoc(), op.getPing(), adaptor.getPing()); - if (failed(dstGT) || failed(srcGT) || failed(pingTile)) - return rewriter.notifyMatchFailure(op, "failed to materialize p2p operands"); - - SmallVector operands{*dstGT, *srcGT, *pingTile}; - std::string actualCallee = callee; - if constexpr (std::is_same_v) { - if (op.getAtomicType() == pto::AtomicType::AtomicAdd) - actualCallee = "pto::comm::TPUT"; - } - if (op.getPong()) { - FailureOr pongTile = - buildCommTileValue(rewriter, op.getLoc(), op.getPong(), adaptor.getPong()); - if (failed(pongTile)) - return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); - operands.push_back(*pongTile); - } - - rewriter.create(op.getLoc(), TypeRange{}, actualCallee, - ArrayAttr{}, ArrayAttr{}, operands); - rewriter.eraseOp(op); - return success(); - } - - std::string callee; -}; - -template -struct PTOSignalCommToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - explicit PTOSignalCommToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - StringRef callee) - : OpConversionPattern(typeConverter, ctx), - callee(callee.str()) {} - - LogicalResult matchAndRewrite(SignalOp op, typename SignalOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - FailureOr signalGT = buildCommGlobalTensorValue( - rewriter, op.getLoc(), op.getSignal(), adaptor.getSignal(), op.getOperation()); - if (failed(signalGT)) - return rewriter.notifyMatchFailure(op, "failed to materialize signal operand"); - - if constexpr (std::is_same_v) { - auto notifyTy = - emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::NotifyOp"); - Value notifyOp = makeEmitCOpaqueConstant( - rewriter, op.getLoc(), notifyTy, notifyOpTok(op.getNotifyOp())); - SmallVector operands{*signalGT, peelUnrealized(adaptor.getValue()), - notifyOp}; - rewriter.create(op.getLoc(), TypeRange{}, callee, - ArrayAttr{}, ArrayAttr{}, operands); - rewriter.eraseOp(op); - } else { - auto waitCmpTy = - emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::WaitCmp"); - Value waitCmp = makeEmitCOpaqueConstant( - rewriter, op.getLoc(), waitCmpTy, waitCmpTok(op.getCmp())); - SmallVector operands{*signalGT, peelUnrealized(adaptor.getCmpValue()), - waitCmp}; - if constexpr (std::is_same_v) { - Type resultTy = this->getTypeConverter()->convertType(op.getResult().getType()); - if (!resultTy) - return rewriter.notifyMatchFailure(op, "failed to convert ttest result type"); - rewriter.replaceOpWithNewOp( - op, TypeRange{resultTy}, callee, ArrayAttr{}, ArrayAttr{}, operands); - } else { - rewriter.create(op.getLoc(), TypeRange{}, callee, - ArrayAttr{}, ArrayAttr{}, operands); - rewriter.eraseOp(op); - } - } - return success(); - } - - std::string callee; -}; - -struct PTODeclareTileMemRefToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::DeclareTileMemRefOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::DeclareTileMemRefOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - Type convertedType = getTypeConverter()->convertType(op.getResult().getType()); - if (!convertedType) - return rewriter.notifyMatchFailure( - op, "failed to convert declare_tile_memref result type"); - rewriter.replaceOp(op, makeEmitCOpaqueConstant(rewriter, op.getLoc(), - convertedType, "nullptr")); - return success(); - } -}; - -struct PTODeclareGlobalToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::DeclareGlobalOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::DeclareGlobalOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - Type convertedType = getTypeConverter()->convertType(op.getEntry().getType()); - if (!convertedType) - return rewriter.notifyMatchFailure( - op, "failed to convert declare_global result type"); - if (auto tvTy = dyn_cast(op.getEntry().getType())) { - if (auto stridesAttr = - op->getAttrOfType(kGlobalTensorStridesAttrName)) { - auto strides = stridesAttr.asArrayRef(); - if (strides.size() == static_cast(tvTy.getRank())) { - convertedType = emitc::OpaqueType::get( - rewriter.getContext(), - getGlobalTensorTypeStringFromShapeAndStrides( - tvTy.getElementType(), tvTy.getShape(), strides)); - } - } - } - auto var = rewriter.create( - op.getLoc(), convertedType, - emitc::OpaqueAttr::get(rewriter.getContext(), "")); - rewriter.replaceOp(op, var.getResult()); - return success(); - } -}; - -struct PTODeclareEventIdArrayToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::DeclareEventIdArrayOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::DeclareEventIdArrayOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - Type arrayTy = getTypeConverter()->convertType(op.getArray().getType()); - if (!arrayTy) - return rewriter.notifyMatchFailure(op, - "failed to map declared eventid_array type"); - - auto array = rewriter - .create( - op.getLoc(), arrayTy, - emitc::OpaqueAttr::get(rewriter.getContext(), "")) - .getResult(); - rewriter.replaceOp(op, array); - return success(); - } -}; - -struct PTOEventIdArrayGetToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::EventIdArrayGetOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::EventIdArrayGetOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value array = peelUnrealized(adaptor.getArray()); - Value index = peelUnrealized(adaptor.getIndex()); - - Type resultTy = getTypeConverter()->convertType(op.getResult().getType()); - if (!resultTy) - return rewriter.notifyMatchFailure(op, - "failed to map eventid_array get result type"); - - auto load = - rewriter.create(op.getLoc(), resultTy, array, index); - rewriter.replaceOp(op, load.getResult()); - return success(); - } -}; - -struct PTOEventIdArraySetToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::EventIdArraySetOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::EventIdArraySetOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value array = peelUnrealized(adaptor.getArray()); - Value index = peelUnrealized(adaptor.getIndex()); - Value value = peelUnrealized(adaptor.getValue()); - - rewriter.create( - op.getLoc(), TypeRange{}, "PTOAS__EVENTID_ARRAY_STORE", - ArrayAttr{}, ArrayAttr{}, ValueRange{array, index, value}); - rewriter.eraseOp(op); - return success(); - } -}; - -// pto.declare_local_array -> emitc.variable of !emitc.array<...>. -// Renders as `T a[D1][D2]...;` in the emitted C++. -struct PTODeclareLocalArrayToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::DeclareLocalArrayOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::DeclareLocalArrayOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - (void)adaptor; - Type arrayTy = getTypeConverter()->convertType(op.getArray().getType()); - if (!arrayTy) - return rewriter.notifyMatchFailure(op, - "failed to map !pto.local_array type"); - - auto var = rewriter - .create( - op.getLoc(), arrayTy, - emitc::OpaqueAttr::get(rewriter.getContext(), "")) - .getResult(); - rewriter.replaceOp(op, var); - return success(); - } -}; - -// pto.local_array_get %a[%i0, %i1, ...] -> rvalue. -// Lowers to a single emitc.subscript with the full index pack; the C++ emitter -// prints it as `a[i0][i1]...`. The adaptor already exposes target-typed values -// (the type converter has remapped !pto.local_array -> !emitc.array and -// index/integer indices), so they're forwarded directly to the builder. -struct PTOLocalArrayGetToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::LocalArrayGetOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::LocalArrayGetOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type resultTy = - getTypeConverter()->convertType(op.getResult().getType()); - if (!resultTy) - return rewriter.notifyMatchFailure( - op, "failed to map local_array element type"); - - auto sub = rewriter.create( - op.getLoc(), resultTy, adaptor.getArray(), adaptor.getIndices()); - rewriter.replaceOp(op, sub.getResult()); - return success(); - } -}; - -// pto.local_array_set %a[%i0, %i1, ...], %v -> emitc.assign to subscript slot. -// The C++ emitter prints this as `a[i0][i1]... = v;`. As above, adaptor values -// are already target-typed; pass them through directly. -struct PTOLocalArraySetToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::LocalArraySetOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::LocalArraySetOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value value = adaptor.getValue(); - Type elemTy = value.getType(); - - Value slot = rewriter - .create(op.getLoc(), elemTy, - adaptor.getArray(), - adaptor.getIndices()) - .getResult(); - rewriter.create(op.getLoc(), slot, value); - rewriter.eraseOp(op); - return success(); - } -}; - -static std::optional getStaticIndexLikeValue(Value value) { - if (!value) - return std::nullopt; - if (auto cst = value.getDefiningOp()) - return cst.value(); - if (auto cst = value.getDefiningOp()) - return cst.value(); - if (auto cst = value.getDefiningOp()) { - if (auto intAttr = dyn_cast(cst.getValue())) - return intAttr.getInt(); - } - return std::nullopt; -} - -static FailureOr buildGlobalTensorViewFromPointer( - ConversionPatternRewriter &rewriter, Location loc, Value ptr, Type elemTy, - ArrayRef shape, ArrayRef strides = {}, - StringRef layoutEnum = "pto::Layout::ND") { - if (llvm::any_of(shape, [](int64_t dim) { - return dim == ShapedType::kDynamic; - })) - return failure(); - - auto *ctx = rewriter.getContext(); - SmallVector rowMajorStrides; - ArrayRef effectiveStrides = strides; - if (effectiveStrides.empty()) { - rowMajorStrides = buildRowMajorStrides(shape); - effectiveStrides = rowMajorStrides; - } - SmallVector shape5D; - SmallVector stride5D; - buildGlobalTensorShapeAndStride(shape, effectiveStrides, shape5D, stride5D); - - std::string shapeType = "pto::Shape<" + joinIntTemplateParams(shape5D) + ">"; - std::string strideType = - "pto::Stride<" + joinIntTemplateParams(stride5D) + ">"; - auto shapeVal = rewriter - .create( - loc, emitc::OpaqueType::get(ctx, shapeType), - shapeType, ArrayAttr{}, ArrayAttr{}, ValueRange{}) - .getResult(0); - auto strideVal = rewriter - .create( - loc, emitc::OpaqueType::get(ctx, strideType), - strideType, ArrayAttr{}, ArrayAttr{}, ValueRange{}) - .getResult(0); - - std::string gtTypeStr = - getGlobalTensorTypeStringFromShapeAndStrides(elemTy, shape, - effectiveStrides, - layoutEnum); - auto gtType = emitc::OpaqueType::get(ctx, gtTypeStr); - auto gt = rewriter.create( - loc, gtType, gtTypeStr, ArrayAttr{}, ArrayAttr{}, - ValueRange{ptr, shapeVal, strideVal}); - return gt.getResult(0); -} - -static bool parseIntegerTemplateList(StringRef token, StringRef marker, - SmallVectorImpl &values) { - size_t pos = token.find(marker); - if (pos == StringRef::npos) - return false; - pos += marker.size(); - size_t end = token.find('>', pos); - if (end == StringRef::npos) - return false; - - SmallVector parts; - token.slice(pos, end).split(parts, ','); - values.clear(); - for (StringRef part : parts) { - int64_t value = 0; - if (part.trim().getAsInteger(10, value)) - return false; - values.push_back(value); - } - return true; -} - -static LogicalResult getStaticTensorViewStrides( - Value source, Value convertedSource, pto::TensorViewType sourceType, - SmallVectorImpl &strides) { - int64_t rank = sourceType.getRank(); - strides.clear(); - - if (auto makeView = source.getDefiningOp()) { - if ((int64_t)makeView.getStrides().size() != rank) - return failure(); - for (Value strideValue : makeView.getStrides()) { - auto cst = getStaticIndexLikeValue(strideValue); - if (!cst) - return failure(); - strides.push_back(*cst); - } - return success(); - } - - Value src = peelUnrealized(convertedSource); - if (auto opaqueTy = dyn_cast(src.getType())) { - SmallVector stride5D; - StringRef token = opaqueTy.getValue(); - if ((parseIntegerTemplateList(token, "pto::Stride<", stride5D) || - parseIntegerTemplateList(token, "Stride<", stride5D)) && - (int64_t)stride5D.size() >= rank) { - strides.append(stride5D.end() - rank, stride5D.end()); - return success(); - } - } - - auto fallback = buildRowMajorStrides(sourceType.getShape()); - strides.append(fallback.begin(), fallback.end()); - return success(); -} - -struct PTOPartitionViewToEmitC - : public OpConversionPattern { - using OpConversionPattern< - mlir::pto::PartitionViewOp>::OpConversionPattern; - - LogicalResult matchAndRewrite(mlir::pto::PartitionViewOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto srcTy = dyn_cast(op.getSource().getType()); - auto resTy = dyn_cast(op.getResult().getType()); - if (!srcTy || !resTy) - return rewriter.notifyMatchFailure( - op, "expected tensor_view source and partition_tensor_view result"); - - if (op.getOffsets().size() != static_cast(srcTy.getRank()) || - op.getSizes().size() != static_cast(srcTy.getRank())) - return rewriter.notifyMatchFailure(op, "rank mismatch"); - - for (auto [idx, value] : llvm::enumerate(op.getSizes())) { - auto cst = getStaticIndexLikeValue(value); - if (!cst) - return rewriter.notifyMatchFailure( - op, "globaltensor partition_view requires static sizes"); - int64_t resultDim = resTy.getShape()[idx]; - if (resultDim != ShapedType::kDynamic && resultDim != *cst) - return rewriter.notifyMatchFailure( - op, "partition_view static size does not match result type"); - } - - SmallVector srcStrides; - if (failed(getStaticTensorViewStrides(op.getSource(), adaptor.getSource(), - srcTy, srcStrides))) - return rewriter.notifyMatchFailure( - op, "partition_view requires static source strides"); - int64_t staticLinearOffset = 0; - SmallVector> dynamicOffsetTerms; - for (auto [idx, values] : - llvm::enumerate(llvm::zip(op.getOffsets(), adaptor.getOffsets()))) { - Value originalOffset = std::get<0>(values); - Value convertedOffset = std::get<1>(values); - int64_t stride = srcStrides[idx]; - if (stride == ShapedType::kDynamic) - return rewriter.notifyMatchFailure( - op, "dynamic source stride is not supported"); - - if (auto cst = getStaticIndexLikeValue(originalOffset)) { - if (*cst != 0) - staticLinearOffset += (*cst) * stride; - continue; - } - dynamicOffsetTerms.push_back({convertedOffset, stride}); - } - - auto *ctx = rewriter.getContext(); - std::string elemTypeStr = getElemTypeStringForGT(srcTy.getElementType()); - auto ptrTy = emitc::PointerType::get( - emitc::OpaqueType::get(ctx, "__gm__ " + elemTypeStr)); - Value src = peelUnrealized(adaptor.getSource()); - auto data = rewriter - .create( - op.getLoc(), ptrTy, "PTOAS__GLOBAL_TENSOR_DATA", - ArrayAttr{}, ArrayAttr{}, ValueRange{src}) - .getResult(0); - Value ptr = data; - if (!dynamicOffsetTerms.empty()) { - Type u32Ty = emitc::OpaqueType::get(ctx, "unsigned"); - auto makeU32 = [&](int64_t value) { - return makeEmitCIntConstant(rewriter, op.getLoc(), u32Ty, value); - }; - auto asU32 = [&](Value value) -> Value { - if (value.getType() == u32Ty) - return value; - return rewriter.create(op.getLoc(), u32Ty, value) - .getResult(); - }; - - Value totalOffset = makeU32(staticLinearOffset); - for (auto [offsetValue, stride] : dynamicOffsetTerms) { - Value term = asU32(offsetValue); - if (stride != 1) { - Value strideValue = makeU32(stride); - term = rewriter - .create(op.getLoc(), u32Ty, term, - strideValue) - .getResult(); - } - totalOffset = rewriter - .create(op.getLoc(), u32Ty, - totalOffset, term) - .getResult(); - } - ptr = rewriter - .create(op.getLoc(), data.getType(), data, - totalOffset) - .getResult(); - } else { - ptr = applyStaticMemrefOffset(rewriter, op.getLoc(), data, - staticLinearOffset); - } - - auto resultOr = buildGlobalTensorViewFromPointer( - rewriter, op.getLoc(), ptr, resTy.getElementType(), resTy.getShape(), - srcStrides); - if (failed(resultOr)) - return rewriter.notifyMatchFailure( - op, "failed to materialize partition GlobalTensor"); - - rewriter.replaceOp(op, *resultOr); - return success(); - } -}; - -static FailureOr getPipeDataTypeToken(Value value) { - auto opaqueTy = dyn_cast(value.getType()); - if (!opaqueTy) - return failure(); - StringRef token = opaqueTy.getValue(); - if (!token.contains("Tile<") && !token.contains("GlobalTensor<")) - return failure(); - return token.str(); -} - -struct PTOTAllocToEmitC : public OpConversionPattern { - PTOTAllocToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult matchAndRewrite(mlir::pto::TAllocOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); - if (failed(pipeTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); - Value entry = peelUnrealized(adaptor.getEntry()); - auto entryTok = getPipeDataTypeToken(entry); - if (failed(entryTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve entry token"); - auto splitTok = getTileSplitToken(op.getSplit()); - if (failed(splitTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve split token"); - - std::string callee = - "TALLOC<" + *pipeTok + ", " + *entryTok + ", " + *splitTok + ">"; - rewriter.replaceOpWithNewOp( - op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, - ValueRange{peelUnrealized(adaptor.getPipeHandle()), entry}); - return success(); - } - - PTOArch targetArch; -}; - -struct PTOTPushToEmitC : public OpConversionPattern { - PTOTPushToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult matchAndRewrite(mlir::pto::TPushOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); - if (failed(pipeTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); - // Read the tile type token from the already-converted OpaqueType, which - // preserves the exact layout produced by BindTileOp / PointerCastOp EmitC. - Value convertedTile = peelUnrealized(adaptor.getTile()); - auto tileTok = getPipeDataTypeToken(convertedTile); - if (failed(tileTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve tile token"); - auto splitTok = getTileSplitToken(op.getSplit()); - if (failed(splitTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve split token"); - - std::string callee = - "TPUSH<" + *pipeTok + ", " + *tileTok + ", " + *splitTok + ">"; - rewriter.replaceOpWithNewOp( - op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, - ValueRange{peelUnrealized(adaptor.getPipeHandle()), convertedTile}); - return success(); - } - - PTOArch targetArch; -}; - -struct PTOTPopToEmitC : public OpConversionPattern { - PTOTPopToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult matchAndRewrite(mlir::pto::TPopOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); - if (failed(pipeTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); - Value convertedTile = peelUnrealized(adaptor.getTile()); - auto tileTok = getPipeDataTypeToken(convertedTile); - if (failed(tileTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve tile token"); - auto splitTok = getTileSplitToken(op.getSplit()); - if (failed(splitTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve split token"); - - std::string callee = - "TPOP<" + *pipeTok + ", " + *tileTok + ", " + *splitTok + ">"; - rewriter.replaceOpWithNewOp( - op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, - ValueRange{peelUnrealized(adaptor.getPipeHandle()), convertedTile}); - return success(); - } - - PTOArch targetArch; -}; - -struct PTOTFreeToEmitC : public OpConversionPattern { - PTOTFreeToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, - PTOArch targetArch) - : OpConversionPattern(typeConverter, ctx), - targetArch(targetArch) {} - - LogicalResult matchAndRewrite(mlir::pto::TFreeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); - if (failed(pipeTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); - auto splitTok = getTileSplitToken(op.getSplit()); - if (failed(splitTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve split token"); - - SmallVector operands{peelUnrealized(adaptor.getPipeHandle())}; - std::string callee; - if (op.getEntry()) { - Value entry = peelUnrealized(adaptor.getEntry()); - auto entryTok = getPipeDataTypeToken(entry); - if (failed(entryTok)) - return rewriter.notifyMatchFailure(op, "failed to resolve entry token"); - callee = "TFREE<" + *pipeTok + ", " + *entryTok + ", " + *splitTok + ">"; - operands.push_back(entry); - } else { - callee = "TFREE<" + *pipeTok + ", " + *splitTok + ">"; - } - rewriter.replaceOpWithNewOp( - op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, operands); - return success(); - } - - PTOArch targetArch; -}; - -//===----------------------------------------------------------------------===// -// populate patterns -//===----------------------------------------------------------------------=== -struct ReinterpretCastToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - auto resMrTy = dyn_cast(op.getType()); - if (!resMrTy) - return failure(); - - auto asAttr = dyn_cast_or_null(resMrTy.getMemorySpace()); - const bool isGm = (!asAttr || asAttr.getAddressSpace() == pto::AddressSpace::GM); - - bool emitAddPtrTrace = op->hasAttr("pto.addptr_trace"); - Value source = peelUnrealized(adaptor.getSource()); - auto offsets = adaptor.getOffsets(); - Value offsetVal = offsets.empty() ? Value() : offsets[0]; - - // GM: keep pointer arithmetic. - if (isGm) { - if (!offsetVal) { - rewriter.replaceOp(op, source); - return success(); - } - - Type resultType = getTypeConverter()->convertType(op.getType()); - if (!resultType) - return failure(); - - auto addOp = rewriter.create(loc, resultType, source, offsetVal); - if (emitAddPtrTrace) { - rewriter.setInsertionPointAfter(addOp); - rewriter.create( - loc, TypeRange{}, "PTOAS__ADDPTR_TRACE", - ArrayAttr{}, ArrayAttr{}, - ValueRange{addOp.getResult(), source, offsetVal}); - } - rewriter.replaceOp(op, addOp.getResult()); - return success(); - } - - // UB/L1/L0 tiles: materialize a new Tile view by assigning an adjusted - // underlying pointer (in elements). - pto::AddressSpace as = asAttr.getAddressSpace(); - - // Element type token. - Type elemTy = resMrTy.getElementType(); - std::string elemTok = getEmitCScalarTypeToken(elemTy); - int64_t elemBytes = getEmitCScalarByteWidth(elemTy); - - // Tile role. - const char *roleTok = "TileType::Vec"; - switch (as) { - case pto::AddressSpace::VEC: - roleTok = "TileType::Vec"; - break; - case pto::AddressSpace::MAT: - roleTok = "TileType::Mat"; - break; - case pto::AddressSpace::LEFT: - roleTok = "TileType::Left"; - break; - case pto::AddressSpace::RIGHT: - roleTok = "TileType::Right"; - break; - case pto::AddressSpace::ACC: - roleTok = "TileType::Acc"; - break; - case pto::AddressSpace::BIAS: - roleTok = "TileType::Bias"; - break; - case pto::AddressSpace::GM: - roleTok = "TileType::Vec"; - break; - case pto::AddressSpace::Zero: - roleTok = "TileType::Vec"; - break; - case pto::AddressSpace::SCALING: - roleTok = "TileType::Scaling"; - break; - } - - // Shape (fallback to 32x32). - int64_t rows = 32, cols = 32; - if (resMrTy.getRank() >= 2 && resMrTy.hasStaticShape()) { - rows = resMrTy.getDimSize(0); - cols = resMrTy.getDimSize(1); - } - int64_t templateRows = - renderTileTemplateDim(rows, elemTy, pto::BLayout::RowMajor, 0); - int64_t templateCols = - renderTileTemplateDim(cols, elemTy, pto::BLayout::RowMajor, 1); - - // Keep a conservative default config for now. - std::string tileTypeStr = - std::string("Tile<") + roleTok + ", " + elemTok + ", " + - std::to_string(templateRows) + ", " + std::to_string(templateCols) + - ", BLayout::RowMajor, " + std::to_string(templateRows) + ", " + - std::to_string(templateCols) + - ", SLayout::NoneBox, 512, PadValue::Null, CompactMode::Null>"; - - auto tileType = emitc::OpaqueType::get(ctx, tileTypeStr); - Value tile = rewriter - .create(loc, tileType, - emitc::OpaqueAttr::get(ctx, "")) - .getResult(); - - // Compute an integer address and assign it to the new tile. - // NOTE: pto-isa TASSIGN requires an integral address (not a pointer). - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - auto rcU64 = rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - - // Non-GM reinterpret_cast operands come from UB/L1/L0 tiles. - // We need the underlying address, but `__cce_get_tile_ptr()` is only valid - // inside `__tf__` functions. Use `tile.data()` (via a post-processed marker) - // and compute the adjusted address in bytes. - Value rawPtr = source; - if (auto ot = dyn_cast(source.getType())) { - // Only Tiles have a `.data()` member. For plain address-space pointers - // (e.g. `__ubuf__ float*`), use the pointer value directly. - if (ot.getValue().starts_with("Tile<")) { - rawPtr = materializeTileDataValue(rewriter, loc, source, as, elemTok); - } - } - - Value baseAddr = rawPtr; - if (isSetFFTsPointerLikeType(rawPtr.getType())) { - baseAddr = rewriter - .create(loc, u64Ty, "reinterpret_cast", - /*args=*/ArrayAttr{}, - /*templateArgs=*/rcU64, - /*operands=*/ValueRange{rawPtr}) - .getResult(0); - } else if (rawPtr.getType() != u64Ty) { - baseAddr = rewriter.create(loc, u64Ty, rawPtr).getResult(); - } - - Value addr = baseAddr; - if (offsetVal) { - Value offU64 = offsetVal; - if (offU64.getType() != u64Ty) - offU64 = rewriter.create(loc, u64Ty, offU64).getResult(); - - auto bytesAttr = emitc::OpaqueAttr::get(ctx, std::to_string(elemBytes)); - Value bytesVal = rewriter.create(loc, u64Ty, bytesAttr); - Value byteOff = rewriter.create(loc, u64Ty, offU64, bytesVal); - addr = rewriter.create(loc, u64Ty, baseAddr, byteOff); - } - - rewriter.create(loc, TypeRange{}, "TASSIGN", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{tile, addr}); - - rewriter.replaceOp(op, tile); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.taddc lowering -> TADDC(dst, src0, src1, src2) -//===----------------------------------------------------------------------===// - -struct PTOTAddCToTADDC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAddCOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value src2 = peelUnrealized(adaptor.getSrc2()); - Value dst = peelUnrealized(adaptor.getDst()); - - // pto-isa does not provide NPU implementation for TADDC yet. - // Decompose: dst = src0 + src1 + src2 - rewriter.create( - loc, TypeRange{}, "TADD", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TADD", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, dst, src2}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tadds lowering -> TADDS(dst, src, scalar) -//===----------------------------------------------------------------------===// - -struct PTOAddSToTADDS : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAddSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value scalar = peelUnrealized(adaptor.getScalar()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TADDS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src, scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.taddsc lowering -> TADDSC(dst, src0, scalar, src1) -//===----------------------------------------------------------------------===// - -struct PTOAddSCToTADDSC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAddSCOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - // pto-isa does not provide NPU implementation for TADDSC yet. - // Decompose: dst = src0 + scalar + src1 - rewriter.create( - loc, TypeRange{}, "TADDS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src0, scalar}); - rewriter.create( - loc, TypeRange{}, "TADD", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, dst, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTOTAndToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAndOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getSrc0()); - Value b = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TAND", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, a, b}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOConcatToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TConcatOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TCONCAT", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTOConcatidxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TConcatidxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value src0Idx = peelUnrealized(adaptor.getSrc0Idx()); - Value src1Idx = peelUnrealized(adaptor.getSrc1Idx()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TCONCAT", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src0, src1, src0Idx, src1Idx}); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTOAndSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TAndSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TANDS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src, scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; - - -struct PTOTCIToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TCIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value S = peelUnrealized(adaptor.getOperands()[0]); - - // The TCI scalar template parameter should follow the original PTO IR - // scalar type, not the converted EmitC value type. - std::string scalarTok = "int32_t"; - if (auto it = dyn_cast(op->getOperand(0).getType())) { - bool isUnsigned = it.isUnsigned(); - if (it.getWidth() == 16) - scalarTok = isUnsigned ? "uint16_t" : "int16_t"; - else - scalarTok = isUnsigned ? "uint32_t" : "int32_t"; - } - - // descending -> "0"/"1" - std::string descTok = op.getDescending() ? "1" : "0"; - - ArrayAttr targs; - if (auto ot = mlir::dyn_cast(dst.getType())) { - std::string tileTok = ot.getValue().str(); // "Tile<...>" - targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, tileTok), - emitc::OpaqueAttr::get(ctx, scalarTok), - emitc::OpaqueAttr::get(ctx, descTok), - }); - } else { - targs = rewriter.getArrayAttr({}); - } - - rewriter.create( - loc, TypeRange{}, "TCI", - /*args=*/ArrayAttr{}, - /*templateArgs=*/targs, - /*operands=*/ValueRange{dst, S}); - - rewriter.eraseOp(op); - return success(); - } -}; -static std::string cmpModeTok(pto::CmpModeAttr a) { - // 生成 "CmpMode::GT" 这种 token - auto m = a.getValue(); // 取 enum - switch (m) { - case pto::CmpMode::EQ: return "CmpMode::EQ"; - case pto::CmpMode::NE: return "CmpMode::NE"; - case pto::CmpMode::LT: return "CmpMode::LT"; - case pto::CmpMode::LE: return "CmpMode::LE"; - case pto::CmpMode::GT: return "CmpMode::GT"; - case pto::CmpMode::GE: return "CmpMode::GE"; - } - return "CmpMode::EQ"; -} -struct PTOColExpandToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPAND", - /*args=*/ArrayAttr(), - /*templateArgs=*/ArrayAttr(), - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandMulToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandMulOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDMUL", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandAddToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandAddOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDADD", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandDivToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandDivOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDDIV", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandExpdifToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandExpdifOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDEXPDIF", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandSubToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandSubOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDSUB", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDMAX", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColExpandMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColExpandMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLEXPANDMIN", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOTTriToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TTriOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value diagonal = peelUnrealized(adaptor.getDiagonal()); - - ArrayAttr templateArgs; - if (auto dstOT = mlir::dyn_cast(dst.getType())) { - templateArgs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, std::to_string(op.getUpperOrLower())), - }); - } else { - templateArgs = ArrayAttr{}; - } - - SmallVector operands{dst, diagonal}; - rewriter.create( - loc, TypeRange{}, "TTRI", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOCmpToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TCmpOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - - std::string tok = "CmpMode::EQ"; - if (auto a = op.getCmpModeAttr()) - tok = cmpModeTok(a); - - auto modeTy = emitc::OpaqueType::get(ctx, "CmpMode"); - Value modeVal = rewriter.create( - loc, modeTy, emitc::OpaqueAttr::get(ctx, tok)); - - rewriter.create( - loc, - TypeRange{}, - "TCMP", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1, modeVal}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOCmpSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TCmpSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - - // cmpMode -> token - auto cmpAttr = op.getCmpModeAttr(); // PTO_CmpModeAttr - std::string tok = cmpModeTok(cmpAttr); - - auto modeTy = emitc::OpaqueType::get(ctx, "CmpMode"); - Value modeVal = rewriter.create( - loc, modeTy, emitc::OpaqueAttr::get(ctx, tok)); - - rewriter.create( - loc, - TypeRange{}, - "TCMPS", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, scalar, modeVal}); - - rewriter.eraseOp(op); - return success(); - } -}; - - -struct PTOColMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - // intrinsic: TCOLMAX(dst, src) - rewriter.create( - loc, TypeRange{}, "TCOLMAX", - /*args=*/ArrayAttr{}, // default: print all operands - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColArgMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColArgMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLARGMAX", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, tmp}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - // intrinsic: TCOLMIN(dst, src) - rewriter.create( - loc, TypeRange{}, "TCOLMIN", - /*args=*/ArrayAttr{}, // default: print all operands - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColArgMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColArgMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLARGMIN", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, tmp}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColSumToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColSumOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - // Check if tmp exists before accessing it - if (op.getTmp()) { - // Format 2: with tmp and isBinary - Value tmp = peelUnrealized(adaptor.getTmp()); - bool isBinary = false; - if (auto a = op.getIsBinaryAttr()) - isBinary = a.getValue(); - - auto boolTy = emitc::OpaqueType::get(ctx, "bool"); - auto tok = isBinary ? "true" : "false"; - Value isBinaryVal = rewriter.create( - loc, boolTy, emitc::OpaqueAttr::get(ctx, tok)); - - rewriter.create( - loc, TypeRange{}, "TCOLSUM", - /*args=*/ArrayAttr(), - /*templateArgs=*/ArrayAttr(), - /*operands=*/ValueRange{dst, src, tmp, isBinaryVal}); - } else { - // Format 1: without tmp and isBinary - rewriter.create( - loc, TypeRange{}, "TCOLSUM", - /*args=*/ArrayAttr(), - /*templateArgs=*/ArrayAttr(), - /*operands=*/ValueRange{dst, src}); - } - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOColProdToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TColProdOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TCOLPROD", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; -static std::string roundModeTok(mlir::pto::RoundModeAttr attr) { - using RM = mlir::pto::RoundMode; - switch (attr.getValue()) { - case RM::NONE: return "RoundMode::CAST_NONE"; - case RM::RINT: return "RoundMode::CAST_RINT"; - case RM::ROUND: return "RoundMode::CAST_ROUND"; - case RM::FLOOR: return "RoundMode::CAST_FLOOR"; - case RM::CEIL: return "RoundMode::CAST_CEIL"; - case RM::TRUNC: return "RoundMode::CAST_TRUNC"; - case RM::ODD: return "RoundMode::CAST_ODD"; - case RM::CAST_RINT: return "RoundMode::CAST_RINT"; - } - return "RoundMode::CAST_RINT"; -} -static std::string saturationModeTok(mlir::pto::SaturationModeAttr attr) { - using SM = mlir::pto::SaturationMode; - switch (attr.getValue()) { - case SM::ON: return "SaturationMode::ON"; - case SM::OFF: return "SaturationMode::OFF"; - } - return "SaturationMode::OFF"; -} -struct PTOCvtToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TCvtOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - pto::RoundModeAttr rmAttr = op.getRmodeAttr(); - std::string rmTok = rmAttr ? roundModeTok(rmAttr) - : std::string("RoundMode::CAST_RINT"); - auto rmodeTy = emitc::OpaqueType::get(ctx, "RoundMode"); - Value rmodeVal = rewriter.create( - loc, rmodeTy, emitc::OpaqueAttr::get(ctx, rmTok)); - - auto satModeTy = emitc::OpaqueType::get(ctx, "SaturationMode"); - auto satAttr = op.getSatModeAttr(); - std::string satTok = satAttr ? saturationModeTok(satAttr) - : std::string("SaturationMode::OFF"); - Value satModeVal = rewriter.create( - loc, satModeTy, emitc::OpaqueAttr::get(ctx, satTok)); - - SmallVector operands{dst, src, rmodeVal, satModeVal}; - - rewriter.create( - loc, TypeRange{}, "TCVT", - /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTORandomToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRandomOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - SmallVector operands{ - dst, - peelUnrealized(adaptor.getKey0()), - peelUnrealized(adaptor.getKey1()), - peelUnrealized(adaptor.getCounter0()), - peelUnrealized(adaptor.getCounter1()), - peelUnrealized(adaptor.getCounter2()), - peelUnrealized(adaptor.getCounter3()), - }; - ArrayAttr templateArgs = rewriter.getArrayAttr( - {emitc::OpaqueAttr::get(ctx, std::to_string(op.getRounds()))}); - - rewriter.create( - loc, TypeRange{}, "PTOAS__TRANDOM", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, operands); - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tdiv lowering -> TDIV(dst, src0, src1) -//===----------------------------------------------------------------------===// - -struct PTODivToTDIV : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TDivOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TDIV", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src0, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tdivs lowering -> TDIVS(dst, src, scalar) or TDIVS(dst, scalar, src) -// Order is determined by operand types: if src is tile_buf, order is (tile, scalar) -// Otherwise, order is (scalar, tile) -//===----------------------------------------------------------------------===// - -struct PTODivSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TDivSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value dst = peelUnrealized(adaptor.getDst()); - // Preserve source order from textual parse: - // ins(tile, scalar) -> TDIVS(dst, tile, scalar) - // ins(scalar, tile) -> TDIVS(dst, scalar, tile) - rewriter.create( - loc, TypeRange{}, "TDIVS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src, scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// pto.tdivs (TDivSOp) lowering -> TDIVS(dst, src, scalar) or TDIVS(dst, scalar, src) -// Order is determined by operand types: if src is tile_buf, order is (tile, scalar) -// Otherwise, order is (scalar, tile) -//===----------------------------------------------------------------------===// - -struct PTOTDivSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TDivSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value dst = peelUnrealized(adaptor.getDst()); - rewriter.create( - loc, TypeRange{}, "TDIVS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src, scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.texp lowering -> TEXP(dst, src) -//===----------------------------------------------------------------------===// - -struct PTOExpToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TExpOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TEXP", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.texpands lowering -> TEXPANDS(dst, scalar) -//===----------------------------------------------------------------------===// - -struct PTOExpandsToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TExpandsOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value scalar = peelUnrealized(adaptor.getScalar()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TEXPANDS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.textract lowering -> TEXTRACT(dst, src, indexRow, indexCol) -//===----------------------------------------------------------------------===// - -struct PTOExtractToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TExtractOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value r0 = peelUnrealized(adaptor.getIndexRow()); - Value c0 = peelUnrealized(adaptor.getIndexCol()); - - rewriter.create( - loc, TypeRange{}, "TEXTRACT", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, r0, c0}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.textract_fp lowering -> TEXTRACT_FP(dst, src, fp, indexRow, indexCol) -//===----------------------------------------------------------------------===// - -struct PTOExtractFPToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TExtractFPOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value fp = peelUnrealized(adaptor.getFp()); - Value dst = peelUnrealized(adaptor.getDst()); - Value r0 = peelUnrealized(adaptor.getIndexRow()); - Value c0 = peelUnrealized(adaptor.getIndexCol()); - - rewriter.create( - loc, TypeRange{}, "TEXTRACT_FP", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, fp, r0, c0}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tinsert lowering -> TINSERT(dst, src, indexRow, indexCol) -// Keep lowering arch-agnostic and let PTO-ISA infer proper A5 path. -//===----------------------------------------------------------------------===// - -struct PTOInsertToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TInsertOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value r0 = peelUnrealized(adaptor.getIndexRow()); - Value c0 = peelUnrealized(adaptor.getIndexCol()); - - rewriter.create( - loc, TypeRange{}, "TINSERT", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, r0, c0}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tinsert_fp lowering -> TINSERT_FP(dst, src, fp, indexRow, indexCol) -//===----------------------------------------------------------------------===// - -struct PTOInsertFPToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TInsertFPOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value fp = peelUnrealized(adaptor.getFp()); - Value dst = peelUnrealized(adaptor.getDst()); - Value r0 = peelUnrealized(adaptor.getIndexRow()); - Value c0 = peelUnrealized(adaptor.getIndexCol()); - - rewriter.create( - loc, TypeRange{}, "TINSERT_FP", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, fp, r0, c0}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tfillpad lowering -> TFILLPAD(dst, src) -//===----------------------------------------------------------------------===// - -struct PTOFillPadToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TFillPadOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TFILLPAD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tfillpad_inplace lowering -> TFILLPAD_INPLACE(dst, src) -//===----------------------------------------------------------------------===// - -struct PTOFillPadInplaceToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TFillPadInplaceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TFILLPAD_INPLACE", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tfillpad_expand lowering -> TFILLPAD_EXPAND(dst, src) -//===----------------------------------------------------------------------===// - -struct PTOFillPadExpandToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TFillPadExpandOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TFILLPAD_EXPAND", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// pto.tgather lowering -// - Index form : TGATHER(dst, src0, indices, tmp) -// - Compare form: TGATHER(dst, src0, kValue, cdst, tmp) -// - Mask form : TGATHER(dst, src0) -//===----------------------------------------------------------------------===// - -[[maybe_unused]] static std::string maskPatternTok(mlir::pto::MaskPatternAttr a) { - - auto v = a.getValue(); // enum - return (std::string("pto::MaskPattern::") + mlir::pto::stringifyMaskPattern(v).str()); -} - -struct PTOGatherToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGatherOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src0 = peelUnrealized(adaptor.getSrc()); - - auto getOpaqueTok = [&](Value v, StringRef name) -> FailureOr { - if (auto ot = mlir::dyn_cast(v.getType())) - return ot.getValue().str(); - return rewriter.notifyMatchFailure(op, (name + " must be emitc::OpaqueType (tile)").str()); - }; - - // Case 1: index-based TGATHER(dst, src0, indices, tmp) - if (Value idx = adaptor.getIndices()) { - idx = peelUnrealized(idx); - Value tmp = peelUnrealized(adaptor.getTmp()); - - rewriter.create( - loc, TypeRange{}, "TGATHER", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, idx, tmp}); - - rewriter.eraseOp(op); - return success(); - } - - // Case 2: compare-based TGATHER( - // dst, src0, kValue, tmp, cdst, offset) - if (Value cdst = adaptor.getCdst()) { - cdst = peelUnrealized(cdst); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value kValue = peelUnrealized(adaptor.getKValue()); - - auto dstTokOr = getOpaqueTok(dst, "dst"); - auto srcTokOr = getOpaqueTok(src0, "src0"); - auto cdstTokOr = getOpaqueTok(cdst, "cdst"); - auto tmpTokOr = getOpaqueTok(tmp, "tmp"); - if (failed(dstTokOr) || failed(srcTokOr) || failed(cdstTokOr) || failed(tmpTokOr)) - return failure(); - - auto cmpAttr = op.getCmpModeAttr(); - std::string cmpTok = cmpAttr ? cmpModeTok(cmpAttr) : "CmpMode::EQ"; - int64_t offset = 0; - if (auto offsetAttr = op.getOffsetAttr()) - offset = offsetAttr.getInt(); - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value offsetVal = makeEmitCIntConstant(rewriter, loc, i32Ty, offset); - - auto targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, *dstTokOr), - emitc::OpaqueAttr::get(ctx, *srcTokOr), - emitc::OpaqueAttr::get(ctx, *tmpTokOr), - emitc::OpaqueAttr::get(ctx, *cdstTokOr), - emitc::OpaqueAttr::get(ctx, cmpTok), - }); - - rewriter.create( - loc, TypeRange{}, "TGATHER", - /*args=*/ArrayAttr{}, /*templateArgs=*/targs, - /*operands=*/ValueRange{dst, src0, kValue, tmp, cdst, offsetVal}); - - rewriter.eraseOp(op); - return success(); - } - - // Case 3: mask-pattern TGATHER(dst, src0) - auto mp = op.getMaskPatternAttr(); - if (!mp) - return rewriter.notifyMatchFailure(op, "expected maskPattern, indices, or cdst on tgather"); - - auto dstTokOr = getOpaqueTok(dst, "dst"); - auto srcTokOr = getOpaqueTok(src0, "src0"); - if (failed(dstTokOr) || failed(srcTokOr)) - return failure(); - - // mp is an EnumAttr; stringify name is "P0101" etc. - // We emit MaskPattern::P0101 (because generated C++ has `using namespace pto;`) - std::string mpTok = std::string("MaskPattern::") + - mlir::pto::stringifyMaskPattern(mp.getValue()).str(); - - auto targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, *dstTokOr), - emitc::OpaqueAttr::get(ctx, *srcTokOr), - emitc::OpaqueAttr::get(ctx, mpTok), - }); - - rewriter.create( - loc, TypeRange{}, "TGATHER", - /*args=*/ArrayAttr{}, - /*templateArgs=*/targs, - /*operands=*/ValueRange{dst, src0}); - - rewriter.eraseOp(op); - return success(); - } -}; - - -struct PTOGatherbToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGatherBOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value offsets = peelUnrealized(adaptor.getOffsets()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TGATHERB", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, offsets}); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// TLOG lowering to EmitC (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - -struct PTOLogToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TLogOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TLOG", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - - - -//===----------------------------------------------------------------------===// -// TLRELU lowering to EmitC (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - - struct PTOLReluToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TLReluOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value slope = peelUnrealized(adaptor.getSlope()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, slope}; - - rewriter.create( - loc, TypeRange{}, "TLRELU", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// TMAX lowering to EmitC (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - -struct PTOMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TMAX", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// TMAXS lowering to EmitC (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - - struct PTOMaxSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMaxSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, scalar}; - rewriter.create( - loc, TypeRange{}, "TMAXS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - - -//===----------------------------------------------------------------------===// -// TMIN lowering to EmitC (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - -struct PTOMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TMIN", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// TMINS lowering to EmitC (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - -//===----------------------------------------------------------------------===// -// TMINS lowering to EmitC (fix APFloat -> FloatAttr) (PTOConvert.cpp) -//===----------------------------------------------------------------------===// - -struct PTOMinsToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMinSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value scalar = peelUnrealized(adaptor.getScalar()); - - SmallVector operands{dst, src, scalar}; - rewriter.create( - loc, TypeRange{}, "TMINS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering for TMOV op -> EmitC) -//===----------------------------------------------------------------------===// - -struct PTOMovToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMovOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value fp; - if (op.getFp()) - fp = peelUnrealized(adaptor.getFp()); - Value preQuantScalar; - if (op.getPreQuantScalar()) - preQuantScalar = peelUnrealized(adaptor.getPreQuantScalar()); - - auto dstOT = mlir::dyn_cast(dst.getType()); - auto srcOT = mlir::dyn_cast(src.getType()); - if (!dstOT || !srcOT) - return rewriter.notifyMatchFailure( - op, "tmov lowering expects opaque dst/src types"); - - auto modeTok = [&](pto::AccToVecMode mode) -> StringRef { - switch (mode) { - case pto::AccToVecMode::SingleModeVec0: - return "pto::AccToVecMode::SingleModeVec0"; - case pto::AccToVecMode::SingleModeVec1: - return "pto::AccToVecMode::SingleModeVec1"; - case pto::AccToVecMode::DualModeSplitM: - return "pto::AccToVecMode::DualModeSplitM"; - case pto::AccToVecMode::DualModeSplitN: - return "pto::AccToVecMode::DualModeSplitN"; - } - llvm_unreachable("unknown AccToVecMode"); - }; - - auto modeAttr = op.getAccToVecModeAttr(); - auto reluTok = [&](pto::ReluPreMode mode) -> StringRef { - switch (mode) { - case pto::ReluPreMode::NoRelu: - return "ReluPreMode::NoRelu"; - case pto::ReluPreMode::NormalRelu: - return "ReluPreMode::NormalRelu"; - } - llvm_unreachable("unknown ReluPreMode"); - }; - - const bool hasFp = static_cast(fp); - const bool hasPreQuantScalar = static_cast(preQuantScalar); - const bool hasMode = static_cast(modeAttr); - const bool reluNonDefault = op.getReluPreMode() != pto::ReluPreMode::NoRelu; - - SmallVector operands{dst, src}; - SmallVector templateArgVec{ - emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), - }; - StringRef callee = "TMOV"; - - if (hasFp) { - auto fpOT = mlir::dyn_cast(fp.getType()); - if (!fpOT) - return rewriter.notifyMatchFailure( - op, "tmov fp lowering expects opaque fp type"); - operands.push_back(fp); - templateArgVec.push_back(emitc::OpaqueAttr::get(ctx, fpOT.getValue().str())); - if (hasMode) - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, modeTok(modeAttr.getValue()))); - if (hasMode || reluNonDefault) - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); - callee = hasMode ? "TMOV" : "TMOV_FP"; - } else if (hasPreQuantScalar) { - operands.push_back(preQuantScalar); - if (hasMode) - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, modeTok(modeAttr.getValue()))); - if (hasMode || reluNonDefault) - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); - } else if (hasMode) { - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, modeTok(modeAttr.getValue()))); - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); - } else if (reluNonDefault) { - templateArgVec.push_back( - emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); - } - - ArrayAttr templateArgs = - templateArgVec.size() == 2 && !hasFp && !hasPreQuantScalar && - !hasMode && !reluNonDefault - ? ArrayAttr{} - : rewriter.getArrayAttr(templateArgVec); - - rewriter.create( - loc, TypeRange{}, callee, - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TMOV_FP DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOMovFPToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMovFPOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value fp = peelUnrealized(adaptor.getFp()); - - // TMOV_FP(dstTileData, cTile, fbTile) - ArrayAttr templateArgs; - auto dstOT = mlir::dyn_cast(dst.getType()); - auto srcOT = mlir::dyn_cast(src.getType()); - auto fpOT = mlir::dyn_cast(fp.getType()); - if (dstOT && srcOT && fpOT) { - templateArgs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, fpOT.getValue().str()), - }); - } else { - templateArgs = ArrayAttr{}; - } - - SmallVector operands{dst, src, fp}; - rewriter.create( - loc, TypeRange{}, "TMOV_FP", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOQuantToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TQuantOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value fp = peelUnrealized(adaptor.getFp()); - - // Optional offset (INT8_ASYM only): passed as pointer (&offset) - Value offsetPtr; - if (op.getOffset()) { - Value offset = peelUnrealized(adaptor.getOffset()); - auto offsetOT = mlir::dyn_cast(offset.getType()); - if (offsetOT) { - offsetPtr = rewriter - .create( - loc, emitc::PointerType::get(offsetOT), "&", offset) - .getResult(); - } - } - - // TQUANT(dst, src, fp[, &offset]) - std::string quantTypeStr = - op.getQuantType() == pto::QuantType::INT8_SYM - ? "pto::QuantType::INT8_SYM" - : "pto::QuantType::INT8_ASYM"; - ArrayAttr templateArgs; - auto dstOT = mlir::dyn_cast(dst.getType()); - auto srcOT = mlir::dyn_cast(src.getType()); - auto fpOT = mlir::dyn_cast(fp.getType()); - if (dstOT && srcOT && fpOT) { - templateArgs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, quantTypeStr), - emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, fpOT.getValue().str()), - }); - } else { - templateArgs = ArrayAttr{}; - } - - SmallVector operands{dst, src, fp}; - if (offsetPtr) - operands.push_back(offsetPtr); - - rewriter.create( - loc, TypeRange{}, "TQUANT", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTODequantToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TDequantOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value scale = peelUnrealized(adaptor.getScale()); - Value offset = peelUnrealized(adaptor.getOffset()); - - // TDEQUANT(dst, src, scale, offset) - ArrayAttr templateArgs; - auto dstOT = mlir::dyn_cast(dst.getType()); - auto srcOT = mlir::dyn_cast(src.getType()); - auto scaleOT = mlir::dyn_cast(scale.getType()); - if (dstOT && srcOT && scaleOT) { - templateArgs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), - emitc::OpaqueAttr::get(ctx, scaleOT.getValue().str()), - }); - } else { - templateArgs = ArrayAttr{}; - } - - rewriter.create( - loc, TypeRange{}, "TDEQUANT", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, - /*operands=*/SmallVector{dst, src, scale, offset}); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TMRGSORT DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOMrgSortToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMrgSortOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - if (op.isFormat1()) { - Value src = peelUnrealized(adaptor.getSrcs().front()); - Value dst = peelUnrealized(adaptor.getDsts().front()); - Value blockLen = peelUnrealized(adaptor.getBlockLen()); - - SmallVector operands{dst, src, blockLen}; - rewriter.create( - loc, TypeRange{}, "TMRGSORT", - ArrayAttr{}, ArrayAttr{}, operands); - } else if (op.isFormat2()) { - // pto-isa API: - // TMRGSORT( - // dst, executedNumList, tmp, src0, src1[, src2[, src3]]); - auto *ctx = rewriter.getContext(); - - Value dst = peelUnrealized(adaptor.getDsts()[0]); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value excuted = peelUnrealized(adaptor.getExcuted()); - - SmallVector srcs; - srcs.reserve(adaptor.getSrcs().size()); - for (Value v : adaptor.getSrcs()) - srcs.push_back(peelUnrealized(v)); - - auto dstOT = mlir::dyn_cast(dst.getType()); - auto tmpOT = mlir::dyn_cast(tmp.getType()); - if (!dstOT || !tmpOT || srcs.size() < 2 || srcs.size() > 4) - return op.emitOpError("format2 expects dst/tmp tilebufs and 2 to 4 srcs"); - - SmallVector targs; - targs.reserve(2 + srcs.size() + 1); - targs.push_back(emitc::OpaqueAttr::get(ctx, dstOT.getValue().str())); - targs.push_back(emitc::OpaqueAttr::get(ctx, tmpOT.getValue().str())); - for (Value v : srcs) { - auto ot = mlir::dyn_cast(v.getType()); - if (!ot) - return op.emitOpError("format2 expects tilebuf srcs"); - targs.push_back(emitc::OpaqueAttr::get(ctx, ot.getValue().str())); - } - targs.push_back(emitc::OpaqueAttr::get(ctx, op.getExhausted() ? "true" : "false")); - ArrayAttr templateArgs = rewriter.getArrayAttr(targs); - - SmallVector operands{dst, excuted, tmp}; - operands.append(srcs.begin(), srcs.end()); - - rewriter.create( - loc, TypeRange{}, "TMRGSORT", - /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, operands); - } else { - return op.emitOpError("unsupported mrgsort_dps format"); - } - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TMUL DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOMulToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMulOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TMUL", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TMULS DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOMulsToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMulSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc0()); - Value dst = peelUnrealized(adaptor.getDst()); - Value scalar = peelUnrealized(adaptor.getScalar()); - - SmallVector operands{dst, src, scalar}; - rewriter.create( - loc, TypeRange{}, "TMULS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TNEG DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTONegToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TNegOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TNEG", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TNOT DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTONotToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TNotOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TNOT", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TOR DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOOrToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TOrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TOR", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TORS DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOOrsToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TOrSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - // NOTE: The conversion type system may materialize integers as emitc.opaque - // (e.g. "int32_t"). For EmitC call emission we can pass the scalar through - // directly without arith casts here. - Value s = adaptor.getScalar(); - - SmallVector operands{dst, src0, s}; - rewriter.create( - loc, TypeRange{}, "TORS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TPARTADD DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOPartAddToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPartAddOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TPARTADD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TPARTMAX DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOPartMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPartMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TPARTMAX", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TPARTMIN DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOPartMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPartMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TPARTMIN", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOPartArgMaxToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPartArgMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value src0Idx = peelUnrealized(adaptor.getSrc0Idx()); - Value src1Idx = peelUnrealized(adaptor.getSrc1Idx()); - Value dst = peelUnrealized(adaptor.getDst()); - Value dstIdx = peelUnrealized(adaptor.getDstIdx()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TPARTARGMAX", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1, dstIdx, src0Idx, src1Idx}); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOPartArgMinToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPartArgMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value src0Idx = peelUnrealized(adaptor.getSrc0Idx()); - Value src1Idx = peelUnrealized(adaptor.getSrc1Idx()); - Value dst = peelUnrealized(adaptor.getDst()); - Value dstIdx = peelUnrealized(adaptor.getDstIdx()); - - rewriter.create( - op.getLoc(), TypeRange{}, "TPARTARGMIN", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1, dstIdx, src0Idx, src1Idx}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TPARTMUL DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOPartMulToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPartMulOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TPARTMUL", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TPRELU DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOPreluToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPReluOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - // C++ interface: TPRELU(dst, src0, src1, tmp) — last parameter is tmp. - SmallVector operands{dst, src0, src1, tmp}; - rewriter.create( - loc, TypeRange{}, "TPRELU", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TRECIP DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORecipToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRecipOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TRECIP", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TRELU DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOReluToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TReluOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TRELU", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TREM DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORemToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRemOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - SmallVector operands{dst, src0, src1, tmp}; - rewriter.create( - loc, TypeRange{}, "TREM", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOFModToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TFModOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TFMOD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TREMS DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORemSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRemSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - Value scalar = peelUnrealized(adaptor.getScalar()); - SmallVector operands{dst, src, scalar, tmp}; - rewriter.create( - loc, TypeRange{}, "TREMS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOFModSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TFModSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value scalar = peelUnrealized(adaptor.getScalar()); - - SmallVector operands{dst, src, scalar}; - rewriter.create( - loc, TypeRange{}, "TFMODS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWEXPAND DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORowExpandToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TROWEXPAND", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowExpandAddToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandAddOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TROWEXPANDADD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowExpandExpdifToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandExpdifOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src0, src1, tmp}); - else - operands.assign({dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TROWEXPANDEXPDIF", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWEXPANDDIV DPS/memref op) -//===----------------------------------------------------------------------===// -// Helper: replace or erase based on whether op has results. -static void replaceOrEraseWithOpaqueCall(Operation *op, - StringRef callee, - ArrayRef args, - ConversionPatternRewriter &rewriter) { - TypeRange resultTypes = op->getResultTypes(); - auto call = rewriter.create( - op->getLoc(), resultTypes, callee, ArrayAttr{}, ArrayAttr{}, ValueRange(args)); - if (resultTypes.empty()) - rewriter.eraseOp(op); - else - rewriter.replaceOp(op, call.getResults()); -} - -static void replaceOrEraseWithOpaqueCallAndReturnDst(Operation *op, Value dst, - StringRef callee, - ArrayRef args, - ConversionPatternRewriter &rewriter) { - rewriter.create( - op->getLoc(), TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, ValueRange(args)); - if (op->getNumResults() == 1) - rewriter.replaceOp(op, dst); - else - rewriter.eraseOp(op); -} - -// ---------- TOp ---------- -struct PTOTGemvBiasToTGEMV_BIAS - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGemvBiasOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getA()); - Value b = peelUnrealized(adaptor.getB()); - Value bias = peelUnrealized(adaptor.getBias()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCall(op.getOperation(), "TGEMV_BIAS", - {dst, a, b, bias}, rewriter); - return success(); - } -}; - -struct PTOTGemvMXToTGEMV_MX - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGemvMxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getA()); - Value aScale = peelUnrealized(adaptor.getAScale()); - Value b = peelUnrealized(adaptor.getB()); - Value bScale = peelUnrealized(adaptor.getBScale()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", - {dst, a, aScale, b, bScale}, rewriter); - return success(); - } -}; - -struct PTOTGemvMXAccToTGEMV_MX - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGemvMxAccOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value cIn = peelUnrealized(adaptor.getCIn()); - Value a = peelUnrealized(adaptor.getA()); - Value aScale = peelUnrealized(adaptor.getAScale()); - Value b = peelUnrealized(adaptor.getB()); - Value bScale = peelUnrealized(adaptor.getBScale()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", - {dst, cIn, a, aScale, b, bScale}, rewriter); - return success(); - } -}; - -struct PTOTGemvMXBiasToTGEMV_MX - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TGemvMxBiasOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getA()); - Value aScale = peelUnrealized(adaptor.getAScale()); - Value b = peelUnrealized(adaptor.getB()); - Value bScale = peelUnrealized(adaptor.getBScale()); - Value bias = peelUnrealized(adaptor.getBias()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", - {dst, a, aScale, b, bScale, bias}, rewriter); - return success(); - } -}; - -struct PTOTMatmulBiasToTMATMUL_BIAS - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMatmulBiasOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getA()); - Value b = peelUnrealized(adaptor.getB()); - Value bias = peelUnrealized(adaptor.getBias()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_BIAS", - {dst, a, b, bias}, rewriter); - return success(); - } -}; - -struct PTOTMatmulMXToTMATMUL_MX - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMatmulMxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getA()); - Value aScale = peelUnrealized(adaptor.getAScale()); - Value b = peelUnrealized(adaptor.getB()); - Value bScale = peelUnrealized(adaptor.getBScale()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_MX", - {dst, a, aScale, b, bScale}, rewriter); - return success(); - } -}; - -struct PTOTMatmulMXAccToTMATMUL_MX_ACC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMatmulMxAccOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value cIn = peelUnrealized(adaptor.getCIn()); - Value a = peelUnrealized(adaptor.getA()); - Value aScale = peelUnrealized(adaptor.getAScale()); - Value b = peelUnrealized(adaptor.getB()); - Value bScale = peelUnrealized(adaptor.getBScale()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_MX", - {dst, cIn, a, aScale, b, bScale}, rewriter); - return success(); - } -}; - -struct PTOTMatmulMXBiasToTMATMUL_MX_BIAS - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TMatmulMxBiasOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = peelUnrealized(adaptor.getA()); - Value aScale = peelUnrealized(adaptor.getAScale()); - Value b = peelUnrealized(adaptor.getB()); - Value bScale = peelUnrealized(adaptor.getBScale()); - Value bias = peelUnrealized(adaptor.getBias()); - Value dst = peelUnrealized(adaptor.getDst()); - - replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_MX", - {dst, a, aScale, b, bScale, bias}, rewriter); - return success(); - } -}; - -struct PTORowExpandDivToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandDivOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src0, src1, tmp}); - else - operands.assign({dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TROWEXPANDDIV", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWEXPANDMUL DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORowExpandMulToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandMulOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src0, src1, tmp}); - else - operands.assign({dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TROWEXPANDMUL", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWEXPANDSUB DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORowExpandSubToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandSubOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src0, src1, tmp}); - else - operands.assign({dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TROWEXPANDSUB", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowExpandMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src0, src1, tmp}); - else - operands.assign({dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TROWEXPANDMAX", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowExpandMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowExpandMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src0, src1, tmp}); - else - operands.assign({dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TROWEXPANDMIN", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWMAX DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORowMaxToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, tmp}; - rewriter.create( - loc, TypeRange{}, "TROWMAX", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowArgMaxToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowArgMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TROWARGMAX", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, tmp}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWMIN DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORowMinToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, tmp}; - rewriter.create( - loc, TypeRange{}, "TROWMIN", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowArgMinToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowArgMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - rewriter.create( - loc, TypeRange{}, "TROWARGMIN", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, tmp}); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TROWSUM DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTORowSumToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowSumOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, tmp}; - rewriter.create( - loc, TypeRange{}, "TROWSUM", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTORowProdToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRowProdOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, tmp}; - rewriter.create( - loc, TypeRange{}, "TROWPROD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TRSQRT DPS/memref op) -// - no-tmp form : TRSQRT(dst, src) -// - tmp form : TRSQRT(dst, src, tmp) -//===----------------------------------------------------------------------===// - -struct PTORsqrtToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TRsqrtOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - SmallVector operands{dst, src}; - if (Value tmp = adaptor.getTmp()) - operands.push_back(peelUnrealized(tmp)); - rewriter.create( - loc, TypeRange{}, "TRSQRT", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSCATTER DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOScatterToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TScatterOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - const bool hasMaskPattern = static_cast(op.getMaskPatternAttr()); - const bool hasIndexes = static_cast(op.getIndexes()); - if (hasMaskPattern == hasIndexes) { - return rewriter.notifyMatchFailure( - op, "expected exactly one of indexes operand or maskPattern attribute"); - } - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - if (auto mp = op.getMaskPatternAttr()) { - auto *ctx = rewriter.getContext(); - auto targs = rewriter.getArrayAttr({ - emitc::OpaqueAttr::get(ctx, maskPatternTok(mp)), - }); - rewriter.create( - loc, TypeRange{}, "TSCATTER", - /*args=*/ArrayAttr{}, /*templateArgs=*/targs, - /*operands=*/ValueRange{dst, src}); - } else { - Value idx = peelUnrealized(adaptor.getIndexes()); - rewriter.create( - loc, TypeRange{}, "TSCATTER", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src, idx}); - } - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSEL DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSelToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSelOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value mask = peelUnrealized(adaptor.getMask()); - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, mask, src0, src1, tmp}; - rewriter.create( - loc, TypeRange{}, "TSEL", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSELS DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSelSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSelSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value mask = peelUnrealized(adaptor.getMask()); - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, mask, src, tmp, scalar}; - rewriter.create( - loc, TypeRange{}, "TSELS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSHL DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOShlSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TShlOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TSHL", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSHR DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOShrSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TShrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TSHR", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering for TSHLS/TSHRS DPS: shift by scalar) -//===----------------------------------------------------------------------===// - -struct PTOShlSConstToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TShlSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - SmallVector operands{dst, src, scalar}; - rewriter.create( - loc, TypeRange{}, "TSHLS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - rewriter.eraseOp(op); - return success(); - } -}; - -struct PTOShrSConstToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TShrSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value dst = peelUnrealized(adaptor.getDst()); - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - SmallVector operands{dst, src, scalar}; - rewriter.create( - loc, TypeRange{}, "TSHRS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (TSORT32 DPS/memref op: ins(src, idx[, tmp]) outs(dst)) -//===----------------------------------------------------------------------===// - -struct PTOSORT32SToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSort32Op op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - Value idx = peelUnrealized(adaptor.getIdx()); - Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); - - SmallVector operands; - if (tmp) - operands.assign({dst, src, idx, tmp}); - else - operands.assign({dst, src, idx}); - rewriter.create( - loc, TypeRange{}, "TSORT32", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSQRT DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSqrtSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSqrtOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src}; - rewriter.create( - loc, TypeRange{}, "TSQRT", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSTORE_FP DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOStoreFPSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TStoreFPOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value fp = peelUnrealized(adaptor.getFp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, fp}; - rewriter.create( - loc, TypeRange{}, "TSTORE_FP", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSUB DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSubSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSubOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TSUB", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSUBC DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSubCSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSubCOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value src2 = peelUnrealized(adaptor.getSrc2()); - Value dst = peelUnrealized(adaptor.getDst()); - - // pto-isa does not provide NPU implementation for TSUBC yet. - // Decompose: dst = src0 - src1 + src2 - rewriter.create( - loc, TypeRange{}, "TSUB", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TADD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, dst, src2}); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSUBS DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSubSSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSubSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, scalar}; - rewriter.create( - loc, TypeRange{}, "TSUBS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TSUBSC DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOSubSCToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TSubSCOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - - // pto-isa does not provide NPU implementation for TSUBSC yet. - // Decompose: dst = src0 - scalar + src1 - rewriter.create( - loc, TypeRange{}, "TSUBS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src0, scalar}); - rewriter.create( - loc, TypeRange{}, "TADD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, dst, src1}); - - rewriter.eraseOp(op); - return success(); - } -}; - - -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TXOR DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOXORToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TXorOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src0 = peelUnrealized(adaptor.getSrc0()); - Value src1 = peelUnrealized(adaptor.getSrc1()); - Value dst = peelUnrealized(adaptor.getDst()); - Value tmp = peelUnrealized(adaptor.getTmp()); - SmallVector operands{dst, src0, src1, tmp}; - rewriter.create( - loc, TypeRange{}, "TXOR", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -struct PTOTTransToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TTransOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, tmp}; - rewriter.create( - loc, TypeRange{}, "TTRANS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; -//===----------------------------------------------------------------------===// -// PTOConvert.cpp (add lowering + patterns.add for TXORS DPS/memref op) -//===----------------------------------------------------------------------===// - -struct PTOXORSToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TXorSOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - Value scalar = peelUnrealized(adaptor.getScalar()); - Value tmp = peelUnrealized(adaptor.getTmp()); - Value dst = peelUnrealized(adaptor.getDst()); - - SmallVector operands{dst, src, scalar, tmp}; - rewriter.create( - loc, TypeRange{}, "TXORS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - struct PTOPrintToTPRINT : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TPrintOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value src = peelUnrealized(adaptor.getSrc()); - - SmallVector operands{src}; - rewriter.create( - loc, TypeRange{}, "TPRINT", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); - - rewriter.eraseOp(op); - return success(); - } -}; - -// pto.print "format", %scalar -> PRINTF("format", scalar) -struct PTOPrintOpToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::PrintOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - - std::string fmt = op.getFormat().str(); - if (fmt.empty()) - fmt = "%f"; - std::string quoted = "\""; - for (char c : fmt) { - if (c == '"' || c == '\\') - quoted += '\\'; - else if (c == '\n') - quoted += "\\n"; - else if (c == '\t') - quoted += "\\t"; - else - quoted += c; - } - quoted += "\""; - - Value scalar = peelUnrealized(adaptor.getScalar()); - auto argsAttr = rewriter.getArrayAttr( - {emitc::OpaqueAttr::get(ctx, quoted), - IntegerAttr::get(IndexType::get(ctx), 0)}); - rewriter.create( - loc, TypeRange{}, "cce::printf", - /*args=*/argsAttr, - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{scalar}); - - rewriter.eraseOp(op); - return success(); - } -}; - -// pto.trap -> TRAP() -struct PTOTrapOpToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TrapOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - rewriter.create( - loc, TypeRange{}, "trap", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{}); - - rewriter.eraseOp(op); - return success(); - } -}; - -// ============================================================================= -// 2. BindTileOp Lowering (FIX: Trace back to physical address) -// ============================================================================= -struct PTOBindTileToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - struct TileBuildSpec { - std::string tileTypeStr; - bool useConstructor = false; - SmallVector constructorArgs; - }; - - static bool getIndexConst(Value v, int64_t &out) { - if (!v) - return false; - if (auto cst = v.getDefiningOp()) { - if (auto ia = dyn_cast(cst.getValue())) { - out = ia.getValue().getSExtValue(); - return true; + auto elemTypeToString = [&](Type elemTy) -> std::string { + if (elemTy.isF16()) + return "half"; + if (elemTy.isBF16()) + return "bfloat16_t"; + if (elemTy.isF32()) + return "float"; + if (elemTy.isF64()) + return "double"; + if (elemTy.isInteger(8)) { + if (elemTy.isSignlessInteger(8) || elemTy.isSignedInteger(8)) + return "int8_t"; + return "uint8_t"; } - } - return false; - } - - static bool getTilePointerStrides(pto::TileBufConfigAttr configAttr, - Type elemTy, int64_t rows, int64_t cols, - int64_t &rowStride, - int64_t &colStride) { - if (rows == ShapedType::kDynamic || cols == ShapedType::kDynamic) - return false; - - int32_t blVal = 0; - if (auto blAttr = dyn_cast(configAttr.getBLayout())) - blVal = static_cast(blAttr.getValue()); - else if (auto intAttr = dyn_cast(configAttr.getBLayout())) - blVal = static_cast(intAttr.getInt()); - - int32_t slVal = 0; - if (auto slAttr = dyn_cast(configAttr.getSLayout())) - slVal = static_cast(slAttr.getValue()); - else if (auto intAttr = dyn_cast(configAttr.getSLayout())) - slVal = static_cast(intAttr.getInt()); - - bool boxed = slVal != 0; - int64_t innerRows = 1; - int64_t innerCols = 1; - if (boxed) { - int32_t fractal = 512; - if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) - fractal = static_cast(frAttr.getInt()); - - unsigned elemBytes = pto::getPTOStorageElemByteSize(elemTy); - if (elemBytes == 0) - return false; - - switch (fractal) { - case 1024: - innerRows = 16; - innerCols = 16; - break; - case 32: - innerRows = 16; - innerCols = 2; - break; - case 512: - if (slVal == 1) { - innerRows = 16; - innerCols = 32 / elemBytes; - } else if (slVal == 2) { - innerRows = 32 / elemBytes; - innerCols = 16; - } else { - return false; - } - break; - default: - return false; + if (elemTy.isInteger(16)) { + if (elemTy.isSignlessInteger(16) || elemTy.isSignedInteger(16)) + return "int16_t"; + return "uint16_t"; } - if (innerRows <= 0 || innerCols <= 0) - return false; - } - - if (!boxed) { - if (blVal == 1) { - rowStride = 1; - colStride = rows; - } else { - rowStride = cols; - colStride = 1; + if (elemTy.isInteger(32)) { + if (elemTy.isSignlessInteger(32) || elemTy.isSignedInteger(32)) + return "int32_t"; + return "uint32_t"; } - return true; - } - - if (blVal == 1) { - if (slVal != 1) - return false; - rowStride = innerCols; - colStride = rows; - return true; - } - - rowStride = cols; - colStride = innerRows; - return true; - } - - LogicalResult matchAndRewrite(pto::BindTileOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto *ctx = rewriter.getContext(); - auto configAttr = op.getConfigAttr(); - auto viewSemantics = op->getAttrOfType("pto.view_semantics"); - bool isSubView = viewSemantics && viewSemantics.getValue() == "subview"; - - auto peelAllCasts = [](Value v) { - while (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(0); - if (auto castOp = v.getDefiningOp()) - v = castOp.getOperand(); - return v; - }; - auto isTileLike = [](Value v) -> bool { - auto ot = dyn_cast(v.getType()); - if (!ot) - return false; - StringRef s = ot.getValue(); - return s.contains("Tile<") || s.contains("ConvTile<"); - }; - auto buildTileSpec = [&]() -> FailureOr { - auto resMrTy = dyn_cast(op.getType()); - if (!resMrTy) - return failure(); - - const char *roleTok = "TileType::Vec"; - if (auto asAttr = - dyn_cast_or_null(resMrTy.getMemorySpace())) { - switch (asAttr.getAddressSpace()) { - case pto::AddressSpace::VEC: - roleTok = "TileType::Vec"; - break; - case pto::AddressSpace::MAT: - roleTok = "TileType::Mat"; - break; - case pto::AddressSpace::LEFT: - roleTok = "TileType::Left"; - break; - case pto::AddressSpace::RIGHT: - roleTok = "TileType::Right"; - break; - case pto::AddressSpace::ACC: - roleTok = "TileType::Acc"; - break; - case pto::AddressSpace::BIAS: - roleTok = "TileType::Bias"; - break; - case pto::AddressSpace::SCALING: - roleTok = "TileType::Scaling"; - break; - case pto::AddressSpace::GM: - case pto::AddressSpace::Zero: - roleTok = "TileType::Vec"; - break; - } + if (elemTy.isInteger(64)) { + return cast(elemTy).isUnsigned() ? "uint64_t" : "int64_t"; } + return "float"; + }; - Type elemTy = resMrTy.getElementType(); - Type emitElemTy = getTypeConverter()->convertType(elemTy); - if (!emitElemTy) - return failure(); - auto emitElemOpaque = dyn_cast(emitElemTy); - if (!emitElemOpaque) - return failure(); - std::string elemTypeStr = emitElemOpaque.getValue().str(); - - if (resMrTy.getRank() < 2) - return failure(); - int64_t rows = resMrTy.getDimSize(0); - int64_t cols = resMrTy.getDimSize(1); - if (rows == ShapedType::kDynamic || cols == ShapedType::kDynamic) - return failure(); + // ------------------------------------------------------------------------- + // Part 1: 指针偏移计算 (Runtime Pointer Arithmetic) + // ------------------------------------------------------------------------- + + // 准备类型: unsigned + Type u32Ty = emitc::OpaqueType::get(ctx, "unsigned"); + + // Helper: 创建 unsigned 常量 + auto mkU32 = [&](int64_t v) -> Value { + return rewriter.create( + loc, u32Ty, emitc::OpaqueAttr::get(ctx, std::to_string(v))); + }; - std::string blTok = "BLayout::RowMajor"; - if (auto blAttr = dyn_cast(configAttr.getBLayout())) { - if (static_cast(blAttr.getValue()) == 1) - blTok = "BLayout::ColMajor"; - } - pto::BLayout blayout = getTileBufBLayoutValue(configAttr); - - if (isSubView) { - auto subMrTy = dyn_cast(op.getSource().getType()); - auto subViewOp = op.getSource().getDefiningOp(); - if (subMrTy && subMrTy.getRank() >= 2 && subViewOp) { - int64_t subRows = subMrTy.getDimSize(0); - int64_t subCols = subMrTy.getDimSize(1); - SmallVector inheritedStrides; - int64_t inheritedOffset = ShapedType::kDynamic; - - if (!pto::isPTOFloat4PackedType(elemTy) && - subRows != ShapedType::kDynamic && - subCols != ShapedType::kDynamic && - succeeded(getStridesAndOffset(subMrTy, inheritedStrides, - inheritedOffset)) && - inheritedStrides.size() >= 2) { - int64_t childRowStride = 0; - int64_t childColStride = 0; - bool sameStrides = getTilePointerStrides( - configAttr, elemTy, subRows, subCols, childRowStride, - childColStride); - sameStrides = sameStrides && - inheritedStrides[0] == childRowStride && - inheritedStrides[1] == childColStride; - if (sameStrides) { - rows = subRows; - cols = subCols; - } - } - } + // Helper: 将 OpFoldResult 转为 EmitC Value (用于计算) + auto ofrToEmitCValue = [&](OpFoldResult ofr) -> Value { + if (auto v = ofr.dyn_cast()) { + Value rv = rewriter.getRemappedValue(v); + // 如果类型不匹配,插入 Cast + if (rv.getType() != u32Ty) + return rewriter.create(loc, u32Ty, rv).getResult(); + return rv; } - - std::string slTok = "SLayout::NoneBox"; - if (auto slAttr = dyn_cast(configAttr.getSLayout())) { - int32_t slVal = static_cast(slAttr.getValue()); - slTok = (slVal == 1) ? "SLayout::RowMajor" - : (slVal == 2) ? "SLayout::ColMajor" - : "SLayout::NoneBox"; + if (auto attr = ofr.dyn_cast()) { + if (auto ia = dyn_cast(attr)) + return mkU32(ia.getValue().getSExtValue()); } + return mkU32(0); + }; - int32_t fractal = 512; - if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) - fractal = frAttr.getInt(); - - std::string padTok = "PadValue::Null"; - if (auto padAttr = dyn_cast(configAttr.getPad())) { - switch (static_cast(padAttr.getValue())) { - case 1: - padTok = "PadValue::Zero"; - break; - case 2: - padTok = "PadValue::Max"; - break; - case 3: - padTok = "PadValue::Min"; - break; - default: - padTok = "PadValue::Null"; - break; - } - } + // 1. 获取 Source 的 Strides (支持动态 Stride 收集) + SmallVector sourceStrides; - std::string compactTok = "CompactMode::Null"; - if (auto compactAttr = dyn_cast(configAttr.getCompactMode())) { - switch (static_cast(compactAttr.getValue())) { - case 1: - compactTok = "CompactMode::Normal"; - break; - case 2: - compactTok = "CompactMode::RowPlusOne"; - break; - default: - compactTok = "CompactMode::Null"; - break; + if (auto rc = op.getSource().getDefiningOp()) { + sourceStrides = rc.getMixedStrides(); + } else { + SmallVector strideInts; + int64_t offset = ShapedType::kDynamic; + bool useTypeStrides = succeeded(getStridesAndOffset(srcType, strideInts, offset)); + (void)offset; + if (useTypeStrides) { + for (int64_t s : strideInts) { + if (s == ShapedType::kDynamic) + useTypeStrides = false; + } } - } - - std::string vrowTok, vcolTok; - bool useConstructor = false; - bool rowIsDynamic = false; - bool colIsDynamic = false; - SmallVector constructorArgs; - - Value vRow = op.getValidRow(); - Value vCol = op.getValidCol(); - Value vRowEmitC = adaptor.getValidRow(); - Value vColEmitC = adaptor.getValidCol(); - bool forceDynamicValid = op->hasAttr(kForceDynamicValidShapeAttrName); - int64_t cRow = 0, cCol = 0; - bool rowIsConst = vRow && getIndexConst(vRow, cRow); - bool colIsConst = vCol && getIndexConst(vCol, cCol); - - auto makeCtorDimValue = [&](Value emitted, int64_t fallback) -> Value { - if (emitted) - return emitted; - return makeEmitCIntConstant( - rewriter, loc, emitc::OpaqueType::get(ctx, "int32_t"), fallback); - }; - auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { - if (!emitted || !pto::isPTOFloat4PackedType(elemTy)) - return emitted; - int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; - if (dimIdx != packedDim) - return emitted; - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); - return rewriter.create(loc, i32Ty, emitted, two).getResult(); - }; - - if (forceDynamicValid) { - vrowTok = "-1"; - vcolTok = "-1"; - useConstructor = true; - constructorArgs.push_back( - makeCtorDimValue(maybeScaleDynamicValid(vRowEmitC, 0), - renderTileTemplateDim(rowIsConst ? cRow : rows, - elemTy, blayout, 0))); - constructorArgs.push_back( - makeCtorDimValue(maybeScaleDynamicValid(vColEmitC, 1), - renderTileTemplateDim(colIsConst ? cCol : cols, - elemTy, blayout, 1))); - } else { - if (rowIsConst) { - vrowTok = std::to_string( - renderTileTemplateDim(cRow, elemTy, blayout, 0)); - } else if (vRow) { - vrowTok = "-1"; - rowIsDynamic = true; - useConstructor = true; + if (useTypeStrides) { + for (int64_t s : strideInts) { + sourceStrides.push_back(rewriter.getIndexAttr(s)); + } } else { - vrowTok = std::to_string( - renderTileTemplateDim(rows, elemTy, blayout, 0)); + // Fallback: Compact Layout + auto shape = srcType.getShape(); + int64_t current = 1; + sourceStrides.resize(rank); + for (int i = rank - 1; i >= 0; --i) { + sourceStrides[i] = rewriter.getIndexAttr(current); + if (shape[i] != ShapedType::kDynamic) current *= shape[i]; + } } + } - if (colIsConst) { - vcolTok = std::to_string( - renderTileTemplateDim(cCol, elemTy, blayout, 1)); - } else if (vCol) { - vcolTok = "-1"; - colIsDynamic = true; - useConstructor = true; - } else { - vcolTok = std::to_string( - renderTileTemplateDim(cols, elemTy, blayout, 1)); - } + // 2. 计算运行时 Offset + auto staticOffsets = op.getStaticOffsets(); + auto dynamicOffsets = adaptor.getOffsets(); + int dynOffIdx = 0; + Value totalOffset = mkU32(0); - if (useConstructor) { - if (rowIsDynamic && vRowEmitC) - constructorArgs.push_back(maybeScaleDynamicValid(vRowEmitC, 0)); - if (colIsDynamic && vColEmitC) - constructorArgs.push_back(maybeScaleDynamicValid(vColEmitC, 1)); + for (int i = 0; i < rank; ++i) { + // A. 获取 Offset + Value offVal; + if (staticOffsets[i] == ShapedType::kDynamic) { + Value rawDyn = dynamicOffsets[dynOffIdx++]; + offVal = rewriter.create(loc, u32Ty, rawDyn); + } else { + offVal = mkU32(staticOffsets[i]); } - } - - std::string tileTypeStr = std::string("Tile<") + roleTok + ", " + - elemTypeStr + ", " + - std::to_string(renderTileTemplateDim( - rows, elemTy, blayout, 0)) + - ", " + - std::to_string(renderTileTemplateDim( - cols, elemTy, blayout, 1)) + - ", " + blTok + - ", " + vrowTok + ", " + vcolTok + ", " + slTok + - ", " + std::to_string(fractal) + ", " + padTok + - ", " + compactTok + - ">"; - return TileBuildSpec{tileTypeStr, useConstructor, constructorArgs}; - }; - - auto buildTileValue = [&](const TileBuildSpec &spec, - bool forceDeclaration = false) -> Value { - auto tileType = emitc::OpaqueType::get(ctx, spec.tileTypeStr); - if (spec.useConstructor && !forceDeclaration) { - return rewriter - .create(loc, tileType, spec.tileTypeStr, - ArrayAttr{}, ArrayAttr{}, - ValueRange(spec.constructorArgs)) - .getResult(0); - } - - return rewriter - .create(loc, tileType, emitc::OpaqueAttr::get(ctx, "")) - .getResult(); - }; - - auto emitElemTypeToString = [&](Type elemTy) -> std::string { - return getEmitCScalarTypeToken(elemTy); - }; - auto buildIntegralAddress = [&](Value sourceValue) -> FailureOr { - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - auto rcU64 = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - - Value rawPtr = sourceValue; - if (auto ot = dyn_cast(sourceValue.getType())) { - StringRef tyStr = ot.getValue(); - if (tyStr.contains("Tile<") || tyStr.contains("ConvTile<")) { - auto srcMrTy = dyn_cast(op.getSource().getType()); - if (!srcMrTy) - return failure(); - std::string elemTok = emitElemTypeToString(srcMrTy.getElementType()); - pto::AddressSpace as = pto::AddressSpace::GM; - if (auto asAttr = - dyn_cast_or_null(srcMrTy.getMemorySpace())) - as = asAttr.getAddressSpace(); - rawPtr = materializeTileDataValue(rewriter, loc, sourceValue, as, - elemTok); + // B. 获取 Stride (用于指针计算) + Value strideVal = mkU32(1); + if (i < (int)sourceStrides.size()) { + strideVal = ofrToEmitCValue(sourceStrides[i]); } - } - - if (isSetFFTsPointerLikeType(rawPtr.getType())) { - return rewriter - .create(loc, u64Ty, "reinterpret_cast", - ArrayAttr{}, rcU64, ValueRange{rawPtr}) - .getResult(0); - } - - if (rawPtr.getType() == u64Ty) - return rawPtr; - return rewriter.create(loc, u64Ty, rawPtr).getResult(); - }; - - if (op.getSource().getDefiningOp()) { - FailureOr tileSpec = buildTileSpec(); - if (failed(tileSpec)) - return failure(); - rewriter.replaceOp(op, buildTileValue(*tileSpec)); - return success(); - } - - Value tileCandidate = peelAllCasts(adaptor.getSource()); - if (viewSemantics && viewSemantics.getValue() == "bitcast" && - isTileLike(tileCandidate)) { - FailureOr tileSpec = buildTileSpec(); - if (failed(tileSpec)) - return failure(); - Value dstTile = buildTileValue(*tileSpec); - FailureOr addr = buildIntegralAddress(tileCandidate); - if (failed(addr)) - return failure(); - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dstTile, *addr}); - rewriter.replaceOp(op, dstTile); - return success(); + // C. 累加 + Value term = rewriter.create(loc, u32Ty, offVal, strideVal); + totalOffset = rewriter.create(loc, u32Ty, totalOffset, term); } - if (viewSemantics && viewSemantics.getValue() == "treshape" && - isTileLike(tileCandidate)) { - FailureOr tileSpec = buildTileSpec(); - if (failed(tileSpec)) - return failure(); - Value dstTile = buildTileValue(*tileSpec, /*forceDeclaration=*/true); - - rewriter.create(loc, TypeRange{}, "TRESHAPE", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dstTile, tileCandidate}); - rewriter.replaceOp(op, dstTile); - return success(); + // 3. 生成新指针 + // + // NOTE: Some toolchains may materialize kernel pointer params as `void*` even + // when the underlying element type is i16. Pointer arithmetic on `void*` + // is ill-formed in C++, so we explicitly cast to a typed pointer for i16. + Value sourcePtr = adaptor.getSource(); + Value tileCandidate = sourcePtr; + if (auto castOp = sourcePtr.getDefiningOp()) { + tileCandidate = castOp.getOperand(); + } else if (auto uc = + sourcePtr.getDefiningOp()) { + tileCandidate = uc.getOperand(0); } - - // Subview origins are kept distinct from generic tile rebinding: - // even when source/destination C++ tile types match, subview may carry - // shifted base address semantics and should materialize a fresh handle. - if (isSubView) { - FailureOr tileSpec = buildTileSpec(); - if (failed(tileSpec)) - return failure(); - Value dstTile = buildTileValue(*tileSpec); - FailureOr addr = buildIntegralAddress(tileCandidate); - if (failed(addr)) - return failure(); - - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dstTile, *addr}); - rewriter.replaceOp(op, dstTile); - return success(); + if (auto ot = dyn_cast(tileCandidate.getType())) { + auto tyStr = ot.getValue(); + if (tyStr.find("Tile<") != std::string::npos || + tyStr.find("ConvTile<") != std::string::npos) { + std::string elemTok = elemTypeToString(srcType.getElementType()); + pto::AddressSpace as = pto::AddressSpace::GM; + if (auto asAttr = + dyn_cast_or_null(srcType.getMemorySpace())) + as = asAttr.getAddressSpace(); + sourcePtr = + materializeTileDataValue(rewriter, loc, tileCandidate, as, elemTok); + if (tileDataReturnsIntegralAddress(as)) + sourcePtr = + materializeAddressAsPointer(rewriter, loc, sourcePtr, as, elemTok); + } } + Value newPtr; + { + auto resTy = mlir::cast(op.getResult().getType()); + Type elemTy = resTy.getElementType(); + if (elemTy.isInteger(16)) { + std::string castElemTypeStr = "int16_t"; + if (cast(elemTy).isUnsigned()) + castElemTypeStr = "uint16_t"; - // Generic tile-to-tile rebind path: preserve the same backing storage and - // rebuild a sibling tile with updated metadata/valid dims. - if (isTileLike(tileCandidate)) { - FailureOr tileSpec = buildTileSpec(); - if (failed(tileSpec)) - return failure(); - - if (!tileSpec->useConstructor) { - if (auto srcTy = dyn_cast(tileCandidate.getType())) { - if (srcTy.getValue() == tileSpec->tileTypeStr) { - rewriter.replaceOp(op, tileCandidate); - return success(); + std::string qualifier = "__gm__"; + if (Attribute ms = srcType.getMemorySpace()) { + if (auto ptoAttr = dyn_cast(ms)) { + qualifier = addrSpaceQualifier(ptoAttr.getAddressSpace()); } } - } - - Value dstTile = buildTileValue(*tileSpec); - FailureOr addr = buildIntegralAddress(tileCandidate); - if (failed(addr)) - return failure(); - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dstTile, *addr}); - rewriter.replaceOp(op, dstTile); - return success(); + auto typedPtrTy = emitc::PointerType::get( + emitc::OpaqueType::get(ctx, qualifier + " " + castElemTypeStr)); + Value typedSourcePtr = rewriter.create(loc, typedPtrTy, sourcePtr); + newPtr = rewriter.create(loc, typedPtrTy, typedSourcePtr, totalOffset); + } else { + newPtr = rewriter.create(loc, sourcePtr.getType(), sourcePtr, totalOffset); + } } - SmallVector physAddrs; - Value source = op.getSource(); - while (auto castOp = source.getDefiningOp()) - source = castOp.getOperand(0); - - if (auto upstreamCast = source.getDefiningOp()) { - auto upstreamOperands = upstreamCast.getAddrs(); - physAddrs.append(upstreamOperands.begin(), upstreamOperands.end()); - } else { - physAddrs.push_back(adaptor.getSource()); + // ------------------------------------------------------------------------- + // Part 2: For non-GM memrefs, keep pointer (no GlobalTensor). + // ------------------------------------------------------------------------- + bool isGlobal = true; + if (auto asAttr = dyn_cast_or_null(srcType.getMemorySpace())) { + auto as = asAttr.getAddressSpace(); + isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); + } + if (!isGlobal) { + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + if (newPtr.getType() != dstTy) + newPtr = rewriter.create(loc, dstTy, newPtr); + rewriter.replaceOp(op, newPtr); + return success(); } - Value vRow = op.getValidRow(); - Value vCol = op.getValidCol(); - - auto newCast = rewriter.create( - loc, op.getType(), physAddrs, vRow ? vRow : Value(), - vCol ? vCol : Value(), configAttr); - if (viewSemantics) - newCast->setAttr("pto.view_semantics", viewSemantics); - if (op->hasAttr(kForceDynamicValidShapeAttrName)) - newCast->setAttr(kForceDynamicValidShapeAttrName, - op->getAttr(kForceDynamicValidShapeAttrName)); - rewriter.replaceOp(op, newCast.getResult()); - - return success(); - } -}; + // ------------------------------------------------------------------------- + // Part 3: 生成 GlobalTensor 类型 (Shape/Stride Template Generation) + // ------------------------------------------------------------------------- + + // When emitting C++ with `declareVariablesAtTop`, value declarations are + // hoisted before body statements. Avoid introducing local `using` aliases + // for templated types (Shape/Stride/GlobalTensor) because those aliases + // would appear after the hoisted declarations and break compilation + // (`unknown type name`). + // + // Instead, use the fully spelled template types as EmitC opaque types. -struct PTOAllocTileToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + auto resTy = mlir::cast(op.getResult().getType()); + + // 1. 解析具体元素类型 + std::string elemTypeStr = getElemTypeStringForGT(resTy.getElementType()); - LogicalResult matchAndRewrite(pto::AllocTileOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - MLIRContext *ctx = rewriter.getContext(); - auto tileTy = cast(op.getResult().getType()); - auto tileTypeString = getEmitCTileTypeString(tileTy); - if (!tileTypeString) - return rewriter.notifyMatchFailure( - op, "only rank-2 alloc_tile handles can be converted to EmitC"); - - Type convertedTy = getTypeConverter()->convertType(tileTy); - if (!convertedTy) - convertedTy = emitc::OpaqueType::get(ctx, *tileTypeString); - - auto validShape = tileTy.getValidShape(); - bool hasDynamicValidDim = - llvm::any_of(validShape, [](int64_t dim) { return dim < 0; }); - bool useConstructor = hasDynamicValidDim; - - SmallVector constructorArgs; - if (useConstructor) { - Type elemTy = tileTy.getElementType(); - pto::BLayout blayout = getTileBufBLayoutValue(tileTy.getConfigAttr()); - auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { - if (!emitted || !pto::isPTOFloat4PackedType(elemTy)) - return emitted; - int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; - if (dimIdx != packedDim) - return emitted; - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); - return rewriter.create(loc, i32Ty, emitted, two) - .getResult(); - }; - - if (validShape.size() > 0 && validShape[0] < 0) { - Value validRow = adaptor.getValidRow(); - if (!validRow) - return rewriter.notifyMatchFailure( - op, "dynamic alloc_tile valid row must have an operand"); - if (validRow) - validRow = peelUnrealized(validRow); - constructorArgs.push_back(maybeScaleDynamicValid(validRow, 0)); - } - if (validShape.size() > 1 && validShape[1] < 0) { - Value validCol = adaptor.getValidCol(); - if (!validCol) - return rewriter.notifyMatchFailure( - op, "dynamic alloc_tile valid col must have an operand"); - if (validCol) - validCol = peelUnrealized(validCol); - constructorArgs.push_back(maybeScaleDynamicValid(validCol, 1)); + // 2. 生成 Shape 模板参数,之后会右对齐有效维度并补齐到 5 维(高维填 1) + SmallVector shapeParamsVec; + SmallVector sizeValues; // 每个维度对应的运行时 size(统一为 unsigned) + auto resShape = resTy.getShape(); + auto mixedSizes = op.getMixedSizes(); + sizeValues.reserve(rank); + for (int i = 0; i < resTy.getRank(); ++i) { + if (resShape[i] == ShapedType::kDynamic) { + shapeParamsVec.push_back(-1); + } else { + shapeParamsVec.push_back(resShape[i]); } + // size 值:优先从 op.getMixedSizes() 取(可动态/静态),否则退化为类型里的静态 shape。 + if (i < (int)mixedSizes.size()) + sizeValues.push_back(ofrToEmitCValue(mixedSizes[i])); + else + sizeValues.push_back( + mkU32(resShape[i] == ShapedType::kDynamic ? 1 : resShape[i])); } - Value tile; - if (useConstructor) { - tile = rewriter - .create( - loc, convertedTy, *tileTypeString, ArrayAttr{}, - ArrayAttr{}, ValueRange(constructorArgs)) - .getResult(0); - } else { - tile = - rewriter - .create( - loc, convertedTy, emitc::OpaqueAttr::get(ctx, "")) - .getResult(); - } + // 3. 生成 Stride 模板参数 + 运行时 stride 值(考虑 subview step) + SmallVector strideTemplateVec; + SmallVector strideValues; // 每个维度对应的运行时 stride(统一为 unsigned) + strideTemplateVec.reserve(rank); + strideValues.reserve(rank); + auto subViewSteps = op.getMixedStrides(); + for (int i = 0; i < rank; ++i) { + OpFoldResult srcStrideOfr = + (i < (int)sourceStrides.size()) ? sourceStrides[i] + : rewriter.getIndexAttr(1); + OpFoldResult stepOfr = (i < (int)subViewSteps.size()) + ? subViewSteps[i] + : rewriter.getIndexAttr(1); - Value addr = adaptor.getAddr(); - if (addr) { - addr = peelUnrealized(addr); - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - if (isa(addr.getType()) || - (isa(addr.getType()) && - cast(addr.getType()).getValue().ends_with("*"))) { - auto rcU64 = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - addr = rewriter - .create(loc, u64Ty, "reinterpret_cast", - ArrayAttr{}, rcU64, - ValueRange{addr}) - .getResult(0); - } else if (addr.getType() != u64Ty) { - addr = rewriter.create(loc, u64Ty, addr).getResult(); + auto srcStatic = extractStaticInt(srcStrideOfr); + auto stepStatic = extractStaticInt(stepOfr); + if (srcStatic && stepStatic) { + int64_t finalStride = (*srcStatic) * (*stepStatic); + strideTemplateVec.push_back(finalStride); + strideValues.push_back(mkU32(finalStride)); + continue; } - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{tile, addr}); + strideTemplateVec.push_back(-1); + Value srcV = ofrToEmitCValue(srcStrideOfr); + Value stepV = ofrToEmitCValue(stepOfr); + // 尽量避免乘以 1 生成冗余指令 + if (stepStatic && *stepStatic == 1) + strideValues.push_back(srcV); + else if (srcStatic && *srcStatic == 1) + strideValues.push_back(stepV); + else + strideValues.push_back( + rewriter.create(loc, u32Ty, srcV, stepV)); } - rewriter.replaceOp(op, tile); - return success(); - } -}; - -static FailureOr -createEmitCTileVariable(ConversionPatternRewriter &rewriter, Location loc, - const TypeConverter *typeConverter, - pto::TileBufType tileTy) { - auto tileTypeString = getEmitCTileTypeString(tileTy); - if (!tileTypeString) - return failure(); - - Type convertedTy = typeConverter->convertType(tileTy); - if (!convertedTy) - convertedTy = emitc::OpaqueType::get(rewriter.getContext(), *tileTypeString); - - return rewriter - .create( - loc, convertedTy, emitc::OpaqueAttr::get(rewriter.getContext(), "")) - .getResult(); -} - -struct PTOTReshapeToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::TReshapeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto tileTy = dyn_cast(op.getResult().getType()); - if (!tileTy) - return failure(); - - FailureOr dst = - createEmitCTileVariable(rewriter, op.getLoc(), getTypeConverter(), tileTy); - if (failed(dst)) - return failure(); - - Value src = peelUnrealized(adaptor.getSrc()); - if (auto castOp = src.getDefiningOp()) - src = castOp.getOperand(); - - rewriter.create(op.getLoc(), TypeRange{}, "TRESHAPE", - ArrayAttr{}, ArrayAttr{}, - ValueRange{*dst, src}); - rewriter.replaceOp(op, *dst); - return success(); - } -}; - -struct PTOBitcastToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(pto::BitcastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto dstTy = dyn_cast(op.getResult().getType()); - auto srcTy = dyn_cast(op.getSrc().getType()); - if (!dstTy || !srcTy) - return failure(); - - FailureOr dst = - createEmitCTileVariable(rewriter, op.getLoc(), getTypeConverter(), dstTy); - if (failed(dst)) - return failure(); + // 3.1 右对齐到 5 维:shape 补 1;已有维度继承原 stride; + // 被补出来的高维按“紧密升维”规则连续推导:stride[i] = shape[i+1] * stride[i+1] + SmallVector finalShape; + SmallVector finalStride; + buildGlobalTensorShapeAndStride(shapeParamsVec, strideTemplateVec, + finalShape, finalStride); + Value oneU32 = mkU32(1); + SmallVector finalShapeValues(5, oneU32); + SmallVector finalStrideValues(5, oneU32); + int shift = 5 - rank; - Value src = peelUnrealized(adaptor.getSrc()); - if (auto castOp = src.getDefiningOp()) - src = castOp.getOperand(); - - pto::AddressSpace as = pto::AddressSpace::GM; - if (auto asAttr = - dyn_cast_or_null(srcTy.getMemorySpace())) - as = asAttr.getAddressSpace(); - std::string elemTok = getEmitCScalarTypeToken(srcTy.getElementType()); - - Value rawPtr = materializeTileDataValue(rewriter, op.getLoc(), src, as, elemTok); - auto u64Ty = emitc::OpaqueType::get(rewriter.getContext(), "uint64_t"); - Value addr = rawPtr; - if (isSetFFTsPointerLikeType(rawPtr.getType())) { - auto rcU64 = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), - "uint64_t")}); - addr = rewriter - .create(op.getLoc(), u64Ty, - "reinterpret_cast", ArrayAttr{}, - rcU64, ValueRange{rawPtr}) - .getResult(0); - } else if (addr.getType() != u64Ty) { - addr = rewriter.create(op.getLoc(), u64Ty, addr).getResult(); + // 先放入原始 shape/stride(保持用户提供的值) + for (int i = 0; i < rank && i < 5; ++i) { + finalShapeValues[shift + i] = sizeValues[i]; + finalStrideValues[shift + i] = strideValues[i]; } - rewriter.create(op.getLoc(), TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{*dst, addr}); - rewriter.replaceOp(op, *dst); - return success(); - } -}; - -struct PTOMaterializeTileToEmitC - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - static bool isTileLike(Value v) { - auto ot = dyn_cast(v.getType()); - if (!ot) - return false; - StringRef s = ot.getValue(); - return s.contains("Tile<") || s.contains("ConvTile<"); - } - - LogicalResult matchAndRewrite(pto::MaterializeTileOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - MLIRContext *ctx = rewriter.getContext(); - auto tileTy = cast(op.getResult().getType()); - auto tileTypeString = getEmitCTileTypeString(tileTy); - if (!tileTypeString) - return rewriter.notifyMatchFailure( - op, "only rank-2 tile_buf handles can be materialized to EmitC"); - - Type convertedTy = getTypeConverter()->convertType(tileTy); - if (!convertedTy) - convertedTy = emitc::OpaqueType::get(ctx, *tileTypeString); - - Value source = peelUnrealized(adaptor.getSource()); - if (auto castOp = source.getDefiningOp()) - source = castOp.getOperand(); - - auto viewSemantics = op->getAttrOfType("pto.view_semantics"); - bool forceDynamicValid = op->hasAttr(kForceDynamicValidShapeAttrName); - bool isReshape = viewSemantics && viewSemantics.getValue() == "treshape"; - bool isSubview = viewSemantics && viewSemantics.getValue() == "subview"; - bool sourceIsDeclaredTile = - op.getSource().getDefiningOp(); - - auto createTileValue = [&]() -> Value { - SmallVector constructorArgs; - bool useConstructor = false; - pto::BLayout blayout = getTileBufBLayoutValue(tileTy.getConfigAttr()); - Type elemTy = tileTy.getElementType(); - auto shape = tileTy.getShape(); - auto validShape = tileTy.getValidShape(); - - auto makeCtorDimValue = [&](Value emitted, int64_t fallback) -> Value { - if (emitted) - return emitted; - return makeEmitCIntConstant( - rewriter, loc, emitc::OpaqueType::get(ctx, "int32_t"), fallback); - }; - auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { - if (!emitted || !pto::isPTOFloat4PackedType(elemTy)) - return emitted; - int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; - if (dimIdx != packedDim) - return emitted; - auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); - Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); - return rewriter.create(loc, i32Ty, emitted, two).getResult(); - }; - auto fallbackDim = [&](int dimIdx) { - return renderTileTemplateDim(shape[dimIdx], elemTy, blayout, dimIdx); - }; - - if (forceDynamicValid) { - useConstructor = true; - constructorArgs.push_back(makeCtorDimValue( - maybeScaleDynamicValid(adaptor.getValidRow(), 0), fallbackDim(0))); - constructorArgs.push_back(makeCtorDimValue( - maybeScaleDynamicValid(adaptor.getValidCol(), 1), fallbackDim(1))); + // 从低维到高维倒推补齐 stride(仅对补出来的前置维度生效) + for (int i = 3; i >= 0; --i) { + // 如果该维已由原始 rank 覆盖,则保持原值 + if (i >= shift) + continue; + if (finalStride[i] != -1) { + finalStrideValues[i] = mkU32(finalStride[i]); + continue; + } + // 动态推导:stride[i] = shape[i+1] * stride[i+1] + if (finalShape[i + 1] == 1) { + finalStrideValues[i] = finalStrideValues[i + 1]; } else { - if (validShape[0] == ShapedType::kDynamic) { - useConstructor = true; - constructorArgs.push_back(makeCtorDimValue( - maybeScaleDynamicValid(adaptor.getValidRow(), 0), fallbackDim(0))); - } - if (validShape[1] == ShapedType::kDynamic) { - useConstructor = true; - constructorArgs.push_back(makeCtorDimValue( - maybeScaleDynamicValid(adaptor.getValidCol(), 1), fallbackDim(1))); - } + finalStrideValues[i] = rewriter.create( + loc, u32Ty, finalShapeValues[i + 1], finalStrideValues[i + 1]); } + } - if (useConstructor) { - return rewriter - .create(loc, convertedTy, *tileTypeString, - ArrayAttr{}, ArrayAttr{}, - ValueRange(constructorArgs)) - .getResult(0); - } + std::string shapeParams = joinIntTemplateParams(finalShape); + std::string strideParams = joinIntTemplateParams(finalStride); - return rewriter - .create(loc, convertedTy, - emitc::OpaqueAttr::get(ctx, "")) - .getResult(); - }; + // Spelled-out C++ types. + std::string shapeCppType = "pto::Shape<" + shapeParams + ">"; + std::string strideCppType = "pto::Stride<" + strideParams + ">"; - if (!isSubview && !forceDynamicValid && isTileLike(source)) { - if (auto srcTy = dyn_cast(source.getType())) { - if (srcTy.getValue() == *tileTypeString) { - rewriter.replaceOp(op, source); - return success(); - } - } - } + // 3.0 Layout: prefer the attribute from InferPTOLayout; only fall back to + // local inference when the pass is disabled. + std::string layoutEnum = "pto::Layout::ND"; + if (auto layout = resolveLayoutForGlobalTensor(op, op.getSource())) { + layoutEnum = layoutToEmitCString(*layout); + } else { + bool allStatic = + llvm::all_of(finalShape, [](int64_t value) { return value != -1; }) && + llvm::all_of(finalStride, [](int64_t value) { return value != -1; }); - Value tile = createTileValue(); - if (sourceIsDeclaredTile) { - rewriter.replaceOp(op, tile); - return success(); - } + int layoutTag = 0; // ND + auto elemBytes = 4; // default float + if (elemTypeStr.find("half") != std::string::npos || + elemTypeStr.find("f16") != std::string::npos || + elemTypeStr.find("bf16") != std::string::npos) + elemBytes = 2; + else if (elemTypeStr.find("double") != std::string::npos || + elemTypeStr.find("f64") != std::string::npos) + elemBytes = 8; - if (isReshape && isTileLike(source)) { - rewriter.create(loc, TypeRange{}, "TRESHAPE", - ArrayAttr{}, ArrayAttr{}, - ValueRange{tile, source}); - rewriter.replaceOp(op, tile); - return success(); - } + if (allStatic) { + if (finalShape[2] == 16 && + finalShape[2] * finalShape[3] * elemBytes == 512 && + finalStride[4] == 1 && finalStride[3] == finalShape[4]) { + layoutTag = 2; // NZ + } else { + bool isRow = finalStride[4] == 1; + for (int i = 3; i >= 0; --i) + isRow &= (finalStride[i] == + multiplyOrDynamic(finalStride[i + 1], finalShape[i + 1])); + bool isCol = finalStride[0] == 1; + for (int i = 0; i < 4; ++i) + isCol &= (finalStride[i + 1] == + multiplyOrDynamic(finalStride[i], finalShape[i])); + if (isCol) + layoutTag = 1; // DN + else + layoutTag = isRow ? 0 : 0; // fallback ND + } + } - pto::AddressSpace as = pto::AddressSpace::GM; - if (auto asAttr = - dyn_cast_or_null(tileTy.getMemorySpace())) - as = asAttr.getAddressSpace(); - std::string elemTok = getEmitCScalarTypeToken(tileTy.getElementType()); - - Value rawPtr = source; - if (isTileLike(rawPtr)) - rawPtr = materializeTileDataValue(rewriter, loc, rawPtr, as, elemTok); - - auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); - Value addr = rawPtr; - if (isSetFFTsPointerLikeType(rawPtr.getType())) { - auto rcU64 = - rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); - addr = rewriter - .create(loc, u64Ty, "reinterpret_cast", - ArrayAttr{}, rcU64, - ValueRange{rawPtr}) - .getResult(0); - } else if (rawPtr.getType() != u64Ty) { - addr = rewriter.create(loc, u64Ty, rawPtr).getResult(); + if (layoutTag == 1) + layoutEnum = "pto::Layout::DN"; + else if (layoutTag == 2) + layoutEnum = "pto::Layout::NZ"; } + // GlobalTensor takes a Layout non-type template parameter; directly use the + // enum constant. - rewriter.create(loc, TypeRange{}, "TASSIGN", - ArrayAttr{}, ArrayAttr{}, - ValueRange{tile, addr}); - rewriter.replaceOp(op, tile); - return success(); - } -}; -// ============================================================================= -// Arith CmpI -> EmitC Cmp -// ============================================================================= -class ArithCmpIToEmitC : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); + // ------------------------------------------------------------------------- + // Part 3: 显式对象实例化 (Explicit Object Instantiation) + // ------------------------------------------------------------------------- - // 将 arith.cmpi 转换为 emitc.cmp - // 映射 Predicate: eq -> equal, slt -> less, etc. - emitc::CmpPredicate emitcPred = emitc::CmpPredicate::eq; - const bool isUnsignedPred = - op.getPredicate() == arith::CmpIPredicate::ult || - op.getPredicate() == arith::CmpIPredicate::ule || - op.getPredicate() == arith::CmpIPredicate::ugt || - op.getPredicate() == arith::CmpIPredicate::uge; - switch (op.getPredicate()) { - case arith::CmpIPredicate::eq: emitcPred = emitc::CmpPredicate::eq; break; - case arith::CmpIPredicate::ne: emitcPred = emitc::CmpPredicate::ne; break; - case arith::CmpIPredicate::slt: emitcPred = emitc::CmpPredicate::lt; break; - case arith::CmpIPredicate::sle: emitcPred = emitc::CmpPredicate::le; break; - case arith::CmpIPredicate::sgt: emitcPred = emitc::CmpPredicate::gt; break; - case arith::CmpIPredicate::sge: emitcPred = emitc::CmpPredicate::ge; break; - // ... 处理无符号比较 (ult, ule 等) ... - case arith::CmpIPredicate::ult: emitcPred = emitc::CmpPredicate::lt; break; - case arith::CmpIPredicate::ule: emitcPred = emitc::CmpPredicate::le; break; - case arith::CmpIPredicate::ugt: emitcPred = emitc::CmpPredicate::gt; break; - case arith::CmpIPredicate::uge: emitcPred = emitc::CmpPredicate::ge; break; + // A. Instantiate Shape object. + auto shapeTypeOpaque = emitc::OpaqueType::get(ctx, shapeCppType); + SmallVector shapeArgs; + // 从 adaptor.getSizes() 获取 subview 的所有 dynamic sizes + for (Value dynSize : adaptor.getSizes()) { + shapeArgs.push_back(dynSize); + } + + auto shapeInstOp = rewriter.create( + loc, + shapeTypeOpaque, // 返回类型 + shapeCppType, // 调用的“函数名”即类名构造函数 + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange(shapeArgs) + ); + + // B. Instantiate Stride object. + auto strideTypeOpaque = emitc::OpaqueType::get(ctx, strideCppType); + // 仅传入动态 stride 维度对应的值,匹配 pto::Stride 的 N-parameter ctor(并满足其 static_assert)。 + SmallVector strideCtorArgs; + strideCtorArgs.reserve(5); + for (int i = 0; i < 5; ++i) { + if (finalStride[i] == -1) + strideCtorArgs.push_back(finalStrideValues[i]); } + auto strideInstOp = rewriter.create( + loc, strideTypeOpaque, strideCppType, + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange(strideCtorArgs)); - Type resTy = getTypeConverter()->convertType(op.getType()); - if (!resTy) - return failure(); + // C. Instantiate GlobalTensor object (ptr + shape + stride). + std::string gtCppType = "GlobalTensor<" + elemTypeStr + ", " + shapeCppType + + ", " + strideCppType + ", " + layoutEnum + ">"; + auto gtType = emitc::OpaqueType::get(ctx, gtCppType); - Value lhs = adaptor.getLhs(); - Value rhs = adaptor.getRhs(); - if (isUnsignedPred) { - Type opTy = op.getLhs().getType(); - auto intTy = dyn_cast(opTy); - const bool isIndex = isa(opTy); - if (!intTy && !isIndex) - return rewriter.notifyMatchFailure( - op, "expected scalar integer or index operands"); - - const unsigned bitWidth = - intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); - if (bitWidth != 1) { - lhs = castSignlessIntToUnsignedSameWidth(rewriter, loc, lhs, bitWidth); - rhs = castSignlessIntToUnsignedSameWidth(rewriter, loc, rhs, bitWidth); - } - } + // 准备构造参数: [ptr, shape_instance, stride_instance] + SmallVector gtConstructorArgs; + gtConstructorArgs.push_back(newPtr); + gtConstructorArgs.push_back(shapeInstOp.getResult(0)); // 拿到 shape_inst 的 SSA Value + gtConstructorArgs.push_back(strideInstOp.getResult(0)); // 拿到 stride_inst 的 SSA Value - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, - /*resultType=*/resTy, // i1 -> bool/i1 - emitcPred, - lhs, - rhs + gtType, + gtCppType, + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange(gtConstructorArgs) ); + return success(); } }; //===----------------------------------------------------------------------===// -// Section Op Lowering +// Helper: build GlobalTensor from a static MemRef (for TLOAD/TSTORE) //===----------------------------------------------------------------------===// -static bool isA5NoSplitPipeOp(Operation *op) { - if (auto talloc = dyn_cast(op)) - return talloc.getSplit() == 0; - if (auto tpush = dyn_cast(op)) - return tpush.getSplit() == 0; - if (auto tpop = dyn_cast(op)) - return tpop.getSplit() == 0; - if (auto tfree = dyn_cast(op)) - return tfree.getSplit() == 0; - if (auto tpush = dyn_cast(op)) - return tpush.getSplit() == 0; - if (auto tpush = dyn_cast(op)) - return tpush.getSplit() == 0; - if (auto talloc = dyn_cast(op)) - return talloc.getSplit() == 0; - if (auto talloc = dyn_cast(op)) - return talloc.getSplit() == 0; - if (auto tpop = dyn_cast(op)) - return tpop.getSplit() == 0; - if (auto tpop = dyn_cast(op)) - return tpop.getSplit() == 0; - if (auto tfree = dyn_cast(op)) - return tfree.getSplit() == 0; - if (auto tfree = dyn_cast(op)) - return tfree.getSplit() == 0; - return false; + +std::string mlir::pto::getElemTypeStringForGT(Type elemTy) { + return getEmitCScalarTypeToken(elemTy); } -static bool hasExplicitSubblockControl(Operation *op) { - bool hasControl = false; - op->walk([&](Operation *nested) { - if (isa(nested)) { - hasControl = true; - return WalkResult::interrupt(); - } - return WalkResult::advance(); +static bool hasStaticShape(MemRefType mrTy) { + return llvm::none_of(mrTy.getShape(), [](int64_t dim) { + return dim == ShapedType::kDynamic; }); - return hasControl; } -static bool needsA5NoSplitVectorGuard(Operation *op) { - auto arch = getTargetArch(op); - if (arch != PTOArch::A5) - return false; - bool isVectorScope = isa(op); - if (auto func = dyn_cast(op)) { - if (auto kernelKindAttr = - func->getAttrOfType( - FunctionKernelKindAttr::name)) { - isVectorScope = - kernelKindAttr.getKernelKind() == FunctionKernelKind::Vector; +static bool getStaticMemrefLayout(MemRefType mrTy, SmallVectorImpl &strides, + int64_t &offset) { + if (failed(getStridesAndOffset(mrTy, strides, offset))) { + strides.clear(); + int64_t stride = 1; + ArrayRef shape = mrTy.getShape(); + for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { + strides.push_back(stride); + stride *= shape[i]; } + std::reverse(strides.begin(), strides.end()); + offset = 0; } - if (!isVectorScope) - return false; - if (hasExplicitSubblockControl(op)) - return false; - - bool hasNoSplitPipe = false; - op->walk([&](Operation *nested) { - if (!isA5NoSplitPipeOp(nested)) - return WalkResult::advance(); - hasNoSplitPipe = true; - return WalkResult::interrupt(); - }); - return hasNoSplitPipe; + return offset != ShapedType::kDynamic && + llvm::none_of(strides, [](int64_t strideValue) { + return strideValue == ShapedType::kDynamic; + }); } -template -struct SectionToEmitC : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +Value mlir::pto::applyStaticMemrefOffset(ConversionPatternRewriter &rewriter, + Location loc, Value basePtr, + int64_t offset) { + if (offset == 0) + return basePtr; + auto *ctx = rewriter.getContext(); + Type u32Ty = emitc::OpaqueType::get(ctx, "unsigned"); + auto offVal = rewriter.create( + loc, u32Ty, emitc::OpaqueAttr::get(ctx, std::to_string(offset))); + return rewriter.create(loc, basePtr.getType(), basePtr, offVal); +} - std::string getMacroName() const { - if (std::is_same::value) - return "__DAV_CUBE__"; - if (std::is_same::value) - return "__DAV_VEC__"; - return "UNKNOWN_MACRO"; - } +static int getGlobalTensorElementBytes(Type elemTy) { + return static_cast(getPTOStorageElemByteSize(elemTy)); +} - LogicalResult - matchAndRewrite(SectionOpTy op, typename SectionOpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - bool needsNoSplitGuard = needsA5NoSplitVectorGuard(op.getOperation()); +static int64_t multiplyOrDynamic(int64_t lhs, int64_t rhs) { + if (lhs < 0 || rhs < 0) + return -1; + return lhs * rhs; +} - std::string startMacro = "\n#if defined(" + getMacroName() + ")"; - rewriter.create(loc, startMacro); +void mlir::pto::buildGlobalTensorShapeAndStride(ArrayRef shape, + ArrayRef strides, + SmallVectorImpl &shape5D, + SmallVectorImpl &stride5D) { + shape5D.assign(5, 1); + stride5D.assign(5, 1); + int rank = static_cast(shape.size()); + int shift = 5 - rank; + for (int i = 0; i < rank && i < 5; ++i) { + shape5D[shift + i] = shape[i]; + stride5D[shift + i] = strides[i]; + } + for (int i = 3; i >= 0; --i) { + if (i >= shift) + continue; + stride5D[i] = multiplyOrDynamic(shape5D[i + 1], stride5D[i + 1]); + } +} - if constexpr (std::is_same_v) { - // Vector mask is a global HW state and may be modified by previous kernels - // (or earlier sections). Reset it to a well-defined state for deterministic - // execution of VEC ops. - rewriter.create(loc, "set_mask_norm();"); - rewriter.create(loc, "set_vector_mask(-1, -1);"); - } +std::string mlir::pto::joinIntTemplateParams(ArrayRef values) { + std::string result; + for (size_t i = 0; i < values.size(); ++i) { + if (i != 0) + result += ", "; + result += std::to_string(values[i]); + } + return result; +} - if (needsNoSplitGuard) { - rewriter.create( - loc, "if (get_subblockid() == 0) {"); - } +SmallVector mlir::pto::buildRowMajorStrides(ArrayRef shape) { + SmallVector strides(shape.size(), 1); + int64_t running = 1; + for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { + strides[i] = running; + running = multiplyOrDynamic(running, shape[i]); + } + return strides; +} - Block &innerBlock = op.getBody().front(); - if (!innerBlock.empty()) { - rewriter.inlineBlockBefore(&innerBlock, op.getOperation(), ValueRange{}); - } +static std::string getGlobalTensorTypeStringFromShape(Type elemTy, + ArrayRef shape, + StringRef layoutEnum) { + SmallVector strides = buildRowMajorStrides(shape); + return getGlobalTensorTypeStringFromShapeAndStrides(elemTy, shape, strides, + layoutEnum); +} - if (needsNoSplitGuard) - rewriter.create(loc, "}"); +std::string mlir::pto::getGlobalTensorTypeStringFromShapeAndStrides( + Type elemTy, ArrayRef shape, ArrayRef strides, + StringRef layoutEnum) { + SmallVector shape5D; + SmallVector stride5D; + buildGlobalTensorShapeAndStride(shape, strides, shape5D, stride5D); - std::string endMacro = "#endif // " + getMacroName() + "\n"; - rewriter.create(loc, endMacro); + std::string elemTypeStr = getElemTypeStringForGT(elemTy); + std::string shapeType = "pto::Shape<" + joinIntTemplateParams(shape5D) + ">"; + std::string strideType = + "pto::Stride<" + joinIntTemplateParams(stride5D) + ">"; + return "GlobalTensor<" + elemTypeStr + ", " + shapeType + ", " + + strideType + ", " + layoutEnum.str() + ">"; +} - rewriter.eraseOp(op); +static emitc::OpaqueType getGlobalTensorOpaqueTypeFromShape( + MLIRContext *ctx, Type elemTy, ArrayRef shape, + StringRef layoutEnum) { + return emitc::OpaqueType::get( + ctx, getGlobalTensorTypeStringFromShape(elemTy, shape, layoutEnum)); +} - return success(); +static std::string inferFallbackGlobalTensorLayout(ArrayRef shape5D, + ArrayRef stride5D, + Type elemTy) { + int elemBytes = getGlobalTensorElementBytes(elemTy); + if (elemBytes == 0) + return "pto::Layout::ND"; + if (shape5D[2] == 16 && multiplyOrDynamic(shape5D[2], shape5D[3]) * elemBytes == 512 && + stride5D[4] == 1 && stride5D[3] == shape5D[4]) { + return "pto::Layout::NZ"; } -}; -//===----------------------------------------------------------------------===// -// SCF Control-Flow Pre-Lowering -// -// EmitC translation supports `emitc.for`/`emitc.if` plus CFG-style -// `cf.br`/`cf.cond_br`. Upstream SCFToEmitC patterns only cover `scf.for` and -// `scf.if`, so we pre-lower some SCF ops into those supported forms. -//===----------------------------------------------------------------------===// + bool isRowMajor = stride5D[4] == 1; + for (int i = 3; i >= 0 && isRowMajor; --i) + isRowMajor = stride5D[i] == multiplyOrDynamic(stride5D[i + 1], shape5D[i + 1]); -namespace { + bool isColMajor = stride5D[0] == 1; + for (int i = 0; i < 4 && isColMajor; ++i) + isColMajor = stride5D[i + 1] == multiplyOrDynamic(stride5D[i], shape5D[i]); -static bool isTriviallyInlineableExecuteRegion(scf::ExecuteRegionOp op) { - Region &r = op.getRegion(); - if (!r.hasOneBlock()) - return false; - Block &b = r.front(); - return isa_and_nonnull(b.getTerminator()); + if (isColMajor) + return "pto::Layout::DN"; + return isRowMajor ? "pto::Layout::ND" : "pto::Layout::ND"; } -static bool needsWholeFunctionSCFToCF(func::FuncOp func) { - bool needs = false; - func.walk([&](Operation *op) { - if (!isa(op)) - return WalkResult::advance(); - Operation *parentOp = op->getParentOp(); - - // `scf.execute_region` can legally appear in single-block parents. Only - // require whole-function SCFToCF if we need to lower it into CFG blocks - // (multi-block region / non-trivial terminators). - if (auto exec = dyn_cast(op)) { - if (parentOp && parentOp->hasTrait() && - !isTriviallyInlineableExecuteRegion(exec)) { - needs = true; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - } - - if (parentOp && parentOp->hasTrait()) { - needs = true; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - return needs; +static std::string resolveGlobalTensorLayout(Operation *anchor, Value basePtr, + ArrayRef shape5D, + ArrayRef stride5D, + Type elemTy) { + if (auto layout = resolveLayoutForGlobalTensor(anchor, basePtr)) + return layoutToEmitCString(*layout); + return inferFallbackGlobalTensorLayout(shape5D, stride5D, elemTy); } -// scf.execute_region is semantically just an inlined region producing results -// via scf.yield. Inline it to the parent block to avoid extra lowering needs. -struct SCFExecuteRegionInline - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(scf::ExecuteRegionOp op, - PatternRewriter &rewriter) const override { - if (op.getRegion().empty()) - return rewriter.notifyMatchFailure(op, "expected non-empty region"); +struct GlobalTensorTypeNames { + std::string shapeTypeName; + std::string strideTypeName; + std::string tensorTypeName; + std::string layoutConstName; +}; - Block &innerBlock = op.getRegion().front(); - auto yield = dyn_cast(innerBlock.getTerminator()); - if (!yield) - return rewriter.notifyMatchFailure(op, "expected scf.yield terminator"); +static GlobalTensorTypeNames getGlobalTensorTypeNames(Operation *anchor) { + std::string suffix = "_" + std::to_string(reinterpret_cast(anchor)); + return { + "GTShape" + suffix, + "GTStride" + suffix, + "GT" + suffix, + "GT" + suffix + "_layout", + }; +} +Value mlir::pto::buildGlobalTensorFromMemref(ConversionPatternRewriter &rewriter, + Location loc, Value basePtr, + MemRefType mrTy, + Operation *anchor) { + auto *ctx = rewriter.getContext(); - // Move the body operations before the execute_region op. - rewriter.inlineBlockBefore(&innerBlock, op.getOperation(), ValueRange{}); + ArrayRef shape = mrTy.getShape(); + if (!hasStaticShape(mrTy)) + return Value(); - // Replace execute_region results with yielded values, then erase the yield. - rewriter.replaceOp(op, yield.getOperands()); - rewriter.eraseOp(yield); - return success(); - } -}; + SmallVector strides; + int64_t offset = 0; + if (!getStaticMemrefLayout(mrTy, strides, offset)) + return Value(); -// Lower scf.execute_region into CFG blocks with cf.br/cf.cond_br by inlining the -// region blocks into the parent region and rewriting scf.yield to branch into a -// continuation block carrying results. -// -// Note: This requires the parent region to allow multiple blocks (e.g. the -// function body CFG region). For execute_region nested in single-block regions -// (scf.for/scf.if), run SCFToCF first to eliminate the single-block constraint. -struct SCFExecuteRegionToCF : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(scf::ExecuteRegionOp op, - PatternRewriter &rewriter) const override { - if (isTriviallyInlineableExecuteRegion(op)) - return rewriter.notifyMatchFailure(op, "trivially inlineable"); - - Operation *parentOp = op->getParentOp(); - if (parentOp && parentOp->hasTrait()) { - return rewriter.notifyMatchFailure( - op, "cannot lower scf.execute_region inside a single-block parent region"); - } + Value ptr = applyStaticMemrefOffset(rewriter, loc, basePtr, offset); + GlobalTensorTypeNames names = getGlobalTensorTypeNames(anchor); + std::string elemTypeStr = getElemTypeStringForGT(mrTy.getElementType()); + SmallVector shape5D; + SmallVector stride5D; + buildGlobalTensorShapeAndStride(shape, strides, shape5D, stride5D); - if (op.getRegion().empty()) - return rewriter.notifyMatchFailure(op, "expected non-empty region"); - - Location loc = op.getLoc(); - Block *curBlock = op->getBlock(); - Region *parentRegion = curBlock->getParent(); - - // Split the parent block so we can branch to a continuation block with phi - // arguments for the execute_region results. - auto execIt = Block::iterator(op.getOperation()); - Block *continueBlock = rewriter.splitBlock(curBlock, std::next(execIt)); - - SmallVector contArgs; - contArgs.reserve(op.getNumResults()); - for (Type t : op.getResultTypes()) - contArgs.push_back(continueBlock->addArgument(t, loc)); - - for (auto it : llvm::enumerate(op.getResults())) - it.value().replaceAllUsesWith(contArgs[it.index()]); - - // Capture blocks before moving the region. - SmallVector movedBlocks; - movedBlocks.reserve(op.getRegion().getBlocks().size()); - for (Block &b : op.getRegion()) - movedBlocks.push_back(&b); - Block *entryBlock = &op.getRegion().front(); - - // Inline the execute_region blocks into the parent region right before the - // continuation block. - rewriter.inlineRegionBefore(op.getRegion(), *parentRegion, - continueBlock->getIterator()); - - // Replace all scf.yield terminators with a branch to the continuation. - for (Block *b : movedBlocks) { - auto yield = dyn_cast(b->getTerminator()); - if (!yield) - continue; - rewriter.setInsertionPoint(yield); - rewriter.create(loc, continueBlock, yield.getOperands()); - rewriter.eraseOp(yield); - } + rewriter.create( + loc, "using " + names.shapeTypeName + " = pto::Shape<" + + joinIntTemplateParams(shape5D) + ">;"); + rewriter.create( + loc, "using " + names.strideTypeName + " = pto::Stride<" + + joinIntTemplateParams(stride5D) + ">;"); - // Replace execute_region itself with a branch to the inlined entry block. - rewriter.setInsertionPoint(op); - rewriter.create(loc, entryBlock, ValueRange{}); - rewriter.eraseOp(op); - return success(); - } -}; + std::string layoutEnum = resolveGlobalTensorLayout( + anchor, basePtr, shape5D, stride5D, mrTy.getElementType()); + rewriter.create(loc, "constexpr pto::Layout " + + names.layoutConstName + " = " + + layoutEnum + ";"); -// Lower scf.index_switch into CFG blocks with cf.cond_br/cf.br so that we can -// avoid `scf.if` result materialization quirks (and avoid relying on cf.switch, -// which is not supported by EmitC C++ translation). -struct SCFIndexSwitchToCF : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + auto shapeTypeOpaque = emitc::OpaqueType::get(ctx, names.shapeTypeName); + auto strideTypeOpaque = emitc::OpaqueType::get(ctx, names.strideTypeName); + auto shapeInstOp = rewriter.create( + loc, shapeTypeOpaque, names.shapeTypeName, ArrayAttr{}, ArrayAttr{}, + ValueRange{}); + auto strideInstOp = rewriter.create( + loc, strideTypeOpaque, names.strideTypeName, ArrayAttr{}, ArrayAttr{}, + ValueRange{}); - static LogicalResult cloneYieldingBlockAndBranchTo( - PatternRewriter &rewriter, Location loc, Block &srcBlock, Block *destBlock, - Block *continueBlock) { - rewriter.setInsertionPointToEnd(destBlock); + rewriter.create( + loc, "using " + names.tensorTypeName + " = GlobalTensor<" + elemTypeStr + + ", " + names.shapeTypeName + ", " + names.strideTypeName + + ", " + names.layoutConstName + ">;"); + auto gtType = emitc::OpaqueType::get(ctx, names.tensorTypeName); - IRMapping mapping; - for (Operation &inner : srcBlock.without_terminator()) - rewriter.clone(inner, mapping); + SmallVector gtArgs; + gtArgs.push_back(ptr); + gtArgs.push_back(shapeInstOp.getResult(0)); + gtArgs.push_back(strideInstOp.getResult(0)); - auto yield = dyn_cast(srcBlock.getTerminator()); - if (!yield) - return failure(); + auto gtInst = rewriter.create( + loc, gtType, names.tensorTypeName, ArrayAttr{}, ArrayAttr{}, + ValueRange(gtArgs)); - SmallVector yieldOperands; - yieldOperands.reserve(yield.getNumOperands()); - for (Value v : yield.getOperands()) - yieldOperands.push_back(mapping.lookupOrDefault(v)); + return gtInst.getResult(0); +} - rewriter.create(loc, continueBlock, yieldOperands); - return success(); - } +static Value maybeWrapGlobalMemrefAsGlobalTensor( + ConversionPatternRewriter &rewriter, Location loc, Value loweredValue, + Type originalType, Operation *anchor) { + auto mrTy = dyn_cast(originalType); + if (!mrTy) + return loweredValue; - static Block *splitBlockForContinuation(PatternRewriter &rewriter, - scf::IndexSwitchOp op) { - auto switchIt = Block::iterator(op.getOperation()); - return rewriter.splitBlock(op->getBlock(), std::next(switchIt)); + bool isGlobal = true; + if (auto asAttr = + dyn_cast_or_null(mrTy.getMemorySpace())) { + auto as = asAttr.getAddressSpace(); + isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); } + if (!isGlobal) + return loweredValue; - static void addContinuationArguments(PatternRewriter &rewriter, - scf::IndexSwitchOp op, Location loc, - Block *continueBlock) { - SmallVector contArgs; - contArgs.reserve(op.getNumResults()); - for (Type type : op.getResultTypes()) - contArgs.push_back(continueBlock->addArgument(type, loc)); - for (auto result : llvm::enumerate(op.getResults())) - result.value().replaceAllUsesWith(contArgs[result.index()]); - } + if (Value gt = + buildGlobalTensorFromMemref(rewriter, loc, loweredValue, mrTy, anchor)) + return gt; + return loweredValue; +} - static void createIndexSwitchBlocks(PatternRewriter &rewriter, - Region *parentRegion, - Region::iterator insertPt, - unsigned numCases, - SmallVectorImpl &checkBlocks, - Block *&defaultBlock, - SmallVectorImpl &caseBlocks) { - checkBlocks.reserve(numCases); - caseBlocks.reserve(numCases); - for (unsigned i = 0; i < numCases; ++i) - checkBlocks.push_back(rewriter.createBlock(parentRegion, insertPt)); - defaultBlock = rewriter.createBlock(parentRegion, insertPt); - for (unsigned i = 0; i < numCases; ++i) - caseBlocks.push_back(rewriter.createBlock(parentRegion, insertPt)); - } +Value mlir::pto::castToGMBytePointer(ConversionPatternRewriter &rewriter, + Location loc, Value value) { + auto *ctx = rewriter.getContext(); + auto targetTy = + emitc::PointerType::get(emitc::OpaqueType::get(ctx, "__gm__ uint8_t")); + if (value.getType() == targetTy) + return value; - static void populateIndexSwitchCheckBlocks( - PatternRewriter &rewriter, Location loc, Value selector, - ArrayRef cases, ArrayRef checkBlocks, - ArrayRef caseBlocks, Block *defaultBlock) { - for (unsigned i = 0; i < checkBlocks.size(); ++i) { - rewriter.setInsertionPointToEnd(checkBlocks[i]); - Value caseVal = rewriter.create(loc, cases[i]); - Value cond = rewriter.create( - loc, arith::CmpIPredicate::eq, selector, caseVal); - Block *falseDest = - (i + 1 < checkBlocks.size()) ? checkBlocks[i + 1] : defaultBlock; - rewriter.create(loc, cond, caseBlocks[i], ValueRange{}, - falseDest, ValueRange{}); - } + auto castTyAttr = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "__gm__ uint8_t*")}); + if (isSetFFTsPointerLikeType(value.getType())) { + return rewriter + .create(loc, targetTy, "reinterpret_cast", + ArrayAttr{}, castTyAttr, + ValueRange{value}) + .getResult(0); } + return rewriter.create(loc, targetTy, value).getResult(); +} - LogicalResult matchAndRewrite(scf::IndexSwitchOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Operation *parentOp = op->getParentOp(); - if (parentOp && parentOp->hasTrait()) { - return rewriter.notifyMatchFailure( - op, "cannot lower scf.index_switch inside a single-block parent region"); - } +Value mlir::pto::materializeTensorViewDataPointer( + ConversionPatternRewriter &rewriter, Location loc, Value value, + Type sourceType) { + auto tvTy = dyn_cast(sourceType); + if (!tvTy) + return value; - Block *curBlock = op->getBlock(); - Region *parentRegion = curBlock->getParent(); - Block *continueBlock = splitBlockForContinuation(rewriter, op); - addContinuationArguments(rewriter, op, loc, continueBlock); - - unsigned numCases = op.getCases().size(); - auto insertPt = continueBlock->getIterator(); - - SmallVector checkBlocks; - SmallVector caseBlocks; - Block *defaultBlock = nullptr; - createIndexSwitchBlocks(rewriter, parentRegion, insertPt, numCases, - checkBlocks, defaultBlock, caseBlocks); - - Value selector = op.getArg(); - auto cases = op.getCases(); - populateIndexSwitchCheckBlocks(rewriter, loc, selector, cases, checkBlocks, - caseBlocks, defaultBlock); - - // Fill case blocks and default block with cloned bodies + branch to cont. - for (unsigned i = 0; i < numCases; ++i) { - if (failed(cloneYieldingBlockAndBranchTo( - rewriter, loc, op.getCaseBlock(i), caseBlocks[i], continueBlock))) - return rewriter.notifyMatchFailure(op, "expected scf.yield terminator"); - } - if (failed(cloneYieldingBlockAndBranchTo(rewriter, loc, op.getDefaultBlock(), - defaultBlock, continueBlock))) - return rewriter.notifyMatchFailure(op, "expected scf.yield terminator"); - - // Replace the original switch op with a branch into the check chain. - Block *entryDest = numCases ? checkBlocks[0] : defaultBlock; - rewriter.setInsertionPointAfter(op); - rewriter.create(loc, entryDest, ValueRange{}); - rewriter.eraseOp(op); - return success(); - } -}; + auto *ctx = rewriter.getContext(); + std::string elemTypeStr = getElemTypeStringForGT(tvTy.getElementType()); + auto ptrTy = emitc::PointerType::get( + emitc::OpaqueType::get(ctx, "__gm__ " + elemTypeStr)); + return rewriter + .create(loc, ptrTy, "PTOAS__GLOBAL_TENSOR_DATA", + ArrayAttr{}, ArrayAttr{}, ValueRange{value}) + .getResult(0); +} -// Lower scf.while into CFG blocks with cf.br/cf.cond_br. -// -// Note: This requires the parent region to allow multiple blocks. In -// particular, scf.if/scf.for regions are single-block and cannot contain this -// lowering. -struct SCFWhileToCF : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - static LogicalResult validateWhileResultUses(scf::WhileOp op) { - Block *parentBlock = op->getBlock(); - for (Value result : op.getResults()) { - for (OpOperand &use : result.getUses()) { - if (use.getOwner()->getBlock() != parentBlock) - return failure(); - } - } - return success(); +static std::string tileBufBLayoutToken(pto::TileBufConfigAttr configAttr) { + std::string blTok = "BLayout::RowMajor"; + if (auto blAttr = dyn_cast(configAttr.getBLayout())) { + if (static_cast(blAttr.getValue()) == 1) + blTok = "BLayout::ColMajor"; } + return blTok; +} - static Block *splitAfterWhileBlock(PatternRewriter &rewriter, - scf::WhileOp op) { - auto whileIt = Block::iterator(op.getOperation()); - return rewriter.splitBlock(op->getBlock(), std::next(whileIt)); +static std::string tileBufSLayoutToken(pto::TileBufConfigAttr configAttr) { + std::string slTok = "SLayout::NoneBox"; + if (auto slAttr = dyn_cast(configAttr.getSLayout())) { + int32_t slVal = static_cast(slAttr.getValue()); + slTok = (slVal == 1) ? "SLayout::RowMajor" + : (slVal == 2) ? "SLayout::ColMajor" + : "SLayout::NoneBox"; } + return slTok; +} - static void addWhileExitArguments(PatternRewriter &rewriter, scf::WhileOp op, - Location loc, Block *afterWhileBlock) { - SmallVector exitArgs; - exitArgs.reserve(op.getNumResults()); - for (Type type : op.getResultTypes()) - exitArgs.push_back(afterWhileBlock->addArgument(type, loc)); - for (auto result : llvm::enumerate(op.getResults())) - result.value().replaceAllUsesWith(exitArgs[result.index()]); +static std::string tileBufPadToken(pto::TileBufConfigAttr configAttr) { + std::string padTok = "PadValue::Null"; + if (auto padAttr = dyn_cast(configAttr.getPad())) { + switch (static_cast(padAttr.getValue())) { + case 1: + padTok = "PadValue::Zero"; + break; + case 2: + padTok = "PadValue::Max"; + break; + case 3: + padTok = "PadValue::Min"; + break; + default: + padTok = "PadValue::Null"; + break; + } } + return padTok; +} - static Block *createWhileHeaderBlock(PatternRewriter &rewriter, - scf::WhileOp op, Location loc, - Block *afterWhileBlock) { - SmallVector headerArgTypes; - for (Value init : op.getInits()) - headerArgTypes.push_back(init.getType()); - SmallVector headerArgLocs(headerArgTypes.size(), loc); - return rewriter.createBlock(afterWhileBlock->getParent(), - afterWhileBlock->getIterator(), headerArgTypes, - headerArgLocs); - } +pto::BLayout mlir::pto::getTileBufBLayoutValue(pto::TileBufConfigAttr configAttr) { + if (auto blAttr = dyn_cast(configAttr.getBLayout())) + return blAttr.getValue(); + return pto::BLayout::RowMajor; +} - static Block *createWhileBodyBlock(PatternRewriter &rewriter, scf::WhileOp op, - Location loc, Block *afterWhileBlock) { - Block &afterRegionBlock = op.getAfter().front(); - SmallVector bodyArgTypes(afterRegionBlock.getArgumentTypes().begin(), - afterRegionBlock.getArgumentTypes().end()); - SmallVector bodyArgLocs(bodyArgTypes.size(), loc); - return rewriter.createBlock(afterWhileBlock->getParent(), - afterWhileBlock->getIterator(), bodyArgTypes, - bodyArgLocs); - } +int64_t mlir::pto::renderTileTemplateDim(int64_t rawDim, Type elemTy, + pto::BLayout blayout, int dimIdx) { + assert(dimIdx >= 0 && dimIdx < 2 && + "renderTileTemplateDim expects a rank-2 rows/cols dimension index"); + if (rawDim == ShapedType::kDynamic) + return rawDim; + if (!pto::isPTOFloat4PackedType(elemTy)) + return rawDim; + int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; + return dimIdx == packedDim ? rawDim * 2 : rawDim; +} - static void rewriteWhileTerminators(PatternRewriter &rewriter, Location loc, - Block *headerBlock, Block *bodyBlock, - Block *afterWhileBlock) { - auto condOp = cast(headerBlock->getTerminator()); - rewriter.setInsertionPoint(condOp); - rewriter.create(loc, condOp.getCondition(), - /*trueDest=*/bodyBlock, - /*trueOperands=*/condOp.getArgs(), - /*falseDest=*/afterWhileBlock, - /*falseOperands=*/condOp.getArgs()); - rewriter.eraseOp(condOp); - - auto yieldOp = cast(bodyBlock->getTerminator()); - rewriter.setInsertionPoint(yieldOp); - rewriter.create(loc, headerBlock, yieldOp.getOperands()); - rewriter.eraseOp(yieldOp); +FailureOr mlir::pto::buildAsyncScratchTileValue( + ConversionPatternRewriter &rewriter, Location loc, Value originalScratch, + Value emittedScratch) { + Value scratch = peelUnrealized(emittedScratch); + if (auto opaqueTy = dyn_cast(scratch.getType())) { + StringRef typeStr = opaqueTy.getValue(); + if (typeStr.contains("Tile<") || typeStr.contains("ConvTile<")) + return scratch; } - LogicalResult matchAndRewrite(scf::WhileOp op, - PatternRewriter &rewriter) const override { - Operation *parentOp = op->getParentOp(); - if (parentOp && parentOp->hasTrait()) { - return rewriter.notifyMatchFailure( - op, "cannot lower scf.while inside a single-block parent region"); - } - - if (failed(validateWhileResultUses(op))) - return rewriter.notifyMatchFailure( - op, "unsupported: while results used outside the parent block"); - - auto loc = op.getLoc(); - Block *afterWhileBlock = splitAfterWhileBlock(rewriter, op); - addWhileExitArguments(rewriter, op, loc, afterWhileBlock); - Block *headerBlock = createWhileHeaderBlock(rewriter, op, loc, - afterWhileBlock); - Block *bodyBlock = createWhileBodyBlock(rewriter, op, loc, afterWhileBlock); - - // Move the before/after region bodies into the new CFG blocks. - Block &afterRegionBlock = op.getAfter().front(); - rewriter.mergeBlocks(&op.getBefore().front(), headerBlock, - headerBlock->getArguments()); - rewriter.mergeBlocks(&afterRegionBlock, bodyBlock, bodyBlock->getArguments()); - rewriteWhileTerminators(rewriter, loc, headerBlock, bodyBlock, - afterWhileBlock); - - // Replace scf.while itself with a branch to the header. - rewriter.setInsertionPoint(op); - rewriter.create(loc, headerBlock, op.getInits()); - rewriter.eraseOp(op); - return success(); - } -}; + auto memTy = dyn_cast(originalScratch.getType()); + if (!memTy) + return failure(); -// Lower cf.switch into chained comparisons and cf.cond_br/cf.br. -// -// EmitC C++ translation currently supports cf.br/cf.cond_br, but not cf.switch. -struct CFSwitchToCondBr : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - static SmallVector> - collectSwitchCaseOperands(cf::SwitchOp op) { - SmallVector> caseOperands; - caseOperands.reserve(op.getCaseDestinations().size()); - for (auto range : op.getCaseOperands()) - caseOperands.emplace_back(range.begin(), range.end()); - return caseOperands; - } + ArrayRef shape = memTy.getShape(); + if (!memTy.hasStaticShape() || shape.empty() || shape.size() > 2) + return failure(); - static SmallVector getSwitchCaseValues(cf::SwitchOp op) { - SmallVector caseValues; - if (auto caseValuesAttr = op.getCaseValues()) { - for (APInt value : caseValuesAttr->getValues()) - caseValues.push_back(value); - } - return caseValues; - } + int64_t rows = shape.size() == 1 ? 1 : shape[0]; + int64_t cols = shape.size() == 1 ? shape[0] : shape[1]; - static SmallVector createSwitchCheckBlocks(PatternRewriter &rewriter, - Region *parentRegion, - Block *curBlock, - size_t numCases) { - auto insertPt = std::next(curBlock->getIterator()); - SmallVector checkBlocks; - checkBlocks.reserve(numCases); - for (size_t i = 0; i < numCases; ++i) - checkBlocks.push_back(rewriter.createBlock(parentRegion, insertPt)); - return checkBlocks; + auto *ctx = rewriter.getContext(); + pto::TileBufConfigAttr configAttr = pto::TileBufConfigAttr::getDefault(ctx); + if (auto bind = originalScratch.getDefiningOp()) { + configAttr = bind.getConfig(); + } else if (auto cast = originalScratch.getDefiningOp()) { + if (auto config = cast.getConfig()) + configAttr = *config; } - static LogicalResult populateSwitchCheckBlocks( - PatternRewriter &rewriter, Location loc, Value flag, IntegerType flagTy, - ArrayRef caseValues, ArrayRef caseDests, - ArrayRef> caseOperands, Block *defaultDest, - ValueRange defaultOperands, ArrayRef checkBlocks, - cf::SwitchOp op) { - for (size_t i = 0; i < caseDests.size(); ++i) { - rewriter.setInsertionPointToEnd(checkBlocks[i]); - APInt caseVal = caseValues[i]; - if (caseVal.getBitWidth() != flagTy.getWidth()) { - return rewriter.notifyMatchFailure( - op, "case value bitwidth doesn't match flag type"); - } + int32_t fractal = 512; + if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) + fractal = frAttr.getInt(); - Value caseConst = rewriter.create( - loc, flagTy, rewriter.getIntegerAttr(flagTy, caseVal)); - Value cond = rewriter.create( - loc, arith::CmpIPredicate::eq, flag, caseConst); - Block *falseDest = - (i + 1 < checkBlocks.size()) ? checkBlocks[i + 1] : defaultDest; - ValueRange falseOperands = - (i + 1 < checkBlocks.size()) ? ValueRange{} : defaultOperands; - rewriter.create(loc, cond, caseDests[i], - caseOperands[i], falseDest, - falseOperands); - } - return success(); - } + Type elemTy = memTy.getElementType(); + pto::BLayout blayout = getTileBufBLayoutValue(configAttr); + int64_t templateRows = renderTileTemplateDim(rows, elemTy, blayout, 0); + int64_t templateCols = renderTileTemplateDim(cols, elemTy, blayout, 1); + std::string elemTypeStr = getEmitCScalarTypeToken(elemTy); + std::string tileTypeStr = + "Tile"; - LogicalResult matchAndRewrite(cf::SwitchOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Operation *parentOp = op->getParentOp(); - if (parentOp && parentOp->hasTrait()) { - return rewriter.notifyMatchFailure( - op, "cannot lower cf.switch inside a single-block parent region"); - } + Value tile = rewriter + .create( + loc, emitc::OpaqueType::get(ctx, tileTypeStr), + emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + auto addr = rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + Value scratchAddr = + rewriter + .create(loc, emitc::OpaqueType::get(ctx, "uint64_t"), + "reinterpret_cast", ArrayAttr{}, addr, + ValueRange{scratch}) + .getResult(0); + rewriter.create(loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{tile, scratchAddr}); + return tile; +} - Block *curBlock = op->getBlock(); - Region *parentRegion = curBlock->getParent(); +//===----------------------------------------------------------------------===// +// pto.pointer_cast lowering +//===----------------------------------------------------------------------=== +struct PTOMScatterToMSCATTER : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - Value flag = op.getFlag(); - auto flagTy = dyn_cast(flag.getType()); - if (!flagTy) - return rewriter.notifyMatchFailure(op, "expected integer switch flag"); + LogicalResult matchAndRewrite(pto::MScatterOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *ctx = rewriter.getContext(); + Value src = peelUnrealized(adaptor.getSrc()); + Value idx = peelUnrealized(adaptor.getIdx()); + Value mem = peelUnrealized(adaptor.getMem()); - SmallVector defaultOperands(op.getDefaultOperands().begin(), - op.getDefaultOperands().end()); - Block *defaultDest = op.getDefaultDestination(); + Value memArg = maybeWrapGlobalMemrefAsGlobalTensor( + rewriter, op.getLoc(), mem, op.getMem().getType(), op.getOperation()); - SmallVector caseDests(op.getCaseDestinations().begin(), - op.getCaseDestinations().end()); - SmallVector> caseOperands = collectSwitchCaseOperands(op); + auto scatterAtomicTok = [&](pto::ScatterAtomicOp atomic) -> StringRef { + switch (atomic) { + case pto::ScatterAtomicOp::None: + return "pto::ScatterAtomicOp::None"; + case pto::ScatterAtomicOp::Add: + return "pto::ScatterAtomicOp::Add"; + case pto::ScatterAtomicOp::Max: + return "pto::ScatterAtomicOp::Max"; + case pto::ScatterAtomicOp::Min: + return "pto::ScatterAtomicOp::Min"; + } + llvm_unreachable("unknown ScatterAtomicOp"); + }; + auto scatterOobTok = [&](pto::ScatterOOB mode) -> StringRef { + switch (mode) { + case pto::ScatterOOB::Undefined: + return "pto::ScatterOOB::Undefined"; + case pto::ScatterOOB::Skip: + return "pto::ScatterOOB::Skip"; + case pto::ScatterOOB::Clamp: + return "pto::ScatterOOB::Clamp"; + case pto::ScatterOOB::Wrap: + return "pto::ScatterOOB::Wrap"; + } + llvm_unreachable("unknown ScatterOOB"); + }; - if (caseDests.empty()) { - rewriter.replaceOpWithNewOp(op, defaultDest, defaultOperands); - return success(); + SmallVector templateArgVec; + const bool rowCoalesce = + isRowCoalescedMGatherIndexType(op.getSrc().getType(), op.getIdx().getType()); + templateArgVec.push_back(emitc::OpaqueAttr::get( + ctx, rowCoalesce ? "pto::Coalesce::Row" : "pto::Coalesce::Elem")); + if (op.getScatterAtomicOp() != pto::ScatterAtomicOp::None || + op.getScatterOob() != pto::ScatterOOB::Undefined) { + templateArgVec.push_back(emitc::OpaqueAttr::get( + ctx, scatterAtomicTok(op.getScatterAtomicOp()))); + if (op.getScatterOob() != pto::ScatterOOB::Undefined) + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, scatterOobTok(op.getScatterOob()))); } + ArrayAttr templateArgs = rewriter.getArrayAttr(templateArgVec); - if (!op.getCaseValues()) - return rewriter.notifyMatchFailure(op, "missing case_values"); - SmallVector caseValues = getSwitchCaseValues(op); - - if (caseValues.size() != caseDests.size()) - return rewriter.notifyMatchFailure(op, "case_values/destinations mismatch"); - if (caseOperands.size() != caseDests.size()) - return rewriter.notifyMatchFailure(op, "case_operands/destinations mismatch"); - - SmallVector checkBlocks = - createSwitchCheckBlocks(rewriter, parentRegion, curBlock, - caseDests.size()); - if (failed(populateSwitchCheckBlocks(rewriter, loc, flag, flagTy, - caseValues, caseDests, caseOperands, - defaultDest, defaultOperands, - checkBlocks, op))) { - return failure(); - } + rewriter.create( + op.getLoc(), TypeRange{}, "MSCATTER", + ArrayAttr{}, templateArgs, + ValueRange{memArg, src, idx}); - // Replace the switch terminator with a branch into the first check block. - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp(op, checkBlocks.front(), - ValueRange{}); + rewriter.eraseOp(op); return success(); } }; - -} // namespace - static void populatePTOToEmitCPatterns(RewritePatternSet &patterns, TypeConverter &typeConverter, MLIRContext *ctx, DataFlowSolver &solver, PTOArch targetArch) { (void)solver; - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx, "pto.set_flag_dyn", - "set_flag"); - patterns.add(typeConverter, ctx, "pto.wait_flag_dyn", - "wait_flag"); - // Backward-compatible aliases used in some downstream branches. - patterns.add(typeConverter, ctx, "pto.set_flag_d", - "set_flag"); - patterns.add(typeConverter, ctx, "pto.wait_flag_d", - "wait_flag"); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add( - typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); + populatePTOToEmitCArithPatterns(patterns, typeConverter, ctx); + populatePTOToEmitCRuntimeOpPatterns(patterns, typeConverter, ctx, targetArch); + populatePTOToEmitCMemoryOpPatterns(patterns, typeConverter, ctx); + populatePTOToEmitCTilePatterns(patterns, typeConverter, ctx); + populatePTOToEmitCSimpleOpPatterns(patterns, typeConverter, ctx); + populatePTOToEmitCTileMaterializationPatterns(patterns, typeConverter, ctx); + populatePTOToEmitCSyncPatterns(patterns, typeConverter, ctx, targetArch); patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add( - typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add>( - typeConverter, ctx); - patterns.add>( - typeConverter, ctx); - patterns.add>( - typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add>(typeConverter, - ctx); - patterns.add>(typeConverter, - ctx); - patterns.add>(typeConverter, - ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add>(typeConverter, ctx); - patterns.add>(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add>(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add>(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); + populatePTOToEmitCKernelOpPatterns(patterns, typeConverter, ctx); patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add>( - typeConverter, ctx, - "pto::comm::TPUT_ASYNC"); - patterns.add>( - typeConverter, ctx, - "pto::comm::TGET_ASYNC"); - patterns.add>(typeConverter, ctx, - "pto::comm::TPUT"); - patterns.add>(typeConverter, ctx, - "pto::comm::TGET"); - patterns.add>(typeConverter, ctx, - "pto::comm::TNOTIFY"); - patterns.add>(typeConverter, ctx, - "pto::comm::TWAIT"); - patterns.add>(typeConverter, ctx, - "pto::comm::TTEST"); - patterns.add>(typeConverter, ctx, - "TBROADCAST"); - patterns.add>(typeConverter, ctx, - "TGATHER"); - patterns.add>(typeConverter, ctx, - "TSCATTER"); - patterns.add>(typeConverter, ctx, - "TREDUCE"); - patterns.add>( - typeConverter, ctx, "PTOAS__ASYNC_EVENT_WAIT"); - patterns.add>( - typeConverter, ctx, "PTOAS__ASYNC_EVENT_TEST"); - patterns.add(typeConverter, ctx, targetArch); - patterns.add(typeConverter, ctx, targetArch); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx, targetArch); - patterns.add(typeConverter, ctx, targetArch); - patterns.add(typeConverter, ctx, targetArch); - patterns.add(typeConverter, ctx, targetArch); - patterns.add(typeConverter, ctx, targetArch); - patterns.add(typeConverter, ctx, targetArch); - patterns.add>(typeConverter, ctx); - patterns.add>(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add(typeConverter, ctx); - patterns.add< - PTOTMatmulBiasToTMATMUL_BIAS, - PTOTMatmulMXToTMATMUL_MX, - PTOTMatmulMXAccToTMATMUL_MX_ACC, - PTOTMatmulMXBiasToTMATMUL_MX_BIAS, - PTOTMatmulBiasToTMATMUL_BIAS, - PTOTMatmulMXToTMATMUL_MX, - PTOTMatmulMXAccToTMATMUL_MX_ACC, - PTOTMatmulMXBiasToTMATMUL_MX_BIAS, - PTOTGemvBiasToTGEMV_BIAS, - PTOTGemvMXToTGEMV_MX, - PTOTGemvMXAccToTGEMV_MX, - PTOTGemvMXBiasToTGEMV_MX, - PTOBarrierToEmitC - >(typeConverter, ctx); - - patterns.add(typeConverter, ctx); - - populateSCFToEmitCConversionPatterns(patterns); - // Keep CFG-style branches type-consistent when block argument types are - // converted (e.g. after lowering scf.while to cf.br/cf.cond_br). - populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); + populatePTOToEmitCCommPatterns(patterns, typeConverter, ctx, targetArch); + populatePTOToEmitCControlFlowPatterns(patterns, typeConverter, ctx); } //===----------------------------------------------------------------------===// @@ -12617,67 +2210,8 @@ static AICORE inline void ptoas_auto_sync_tail( } // 1.5 Pre-lower SCF constructs not handled by SCFToEmitC. - { - // scf.while / scf.index_switch are lowered via CFG blocks. This is not - // possible inside ops that require single-block regions (e.g. scf.for / - // scf.if). If we see such nesting, lower the entire function to the - // ControlFlow dialect first. - bool needsAnySCFToCF = false; - for (auto func : mop.getOps()) { - if (needsWholeFunctionSCFToCF(func)) { - needsAnySCFToCF = true; - break; - } - } - if (needsAnySCFToCF) { - RewritePatternSet scfToCfPatterns(ctx); - populateSCFToControlFlowConversionPatterns(scfToCfPatterns); - FrozenRewritePatternSet frozenSCFToCF(std::move(scfToCfPatterns)); - - ConversionTarget scfToCfTarget(*ctx); - // Only eliminate the single-block SCF constructs; we'll pre-lower - // scf.while/index_switch/execute_region ourselves afterwards. - scfToCfTarget.addIllegalOp(); - scfToCfTarget.markUnknownOpDynamicallyLegal( - [](Operation *) { return true; }); - - for (auto func : mop.getOps()) { - if (!needsWholeFunctionSCFToCF(func)) - continue; - if (failed(applyPartialConversion(func, scfToCfTarget, - frozenSCFToCF))) { - func.emitError() - << "failed to lower nested SCF to ControlFlow (SCFToCF)"; - return signalPassFailure(); - } - } - } - - RewritePatternSet scfLoweringPatterns(ctx); - scfLoweringPatterns.add(ctx); - (void)applyPatternsAndFoldGreedily(mop, std::move(scfLoweringPatterns)); - - bool hasUnsupportedSCF = false; - mop.walk([&](Operation *op) { - if (isa(op)) { - hasUnsupportedSCF = true; - op->emitError() << "Unsupported SCF op remained after pre-lowering"; - return WalkResult::interrupt(); - } - if (isa(op)) { - hasUnsupportedSCF = true; - op->emitError() - << "Unsupported CF op remained after pre-lowering: cf.switch"; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - if (hasUnsupportedSCF) - return signalPassFailure(); - } + if (failed(runPTOToEmitCSCFPreLowering(mop, ctx))) + return signalPassFailure(); PTOToEmitCTypeConverter typeConverter(ctx, targetArch); diff --git a/lib/PTO/Transforms/PTOToEmitCArith.cpp b/lib/PTO/Transforms/PTOToEmitCArith.cpp new file mode 100644 index 000000000..19a0717a7 --- /dev/null +++ b/lib/PTO/Transforms/PTOToEmitCArith.cpp @@ -0,0 +1,1782 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- PTOToEmitCArith.cpp ------------------------------------------------===// +//===----------------------------------------------------------------------===// + +#include "PTOToEmitCInternal.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/APSInt.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include + +using namespace mlir; + +namespace mlir::pto { +namespace { + +static constexpr unsigned kPTOIndexBitWidth = 32; + +static emitc::OpaqueType getSignedIntOpaqueType(MLIRContext *ctx, unsigned bitWidth); +static emitc::OpaqueType getUnsignedIntOpaqueType(MLIRContext *ctx, unsigned bitWidth); +static emitc::OpaqueType getWiderSignedIntOpaqueType(MLIRContext *ctx, unsigned bitWidth); +static emitc::OpaqueType getWiderUnsignedIntOpaqueType(MLIRContext *ctx, unsigned bitWidth); +static Value makeEmitCOpaqueConstant(ConversionPatternRewriter &rewriter, + Location loc, Type type, + llvm::StringRef literal); +static Value makeEmitCIntConstant(ConversionPatternRewriter &rewriter, + Location loc, Type type, int64_t value); +static FailureOr buildEmitCOpaqueConstantLiteral(Type targetType, + Attribute valueAttr); +static Value emitCCast(ConversionPatternRewriter &rewriter, Location loc, + Type dstType, Value src); +static Value castSignlessIntToUnsignedSameWidth(ConversionPatternRewriter &rewriter, + Location loc, Value v, + unsigned bitWidth); + +//===----------------------------------------------------------------------===// +// Arith -> EmitC (full dialect coverage for scalar ops) +//===----------------------------------------------------------------------===// + +template +struct ArithSimpleBinaryToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type dstTy = this->getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getOperands()); + return success(); + } +}; + +// Integer bitwise ops (andi/ori/xori) on signless integers: perform in unsigned +// to avoid signedness pitfalls, then cast back. +template +struct ArithUnsignedBitwiseBinaryToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = this->getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + if (bitWidth == 1) { + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value resU = rewriter.create(loc, uTy, lhsU, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, resU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithDivUIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::DivUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value divU = rewriter.create(loc, uTy, lhsU, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, divU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithRemUIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::RemUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value remU = rewriter.create(loc, uTy, lhsU, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, remU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithCeilDivUIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::CeilDivUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value one = makeEmitCIntConstant(rewriter, loc, uTy, 1); + Value rhsMinusOne = rewriter.create(loc, uTy, rhsU, one); + Value num = rewriter.create(loc, uTy, lhsU, rhsMinusOne); + Value divU = rewriter.create(loc, uTy, num, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, divU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithCeilDivSIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::CeilDivSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + Value zero = makeEmitCIntConstant(rewriter, loc, dstTy, 0); + Value one = makeEmitCIntConstant(rewriter, loc, dstTy, 1); + + Value q0 = rewriter.create(loc, dstTy, adaptor.getLhs(), + adaptor.getRhs()); + Value r = rewriter.create(loc, dstTy, adaptor.getLhs(), + adaptor.getRhs()); + + Value rNeZero = rewriter.create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::ne, r, + zero); + Value lhsLt0 = + rewriter.create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, adaptor.getLhs(), + zero); + Value rhsLt0 = + rewriter.create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, adaptor.getRhs(), + zero); + Value signsSame = + rewriter.create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::eq, lhsLt0, rhsLt0); + Value adjust = + rewriter.create(loc, rewriter.getI1Type(), + rNeZero, signsSame); + + Value qPlusOne = rewriter.create(loc, dstTy, q0, one); + Value result = rewriter.create(loc, dstTy, adjust, + qPlusOne, q0); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithFloorDivSIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::FloorDivSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + Value zero = makeEmitCIntConstant(rewriter, loc, dstTy, 0); + Value one = makeEmitCIntConstant(rewriter, loc, dstTy, 1); + + Value q0 = rewriter.create(loc, dstTy, adaptor.getLhs(), + adaptor.getRhs()); + Value r = rewriter.create(loc, dstTy, adaptor.getLhs(), + adaptor.getRhs()); + + Value rNeZero = rewriter.create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::ne, r, + zero); + Value lhsLt0 = + rewriter.create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, adaptor.getLhs(), + zero); + Value rhsLt0 = + rewriter.create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, adaptor.getRhs(), + zero); + Value signsDifferent = + rewriter.create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::ne, lhsLt0, rhsLt0); + Value adjust = + rewriter.create(loc, rewriter.getI1Type(), + rNeZero, signsDifferent); + + Value qMinusOne = rewriter.create(loc, dstTy, q0, one); + Value result = rewriter.create(loc, dstTy, adjust, + qMinusOne, q0); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithShiftLeftToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::ShLIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + if (bitWidth == 1) { + // Compute on u8 and truncate to i1. + auto u8Ty = getUnsignedIntOpaqueType(rewriter.getContext(), 8); + Value lhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getLhs()); + Value rhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getRhs()); + Value sh = rewriter.create(loc, u8Ty, lhsU8, + rhsU8); + Value masked = + rewriter.create(loc, u8Ty, sh, + makeEmitCIntConstant(rewriter, loc, + u8Ty, 1)); + rewriter.replaceOp(op, emitCCast(rewriter, loc, dstTy, masked)); + return success(); + } + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value shU = + rewriter.create(loc, uTy, lhsU, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, shU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithShiftRightUIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::ShRUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + if (bitWidth == 1) { + // (x >> y) on i1 is either x (y==0) or 0 (y!=0); approximate in u8. + auto u8Ty = getUnsignedIntOpaqueType(rewriter.getContext(), 8); + Value lhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getLhs()); + Value rhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getRhs()); + Value sh = rewriter.create(loc, u8Ty, lhsU8, + rhsU8); + Value masked = + rewriter.create(loc, u8Ty, sh, + makeEmitCIntConstant(rewriter, loc, + u8Ty, 1)); + rewriter.replaceOp(op, emitCCast(rewriter, loc, dstTy, masked)); + return success(); + } + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value shU = + rewriter.create(loc, uTy, lhsU, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, shU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithShiftRightSIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::ShRSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + if (bitWidth == 1) { + // (x >> y) on i1 is either x (y==0) or 0 (y!=0); approximate in u8. + auto u8Ty = getUnsignedIntOpaqueType(rewriter.getContext(), 8); + Value lhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getLhs()); + Value rhsU8 = emitCCast(rewriter, loc, u8Ty, adaptor.getRhs()); + Value sh = rewriter.create(loc, u8Ty, lhsU8, + rhsU8); + Value masked = + rewriter.create(loc, u8Ty, sh, + makeEmitCIntConstant(rewriter, loc, + u8Ty, 1)); + rewriter.replaceOp(op, emitCCast(rewriter, loc, dstTy, masked)); + return success(); + } + + // Signed arithmetic shift; cast RHS to unsigned to interpret shift amount. + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value sh = + rewriter.create(loc, dstTy, adaptor.getLhs(), + rhsU); + rewriter.replaceOp(op, sh); + return success(); + } +}; + +struct ArithNegFToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::NegFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getOperand()); + return success(); + } +}; + +struct ArithRemFToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::RemFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + + // Use builtin `fmod` when possible. For f16, compute in float and cast back. + Type callTy = dstTy; + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + + if (auto opFloatTy = dyn_cast(op.getType())) { + if (opFloatTy.isF16()) { + auto f32Ty = emitc::OpaqueType::get(rewriter.getContext(), "float"); + lhs = emitCCast(rewriter, loc, f32Ty, lhs); + rhs = emitCCast(rewriter, loc, f32Ty, rhs); + callTy = f32Ty; + } + } + + // Prefer `__builtin_fmod*` to avoid relying on extra headers. + llvm::StringRef callee = "__builtin_fmod"; + if (auto opFloatTy = dyn_cast(op.getType())) { + if (opFloatTy.isF32() || opFloatTy.isF16()) + callee = "__builtin_fmodf"; + else if (opFloatTy.isF64()) + callee = "__builtin_fmod"; + } + + auto call = rewriter.create( + loc, TypeRange{callTy}, callee, ValueRange{lhs, rhs}, + /*args=*/ArrayAttr{}, /*template_args=*/ArrayAttr{}); + Value result = call.getResult(0); + if (callTy != dstTy) + result = emitCCast(rewriter, loc, dstTy, result); + + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithSelectToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getCondition().getType().isInteger(1)) + return rewriter.notifyMatchFailure( + op, "only scalar i1 conditions supported for arith.select"); + + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + + auto cond = + rewriter.create(op.getLoc(), dstTy, + adaptor.getCondition(), + adaptor.getTrueValue(), + adaptor.getFalseValue()); + rewriter.replaceOp(op, cond.getResult()); + return success(); + } +}; + +struct ArithExtUIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto dstIntTy = dyn_cast(op.getType()); + auto srcIntTy = dyn_cast(op.getIn().getType()); + if (!dstIntTy || !srcIntTy) + return rewriter.notifyMatchFailure(op, "expected scalar integer types"); + + Type dstTy = getTypeConverter()->convertType(dstIntTy); + if (!dstTy) + return failure(); + + // i1 -> iN: bool to integer already behaves as 0/1. + if (srcIntTy.getWidth() == 1) { + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); + return success(); + } + + auto uDstTy = + getUnsignedIntOpaqueType(rewriter.getContext(), dstIntTy.getWidth()); + Value srcU = + castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getIn(), + srcIntTy.getWidth()); + Value extU = emitCCast(rewriter, loc, uDstTy, srcU); + Value result = emitCCast(rewriter, loc, dstTy, extU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithExtSIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto dstIntTy = dyn_cast(op.getType()); + auto srcIntTy = dyn_cast(op.getIn().getType()); + if (!dstIntTy || !srcIntTy) + return rewriter.notifyMatchFailure(op, "expected scalar integer types"); + + Type dstTy = getTypeConverter()->convertType(dstIntTy); + if (!dstTy) + return failure(); + + // i1 sign-extension: 0 -> 0, 1 -> -1. + if (srcIntTy.getWidth() == 1) { + Value zero = makeEmitCIntConstant(rewriter, loc, dstTy, 0); + Value asInt = emitCCast(rewriter, loc, dstTy, adaptor.getIn()); + Value neg = rewriter.create(loc, dstTy, zero, asInt).getResult(); + rewriter.replaceOp(op, neg); + return success(); + } + + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); + return success(); + } +}; + +template +struct ArithCastToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type dstTy = this->getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); + return success(); + } +}; + +struct ArithIndexCastUIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::IndexCastUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + + // MemRef casts are handled elsewhere; for safety, fall back to emitc.cast. + if (isa(op.getIn().getType()) || isa(op.getType())) { + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); + return success(); + } + + auto getBW = [](Type t) -> std::optional { + if (auto i = dyn_cast(t)) + return i.getWidth(); + if (isa(t)) + return kPTOIndexBitWidth; + return std::nullopt; + }; + + auto srcBW = getBW(op.getIn().getType()); + auto dstBW = getBW(op.getType()); + if (!srcBW || !dstBW) + return rewriter.notifyMatchFailure(op, "unsupported index_castui types"); + + if (*dstBW <= *srcBW) { + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); + return success(); + } + + auto uSrcTy = getUnsignedIntOpaqueType(rewriter.getContext(), *srcBW); + auto uDstTy = getUnsignedIntOpaqueType(rewriter.getContext(), *dstBW); + Value srcU = emitCCast(rewriter, loc, uSrcTy, adaptor.getIn()); + Value extU = emitCCast(rewriter, loc, uDstTy, srcU); + Value result = emitCCast(rewriter, loc, dstTy, extU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithUIToFPToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto srcIntTy = dyn_cast(op.getIn().getType()); + if (!srcIntTy) + return rewriter.notifyMatchFailure(op, "expected scalar integer input"); + + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + + // Convert via an unsigned integer type of the same width. + if (srcIntTy.getWidth() == 1) { + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); + return success(); + } + Value srcU = + castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getIn(), + srcIntTy.getWidth()); + Value fp = rewriter.create(loc, dstTy, srcU).getResult(); + rewriter.replaceOp(op, fp); + return success(); + } +}; + +struct ArithFPToUIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::FPToUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto dstIntTy = dyn_cast(op.getType()); + if (!dstIntTy) + return rewriter.notifyMatchFailure(op, "expected scalar integer result"); + + Type dstTy = getTypeConverter()->convertType(dstIntTy); + if (!dstTy) + return failure(); + + auto uDstTy = + getUnsignedIntOpaqueType(rewriter.getContext(), dstIntTy.getWidth()); + Value asU = rewriter.create(loc, uDstTy, adaptor.getIn()).getResult(); + Value result = emitCCast(rewriter, loc, dstTy, asU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithBitcastToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + + // For pointer-like types, a regular cast is fine. + if (isa(dstTy)) { + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); + return success(); + } + + // Only support scalar int/float/index bitcasts here. + auto srcTy = op.getIn().getType(); + auto dstOrigTy = op.getType(); + + auto getBitWidth = [](Type t) -> std::optional { + if (auto it = dyn_cast(t)) + return it.getWidth(); + if (auto ft = dyn_cast(t)) + return ft.getWidth(); + if (isa(t)) + return kPTOIndexBitWidth; + return std::nullopt; + }; + auto srcBW = getBitWidth(srcTy); + auto dstBW = getBitWidth(dstOrigTy); + if (!srcBW || !dstBW || *srcBW != *dstBW) + return rewriter.notifyMatchFailure(op, "bitcast requires equal bitwidth"); + + // Determine the template argument from the destination type string. + auto dstOpaque = dyn_cast(dstTy); + if (!dstOpaque) + return rewriter.notifyMatchFailure(op, "expected emitc opaque dest type"); + + auto templateArgs = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), + dstOpaque.getValue())}); + auto call = rewriter.create( + loc, TypeRange{dstTy}, "ptoas_bitcast", /*operands=*/ValueRange{adaptor.getIn()}, + /*args=*/ArrayAttr{}, /*template_args=*/templateArgs); + rewriter.replaceOp(op, call.getResult(0)); + return success(); + } +}; + +// arith.cmpf lowering with ordered/unordered semantics. +struct ArithCmpFToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + struct CmpFConfig { + bool unordered = false; + emitc::CmpPredicate predicate = emitc::CmpPredicate::eq; + }; + + static Value isNaN(ConversionPatternRewriter &rewriter, Location loc, + Value v) { + return rewriter + .create(loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, + v, v) + .getResult(); + } + + static Value isNotNaN(ConversionPatternRewriter &rewriter, Location loc, + Value v) { + return rewriter + .create(loc, rewriter.getI1Type(), emitc::CmpPredicate::eq, + v, v) + .getResult(); + } + + static std::optional buildSpecialCmpFResult( + arith::CmpFPredicate predicate, ConversionPatternRewriter &rewriter, + Location loc, Type i1Ty, Value lhs, Value rhs) { + switch (predicate) { + case arith::CmpFPredicate::AlwaysFalse: + return makeEmitCOpaqueConstant(rewriter, loc, i1Ty, "false"); + case arith::CmpFPredicate::AlwaysTrue: + return makeEmitCOpaqueConstant(rewriter, loc, i1Ty, "true"); + case arith::CmpFPredicate::ORD: + return rewriter.create( + loc, i1Ty, isNotNaN(rewriter, loc, lhs), + isNotNaN(rewriter, loc, rhs)) + .getResult(); + case arith::CmpFPredicate::UNO: + return rewriter.create( + loc, i1Ty, isNaN(rewriter, loc, lhs), + isNaN(rewriter, loc, rhs)) + .getResult(); + default: + return std::nullopt; + } + } + + static std::optional + getCmpFConfig(arith::CmpFPredicate predicate) { + switch (predicate) { + case arith::CmpFPredicate::OEQ: + return CmpFConfig{false, emitc::CmpPredicate::eq}; + case arith::CmpFPredicate::OGT: + return CmpFConfig{false, emitc::CmpPredicate::gt}; + case arith::CmpFPredicate::OGE: + return CmpFConfig{false, emitc::CmpPredicate::ge}; + case arith::CmpFPredicate::OLT: + return CmpFConfig{false, emitc::CmpPredicate::lt}; + case arith::CmpFPredicate::OLE: + return CmpFConfig{false, emitc::CmpPredicate::le}; + case arith::CmpFPredicate::ONE: + return CmpFConfig{false, emitc::CmpPredicate::ne}; + case arith::CmpFPredicate::UEQ: + return CmpFConfig{true, emitc::CmpPredicate::eq}; + case arith::CmpFPredicate::UGT: + return CmpFConfig{true, emitc::CmpPredicate::gt}; + case arith::CmpFPredicate::UGE: + return CmpFConfig{true, emitc::CmpPredicate::ge}; + case arith::CmpFPredicate::ULT: + return CmpFConfig{true, emitc::CmpPredicate::lt}; + case arith::CmpFPredicate::ULE: + return CmpFConfig{true, emitc::CmpPredicate::le}; + case arith::CmpFPredicate::UNE: + return CmpFConfig{true, emitc::CmpPredicate::ne}; + default: + return std::nullopt; + } + } + + static Value buildCmpFResult(const CmpFConfig &config, + ConversionPatternRewriter &rewriter, + Location loc, Type i1Ty, Value lhs, Value rhs) { + Value cmp = rewriter + .create(loc, i1Ty, config.predicate, lhs, rhs) + .getResult(); + Value unord = rewriter.create( + loc, i1Ty, isNaN(rewriter, loc, lhs), isNaN(rewriter, loc, rhs)); + if (config.unordered) + return rewriter + .create(loc, i1Ty, unord, cmp) + .getResult(); + Value ord = rewriter.create( + loc, i1Ty, isNotNaN(rewriter, loc, lhs), isNotNaN(rewriter, loc, rhs)); + return rewriter + .create(loc, i1Ty, ord, cmp) + .getResult(); + } + + LogicalResult matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isa(op.getLhs().getType())) + return rewriter.notifyMatchFailure(op, "cmpf only supported on scalar floats"); + + auto loc = op.getLoc(); + auto i1Ty = rewriter.getI1Type(); + if (auto special = buildSpecialCmpFResult(op.getPredicate(), rewriter, loc, + i1Ty, adaptor.getLhs(), + adaptor.getRhs())) { + rewriter.replaceOp(op, *special); + return success(); + } + + auto config = getCmpFConfig(op.getPredicate()); + if (!config) + return rewriter.notifyMatchFailure(op, "unsupported cmpf predicate"); + rewriter.replaceOp(op, buildCmpFResult(*config, rewriter, loc, i1Ty, + adaptor.getLhs(), adaptor.getRhs())); + return success(); + } +}; + +struct ArithAddUIExtendedToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getSum().getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, + "expected scalar integer or index operands"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + SmallVector newResultTypes; + if (failed(getTypeConverter()->convertTypes(op->getResultTypes(), + newResultTypes))) + return failure(); + if (newResultTypes.size() != 2) + return failure(); + + Type sumDstTy = newResultTypes[0]; + Type overflowDstTy = newResultTypes[1]; + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + auto wideTy = getWiderUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value lhsWide = emitCCast(rewriter, loc, wideTy, lhsU); + Value rhsWide = emitCCast(rewriter, loc, wideTy, rhsU); + Value sumWide = + rewriter.create(loc, wideTy, lhsWide, rhsWide).getResult(); + + Value sumN = emitCCast(rewriter, loc, uTy, sumWide); + Value sum = emitCCast(rewriter, loc, sumDstTy, sumN); + + Value shiftAmt = makeEmitCIntConstant(rewriter, loc, wideTy, bitWidth); + Value high = rewriter + .create(loc, wideTy, sumWide, + shiftAmt) + .getResult(); + Value zeroWide = makeEmitCIntConstant(rewriter, loc, wideTy, 0); + Value overflow = + rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::ne, high, zeroWide) + .getResult(); + overflow = emitCCast(rewriter, loc, overflowDstTy, overflow); + + rewriter.replaceOp(op, {sum, overflow}); + return success(); + } +}; + +template +struct ArithMulExtendedToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getResult(0).getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, + "expected scalar integer or index operands"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + SmallVector newResultTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), + newResultTypes))) + return failure(); + if (newResultTypes.size() != 2) + return failure(); + + Type lowDstTy = newResultTypes[0]; + Type highDstTy = newResultTypes[1]; + + Type wideTy = isUnsigned ? (Type)getWiderUnsignedIntOpaqueType(rewriter.getContext(), + bitWidth) + : (Type)getWiderSignedIntOpaqueType(rewriter.getContext(), + bitWidth); + + Value lhsWide; + Value rhsWide; + if constexpr (isUnsigned) { + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + lhsWide = emitCCast(rewriter, loc, wideTy, lhsU); + rhsWide = emitCCast(rewriter, loc, wideTy, rhsU); + } else { + lhsWide = emitCCast(rewriter, loc, wideTy, adaptor.getLhs()); + rhsWide = emitCCast(rewriter, loc, wideTy, adaptor.getRhs()); + } + + Value prodWide = + rewriter.create(loc, wideTy, lhsWide, rhsWide).getResult(); + Value low = emitCCast(rewriter, loc, lowDstTy, prodWide); + + Value shiftAmt = makeEmitCIntConstant(rewriter, loc, wideTy, bitWidth); + Value highWide = rewriter + .create(loc, wideTy, prodWide, + shiftAmt) + .getResult(); + Value high = emitCCast(rewriter, loc, highDstTy, highWide); + + rewriter.replaceOp(op, {low, high}); + return success(); + } +}; + +using ArithMulSIExtendedToEmitC = + ArithMulExtendedToEmitC; +using ArithMulUIExtendedToEmitC = + ArithMulExtendedToEmitC; + +struct ArithMinMaxIToEmitCBase { + static Value makeSelect(ConversionPatternRewriter &rewriter, Location loc, + Type dstTy, Value cond, Value trueV, Value falseV) { + return rewriter + .create(loc, dstTy, cond, trueV, falseV) + .getResult(); + } +}; + +struct ArithMaxSIToEmitC : public OpConversionPattern, + ArithMinMaxIToEmitCBase { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::MaxSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + Value cond = rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, + adaptor.getLhs(), adaptor.getRhs()) + .getResult(); + Value res = makeSelect(rewriter, loc, dstTy, cond, adaptor.getRhs(), + adaptor.getLhs()); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ArithMinSIToEmitC : public OpConversionPattern, + ArithMinMaxIToEmitCBase { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::MinSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + Value cond = rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, + adaptor.getLhs(), adaptor.getRhs()) + .getResult(); + Value res = makeSelect(rewriter, loc, dstTy, cond, adaptor.getLhs(), + adaptor.getRhs()); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ArithMaxUIToEmitC : public OpConversionPattern, + ArithMinMaxIToEmitCBase { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::MaxUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + Value lhsU = + castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = + castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value cond = rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, lhsU, rhsU) + .getResult(); + Value res = makeSelect(rewriter, loc, dstTy, cond, adaptor.getRhs(), + adaptor.getLhs()); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ArithMinUIToEmitC : public OpConversionPattern, + ArithMinMaxIToEmitCBase { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::MinUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + Value lhsU = + castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = + castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value cond = rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, lhsU, rhsU) + .getResult(); + Value res = makeSelect(rewriter, loc, dstTy, cond, adaptor.getLhs(), + adaptor.getRhs()); + rewriter.replaceOp(op, res); + return success(); + } +}; + +// Floating-point max/min variants. +struct ArithFloatMinMaxToEmitCBase { + static Value isNaN(ConversionPatternRewriter &rewriter, Location loc, + Value v) { + return rewriter + .create(loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, + v, v) + .getResult(); + } + + static Value makeFZero(ConversionPatternRewriter &rewriter, Location loc, + Type ty) { + return makeEmitCOpaqueConstant(rewriter, loc, ty, "0.0f"); + } +}; + +struct ArithMaxNumFToEmitC : public OpConversionPattern, + ArithFloatMinMaxToEmitCBase { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::MaxNumFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + + Value lhsNaN = isNaN(rewriter, loc, adaptor.getLhs()); + Value rhsNaN = isNaN(rewriter, loc, adaptor.getRhs()); + + Value cmpLt = rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, + adaptor.getLhs(), adaptor.getRhs()) + .getResult(); + Value maxNoNaN = + rewriter + .create(loc, dstTy, cmpLt, adaptor.getRhs(), + adaptor.getLhs()) + .getResult(); + + Value rhsOrMax = + rewriter + .create(loc, dstTy, rhsNaN, adaptor.getLhs(), + maxNoNaN) + .getResult(); + Value res = + rewriter + .create(loc, dstTy, lhsNaN, adaptor.getRhs(), + rhsOrMax) + .getResult(); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ArithMinNumFToEmitC : public OpConversionPattern, + ArithFloatMinMaxToEmitCBase { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::MinNumFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type dstTy = getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + + Value lhsNaN = isNaN(rewriter, loc, adaptor.getLhs()); + Value rhsNaN = isNaN(rewriter, loc, adaptor.getRhs()); + + Value cmpLt = rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, + adaptor.getLhs(), adaptor.getRhs()) + .getResult(); + Value minNoNaN = + rewriter + .create(loc, dstTy, cmpLt, adaptor.getLhs(), + adaptor.getRhs()) + .getResult(); + + Value rhsOrMin = + rewriter + .create(loc, dstTy, rhsNaN, adaptor.getLhs(), + minNoNaN) + .getResult(); + Value res = + rewriter + .create(loc, dstTy, lhsNaN, adaptor.getRhs(), + rhsOrMin) + .getResult(); + rewriter.replaceOp(op, res); + return success(); + } +}; + +template +struct ArithMinMaxFPropagateNaNToEmitC : public OpConversionPattern, + ArithFloatMinMaxToEmitCBase { + using OpConversionPattern::OpConversionPattern; + + static Value buildPrimaryCandidate(ConversionPatternRewriter &rewriter, + Location loc, Type dstTy, Value lhs, + Value rhs) { + Value cmpLt = + rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::lt, lhs, rhs) + .getResult(); + return rewriter + .create( + loc, dstTy, cmpLt, isMaximum ? rhs : lhs, isMaximum ? lhs : rhs) + .getResult(); + } + + static Value buildSignBitValue(ConversionPatternRewriter &rewriter, + Location loc, Value lhs, FloatType floatTy) { + auto bitsTy = + getUnsignedIntOpaqueType(rewriter.getContext(), floatTy.getWidth()); + auto templateArgs = rewriter.getArrayAttr({emitc::OpaqueAttr::get( + rewriter.getContext(), cast(bitsTy).getValue())}); + Value lhsBits = + rewriter + .create(loc, TypeRange{bitsTy}, "ptoas_bitcast", + ValueRange{lhs}, ArrayAttr{}, + templateArgs) + .getResult(0); + Value oneBits = makeEmitCIntConstant(rewriter, loc, bitsTy, 1); + Value shiftAmount = + makeEmitCIntConstant(rewriter, loc, bitsTy, floatTy.getWidth() - 1); + Value signMask = rewriter + .create(loc, bitsTy, oneBits, + shiftAmount) + .getResult(); + return rewriter + .create(loc, bitsTy, lhsBits, signMask) + .getResult(); + } + + static Value buildSignedZeroCandidate(ConversionPatternRewriter &rewriter, + Location loc, Type dstTy, Value lhs, + Value rhs, FloatType floatTy) { + Value zero = makeFZero(rewriter, loc, dstTy); + Value equal = rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::eq, lhs, rhs) + .getResult(); + Value lhsZero = rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::eq, lhs, + zero) + .getResult(); + Value bothZero = rewriter + .create(loc, rewriter.getI1Type(), + equal, lhsZero) + .getResult(); + auto bitsTy = + getUnsignedIntOpaqueType(rewriter.getContext(), floatTy.getWidth()); + Value zeroBits = makeEmitCIntConstant(rewriter, loc, bitsTy, 0); + Value lhsIsNegZero = + rewriter + .create(loc, rewriter.getI1Type(), + emitc::CmpPredicate::ne, + buildSignBitValue(rewriter, loc, lhs, floatTy), + zeroBits) + .getResult(); + Value tie = rewriter + .create( + loc, dstTy, lhsIsNegZero, isMaximum ? rhs : lhs, + isMaximum ? lhs : rhs) + .getResult(); + return rewriter + .create(loc, dstTy, bothZero, tie, + buildPrimaryCandidate(rewriter, loc, dstTy, + lhs, rhs)) + .getResult(); + } + + static Value buildNaNPropagatingResult(ConversionPatternRewriter &rewriter, + Location loc, Type dstTy, Value lhs, + Value rhs, FloatType floatTy) { + Value lhsNaN = isNaN(rewriter, loc, lhs); + Value rhsNaN = isNaN(rewriter, loc, rhs); + Value noNaN = + buildSignedZeroCandidate(rewriter, loc, dstTy, lhs, rhs, floatTy); + Value rhsOrNoNaN = rewriter + .create(loc, dstTy, rhsNaN, rhs, + noNaN) + .getResult(); + return rewriter + .create(loc, dstTy, lhsNaN, lhs, rhsOrNoNaN) + .getResult(); + } + + LogicalResult + matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isa(op.getType())) + return rewriter.notifyMatchFailure(op, "expected scalar float type"); + + auto loc = op.getLoc(); + Type dstTy = this->getTypeConverter()->convertType(op.getType()); + if (!dstTy) + return failure(); + + auto floatTy = cast(op.getType()); + rewriter.replaceOp(op, buildNaNPropagatingResult( + rewriter, loc, dstTy, adaptor.getLhs(), + adaptor.getRhs(), floatTy)); + return success(); + } +}; + +using ArithMaximumFToEmitC = + ArithMinMaxFPropagateNaNToEmitC; +using ArithMinimumFToEmitC = + ArithMinMaxFPropagateNaNToEmitC; + +//===----------------------------------------------------------------------===// +// Arith -> EmitC helpers +//===----------------------------------------------------------------------===// + +static emitc::OpaqueType getSignedIntOpaqueType(MLIRContext *ctx, + unsigned bitWidth) { + switch (bitWidth) { + case 1: + return emitc::OpaqueType::get(ctx, "int8_t"); + case 8: + return emitc::OpaqueType::get(ctx, "int8_t"); + case 16: + return emitc::OpaqueType::get(ctx, "int16_t"); + case 32: + return emitc::OpaqueType::get(ctx, "int32_t"); + case 64: + return emitc::OpaqueType::get(ctx, "int64_t"); + case 128: + return emitc::OpaqueType::get(ctx, "__int128"); + default: + llvm::errs() << "[Debug] Unsupported signed integer bitwidth: " << bitWidth + << "\n"; + return emitc::OpaqueType::get(ctx, "int64_t"); + } +} + +static emitc::OpaqueType getUnsignedIntOpaqueType(MLIRContext *ctx, + unsigned bitWidth) { + switch (bitWidth) { + case 1: + return emitc::OpaqueType::get(ctx, "uint8_t"); + case 8: + return emitc::OpaqueType::get(ctx, "uint8_t"); + case 16: + return emitc::OpaqueType::get(ctx, "uint16_t"); + case 32: + return emitc::OpaqueType::get(ctx, "uint32_t"); + case 64: + return emitc::OpaqueType::get(ctx, "uint64_t"); + case 128: + return emitc::OpaqueType::get(ctx, "unsigned __int128"); + default: + llvm::errs() << "[Debug] Unsupported unsigned integer bitwidth: " + << bitWidth << "\n"; + return emitc::OpaqueType::get(ctx, "uint64_t"); + } +} + +static emitc::OpaqueType getWiderSignedIntOpaqueType(MLIRContext *ctx, + unsigned bitWidth) { + switch (bitWidth) { + case 1: + case 8: + return getSignedIntOpaqueType(ctx, 16); + case 16: + return getSignedIntOpaqueType(ctx, 32); + case 32: + return getSignedIntOpaqueType(ctx, 64); + case 64: + return getSignedIntOpaqueType(ctx, 128); + default: + return getSignedIntOpaqueType(ctx, 128); + } +} + +static emitc::OpaqueType getWiderUnsignedIntOpaqueType(MLIRContext *ctx, + unsigned bitWidth) { + switch (bitWidth) { + case 1: + case 8: + return getUnsignedIntOpaqueType(ctx, 16); + case 16: + return getUnsignedIntOpaqueType(ctx, 32); + case 32: + return getUnsignedIntOpaqueType(ctx, 64); + case 64: + return getUnsignedIntOpaqueType(ctx, 128); + default: + return getUnsignedIntOpaqueType(ctx, 128); + } +} + +static Value makeEmitCOpaqueConstant(ConversionPatternRewriter &rewriter, + Location loc, Type type, + llvm::StringRef literal) { + auto attr = emitc::OpaqueAttr::get(rewriter.getContext(), literal); + return rewriter.create(loc, type, attr); +} + +static Value makeEmitCIntConstant(ConversionPatternRewriter &rewriter, + Location loc, Type type, int64_t value) { + return makeEmitCOpaqueConstant(rewriter, loc, type, std::to_string(value)); +} + +static FailureOr buildEmitCOpaqueConstantLiteral(Type targetType, + Attribute valueAttr) { + auto opaqueTy = dyn_cast(targetType); + if (!opaqueTy) + return failure(); + + if (opaqueTy.getValue() == "pto::MrgSortExecutedNumList") { + auto dense = dyn_cast_or_null(valueAttr); + if (!dense) + return failure(); + + auto vecTy = dyn_cast(dense.getType()); + if (!vecTy || vecTy.getRank() != 1 || vecTy.getNumElements() != 4 || + !vecTy.getElementType().isInteger(16)) + return failure(); + + std::string literal; + llvm::raw_string_ostream os(literal); + os << "pto::MrgSortExecutedNumList{"; + bool first = true; + for (APInt elem : dense.getValues()) { + if (!first) + os << ", "; + first = false; + os << elem.getZExtValue(); + } + os << "}"; + os.flush(); + return literal; + } + + return failure(); +} + +static Value emitCCast(ConversionPatternRewriter &rewriter, Location loc, + Type dstType, Value src) { + if (src.getType() == dstType) + return src; + return rewriter.createOrFold(loc, dstType, src); +} + +// For signless iN integers lowered to signed C++ types, this creates a value +// representing the same N-bit pattern in an unsigned C++ type of the same +// width. This avoids incorrect sign-extension when later widening to a larger +// unsigned type. +static Value castSignlessIntToUnsignedSameWidth(ConversionPatternRewriter &rewriter, + Location loc, Value v, + unsigned bitWidth) { + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + return emitCCast(rewriter, loc, uTy, v); +} + +struct ArithMulIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(arith::MulIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + // i1 mul is equivalent to bitwise AND (mod 2 arithmetic). + if (bitWidth == 1) { + rewriter.replaceOpWithNewOp(op, opTy, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value mulU = rewriter.create(loc, uTy, lhsU, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, mulU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithAddIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + // i1 add is equivalent to XOR (mod 2 arithmetic). + if (bitWidth == 1) { + rewriter.replaceOpWithNewOp(op, opTy, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value addU = rewriter.create(loc, uTy, lhsU, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, addU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithCastOPToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type newTy = getTypeConverter()->convertType(op.getType()); + if (!newTy) + return failure(); + if (adaptor.getIn().getType() == newTy) { + rewriter.replaceOp(op, adaptor.getIn()); + return success(); + } + rewriter.replaceOpWithNewOp(op, newTy, adaptor.getIn()); + return success(); + } +}; + +struct ArithSubIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(arith::SubIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Type opTy = op.getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure(op, "expected scalar integer or index type"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + + Type dstTy = getTypeConverter()->convertType(opTy); + if (!dstTy) + return failure(); + + // i1 sub is equivalent to XOR (mod 2 arithmetic). + if (bitWidth == 1) { + rewriter.replaceOpWithNewOp(op, opTy, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } + + auto uTy = getUnsignedIntOpaqueType(rewriter.getContext(), bitWidth); + Value lhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getLhs(), + bitWidth); + Value rhsU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getRhs(), + bitWidth); + Value subU = rewriter.create(loc, uTy, lhsU, rhsU); + Value result = emitCCast(rewriter, loc, dstTy, subU); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ArithDivSIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(arith::DivSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type newTy = getTypeConverter()->convertType(op.getType()); + if (!newTy) + return failure(); + rewriter.replaceOpWithNewOp(op, newTy, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } +}; + +struct ArithRemSIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type newTy = getTypeConverter()->convertType(op.getType()); + if (!newTy) + return failure(); + rewriter.replaceOpWithNewOp(op, newTy, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } +}; + +struct ArithTruncIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + auto dstIntTy = dyn_cast(op.getType()); + auto srcIntTy = dyn_cast(op.getIn().getType()); + if (!dstIntTy || !srcIntTy) + return rewriter.notifyMatchFailure(op, "expected scalar integer types"); + + Type dstTy = getTypeConverter()->convertType(dstIntTy); + if (!dstTy) + return failure(); + + // to-i1 conversions: Arith wants truncation to the low bit, while C/C++ + // casts to bool are equivalent to `v != 0`. Implement as `(bool)(v & 1)`. + if (dstIntTy.getWidth() == 1) { + if (srcIntTy.getWidth() == 1) { + rewriter.replaceOp(op, adaptor.getIn()); + return success(); + } + + auto uSrcTy = + getUnsignedIntOpaqueType(rewriter.getContext(), srcIntTy.getWidth()); + Value inU = castSignlessIntToUnsignedSameWidth(rewriter, loc, adaptor.getIn(), + srcIntTy.getWidth()); + Value one = makeEmitCIntConstant(rewriter, loc, uSrcTy, 1); + Value masked = + rewriter.create(loc, uSrcTy, inU, one); + Value asBool = emitCCast(rewriter, loc, dstTy, masked); + rewriter.replaceOp(op, asBool); + return success(); + } + + rewriter.replaceOpWithNewOp(op, dstTy, adaptor.getIn()); + return success(); + } +}; + +struct ArithConstantToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type newType = getTypeConverter()->convertType(op.getType()); + if (!newType) + return failure(); + + // `adaptor.getValue()` may be null if attribute conversion isn't defined. + // Use the original attribute as fallback and always cast null-safely. + Attribute valueAttr = adaptor.getValue(); + if (!valueAttr) + valueAttr = op.getValue(); + + if (auto opaqueLiteral = buildEmitCOpaqueConstantLiteral(newType, valueAttr); + succeeded(opaqueLiteral)) { + auto constAttr = emitc::OpaqueAttr::get(rewriter.getContext(), *opaqueLiteral); + rewriter.replaceOpWithNewOp(op, newType, constAttr); + return success(); + } + + if (auto floatAttr = dyn_cast_or_null(valueAttr)) { + SmallString<32> valStr; + floatAttr.getValue().toString(valStr); + llvm::StringRef s(valStr); + // Ensure the literal parses as a floating-point constant in C/C++. + // `APFloat::toString` may emit "1" for integral values; make it "1.0". + const bool hasFloatMarker = + s.contains('.') || s.contains('e') || s.contains('E') || + s.contains('p') || s.contains('P') || s.starts_with("0x") || + s.starts_with("0X") || s.starts_with("nan") || + s.starts_with("-nan") || s.starts_with("inf") || + s.starts_with("-inf"); + if (!hasFloatMarker) + valStr.append(".0"); + // Suffix: keep `f` for f16/f32; omit for f64. + if (!floatAttr.getType().isF64()) + valStr.append("f"); + auto constAttr = emitc::OpaqueAttr::get(rewriter.getContext(), valStr); + rewriter.replaceOpWithNewOp(op, newType, constAttr); + return success(); + } + + if (auto intAttr = dyn_cast_or_null(valueAttr)) { + std::string valStr = std::to_string(intAttr.getValue().getSExtValue()); + auto constAttr = emitc::OpaqueAttr::get(rewriter.getContext(), valStr); + rewriter.replaceOpWithNewOp(op, newType, constAttr); + return success(); + } + + return failure(); + } +}; + +} // namespace + +void populatePTOToEmitCArithPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx) { + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add>( + typeConverter, ctx); + patterns.add>( + typeConverter, ctx); + patterns.add>( + typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); +} + +} // namespace mlir::pto diff --git a/lib/PTO/Transforms/PTOToEmitCComm.cpp b/lib/PTO/Transforms/PTOToEmitCComm.cpp new file mode 100644 index 000000000..93aed176d --- /dev/null +++ b/lib/PTO/Transforms/PTOToEmitCComm.cpp @@ -0,0 +1,889 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- PTOToEmitCComm.cpp --------------------------------------------------===// +//===----------------------------------------------------------------------===// + +#include "PTOToEmitCInternal.h" + +#include "PTO/IR/PTO.h" + +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +#include +#include +#include + +using namespace mlir; +using namespace mlir::pto; + +namespace mlir::pto { +namespace { + +static constexpr llvm::StringLiteral kGlobalTensorStridesAttrName = + "__pto.globaltensor_strides"; + +struct PTOInitializeL2G2LPipeToEmitC + : public OpConversionPattern { + PTOInitializeL2G2LPipeToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + PTOArch targetArch) + : OpConversionPattern(typeConverter, ctx), + targetArch(targetArch) {} + + LogicalResult matchAndRewrite(mlir::pto::InitializeL2G2LPipeOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto tpipeTok = buildTPipeTokenFromInitOp(op.getOperation(), targetArch); + if (failed(tpipeTok)) + return rewriter.notifyMatchFailure(op, "failed to build TPipe token"); + + auto *ctx = rewriter.getContext(); + auto emitPipeTy = + cast(getTypeConverter()->convertType(op.getPipe().getType())); + + Value gmAddr = peelUnrealized(adaptor.getGmAddr()); + gmAddr = materializeTensorViewDataPointer( + rewriter, op.getLoc(), gmAddr, op.getGmAddr().getType()); + Value localAddr = + op.getLocalAddr() ? peelUnrealized(adaptor.getLocalAddr()) : Value(); + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value zero = makeEmitCIntConstant(rewriter, op.getLoc(), i32Ty, 0); + + Value c2vBuf = zero; + Value v2cBuf = zero; + if (op.getDirMask() == 1) + c2vBuf = localAddr ? localAddr : zero; + else if (op.getDirMask() == 2) + v2cBuf = localAddr ? localAddr : zero; + else if (op.getDirMask() == 3) { + if (localAddr) { + if (!op.getPeerLocalAddr()) + return rewriter.notifyMatchFailure( + op, "bidirectional l2g2l pipe requires peer local buffer"); + c2vBuf = localAddr; + v2cBuf = peelUnrealized(adaptor.getPeerLocalAddr()); + } + } else + return rewriter.notifyMatchFailure(op, "unsupported dir_mask"); + + rewriter.replaceOpWithNewOp( + op, TypeRange{emitPipeTy}, *tpipeTok, ArrayAttr{}, ArrayAttr{}, + ValueRange{gmAddr, c2vBuf, v2cBuf}); + return success(); + } + + PTOArch targetArch; +}; + +struct PTOInitializeL2LPipeToEmitC + : public OpConversionPattern { + PTOInitializeL2LPipeToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + PTOArch targetArch) + : OpConversionPattern(typeConverter, ctx), + targetArch(targetArch) {} + + LogicalResult matchAndRewrite(mlir::pto::InitializeL2LPipeOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto tpipeTok = buildTPipeTokenFromInitOp(op.getOperation(), targetArch); + if (failed(tpipeTok)) + return rewriter.notifyMatchFailure(op, "failed to build TPipe token"); + + auto *ctx = rewriter.getContext(); + auto emitPipeTy = + cast(getTypeConverter()->convertType(op.getPipe().getType())); + + auto gmPtrTy = + emitc::PointerType::get(emitc::OpaqueType::get(ctx, "__gm__ void")); + Value nullGm = + makeEmitCOpaqueConstant(rewriter, op.getLoc(), gmPtrTy, "nullptr"); + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value zero = makeEmitCIntConstant(rewriter, op.getLoc(), i32Ty, 0); + Value localAddr = peelUnrealized(adaptor.getLocalAddr()); + + Value c2vBuf = zero; + Value v2cBuf = zero; + if (op.getDirMask() == 1) + c2vBuf = localAddr; + else if (op.getDirMask() == 2) + v2cBuf = localAddr; + else if (op.getDirMask() == 3) { + c2vBuf = localAddr; + v2cBuf = peelUnrealized(adaptor.getPeerLocalAddr()); + } else + return rewriter.notifyMatchFailure(op, "unsupported dir_mask"); + + rewriter.replaceOpWithNewOp( + op, TypeRange{emitPipeTy}, *tpipeTok, ArrayAttr{}, ArrayAttr{}, + ValueRange{nullGm, c2vBuf, v2cBuf}); + return success(); + } + + PTOArch targetArch; +}; + +struct PTOBuildAsyncSessionToEmitC + : public OpConversionPattern { + PTOBuildAsyncSessionToEmitC(TypeConverter &typeConverter, MLIRContext *ctx) + : OpConversionPattern(typeConverter, ctx) {} + + LogicalResult matchAndRewrite(mlir::pto::BuildAsyncSessionOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *ctx = rewriter.getContext(); + Location loc = op.getLoc(); + + auto sessionTy = + dyn_cast(getTypeConverter()->convertType(op.getSession().getType())); + if (!sessionTy) + return rewriter.notifyMatchFailure(op, "failed to convert async session type"); + + FailureOr scratchTile = + buildAsyncScratchTileValue(rewriter, loc, op.getScratch(), + adaptor.getScratch()); + if (failed(scratchTile)) + return rewriter.notifyMatchFailure(op, "failed to materialize async scratch tile"); + + Value workspace = + castToGMBytePointer(rewriter, loc, peelUnrealized(adaptor.getWorkspace())); + + Value session = rewriter + .create( + loc, sessionTy, emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + + auto u32Ty = emitc::OpaqueType::get(ctx, "uint32_t"); + + auto makeU32Const = [&](uint64_t value) -> Value { + return makeEmitCOpaqueConstant(rewriter, loc, u32Ty, + std::to_string(value) + "u"); + }; + uint64_t syncId = op.getSyncIdAttr() ? op.getSyncIdAttr().getInt() : 0; + uint64_t blockBytes = + op.getBlockBytesAttr() ? op.getBlockBytesAttr().getInt() : 32 * 1024; + uint64_t commBlockOffset = + op.getCommBlockOffsetAttr() ? op.getCommBlockOffsetAttr().getInt() : 0; + uint64_t queueNum = op.getQueueNumAttr() ? op.getQueueNumAttr().getInt() : 1; + uint64_t channelGroupIdx = op.getChannelGroupIdxAttr() + ? op.getChannelGroupIdxAttr().getInt() + : UINT32_MAX; + + Value syncIdVal = makeU32Const(syncId); + Value channelGroupIdxVal = + channelGroupIdx == UINT32_MAX + ? makeEmitCOpaqueConstant(rewriter, loc, u32Ty, "UINT32_MAX") + : makeU32Const(channelGroupIdx); + + auto baseConfigTy = + emitc::OpaqueType::get(ctx, "pto::comm::sdma::SdmaBaseConfig"); + Value baseConfig = + rewriter + .create( + loc, baseConfigTy, + emitc::OpaqueAttr::get( + ctx, "{" + std::to_string(blockBytes) + "ULL, " + + std::to_string(commBlockOffset) + "ULL, " + + std::to_string(queueNum) + "u}")) + .getResult(); + + rewriter.create( + loc, TypeRange{}, "pto::comm::BuildAsyncSession", + ArrayAttr{}, ArrayAttr{}, + ValueRange{*scratchTile, workspace, session, syncIdVal, baseConfig, + channelGroupIdxVal}); + + rewriter.replaceOp(op, session); + return success(); + } +}; + +template +struct PTOAsyncTransferToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + explicit PTOAsyncTransferToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + StringRef callee) + : OpConversionPattern(typeConverter, ctx), callee(callee.str()) {} + + LogicalResult matchAndRewrite(AsyncOp op, typename AsyncOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value dst = peelUnrealized(adaptor.getDst()); + Value src = peelUnrealized(adaptor.getSrc()); + Value dstGT = dst; + Value srcGT = src; + if (!isEmitCGlobalTensorLikeType(dstGT.getType())) { + auto dstMrTy = dyn_cast(op.getDst().getType()); + if (!dstMrTy) + return rewriter.notifyMatchFailure(op, "expected dst to lower to GlobalTensor or memref"); + dstGT = buildGlobalTensorFromMemref(rewriter, op.getLoc(), dst, dstMrTy, + op.getDst().getDefiningOp() + ? op.getDst().getDefiningOp() + : op.getOperation()); + } + if (!isEmitCGlobalTensorLikeType(srcGT.getType())) { + auto srcMrTy = dyn_cast(op.getSrc().getType()); + if (!srcMrTy) + return rewriter.notifyMatchFailure(op, "expected src to lower to GlobalTensor or memref"); + srcGT = buildGlobalTensorFromMemref(rewriter, op.getLoc(), src, srcMrTy, + op.getSrc().getDefiningOp() + ? op.getSrc().getDefiningOp() + : op.getOperation()); + } + if (!dstGT || !srcGT) + return rewriter.notifyMatchFailure(op, "failed to build GlobalTensor operands"); + + Type eventTy = this->getTypeConverter()->convertType(op.getEvent().getType()); + if (!eventTy) + return rewriter.notifyMatchFailure(op, "failed to convert async event type"); + + rewriter.replaceOpWithNewOp( + op, TypeRange{eventTy}, callee, ArrayAttr{}, ArrayAttr{}, + ValueRange{dstGT, srcGT, peelUnrealized(adaptor.getSession())}); + return success(); + } + + std::string callee; +}; + +template +struct PTOAsyncEventToEmitC : public OpConversionPattern { + explicit PTOAsyncEventToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + StringRef callee) + : OpConversionPattern(typeConverter, ctx), + callee(callee.str()) {} + + LogicalResult matchAndRewrite(AsyncEventOp op, + typename AsyncEventOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultTy = + this->getTypeConverter()->convertType(op.getCompleted().getType()); + if (!resultTy) + return rewriter.notifyMatchFailure(op, "failed to convert async event result type"); + + rewriter.replaceOpWithNewOp( + op, TypeRange{resultTy}, callee, ArrayAttr{}, ArrayAttr{}, + ValueRange{peelUnrealized(adaptor.getEvent()), + peelUnrealized(adaptor.getSession())}); + return success(); + } + + std::string callee; +}; + +static FailureOr buildCommGlobalTensorValue( + ConversionPatternRewriter &rewriter, Location loc, Value originalValue, + Value emittedValue, Operation *anchor) { + Value value = peelUnrealized(emittedValue); + if (isEmitCGlobalTensorLikeType(value.getType())) + return value; + + auto memTy = dyn_cast(originalValue.getType()); + if (!memTy) + return failure(); + + Value gt = buildGlobalTensorFromMemref(rewriter, loc, value, memTy, anchor); + if (!gt) + return failure(); + return gt; +} + +static FailureOr buildCommTileValue(ConversionPatternRewriter &rewriter, + Location loc, Value originalValue, + Value emittedValue) { + Value value = peelUnrealized(emittedValue); + if (auto opaqueTy = dyn_cast(value.getType())) { + StringRef typeStr = opaqueTy.getValue(); + if (typeStr.contains("Tile<") || typeStr.contains("ConvTile<")) + return value; + } + return buildAsyncScratchTileValue(rewriter, loc, originalValue, emittedValue); +} + +static FailureOr buildCollectiveParallelGroup( + ConversionPatternRewriter &rewriter, Location loc, + ArrayRef groupGTs, int64_t root) { + if (groupGTs.empty()) + return failure(); + + auto firstTy = dyn_cast(groupGTs.front().getType()); + if (!firstTy) + return failure(); + + auto *ctx = rewriter.getContext(); + auto arrayTy = emitc::ArrayType::get({static_cast(groupGTs.size())}, + firstTy); + auto groupArray = cast>( + rewriter + .create(loc, arrayTy, + emitc::OpaqueAttr::get(ctx, "{}")) + .getResult()); + + auto indexTy = emitc::OpaqueType::get(ctx, "int"); + for (auto [idx, groupVal] : llvm::enumerate(groupGTs)) { + Value idxVal = + makeEmitCIntConstant(rewriter, loc, indexTy, static_cast(idx)); + Value slot = + rewriter.create(loc, groupArray, ValueRange{idxVal}) + .getResult(); + rewriter.create(loc, slot, groupVal); + } + + std::string pgTypeStr = + (Twine("pto::comm::ParallelGroup<") + firstTy.getValue() + ">").str(); + auto pgTy = emitc::OpaqueType::get(ctx, pgTypeStr); + Value sizeVal = makeEmitCIntConstant(rewriter, loc, indexTy, + static_cast(groupGTs.size())); + Value rootVal = makeEmitCIntConstant(rewriter, loc, indexTy, root); + return rewriter + .create( + loc, TypeRange{pgTy}, (Twine(pgTypeStr) + "::Create").str(), + ArrayAttr{}, ArrayAttr{}, ValueRange{groupArray, sizeVal, rootVal}) + .getResult(0); +} + +static std::string notifyOpTok(pto::NotifyOp op) { + switch (op) { + case pto::NotifyOp::AtomicAdd: + return "pto::comm::NotifyOp::AtomicAdd"; + case pto::NotifyOp::Set: + return "pto::comm::NotifyOp::Set"; + } + return "pto::comm::NotifyOp::Set"; +} + +static std::string waitCmpTok(pto::WaitCmp cmp) { + switch (cmp) { + case pto::WaitCmp::EQ: + return "pto::comm::WaitCmp::EQ"; + case pto::WaitCmp::NE: + return "pto::comm::WaitCmp::NE"; + case pto::WaitCmp::GT: + return "pto::comm::WaitCmp::GT"; + case pto::WaitCmp::GE: + return "pto::comm::WaitCmp::GE"; + case pto::WaitCmp::LT: + return "pto::comm::WaitCmp::LT"; + case pto::WaitCmp::LE: + return "pto::comm::WaitCmp::LE"; + } + return "pto::comm::WaitCmp::EQ"; +} + +static std::string reduceOpTok(pto::ReduceOp op) { + switch (op) { + case pto::ReduceOp::Sum: + return "pto::comm::ReduceOp::Sum"; + case pto::ReduceOp::Max: + return "pto::comm::ReduceOp::Max"; + case pto::ReduceOp::Min: + return "pto::comm::ReduceOp::Min"; + } + return "pto::comm::ReduceOp::Sum"; +} + +template +static FailureOr> buildCommGroupGlobalTensors( + ConversionPatternRewriter &rewriter, Location loc, OpTy op, + ValueRange originalGroup, ValueRange emittedGroup) { + SmallVector groupGTs; + groupGTs.reserve(originalGroup.size()); + for (auto [orig, emitted] : llvm::zip(originalGroup, emittedGroup)) { + FailureOr gt = + buildCommGlobalTensorValue(rewriter, loc, orig, emitted, op.getOperation()); + if (failed(gt)) + return failure(); + groupGTs.push_back(*gt); + } + return groupGTs; +} + +template +struct PTOCommCollectiveToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + explicit PTOCommCollectiveToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + StringRef apiName) + : OpConversionPattern(typeConverter, ctx), + apiName(apiName.str()) {} + + LogicalResult matchAndRewrite(CollectiveOp op, typename CollectiveOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + auto buildPong = [&](Value original, Value emitted, StringRef name) -> FailureOr { + if (!original) + return failure(); + return buildCommTileValue(rewriter, loc, original, emitted); + }; + + if constexpr (std::is_same_v) { + FailureOr srcGT = + buildCommGlobalTensorValue(rewriter, loc, op.getSrc(), adaptor.getSrc(), + op.getOperation()); + FailureOr pingTile = + buildCommTileValue(rewriter, loc, op.getPing(), adaptor.getPing()); + auto groupGTs = + buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); + if (failed(srcGT) || failed(pingTile) || failed(groupGTs)) + return rewriter.notifyMatchFailure(op, "failed to materialize broadcast operands"); + FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); + if (failed(pg)) + return rewriter.notifyMatchFailure(op, "failed to materialize broadcast group"); + if (op.getPong()) { + FailureOr pongTile = + buildPong(op.getPong(), adaptor.getPong(), "__pong"); + if (failed(pongTile)) + return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); + rewriter.create( + loc, TypeRange{}, "pto::comm::TBROADCAST", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *srcGT, *pingTile, *pongTile}); + } else { + rewriter.create( + loc, TypeRange{}, "pto::comm::TBROADCAST", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *srcGT, *pingTile}); + } + } else if constexpr (std::is_same_v) { + FailureOr dstGT = + buildCommGlobalTensorValue(rewriter, loc, op.getDst(), adaptor.getDst(), + op.getOperation()); + FailureOr pingTile = + buildCommTileValue(rewriter, loc, op.getPing(), adaptor.getPing()); + auto groupGTs = + buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); + if (failed(dstGT) || failed(pingTile) || failed(groupGTs)) + return rewriter.notifyMatchFailure(op, "failed to materialize gather operands"); + FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); + if (failed(pg)) + return rewriter.notifyMatchFailure(op, "failed to materialize gather group"); + if (op.getPong()) { + FailureOr pongTile = + buildPong(op.getPong(), adaptor.getPong(), "__pong"); + if (failed(pongTile)) + return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); + rewriter.create( + loc, TypeRange{}, "pto::comm::TGATHER", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *dstGT, *pingTile, *pongTile}); + } else { + rewriter.create( + loc, TypeRange{}, "pto::comm::TGATHER", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *dstGT, *pingTile}); + } + } else if constexpr (std::is_same_v) { + FailureOr srcGT = + buildCommGlobalTensorValue(rewriter, loc, op.getSrc(), adaptor.getSrc(), + op.getOperation()); + FailureOr pingTile = + buildCommTileValue(rewriter, loc, op.getPing(), adaptor.getPing()); + auto groupGTs = + buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); + if (failed(srcGT) || failed(pingTile) || failed(groupGTs)) + return rewriter.notifyMatchFailure(op, "failed to materialize scatter operands"); + FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); + if (failed(pg)) + return rewriter.notifyMatchFailure(op, "failed to materialize scatter group"); + if (op.getPong()) { + FailureOr pongTile = + buildPong(op.getPong(), adaptor.getPong(), "__pong"); + if (failed(pongTile)) + return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); + rewriter.create( + loc, TypeRange{}, "pto::comm::TSCATTER", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *srcGT, *pingTile, *pongTile}); + } else { + rewriter.create( + loc, TypeRange{}, "pto::comm::TSCATTER", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *srcGT, *pingTile}); + } + } else { + FailureOr dstGT = + buildCommGlobalTensorValue(rewriter, loc, op.getDst(), adaptor.getDst(), + op.getOperation()); + FailureOr accTile = + buildCommTileValue(rewriter, loc, op.getAcc(), adaptor.getAcc()); + FailureOr recvPing = + buildCommTileValue(rewriter, loc, op.getRecvPing(), adaptor.getRecvPing()); + auto groupGTs = + buildCommGroupGlobalTensors(rewriter, loc, op, op.getGroup(), adaptor.getGroup()); + if (failed(dstGT) || failed(accTile) || failed(recvPing) || failed(groupGTs)) + return rewriter.notifyMatchFailure(op, "failed to materialize reduce operands"); + FailureOr pg = buildCollectiveParallelGroup(rewriter, loc, *groupGTs, op.getRoot()); + if (failed(pg)) + return rewriter.notifyMatchFailure(op, "failed to materialize reduce group"); + if (op.getRecvPong()) { + FailureOr recvPong = + buildPong(op.getRecvPong(), adaptor.getRecvPong(), "__recv_pong"); + if (failed(recvPong)) + return rewriter.notifyMatchFailure(op, "failed to materialize recv_pong"); + auto reduceTy = + emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::ReduceOp"); + Value reduceOp = makeEmitCOpaqueConstant(rewriter, loc, reduceTy, + reduceOpTok(op.getReduceOp())); + rewriter.create( + loc, TypeRange{}, "pto::comm::TREDUCE", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *dstGT, *accTile, *recvPing, *recvPong, reduceOp}); + } else { + auto reduceTy = + emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::ReduceOp"); + Value reduceOp = makeEmitCOpaqueConstant(rewriter, loc, reduceTy, + reduceOpTok(op.getReduceOp())); + rewriter.create( + loc, TypeRange{}, "pto::comm::TREDUCE", ArrayAttr{}, ArrayAttr{}, + ValueRange{*pg, *dstGT, *accTile, *recvPing, reduceOp}); + } + } + rewriter.eraseOp(op); + return success(); + } + + std::string apiName; +}; + +template +struct PTOP2PCommToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + explicit PTOP2PCommToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + StringRef callee) + : OpConversionPattern(typeConverter, ctx), callee(callee.str()) {} + + LogicalResult matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr dstGT = + buildCommGlobalTensorValue(rewriter, op.getLoc(), op.getDst(), adaptor.getDst(), + op.getOperation()); + FailureOr srcGT = + buildCommGlobalTensorValue(rewriter, op.getLoc(), op.getSrc(), adaptor.getSrc(), + op.getOperation()); + FailureOr pingTile = + buildCommTileValue(rewriter, op.getLoc(), op.getPing(), adaptor.getPing()); + if (failed(dstGT) || failed(srcGT) || failed(pingTile)) + return rewriter.notifyMatchFailure(op, "failed to materialize p2p operands"); + + SmallVector operands{*dstGT, *srcGT, *pingTile}; + std::string actualCallee = callee; + if constexpr (std::is_same_v) { + if (op.getAtomicType() == pto::AtomicType::AtomicAdd) + actualCallee = "pto::comm::TPUT"; + } + if (op.getPong()) { + FailureOr pongTile = + buildCommTileValue(rewriter, op.getLoc(), op.getPong(), adaptor.getPong()); + if (failed(pongTile)) + return rewriter.notifyMatchFailure(op, "failed to materialize pong tile"); + operands.push_back(*pongTile); + } + + rewriter.create(op.getLoc(), TypeRange{}, actualCallee, + ArrayAttr{}, ArrayAttr{}, operands); + rewriter.eraseOp(op); + return success(); + } + + std::string callee; +}; + +template +struct PTOSignalCommToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + explicit PTOSignalCommToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + StringRef callee) + : OpConversionPattern(typeConverter, ctx), + callee(callee.str()) {} + + LogicalResult matchAndRewrite(SignalOp op, typename SignalOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr signalGT = buildCommGlobalTensorValue( + rewriter, op.getLoc(), op.getSignal(), adaptor.getSignal(), op.getOperation()); + if (failed(signalGT)) + return rewriter.notifyMatchFailure(op, "failed to materialize signal operand"); + + if constexpr (std::is_same_v) { + auto notifyTy = + emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::NotifyOp"); + Value notifyOp = makeEmitCOpaqueConstant( + rewriter, op.getLoc(), notifyTy, notifyOpTok(op.getNotifyOp())); + SmallVector operands{*signalGT, peelUnrealized(adaptor.getValue()), + notifyOp}; + rewriter.create(op.getLoc(), TypeRange{}, callee, + ArrayAttr{}, ArrayAttr{}, operands); + rewriter.eraseOp(op); + } else { + auto waitCmpTy = + emitc::OpaqueType::get(rewriter.getContext(), "pto::comm::WaitCmp"); + Value waitCmp = makeEmitCOpaqueConstant( + rewriter, op.getLoc(), waitCmpTy, waitCmpTok(op.getCmp())); + SmallVector operands{*signalGT, peelUnrealized(adaptor.getCmpValue()), + waitCmp}; + if constexpr (std::is_same_v) { + Type resultTy = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultTy) + return rewriter.notifyMatchFailure(op, "failed to convert ttest result type"); + rewriter.replaceOpWithNewOp( + op, TypeRange{resultTy}, callee, ArrayAttr{}, ArrayAttr{}, operands); + } else { + rewriter.create(op.getLoc(), TypeRange{}, callee, + ArrayAttr{}, ArrayAttr{}, operands); + rewriter.eraseOp(op); + } + } + return success(); + } + + std::string callee; +}; + +struct PTODeclareTileMemRefToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::DeclareTileMemRefOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::DeclareTileMemRefOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + Type convertedType = getTypeConverter()->convertType(op.getResult().getType()); + if (!convertedType) + return rewriter.notifyMatchFailure( + op, "failed to convert declare_tile_memref result type"); + rewriter.replaceOp(op, makeEmitCOpaqueConstant(rewriter, op.getLoc(), + convertedType, "nullptr")); + return success(); + } +}; + +struct PTODeclareGlobalToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::DeclareGlobalOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::DeclareGlobalOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + Type convertedType = getTypeConverter()->convertType(op.getEntry().getType()); + if (!convertedType) + return rewriter.notifyMatchFailure( + op, "failed to convert declare_global result type"); + if (auto tvTy = dyn_cast(op.getEntry().getType())) { + if (auto stridesAttr = + op->getAttrOfType(kGlobalTensorStridesAttrName)) { + auto strides = stridesAttr.asArrayRef(); + if (strides.size() == static_cast(tvTy.getRank())) { + convertedType = emitc::OpaqueType::get( + rewriter.getContext(), + getGlobalTensorTypeStringFromShapeAndStrides( + tvTy.getElementType(), tvTy.getShape(), strides)); + } + } + } + auto var = rewriter.create( + op.getLoc(), convertedType, + emitc::OpaqueAttr::get(rewriter.getContext(), "")); + rewriter.replaceOp(op, var.getResult()); + return success(); + } +}; + +struct PTODeclareEventIdArrayToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::DeclareEventIdArrayOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::DeclareEventIdArrayOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + Type arrayTy = getTypeConverter()->convertType(op.getArray().getType()); + if (!arrayTy) + return rewriter.notifyMatchFailure(op, + "failed to map declared eventid_array type"); + + auto array = rewriter + .create( + op.getLoc(), arrayTy, + emitc::OpaqueAttr::get(rewriter.getContext(), "")) + .getResult(); + rewriter.replaceOp(op, array); + return success(); + } +}; + +struct PTOEventIdArrayGetToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::EventIdArrayGetOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::EventIdArrayGetOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value array = peelUnrealized(adaptor.getArray()); + Value index = peelUnrealized(adaptor.getIndex()); + + Type resultTy = getTypeConverter()->convertType(op.getResult().getType()); + if (!resultTy) + return rewriter.notifyMatchFailure(op, + "failed to map eventid_array get result type"); + + auto load = + rewriter.create(op.getLoc(), resultTy, array, index); + rewriter.replaceOp(op, load.getResult()); + return success(); + } +}; + +struct PTOEventIdArraySetToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::EventIdArraySetOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::EventIdArraySetOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value array = peelUnrealized(adaptor.getArray()); + Value index = peelUnrealized(adaptor.getIndex()); + Value value = peelUnrealized(adaptor.getValue()); + + rewriter.create( + op.getLoc(), TypeRange{}, "PTOAS__EVENTID_ARRAY_STORE", + ArrayAttr{}, ArrayAttr{}, ValueRange{array, index, value}); + rewriter.eraseOp(op); + return success(); + } +}; + +// pto.declare_local_array -> emitc.variable of !emitc.array<...>. +// Renders as `T a[D1][D2]...;` in the emitted C++. +struct PTODeclareLocalArrayToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::DeclareLocalArrayOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::DeclareLocalArrayOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + Type arrayTy = getTypeConverter()->convertType(op.getArray().getType()); + if (!arrayTy) + return rewriter.notifyMatchFailure(op, + "failed to map !pto.local_array type"); + + auto var = rewriter + .create( + op.getLoc(), arrayTy, + emitc::OpaqueAttr::get(rewriter.getContext(), "")) + .getResult(); + rewriter.replaceOp(op, var); + return success(); + } +}; + +// pto.local_array_get %a[%i0, %i1, ...] -> rvalue. +// Lowers to a single emitc.subscript with the full index pack; the C++ emitter +// prints it as `a[i0][i1]...`. The adaptor already exposes target-typed values +// (the type converter has remapped !pto.local_array -> !emitc.array and +// index/integer indices), so they're forwarded directly to the builder. +struct PTOLocalArrayGetToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::LocalArrayGetOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::LocalArrayGetOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultTy = + getTypeConverter()->convertType(op.getResult().getType()); + if (!resultTy) + return rewriter.notifyMatchFailure( + op, "failed to map local_array element type"); + + auto sub = rewriter.create( + op.getLoc(), resultTy, adaptor.getArray(), adaptor.getIndices()); + rewriter.replaceOp(op, sub.getResult()); + return success(); + } +}; + +// pto.local_array_set %a[%i0, %i1, ...], %v -> emitc.assign to subscript slot. +// The C++ emitter prints this as `a[i0][i1]... = v;`. As above, adaptor values +// are already target-typed; pass them through directly. +struct PTOLocalArraySetToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::LocalArraySetOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::LocalArraySetOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value value = adaptor.getValue(); + Type elemTy = value.getType(); + + Value slot = rewriter + .create(op.getLoc(), elemTy, + adaptor.getArray(), + adaptor.getIndices()) + .getResult(); + rewriter.create(op.getLoc(), slot, value); + rewriter.eraseOp(op); + return success(); + } +}; + + +} // namespace + +void populatePTOToEmitCCommPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx, PTOArch targetArch) { + patterns.add(typeConverter, ctx, targetArch); + patterns.add(typeConverter, ctx, targetArch); + patterns.add(typeConverter, ctx); + patterns.add>( + typeConverter, ctx, + "pto::comm::TPUT_ASYNC"); + patterns.add>( + typeConverter, ctx, + "pto::comm::TGET_ASYNC"); + patterns.add>(typeConverter, ctx, + "pto::comm::TPUT"); + patterns.add>(typeConverter, ctx, + "pto::comm::TGET"); + patterns.add>(typeConverter, ctx, + "pto::comm::TNOTIFY"); + patterns.add>(typeConverter, ctx, + "pto::comm::TWAIT"); + patterns.add>(typeConverter, ctx, + "pto::comm::TTEST"); + patterns.add>(typeConverter, ctx, + "TBROADCAST"); + patterns.add>(typeConverter, ctx, + "TGATHER"); + patterns.add>(typeConverter, ctx, + "TSCATTER"); + patterns.add>(typeConverter, ctx, + "TREDUCE"); + patterns.add>( + typeConverter, ctx, "PTOAS__ASYNC_EVENT_WAIT"); + patterns.add>( + typeConverter, ctx, "PTOAS__ASYNC_EVENT_TEST"); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); +} + +} // namespace mlir::pto diff --git a/lib/PTO/Transforms/PTOToEmitCControlFlow.cpp b/lib/PTO/Transforms/PTOToEmitCControlFlow.cpp new file mode 100644 index 000000000..8422fe40d --- /dev/null +++ b/lib/PTO/Transforms/PTOToEmitCControlFlow.cpp @@ -0,0 +1,717 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- PTOToEmitCControlFlow.cpp ------------------------------------------===// +//===----------------------------------------------------------------------===// + +#include "PTOToEmitCInternal.h" + +#include "PTO/IR/PTO.h" + +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" +#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::pto; + +namespace mlir::pto { +namespace { + +//===----------------------------------------------------------------------===// +// Return lowering +//===----------------------------------------------------------------------=== + +static constexpr llvm::StringLiteral kAutoSyncTailPendingModeAttr = + "__pto.auto_sync_tail_mode"; + +struct ReturnToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (auto emitcFunc = op->getParentOfType()) { + if (auto modeAttr = + emitcFunc->getAttrOfType(kAutoSyncTailPendingModeAttr)) { + auto *ctx = rewriter.getContext(); + rewriter.setInsertionPoint(op); + auto args = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, modeAttr.getValue())}); + rewriter.create( + op.getLoc(), TypeRange{}, "ptoas_auto_sync_tail", + args, ArrayAttr{}, ValueRange{}); + } + } + + auto vals = adaptor.getOperands(); + if (vals.empty()) { + rewriter.replaceOpWithNewOp(op, Value{}); + return success(); + } + if (vals.size() == 1) { + rewriter.replaceOpWithNewOp(op, vals[0]); + return success(); + } + return rewriter.notifyMatchFailure(op, "EmitC cannot return multiple values"); + } +}; + +struct CallToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(func::CallOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op.getNumResults() > 1) + return rewriter.notifyMatchFailure( + op, "EmitC cannot lower calls with multiple results"); + + SmallVector resultTypes; + if (failed( + getTypeConverter()->convertTypes(op.getResultTypes(), resultTypes))) + return rewriter.notifyMatchFailure(op, + "failed to convert call result types"); + + rewriter.replaceOpWithNewOp(op, op.getCalleeAttr(), + resultTypes, + adaptor.getOperands()); + return success(); + } +}; + + + +template +struct SectionToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + std::string getMacroName() const { + if (std::is_same::value) + return "__DAV_CUBE__"; + if (std::is_same::value) + return "__DAV_VEC__"; + return "UNKNOWN_MACRO"; + } + + LogicalResult + matchAndRewrite(SectionOpTy op, typename SectionOpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + bool needsNoSplitGuard = needsA5NoSplitVectorGuard(op.getOperation()); + + std::string startMacro = "\n#if defined(" + getMacroName() + ")"; + rewriter.create(loc, startMacro); + + if constexpr (std::is_same_v) { + // Vector mask is a global HW state and may be modified by previous kernels + // (or earlier sections). Reset it to a well-defined state for deterministic + // execution of VEC ops. + rewriter.create(loc, "set_mask_norm();"); + rewriter.create(loc, "set_vector_mask(-1, -1);"); + } + + if (needsNoSplitGuard) { + rewriter.create( + loc, "if (get_subblockid() == 0) {"); + } + + Block &innerBlock = op.getBody().front(); + if (!innerBlock.empty()) { + rewriter.inlineBlockBefore(&innerBlock, op.getOperation(), ValueRange{}); + } + + if (needsNoSplitGuard) + rewriter.create(loc, "}"); + + std::string endMacro = "#endif // " + getMacroName() + "\n"; + rewriter.create(loc, endMacro); + + rewriter.eraseOp(op); + + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// SCF Control-Flow Pre-Lowering +// +// EmitC translation supports `emitc.for`/`emitc.if` plus CFG-style +// `cf.br`/`cf.cond_br`. Upstream SCFToEmitC patterns only cover `scf.for` and +// `scf.if`, so we pre-lower some SCF ops into those supported forms. +//===----------------------------------------------------------------------===// + +namespace { + +static bool isTriviallyInlineableExecuteRegion(scf::ExecuteRegionOp op) { + Region &r = op.getRegion(); + if (!r.hasOneBlock()) + return false; + Block &b = r.front(); + return isa_and_nonnull(b.getTerminator()); +} + +static bool needsWholeFunctionSCFToCF(func::FuncOp func) { + bool needs = false; + func.walk([&](Operation *op) { + if (!isa(op)) + return WalkResult::advance(); + Operation *parentOp = op->getParentOp(); + + // `scf.execute_region` can legally appear in single-block parents. Only + // require whole-function SCFToCF if we need to lower it into CFG blocks + // (multi-block region / non-trivial terminators). + if (auto exec = dyn_cast(op)) { + if (parentOp && parentOp->hasTrait() && + !isTriviallyInlineableExecuteRegion(exec)) { + needs = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + } + + if (parentOp && parentOp->hasTrait()) { + needs = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return needs; +} + +// scf.execute_region is semantically just an inlined region producing results +// via scf.yield. Inline it to the parent block to avoid extra lowering needs. +struct SCFExecuteRegionInline + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ExecuteRegionOp op, + PatternRewriter &rewriter) const override { + if (op.getRegion().empty()) + return rewriter.notifyMatchFailure(op, "expected non-empty region"); + + Block &innerBlock = op.getRegion().front(); + auto yield = dyn_cast(innerBlock.getTerminator()); + if (!yield) + return rewriter.notifyMatchFailure(op, "expected scf.yield terminator"); + + // Move the body operations before the execute_region op. + rewriter.inlineBlockBefore(&innerBlock, op.getOperation(), ValueRange{}); + + // Replace execute_region results with yielded values, then erase the yield. + rewriter.replaceOp(op, yield.getOperands()); + rewriter.eraseOp(yield); + return success(); + } +}; + +// Lower scf.execute_region into CFG blocks with cf.br/cf.cond_br by inlining the +// region blocks into the parent region and rewriting scf.yield to branch into a +// continuation block carrying results. +// +// Note: This requires the parent region to allow multiple blocks (e.g. the +// function body CFG region). For execute_region nested in single-block regions +// (scf.for/scf.if), run SCFToCF first to eliminate the single-block constraint. +struct SCFExecuteRegionToCF : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ExecuteRegionOp op, + PatternRewriter &rewriter) const override { + if (isTriviallyInlineableExecuteRegion(op)) + return rewriter.notifyMatchFailure(op, "trivially inlineable"); + + Operation *parentOp = op->getParentOp(); + if (parentOp && parentOp->hasTrait()) { + return rewriter.notifyMatchFailure( + op, "cannot lower scf.execute_region inside a single-block parent region"); + } + + if (op.getRegion().empty()) + return rewriter.notifyMatchFailure(op, "expected non-empty region"); + + Location loc = op.getLoc(); + Block *curBlock = op->getBlock(); + Region *parentRegion = curBlock->getParent(); + + // Split the parent block so we can branch to a continuation block with phi + // arguments for the execute_region results. + auto execIt = Block::iterator(op.getOperation()); + Block *continueBlock = rewriter.splitBlock(curBlock, std::next(execIt)); + + SmallVector contArgs; + contArgs.reserve(op.getNumResults()); + for (Type t : op.getResultTypes()) + contArgs.push_back(continueBlock->addArgument(t, loc)); + + for (auto it : llvm::enumerate(op.getResults())) + it.value().replaceAllUsesWith(contArgs[it.index()]); + + // Capture blocks before moving the region. + SmallVector movedBlocks; + movedBlocks.reserve(op.getRegion().getBlocks().size()); + for (Block &b : op.getRegion()) + movedBlocks.push_back(&b); + Block *entryBlock = &op.getRegion().front(); + + // Inline the execute_region blocks into the parent region right before the + // continuation block. + rewriter.inlineRegionBefore(op.getRegion(), *parentRegion, + continueBlock->getIterator()); + + // Replace all scf.yield terminators with a branch to the continuation. + for (Block *b : movedBlocks) { + auto yield = dyn_cast(b->getTerminator()); + if (!yield) + continue; + rewriter.setInsertionPoint(yield); + rewriter.create(loc, continueBlock, yield.getOperands()); + rewriter.eraseOp(yield); + } + + // Replace execute_region itself with a branch to the inlined entry block. + rewriter.setInsertionPoint(op); + rewriter.create(loc, entryBlock, ValueRange{}); + rewriter.eraseOp(op); + return success(); + } +}; + +// Lower scf.index_switch into CFG blocks with cf.cond_br/cf.br so that we can +// avoid `scf.if` result materialization quirks (and avoid relying on cf.switch, +// which is not supported by EmitC C++ translation). +struct SCFIndexSwitchToCF : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + static LogicalResult cloneYieldingBlockAndBranchTo( + PatternRewriter &rewriter, Location loc, Block &srcBlock, Block *destBlock, + Block *continueBlock) { + rewriter.setInsertionPointToEnd(destBlock); + + IRMapping mapping; + for (Operation &inner : srcBlock.without_terminator()) + rewriter.clone(inner, mapping); + + auto yield = dyn_cast(srcBlock.getTerminator()); + if (!yield) + return failure(); + + SmallVector yieldOperands; + yieldOperands.reserve(yield.getNumOperands()); + for (Value v : yield.getOperands()) + yieldOperands.push_back(mapping.lookupOrDefault(v)); + + rewriter.create(loc, continueBlock, yieldOperands); + return success(); + } + + static Block *splitBlockForContinuation(PatternRewriter &rewriter, + scf::IndexSwitchOp op) { + auto switchIt = Block::iterator(op.getOperation()); + return rewriter.splitBlock(op->getBlock(), std::next(switchIt)); + } + + static void addContinuationArguments(PatternRewriter &rewriter, + scf::IndexSwitchOp op, Location loc, + Block *continueBlock) { + SmallVector contArgs; + contArgs.reserve(op.getNumResults()); + for (Type type : op.getResultTypes()) + contArgs.push_back(continueBlock->addArgument(type, loc)); + for (auto result : llvm::enumerate(op.getResults())) + result.value().replaceAllUsesWith(contArgs[result.index()]); + } + + static void createIndexSwitchBlocks(PatternRewriter &rewriter, + Region *parentRegion, + Region::iterator insertPt, + unsigned numCases, + SmallVectorImpl &checkBlocks, + Block *&defaultBlock, + SmallVectorImpl &caseBlocks) { + checkBlocks.reserve(numCases); + caseBlocks.reserve(numCases); + for (unsigned i = 0; i < numCases; ++i) + checkBlocks.push_back(rewriter.createBlock(parentRegion, insertPt)); + defaultBlock = rewriter.createBlock(parentRegion, insertPt); + for (unsigned i = 0; i < numCases; ++i) + caseBlocks.push_back(rewriter.createBlock(parentRegion, insertPt)); + } + + static void populateIndexSwitchCheckBlocks( + PatternRewriter &rewriter, Location loc, Value selector, + ArrayRef cases, ArrayRef checkBlocks, + ArrayRef caseBlocks, Block *defaultBlock) { + for (unsigned i = 0; i < checkBlocks.size(); ++i) { + rewriter.setInsertionPointToEnd(checkBlocks[i]); + Value caseVal = rewriter.create(loc, cases[i]); + Value cond = rewriter.create( + loc, arith::CmpIPredicate::eq, selector, caseVal); + Block *falseDest = + (i + 1 < checkBlocks.size()) ? checkBlocks[i + 1] : defaultBlock; + rewriter.create(loc, cond, caseBlocks[i], ValueRange{}, + falseDest, ValueRange{}); + } + } + + LogicalResult matchAndRewrite(scf::IndexSwitchOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Operation *parentOp = op->getParentOp(); + if (parentOp && parentOp->hasTrait()) { + return rewriter.notifyMatchFailure( + op, "cannot lower scf.index_switch inside a single-block parent region"); + } + + Block *curBlock = op->getBlock(); + Region *parentRegion = curBlock->getParent(); + Block *continueBlock = splitBlockForContinuation(rewriter, op); + addContinuationArguments(rewriter, op, loc, continueBlock); + + unsigned numCases = op.getCases().size(); + auto insertPt = continueBlock->getIterator(); + + SmallVector checkBlocks; + SmallVector caseBlocks; + Block *defaultBlock = nullptr; + createIndexSwitchBlocks(rewriter, parentRegion, insertPt, numCases, + checkBlocks, defaultBlock, caseBlocks); + + Value selector = op.getArg(); + auto cases = op.getCases(); + populateIndexSwitchCheckBlocks(rewriter, loc, selector, cases, checkBlocks, + caseBlocks, defaultBlock); + + // Fill case blocks and default block with cloned bodies + branch to cont. + for (unsigned i = 0; i < numCases; ++i) { + if (failed(cloneYieldingBlockAndBranchTo( + rewriter, loc, op.getCaseBlock(i), caseBlocks[i], continueBlock))) + return rewriter.notifyMatchFailure(op, "expected scf.yield terminator"); + } + if (failed(cloneYieldingBlockAndBranchTo(rewriter, loc, op.getDefaultBlock(), + defaultBlock, continueBlock))) + return rewriter.notifyMatchFailure(op, "expected scf.yield terminator"); + + // Replace the original switch op with a branch into the check chain. + Block *entryDest = numCases ? checkBlocks[0] : defaultBlock; + rewriter.setInsertionPointAfter(op); + rewriter.create(loc, entryDest, ValueRange{}); + rewriter.eraseOp(op); + return success(); + } +}; + +// Lower scf.while into CFG blocks with cf.br/cf.cond_br. +// +// Note: This requires the parent region to allow multiple blocks. In +// particular, scf.if/scf.for regions are single-block and cannot contain this +// lowering. +struct SCFWhileToCF : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + static LogicalResult validateWhileResultUses(scf::WhileOp op) { + Block *parentBlock = op->getBlock(); + for (Value result : op.getResults()) { + for (OpOperand &use : result.getUses()) { + if (use.getOwner()->getBlock() != parentBlock) + return failure(); + } + } + return success(); + } + + static Block *splitAfterWhileBlock(PatternRewriter &rewriter, + scf::WhileOp op) { + auto whileIt = Block::iterator(op.getOperation()); + return rewriter.splitBlock(op->getBlock(), std::next(whileIt)); + } + + static void addWhileExitArguments(PatternRewriter &rewriter, scf::WhileOp op, + Location loc, Block *afterWhileBlock) { + SmallVector exitArgs; + exitArgs.reserve(op.getNumResults()); + for (Type type : op.getResultTypes()) + exitArgs.push_back(afterWhileBlock->addArgument(type, loc)); + for (auto result : llvm::enumerate(op.getResults())) + result.value().replaceAllUsesWith(exitArgs[result.index()]); + } + + static Block *createWhileHeaderBlock(PatternRewriter &rewriter, + scf::WhileOp op, Location loc, + Block *afterWhileBlock) { + SmallVector headerArgTypes; + for (Value init : op.getInits()) + headerArgTypes.push_back(init.getType()); + SmallVector headerArgLocs(headerArgTypes.size(), loc); + return rewriter.createBlock(afterWhileBlock->getParent(), + afterWhileBlock->getIterator(), headerArgTypes, + headerArgLocs); + } + + static Block *createWhileBodyBlock(PatternRewriter &rewriter, scf::WhileOp op, + Location loc, Block *afterWhileBlock) { + Block &afterRegionBlock = op.getAfter().front(); + SmallVector bodyArgTypes(afterRegionBlock.getArgumentTypes().begin(), + afterRegionBlock.getArgumentTypes().end()); + SmallVector bodyArgLocs(bodyArgTypes.size(), loc); + return rewriter.createBlock(afterWhileBlock->getParent(), + afterWhileBlock->getIterator(), bodyArgTypes, + bodyArgLocs); + } + + static void rewriteWhileTerminators(PatternRewriter &rewriter, Location loc, + Block *headerBlock, Block *bodyBlock, + Block *afterWhileBlock) { + auto condOp = cast(headerBlock->getTerminator()); + rewriter.setInsertionPoint(condOp); + rewriter.create(loc, condOp.getCondition(), + /*trueDest=*/bodyBlock, + /*trueOperands=*/condOp.getArgs(), + /*falseDest=*/afterWhileBlock, + /*falseOperands=*/condOp.getArgs()); + rewriter.eraseOp(condOp); + + auto yieldOp = cast(bodyBlock->getTerminator()); + rewriter.setInsertionPoint(yieldOp); + rewriter.create(loc, headerBlock, yieldOp.getOperands()); + rewriter.eraseOp(yieldOp); + } + + LogicalResult matchAndRewrite(scf::WhileOp op, + PatternRewriter &rewriter) const override { + Operation *parentOp = op->getParentOp(); + if (parentOp && parentOp->hasTrait()) { + return rewriter.notifyMatchFailure( + op, "cannot lower scf.while inside a single-block parent region"); + } + + if (failed(validateWhileResultUses(op))) + return rewriter.notifyMatchFailure( + op, "unsupported: while results used outside the parent block"); + + auto loc = op.getLoc(); + Block *afterWhileBlock = splitAfterWhileBlock(rewriter, op); + addWhileExitArguments(rewriter, op, loc, afterWhileBlock); + Block *headerBlock = createWhileHeaderBlock(rewriter, op, loc, + afterWhileBlock); + Block *bodyBlock = createWhileBodyBlock(rewriter, op, loc, afterWhileBlock); + + // Move the before/after region bodies into the new CFG blocks. + Block &afterRegionBlock = op.getAfter().front(); + rewriter.mergeBlocks(&op.getBefore().front(), headerBlock, + headerBlock->getArguments()); + rewriter.mergeBlocks(&afterRegionBlock, bodyBlock, bodyBlock->getArguments()); + rewriteWhileTerminators(rewriter, loc, headerBlock, bodyBlock, + afterWhileBlock); + + // Replace scf.while itself with a branch to the header. + rewriter.setInsertionPoint(op); + rewriter.create(loc, headerBlock, op.getInits()); + rewriter.eraseOp(op); + return success(); + } +}; + +// Lower cf.switch into chained comparisons and cf.cond_br/cf.br. +// +// EmitC C++ translation currently supports cf.br/cf.cond_br, but not cf.switch. +struct CFSwitchToCondBr : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + static SmallVector> + collectSwitchCaseOperands(cf::SwitchOp op) { + SmallVector> caseOperands; + caseOperands.reserve(op.getCaseDestinations().size()); + for (auto range : op.getCaseOperands()) + caseOperands.emplace_back(range.begin(), range.end()); + return caseOperands; + } + + static SmallVector getSwitchCaseValues(cf::SwitchOp op) { + SmallVector caseValues; + if (auto caseValuesAttr = op.getCaseValues()) { + for (APInt value : caseValuesAttr->getValues()) + caseValues.push_back(value); + } + return caseValues; + } + + static SmallVector createSwitchCheckBlocks(PatternRewriter &rewriter, + Region *parentRegion, + Block *curBlock, + size_t numCases) { + auto insertPt = std::next(curBlock->getIterator()); + SmallVector checkBlocks; + checkBlocks.reserve(numCases); + for (size_t i = 0; i < numCases; ++i) + checkBlocks.push_back(rewriter.createBlock(parentRegion, insertPt)); + return checkBlocks; + } + + static LogicalResult populateSwitchCheckBlocks( + PatternRewriter &rewriter, Location loc, Value flag, IntegerType flagTy, + ArrayRef caseValues, ArrayRef caseDests, + ArrayRef> caseOperands, Block *defaultDest, + ValueRange defaultOperands, ArrayRef checkBlocks, + cf::SwitchOp op) { + for (size_t i = 0; i < caseDests.size(); ++i) { + rewriter.setInsertionPointToEnd(checkBlocks[i]); + APInt caseVal = caseValues[i]; + if (caseVal.getBitWidth() != flagTy.getWidth()) { + return rewriter.notifyMatchFailure( + op, "case value bitwidth doesn't match flag type"); + } + + Value caseConst = rewriter.create( + loc, flagTy, rewriter.getIntegerAttr(flagTy, caseVal)); + Value cond = rewriter.create( + loc, arith::CmpIPredicate::eq, flag, caseConst); + Block *falseDest = + (i + 1 < checkBlocks.size()) ? checkBlocks[i + 1] : defaultDest; + ValueRange falseOperands = + (i + 1 < checkBlocks.size()) ? ValueRange{} : defaultOperands; + rewriter.create(loc, cond, caseDests[i], + caseOperands[i], falseDest, + falseOperands); + } + return success(); + } + + LogicalResult matchAndRewrite(cf::SwitchOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Operation *parentOp = op->getParentOp(); + if (parentOp && parentOp->hasTrait()) { + return rewriter.notifyMatchFailure( + op, "cannot lower cf.switch inside a single-block parent region"); + } + + Block *curBlock = op->getBlock(); + Region *parentRegion = curBlock->getParent(); + + Value flag = op.getFlag(); + auto flagTy = dyn_cast(flag.getType()); + if (!flagTy) + return rewriter.notifyMatchFailure(op, "expected integer switch flag"); + + SmallVector defaultOperands(op.getDefaultOperands().begin(), + op.getDefaultOperands().end()); + Block *defaultDest = op.getDefaultDestination(); + + SmallVector caseDests(op.getCaseDestinations().begin(), + op.getCaseDestinations().end()); + SmallVector> caseOperands = collectSwitchCaseOperands(op); + + if (caseDests.empty()) { + rewriter.replaceOpWithNewOp(op, defaultDest, defaultOperands); + return success(); + } + + if (!op.getCaseValues()) + return rewriter.notifyMatchFailure(op, "missing case_values"); + SmallVector caseValues = getSwitchCaseValues(op); + + if (caseValues.size() != caseDests.size()) + return rewriter.notifyMatchFailure(op, "case_values/destinations mismatch"); + if (caseOperands.size() != caseDests.size()) + return rewriter.notifyMatchFailure(op, "case_operands/destinations mismatch"); + + SmallVector checkBlocks = + createSwitchCheckBlocks(rewriter, parentRegion, curBlock, + caseDests.size()); + if (failed(populateSwitchCheckBlocks(rewriter, loc, flag, flagTy, + caseValues, caseDests, caseOperands, + defaultDest, defaultOperands, + checkBlocks, op))) { + return failure(); + } + + // Replace the switch terminator with a branch into the first check block. + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp(op, checkBlocks.front(), + ValueRange{}); + return success(); + } +}; + +} // namespace + + +} // namespace + +LogicalResult runPTOToEmitCSCFPreLowering(ModuleOp mop, MLIRContext *ctx) { + bool needsAnySCFToCF = false; + for (auto func : mop.getOps()) { + if (needsWholeFunctionSCFToCF(func)) { + needsAnySCFToCF = true; + break; + } + } + if (needsAnySCFToCF) { + RewritePatternSet scfToCfPatterns(ctx); + populateSCFToControlFlowConversionPatterns(scfToCfPatterns); + FrozenRewritePatternSet frozenSCFToCF(std::move(scfToCfPatterns)); + + ConversionTarget scfToCfTarget(*ctx); + scfToCfTarget.addIllegalOp(); + scfToCfTarget.markUnknownOpDynamicallyLegal( + [](Operation *) { return true; }); + + for (auto func : mop.getOps()) { + if (!needsWholeFunctionSCFToCF(func)) + continue; + if (failed(applyPartialConversion(func, scfToCfTarget, + frozenSCFToCF))) { + func.emitError() + << "failed to lower nested SCF to ControlFlow (SCFToCF)"; + return failure(); + } + } + } + + RewritePatternSet scfLoweringPatterns(ctx); + scfLoweringPatterns.add(ctx); + (void)applyPatternsAndFoldGreedily(mop, std::move(scfLoweringPatterns)); + + bool hasUnsupportedSCF = false; + mop.walk([&](Operation *op) { + if (isa(op)) { + hasUnsupportedSCF = true; + op->emitError() << "Unsupported SCF op remained after pre-lowering"; + return WalkResult::interrupt(); + } + if (isa(op)) { + hasUnsupportedSCF = true; + op->emitError() + << "Unsupported CF op remained after pre-lowering: cf.switch"; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return failure(hasUnsupportedSCF); +} + +void populatePTOToEmitCControlFlowPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx) { + patterns.add>(typeConverter, ctx); + patterns.add>(typeConverter, ctx); + patterns.add(typeConverter, ctx); + populateSCFToEmitCConversionPatterns(patterns); + populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); +} + +} // namespace mlir::pto diff --git a/lib/PTO/Transforms/PTOToEmitCInternal.h b/lib/PTO/Transforms/PTOToEmitCInternal.h new file mode 100644 index 000000000..e6c039c91 --- /dev/null +++ b/lib/PTO/Transforms/PTOToEmitCInternal.h @@ -0,0 +1,151 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef MLIR_DIALECT_PTO_TRANSFORMS_PTOTOEMITCINTERNAL_H +#define MLIR_DIALECT_PTO_TRANSFORMS_PTOTOEMITCINTERNAL_H + +#pragma GCC diagnostic ignored "-Woverloaded-virtual" +// GCC warns on MLIR OpConversionPattern helper overloads hiding RewritePattern::rewrite. + +#include "PTO/IR/PTO.h" + +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +#include +#include + +namespace mlir::pto { + +Value peelUnrealized(Value v); + +Value makeEmitCOpaqueConstant(ConversionPatternRewriter &rewriter, + Location loc, Type type, + llvm::StringRef literal); + +Value makeEmitCIntConstant(ConversionPatternRewriter &rewriter, Location loc, + Type type, int64_t value); + +Value emitCCast(ConversionPatternRewriter &rewriter, Location loc, Type dstType, + Value src); + +Value castSignlessIntToUnsignedSameWidth(ConversionPatternRewriter &rewriter, + Location loc, Value v, + unsigned bitWidth); + +std::string getEmitCScalarTypeToken(Type elemTy); + +pto::BLayout getTileBufBLayoutValue(pto::TileBufConfigAttr configAttr); + +int64_t renderTileTemplateDim(int64_t rawDim, Type elemTy, + pto::BLayout blayout, int dimIdx); + +std::optional getEmitCTileTypeString(pto::TileBufType type); + +bool isSetFFTsPointerLikeType(Type ty); + +bool isEmitCGlobalTensorLikeType(Type ty); + +std::string getGlobalTensorTypeStringFromShapeAndStrides( + Type elemTy, ArrayRef shape, ArrayRef strides, + llvm::StringRef layoutEnum = "pto::Layout::ND"); + +std::string getElemTypeStringForGT(Type elemTy); + +SmallVector buildRowMajorStrides(ArrayRef shape); + +void buildGlobalTensorShapeAndStride(ArrayRef shape, + ArrayRef strides, + SmallVectorImpl &shape5D, + SmallVectorImpl &stride5D); + +std::string joinIntTemplateParams(ArrayRef values); + +Value buildGlobalTensorFromMemref(ConversionPatternRewriter &rewriter, + Location loc, Value basePtr, + MemRefType mrTy, Operation *anchor); + +FailureOr buildTPipeTokenFromInitOp(Operation *op, + PTOArch targetArch); + +Value castToGMBytePointer(ConversionPatternRewriter &rewriter, Location loc, + Value value); + +Value materializeTensorViewDataPointer(ConversionPatternRewriter &rewriter, + Location loc, Value value, + Type originalType); + +Value materializeAddressAsPointer(ConversionPatternRewriter &rewriter, + Location loc, Value addr, + pto::AddressSpace as, + llvm::StringRef elemTok); + +Value applyStaticMemrefOffset(ConversionPatternRewriter &rewriter, + Location loc, Value basePtr, int64_t offset); + +FailureOr buildAsyncScratchTileValue(ConversionPatternRewriter &rewriter, + Location loc, Value originalScratch, + Value emittedScratch); + +bool needsA5NoSplitVectorGuard(Operation *op); + +Value materializeTileDataValue(ConversionPatternRewriter &rewriter, + Location loc, Value tile, + pto::AddressSpace as, + llvm::StringRef elemTypeToken); + +void populatePTOToEmitCArithPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx); + +void populatePTOToEmitCTilePatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx); + +void populatePTOToEmitCTileExtraPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx); + +void populatePTOToEmitCTileMaterializationPatterns( + RewritePatternSet &patterns, TypeConverter &typeConverter, MLIRContext *ctx); + +void populatePTOToEmitCSyncPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx, PTOArch targetArch); + +void populatePTOToEmitCCommPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx, PTOArch targetArch); + +void populatePTOToEmitCKernelOpPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx); + +LogicalResult runPTOToEmitCSCFPreLowering(ModuleOp mop, MLIRContext *ctx); + +void populatePTOToEmitCControlFlowPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx); + +void populatePTOToEmitCSimpleOpPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx); + +void populatePTOToEmitCRuntimeOpPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx, PTOArch targetArch); + +void populatePTOToEmitCMemoryOpPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx); + +} // namespace mlir::pto + +#endif // MLIR_DIALECT_PTO_TRANSFORMS_PTOTOEMITCINTERNAL_H diff --git a/lib/PTO/Transforms/PTOToEmitCKernelOps.cpp b/lib/PTO/Transforms/PTOToEmitCKernelOps.cpp new file mode 100644 index 000000000..e0a80102d --- /dev/null +++ b/lib/PTO/Transforms/PTOToEmitCKernelOps.cpp @@ -0,0 +1,516 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- PTOToEmitCKernelOps.cpp --------------------------------------------===// +//===----------------------------------------------------------------------===// + +#include "PTOToEmitCInternal.h" + +#include "PTO/IR/PTO.h" + +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +#include + +using namespace mlir; +using namespace mlir::pto; + +namespace mlir::pto { +namespace { + +struct PTOTLoadToTLOAD : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TLoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getDst()) + return rewriter.notifyMatchFailure(op, "expected outs(dst) on pto.tload"); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value srcArg = src; + if (auto srcMrTy = dyn_cast(op.getSrc().getType())) { + bool isGlobal = true; + if (auto asAttr = dyn_cast_or_null(srcMrTy.getMemorySpace())) { + auto as = asAttr.getAddressSpace(); + isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); + } + if (isGlobal) { + if (Value gt = buildGlobalTensorFromMemref(rewriter, op.getLoc(), src, srcMrTy, + op.getOperation())) + srcArg = gt; + } + } + + rewriter.create( + op.getLoc(), TypeRange{}, "TLOAD", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, srcArg}); + + if (op->getNumResults() == 1) { + rewriter.replaceOp(op, dst); + } else { + rewriter.eraseOp(op); + } + return success(); + } +}; + +struct PTOTPrefetchToTPREFETCH : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPrefetchOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getDst()) + return rewriter.notifyMatchFailure(op, "expected outs(dst) on pto.tprefetch"); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value srcArg = src; + if (auto srcMrTy = dyn_cast(op.getSrc().getType())) { + bool isGlobal = true; + if (auto asAttr = dyn_cast_or_null(srcMrTy.getMemorySpace())) { + auto as = asAttr.getAddressSpace(); + isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); + } + if (isGlobal) { + if (Value gt = buildGlobalTensorFromMemref(rewriter, op.getLoc(), src, srcMrTy, + op.getOperation())) + srcArg = gt; + } + } + + rewriter.create( + op.getLoc(), TypeRange{}, "TPREFETCH", + ArrayAttr{}, ArrayAttr{}, ValueRange{dst, srcArg}); + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOTPrefetchAsyncToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPrefetchAsyncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = peelUnrealized(adaptor.getSrc()); + Value srcArg = src; + if (!isEmitCGlobalTensorLikeType(srcArg.getType())) { + auto srcMrTy = dyn_cast(op.getSrc().getType()); + if (!srcMrTy) + return rewriter.notifyMatchFailure( + op, "expected src to lower to GlobalTensor or memref"); + srcArg = buildGlobalTensorFromMemref(rewriter, op.getLoc(), src, srcMrTy, + op.getSrc().getDefiningOp() + ? op.getSrc().getDefiningOp() + : op.getOperation()); + } + if (!srcArg) + return rewriter.notifyMatchFailure(op, + "failed to build GlobalTensor src"); + + Value prefetchCtx = peelUnrealized(adaptor.getCtx()); + + Type eventTy = getTypeConverter()->convertType(op.getEvent().getType()); + if (!eventTy) + return rewriter.notifyMatchFailure( + op, "failed to convert tprefetch_async result type"); + + Value event = rewriter + .create( + op.getLoc(), TypeRange{eventTy}, "TPREFETCH_ASYNC", + ArrayAttr{}, ArrayAttr{}, + ValueRange{srcArg, prefetchCtx}) + .getResult(0); + + rewriter.replaceOp(op, ValueRange{event}); + return success(); + } +}; + +struct PTOMakePrefetchAsyncContextToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::MakePrefetchAsyncContextOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type ctxTy = getTypeConverter()->convertType(op.getCtx().getType()); + if (!ctxTy) + return rewriter.notifyMatchFailure( + op, "failed to convert make_prefetch_async_context result type"); + + Value workspace = peelUnrealized(adaptor.getWorkspace()); + workspace = castToGMBytePointer(rewriter, op.getLoc(), workspace); + + Value ctx = rewriter + .create( + op.getLoc(), TypeRange{ctxTy}, "pto::PrefetchAsyncContext", + ArrayAttr{}, ArrayAttr{}, ValueRange{workspace}) + .getResult(0); + + rewriter.replaceOp(op, ValueRange{ctx}); + return success(); + } +}; + +struct PTOGetPrefetchAsyncSessionToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::GetPrefetchAsyncSessionOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type sessionTy = getTypeConverter()->convertType(op.getSession().getType()); + if (!sessionTy) + return rewriter.notifyMatchFailure( + op, "failed to convert get_prefetch_async_session result type"); + + Value ctx = peelUnrealized(adaptor.getCtx()); + Value session = rewriter + .create( + op.getLoc(), TypeRange{sessionTy}, + "PTOAS__PREFETCH_CTX_SESSION", ArrayAttr{}, + ArrayAttr{}, ValueRange{ctx}) + .getResult(0); + + rewriter.replaceOp(op, ValueRange{session}); + return success(); + } +}; + +struct PTOTStoreToTSTORE : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + static std::string stPhaseTok(pto::STPhase phase) { + switch (phase) { + case pto::STPhase::Unspecified: return "STPhase::Unspecified"; + case pto::STPhase::Partial: return "STPhase::Partial"; + case pto::STPhase::Final: return "STPhase::Final"; + } + return "STPhase::Unspecified"; + } + + static std::string atomicTypeTok(pto::AtomicType atomicType) { + switch (atomicType) { + case pto::AtomicType::AtomicNone: return "AtomicType::AtomicNone"; + case pto::AtomicType::AtomicAdd: return "AtomicType::AtomicAdd"; + } + return "AtomicType::AtomicNone"; + } + + static std::string reluPreModeTok(pto::ReluPreMode reluPreMode) { + switch (reluPreMode) { + case pto::ReluPreMode::NoRelu: return "ReluPreMode::NoRelu"; + case pto::ReluPreMode::NormalRelu: return "ReluPreMode::NormalRelu"; + } + return "ReluPreMode::NoRelu"; + } + + LogicalResult matchAndRewrite(pto::TStoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getDst()) + return rewriter.notifyMatchFailure(op, "expected outs(dst) on pto.tstore"); + + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value preQuantScalar; + if (op.getPreQuantScalar()) + preQuantScalar = peelUnrealized(adaptor.getPreQuantScalar()); + Value dstArg = dst; + if (auto dstMrTy = dyn_cast(op.getDst().getType())) { + bool isGlobal = true; + if (auto asAttr = dyn_cast_or_null(dstMrTy.getMemorySpace())) { + auto as = asAttr.getAddressSpace(); + isGlobal = (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); + } + if (isGlobal) { + if (Value gt = buildGlobalTensorFromMemref(rewriter, op.getLoc(), dst, dstMrTy, + op.getOperation())) + dstArg = gt; + } + } + + const auto phase = op.getStPhase(); + const auto atomicType = op.getAtomicType(); + const auto reluPreMode = op.getReluPreMode(); + const bool hasPreQuantScalar = static_cast(preQuantScalar); + const bool phaseNonDefault = phase != pto::STPhase::Unspecified; + const bool atomicNonDefault = atomicType != pto::AtomicType::AtomicNone; + const bool reluNonDefault = reluPreMode != pto::ReluPreMode::NoRelu; + + auto getOpaqueTok = [&](Value v, StringRef name) -> FailureOr { + if (auto ot = mlir::dyn_cast(v.getType())) + return ot.getValue().str(); + return rewriter.notifyMatchFailure(op, (name + " must be emitc::OpaqueType").str()); + }; + + ArrayAttr targs; + // Map op attributes/operands to the exact TSTORE overload family: + // 1) TSTORE(dst, src) + // 2) TSTORE(dst, src) + // 3) TSTORE(dst, src) + // 4) TSTORE(dst, src) + // 5) TSTORE(dst, src) + // 6) TSTORE(dst, src) + // 7) TSTORE(dst, src, preQuant) + // 8) TSTORE(dst, src, preQuant) + if (!hasPreQuantScalar && !reluNonDefault && !atomicNonDefault) { + if (phaseNonDefault) { + targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, stPhaseTok(phase)), + }); + } else { + targs = ArrayAttr{}; + } + } else { + auto srcTokOr = getOpaqueTok(src, "src"); + auto dstTokOr = getOpaqueTok(dstArg, "dst"); + if (failed(srcTokOr) || failed(dstTokOr)) + return failure(); + + // If there is no preQuant and relu stays default, emit the atomic-only + // overloads (#3/#4) without ReluPreMode template argument. + if (!hasPreQuantScalar && !reluNonDefault) { + if (phaseNonDefault) { + targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, stPhaseTok(phase)), + emitc::OpaqueAttr::get(ctx, *srcTokOr), + emitc::OpaqueAttr::get(ctx, *dstTokOr), + emitc::OpaqueAttr::get(ctx, atomicTypeTok(atomicType)), + }); + } else { + targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, *srcTokOr), + emitc::OpaqueAttr::get(ctx, *dstTokOr), + emitc::OpaqueAttr::get(ctx, atomicTypeTok(atomicType)), + }); + } + } else { + // Relu/preQuant families (#5/#6/#7/#8): keep AtomicType + ReluPreMode. + if (phaseNonDefault) { + targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, stPhaseTok(phase)), + emitc::OpaqueAttr::get(ctx, *srcTokOr), + emitc::OpaqueAttr::get(ctx, *dstTokOr), + emitc::OpaqueAttr::get(ctx, atomicTypeTok(atomicType)), + emitc::OpaqueAttr::get(ctx, reluPreModeTok(reluPreMode)), + }); + } else { + targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, *srcTokOr), + emitc::OpaqueAttr::get(ctx, *dstTokOr), + emitc::OpaqueAttr::get(ctx, atomicTypeTok(atomicType)), + emitc::OpaqueAttr::get(ctx, reluPreModeTok(reluPreMode)), + }); + } + } + } + + SmallVector operands{dstArg, src}; + if (hasPreQuantScalar) + operands.push_back(preQuantScalar); + + rewriter.create( + loc, TypeRange{}, "TSTORE", + /*args=*/ArrayAttr{}, /*templateArgs=*/targs, + /*operands=*/operands); + + if (op->getNumResults() == 1) { + rewriter.replaceOp(op, dst); + } else { + rewriter.eraseOp(op); + } + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// pto.matmul_dps lowering (Simplified: No internal copy/sync) +//===----------------------------------------------------------------------===// +// +// Render `pto.tmatmul` as one of three forms depending on the optional +// `acc_phase` attribute: +// * absent / Unspecified -> `TMATMUL(dst, lhs, rhs)` +// * Partial -> `TMATMUL(dst, lhs, rhs)` +// * Final -> `TMATMUL(dst, lhs, rhs)` +// The Unspecified default keeps backward compatibility with all upstream IR +// that does not yet emit an explicit phase attribute. +static ArrayAttr buildAccPhaseTemplateArgs(ConversionPatternRewriter &rewriter, + pto::AccPhase phase) { + StringRef tmpl; + switch (phase) { + case pto::AccPhase::Unspecified: + return ArrayAttr{}; + case pto::AccPhase::Partial: + tmpl = "AccPhase::Partial"; + break; + case pto::AccPhase::Final: + tmpl = "AccPhase::Final"; + break; + } + if (tmpl.empty()) + return ArrayAttr{}; + return rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(rewriter.getContext(), tmpl)}); +} + +struct PTOTMatmulToTMATMUL : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMatmulOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // 1. 获取操作数 (剥离 Cast) + Value lhs = peelUnrealized(adaptor.getLhs()); // A (Left) + Value rhs = peelUnrealized(adaptor.getRhs()); // B (Right) + Value dst = peelUnrealized(adaptor.getDst()); // C (Acc) + + // 2. 根据 acc_phase 属性决定是否生成 TMATMUL(...) + ArrayAttr templateArgs = + buildAccPhaseTemplateArgs(rewriter, op.getAccPhase()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TMATMUL", + /*args=*/ArrayAttr{}, /*template_args=*/templateArgs, + ValueRange{dst, lhs, rhs}); + + // 3. 处理 Op 替换/删除 + if (op->getNumResults() == 1) { + rewriter.replaceOp(op, dst); + } else { + rewriter.eraseOp(op); + } + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// pto.tgemv lowering +//===----------------------------------------------------------------------===// +struct PTOTGemvToTGEMV : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGemvOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // 1. 获取操作数 (剥离 Cast) + Value lhs = peelUnrealized(adaptor.getLhs()); // A (Matrix) + Value rhs = peelUnrealized(adaptor.getRhs()); // B (Vector) + Value dst = peelUnrealized(adaptor.getDst()); // C (Result) + + // 2. 直接生成函数调用 TGEMV(dst, lhs, rhs) + rewriter.create( + op.getLoc(), TypeRange{}, "TGEMV", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, lhs, rhs}); + + // 3. 处理 Op 替换/删除 + if (op->getNumResults() == 1) { + rewriter.replaceOp(op, dst); + } else { + rewriter.eraseOp(op); + } + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// pto.tgemv.acc lowering +//===----------------------------------------------------------------------===// +struct PTOTGemvAccToTGEMVACC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGemvAccOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getDst()) + return rewriter.notifyMatchFailure(op, "expected outs(dst) for pto.tgemv.acc"); + + // 1. 获取操作数 + Value accIn = peelUnrealized(adaptor.getAccIn()); // AccOld + Value lhs = peelUnrealized(adaptor.getLhs()); // A (Matrix) + Value rhs = peelUnrealized(adaptor.getRhs()); // B (Vector) + Value dst = peelUnrealized(adaptor.getDst()); // AccNew + + // 2. 直接生成函数调用 TGEMV_ACC(dst, accIn, lhs, rhs) + rewriter.create( + op.getLoc(), TypeRange{}, "TGEMV_ACC", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, accIn, lhs, rhs}); + + // 3. 处理 Op 替换/删除 + if (op->getNumResults() == 1) { + rewriter.replaceOp(op, dst); + } else { + rewriter.eraseOp(op); + } + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// pto.matmul_acc_dps lowering (Simplified: No internal copy/sync) +//===----------------------------------------------------------------------===// +struct PTOTMatmulAccToTMATMULACC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMatmulAccOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getDst()) + return rewriter.notifyMatchFailure(op, "expected outs(dst) for pto.tmatmul.acc"); + + // 1. 获取操作数 + Value accIn = peelUnrealized(adaptor.getAccIn()); // AccOld + Value lhs = peelUnrealized(adaptor.getLhs()); // A (Left) + Value rhs = peelUnrealized(adaptor.getRhs()); // B (Right) + Value dst = peelUnrealized(adaptor.getDst()); // AccNew + + // 2. 根据 acc_phase 属性决定是否生成 TMATMUL_ACC(...) + ArrayAttr templateArgs = + buildAccPhaseTemplateArgs(rewriter, op.getAccPhase()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TMATMUL_ACC", + /*args=*/ArrayAttr{}, /*template_args=*/templateArgs, + ValueRange{dst, accIn, lhs, rhs}); + + // 3. 处理 Op 替换/删除 + if (op->getNumResults() == 1) { + rewriter.replaceOp(op, dst); + } else { + rewriter.eraseOp(op); + } + return success(); + } +}; + + +} // namespace + +void populatePTOToEmitCKernelOpPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx) { + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); +} + +} // namespace mlir::pto diff --git a/lib/PTO/Transforms/PTOToEmitCMemoryOps.cpp b/lib/PTO/Transforms/PTOToEmitCMemoryOps.cpp new file mode 100644 index 000000000..dba225b3c --- /dev/null +++ b/lib/PTO/Transforms/PTOToEmitCMemoryOps.cpp @@ -0,0 +1,597 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- PTOToEmitCMemoryOps.cpp --------------------------------------------===// +//===----------------------------------------------------------------------===// + +#include "PTOToEmitCInternal.h" + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +#include +#include + +using namespace mlir; +using namespace mlir::pto; + +namespace mlir::pto { +namespace { + +static constexpr llvm::StringLiteral kForceDynamicValidShapeAttrName = + "__pto.force_dynamic_valid_shape"; + +struct PointerCastConversion : public OpConversionPattern { + static bool getIndexConst(Value v, int64_t &out) { + if (auto cst = v.getDefiningOp()) { + if (auto ia = dyn_cast(cst.getValue())) { + out = ia.getValue().getSExtValue(); + return true; + } + } + return false; + } + + using OpConversionPattern::OpConversionPattern; + + enum class TileRole { Vec, Mat, Left, Right, Acc, Bias, Scaling }; + + static void collectUserOpsThroughCasts(Value v, SmallVectorImpl &out) { + for (Operation *u : v.getUsers()) { + if (auto castOp = dyn_cast(u)) { + for (Value r : castOp.getResults()) + collectUserOpsThroughCasts(r, out); + continue; + } + out.push_back(u); + } + } + + static Value peelUnrealized(Value v) { + while (auto castOp = v.getDefiningOp()) { + v = castOp.getOperand(0); + } + return v; + } + + static TileRole inferRole(pto::PointerCastOp op) { + // 1. 优先检查 AddressSpace + if (auto memRefTy = dyn_cast(op.getType())) { + Attribute memorySpace = memRefTy.getMemorySpace(); + if (auto ptoAttr = dyn_cast_or_null(memorySpace)) { + switch (ptoAttr.getAddressSpace()) { + case pto::AddressSpace::LEFT: return TileRole::Left; + case pto::AddressSpace::RIGHT: return TileRole::Right; + case pto::AddressSpace::ACC: return TileRole::Acc; + case pto::AddressSpace::BIAS: return TileRole::Bias; + case pto::AddressSpace::MAT: return TileRole::Mat; + case pto::AddressSpace::SCALING: return TileRole::Scaling; + default: break; + } + } + } + + // 2. 通过 Usage 推导 (Fallback) + SmallVector users; + collectUserOpsThroughCasts(op.getResult(), users); + + for (Operation *user : users) { + if (auto mm = dyn_cast(user)) { + if (mm.getDst() && peelUnrealized(mm.getDst()) == op.getResult()) return TileRole::Acc; + if (peelUnrealized(mm.getLhs()) == op.getResult()) return TileRole::Left; + if (peelUnrealized(mm.getRhs()) == op.getResult()) return TileRole::Right; + } + if (auto mmacc = dyn_cast(user)) { + if (mmacc.getDst() && peelUnrealized(mmacc.getDst()) == op.getResult()) return TileRole::Acc; + if (peelUnrealized(mmacc.getAccIn()) == op.getResult()) return TileRole::Acc; + if (peelUnrealized(mmacc.getLhs()) == op.getResult()) return TileRole::Left; + if (peelUnrealized(mmacc.getRhs()) == op.getResult()) return TileRole::Right; + } + } + + return TileRole::Vec; + } + + // [新增] 辅助函数:判断 Value 是否源自 arith.constant + static bool isConstant(Value v, int64_t &outVal) { + if (!v) return false; + if (auto cst = v.getDefiningOp()) { + if (auto attr = dyn_cast(cst.getValue())) { + outVal = attr.getInt(); + return true; + } + } + return false; + } + + LogicalResult matchAndRewrite(pto::PointerCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto selfType = mlir::cast(op.getType()); + ArrayRef shape = selfType.getShape(); + Type elemType = selfType.getElementType(); + + // 1. 推导 Tile Role + TileRole role = inferRole(op); + + // 2. 类型字符串生成 (elemTypeStr, dimStr) + std::string elemTypeStr = getEmitCScalarTypeToken(elemType); + + std::string dimStr; + pto::BLayout blayout = pto::BLayout::RowMajor; + auto dimToString = [&](int64_t dim, const char *symbol, + int dimIdx) -> std::string { + if (dim == ShapedType::kDynamic) + return std::string(symbol); + return std::to_string(renderTileTemplateDim(dim, elemType, blayout, + dimIdx)); + }; + + // 3. Role Token + const char *roleTok = "TileType::Vec"; + switch (role) { + case TileRole::Left: roleTok = "TileType::Left"; break; + case TileRole::Right: roleTok = "TileType::Right"; break; + case TileRole::Acc: roleTok = "TileType::Acc"; break; + case TileRole::Bias: roleTok = "TileType::Bias"; break; + case TileRole::Mat: roleTok = "TileType::Mat"; break; + case TileRole::Vec: roleTok = "TileType::Vec"; break; + case TileRole::Scaling: roleTok = "TileType::Scaling"; break; + } + + // 4. Config & Layout (support BLayoutAttr/SLayoutAttr/PadValueAttr after namespace change) + std::string layoutParams = "BLayout::RowMajor"; + std::string extraParams = ""; + if (auto configOpt = op.getConfig()) { + auto config = *configOpt; + int32_t blVal = 0; + if (auto attr = dyn_cast(config.getBLayout())) + blVal = static_cast(attr.getValue()); + + if (blVal == 1) layoutParams = "BLayout::ColMajor"; + blayout = blVal == 1 ? pto::BLayout::ColMajor : pto::BLayout::RowMajor; + + int32_t slVal = 0; + if (auto attr = dyn_cast(config.getSLayout())) + slVal = static_cast(attr.getValue()); + + std::string slStr = (slVal == 1) ? "SLayout::RowMajor" : (slVal == 2) ? "SLayout::ColMajor" : "SLayout::NoneBox"; + + int32_t frVal = 0; + if (auto attr = dyn_cast(config.getSFractalSize())) frVal = attr.getInt(); + + int32_t padVal = 0; + if (auto attr = dyn_cast(config.getPad())) + padVal = static_cast(attr.getValue()); + + std::string padStr = "PadValue::Null"; + switch (padVal) { + case 1: padStr = "PadValue::Zero"; break; + case 2: padStr = "PadValue::Max"; break; + case 3: padStr = "PadValue::Min"; break; + } + + int32_t compactVal = 0; + if (auto attr = dyn_cast(config.getCompactMode())) + compactVal = static_cast(attr.getValue()); + + std::string compactStr = "CompactMode::Null"; + switch (compactVal) { + case 1: compactStr = "CompactMode::Normal"; break; + case 2: compactStr = "CompactMode::RowPlusOne"; break; + } + + if (!slStr.empty()) { + extraParams += ", " + slStr + ", " + std::to_string(frVal) + ", " + + padStr + ", " + compactStr; + } + } else { + extraParams = ", SLayout::NoneBox, 512, PadValue::Null, CompactMode::Null"; + } + + if (role == TileRole::Left) + dimStr = dimToString(shape[0], "M", 0) + ", " + + dimToString(shape[1], "K", 1); + else if (role == TileRole::Right) + dimStr = dimToString(shape[0], "K", 0) + ", " + + dimToString(shape[1], "N", 1); + else if (role == TileRole::Bias) + dimStr = "1, " + dimToString(shape[1], "N", 1); + else + dimStr = dimToString(shape[0], "M", 0) + ", " + + dimToString(shape[1], "N", 1); + + // [核心修改] Valid Dims 处理逻辑 (支持混合静态/动态) + std::string vrowTok, vcolTok; + bool useConstructor = false; + + bool rowIsDynamic = false; + bool colIsDynamic = false; + + SmallVector constructorArgs; + + Value vRow = op.getValidRow(); + Value vCol = op.getValidCol(); + Value vRowEmitC = adaptor.getValidRow(); + Value vColEmitC = adaptor.getValidCol(); + bool forceDynamicValid = op->hasAttr(kForceDynamicValidShapeAttrName); + + int64_t cRow = 0, cCol = 0; + bool rowIsConst = vRow && isConstant(vRow, cRow); + bool colIsConst = vCol && isConstant(vCol, cCol); + + auto makeCtorDimValue = [&](Value emitted, int64_t fallback) -> Value { + if (emitted) + return emitted; + return makeEmitCIntConstant( + rewriter, loc, emitc::OpaqueType::get(ctx, "int32_t"), fallback); + }; + auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { + if (!emitted || !pto::isPTOFloat4PackedType(elemType)) + return emitted; + int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; + if (dimIdx != packedDim) + return emitted; + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); + return rewriter.create(loc, i32Ty, emitted, two).getResult(); + }; + + if (forceDynamicValid) { + vrowTok = "-1"; + vcolTok = "-1"; + useConstructor = true; + constructorArgs.push_back( + makeCtorDimValue(maybeScaleDynamicValid(vRowEmitC, 0), + renderTileTemplateDim(rowIsConst ? cRow : shape[0], + elemType, blayout, 0))); + constructorArgs.push_back( + makeCtorDimValue(maybeScaleDynamicValid(vColEmitC, 1), + renderTileTemplateDim(colIsConst ? cCol : shape[1], + elemType, blayout, 1))); + } else { + if (rowIsConst) { + vrowTok = std::to_string( + renderTileTemplateDim(cRow, elemType, blayout, 0)); + } else if (vRow) { + vrowTok = "-1"; + rowIsDynamic = true; + useConstructor = true; + } else { + vrowTok = std::to_string( + renderTileTemplateDim(shape[0], elemType, blayout, 0)); + } + + if (colIsConst) { + vcolTok = std::to_string( + renderTileTemplateDim(cCol, elemType, blayout, 1)); + } else if (vCol) { + vcolTok = "-1"; + colIsDynamic = true; + useConstructor = true; + } else { + vcolTok = std::to_string( + renderTileTemplateDim(shape[1], elemType, blayout, 1)); + } + + if (useConstructor) { + if (rowIsDynamic && vRowEmitC) + constructorArgs.push_back(maybeScaleDynamicValid(vRowEmitC, 0)); + if (colIsDynamic && vColEmitC) + constructorArgs.push_back(maybeScaleDynamicValid(vColEmitC, 1)); + } + } + + // 5. 生成 Tile 类型字符串 + std::string tileTypeStr = + std::string("Tile<") + roleTok + ", " + elemTypeStr + ", " + dimStr + ", " + + layoutParams + ", " + vrowTok + ", " + vcolTok + extraParams + ">"; + + auto tileType = emitc::OpaqueType::get(ctx, tileTypeStr); + Value resultValue; + + if (useConstructor) { + // 使用 CallOpaqueOp 生成构造函数调用 (Tile v = Tile(...)) + auto ctorOp = rewriter.create( + loc, + tileType, // Result Type + tileTypeStr, // Callee Name (类名) + ArrayAttr{}, // args + ArrayAttr{}, // template_args + ValueRange(constructorArgs) // operands + ); + resultValue = ctorOp.getResult(0); + } else { + // 静态情况 (Tile v;) + auto varOp = rewriter.create( + loc, + tileType, + emitc::OpaqueAttr::get(ctx, "") + ); + resultValue = varOp.getResult(); + } + + // TASSIGN: pto-isa expects an integral address. + Value addr = adaptor.getAddrs()[0]; + if (isa(addr.getType()) || + (isa(addr.getType()) && + cast(addr.getType()).getValue().ends_with("*"))) { + auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); + auto rcU64 = rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + addr = rewriter.create( + loc, u64Ty, "reinterpret_cast", + /*args=*/ArrayAttr{}, /*templateArgs=*/rcU64, + /*operands=*/ValueRange{addr}) + .getResult(0); + } + + rewriter.create( + loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{resultValue, addr}); + + rewriter.replaceOp(op, resultValue); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// pto.load_dps / pto.store_dps lowering (FIX: keep optional result) +//===----------------------------------------------------------------------=== + +// GetBlockIdxOp Lowering (pto.get_block_idx -> get_block_idx()) + + +static std::optional getStaticIndexLikeValue(Value value) { + if (!value) + return std::nullopt; + if (auto cst = value.getDefiningOp()) + return cst.value(); + if (auto cst = value.getDefiningOp()) + return cst.value(); + if (auto cst = value.getDefiningOp()) { + if (auto intAttr = dyn_cast(cst.getValue())) + return intAttr.getInt(); + } + return std::nullopt; +} + +static FailureOr buildGlobalTensorViewFromPointer( + ConversionPatternRewriter &rewriter, Location loc, Value ptr, Type elemTy, + ArrayRef shape, ArrayRef strides = {}, + StringRef layoutEnum = "pto::Layout::ND") { + if (llvm::any_of(shape, [](int64_t dim) { + return dim == ShapedType::kDynamic; + })) + return failure(); + + auto *ctx = rewriter.getContext(); + SmallVector rowMajorStrides; + ArrayRef effectiveStrides = strides; + if (effectiveStrides.empty()) { + rowMajorStrides = buildRowMajorStrides(shape); + effectiveStrides = rowMajorStrides; + } + SmallVector shape5D; + SmallVector stride5D; + buildGlobalTensorShapeAndStride(shape, effectiveStrides, shape5D, stride5D); + + std::string shapeType = "pto::Shape<" + joinIntTemplateParams(shape5D) + ">"; + std::string strideType = + "pto::Stride<" + joinIntTemplateParams(stride5D) + ">"; + auto shapeVal = rewriter + .create( + loc, emitc::OpaqueType::get(ctx, shapeType), + shapeType, ArrayAttr{}, ArrayAttr{}, ValueRange{}) + .getResult(0); + auto strideVal = rewriter + .create( + loc, emitc::OpaqueType::get(ctx, strideType), + strideType, ArrayAttr{}, ArrayAttr{}, ValueRange{}) + .getResult(0); + + std::string gtTypeStr = + getGlobalTensorTypeStringFromShapeAndStrides(elemTy, shape, + effectiveStrides, + layoutEnum); + auto gtType = emitc::OpaqueType::get(ctx, gtTypeStr); + auto gt = rewriter.create( + loc, gtType, gtTypeStr, ArrayAttr{}, ArrayAttr{}, + ValueRange{ptr, shapeVal, strideVal}); + return gt.getResult(0); +} + +static bool parseIntegerTemplateList(StringRef token, StringRef marker, + SmallVectorImpl &values) { + size_t pos = token.find(marker); + if (pos == StringRef::npos) + return false; + pos += marker.size(); + size_t end = token.find('>', pos); + if (end == StringRef::npos) + return false; + + SmallVector parts; + token.slice(pos, end).split(parts, ','); + values.clear(); + for (StringRef part : parts) { + int64_t value = 0; + if (part.trim().getAsInteger(10, value)) + return false; + values.push_back(value); + } + return true; +} + +static LogicalResult getStaticTensorViewStrides( + Value source, Value convertedSource, pto::TensorViewType sourceType, + SmallVectorImpl &strides) { + int64_t rank = sourceType.getRank(); + strides.clear(); + + if (auto makeView = source.getDefiningOp()) { + if ((int64_t)makeView.getStrides().size() != rank) + return failure(); + for (Value strideValue : makeView.getStrides()) { + auto cst = getStaticIndexLikeValue(strideValue); + if (!cst) + return failure(); + strides.push_back(*cst); + } + return success(); + } + + Value src = peelUnrealized(convertedSource); + if (auto opaqueTy = dyn_cast(src.getType())) { + SmallVector stride5D; + StringRef token = opaqueTy.getValue(); + if ((parseIntegerTemplateList(token, "pto::Stride<", stride5D) || + parseIntegerTemplateList(token, "Stride<", stride5D)) && + (int64_t)stride5D.size() >= rank) { + strides.append(stride5D.end() - rank, stride5D.end()); + return success(); + } + } + + auto fallback = buildRowMajorStrides(sourceType.getShape()); + strides.append(fallback.begin(), fallback.end()); + return success(); +} + +struct PTOPartitionViewToEmitC + : public OpConversionPattern { + using OpConversionPattern< + mlir::pto::PartitionViewOp>::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::PartitionViewOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcTy = dyn_cast(op.getSource().getType()); + auto resTy = dyn_cast(op.getResult().getType()); + if (!srcTy || !resTy) + return rewriter.notifyMatchFailure( + op, "expected tensor_view source and partition_tensor_view result"); + + if (op.getOffsets().size() != static_cast(srcTy.getRank()) || + op.getSizes().size() != static_cast(srcTy.getRank())) + return rewriter.notifyMatchFailure(op, "rank mismatch"); + + for (auto [idx, value] : llvm::enumerate(op.getSizes())) { + auto cst = getStaticIndexLikeValue(value); + if (!cst) + return rewriter.notifyMatchFailure( + op, "globaltensor partition_view requires static sizes"); + int64_t resultDim = resTy.getShape()[idx]; + if (resultDim != ShapedType::kDynamic && resultDim != *cst) + return rewriter.notifyMatchFailure( + op, "partition_view static size does not match result type"); + } + + SmallVector srcStrides; + if (failed(getStaticTensorViewStrides(op.getSource(), adaptor.getSource(), + srcTy, srcStrides))) + return rewriter.notifyMatchFailure( + op, "partition_view requires static source strides"); + int64_t staticLinearOffset = 0; + SmallVector> dynamicOffsetTerms; + for (auto [idx, values] : + llvm::enumerate(llvm::zip(op.getOffsets(), adaptor.getOffsets()))) { + Value originalOffset = std::get<0>(values); + Value convertedOffset = std::get<1>(values); + int64_t stride = srcStrides[idx]; + if (stride == ShapedType::kDynamic) + return rewriter.notifyMatchFailure( + op, "dynamic source stride is not supported"); + + if (auto cst = getStaticIndexLikeValue(originalOffset)) { + if (*cst != 0) + staticLinearOffset += (*cst) * stride; + continue; + } + dynamicOffsetTerms.push_back({convertedOffset, stride}); + } + + auto *ctx = rewriter.getContext(); + std::string elemTypeStr = getElemTypeStringForGT(srcTy.getElementType()); + auto ptrTy = emitc::PointerType::get( + emitc::OpaqueType::get(ctx, "__gm__ " + elemTypeStr)); + Value src = peelUnrealized(adaptor.getSource()); + auto data = rewriter + .create( + op.getLoc(), ptrTy, "PTOAS__GLOBAL_TENSOR_DATA", + ArrayAttr{}, ArrayAttr{}, ValueRange{src}) + .getResult(0); + Value ptr = data; + if (!dynamicOffsetTerms.empty()) { + Type u32Ty = emitc::OpaqueType::get(ctx, "unsigned"); + auto makeU32 = [&](int64_t value) { + return makeEmitCIntConstant(rewriter, op.getLoc(), u32Ty, value); + }; + auto asU32 = [&](Value value) -> Value { + if (value.getType() == u32Ty) + return value; + return rewriter.create(op.getLoc(), u32Ty, value) + .getResult(); + }; + + Value totalOffset = makeU32(staticLinearOffset); + for (auto [offsetValue, stride] : dynamicOffsetTerms) { + Value term = asU32(offsetValue); + if (stride != 1) { + Value strideValue = makeU32(stride); + term = rewriter + .create(op.getLoc(), u32Ty, term, + strideValue) + .getResult(); + } + totalOffset = rewriter + .create(op.getLoc(), u32Ty, + totalOffset, term) + .getResult(); + } + ptr = rewriter + .create(op.getLoc(), data.getType(), data, + totalOffset) + .getResult(); + } else { + ptr = applyStaticMemrefOffset(rewriter, op.getLoc(), data, + staticLinearOffset); + } + + auto resultOr = buildGlobalTensorViewFromPointer( + rewriter, op.getLoc(), ptr, resTy.getElementType(), resTy.getShape(), + srcStrides); + if (failed(resultOr)) + return rewriter.notifyMatchFailure( + op, "failed to materialize partition GlobalTensor"); + + rewriter.replaceOp(op, *resultOr); + return success(); + } +}; + + +} // namespace + +void populatePTOToEmitCMemoryOpPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx) { + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); +} + +} // namespace mlir::pto diff --git a/lib/PTO/Transforms/PTOToEmitCRuntimeOps.cpp b/lib/PTO/Transforms/PTOToEmitCRuntimeOps.cpp new file mode 100644 index 000000000..a80b79fa0 --- /dev/null +++ b/lib/PTO/Transforms/PTOToEmitCRuntimeOps.cpp @@ -0,0 +1,736 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- PTOToEmitCRuntimeOps.cpp -------------------------------------------===// +//===----------------------------------------------------------------------===// + +#include "PTOToEmitCInternal.h" + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +#include + +using namespace mlir; +using namespace mlir::pto; + +namespace mlir::pto { +namespace { + +static constexpr unsigned kPTOIndexBitWidth = 32; + +static int64_t getEmitCScalarByteWidth(Type elemTy) { + if (pto::getPTOStorageElemByteSize(elemTy) == 1) + return 1; + if (elemTy.isF16() || elemTy.isBF16() || elemTy.isInteger(16)) + return 2; + if (elemTy.isF32() || elemTy.isInteger(32)) + return 4; + if (elemTy.isF64() || elemTy.isInteger(64)) + return 8; + return 4; +} + +static FailureOr getTileSplitToken(int64_t split) { + switch (split) { + case 0: + return std::string("TileSplitAxis::TILE_NO_SPLIT"); + case 1: + return std::string("TileSplitAxis::TILE_UP_DOWN"); + case 2: + return std::string("TileSplitAxis::TILE_LEFT_RIGHT"); + default: + return failure(); + } +} + +static FailureOr +getTPipeDirectionToken(bool isL2G2L, int8_t dirMask, PTOArch targetArch) { + if (dirMask == 1) { + if (isL2G2L && targetArch == PTOArch::A5) + return std::string("Direction::DIR_C2V_GM"); + return std::string("Direction::DIR_C2V"); + } + if (dirMask == 2) { + if (isL2G2L && targetArch == PTOArch::A5) + return std::string("Direction::DIR_V2C_GM"); + return std::string("Direction::DIR_V2C"); + } + if (dirMask == 3) + return std::string("Direction::DIR_BOTH"); + return failure(); +} + +static std::string buildTPipeToken(int32_t flagBase, llvm::StringRef dirTok, + int32_t slotSize, int32_t slotNum, + int32_t localSlotNum, bool nosplit) { + std::string token = "TPipe<" + std::to_string(flagBase) + ", " + dirTok.str() + + ", " + std::to_string(slotSize) + ", " + + std::to_string(slotNum); + token += ", " + std::to_string(localSlotNum); + token += nosplit ? ", true" : ", false"; + token += ">"; + return token; +} + +} // namespace + +FailureOr buildTPipeTokenFromInitOp(Operation *op, + PTOArch targetArch) { + if (auto initOp = dyn_cast(op)) { + if (!initOp.getFlagBaseAttr()) + return failure(); + auto dirTok = + getTPipeDirectionToken(/*isL2G2L=*/true, initOp.getDirMask(), targetArch); + if (failed(dirTok)) + return failure(); + int32_t localSlotNum = initOp.getLocalSlotNumAttr() + ? initOp.getLocalSlotNumAttr().getInt() + : initOp.getSlotNum(); + return buildTPipeToken(initOp.getFlagBaseAttr().getInt(), *dirTok, + initOp.getSlotSize(), initOp.getSlotNum(), + localSlotNum, + initOp.getNosplitAttr() && + initOp.getNosplitAttr().getValue()); + } + + if (auto initOp = dyn_cast(op)) { + if (!initOp.getFlagBaseAttr()) + return failure(); + auto dirTok = + getTPipeDirectionToken(/*isL2G2L=*/false, initOp.getDirMask(), targetArch); + if (failed(dirTok)) + return failure(); + return buildTPipeToken(initOp.getFlagBaseAttr().getInt(), *dirTok, + initOp.getSlotSize(), initOp.getSlotNum(), 2, + initOp.getNosplitAttr() && + initOp.getNosplitAttr().getValue()); + } + + return failure(); +} + + +namespace { + +static FailureOr getTPipeTokenFromValue(Value pipeHandle, + PTOArch targetArch) { + pipeHandle = peelUnrealized(pipeHandle); + Operation *def = pipeHandle.getDefiningOp(); + if (!def) + return failure(); + return buildTPipeTokenFromInitOp(def, targetArch); +} + + + +static FailureOr getPipeDataTypeToken(Value value) { + auto opaqueTy = dyn_cast(value.getType()); + if (!opaqueTy) + return failure(); + StringRef token = opaqueTy.getValue(); + if (!token.contains("Tile<") && !token.contains("GlobalTensor<")) + return failure(); + return token.str(); +} + +struct PTOTAllocToEmitC : public OpConversionPattern { + PTOTAllocToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + PTOArch targetArch) + : OpConversionPattern(typeConverter, ctx), + targetArch(targetArch) {} + + LogicalResult matchAndRewrite(mlir::pto::TAllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); + if (failed(pipeTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); + Value entry = peelUnrealized(adaptor.getEntry()); + auto entryTok = getPipeDataTypeToken(entry); + if (failed(entryTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve entry token"); + auto splitTok = getTileSplitToken(op.getSplit()); + if (failed(splitTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve split token"); + + std::string callee = + "TALLOC<" + *pipeTok + ", " + *entryTok + ", " + *splitTok + ">"; + rewriter.replaceOpWithNewOp( + op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, + ValueRange{peelUnrealized(adaptor.getPipeHandle()), entry}); + return success(); + } + + PTOArch targetArch; +}; + +struct PTOTPushToEmitC : public OpConversionPattern { + PTOTPushToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + PTOArch targetArch) + : OpConversionPattern(typeConverter, ctx), + targetArch(targetArch) {} + + LogicalResult matchAndRewrite(mlir::pto::TPushOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); + if (failed(pipeTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); + // Read the tile type token from the already-converted OpaqueType, which + // preserves the exact layout produced by BindTileOp / PointerCastOp EmitC. + Value convertedTile = peelUnrealized(adaptor.getTile()); + auto tileTok = getPipeDataTypeToken(convertedTile); + if (failed(tileTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve tile token"); + auto splitTok = getTileSplitToken(op.getSplit()); + if (failed(splitTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve split token"); + + std::string callee = + "TPUSH<" + *pipeTok + ", " + *tileTok + ", " + *splitTok + ">"; + rewriter.replaceOpWithNewOp( + op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, + ValueRange{peelUnrealized(adaptor.getPipeHandle()), convertedTile}); + return success(); + } + + PTOArch targetArch; +}; + +struct PTOTPopToEmitC : public OpConversionPattern { + PTOTPopToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + PTOArch targetArch) + : OpConversionPattern(typeConverter, ctx), + targetArch(targetArch) {} + + LogicalResult matchAndRewrite(mlir::pto::TPopOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); + if (failed(pipeTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); + Value convertedTile = peelUnrealized(adaptor.getTile()); + auto tileTok = getPipeDataTypeToken(convertedTile); + if (failed(tileTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve tile token"); + auto splitTok = getTileSplitToken(op.getSplit()); + if (failed(splitTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve split token"); + + std::string callee = + "TPOP<" + *pipeTok + ", " + *tileTok + ", " + *splitTok + ">"; + rewriter.replaceOpWithNewOp( + op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, + ValueRange{peelUnrealized(adaptor.getPipeHandle()), convertedTile}); + return success(); + } + + PTOArch targetArch; +}; + +struct PTOTFreeToEmitC : public OpConversionPattern { + PTOTFreeToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + PTOArch targetArch) + : OpConversionPattern(typeConverter, ctx), + targetArch(targetArch) {} + + LogicalResult matchAndRewrite(mlir::pto::TFreeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto pipeTok = getTPipeTokenFromValue(op.getPipeHandle(), targetArch); + if (failed(pipeTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve pipe token"); + auto splitTok = getTileSplitToken(op.getSplit()); + if (failed(splitTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve split token"); + + SmallVector operands{peelUnrealized(adaptor.getPipeHandle())}; + std::string callee; + if (op.getEntry()) { + Value entry = peelUnrealized(adaptor.getEntry()); + auto entryTok = getPipeDataTypeToken(entry); + if (failed(entryTok)) + return rewriter.notifyMatchFailure(op, "failed to resolve entry token"); + callee = "TFREE<" + *pipeTok + ", " + *entryTok + ", " + *splitTok + ">"; + operands.push_back(entry); + } else { + callee = "TFREE<" + *pipeTok + ", " + *splitTok + ">"; + } + rewriter.replaceOpWithNewOp( + op, TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, operands); + return success(); + } + + PTOArch targetArch; +}; + +//===----------------------------------------------------------------------===// +// populate patterns +//===----------------------------------------------------------------------=== +struct ReinterpretCastToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + auto resMrTy = dyn_cast(op.getType()); + if (!resMrTy) + return failure(); + + auto asAttr = dyn_cast_or_null(resMrTy.getMemorySpace()); + const bool isGm = (!asAttr || asAttr.getAddressSpace() == pto::AddressSpace::GM); + + bool emitAddPtrTrace = op->hasAttr("pto.addptr_trace"); + Value source = peelUnrealized(adaptor.getSource()); + auto offsets = adaptor.getOffsets(); + Value offsetVal = offsets.empty() ? Value() : offsets[0]; + + // GM: keep pointer arithmetic. + if (isGm) { + if (!offsetVal) { + rewriter.replaceOp(op, source); + return success(); + } + + Type resultType = getTypeConverter()->convertType(op.getType()); + if (!resultType) + return failure(); + + auto addOp = rewriter.create(loc, resultType, source, offsetVal); + if (emitAddPtrTrace) { + rewriter.setInsertionPointAfter(addOp); + rewriter.create( + loc, TypeRange{}, "PTOAS__ADDPTR_TRACE", + ArrayAttr{}, ArrayAttr{}, + ValueRange{addOp.getResult(), source, offsetVal}); + } + rewriter.replaceOp(op, addOp.getResult()); + return success(); + } + + // UB/L1/L0 tiles: materialize a new Tile view by assigning an adjusted + // underlying pointer (in elements). + pto::AddressSpace as = asAttr.getAddressSpace(); + + // Element type token. + Type elemTy = resMrTy.getElementType(); + std::string elemTok = getEmitCScalarTypeToken(elemTy); + int64_t elemBytes = getEmitCScalarByteWidth(elemTy); + + // Tile role. + const char *roleTok = "TileType::Vec"; + switch (as) { + case pto::AddressSpace::VEC: + roleTok = "TileType::Vec"; + break; + case pto::AddressSpace::MAT: + roleTok = "TileType::Mat"; + break; + case pto::AddressSpace::LEFT: + roleTok = "TileType::Left"; + break; + case pto::AddressSpace::RIGHT: + roleTok = "TileType::Right"; + break; + case pto::AddressSpace::ACC: + roleTok = "TileType::Acc"; + break; + case pto::AddressSpace::BIAS: + roleTok = "TileType::Bias"; + break; + case pto::AddressSpace::GM: + roleTok = "TileType::Vec"; + break; + case pto::AddressSpace::Zero: + roleTok = "TileType::Vec"; + break; + case pto::AddressSpace::SCALING: + roleTok = "TileType::Scaling"; + break; + } + + // Shape (fallback to 32x32). + int64_t rows = 32, cols = 32; + if (resMrTy.getRank() >= 2 && resMrTy.hasStaticShape()) { + rows = resMrTy.getDimSize(0); + cols = resMrTy.getDimSize(1); + } + int64_t templateRows = + renderTileTemplateDim(rows, elemTy, pto::BLayout::RowMajor, 0); + int64_t templateCols = + renderTileTemplateDim(cols, elemTy, pto::BLayout::RowMajor, 1); + + // Keep a conservative default config for now. + std::string tileTypeStr = + std::string("Tile<") + roleTok + ", " + elemTok + ", " + + std::to_string(templateRows) + ", " + std::to_string(templateCols) + + ", BLayout::RowMajor, " + std::to_string(templateRows) + ", " + + std::to_string(templateCols) + + ", SLayout::NoneBox, 512, PadValue::Null, CompactMode::Null>"; + + auto tileType = emitc::OpaqueType::get(ctx, tileTypeStr); + Value tile = rewriter + .create(loc, tileType, + emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + + // Compute an integer address and assign it to the new tile. + // NOTE: pto-isa TASSIGN requires an integral address (not a pointer). + auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); + auto rcU64 = rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + + // Non-GM reinterpret_cast operands come from UB/L1/L0 tiles. + // We need the underlying address, but `__cce_get_tile_ptr()` is only valid + // inside `__tf__` functions. Use `tile.data()` (via a post-processed marker) + // and compute the adjusted address in bytes. + Value rawPtr = source; + if (auto ot = dyn_cast(source.getType())) { + // Only Tiles have a `.data()` member. For plain address-space pointers + // (e.g. `__ubuf__ float*`), use the pointer value directly. + if (ot.getValue().starts_with("Tile<")) { + rawPtr = materializeTileDataValue(rewriter, loc, source, as, elemTok); + } + } + + Value baseAddr = rawPtr; + if (isSetFFTsPointerLikeType(rawPtr.getType())) { + baseAddr = rewriter + .create(loc, u64Ty, "reinterpret_cast", + /*args=*/ArrayAttr{}, + /*templateArgs=*/rcU64, + /*operands=*/ValueRange{rawPtr}) + .getResult(0); + } else if (rawPtr.getType() != u64Ty) { + baseAddr = rewriter.create(loc, u64Ty, rawPtr).getResult(); + } + + Value addr = baseAddr; + if (offsetVal) { + Value offU64 = offsetVal; + if (offU64.getType() != u64Ty) + offU64 = rewriter.create(loc, u64Ty, offU64).getResult(); + + auto bytesAttr = emitc::OpaqueAttr::get(ctx, std::to_string(elemBytes)); + Value bytesVal = rewriter.create(loc, u64Ty, bytesAttr); + Value byteOff = rewriter.create(loc, u64Ty, offU64, bytesVal); + addr = rewriter.create(loc, u64Ty, baseAddr, byteOff); + } + + rewriter.create(loc, TypeRange{}, "TASSIGN", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{tile, addr}); + + rewriter.replaceOp(op, tile); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.taddc lowering -> TADDC(dst, src0, src1, src2) +//===----------------------------------------------------------------------===// + +struct PTOTAddCToTADDC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAddCOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value src2 = peelUnrealized(adaptor.getSrc2()); + Value dst = peelUnrealized(adaptor.getDst()); + + // pto-isa does not provide NPU implementation for TADDC yet. + // Decompose: dst = src0 + src1 + src2 + rewriter.create( + loc, TypeRange{}, "TADD", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TADD", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, dst, src2}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tadds lowering -> TADDS(dst, src, scalar) +//===----------------------------------------------------------------------===// + +struct PTOAddSToTADDS : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAddSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value scalar = peelUnrealized(adaptor.getScalar()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TADDS", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src, scalar}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.taddsc lowering -> TADDSC(dst, src0, scalar, src1) +//===----------------------------------------------------------------------===// + +struct PTOAddSCToTADDSC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAddSCOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + // pto-isa does not provide NPU implementation for TADDSC yet. + // Decompose: dst = src0 + scalar + src1 + rewriter.create( + loc, TypeRange{}, "TADDS", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src0, scalar}); + rewriter.create( + loc, TypeRange{}, "TADD", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, dst, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; +// Tile/vector PTO op conversion patterns live in PTOToEmitCTilePatterns.cpp. + +struct PTOPrintOpToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::PrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + std::string fmt = op.getFormat().str(); + if (fmt.empty()) + fmt = "%f"; + std::string quoted = "\""; + for (char c : fmt) { + if (c == '"' || c == '\\') + quoted += '\\'; + else if (c == '\n') + quoted += "\\n"; + else if (c == '\t') + quoted += "\\t"; + else + quoted += c; + } + quoted += "\""; + + Value scalar = peelUnrealized(adaptor.getScalar()); + auto argsAttr = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, quoted), + IntegerAttr::get(IndexType::get(ctx), 0)}); + rewriter.create( + loc, TypeRange{}, "cce::printf", + /*args=*/argsAttr, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{scalar}); + + rewriter.eraseOp(op); + return success(); + } +}; + +// pto.trap -> TRAP() +struct PTOTrapOpToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TrapOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + rewriter.create( + loc, TypeRange{}, "trap", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{}); + + rewriter.eraseOp(op); + return success(); + } +}; + +// ============================================================================= +// Arith CmpI -> EmitC Cmp +// ============================================================================= +class ArithCmpIToEmitC : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + // 将 arith.cmpi 转换为 emitc.cmp + // 映射 Predicate: eq -> equal, slt -> less, etc. + emitc::CmpPredicate emitcPred = emitc::CmpPredicate::eq; + const bool isUnsignedPred = + op.getPredicate() == arith::CmpIPredicate::ult || + op.getPredicate() == arith::CmpIPredicate::ule || + op.getPredicate() == arith::CmpIPredicate::ugt || + op.getPredicate() == arith::CmpIPredicate::uge; + switch (op.getPredicate()) { + case arith::CmpIPredicate::eq: emitcPred = emitc::CmpPredicate::eq; break; + case arith::CmpIPredicate::ne: emitcPred = emitc::CmpPredicate::ne; break; + case arith::CmpIPredicate::slt: emitcPred = emitc::CmpPredicate::lt; break; + case arith::CmpIPredicate::sle: emitcPred = emitc::CmpPredicate::le; break; + case arith::CmpIPredicate::sgt: emitcPred = emitc::CmpPredicate::gt; break; + case arith::CmpIPredicate::sge: emitcPred = emitc::CmpPredicate::ge; break; + // ... 处理无符号比较 (ult, ule 等) ... + case arith::CmpIPredicate::ult: emitcPred = emitc::CmpPredicate::lt; break; + case arith::CmpIPredicate::ule: emitcPred = emitc::CmpPredicate::le; break; + case arith::CmpIPredicate::ugt: emitcPred = emitc::CmpPredicate::gt; break; + case arith::CmpIPredicate::uge: emitcPred = emitc::CmpPredicate::ge; break; + } + + Type resTy = getTypeConverter()->convertType(op.getType()); + if (!resTy) + return failure(); + + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + if (isUnsignedPred) { + Type opTy = op.getLhs().getType(); + auto intTy = dyn_cast(opTy); + const bool isIndex = isa(opTy); + if (!intTy && !isIndex) + return rewriter.notifyMatchFailure( + op, "expected scalar integer or index operands"); + + const unsigned bitWidth = + intTy ? intTy.getWidth() : static_cast(kPTOIndexBitWidth); + if (bitWidth != 1) { + lhs = castSignlessIntToUnsignedSameWidth(rewriter, loc, lhs, bitWidth); + rhs = castSignlessIntToUnsignedSameWidth(rewriter, loc, rhs, bitWidth); + } + } + + rewriter.replaceOpWithNewOp( + op, + /*resultType=*/resTy, // i1 -> bool/i1 + emitcPred, + lhs, + rhs + ); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Section Op Lowering +//===----------------------------------------------------------------------===// +static bool isA5NoSplitPipeOp(Operation *op) { + if (auto talloc = dyn_cast(op)) + return talloc.getSplit() == 0; + if (auto tpush = dyn_cast(op)) + return tpush.getSplit() == 0; + if (auto tpop = dyn_cast(op)) + return tpop.getSplit() == 0; + if (auto tfree = dyn_cast(op)) + return tfree.getSplit() == 0; + if (auto tpush = dyn_cast(op)) + return tpush.getSplit() == 0; + if (auto tpush = dyn_cast(op)) + return tpush.getSplit() == 0; + if (auto talloc = dyn_cast(op)) + return talloc.getSplit() == 0; + if (auto talloc = dyn_cast(op)) + return talloc.getSplit() == 0; + if (auto tpop = dyn_cast(op)) + return tpop.getSplit() == 0; + if (auto tpop = dyn_cast(op)) + return tpop.getSplit() == 0; + if (auto tfree = dyn_cast(op)) + return tfree.getSplit() == 0; + if (auto tfree = dyn_cast(op)) + return tfree.getSplit() == 0; + return false; +} + +static bool hasExplicitSubblockControl(Operation *op) { + bool hasControl = false; + op->walk([&](Operation *nested) { + if (isa(nested)) { + hasControl = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return hasControl; +} + +} // namespace + +bool needsA5NoSplitVectorGuard(Operation *op) { + auto arch = getTargetArch(op); + if (arch != PTOArch::A5) + return false; + bool isVectorScope = isa(op); + if (auto func = dyn_cast(op)) { + if (auto kernelKindAttr = + func->getAttrOfType( + FunctionKernelKindAttr::name)) { + isVectorScope = + kernelKindAttr.getKernelKind() == FunctionKernelKind::Vector; + } + } + if (!isVectorScope) + return false; + if (hasExplicitSubblockControl(op)) + return false; + + bool hasNoSplitPipe = false; + op->walk([&](Operation *nested) { + if (!isA5NoSplitPipeOp(nested)) + return WalkResult::advance(); + hasNoSplitPipe = true; + return WalkResult::interrupt(); + }); + return hasNoSplitPipe; +} + + +void populatePTOToEmitCRuntimeOpPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx, PTOArch targetArch) { + patterns.add(typeConverter, ctx, targetArch); + patterns.add(typeConverter, ctx, targetArch); + patterns.add(typeConverter, ctx, targetArch); + patterns.add(typeConverter, ctx, targetArch); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); +} + +} // namespace mlir::pto diff --git a/lib/PTO/Transforms/PTOToEmitCSimpleOps.cpp b/lib/PTO/Transforms/PTOToEmitCSimpleOps.cpp new file mode 100644 index 000000000..a2127a34e --- /dev/null +++ b/lib/PTO/Transforms/PTOToEmitCSimpleOps.cpp @@ -0,0 +1,593 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- PTOToEmitCSimpleOps.cpp --------------------------------------------===// +//===----------------------------------------------------------------------===// + +#include "PTOToEmitCInternal.h" + +#include "PTO/IR/PTO.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" + +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; +using namespace mlir::pto; + +namespace mlir::pto { +namespace { + +struct PTOGetBlockIdxToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mlir::pto::GetBlockIdxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + rewriter.replaceOpWithNewOp( + op, op.getType(), "get_block_idx", ValueRange{}, ArrayAttr{}, + ArrayAttr{}); + + return success(); + } +}; + +// GetBlockNumOp Lowering (pto.get_block_num -> get_block_num()) +struct PTOGetBlockNumToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mlir::pto::GetBlockNumOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + rewriter.replaceOpWithNewOp( + op, op.getType(), "get_block_num", ValueRange{}, ArrayAttr{}, + ArrayAttr{}); + + return success(); + } +}; + +// GetSubBlockIdxOp Lowering (pto.get_block_idx -> get_subblockid()) +struct PTOGetSubBlockIdxToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mlir::pto::GetSubBlockIdxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + rewriter.replaceOpWithNewOp( + op, op.getType(), "get_subblockid", ValueRange{}, ArrayAttr{}, + ArrayAttr{}); + + return success(); + } +}; + +// GetSubBlockNumOp Lowering. +struct PTOGetSubBlockNumToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mlir::pto::GetSubBlockNumOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + rewriter.replaceOpWithNewOp( + op, op.getType(), "get_subblockdim", ValueRange{}, ArrayAttr{}, + ArrayAttr{}); + + return success(); + } +}; + + + + +struct PTOSetValToSETVAL : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSetValOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value dst = peelUnrealized(adaptor.getDst()); + Value val = peelUnrealized(adaptor.getVal()); + + // ---- offset: SSA index operand ---- + Value offset = peelUnrealized(adaptor.getOffset()); + + // Emit a marker call and let the ptoas post-processing step lower it to + // the corresponding tile setter. + rewriter.create( + op.getLoc(), TypeRange{}, "PTOAS__TILE_SET_VALUE", + ArrayAttr{}, ArrayAttr{}, ValueRange{dst, offset, val}); + + rewriter.eraseOp(op); + return success(); + } +}; +struct PTOGetValToGETVAL : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGetValOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = peelUnrealized(adaptor.getSrc()); + + // ---- offset: SSA index operand ---- + Value offset = peelUnrealized(adaptor.getOffset()); + + // Emit a marker call and let the ptoas post-processing step lower it to + // the corresponding tile getter. + Type dstTy = getTypeConverter()->convertType(op.getDst().getType()); + if (!dstTy) + return failure(); + auto call = rewriter.create( + op.getLoc(), + TypeRange{dstTy}, + "PTOAS__TILE_GET_VALUE", + ArrayAttr{}, ArrayAttr{}, + ValueRange{src, offset}); + + rewriter.replaceOp(op, call.getResults()); + return success(); + } +}; + +struct PTOTAxpyToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAxpyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value scalar = peelUnrealized(adaptor.getScalar()); + + rewriter.create( + loc, TypeRange{}, "TAXPY", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, scalar}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOHistogramToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::THistogramOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value idx = peelUnrealized(adaptor.getIdx()); + Value dst = peelUnrealized(adaptor.getDst()); + + StringRef histByte = "HistByte::BYTE_1"; + int64_t byte = 1; + auto byteAttr = op.getByteAttr(); + if (byteAttr) + byte = byteAttr.getInt(); + if (auto legacyIsMSB = op->getAttrOfType("isMSB")) { + int64_t legacyByte = legacyIsMSB.getValue() ? 1 : 0; + if (byteAttr && byte != legacyByte) + return rewriter.notifyMatchFailure( + op, "conflicting 'byte' and legacy 'isMSB' attributes"); + byte = legacyByte; + } + switch (byte) { + case 0: + histByte = "HistByte::BYTE_0"; + break; + case 1: + histByte = "HistByte::BYTE_1"; + break; + case 2: + histByte = "HistByte::BYTE_2"; + break; + case 3: + histByte = "HistByte::BYTE_3"; + break; + default: + return rewriter.notifyMatchFailure(op, + "expected byte to be in range [0, 3]"); + } + + auto templateArgs = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, histByte)}); + rewriter.create( + loc, TypeRange{}, "THISTOGRAM", + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, + /*operands=*/ValueRange{dst, src, idx}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOGetScaleAddrToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGetScaleAddrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TGET_SCALE_ADDR", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOSetValidShapeToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::SetValidShapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto peelAllCasts = [](Value v) { + while (auto castOp = v.getDefiningOp()) + v = castOp.getOperand(0); + if (auto castOp = v.getDefiningOp()) + v = castOp.getOperand(); + return v; + }; + auto isTileLike = [](Value v) -> bool { + auto ot = dyn_cast(v.getType()); + if (!ot) + return false; + StringRef s = ot.getValue(); + return s.contains("Tile<") || s.contains("ConvTile<"); + }; + + Value src = peelAllCasts(peelUnrealized(adaptor.getSource())); + Value row = peelUnrealized(adaptor.getValidRow()); + Value col = peelUnrealized(adaptor.getValidCol()); + + if (!isTileLike(src)) + return rewriter.notifyMatchFailure( + op, "set_validshape source must lower to a tile-like value"); + + rewriter.create( + op.getLoc(), TypeRange{}, "PTOAS__TILE_SET_VALIDSHAPE", ArrayAttr{}, + ArrayAttr{}, ValueRange{src, row, col}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOGetValidShapeToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::GetValidShapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto peelAllCasts = [](Value v) { + while (auto castOp = v.getDefiningOp()) + v = castOp.getOperand(0); + if (auto castOp = v.getDefiningOp()) + v = castOp.getOperand(); + return v; + }; + auto isTileLike = [](Value v) -> bool { + auto ot = dyn_cast(v.getType()); + if (!ot) + return false; + StringRef s = ot.getValue(); + return s.contains("Tile<") || s.contains("ConvTile<"); + }; + + Value src = peelAllCasts(peelUnrealized(adaptor.getSource())); + if (!isTileLike(src)) + return rewriter.notifyMatchFailure( + op, "get_validshape source must lower to a tile-like value"); + + auto resultTy = getTypeConverter()->convertType(rewriter.getIndexType()); + if (!resultTy) + return failure(); + + Value row = rewriter + .create( + op.getLoc(), resultTy, + "PTOAS__TILE_GET_VALID_ROW", ArrayAttr{}, + ArrayAttr{}, ValueRange{src}) + .getResult(0); + Value col = rewriter + .create( + op.getLoc(), resultTy, + "PTOAS__TILE_GET_VALID_COL", ArrayAttr{}, + ArrayAttr{}, ValueRange{src}) + .getResult(0); + rewriter.replaceOp(op, ValueRange{row, col}); + return success(); + } +}; + +struct PTOTAssignToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAssignOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto peelAllCasts = [](Value v) { + while (auto castOp = v.getDefiningOp()) + v = castOp.getOperand(0); + if (auto castOp = v.getDefiningOp()) + v = castOp.getOperand(); + return v; + }; + auto isTileLike = [](Value v) -> bool { + auto ot = dyn_cast(v.getType()); + if (!ot) + return false; + StringRef s = ot.getValue(); + return s.contains("Tile<") || s.contains("ConvTile<"); + }; + + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value tile = peelAllCasts(peelUnrealized(adaptor.getTile())); + if (!isTileLike(tile)) + return rewriter.notifyMatchFailure( + op, "tassign tile must lower to a tile-like value"); + + Value addr = peelUnrealized(adaptor.getAddr()); + auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); + if (isa(addr.getType()) || + (isa(addr.getType()) && + cast(addr.getType()).getValue().ends_with("*"))) { + auto rcU64 = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + addr = rewriter + .create(loc, u64Ty, "reinterpret_cast", + ArrayAttr{}, rcU64, + ValueRange{addr}) + .getResult(0); + } else if (addr.getType() != u64Ty) { + addr = rewriter.create(loc, u64Ty, addr).getResult(); + } + + rewriter.create(loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{tile, addr}); + rewriter.replaceOp(op, tile); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// pto.load_scalar / pto.store_scalar lowering -> ptr[offset] +//===----------------------------------------------------------------------===// + +static Type getPointerLikeElementType(Type type) { + if (auto ptrTy = dyn_cast(type)) + return ptrTy.getElementType(); + if (auto memTy = dyn_cast(type)) + return memTy.getElementType(); + return Type(); +} + +struct PTOPtrToIntToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::PtrToIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value ptr = peelUnrealized(adaptor.getPtr()); + Type dstTy = getTypeConverter()->convertType(op.getResult().getType()); + if (!dstTy) + return failure(); + + auto dstOpaque = dyn_cast(dstTy); + if (!dstOpaque) + return failure(); + + auto templateArgs = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), + dstOpaque.getValue())}); + auto cast = rewriter.create( + op.getLoc(), dstTy, "reinterpret_cast", ArrayAttr{}, templateArgs, + ValueRange{ptr}); + rewriter.replaceOp(op, cast.getResult(0)); + return success(); + } +}; + +struct PTOIntToPtrToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::IntToPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value addr = peelUnrealized(adaptor.getAddr()); + Type dstTy = getTypeConverter()->convertType(op.getResult().getType()); + if (!dstTy) + return failure(); + + Type dstElemTy = getPointerLikeElementType(op.getResult().getType()); + if (!dstElemTy) + return failure(); + + std::string castType = + std::string("__gm__ ") + getEmitCScalarTypeToken(dstElemTy) + "*"; + auto templateArgs = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), + castType)}); + auto cast = rewriter.create( + op.getLoc(), dstTy, "reinterpret_cast", ArrayAttr{}, templateArgs, + ValueRange{addr}); + rewriter.replaceOp(op, cast.getResult(0)); + return success(); + } +}; + +struct PTOLoadScalarToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::LoadScalarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value ptr = peelUnrealized(adaptor.getPtr()); + Value offset = peelUnrealized(adaptor.getOffset()); + + Type dstTy = getTypeConverter()->convertType(op.getValue().getType()); + if (!dstTy) + return failure(); + + auto call = rewriter.create( + op.getLoc(), TypeRange{dstTy}, "PTOAS__PTR_LOAD", + ArrayAttr{}, ArrayAttr{}, ValueRange{ptr, offset}); + + rewriter.replaceOp(op, call.getResults()); + return success(); + } +}; + +struct PTOStoreScalarToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::StoreScalarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value ptr = peelUnrealized(adaptor.getPtr()); + Value offset = peelUnrealized(adaptor.getOffset()); + Value val = peelUnrealized(adaptor.getValue()); + + rewriter.create( + op.getLoc(), TypeRange{}, "PTOAS__PTR_STORE", + ArrayAttr{}, ArrayAttr{}, ValueRange{ptr, offset, val}); + rewriter.create( + op.getLoc(), TypeRange{}, "PTOAS__SCALAR_GM_STORE_FLUSH", + ArrayAttr{}, ArrayAttr{}, ValueRange{ptr}); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// pto.tabs lowering -> TABS(dst, src) +//===----------------------------------------------------------------------===// + + + +struct PTOTAbsToTABS : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAbsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + // intrinsic: TABS(dst, src) + rewriter.create( + op.getLoc(), TypeRange{}, "TABS", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tadd lowering -> TADD(dst, src0, src1) +//===----------------------------------------------------------------------===// + +struct PTOTAddToTADD : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAddOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TADD", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + + +struct AffineApplyMulConstToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(affine::AffineApplyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto map = op.getAffineMap(); + + if (map.getNumDims() != 0 || map.getNumSymbols() != 1) + return failure(); + + auto expr = map.getResult(0); + auto bin = dyn_cast(expr); + if (!bin || bin.getKind() != AffineExprKind::Mul) + return failure(); + + auto lhs = bin.getLHS(); + auto rhs = bin.getRHS(); + + auto symExpr = dyn_cast(lhs); + auto constExpr = dyn_cast(rhs); + if (!symExpr || !constExpr) + return failure(); + + Value inputVal = adaptor.getMapOperands()[0]; + + std::string valStr = std::to_string(constExpr.getValue()); + auto cstAttr = emitc::OpaqueAttr::get(rewriter.getContext(), valStr); + auto cstOp = rewriter.create( + op.getLoc(), inputVal.getType(), cstAttr); + + rewriter.replaceOpWithNewOp( + op, inputVal.getType(), inputVal, cstOp); + + return success(); + } +}; + + +} // namespace + +void populatePTOToEmitCSimpleOpPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx) { + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add( + typeConverter, ctx); + patterns.add( + typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); +} + +} // namespace mlir::pto diff --git a/lib/PTO/Transforms/PTOToEmitCSync.cpp b/lib/PTO/Transforms/PTOToEmitCSync.cpp new file mode 100644 index 000000000..efa812e5e --- /dev/null +++ b/lib/PTO/Transforms/PTOToEmitCSync.cpp @@ -0,0 +1,1046 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- PTOToEmitCSync.cpp --------------------------------------------------===// +//===----------------------------------------------------------------------===// + +#include "PTOToEmitCInternal.h" + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOSyncUtils.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +#include +#include +#include + +using namespace mlir; +using namespace mlir::pto; + +namespace mlir::pto { +namespace { + +static constexpr llvm::StringLiteral kAutoSyncTailPendingModeAttr = + "__pto.auto_sync_tail_mode"; +static inline std::string pipeTokFromPipeAttr(mlir::pto::PipeAttr a); + +struct InterCoreSyncCallDesc { + const char *callee = nullptr; + ArrayAttr args; + SmallVector operands; +}; + +static Value castInterCoreEventIdToI32(ConversionPatternRewriter &rewriter, + Location loc, Value eventId) { + auto i32Ty = emitc::OpaqueType::get(rewriter.getContext(), "int32_t"); + if (eventId.getType() == i32Ty) + return eventId; + return emitCCast(rewriter, loc, i32Ty, eventId); +} + +static Attribute getFFTSModeCodegenArg(ConversionPatternRewriter &rewriter, + int64_t fftsMode) { + auto *ctx = rewriter.getContext(); + if (fftsMode == 2) + return emitc::OpaqueAttr::get(ctx, "FFTS_MODE_VAL"); + return emitc::OpaqueAttr::get(ctx, std::to_string(fftsMode)); +} + +static Value createFFTSMsg(ConversionPatternRewriter &rewriter, Location loc, + Value eventI32, int64_t fftsMode) { + auto *ctx = rewriter.getContext(); + auto msgTy = emitc::OpaqueType::get(ctx, "uint16_t"); + auto msgArgs = rewriter.getArrayAttr({ + getFFTSModeCodegenArg(rewriter, fftsMode), + IntegerAttr::get(IndexType::get(ctx), 0), + }); + return rewriter + .create(loc, msgTy, "getFFTSMsg", + /*args=*/msgArgs, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{eventI32}) + .getResult(0); +} + +static InterCoreSyncCallDesc buildInterCoreSyncSetCall( + ConversionPatternRewriter &rewriter, Location loc, PTOArch targetArch, + pto::PipeAttr pipeAttr, IntegerAttr eventIdAttr, int64_t fftsMode) { + auto *ctx = rewriter.getContext(); + std::string pipeTok = pipeTokFromPipeAttr(pipeAttr); + + if (targetArch == PTOArch::A3) { + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value eventVal = + makeEmitCIntConstant(rewriter, loc, i32Ty, eventIdAttr.getInt()); + Value msgVal = createFFTSMsg(rewriter, loc, eventVal, fftsMode); + + InterCoreSyncCallDesc desc; + desc.callee = "ffts_cross_core_sync"; + desc.args = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeTok), + IntegerAttr::get(IndexType::get(ctx), 0), + }); + desc.operands.push_back(msgVal); + return desc; + } + + InterCoreSyncCallDesc desc; + desc.callee = "set_intra_block"; + desc.args = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, pipeTok), eventIdAttr}); + return desc; +} + +static InterCoreSyncCallDesc buildInterCoreSyncSetCallDyn( + ConversionPatternRewriter &rewriter, Location loc, PTOArch targetArch, + pto::PipeAttr pipeAttr, Value eventIdVal, int64_t fftsMode) { + auto *ctx = rewriter.getContext(); + std::string pipeTok = pipeTokFromPipeAttr(pipeAttr); + Value eventI32 = castInterCoreEventIdToI32(rewriter, loc, eventIdVal); + + if (targetArch == PTOArch::A3) { + Value msgVal = createFFTSMsg(rewriter, loc, eventI32, fftsMode); + + InterCoreSyncCallDesc desc; + desc.callee = "ffts_cross_core_sync"; + desc.args = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeTok), + IntegerAttr::get(IndexType::get(ctx), 0), + }); + desc.operands.push_back(msgVal); + return desc; + } + + InterCoreSyncCallDesc desc; + desc.callee = "set_intra_block"; + desc.args = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeTok), + IntegerAttr::get(IndexType::get(ctx), 0), + }); + desc.operands.push_back(eventI32); + return desc; +} + +static InterCoreSyncCallDesc buildInterCoreSyncWaitCall( + ConversionPatternRewriter &rewriter, PTOArch targetArch, + pto::PipeAttr pipeAttr, IntegerAttr eventIdAttr) { + auto *ctx = rewriter.getContext(); + std::string pipeTok = pipeTokFromPipeAttr(pipeAttr); + + InterCoreSyncCallDesc desc; + if (targetArch == PTOArch::A3) { + desc.callee = "wait_flag_dev"; + desc.args = rewriter.getArrayAttr({eventIdAttr}); + return desc; + } + + desc.callee = "wait_intra_block"; + desc.args = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, pipeTok), eventIdAttr}); + return desc; +} + +static InterCoreSyncCallDesc buildInterCoreSyncWaitCallDyn( + ConversionPatternRewriter &rewriter, Location loc, PTOArch targetArch, + pto::PipeAttr pipeAttr, Value eventIdVal) { + auto *ctx = rewriter.getContext(); + std::string pipeTok = pipeTokFromPipeAttr(pipeAttr); + Value eventI32 = castInterCoreEventIdToI32(rewriter, loc, eventIdVal); + + InterCoreSyncCallDesc desc; + if (targetArch == PTOArch::A3) { + desc.callee = "wait_flag_dev"; + desc.args = rewriter.getArrayAttr({IntegerAttr::get(IndexType::get(ctx), 0)}); + desc.operands.push_back(eventI32); + return desc; + } + + desc.callee = "wait_intra_block"; + desc.args = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeTok), + IntegerAttr::get(IndexType::get(ctx), 0), + }); + desc.operands.push_back(eventI32); + return desc; +} + + + +static FailureOr buildSyncAllWorkspaceTileValue( + ConversionPatternRewriter &rewriter, Location loc, Value originalWorkspace, + Value emittedWorkspace) { + Value workspace = peelUnrealized(emittedWorkspace); + if (auto opaqueTy = dyn_cast(workspace.getType())) { + StringRef typeStr = opaqueTy.getValue(); + if (typeStr.contains("Tile<") || typeStr.contains("ConvTile<")) + return workspace; + } + + auto memTy = dyn_cast(originalWorkspace.getType()); + if (!memTy) + return failure(); + if (!memTy.hasStaticShape()) + return failure(); + + ArrayRef rawShape = memTy.getShape(); + if (rawShape.empty() || rawShape.size() > 2) + return failure(); + + int64_t rows = rawShape.size() == 1 ? 1 : rawShape[0]; + int64_t cols = rawShape.size() == 1 ? rawShape[0] : rawShape[1]; + SmallVector shape{rows, cols}; + SmallVector validShape{rows, cols}; + + auto *ctx = rewriter.getContext(); + pto::TileBufConfigAttr configAttr = pto::TileBufConfigAttr::getDefault(ctx); + if (auto bind = originalWorkspace.getDefiningOp()) { + configAttr = bind.getConfig(); + } else if (auto cast = originalWorkspace.getDefiningOp()) { + if (auto config = cast.getConfig()) + configAttr = *config; + } + + Attribute memorySpace = memTy.getMemorySpace(); + if (!memorySpace) + return failure(); + + auto tileTy = pto::TileBufType::get(ctx, shape, memTy.getElementType(), + memorySpace, validShape, configAttr); + auto tileTypeString = getEmitCTileTypeString(tileTy); + if (!tileTypeString) + return failure(); + + auto tileEmitTy = emitc::OpaqueType::get(ctx, *tileTypeString); + Value tile = rewriter + .create(loc, tileEmitTy, + emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + + Value rawPtr = workspace; + auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); + if (isSetFFTsPointerLikeType(rawPtr.getType())) { + auto rcU64 = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + rawPtr = rewriter + .create(loc, u64Ty, "reinterpret_cast", + ArrayAttr{}, rcU64, + ValueRange{rawPtr}) + .getResult(0); + } else if (rawPtr.getType() != u64Ty) { + rawPtr = rewriter.create(loc, u64Ty, rawPtr).getResult(); + } + + rewriter.create(loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{tile, rawPtr}); + return tile; +} + + + +//===----------------------------------------------------------------------===// +// Sync lowering +//===----------------------------------------------------------------------=== + +static constexpr llvm::StringLiteral kAutoSyncTailBarrierAttr = + "pto.auto_sync_tail_barrier"; +static constexpr llvm::StringLiteral kAutoSyncTailHintAttr = + "pto.auto_sync_tail_hint"; +static constexpr llvm::StringLiteral kAutoSyncTailPolicyBarrierAll = + "barrier_all"; +static constexpr llvm::StringLiteral kAutoSyncTailPolicyMte3ToSEvent0 = + "setwait_mte3_to_s_event0"; +static constexpr llvm::StringLiteral kAutoSyncTailModeBarrierAllToken = + "PTOAutoSyncTailMode::kBarrierAll"; +static constexpr llvm::StringLiteral kAutoSyncTailModeMte3ToSEvent0Token = + "PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0"; + +static std::string getAutoSyncTailModeToken(Operation *op) { + if (op) { + if (auto hintAttr = op->getAttrOfType(kAutoSyncTailHintAttr)) { + if (hintAttr.getValue() == kAutoSyncTailPolicyBarrierAll) + return kAutoSyncTailModeBarrierAllToken.str(); + if (hintAttr.getValue() == kAutoSyncTailPolicyMte3ToSEvent0) + return kAutoSyncTailModeMte3ToSEvent0Token.str(); + } + } + + auto func = op ? op->getParentOfType() : func::FuncOp(); + if (!func) + return kAutoSyncTailModeBarrierAllToken.str(); + + auto hintAttr = func->getAttrOfType(kAutoSyncTailHintAttr); + if (!hintAttr) + return kAutoSyncTailModeBarrierAllToken.str(); + + if (hintAttr.getValue() == kAutoSyncTailPolicyBarrierAll) + return kAutoSyncTailModeBarrierAllToken.str(); + if (hintAttr.getValue() == kAutoSyncTailPolicyMte3ToSEvent0) + return kAutoSyncTailModeMte3ToSEvent0Token.str(); + + // Fallback to the conservative behavior when seeing unknown policies. + return kAutoSyncTailModeBarrierAllToken.str(); +} + +[[maybe_unused]] static std::string getPipeName(pto::PIPE pipe) { + switch (pipe) { + case pto::PIPE::PIPE_S: return "PIPE_S"; + case pto::PIPE::PIPE_V: return "PIPE_V"; + case pto::PIPE::PIPE_M: return "PIPE_M"; + case pto::PIPE::PIPE_MTE1: return "PIPE_MTE1"; + case pto::PIPE::PIPE_MTE2: return "PIPE_MTE2"; + case pto::PIPE::PIPE_MTE3: return "PIPE_MTE3"; + case pto::PIPE::PIPE_ALL: return "PIPE_ALL"; + case pto::PIPE::PIPE_MTE4: return "PIPE_MTE4"; + case pto::PIPE::PIPE_MTE5: return "PIPE_MTE5"; + case pto::PIPE::PIPE_V2: return "PIPE_V2"; + case pto::PIPE::PIPE_FIX: return "PIPE_FIX"; + case pto::PIPE::VIRTUAL_PIPE_MTE2_L1A: return "VIRTUAL_PIPE_MTE2_L1A"; + case pto::PIPE::VIRTUAL_PIPE_MTE2_L1B: return "VIRTUAL_PIPE_MTE2_L1B"; + // 默认回退 + default: return "PIPE_ALL"; + } +} + +//===----------------------------------------------------------------------===// +// pto.barrier lowering -> pipe_barrier(...) +//===----------------------------------------------------------------------===// +struct PTOBarrierToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::BarrierOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op->hasAttr(kAutoSyncTailBarrierAttr)) { + auto modeAttr = rewriter.getStringAttr(getAutoSyncTailModeToken(op)); + if (auto emitcFunc = op->getParentOfType()) { + emitcFunc->setAttr(kAutoSyncTailPendingModeAttr, modeAttr); + } else if (auto funcOp = op->getParentOfType()) { + funcOp->setAttr(kAutoSyncTailPendingModeAttr, modeAttr); + } + rewriter.eraseOp(op); + return success(); + } + + // [FIX] op.getPipe() returns PipeAttr. + // We must call .getPipe() on the attribute to get the actual Enum value. + pto::PIPE pipeEnum = op.getPipe().getPipe(); + + // Convert Enum to String (e.g., PIPE_ALL -> "PIPE_ALL") + std::string pipeStr = pto::stringifyPIPE(pipeEnum).str(); + auto *ctx = rewriter.getContext(); + + auto args = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeStr) + }); + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, // void return + "pipe_barrier", // function name + args, // arguments + ArrayAttr{}, // template args + ValueRange{} // operands + ); + + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Sync lowering (robust for bracket form pto.set_flag[...] / pto.wait_flag[...]) +// Replace your PTOSyncToRuntimeCall with the code below. +//===----------------------------------------------------------------------===// + +static bool tryConvertPipeAttrToToken(Attribute attr, std::string &token) { + if (!attr) + return false; + if (auto pipe = dyn_cast(attr)) { + token = mlir::pto::stringifyPIPE(pipe.getPipe()).str(); + return true; + } + if (auto stringAttr = dyn_cast(attr)) { + token = stringAttr.getValue().str(); + return true; + } + return false; +} + +static bool tryConvertEventAttrToToken(Attribute attr, std::string &token) { + if (!attr) + return false; + if (auto event = dyn_cast(attr)) { + token = mlir::pto::stringifyEVENT(event.getEvent()).str(); + return true; + } + if (auto stringAttr = dyn_cast(attr)) { + token = stringAttr.getValue().str(); + return true; + } + return false; +} + +static bool tryAssignSyncTokens(Attribute srcAttr, Attribute dstAttr, + Attribute evtAttr, std::string &srcTok, + std::string &dstTok, std::string &evtTok) { + std::string localSrc; + std::string localDst; + std::string localEvt; + if (!tryConvertPipeAttrToToken(srcAttr, localSrc) || + !tryConvertPipeAttrToToken(dstAttr, localDst) || + !tryConvertEventAttrToToken(evtAttr, localEvt)) { + return false; + } + srcTok = std::move(localSrc); + dstTok = std::move(localDst); + evtTok = std::move(localEvt); + return true; +} + +static bool tryExtractSyncTokensFromNamedAttrs(Operation *op, + StringRef srcName, + StringRef dstName, + StringRef evtName, + std::string &srcTok, + std::string &dstTok, + std::string &evtTok) { + return tryAssignSyncTokens(op->getAttr(srcName), op->getAttr(dstName), + op->getAttr(evtName), srcTok, dstTok, evtTok); +} + +static bool tryExtractSyncTokensFromArrayAttr(Operation *op, StringRef attrName, + std::string &srcTok, + std::string &dstTok, + std::string &evtTok) { + auto arrayAttr = op->getAttrOfType(attrName); + if (!arrayAttr || arrayAttr.size() < 3) + return false; + return tryAssignSyncTokens(arrayAttr[0], arrayAttr[1], arrayAttr[2], srcTok, + dstTok, evtTok); +} + +static bool tryExtractFallbackSyncTokens(Operation *op, std::string &srcTok, + std::string &dstTok, + std::string &evtTok) { + SmallVector pipes; + std::string event; + for (NamedAttribute namedAttr : op->getAttrs()) { + std::string token; + if (tryConvertPipeAttrToToken(namedAttr.getValue(), token)) { + pipes.push_back(std::move(token)); + continue; + } + if (event.empty() && + tryConvertEventAttrToToken(namedAttr.getValue(), token)) { + event = std::move(token); + } + } + if (pipes.size() < 2 || event.empty()) + return false; + srcTok = pipes[0]; + dstTok = pipes[1]; + evtTok = event; + return true; +} + +static LogicalResult extractSyncTripletTokens(Operation *op, + std::string &srcTok, + std::string &dstTok, + std::string &evtTok, + ConversionPatternRewriter &rewriter) { + if (tryExtractSyncTokensFromNamedAttrs(op, "src_pipe", "dst_pipe", "event_id", + srcTok, dstTok, evtTok) || + tryExtractSyncTokensFromNamedAttrs(op, "srcPipe", "dstPipe", "eventId", + srcTok, dstTok, evtTok) || + tryExtractSyncTokensFromNamedAttrs(op, "src", "dst", "event", srcTok, + dstTok, evtTok)) { + return success(); + } + + for (StringRef attrName : {"args", "pipes", "sync", "triplet", "attrs"}) { + if (tryExtractSyncTokensFromArrayAttr(op, attrName, srcTok, dstTok, + evtTok)) { + return success(); + } + } + + if (tryExtractFallbackSyncTokens(op, srcTok, dstTok, evtTok)) + return success(); + return rewriter.notifyMatchFailure( + op, "cannot extract PIPE/PIPE/EVENT tokens from pto.{set,wait}_flag"); +} +static inline std::string pipeTokFromPipeEnum(mlir::pto::PIPE p) { + return mlir::pto::stringifyPIPE(p).str(); +} +[[maybe_unused]] static inline std::string evtTokFromEventEnum(mlir::pto::EVENT e) { + return mlir::pto::stringifyEVENT(e).str(); +} +static inline std::string pipeTokFromPipeAttr(mlir::pto::PipeAttr a) { + return mlir::pto::stringifyPIPE(a.getPipe()).str(); +} +static inline std::string evtTokFromEventAttr(mlir::pto::EventAttr a) { + return mlir::pto::stringifyEVENT(a.getEvent()).str(); +} + +template +struct HasGetSrcPipe : std::false_type {}; +template +struct HasGetSrcPipe().getSrcPipe())>> : std::true_type {}; + +template +struct HasGetDstPipe : std::false_type {}; +template +struct HasGetDstPipe().getDstPipe())>> : std::true_type {}; + +template +struct HasGetEventId : std::false_type {}; +template +struct HasGetEventId().getEventId())>> : std::true_type {}; + +template +struct HasGetSrcPipeAttr : std::false_type {}; +template +struct HasGetSrcPipeAttr().getSrcPipeAttr())>> : std::true_type {}; + +template +struct HasGetDstPipeAttr : std::false_type {}; +template +struct HasGetDstPipeAttr().getDstPipeAttr())>> : std::true_type {}; + +template +struct HasGetEventIdAttr : std::false_type {}; +template +struct HasGetEventIdAttr().getEventIdAttr())>> : std::true_type {}; + +template +static LogicalResult extractSyncTokens(SyncOpT op, + std::string &srcTok, + std::string &dstTok, + std::string &evtTok, + ConversionPatternRewriter &rewriter) { + if constexpr (HasGetSrcPipe::value && + HasGetDstPipe::value && + HasGetEventId::value) { + auto s = op.getSrcPipe(); + auto d = op.getDstPipe(); + auto e = op.getEventId(); + + if constexpr (std::is_same::value) srcTok = pipeTokFromPipeEnum(s); + else srcTok = pipeTokFromPipeAttr(s); + + if constexpr (std::is_same::value) dstTok = pipeTokFromPipeEnum(d); + else dstTok = pipeTokFromPipeAttr(d); + + if constexpr (std::is_same::value) evtTok = evtTokFromEventEnum(e); + else evtTok = evtTokFromEventAttr(e); + + return success(); + } + + if constexpr (HasGetSrcPipeAttr::value && + HasGetDstPipeAttr::value && + HasGetEventIdAttr::value) { + auto s = op.getSrcPipeAttr(); + auto d = op.getDstPipeAttr(); + auto e = op.getEventIdAttr(); + srcTok = pipeTokFromPipeAttr(s); + dstTok = pipeTokFromPipeAttr(d); + evtTok = evtTokFromEventAttr(e); + return success(); + } + + return extractSyncTripletTokens(op.getOperation(), srcTok, dstTok, evtTok, rewriter); +} +struct PTOSetFlagToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::SetFlagOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + auto *ctx = rewriter.getContext(); + + std::string srcTok, dstTok, evtTok; + if (failed(extractSyncTokens(op, srcTok, dstTok, evtTok, rewriter))) + return failure(); + + auto argsAttr = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, srcTok), + emitc::OpaqueAttr::get(ctx, dstTok), + emitc::OpaqueAttr::get(ctx, evtTok), + }); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, "set_flag", + /*args=*/argsAttr, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{}); + return success(); + } +}; + +struct PTOWaitFlagToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::WaitFlagOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + auto *ctx = rewriter.getContext(); + + std::string srcTok, dstTok, evtTok; + if (failed(extractSyncTokens(op, srcTok, dstTok, evtTok, rewriter))) + return failure(); + + auto argsAttr = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, srcTok), + emitc::OpaqueAttr::get(ctx, dstTok), + emitc::OpaqueAttr::get(ctx, evtTok), + }); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, "wait_flag", + /*args=*/argsAttr, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{}); + return success(); + } +}; + +struct PTOSyncToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::TSyncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector operands; + operands.reserve(adaptor.getEvents().size()); + for (Value event : adaptor.getEvents()) + operands.push_back(peelUnrealized(event)); + + rewriter.create( + op.getLoc(), TypeRange{}, "TSYNC", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange(operands)); + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOSyncAllToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + static StringRef coreTypeTok(pto::SyncCoreType coreType) { + switch (coreType) { + case pto::SyncCoreType::AIVOnly: + return "SyncCoreType::AIVOnly"; + case pto::SyncCoreType::AICOnly: + return "SyncCoreType::AICOnly"; + case pto::SyncCoreType::Mix: + return "SyncCoreType::Mix"; + } + llvm_unreachable("unhandled SyncCoreType"); + } + + LogicalResult matchAndRewrite(mlir::pto::SyncAllOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto mode = op.getMode().getValue(); + auto coreType = op.getCoreType().getValue(); + + auto buildGmWorkspace = [&]() -> FailureOr { + Value gm = peelUnrealized(adaptor.getGmWorkspace()); + if (isEmitCGlobalTensorLikeType(gm.getType())) + return gm; + + auto memTy = dyn_cast(op.getGmWorkspace().getType()); + if (!memTy) + return failure(); + + Value gt = buildGlobalTensorFromMemref(rewriter, op.getLoc(), gm, memTy, + op.getGmWorkspace().getDefiningOp() + ? op.getGmWorkspace().getDefiningOp() + : op.getOperation()); + if (!gt) + return failure(); + return gt; + }; + + if (mode == pto::SyncAllMode::Hard) { + std::string callee = "SYNCALL<" + coreTypeTok(coreType).str() + ">"; + rewriter.create(op.getLoc(), TypeRange{}, callee, + ArrayAttr{}, ArrayAttr{}, + ValueRange{}); + rewriter.eraseOp(op); + return success(); + } + + FailureOr gmWorkspace = buildGmWorkspace(); + if (failed(gmWorkspace)) + return rewriter.notifyMatchFailure(op, + "failed to build gm_workspace GlobalTensor"); + + auto i32Ty = emitc::OpaqueType::get(rewriter.getContext(), "int32_t"); + Value usedCores = adaptor.getUsedCores() + ? peelUnrealized(adaptor.getUsedCores()) + : makeEmitCIntConstant(rewriter, op.getLoc(), i32Ty, 0); + if (usedCores.getType() != i32Ty) + usedCores = rewriter.create(op.getLoc(), i32Ty, usedCores) + .getResult(); + + std::string callee = + "SYNCALL"; + + SmallVector operands{*gmWorkspace}; + switch (coreType) { + case pto::SyncCoreType::AIVOnly: { + FailureOr ubWorkspace = + buildSyncAllWorkspaceTileValue(rewriter, op.getLoc(), + op.getUbWorkspace(), + adaptor.getUbWorkspace()); + if (failed(ubWorkspace)) + return rewriter.notifyMatchFailure( + op, "failed to materialize ub_workspace tile"); + operands.push_back(*ubWorkspace); + break; + } + case pto::SyncCoreType::AICOnly: { + FailureOr l1Workspace = + buildSyncAllWorkspaceTileValue(rewriter, op.getLoc(), + op.getL1Workspace(), + adaptor.getL1Workspace()); + if (failed(l1Workspace)) + return rewriter.notifyMatchFailure( + op, "failed to materialize l1_workspace tile"); + operands.push_back(*l1Workspace); + break; + } + case pto::SyncCoreType::Mix: { + FailureOr ubWorkspace = + buildSyncAllWorkspaceTileValue(rewriter, op.getLoc(), + op.getUbWorkspace(), + adaptor.getUbWorkspace()); + FailureOr l1Workspace = + buildSyncAllWorkspaceTileValue(rewriter, op.getLoc(), + op.getL1Workspace(), + adaptor.getL1Workspace()); + if (failed(ubWorkspace) || failed(l1Workspace)) + return rewriter.notifyMatchFailure( + op, "failed to materialize mixed syncall workspace tiles"); + operands.push_back(*ubWorkspace); + operands.push_back(*l1Workspace); + break; + } + } + + operands.push_back(usedCores); + rewriter.create(op.getLoc(), TypeRange{}, callee, + ArrayAttr{}, ArrayAttr{}, + ValueRange(operands)); + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOSyncFlagDynToEmitC : public ConversionPattern { + PTOSyncFlagDynToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + StringRef opName, StringRef callee) + : ConversionPattern(typeConverter, opName, /*benefit=*/1, ctx), + callee(callee.str()) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (operands.size() != 1) + return rewriter.notifyMatchFailure(op, "expected exactly one dynamic event-id operand"); + + auto srcAttr = op->getAttrOfType("src_pipe"); + auto dstAttr = op->getAttrOfType("dst_pipe"); + if (!srcAttr || !dstAttr) + return rewriter.notifyMatchFailure(op, "missing PipeAttr src_pipe/dst_pipe attrs"); + + auto *ctx = rewriter.getContext(); + std::string srcTok = pipeTokFromPipeAttr(srcAttr); + std::string dstTok = pipeTokFromPipeAttr(dstAttr); + + Value eventVal = operands.front(); + eventVal = + emitCCast(rewriter, op->getLoc(), emitc::OpaqueType::get(ctx, "event_t"), eventVal); + + auto argsAttr = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, srcTok), + emitc::OpaqueAttr::get(ctx, dstTok), + IntegerAttr::get(IndexType::get(ctx), 0), + }); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, callee, + /*args=*/argsAttr, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{eventVal}); + return success(); + } + +private: + std::string callee; +}; + +struct PTOGetBufToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::GetBufOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + auto *ctx = rewriter.getContext(); + + auto opTypeOr = parseSyncOpTypeLikeAttr(op.getOpTypeAttr()); + if (failed(opTypeOr)) + return rewriter.notifyMatchFailure(op, "get_buf expects pipe_event_type/sync_op_type attr"); + auto pipe = mapSyncOpTypeToPipe(*opTypeOr); + if (!isConcreteSyncPipe(pipe)) + return rewriter.notifyMatchFailure(op, "get_buf op_type cannot map to a concrete pipe"); + std::string pipeTok = pipeTokFromPipeEnum(pipe); + auto argsAttr = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeTok), + op.getBufIdAttr(), + op.getModeAttr(), + }); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, "get_buf", + /*args=*/argsAttr, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{}); + return success(); + } +}; + +struct PTORlsBufToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::RlsBufOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + auto *ctx = rewriter.getContext(); + + auto opTypeOr = parseSyncOpTypeLikeAttr(op.getOpTypeAttr()); + if (failed(opTypeOr)) + return rewriter.notifyMatchFailure(op, "rls_buf expects pipe_event_type/sync_op_type attr"); + auto pipe = mapSyncOpTypeToPipe(*opTypeOr); + if (!isConcreteSyncPipe(pipe)) + return rewriter.notifyMatchFailure(op, "rls_buf op_type cannot map to a concrete pipe"); + std::string pipeTok = pipeTokFromPipeEnum(pipe); + auto argsAttr = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeTok), + op.getBufIdAttr(), + op.getModeAttr(), + }); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, "rls_buf", + /*args=*/argsAttr, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{}); + return success(); + } +}; + +struct PTOSetFFTsToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::SetFFTsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *ctx = rewriter.getContext(); + auto loc = op.getLoc(); + + Value fftsAddr = peelUnrealized(adaptor.getFfts()); + auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); + + if (isSetFFTsPointerLikeType(fftsAddr.getType())) { + auto castTyAttr = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + fftsAddr = + rewriter + .create(loc, u64Ty, "reinterpret_cast", + /*args=*/ArrayAttr{}, + /*templateArgs=*/castTyAttr, + /*operands=*/ValueRange{fftsAddr}) + .getResult(0); + } else if (fftsAddr.getType() != u64Ty) { + fftsAddr = + rewriter.create(loc, u64Ty, fftsAddr).getResult(); + } + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, "set_ffts_base_addr", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{fftsAddr}); + return success(); + } +}; + +struct PTOSyncSetToEmitC : public OpConversionPattern { + PTOSyncSetToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + PTOArch targetArch) + : OpConversionPattern(typeConverter, ctx), + targetArch(targetArch) {} + + LogicalResult + matchAndRewrite(mlir::pto::SyncSetOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto *ctx = rewriter.getContext(); + IntegerAttr eventIdAttr = op.getEventIdAttr(); + Value eventIdDyn = adaptor.getEventIdDyn(); + int64_t fftsMode = 2; + if (IntegerAttr fftsModeAttr = op.getFftsModeAttr()) + fftsMode = fftsModeAttr.getInt(); + + if ((eventIdAttr != nullptr) == static_cast(eventIdDyn)) + return rewriter.notifyMatchFailure( + op, "expects exactly one of static event_id attr or dynamic event_id operand"); + + // A5 inter-core sync mirrors +16 only for cube-side producer (PIPE_FIX). + // Vec-side producer (PIPE_MTE3) emits a single set; hardware handles the + // subblock mapping in PTO-ISA custom flow. + if (targetArch == PTOArch::A5) { + pto::PIPE pipe = op.getPipe().getPipe(); + bool needsMirrorPlus16 = (pipe == pto::PIPE::PIPE_FIX); + std::string pipeTok = pipeTokFromPipeAttr(op.getPipe()); + auto emitSet = [&](Value eventOperand, IntegerAttr eventLiteral, + bool isDynamic) { + if (isDynamic) { + auto args = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeTok), + IntegerAttr::get(IndexType::get(ctx), 0), + }); + rewriter.create(loc, TypeRange{}, "set_intra_block", + /*args=*/args, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{eventOperand}); + return; + } + auto args = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeTok), + eventLiteral, + }); + rewriter.create(loc, TypeRange{}, "set_intra_block", + /*args=*/args, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{}); + }; + + if (eventIdAttr) { + emitSet(Value{}, eventIdAttr, /*isDynamic=*/false); + if (needsMirrorPlus16) { + auto plus16 = IntegerAttr::get(eventIdAttr.getType(), + eventIdAttr.getInt() + 16); + emitSet(Value{}, plus16, /*isDynamic=*/false); + } + } else { + Value eventI32 = castInterCoreEventIdToI32(rewriter, loc, eventIdDyn); + emitSet(eventI32, IntegerAttr{}, /*isDynamic=*/true); + if (needsMirrorPlus16) { + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value c16 = makeEmitCIntConstant(rewriter, loc, i32Ty, 16); + Value eventI32Plus16 = + rewriter.create(loc, i32Ty, eventI32, c16).getResult(); + emitSet(eventI32Plus16, IntegerAttr{}, /*isDynamic=*/true); + } + } + + rewriter.eraseOp(op); + return success(); + } + + InterCoreSyncCallDesc desc; + if (eventIdAttr) { + desc = buildInterCoreSyncSetCall(rewriter, loc, targetArch, op.getPipe(), + eventIdAttr, fftsMode); + } else { + desc = buildInterCoreSyncSetCallDyn(rewriter, loc, targetArch, op.getPipe(), + eventIdDyn, fftsMode); + } + rewriter.create(loc, TypeRange{}, desc.callee, + /*args=*/desc.args, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/desc.operands); + + rewriter.eraseOp(op); + return success(); + } + + PTOArch targetArch; +}; + +struct PTOSyncWaitToEmitC : public OpConversionPattern { + PTOSyncWaitToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, + PTOArch targetArch) + : OpConversionPattern(typeConverter, ctx), + targetArch(targetArch) {} + + LogicalResult + matchAndRewrite(mlir::pto::SyncWaitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + IntegerAttr eventIdAttr = op.getEventIdAttr(); + Value eventIdDyn = adaptor.getEventIdDyn(); + + if ((eventIdAttr != nullptr) == static_cast(eventIdDyn)) + return rewriter.notifyMatchFailure( + op, "expects exactly one of static event_id attr or dynamic event_id operand"); + + InterCoreSyncCallDesc desc; + if (eventIdAttr) { + desc = buildInterCoreSyncWaitCall(rewriter, targetArch, op.getPipe(), + eventIdAttr); + } else { + desc = buildInterCoreSyncWaitCallDyn(rewriter, loc, targetArch, op.getPipe(), + eventIdDyn); + } + rewriter.create(loc, TypeRange{}, desc.callee, + desc.args, ArrayAttr{}, desc.operands); + + rewriter.eraseOp(op); + return success(); + } + + PTOArch targetArch; +}; + + +} // namespace + +void populatePTOToEmitCSyncPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx, PTOArch targetArch) { + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx, "pto.set_flag_dyn", + "set_flag"); + patterns.add(typeConverter, ctx, "pto.wait_flag_dyn", + "wait_flag"); + patterns.add(typeConverter, ctx, "pto.set_flag_d", + "set_flag"); + patterns.add(typeConverter, ctx, "pto.wait_flag_d", + "wait_flag"); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx, targetArch); + patterns.add(typeConverter, ctx, targetArch); +} + +} // namespace mlir::pto diff --git a/lib/PTO/Transforms/PTOToEmitCTileMaterialization.cpp b/lib/PTO/Transforms/PTOToEmitCTileMaterialization.cpp new file mode 100644 index 000000000..5fedd725c --- /dev/null +++ b/lib/PTO/Transforms/PTOToEmitCTileMaterialization.cpp @@ -0,0 +1,923 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- PTOToEmitCTileMaterialization.cpp ----------------------------------===// +//===----------------------------------------------------------------------===// + +#include "PTOToEmitCInternal.h" + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +#include + +using namespace mlir; +using namespace mlir::pto; + +namespace mlir::pto { +namespace { + +static constexpr llvm::StringLiteral kForceDynamicValidShapeAttrName = + "__pto.force_dynamic_valid_shape"; + +// ============================================================================= +// 2. BindTileOp Lowering (FIX: Trace back to physical address) +// ============================================================================= +struct PTOBindTileToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + struct TileBuildSpec { + std::string tileTypeStr; + bool useConstructor = false; + SmallVector constructorArgs; + }; + + static bool getIndexConst(Value v, int64_t &out) { + if (!v) + return false; + if (auto cst = v.getDefiningOp()) { + if (auto ia = dyn_cast(cst.getValue())) { + out = ia.getValue().getSExtValue(); + return true; + } + } + return false; + } + + static bool getTilePointerStrides(pto::TileBufConfigAttr configAttr, + Type elemTy, int64_t rows, int64_t cols, + int64_t &rowStride, + int64_t &colStride) { + if (rows == ShapedType::kDynamic || cols == ShapedType::kDynamic) + return false; + + int32_t blVal = 0; + if (auto blAttr = dyn_cast(configAttr.getBLayout())) + blVal = static_cast(blAttr.getValue()); + else if (auto intAttr = dyn_cast(configAttr.getBLayout())) + blVal = static_cast(intAttr.getInt()); + + int32_t slVal = 0; + if (auto slAttr = dyn_cast(configAttr.getSLayout())) + slVal = static_cast(slAttr.getValue()); + else if (auto intAttr = dyn_cast(configAttr.getSLayout())) + slVal = static_cast(intAttr.getInt()); + + bool boxed = slVal != 0; + int64_t innerRows = 1; + int64_t innerCols = 1; + if (boxed) { + int32_t fractal = 512; + if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) + fractal = static_cast(frAttr.getInt()); + + unsigned elemBytes = pto::getPTOStorageElemByteSize(elemTy); + if (elemBytes == 0) + return false; + + switch (fractal) { + case 1024: + innerRows = 16; + innerCols = 16; + break; + case 32: + innerRows = 16; + innerCols = 2; + break; + case 512: + if (slVal == 1) { + innerRows = 16; + innerCols = 32 / elemBytes; + } else if (slVal == 2) { + innerRows = 32 / elemBytes; + innerCols = 16; + } else { + return false; + } + break; + default: + return false; + } + if (innerRows <= 0 || innerCols <= 0) + return false; + } + + if (!boxed) { + if (blVal == 1) { + rowStride = 1; + colStride = rows; + } else { + rowStride = cols; + colStride = 1; + } + return true; + } + + if (blVal == 1) { + if (slVal != 1) + return false; + rowStride = innerCols; + colStride = rows; + return true; + } + + rowStride = cols; + colStride = innerRows; + return true; + } + + LogicalResult matchAndRewrite(pto::BindTileOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto configAttr = op.getConfigAttr(); + auto viewSemantics = op->getAttrOfType("pto.view_semantics"); + bool isSubView = viewSemantics && viewSemantics.getValue() == "subview"; + + auto peelAllCasts = [](Value v) { + while (auto castOp = v.getDefiningOp()) + v = castOp.getOperand(0); + if (auto castOp = v.getDefiningOp()) + v = castOp.getOperand(); + return v; + }; + auto isTileLike = [](Value v) -> bool { + auto ot = dyn_cast(v.getType()); + if (!ot) + return false; + StringRef s = ot.getValue(); + return s.contains("Tile<") || s.contains("ConvTile<"); + }; + auto buildTileSpec = [&]() -> FailureOr { + auto resMrTy = dyn_cast(op.getType()); + if (!resMrTy) + return failure(); + + const char *roleTok = "TileType::Vec"; + if (auto asAttr = + dyn_cast_or_null(resMrTy.getMemorySpace())) { + switch (asAttr.getAddressSpace()) { + case pto::AddressSpace::VEC: + roleTok = "TileType::Vec"; + break; + case pto::AddressSpace::MAT: + roleTok = "TileType::Mat"; + break; + case pto::AddressSpace::LEFT: + roleTok = "TileType::Left"; + break; + case pto::AddressSpace::RIGHT: + roleTok = "TileType::Right"; + break; + case pto::AddressSpace::ACC: + roleTok = "TileType::Acc"; + break; + case pto::AddressSpace::BIAS: + roleTok = "TileType::Bias"; + break; + case pto::AddressSpace::SCALING: + roleTok = "TileType::Scaling"; + break; + case pto::AddressSpace::GM: + case pto::AddressSpace::Zero: + roleTok = "TileType::Vec"; + break; + } + } + + Type elemTy = resMrTy.getElementType(); + Type emitElemTy = getTypeConverter()->convertType(elemTy); + if (!emitElemTy) + return failure(); + auto emitElemOpaque = dyn_cast(emitElemTy); + if (!emitElemOpaque) + return failure(); + std::string elemTypeStr = emitElemOpaque.getValue().str(); + + if (resMrTy.getRank() < 2) + return failure(); + int64_t rows = resMrTy.getDimSize(0); + int64_t cols = resMrTy.getDimSize(1); + if (rows == ShapedType::kDynamic || cols == ShapedType::kDynamic) + return failure(); + + std::string blTok = "BLayout::RowMajor"; + if (auto blAttr = dyn_cast(configAttr.getBLayout())) { + if (static_cast(blAttr.getValue()) == 1) + blTok = "BLayout::ColMajor"; + } + pto::BLayout blayout = getTileBufBLayoutValue(configAttr); + + if (isSubView) { + auto subMrTy = dyn_cast(op.getSource().getType()); + auto subViewOp = op.getSource().getDefiningOp(); + if (subMrTy && subMrTy.getRank() >= 2 && subViewOp) { + int64_t subRows = subMrTy.getDimSize(0); + int64_t subCols = subMrTy.getDimSize(1); + SmallVector inheritedStrides; + int64_t inheritedOffset = ShapedType::kDynamic; + + if (!pto::isPTOFloat4PackedType(elemTy) && + subRows != ShapedType::kDynamic && + subCols != ShapedType::kDynamic && + succeeded(getStridesAndOffset(subMrTy, inheritedStrides, + inheritedOffset)) && + inheritedStrides.size() >= 2) { + int64_t childRowStride = 0; + int64_t childColStride = 0; + bool sameStrides = getTilePointerStrides( + configAttr, elemTy, subRows, subCols, childRowStride, + childColStride); + sameStrides = sameStrides && + inheritedStrides[0] == childRowStride && + inheritedStrides[1] == childColStride; + if (sameStrides) { + rows = subRows; + cols = subCols; + } + } + } + } + + std::string slTok = "SLayout::NoneBox"; + if (auto slAttr = dyn_cast(configAttr.getSLayout())) { + int32_t slVal = static_cast(slAttr.getValue()); + slTok = (slVal == 1) ? "SLayout::RowMajor" + : (slVal == 2) ? "SLayout::ColMajor" + : "SLayout::NoneBox"; + } + + int32_t fractal = 512; + if (auto frAttr = dyn_cast(configAttr.getSFractalSize())) + fractal = frAttr.getInt(); + + std::string padTok = "PadValue::Null"; + if (auto padAttr = dyn_cast(configAttr.getPad())) { + switch (static_cast(padAttr.getValue())) { + case 1: + padTok = "PadValue::Zero"; + break; + case 2: + padTok = "PadValue::Max"; + break; + case 3: + padTok = "PadValue::Min"; + break; + default: + padTok = "PadValue::Null"; + break; + } + } + + std::string compactTok = "CompactMode::Null"; + if (auto compactAttr = dyn_cast(configAttr.getCompactMode())) { + switch (static_cast(compactAttr.getValue())) { + case 1: + compactTok = "CompactMode::Normal"; + break; + case 2: + compactTok = "CompactMode::RowPlusOne"; + break; + default: + compactTok = "CompactMode::Null"; + break; + } + } + + std::string vrowTok, vcolTok; + bool useConstructor = false; + bool rowIsDynamic = false; + bool colIsDynamic = false; + SmallVector constructorArgs; + + Value vRow = op.getValidRow(); + Value vCol = op.getValidCol(); + Value vRowEmitC = adaptor.getValidRow(); + Value vColEmitC = adaptor.getValidCol(); + bool forceDynamicValid = op->hasAttr(kForceDynamicValidShapeAttrName); + int64_t cRow = 0, cCol = 0; + bool rowIsConst = vRow && getIndexConst(vRow, cRow); + bool colIsConst = vCol && getIndexConst(vCol, cCol); + + auto makeCtorDimValue = [&](Value emitted, int64_t fallback) -> Value { + if (emitted) + return emitted; + return makeEmitCIntConstant( + rewriter, loc, emitc::OpaqueType::get(ctx, "int32_t"), fallback); + }; + auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { + if (!emitted || !pto::isPTOFloat4PackedType(elemTy)) + return emitted; + int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; + if (dimIdx != packedDim) + return emitted; + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); + return rewriter.create(loc, i32Ty, emitted, two).getResult(); + }; + + if (forceDynamicValid) { + vrowTok = "-1"; + vcolTok = "-1"; + useConstructor = true; + constructorArgs.push_back( + makeCtorDimValue(maybeScaleDynamicValid(vRowEmitC, 0), + renderTileTemplateDim(rowIsConst ? cRow : rows, + elemTy, blayout, 0))); + constructorArgs.push_back( + makeCtorDimValue(maybeScaleDynamicValid(vColEmitC, 1), + renderTileTemplateDim(colIsConst ? cCol : cols, + elemTy, blayout, 1))); + } else { + if (rowIsConst) { + vrowTok = std::to_string( + renderTileTemplateDim(cRow, elemTy, blayout, 0)); + } else if (vRow) { + vrowTok = "-1"; + rowIsDynamic = true; + useConstructor = true; + } else { + vrowTok = std::to_string( + renderTileTemplateDim(rows, elemTy, blayout, 0)); + } + + if (colIsConst) { + vcolTok = std::to_string( + renderTileTemplateDim(cCol, elemTy, blayout, 1)); + } else if (vCol) { + vcolTok = "-1"; + colIsDynamic = true; + useConstructor = true; + } else { + vcolTok = std::to_string( + renderTileTemplateDim(cols, elemTy, blayout, 1)); + } + + if (useConstructor) { + if (rowIsDynamic && vRowEmitC) + constructorArgs.push_back(maybeScaleDynamicValid(vRowEmitC, 0)); + if (colIsDynamic && vColEmitC) + constructorArgs.push_back(maybeScaleDynamicValid(vColEmitC, 1)); + } + } + + std::string tileTypeStr = std::string("Tile<") + roleTok + ", " + + elemTypeStr + ", " + + std::to_string(renderTileTemplateDim( + rows, elemTy, blayout, 0)) + + ", " + + std::to_string(renderTileTemplateDim( + cols, elemTy, blayout, 1)) + + ", " + blTok + + ", " + vrowTok + ", " + vcolTok + ", " + slTok + + ", " + std::to_string(fractal) + ", " + padTok + + ", " + compactTok + + ">"; + return TileBuildSpec{tileTypeStr, useConstructor, constructorArgs}; + }; + + auto buildTileValue = [&](const TileBuildSpec &spec, + bool forceDeclaration = false) -> Value { + auto tileType = emitc::OpaqueType::get(ctx, spec.tileTypeStr); + if (spec.useConstructor && !forceDeclaration) { + return rewriter + .create(loc, tileType, spec.tileTypeStr, + ArrayAttr{}, ArrayAttr{}, + ValueRange(spec.constructorArgs)) + .getResult(0); + } + + return rewriter + .create(loc, tileType, emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + }; + + auto emitElemTypeToString = [&](Type elemTy) -> std::string { + return getEmitCScalarTypeToken(elemTy); + }; + + auto buildIntegralAddress = [&](Value sourceValue) -> FailureOr { + auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); + auto rcU64 = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + + Value rawPtr = sourceValue; + if (auto ot = dyn_cast(sourceValue.getType())) { + StringRef tyStr = ot.getValue(); + if (tyStr.contains("Tile<") || tyStr.contains("ConvTile<")) { + auto srcMrTy = dyn_cast(op.getSource().getType()); + if (!srcMrTy) + return failure(); + std::string elemTok = emitElemTypeToString(srcMrTy.getElementType()); + pto::AddressSpace as = pto::AddressSpace::GM; + if (auto asAttr = + dyn_cast_or_null(srcMrTy.getMemorySpace())) + as = asAttr.getAddressSpace(); + rawPtr = materializeTileDataValue(rewriter, loc, sourceValue, as, + elemTok); + } + } + + if (isSetFFTsPointerLikeType(rawPtr.getType())) { + return rewriter + .create(loc, u64Ty, "reinterpret_cast", + ArrayAttr{}, rcU64, ValueRange{rawPtr}) + .getResult(0); + } + + if (rawPtr.getType() == u64Ty) + return rawPtr; + return rewriter.create(loc, u64Ty, rawPtr).getResult(); + }; + + if (op.getSource().getDefiningOp()) { + FailureOr tileSpec = buildTileSpec(); + if (failed(tileSpec)) + return failure(); + rewriter.replaceOp(op, buildTileValue(*tileSpec)); + return success(); + } + + Value tileCandidate = peelAllCasts(adaptor.getSource()); + if (viewSemantics && viewSemantics.getValue() == "bitcast" && + isTileLike(tileCandidate)) { + FailureOr tileSpec = buildTileSpec(); + if (failed(tileSpec)) + return failure(); + Value dstTile = buildTileValue(*tileSpec); + FailureOr addr = buildIntegralAddress(tileCandidate); + if (failed(addr)) + return failure(); + + rewriter.create(loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dstTile, *addr}); + rewriter.replaceOp(op, dstTile); + return success(); + } + + if (viewSemantics && viewSemantics.getValue() == "treshape" && + isTileLike(tileCandidate)) { + FailureOr tileSpec = buildTileSpec(); + if (failed(tileSpec)) + return failure(); + Value dstTile = buildTileValue(*tileSpec, /*forceDeclaration=*/true); + + rewriter.create(loc, TypeRange{}, "TRESHAPE", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dstTile, tileCandidate}); + rewriter.replaceOp(op, dstTile); + return success(); + } + + // Subview origins are kept distinct from generic tile rebinding: + // even when source/destination C++ tile types match, subview may carry + // shifted base address semantics and should materialize a fresh handle. + if (isSubView) { + FailureOr tileSpec = buildTileSpec(); + if (failed(tileSpec)) + return failure(); + Value dstTile = buildTileValue(*tileSpec); + FailureOr addr = buildIntegralAddress(tileCandidate); + if (failed(addr)) + return failure(); + + rewriter.create(loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dstTile, *addr}); + rewriter.replaceOp(op, dstTile); + return success(); + } + + // Generic tile-to-tile rebind path: preserve the same backing storage and + // rebuild a sibling tile with updated metadata/valid dims. + if (isTileLike(tileCandidate)) { + FailureOr tileSpec = buildTileSpec(); + if (failed(tileSpec)) + return failure(); + + if (!tileSpec->useConstructor) { + if (auto srcTy = dyn_cast(tileCandidate.getType())) { + if (srcTy.getValue() == tileSpec->tileTypeStr) { + rewriter.replaceOp(op, tileCandidate); + return success(); + } + } + } + + Value dstTile = buildTileValue(*tileSpec); + FailureOr addr = buildIntegralAddress(tileCandidate); + if (failed(addr)) + return failure(); + + rewriter.create(loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dstTile, *addr}); + rewriter.replaceOp(op, dstTile); + return success(); + } + + SmallVector physAddrs; + Value source = op.getSource(); + + while (auto castOp = source.getDefiningOp()) + source = castOp.getOperand(0); + + if (auto upstreamCast = source.getDefiningOp()) { + auto upstreamOperands = upstreamCast.getAddrs(); + physAddrs.append(upstreamOperands.begin(), upstreamOperands.end()); + } else { + physAddrs.push_back(adaptor.getSource()); + } + + Value vRow = op.getValidRow(); + Value vCol = op.getValidCol(); + + auto newCast = rewriter.create( + loc, op.getType(), physAddrs, vRow ? vRow : Value(), + vCol ? vCol : Value(), configAttr); + if (viewSemantics) + newCast->setAttr("pto.view_semantics", viewSemantics); + if (op->hasAttr(kForceDynamicValidShapeAttrName)) + newCast->setAttr(kForceDynamicValidShapeAttrName, + op->getAttr(kForceDynamicValidShapeAttrName)); + rewriter.replaceOp(op, newCast.getResult()); + + return success(); + } +}; + +struct PTOAllocTileToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::AllocTileOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + auto tileTy = cast(op.getResult().getType()); + auto tileTypeString = getEmitCTileTypeString(tileTy); + if (!tileTypeString) + return rewriter.notifyMatchFailure( + op, "only rank-2 alloc_tile handles can be converted to EmitC"); + + Type convertedTy = getTypeConverter()->convertType(tileTy); + if (!convertedTy) + convertedTy = emitc::OpaqueType::get(ctx, *tileTypeString); + + auto validShape = tileTy.getValidShape(); + bool hasDynamicValidDim = + llvm::any_of(validShape, [](int64_t dim) { return dim < 0; }); + bool useConstructor = hasDynamicValidDim; + + SmallVector constructorArgs; + if (useConstructor) { + Type elemTy = tileTy.getElementType(); + pto::BLayout blayout = getTileBufBLayoutValue(tileTy.getConfigAttr()); + auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { + if (!emitted || !pto::isPTOFloat4PackedType(elemTy)) + return emitted; + int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; + if (dimIdx != packedDim) + return emitted; + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); + return rewriter.create(loc, i32Ty, emitted, two) + .getResult(); + }; + + if (validShape.size() > 0 && validShape[0] < 0) { + Value validRow = adaptor.getValidRow(); + if (!validRow) + return rewriter.notifyMatchFailure( + op, "dynamic alloc_tile valid row must have an operand"); + if (validRow) + validRow = peelUnrealized(validRow); + constructorArgs.push_back(maybeScaleDynamicValid(validRow, 0)); + } + if (validShape.size() > 1 && validShape[1] < 0) { + Value validCol = adaptor.getValidCol(); + if (!validCol) + return rewriter.notifyMatchFailure( + op, "dynamic alloc_tile valid col must have an operand"); + if (validCol) + validCol = peelUnrealized(validCol); + constructorArgs.push_back(maybeScaleDynamicValid(validCol, 1)); + } + } + + Value tile; + if (useConstructor) { + tile = rewriter + .create( + loc, convertedTy, *tileTypeString, ArrayAttr{}, + ArrayAttr{}, ValueRange(constructorArgs)) + .getResult(0); + } else { + tile = + rewriter + .create( + loc, convertedTy, emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + } + + Value addr = adaptor.getAddr(); + if (addr) { + addr = peelUnrealized(addr); + auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); + if (isa(addr.getType()) || + (isa(addr.getType()) && + cast(addr.getType()).getValue().ends_with("*"))) { + auto rcU64 = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + addr = rewriter + .create(loc, u64Ty, "reinterpret_cast", + ArrayAttr{}, rcU64, + ValueRange{addr}) + .getResult(0); + } else if (addr.getType() != u64Ty) { + addr = rewriter.create(loc, u64Ty, addr).getResult(); + } + + rewriter.create(loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{tile, addr}); + } + + rewriter.replaceOp(op, tile); + return success(); + } +}; + +static FailureOr +createEmitCTileVariable(ConversionPatternRewriter &rewriter, Location loc, + const TypeConverter *typeConverter, + pto::TileBufType tileTy) { + auto tileTypeString = getEmitCTileTypeString(tileTy); + if (!tileTypeString) + return failure(); + + Type convertedTy = typeConverter->convertType(tileTy); + if (!convertedTy) + convertedTy = emitc::OpaqueType::get(rewriter.getContext(), *tileTypeString); + + return rewriter + .create( + loc, convertedTy, emitc::OpaqueAttr::get(rewriter.getContext(), "")) + .getResult(); +} + +struct PTOTReshapeToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TReshapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto tileTy = dyn_cast(op.getResult().getType()); + if (!tileTy) + return failure(); + + FailureOr dst = + createEmitCTileVariable(rewriter, op.getLoc(), getTypeConverter(), tileTy); + if (failed(dst)) + return failure(); + + Value src = peelUnrealized(adaptor.getSrc()); + if (auto castOp = src.getDefiningOp()) + src = castOp.getOperand(); + + rewriter.create(op.getLoc(), TypeRange{}, "TRESHAPE", + ArrayAttr{}, ArrayAttr{}, + ValueRange{*dst, src}); + rewriter.replaceOp(op, *dst); + return success(); + } +}; + +struct PTOBitcastToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::BitcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto dstTy = dyn_cast(op.getResult().getType()); + auto srcTy = dyn_cast(op.getSrc().getType()); + if (!dstTy || !srcTy) + return failure(); + + FailureOr dst = + createEmitCTileVariable(rewriter, op.getLoc(), getTypeConverter(), dstTy); + if (failed(dst)) + return failure(); + + Value src = peelUnrealized(adaptor.getSrc()); + if (auto castOp = src.getDefiningOp()) + src = castOp.getOperand(); + + pto::AddressSpace as = pto::AddressSpace::GM; + if (auto asAttr = + dyn_cast_or_null(srcTy.getMemorySpace())) + as = asAttr.getAddressSpace(); + std::string elemTok = getEmitCScalarTypeToken(srcTy.getElementType()); + + Value rawPtr = materializeTileDataValue(rewriter, op.getLoc(), src, as, elemTok); + auto u64Ty = emitc::OpaqueType::get(rewriter.getContext(), "uint64_t"); + Value addr = rawPtr; + if (isSetFFTsPointerLikeType(rawPtr.getType())) { + auto rcU64 = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(rewriter.getContext(), + "uint64_t")}); + addr = rewriter + .create(op.getLoc(), u64Ty, + "reinterpret_cast", ArrayAttr{}, + rcU64, ValueRange{rawPtr}) + .getResult(0); + } else if (addr.getType() != u64Ty) { + addr = rewriter.create(op.getLoc(), u64Ty, addr).getResult(); + } + + rewriter.create(op.getLoc(), TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{*dst, addr}); + rewriter.replaceOp(op, *dst); + return success(); + } +}; + +struct PTOMaterializeTileToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + static bool isTileLike(Value v) { + auto ot = dyn_cast(v.getType()); + if (!ot) + return false; + StringRef s = ot.getValue(); + return s.contains("Tile<") || s.contains("ConvTile<"); + } + + LogicalResult matchAndRewrite(pto::MaterializeTileOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + auto tileTy = cast(op.getResult().getType()); + auto tileTypeString = getEmitCTileTypeString(tileTy); + if (!tileTypeString) + return rewriter.notifyMatchFailure( + op, "only rank-2 tile_buf handles can be materialized to EmitC"); + + Type convertedTy = getTypeConverter()->convertType(tileTy); + if (!convertedTy) + convertedTy = emitc::OpaqueType::get(ctx, *tileTypeString); + + Value source = peelUnrealized(adaptor.getSource()); + if (auto castOp = source.getDefiningOp()) + source = castOp.getOperand(); + + auto viewSemantics = op->getAttrOfType("pto.view_semantics"); + bool forceDynamicValid = op->hasAttr(kForceDynamicValidShapeAttrName); + bool isReshape = viewSemantics && viewSemantics.getValue() == "treshape"; + bool isSubview = viewSemantics && viewSemantics.getValue() == "subview"; + bool sourceIsDeclaredTile = + op.getSource().getDefiningOp(); + + auto createTileValue = [&]() -> Value { + SmallVector constructorArgs; + bool useConstructor = false; + pto::BLayout blayout = getTileBufBLayoutValue(tileTy.getConfigAttr()); + Type elemTy = tileTy.getElementType(); + auto shape = tileTy.getShape(); + auto validShape = tileTy.getValidShape(); + + auto makeCtorDimValue = [&](Value emitted, int64_t fallback) -> Value { + if (emitted) + return emitted; + return makeEmitCIntConstant( + rewriter, loc, emitc::OpaqueType::get(ctx, "int32_t"), fallback); + }; + auto maybeScaleDynamicValid = [&](Value emitted, int dimIdx) -> Value { + if (!emitted || !pto::isPTOFloat4PackedType(elemTy)) + return emitted; + int packedDim = blayout == pto::BLayout::ColMajor ? 0 : 1; + if (dimIdx != packedDim) + return emitted; + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value two = makeEmitCIntConstant(rewriter, loc, i32Ty, 2); + return rewriter.create(loc, i32Ty, emitted, two).getResult(); + }; + auto fallbackDim = [&](int dimIdx) { + return renderTileTemplateDim(shape[dimIdx], elemTy, blayout, dimIdx); + }; + + if (forceDynamicValid) { + useConstructor = true; + constructorArgs.push_back(makeCtorDimValue( + maybeScaleDynamicValid(adaptor.getValidRow(), 0), fallbackDim(0))); + constructorArgs.push_back(makeCtorDimValue( + maybeScaleDynamicValid(adaptor.getValidCol(), 1), fallbackDim(1))); + } else { + if (validShape[0] == ShapedType::kDynamic) { + useConstructor = true; + constructorArgs.push_back(makeCtorDimValue( + maybeScaleDynamicValid(adaptor.getValidRow(), 0), fallbackDim(0))); + } + if (validShape[1] == ShapedType::kDynamic) { + useConstructor = true; + constructorArgs.push_back(makeCtorDimValue( + maybeScaleDynamicValid(adaptor.getValidCol(), 1), fallbackDim(1))); + } + } + + if (useConstructor) { + return rewriter + .create(loc, convertedTy, *tileTypeString, + ArrayAttr{}, ArrayAttr{}, + ValueRange(constructorArgs)) + .getResult(0); + } + + return rewriter + .create(loc, convertedTy, + emitc::OpaqueAttr::get(ctx, "")) + .getResult(); + }; + + if (!isSubview && !forceDynamicValid && isTileLike(source)) { + if (auto srcTy = dyn_cast(source.getType())) { + if (srcTy.getValue() == *tileTypeString) { + rewriter.replaceOp(op, source); + return success(); + } + } + } + + Value tile = createTileValue(); + if (sourceIsDeclaredTile) { + rewriter.replaceOp(op, tile); + return success(); + } + + if (isReshape && isTileLike(source)) { + rewriter.create(loc, TypeRange{}, "TRESHAPE", + ArrayAttr{}, ArrayAttr{}, + ValueRange{tile, source}); + rewriter.replaceOp(op, tile); + return success(); + } + + pto::AddressSpace as = pto::AddressSpace::GM; + if (auto asAttr = + dyn_cast_or_null(tileTy.getMemorySpace())) + as = asAttr.getAddressSpace(); + std::string elemTok = getEmitCScalarTypeToken(tileTy.getElementType()); + + Value rawPtr = source; + if (isTileLike(rawPtr)) + rawPtr = materializeTileDataValue(rewriter, loc, rawPtr, as, elemTok); + + auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); + Value addr = rawPtr; + if (isSetFFTsPointerLikeType(rawPtr.getType())) { + auto rcU64 = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + addr = rewriter + .create(loc, u64Ty, "reinterpret_cast", + ArrayAttr{}, rcU64, + ValueRange{rawPtr}) + .getResult(0); + } else if (rawPtr.getType() != u64Ty) { + addr = rewriter.create(loc, u64Ty, rawPtr).getResult(); + } + + rewriter.create(loc, TypeRange{}, "TASSIGN", + ArrayAttr{}, ArrayAttr{}, + ValueRange{tile, addr}); + rewriter.replaceOp(op, tile); + return success(); + } +}; + + +} // namespace + +void populatePTOToEmitCTileMaterializationPatterns( + RewritePatternSet &patterns, TypeConverter &typeConverter, MLIRContext *ctx) { + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); +} + +} // namespace mlir::pto diff --git a/lib/PTO/Transforms/PTOToEmitCTilePatterns.cpp b/lib/PTO/Transforms/PTOToEmitCTilePatterns.cpp new file mode 100644 index 000000000..e723c2c9c --- /dev/null +++ b/lib/PTO/Transforms/PTOToEmitCTilePatterns.cpp @@ -0,0 +1,1439 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- PTOToEmitCTilePatterns.cpp ----------------------------------------===// +//===----------------------------------------------------------------------===// + +#include "PTOToEmitCInternal.h" + +#include "PTO/IR/PTO.h" + +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +#include + +using namespace mlir; +using namespace mlir::pto; + +namespace mlir::pto { +namespace { + +struct PTOTAndToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAndOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value a = peelUnrealized(adaptor.getSrc0()); + Value b = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TAND", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, a, b}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOConcatToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TConcatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TCONCAT", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; +struct PTOConcatidxToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TConcatidxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value src0Idx = peelUnrealized(adaptor.getSrc0Idx()); + Value src1Idx = peelUnrealized(adaptor.getSrc1Idx()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TCONCAT", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src0, src1, src0Idx, src1Idx}); + + rewriter.eraseOp(op); + return success(); + } +}; +struct PTOAndSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAndSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TANDS", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src, scalar}); + + rewriter.eraseOp(op); + return success(); + } +}; + + +struct PTOTCIToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TCIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value S = peelUnrealized(adaptor.getOperands()[0]); + + // The TCI scalar template parameter should follow the original PTO IR + // scalar type, not the converted EmitC value type. + std::string scalarTok = "int32_t"; + if (auto it = dyn_cast(op->getOperand(0).getType())) { + bool isUnsigned = it.isUnsigned(); + if (it.getWidth() == 16) + scalarTok = isUnsigned ? "uint16_t" : "int16_t"; + else + scalarTok = isUnsigned ? "uint32_t" : "int32_t"; + } + + // descending -> "0"/"1" + std::string descTok = op.getDescending() ? "1" : "0"; + + ArrayAttr targs; + if (auto ot = mlir::dyn_cast(dst.getType())) { + std::string tileTok = ot.getValue().str(); // "Tile<...>" + targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, tileTok), + emitc::OpaqueAttr::get(ctx, scalarTok), + emitc::OpaqueAttr::get(ctx, descTok), + }); + } else { + targs = rewriter.getArrayAttr({}); + } + + rewriter.create( + loc, TypeRange{}, "TCI", + /*args=*/ArrayAttr{}, + /*templateArgs=*/targs, + /*operands=*/ValueRange{dst, S}); + + rewriter.eraseOp(op); + return success(); + } +}; +static std::string cmpModeTok(pto::CmpModeAttr a) { + // 生成 "CmpMode::GT" 这种 token + auto m = a.getValue(); // 取 enum + switch (m) { + case pto::CmpMode::EQ: return "CmpMode::EQ"; + case pto::CmpMode::NE: return "CmpMode::NE"; + case pto::CmpMode::LT: return "CmpMode::LT"; + case pto::CmpMode::LE: return "CmpMode::LE"; + case pto::CmpMode::GT: return "CmpMode::GT"; + case pto::CmpMode::GE: return "CmpMode::GE"; + } + return "CmpMode::EQ"; +} +struct PTOColExpandToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value src = peelUnrealized(adaptor.getSrc()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPAND", + /*args=*/ArrayAttr(), + /*templateArgs=*/ArrayAttr(), + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColExpandMulToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandMulOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPANDMUL", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColExpandAddToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandAddOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPANDADD", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColExpandDivToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandDivOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPANDDIV", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColExpandExpdifToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandExpdifOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPANDEXPDIF", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColExpandSubToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandSubOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPANDSUB", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColExpandMaxToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPANDMAX", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColExpandMinToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColExpandMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLEXPANDMIN", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOTTriToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TTriOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value diagonal = peelUnrealized(adaptor.getDiagonal()); + + ArrayAttr templateArgs; + if (auto dstOT = mlir::dyn_cast(dst.getType())) { + templateArgs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), + emitc::OpaqueAttr::get(ctx, std::to_string(op.getUpperOrLower())), + }); + } else { + templateArgs = ArrayAttr{}; + } + + SmallVector operands{dst, diagonal}; + rewriter.create( + loc, TypeRange{}, "TTRI", + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOCmpToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TCmpOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + + std::string tok = "CmpMode::EQ"; + if (auto a = op.getCmpModeAttr()) + tok = cmpModeTok(a); + + auto modeTy = emitc::OpaqueType::get(ctx, "CmpMode"); + Value modeVal = rewriter.create( + loc, modeTy, emitc::OpaqueAttr::get(ctx, tok)); + + rewriter.create( + loc, + TypeRange{}, + "TCMP", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1, modeVal}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOCmpSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TCmpSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + + // cmpMode -> token + auto cmpAttr = op.getCmpModeAttr(); // PTO_CmpModeAttr + std::string tok = cmpModeTok(cmpAttr); + + auto modeTy = emitc::OpaqueType::get(ctx, "CmpMode"); + Value modeVal = rewriter.create( + loc, modeTy, emitc::OpaqueAttr::get(ctx, tok)); + + rewriter.create( + loc, + TypeRange{}, + "TCMPS", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, scalar, modeVal}); + + rewriter.eraseOp(op); + return success(); + } +}; + + +struct PTOColMaxToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + // intrinsic: TCOLMAX(dst, src) + rewriter.create( + loc, TypeRange{}, "TCOLMAX", + /*args=*/ArrayAttr{}, // default: print all operands + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColArgMaxToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColArgMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLARGMAX", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, tmp}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColMinToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + // intrinsic: TCOLMIN(dst, src) + rewriter.create( + loc, TypeRange{}, "TCOLMIN", + /*args=*/ArrayAttr{}, // default: print all operands + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColArgMinToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColArgMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLARGMIN", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, tmp}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColSumToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColSumOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + // Check if tmp exists before accessing it + if (op.getTmp()) { + // Format 2: with tmp and isBinary + Value tmp = peelUnrealized(adaptor.getTmp()); + bool isBinary = false; + if (auto a = op.getIsBinaryAttr()) + isBinary = a.getValue(); + + auto boolTy = emitc::OpaqueType::get(ctx, "bool"); + auto tok = isBinary ? "true" : "false"; + Value isBinaryVal = rewriter.create( + loc, boolTy, emitc::OpaqueAttr::get(ctx, tok)); + + rewriter.create( + loc, TypeRange{}, "TCOLSUM", + /*args=*/ArrayAttr(), + /*templateArgs=*/ArrayAttr(), + /*operands=*/ValueRange{dst, src, tmp, isBinaryVal}); + } else { + // Format 1: without tmp and isBinary + rewriter.create( + loc, TypeRange{}, "TCOLSUM", + /*args=*/ArrayAttr(), + /*templateArgs=*/ArrayAttr(), + /*operands=*/ValueRange{dst, src}); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOColProdToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TColProdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TCOLPROD", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; +static std::string roundModeTok(mlir::pto::RoundModeAttr attr) { + using RM = mlir::pto::RoundMode; + switch (attr.getValue()) { + case RM::NONE: return "RoundMode::CAST_NONE"; + case RM::RINT: return "RoundMode::CAST_RINT"; + case RM::ROUND: return "RoundMode::CAST_ROUND"; + case RM::FLOOR: return "RoundMode::CAST_FLOOR"; + case RM::CEIL: return "RoundMode::CAST_CEIL"; + case RM::TRUNC: return "RoundMode::CAST_TRUNC"; + case RM::ODD: return "RoundMode::CAST_ODD"; + case RM::CAST_RINT: return "RoundMode::CAST_RINT"; + } + return "RoundMode::CAST_RINT"; +} +static std::string saturationModeTok(mlir::pto::SaturationModeAttr attr) { + using SM = mlir::pto::SaturationMode; + switch (attr.getValue()) { + case SM::ON: return "SaturationMode::ON"; + case SM::OFF: return "SaturationMode::OFF"; + } + return "SaturationMode::OFF"; +} +struct PTOCvtToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TCvtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + pto::RoundModeAttr rmAttr = op.getRmodeAttr(); + std::string rmTok = rmAttr ? roundModeTok(rmAttr) + : std::string("RoundMode::CAST_RINT"); + auto rmodeTy = emitc::OpaqueType::get(ctx, "RoundMode"); + Value rmodeVal = rewriter.create( + loc, rmodeTy, emitc::OpaqueAttr::get(ctx, rmTok)); + + auto satModeTy = emitc::OpaqueType::get(ctx, "SaturationMode"); + auto satAttr = op.getSatModeAttr(); + std::string satTok = satAttr ? saturationModeTok(satAttr) + : std::string("SaturationMode::OFF"); + Value satModeVal = rewriter.create( + loc, satModeTy, emitc::OpaqueAttr::get(ctx, satTok)); + + SmallVector operands{dst, src, rmodeVal, satModeVal}; + + rewriter.create( + loc, TypeRange{}, "TCVT", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +struct PTORandomToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRandomOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + SmallVector operands{ + dst, + peelUnrealized(adaptor.getKey0()), + peelUnrealized(adaptor.getKey1()), + peelUnrealized(adaptor.getCounter0()), + peelUnrealized(adaptor.getCounter1()), + peelUnrealized(adaptor.getCounter2()), + peelUnrealized(adaptor.getCounter3()), + }; + ArrayAttr templateArgs = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, std::to_string(op.getRounds()))}); + + rewriter.create( + loc, TypeRange{}, "PTOAS__TRANDOM", + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, operands); + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tdiv lowering -> TDIV(dst, src0, src1) +//===----------------------------------------------------------------------===// + +struct PTODivToTDIV : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TDivOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TDIV", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src0, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tdivs lowering -> TDIVS(dst, src, scalar) or TDIVS(dst, scalar, src) +// Order is determined by operand types: if src is tile_buf, order is (tile, scalar) +// Otherwise, order is (scalar, tile) +//===----------------------------------------------------------------------===// + +struct PTODivSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TDivSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value dst = peelUnrealized(adaptor.getDst()); + // Preserve source order from textual parse: + // ins(tile, scalar) -> TDIVS(dst, tile, scalar) + // ins(scalar, tile) -> TDIVS(dst, scalar, tile) + rewriter.create( + loc, TypeRange{}, "TDIVS", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src, scalar}); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// pto.tdivs (TDivSOp) lowering -> TDIVS(dst, src, scalar) or TDIVS(dst, scalar, src) +// Order is determined by operand types: if src is tile_buf, order is (tile, scalar) +// Otherwise, order is (scalar, tile) +//===----------------------------------------------------------------------===// + +struct PTOTDivSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TDivSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value dst = peelUnrealized(adaptor.getDst()); + rewriter.create( + loc, TypeRange{}, "TDIVS", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src, scalar}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.texp lowering -> TEXP(dst, src) +//===----------------------------------------------------------------------===// + +struct PTOExpToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TExpOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TEXP", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.texpands lowering -> TEXPANDS(dst, scalar) +//===----------------------------------------------------------------------===// + +struct PTOExpandsToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TExpandsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value scalar = peelUnrealized(adaptor.getScalar()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TEXPANDS", + ArrayAttr{}, ArrayAttr{}, + ValueRange{dst, scalar}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.textract lowering -> TEXTRACT(dst, src, indexRow, indexCol) +//===----------------------------------------------------------------------===// + +struct PTOExtractToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TExtractOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value r0 = peelUnrealized(adaptor.getIndexRow()); + Value c0 = peelUnrealized(adaptor.getIndexCol()); + + rewriter.create( + loc, TypeRange{}, "TEXTRACT", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, r0, c0}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.textract_fp lowering -> TEXTRACT_FP(dst, src, fp, indexRow, indexCol) +//===----------------------------------------------------------------------===// + +struct PTOExtractFPToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TExtractFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value fp = peelUnrealized(adaptor.getFp()); + Value dst = peelUnrealized(adaptor.getDst()); + Value r0 = peelUnrealized(adaptor.getIndexRow()); + Value c0 = peelUnrealized(adaptor.getIndexCol()); + + rewriter.create( + loc, TypeRange{}, "TEXTRACT_FP", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, fp, r0, c0}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tinsert lowering -> TINSERT(dst, src, indexRow, indexCol) +// Keep lowering arch-agnostic and let PTO-ISA infer proper A5 path. +//===----------------------------------------------------------------------===// + +struct PTOInsertToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TInsertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value r0 = peelUnrealized(adaptor.getIndexRow()); + Value c0 = peelUnrealized(adaptor.getIndexCol()); + + rewriter.create( + loc, TypeRange{}, "TINSERT", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, r0, c0}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tinsert_fp lowering -> TINSERT_FP(dst, src, fp, indexRow, indexCol) +//===----------------------------------------------------------------------===// + +struct PTOInsertFPToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TInsertFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value fp = peelUnrealized(adaptor.getFp()); + Value dst = peelUnrealized(adaptor.getDst()); + Value r0 = peelUnrealized(adaptor.getIndexRow()); + Value c0 = peelUnrealized(adaptor.getIndexCol()); + + rewriter.create( + loc, TypeRange{}, "TINSERT_FP", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, fp, r0, c0}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tfillpad lowering -> TFILLPAD(dst, src) +//===----------------------------------------------------------------------===// + +struct PTOFillPadToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TFillPadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TFILLPAD", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tfillpad_inplace lowering -> TFILLPAD_INPLACE(dst, src) +//===----------------------------------------------------------------------===// + +struct PTOFillPadInplaceToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TFillPadInplaceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TFILLPAD_INPLACE", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tfillpad_expand lowering -> TFILLPAD_EXPAND(dst, src) +//===----------------------------------------------------------------------===// + +struct PTOFillPadExpandToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TFillPadExpandOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TFILLPAD_EXPAND", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// pto.tgather lowering +// - Index form : TGATHER(dst, src0, indices, tmp) +// - Compare form: TGATHER(dst, src0, kValue, cdst, tmp) +// - Mask form : TGATHER(dst, src0) +//===----------------------------------------------------------------------===// + +[[maybe_unused]] static std::string maskPatternTok(mlir::pto::MaskPatternAttr a) { + + auto v = a.getValue(); // enum + return (std::string("pto::MaskPattern::") + mlir::pto::stringifyMaskPattern(v).str()); +} + +struct PTOGatherToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value src0 = peelUnrealized(adaptor.getSrc()); + + auto getOpaqueTok = [&](Value v, StringRef name) -> FailureOr { + if (auto ot = mlir::dyn_cast(v.getType())) + return ot.getValue().str(); + return rewriter.notifyMatchFailure(op, (name + " must be emitc::OpaqueType (tile)").str()); + }; + + // Case 1: index-based TGATHER(dst, src0, indices, tmp) + if (Value idx = adaptor.getIndices()) { + idx = peelUnrealized(idx); + Value tmp = peelUnrealized(adaptor.getTmp()); + + rewriter.create( + loc, TypeRange{}, "TGATHER", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, idx, tmp}); + + rewriter.eraseOp(op); + return success(); + } + + // Case 2: compare-based TGATHER( + // dst, src0, kValue, tmp, cdst, offset) + if (Value cdst = adaptor.getCdst()) { + cdst = peelUnrealized(cdst); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value kValue = peelUnrealized(adaptor.getKValue()); + + auto dstTokOr = getOpaqueTok(dst, "dst"); + auto srcTokOr = getOpaqueTok(src0, "src0"); + auto cdstTokOr = getOpaqueTok(cdst, "cdst"); + auto tmpTokOr = getOpaqueTok(tmp, "tmp"); + if (failed(dstTokOr) || failed(srcTokOr) || failed(cdstTokOr) || failed(tmpTokOr)) + return failure(); + + auto cmpAttr = op.getCmpModeAttr(); + std::string cmpTok = cmpAttr ? cmpModeTok(cmpAttr) : "CmpMode::EQ"; + int64_t offset = 0; + if (auto offsetAttr = op.getOffsetAttr()) + offset = offsetAttr.getInt(); + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value offsetVal = makeEmitCIntConstant(rewriter, loc, i32Ty, offset); + + auto targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, *dstTokOr), + emitc::OpaqueAttr::get(ctx, *srcTokOr), + emitc::OpaqueAttr::get(ctx, *tmpTokOr), + emitc::OpaqueAttr::get(ctx, *cdstTokOr), + emitc::OpaqueAttr::get(ctx, cmpTok), + }); + + rewriter.create( + loc, TypeRange{}, "TGATHER", + /*args=*/ArrayAttr{}, /*templateArgs=*/targs, + /*operands=*/ValueRange{dst, src0, kValue, tmp, cdst, offsetVal}); + + rewriter.eraseOp(op); + return success(); + } + + // Case 3: mask-pattern TGATHER(dst, src0) + auto mp = op.getMaskPatternAttr(); + if (!mp) + return rewriter.notifyMatchFailure(op, "expected maskPattern, indices, or cdst on tgather"); + + auto dstTokOr = getOpaqueTok(dst, "dst"); + auto srcTokOr = getOpaqueTok(src0, "src0"); + if (failed(dstTokOr) || failed(srcTokOr)) + return failure(); + + // mp is an EnumAttr; stringify name is "P0101" etc. + // We emit MaskPattern::P0101 (because generated C++ has `using namespace pto;`) + std::string mpTok = std::string("MaskPattern::") + + mlir::pto::stringifyMaskPattern(mp.getValue()).str(); + + auto targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, *dstTokOr), + emitc::OpaqueAttr::get(ctx, *srcTokOr), + emitc::OpaqueAttr::get(ctx, mpTok), + }); + + rewriter.create( + loc, TypeRange{}, "TGATHER", + /*args=*/ArrayAttr{}, + /*templateArgs=*/targs, + /*operands=*/ValueRange{dst, src0}); + + rewriter.eraseOp(op); + return success(); + } +}; + + +struct PTOGatherbToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGatherBOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value offsets = peelUnrealized(adaptor.getOffsets()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TGATHERB", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, offsets}); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// TLOG lowering to EmitC (PTOConvert.cpp) +//===----------------------------------------------------------------------===// + +struct PTOLogToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TLogOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src}; + rewriter.create( + loc, TypeRange{}, "TLOG", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + + + +//===----------------------------------------------------------------------===// +// TLRELU lowering to EmitC (PTOConvert.cpp) +//===----------------------------------------------------------------------===// + + struct PTOLReluToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TLReluOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value slope = peelUnrealized(adaptor.getSlope()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, slope}; + + rewriter.create( + loc, TypeRange{}, "TLRELU", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// TMAX lowering to EmitC (PTOConvert.cpp) +//===----------------------------------------------------------------------===// + +struct PTOMaxToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TMAX", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// TMAXS lowering to EmitC (PTOConvert.cpp) +//===----------------------------------------------------------------------===// + + struct PTOMaxSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMaxSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, scalar}; + rewriter.create( + loc, TypeRange{}, "TMAXS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + + +//===----------------------------------------------------------------------===// +// TMIN lowering to EmitC (PTOConvert.cpp) +//===----------------------------------------------------------------------===// + +struct PTOMinToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TMIN", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// TMINS lowering to EmitC (PTOConvert.cpp) +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// TMINS lowering to EmitC (fix APFloat -> FloatAttr) (PTOConvert.cpp) +//===----------------------------------------------------------------------===// + +struct PTOMinsToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMinSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value scalar = peelUnrealized(adaptor.getScalar()); + + SmallVector operands{dst, src, scalar}; + rewriter.create( + loc, TypeRange{}, "TMINS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering for TMOV op -> EmitC) +//===----------------------------------------------------------------------===// + +struct PTOMovToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMovOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value fp; + if (op.getFp()) + fp = peelUnrealized(adaptor.getFp()); + Value preQuantScalar; + if (op.getPreQuantScalar()) + preQuantScalar = peelUnrealized(adaptor.getPreQuantScalar()); + + auto dstOT = mlir::dyn_cast(dst.getType()); + auto srcOT = mlir::dyn_cast(src.getType()); + if (!dstOT || !srcOT) + return rewriter.notifyMatchFailure( + op, "tmov lowering expects opaque dst/src types"); + + auto modeTok = [&](pto::AccToVecMode mode) -> StringRef { + switch (mode) { + case pto::AccToVecMode::SingleModeVec0: + return "pto::AccToVecMode::SingleModeVec0"; + case pto::AccToVecMode::SingleModeVec1: + return "pto::AccToVecMode::SingleModeVec1"; + case pto::AccToVecMode::DualModeSplitM: + return "pto::AccToVecMode::DualModeSplitM"; + case pto::AccToVecMode::DualModeSplitN: + return "pto::AccToVecMode::DualModeSplitN"; + } + llvm_unreachable("unknown AccToVecMode"); + }; + + auto modeAttr = op.getAccToVecModeAttr(); + auto reluTok = [&](pto::ReluPreMode mode) -> StringRef { + switch (mode) { + case pto::ReluPreMode::NoRelu: + return "ReluPreMode::NoRelu"; + case pto::ReluPreMode::NormalRelu: + return "ReluPreMode::NormalRelu"; + } + llvm_unreachable("unknown ReluPreMode"); + }; + + const bool hasFp = static_cast(fp); + const bool hasPreQuantScalar = static_cast(preQuantScalar); + const bool hasMode = static_cast(modeAttr); + const bool reluNonDefault = op.getReluPreMode() != pto::ReluPreMode::NoRelu; + + SmallVector operands{dst, src}; + SmallVector templateArgVec{ + emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), + emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), + }; + StringRef callee = "TMOV"; + + if (hasFp) { + auto fpOT = mlir::dyn_cast(fp.getType()); + if (!fpOT) + return rewriter.notifyMatchFailure( + op, "tmov fp lowering expects opaque fp type"); + operands.push_back(fp); + templateArgVec.push_back(emitc::OpaqueAttr::get(ctx, fpOT.getValue().str())); + if (hasMode) + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, modeTok(modeAttr.getValue()))); + if (hasMode || reluNonDefault) + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); + callee = hasMode ? "TMOV" : "TMOV_FP"; + } else if (hasPreQuantScalar) { + operands.push_back(preQuantScalar); + if (hasMode) + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, modeTok(modeAttr.getValue()))); + if (hasMode || reluNonDefault) + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); + } else if (hasMode) { + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, modeTok(modeAttr.getValue()))); + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); + } else if (reluNonDefault) { + templateArgVec.push_back( + emitc::OpaqueAttr::get(ctx, reluTok(op.getReluPreMode()))); + } + + ArrayAttr templateArgs = + templateArgVec.size() == 2 && !hasFp && !hasPreQuantScalar && + !hasMode && !reluNonDefault + ? ArrayAttr{} + : rewriter.getArrayAttr(templateArgVec); + + rewriter.create( + loc, TypeRange{}, callee, + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TMOV_FP DPS/memref op) +//===----------------------------------------------------------------------===// + +void populatePTOToEmitCTilePatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx) { + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add( + typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + populatePTOToEmitCTileExtraPatterns(patterns, typeConverter, ctx); +} + +} // namespace mlir::pto diff --git a/lib/PTO/Transforms/PTOToEmitCTilePatternsExtra.cpp b/lib/PTO/Transforms/PTOToEmitCTilePatternsExtra.cpp new file mode 100644 index 000000000..e7c5b93cc --- /dev/null +++ b/lib/PTO/Transforms/PTOToEmitCTilePatternsExtra.cpp @@ -0,0 +1,1819 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- PTOToEmitCTilePatternsExtra.cpp -----------------------------------===// +//===----------------------------------------------------------------------===// + +#include "PTOToEmitCInternal.h" + +#include "PTO/IR/PTO.h" + +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +#include + +using namespace mlir; +using namespace mlir::pto; + +namespace mlir::pto { +namespace { + +[[maybe_unused]] static std::string maskPatternTok(mlir::pto::MaskPatternAttr a) { + auto value = a.getValue(); + return (std::string("pto::MaskPattern::") + + mlir::pto::stringifyMaskPattern(value).str()); +} + +struct PTOMovFPToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMovFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value src = peelUnrealized(adaptor.getSrc()); + Value fp = peelUnrealized(adaptor.getFp()); + + // TMOV_FP(dstTileData, cTile, fbTile) + ArrayAttr templateArgs; + auto dstOT = mlir::dyn_cast(dst.getType()); + auto srcOT = mlir::dyn_cast(src.getType()); + auto fpOT = mlir::dyn_cast(fp.getType()); + if (dstOT && srcOT && fpOT) { + templateArgs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), + emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), + emitc::OpaqueAttr::get(ctx, fpOT.getValue().str()), + }); + } else { + templateArgs = ArrayAttr{}; + } + + SmallVector operands{dst, src, fp}; + rewriter.create( + loc, TypeRange{}, "TMOV_FP", + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOQuantToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TQuantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value src = peelUnrealized(adaptor.getSrc()); + Value fp = peelUnrealized(adaptor.getFp()); + + // Optional offset (INT8_ASYM only): passed as pointer (&offset) + Value offsetPtr; + if (op.getOffset()) { + Value offset = peelUnrealized(adaptor.getOffset()); + auto offsetOT = mlir::dyn_cast(offset.getType()); + if (offsetOT) { + offsetPtr = rewriter + .create( + loc, emitc::PointerType::get(offsetOT), "&", offset) + .getResult(); + } + } + + // TQUANT(dst, src, fp[, &offset]) + std::string quantTypeStr = + op.getQuantType() == pto::QuantType::INT8_SYM + ? "pto::QuantType::INT8_SYM" + : "pto::QuantType::INT8_ASYM"; + ArrayAttr templateArgs; + auto dstOT = mlir::dyn_cast(dst.getType()); + auto srcOT = mlir::dyn_cast(src.getType()); + auto fpOT = mlir::dyn_cast(fp.getType()); + if (dstOT && srcOT && fpOT) { + templateArgs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, quantTypeStr), + emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), + emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), + emitc::OpaqueAttr::get(ctx, fpOT.getValue().str()), + }); + } else { + templateArgs = ArrayAttr{}; + } + + SmallVector operands{dst, src, fp}; + if (offsetPtr) + operands.push_back(offsetPtr); + + rewriter.create( + loc, TypeRange{}, "TQUANT", + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +struct PTODequantToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TDequantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDst()); + Value src = peelUnrealized(adaptor.getSrc()); + Value scale = peelUnrealized(adaptor.getScale()); + Value offset = peelUnrealized(adaptor.getOffset()); + + // TDEQUANT(dst, src, scale, offset) + ArrayAttr templateArgs; + auto dstOT = mlir::dyn_cast(dst.getType()); + auto srcOT = mlir::dyn_cast(src.getType()); + auto scaleOT = mlir::dyn_cast(scale.getType()); + if (dstOT && srcOT && scaleOT) { + templateArgs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, dstOT.getValue().str()), + emitc::OpaqueAttr::get(ctx, srcOT.getValue().str()), + emitc::OpaqueAttr::get(ctx, scaleOT.getValue().str()), + }); + } else { + templateArgs = ArrayAttr{}; + } + + rewriter.create( + loc, TypeRange{}, "TDEQUANT", + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, + /*operands=*/SmallVector{dst, src, scale, offset}); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TMRGSORT DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOMrgSortToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMrgSortOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + if (op.isFormat1()) { + Value src = peelUnrealized(adaptor.getSrcs().front()); + Value dst = peelUnrealized(adaptor.getDsts().front()); + Value blockLen = peelUnrealized(adaptor.getBlockLen()); + + SmallVector operands{dst, src, blockLen}; + rewriter.create( + loc, TypeRange{}, "TMRGSORT", + ArrayAttr{}, ArrayAttr{}, operands); + } else if (op.isFormat2()) { + // pto-isa API: + // TMRGSORT( + // dst, executedNumList, tmp, src0, src1[, src2[, src3]]); + auto *ctx = rewriter.getContext(); + + Value dst = peelUnrealized(adaptor.getDsts()[0]); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value excuted = peelUnrealized(adaptor.getExcuted()); + + SmallVector srcs; + srcs.reserve(adaptor.getSrcs().size()); + for (Value v : adaptor.getSrcs()) + srcs.push_back(peelUnrealized(v)); + + auto dstOT = mlir::dyn_cast(dst.getType()); + auto tmpOT = mlir::dyn_cast(tmp.getType()); + if (!dstOT || !tmpOT || srcs.size() < 2 || srcs.size() > 4) + return op.emitOpError("format2 expects dst/tmp tilebufs and 2 to 4 srcs"); + + SmallVector targs; + targs.reserve(2 + srcs.size() + 1); + targs.push_back(emitc::OpaqueAttr::get(ctx, dstOT.getValue().str())); + targs.push_back(emitc::OpaqueAttr::get(ctx, tmpOT.getValue().str())); + for (Value v : srcs) { + auto ot = mlir::dyn_cast(v.getType()); + if (!ot) + return op.emitOpError("format2 expects tilebuf srcs"); + targs.push_back(emitc::OpaqueAttr::get(ctx, ot.getValue().str())); + } + targs.push_back(emitc::OpaqueAttr::get(ctx, op.getExhausted() ? "true" : "false")); + ArrayAttr templateArgs = rewriter.getArrayAttr(targs); + + SmallVector operands{dst, excuted, tmp}; + operands.append(srcs.begin(), srcs.end()); + + rewriter.create( + loc, TypeRange{}, "TMRGSORT", + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, operands); + } else { + return op.emitOpError("unsupported mrgsort_dps format"); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TMUL DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOMulToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMulOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TMUL", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TMULS DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOMulsToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMulSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc0()); + Value dst = peelUnrealized(adaptor.getDst()); + Value scalar = peelUnrealized(adaptor.getScalar()); + + SmallVector operands{dst, src, scalar}; + rewriter.create( + loc, TypeRange{}, "TMULS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TNEG DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTONegToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TNegOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src}; + rewriter.create( + loc, TypeRange{}, "TNEG", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TNOT DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTONotToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TNotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src}; + rewriter.create( + loc, TypeRange{}, "TNOT", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TOR DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOOrToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TOrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TOR", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TORS DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOOrsToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TOrSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + // NOTE: The conversion type system may materialize integers as emitc.opaque + // (e.g. "int32_t"). For EmitC call emission we can pass the scalar through + // directly without arith casts here. + Value s = adaptor.getScalar(); + + SmallVector operands{dst, src0, s}; + rewriter.create( + loc, TypeRange{}, "TORS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TPARTADD DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOPartAddToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPartAddOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TPARTADD", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TPARTMAX DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOPartMaxToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPartMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TPARTMAX", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TPARTMIN DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOPartMinToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPartMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TPARTMIN", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOPartArgMaxToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPartArgMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value src0Idx = peelUnrealized(adaptor.getSrc0Idx()); + Value src1Idx = peelUnrealized(adaptor.getSrc1Idx()); + Value dst = peelUnrealized(adaptor.getDst()); + Value dstIdx = peelUnrealized(adaptor.getDstIdx()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TPARTARGMAX", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1, dstIdx, src0Idx, src1Idx}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOPartArgMinToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPartArgMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value src0Idx = peelUnrealized(adaptor.getSrc0Idx()); + Value src1Idx = peelUnrealized(adaptor.getSrc1Idx()); + Value dst = peelUnrealized(adaptor.getDst()); + Value dstIdx = peelUnrealized(adaptor.getDstIdx()); + + rewriter.create( + op.getLoc(), TypeRange{}, "TPARTARGMIN", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1, dstIdx, src0Idx, src1Idx}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TPARTMUL DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOPartMulToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPartMulOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TPARTMUL", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TPRELU DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOPreluToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPReluOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + // C++ interface: TPRELU(dst, src0, src1, tmp) — last parameter is tmp. + SmallVector operands{dst, src0, src1, tmp}; + rewriter.create( + loc, TypeRange{}, "TPRELU", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TRECIP DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORecipToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRecipOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src}; + rewriter.create( + loc, TypeRange{}, "TRECIP", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TRELU DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOReluToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TReluOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src}; + rewriter.create( + loc, TypeRange{}, "TRELU", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TREM DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORemToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRemOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + SmallVector operands{dst, src0, src1, tmp}; + rewriter.create( + loc, TypeRange{}, "TREM", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOFModToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TFModOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TFMOD", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TREMS DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORemSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRemSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + Value scalar = peelUnrealized(adaptor.getScalar()); + SmallVector operands{dst, src, scalar, tmp}; + rewriter.create( + loc, TypeRange{}, "TREMS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOFModSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TFModSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value scalar = peelUnrealized(adaptor.getScalar()); + + SmallVector operands{dst, src, scalar}; + rewriter.create( + loc, TypeRange{}, "TFMODS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TROWEXPAND DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORowExpandToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src}; + rewriter.create( + loc, TypeRange{}, "TROWEXPAND", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTORowExpandAddToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandAddOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TROWEXPANDADD", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTORowExpandExpdifToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandExpdifOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); + + SmallVector operands; + if (tmp) + operands.assign({dst, src0, src1, tmp}); + else + operands.assign({dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TROWEXPANDEXPDIF", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TROWEXPANDDIV DPS/memref op) +//===----------------------------------------------------------------------===// +// Helper: replace or erase based on whether op has results. +static void replaceOrEraseWithOpaqueCall(Operation *op, + StringRef callee, + ArrayRef args, + ConversionPatternRewriter &rewriter) { + TypeRange resultTypes = op->getResultTypes(); + auto call = rewriter.create( + op->getLoc(), resultTypes, callee, ArrayAttr{}, ArrayAttr{}, ValueRange(args)); + if (resultTypes.empty()) + rewriter.eraseOp(op); + else + rewriter.replaceOp(op, call.getResults()); +} + +static void replaceOrEraseWithOpaqueCallAndReturnDst(Operation *op, Value dst, + StringRef callee, + ArrayRef args, + ConversionPatternRewriter &rewriter) { + rewriter.create( + op->getLoc(), TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, ValueRange(args)); + if (op->getNumResults() == 1) + rewriter.replaceOp(op, dst); + else + rewriter.eraseOp(op); +} + +// ---------- TOp ---------- +struct PTOTGemvBiasToTGEMV_BIAS + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGemvBiasOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value a = peelUnrealized(adaptor.getA()); + Value b = peelUnrealized(adaptor.getB()); + Value bias = peelUnrealized(adaptor.getBias()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCall(op.getOperation(), "TGEMV_BIAS", + {dst, a, b, bias}, rewriter); + return success(); + } +}; + +struct PTOTGemvMXToTGEMV_MX + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGemvMxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value a = peelUnrealized(adaptor.getA()); + Value aScale = peelUnrealized(adaptor.getAScale()); + Value b = peelUnrealized(adaptor.getB()); + Value bScale = peelUnrealized(adaptor.getBScale()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", + {dst, a, aScale, b, bScale}, rewriter); + return success(); + } +}; + +struct PTOTGemvMXAccToTGEMV_MX + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGemvMxAccOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value cIn = peelUnrealized(adaptor.getCIn()); + Value a = peelUnrealized(adaptor.getA()); + Value aScale = peelUnrealized(adaptor.getAScale()); + Value b = peelUnrealized(adaptor.getB()); + Value bScale = peelUnrealized(adaptor.getBScale()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", + {dst, cIn, a, aScale, b, bScale}, rewriter); + return success(); + } +}; + +struct PTOTGemvMXBiasToTGEMV_MX + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGemvMxBiasOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value a = peelUnrealized(adaptor.getA()); + Value aScale = peelUnrealized(adaptor.getAScale()); + Value b = peelUnrealized(adaptor.getB()); + Value bScale = peelUnrealized(adaptor.getBScale()); + Value bias = peelUnrealized(adaptor.getBias()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", + {dst, a, aScale, b, bScale, bias}, rewriter); + return success(); + } +}; + +struct PTOTMatmulBiasToTMATMUL_BIAS + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMatmulBiasOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value a = peelUnrealized(adaptor.getA()); + Value b = peelUnrealized(adaptor.getB()); + Value bias = peelUnrealized(adaptor.getBias()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_BIAS", + {dst, a, b, bias}, rewriter); + return success(); + } +}; + +struct PTOTMatmulMXToTMATMUL_MX + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMatmulMxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value a = peelUnrealized(adaptor.getA()); + Value aScale = peelUnrealized(adaptor.getAScale()); + Value b = peelUnrealized(adaptor.getB()); + Value bScale = peelUnrealized(adaptor.getBScale()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_MX", + {dst, a, aScale, b, bScale}, rewriter); + return success(); + } +}; + +struct PTOTMatmulMXAccToTMATMUL_MX_ACC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMatmulMxAccOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value cIn = peelUnrealized(adaptor.getCIn()); + Value a = peelUnrealized(adaptor.getA()); + Value aScale = peelUnrealized(adaptor.getAScale()); + Value b = peelUnrealized(adaptor.getB()); + Value bScale = peelUnrealized(adaptor.getBScale()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_MX", + {dst, cIn, a, aScale, b, bScale}, rewriter); + return success(); + } +}; + +struct PTOTMatmulMXBiasToTMATMUL_MX_BIAS + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TMatmulMxBiasOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value a = peelUnrealized(adaptor.getA()); + Value aScale = peelUnrealized(adaptor.getAScale()); + Value b = peelUnrealized(adaptor.getB()); + Value bScale = peelUnrealized(adaptor.getBScale()); + Value bias = peelUnrealized(adaptor.getBias()); + Value dst = peelUnrealized(adaptor.getDst()); + + replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_MX", + {dst, a, aScale, b, bScale, bias}, rewriter); + return success(); + } +}; + +struct PTORowExpandDivToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandDivOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); + + SmallVector operands; + if (tmp) + operands.assign({dst, src0, src1, tmp}); + else + operands.assign({dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TROWEXPANDDIV", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TROWEXPANDMUL DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORowExpandMulToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandMulOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); + + SmallVector operands; + if (tmp) + operands.assign({dst, src0, src1, tmp}); + else + operands.assign({dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TROWEXPANDMUL", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TROWEXPANDSUB DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORowExpandSubToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandSubOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); + + SmallVector operands; + if (tmp) + operands.assign({dst, src0, src1, tmp}); + else + operands.assign({dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TROWEXPANDSUB", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTORowExpandMaxToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); + + SmallVector operands; + if (tmp) + operands.assign({dst, src0, src1, tmp}); + else + operands.assign({dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TROWEXPANDMAX", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTORowExpandMinToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowExpandMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); + + SmallVector operands; + if (tmp) + operands.assign({dst, src0, src1, tmp}); + else + operands.assign({dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TROWEXPANDMIN", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TROWMAX DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORowMaxToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, tmp}; + rewriter.create( + loc, TypeRange{}, "TROWMAX", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTORowArgMaxToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowArgMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TROWARGMAX", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, tmp}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TROWMIN DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORowMinToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, tmp}; + rewriter.create( + loc, TypeRange{}, "TROWMIN", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTORowArgMinToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowArgMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TROWARGMIN", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, tmp}); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TROWSUM DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTORowSumToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowSumOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, tmp}; + rewriter.create( + loc, TypeRange{}, "TROWSUM", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTORowProdToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowProdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, tmp}; + rewriter.create( + loc, TypeRange{}, "TROWPROD", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TRSQRT DPS/memref op) +// - no-tmp form : TRSQRT(dst, src) +// - tmp form : TRSQRT(dst, src, tmp) +//===----------------------------------------------------------------------===// + +struct PTORsqrtToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRsqrtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + SmallVector operands{dst, src}; + if (Value tmp = adaptor.getTmp()) + operands.push_back(peelUnrealized(tmp)); + rewriter.create( + loc, TypeRange{}, "TRSQRT", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSCATTER DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOScatterToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TScatterOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + const bool hasMaskPattern = static_cast(op.getMaskPatternAttr()); + const bool hasIndexes = static_cast(op.getIndexes()); + if (hasMaskPattern == hasIndexes) { + return rewriter.notifyMatchFailure( + op, "expected exactly one of indexes operand or maskPattern attribute"); + } + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + if (auto mp = op.getMaskPatternAttr()) { + auto *ctx = rewriter.getContext(); + auto targs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, maskPatternTok(mp)), + }); + rewriter.create( + loc, TypeRange{}, "TSCATTER", + /*args=*/ArrayAttr{}, /*templateArgs=*/targs, + /*operands=*/ValueRange{dst, src}); + } else { + Value idx = peelUnrealized(adaptor.getIndexes()); + rewriter.create( + loc, TypeRange{}, "TSCATTER", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, idx}); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSEL DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOSelToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSelOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value mask = peelUnrealized(adaptor.getMask()); + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, mask, src0, src1, tmp}; + rewriter.create( + loc, TypeRange{}, "TSEL", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSELS DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOSelSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSelSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value mask = peelUnrealized(adaptor.getMask()); + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, mask, src, tmp, scalar}; + rewriter.create( + loc, TypeRange{}, "TSELS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSHL DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOShlSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TShlOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TSHL", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSHR DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOShrSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TShrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TSHR", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering for TSHLS/TSHRS DPS: shift by scalar) +//===----------------------------------------------------------------------===// + +struct PTOShlSConstToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TShlSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value dst = peelUnrealized(adaptor.getDst()); + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + SmallVector operands{dst, src, scalar}; + rewriter.create( + loc, TypeRange{}, "TSHLS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOShrSConstToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TShrSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value dst = peelUnrealized(adaptor.getDst()); + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + SmallVector operands{dst, src, scalar}; + rewriter.create( + loc, TypeRange{}, "TSHRS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (TSORT32 DPS/memref op: ins(src, idx[, tmp]) outs(dst)) +//===----------------------------------------------------------------------===// + +struct PTOSORT32SToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSort32Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value idx = peelUnrealized(adaptor.getIdx()); + Value tmp = op.getTmp() ? peelUnrealized(adaptor.getTmp()) : Value(); + + SmallVector operands; + if (tmp) + operands.assign({dst, src, idx, tmp}); + else + operands.assign({dst, src, idx}); + rewriter.create( + loc, TypeRange{}, "TSORT32", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSQRT DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOSqrtSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSqrtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src}; + rewriter.create( + loc, TypeRange{}, "TSQRT", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSTORE_FP DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOStoreFPSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TStoreFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value fp = peelUnrealized(adaptor.getFp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, fp}; + rewriter.create( + loc, TypeRange{}, "TSTORE_FP", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSUB DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOSubSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSubOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src0, src1}; + rewriter.create( + loc, TypeRange{}, "TSUB", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSUBC DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOSubCSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSubCOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value src2 = peelUnrealized(adaptor.getSrc2()); + Value dst = peelUnrealized(adaptor.getDst()); + + // pto-isa does not provide NPU implementation for TSUBC yet. + // Decompose: dst = src0 - src1 + src2 + rewriter.create( + loc, TypeRange{}, "TSUB", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, src1}); + rewriter.create( + loc, TypeRange{}, "TADD", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, dst, src2}); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSUBS DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOSubSSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSubSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, scalar}; + rewriter.create( + loc, TypeRange{}, "TSUBS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TSUBSC DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOSubSCToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TSubSCOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + + // pto-isa does not provide NPU implementation for TSUBSC yet. + // Decompose: dst = src0 - scalar + src1 + rewriter.create( + loc, TypeRange{}, "TSUBS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src0, scalar}); + rewriter.create( + loc, TypeRange{}, "TADD", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, dst, src1}); + + rewriter.eraseOp(op); + return success(); + } +}; + + +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TXOR DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOXORToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TXorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src0 = peelUnrealized(adaptor.getSrc0()); + Value src1 = peelUnrealized(adaptor.getSrc1()); + Value dst = peelUnrealized(adaptor.getDst()); + Value tmp = peelUnrealized(adaptor.getTmp()); + SmallVector operands{dst, src0, src1, tmp}; + rewriter.create( + loc, TypeRange{}, "TXOR", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +struct PTOTTransToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TTransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, tmp}; + rewriter.create( + loc, TypeRange{}, "TTRANS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +//===----------------------------------------------------------------------===// +// PTOConvert.cpp (add lowering + patterns.add for TXORS DPS/memref op) +//===----------------------------------------------------------------------===// + +struct PTOXORSToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TXorSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value scalar = peelUnrealized(adaptor.getScalar()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + SmallVector operands{dst, src, scalar, tmp}; + rewriter.create( + loc, TypeRange{}, "TXORS", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; +struct PTOPrintToTPRINT : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TPrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + + SmallVector operands{src}; + rewriter.create( + loc, TypeRange{}, "TPRINT", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/operands); + + rewriter.eraseOp(op); + return success(); + } +}; + +// pto.print "format", %scalar -> PRINTF("format", scalar) + +} // namespace + +void populatePTOToEmitCTileExtraPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter, + MLIRContext *ctx) { + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add< + PTOTMatmulBiasToTMATMUL_BIAS, + PTOTMatmulMXToTMATMUL_MX, + PTOTMatmulMXAccToTMATMUL_MX_ACC, + PTOTMatmulMXBiasToTMATMUL_MX_BIAS, + PTOTGemvBiasToTGEMV_BIAS, + PTOTGemvMXToTGEMV_MX, + PTOTGemvMXAccToTGEMV_MX, + PTOTGemvMXBiasToTGEMV_MX>(typeConverter, ctx); +} + +} // namespace mlir::pto diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index c21669b81..63f2c8687 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -15,6 +15,7 @@ #include "PTO/IR/PTO.h" #include "PTO/IR/PTOTypeUtils.h" #include "PTO/Transforms/Passes.h" +#include "PTOViewToMemrefInternal.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -98,11 +99,6 @@ constexpr int32_t kSLayoutColMajor = constexpr int32_t kCompactModeRowPlusOne = static_cast(CompactMode::RowPlusOne); -constexpr unsigned kThirdOperandIndex = 2; -constexpr unsigned kFourthOperandIndex = 3; -constexpr unsigned kFifthOperandIndex = 4; -constexpr unsigned kSixthOperandIndex = 5; - template using SmallInlineVector = SmallVector; @@ -1804,1781 +1800,12 @@ struct PTOViewToMemrefPass // ------------------------------------------------------------------ // Stage 3: Rewrite Compute Ops - // [关键] 全面使用 op->getOperand(i) 避免 Typed Accessor Crash // ------------------------------------------------------------------ - - // --- TLoadOp [Src, Dst] --- - DefaultInlineVector loads; - func.walk([&](mlir::pto::TLoadOp op) { loads.push_back(op); }); - for (auto op : loads) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op->getOperand(0); - Value dst = op->getOperand(1); - - auto newOp = - rewriter.create(op.getLoc(), TypeRange{}, src, dst); - newOp->setAttrs(op->getAttrs()); - rewriter.replaceOp(op, newOp->getResults()); - } - - // --- TStoreOp [Src, Dst] --- - DefaultInlineVector storeops; - func.walk([&](mlir::pto::TStoreOp op) { storeops.push_back(op); }); - for (auto op : storeops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op->getOperand(0); - Value dst = op->getOperand(1); - Value preQuant = op.getPreQuantScalar(); - - pto::TStoreOp newOp; - if (preQuant) { - newOp = rewriter.create(op.getLoc(), TypeRange{}, - src, dst, preQuant); - } else { - newOp = rewriter.create(op.getLoc(), TypeRange{}, - src, dst, Value{}); - } - newOp->setAttrs(op->getAttrs()); - rewriter.replaceOp(op, newOp->getResults()); - } - - // --- TTransOp [Src, Tmp, Dst] --- - DefaultInlineVector trans; - func.walk([&](mlir::pto::TTransOp op) { trans.push_back(op); }); - for (auto op : trans) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex)); - } - - // --- TExpOp [Src, Dst] --- - DefaultInlineVector exp; - func.walk([&](mlir::pto::TExpOp op) { exp.push_back(op); }); - for (auto op : exp) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, op->getOperand(0), op->getOperand(1)); - } - - // --- TMulOp [Src, Scalar, Dst] --- - DefaultInlineVector mul; - func.walk([&](mlir::pto::TMulOp op) { mul.push_back(op); }); - for (auto op : mul) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, op->getOperand(0), op.getOperand(1), - op->getOperand(kThirdOperandIndex)); - } - - // --- TMulSOp [Src, Scalar, Dst] --- - DefaultInlineVector muls; - func.walk([&](mlir::pto::TMulSOp op) { muls.push_back(op); }); - for (auto op : muls) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, op->getOperand(0), op.getScalar(), - op->getOperand(kThirdOperandIndex)); - } - - // --- TAddOp [Src0, Src1, Dst] --- - DefaultInlineVector addops; - func.walk([&](mlir::pto::TAddOp op) { addops.push_back(op); }); - for (auto op : addops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex)); - } - - // --- TMatmulOp [Lhs, Rhs, Dst] (no optional bias in ODS) --- - DefaultInlineVector matmuls; - func.walk([&](mlir::pto::TMatmulOp op) { matmuls.push_back(op); }); - for (auto op : matmuls) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - Value lhs = op->getOperand(0); - Value rhs = op->getOperand(1); - Value dst = op->getOperand(kThirdOperandIndex); - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, lhs, rhs, dst, op.getAccPhaseAttr()); - } - - // --- TMatmulAccOp [Acc, Lhs, Rhs, Dst] --- - DefaultInlineVector matmulAccs; - func.walk([&](mlir::pto::TMatmulAccOp op) { matmulAccs.push_back(op); }); - for (auto op : matmulAccs) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex), op.getAccPhaseAttr()); - } - - // --- TMatmulBiasOp [Acc, Lhs, Rhs, Bias, Dst] --- - DefaultInlineVector matmulBiass; - func.walk([&](mlir::pto::TMatmulBiasOp op) { matmulBiass.push_back(op); }); - for (auto op : matmulBiass) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex)); - } - - // --- TMatmulMxOp--- - DefaultInlineVector matmulMxs; - func.walk([&](mlir::pto::TMatmulMxOp op) { matmulMxs.push_back(op); }); - for (auto op : matmulMxs) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex), - op->getOperand(kFifthOperandIndex)); - } - - // --- TMatmulMxAccOp --- - DefaultInlineVector matmulMxAccs; - func.walk([&](mlir::pto::TMatmulMxAccOp op) { matmulMxAccs.push_back(op); }); - for (auto op : matmulMxAccs) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex), - op->getOperand(kFifthOperandIndex), - op->getOperand(kSixthOperandIndex)); - } - - // --- TMatmulMxBiasOp --- - DefaultInlineVector matmulMxBiass; - func.walk([&](mlir::pto::TMatmulMxBiasOp op) { matmulMxBiass.push_back(op); }); - for (auto op : matmulMxBiass) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex), - op->getOperand(kFifthOperandIndex), - op->getOperand(kSixthOperandIndex)); - } - - // --- TGemvOp [Lhs, Rhs, Dst] --- - DefaultInlineVector gemvs; - func.walk([&](mlir::pto::TGemvOp op) { gemvs.push_back(op); }); - for (auto op : gemvs) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value lhs = op->getOperand(0); - Value rhs = op->getOperand(1); - Value dst = op->getOperand(kThirdOperandIndex); - - rewriter.replaceOpWithNewOp( - op, TypeRange{}, lhs, rhs, dst); - } - - // --- TGemvAccOp [Acc, Lhs, Rhs, Dst] --- - DefaultInlineVector gemvAccs; - func.walk([&](mlir::pto::TGemvAccOp op) { gemvAccs.push_back(op); }); - for (auto op : gemvAccs) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex)); - } - - // --- TGemvBiasOp [Acc, Lhs, Rhs, Bias, Dst] --- - DefaultInlineVector gemvBiass; - func.walk([&](mlir::pto::TGemvBiasOp op) { gemvBiass.push_back(op); }); - for (auto op : gemvBiass) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex)); - } - - // --- TGemvMxOp [A, AScale, B, BScale, Dst] --- - DefaultInlineVector gemvMxs; - func.walk([&](mlir::pto::TGemvMxOp op) { gemvMxs.push_back(op); }); - for (auto op : gemvMxs) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex), - op->getOperand(kFifthOperandIndex)); - } - - // --- TGemvMxAccOp [CIn, A, AScale, B, BScale, Dst] --- - DefaultInlineVector gemvMxAccs; - func.walk([&](mlir::pto::TGemvMxAccOp op) { gemvMxAccs.push_back(op); }); - for (auto op : gemvMxAccs) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex), - op->getOperand(kFifthOperandIndex), - op->getOperand(kSixthOperandIndex)); - } - - // --- TGemvMxBiasOp [A, AScale, B, BScale, Bias, Dst] --- - DefaultInlineVector gemvMxBiass; - func.walk([&](mlir::pto::TGemvMxBiasOp op) { gemvMxBiass.push_back(op); }); - for (auto op : gemvMxBiass) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), - op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex), - op->getOperand(kFifthOperandIndex), - op->getOperand(kSixthOperandIndex)); - } - - // --- TMovOp [Src, Dst] --- - DefaultInlineVector movs; - func.walk([&](mlir::pto::TMovOp op) { movs.push_back(op); }); - for (auto op : movs) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, op.getSrc(), op.getDst(), op.getFp(), - op.getPreQuantScalar(), op.getAccToVecModeAttr(), - op.getReluPreModeAttr()); - } - - DefaultInlineVector abseops; - func.walk([&](mlir::pto::TAbsOp op) { abseops.push_back(op); }); - - for (auto op : abseops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - DefaultInlineVector addcops; - func.walk([&](mlir::pto::TAddCOp op) { addcops.push_back(op); }); - - for (auto op : addcops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value src2 = op.getSrc2(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto src2Ty = dyn_cast(src2.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !src2Ty ||!dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - src2, - dst); - } - - DefaultInlineVector addsops; - func.walk([&](mlir::pto::TAddSOp op) { addsops.push_back(op); }); - - for (auto op : addsops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value scalar = op.getScalar(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - scalar, - dst); - } - - DefaultInlineVector addscops; - func.walk([&](mlir::pto::TAddSCOp op) { addscops.push_back(op); }); - - for (auto op : addscops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value scalar = op.getScalar(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - scalar, - src1, - dst); - } - - DefaultInlineVector andops; - func.walk([&](mlir::pto::TAndOp op) { andops.push_back(op); }); - - for (auto op : andops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - dst); - } - - DefaultInlineVector concats; - func.walk([&](mlir::pto::TConcatOp op) { concats.push_back(op); }); - - for (auto op : concats) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - dst); - } - - DefaultInlineVector concatIdxs; - func.walk([&](mlir::pto::TConcatidxOp op) { concatIdxs.push_back(op); }); - - IRRewriter rewriter(ctx); - for (auto op : concatIdxs) { - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value src0Idx = op.getSrc0Idx(); - Value src1Idx = op.getSrc1Idx(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto src0IdxTy = dyn_cast(src0Idx.getType()); - auto src1IdxTy = dyn_cast(src1Idx.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !src0IdxTy || !src1IdxTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - src0Idx, - src1Idx, - dst); - } - - DefaultInlineVector andsops; - func.walk([&](mlir::pto::TAndSOp op) { andsops.push_back(op); }); - - for (auto op : andsops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value scalar = op.getScalar(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - scalar, - dst); - } - - DefaultInlineVector ciops; - func.walk([&](mlir::pto::TCIOp op) { ciops.push_back(op); }); - - for (auto op : ciops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value s = op->getOperand(0); - Value dst = op.getDst(); - bool descending = op.getDescending(); - - auto sTy = dyn_cast(s.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!sTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - s, - dst, - descending); - } - - DefaultInlineVector cmpops; - func.walk([&](mlir::pto::TCmpOp op) { cmpops.push_back(op); }); - - for (auto op : cmpops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - auto newOp = rewriter.create( - op.getLoc(), - TypeRange{}, - src0, - src1, - dst); - - if (auto a = op.getCmpModeAttr()) - newOp->setAttr("cmpMode", a); - - rewriter.replaceOp(op, newOp->getResults()); // 0 results -> OK - } - - DefaultInlineVector cmpsops; - func.walk([&](mlir::pto::TCmpSOp op) { cmpsops.push_back(op); }); - - for (auto op : cmpsops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value scalar = op.getScalar(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - auto scalarTy = scalar.getType(); - bool scalarOk = - isa(scalarTy); // ScalarType in ODS: int/float - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - if (!scalarOk) { - op.emitError("expects scalar to be an integer or float type"); - signalPassFailure(); - return; - } - - auto cmpMode = op.getCmpModeAttr(); - auto newOp = rewriter.create( - op.getLoc(), - TypeRange{}, - src, - scalar, - cmpMode, - dst); - - rewriter.replaceOp(op, newOp->getResults()); // 0 results -> OK - } - - DefaultInlineVector colexpand; - func.walk([&](mlir::pto::TColExpandOp op) { colexpand.push_back(op); }); - - for (auto op : colexpand) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if ( !srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - DefaultInlineVector colmaxops; - func.walk([&](mlir::pto::TColMaxOp op) { colmaxops.push_back(op); }); - - for (auto op : colmaxops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if ( !srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - DefaultInlineVector colminops; - func.walk([&](mlir::pto::TColMinOp op) { colminops.push_back(op); }); - - for (auto op : colminops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if ( !srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - DefaultInlineVector colexpandmulops; - func.walk([&](mlir::pto::TColExpandMulOp op) { - colexpandmulops.push_back(op); - }); - - for (auto op : colexpandmulops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - dst); - } - - DefaultInlineVector colexpandmaxops; - func.walk([&](mlir::pto::TColExpandMaxOp op) { - colexpandmaxops.push_back(op); - }); - - for (auto op : colexpandmaxops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - dst); - } - - DefaultInlineVector colexpandminops; - func.walk([&](mlir::pto::TColExpandMinOp op) { - colexpandminops.push_back(op); - }); - - for (auto op : colexpandminops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - dst); - } - - DefaultInlineVector colsumops; - func.walk([&](mlir::pto::TColSumOp op) { colsumops.push_back(op); }); - - for (auto op : colsumops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - Value tmp = op.getTmp(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("src/dst are not memref yet"); - signalPassFailure(); - return; - } - - // If tmp exists, it must have isBinary attribute - if (tmp) { - auto tmpTy = dyn_cast(tmp.getType()); - if (!tmpTy) { - op.emitError("tmp is not memref yet"); - signalPassFailure(); - return; - } - - // Get isBinary attribute (should exist if tmp exists) - BoolAttr isBinaryAttr = op.getIsBinaryAttr(); - if (!isBinaryAttr) { - isBinaryAttr = BoolAttr::get(ctx, false); - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - tmp, - dst, - isBinaryAttr); - } else { - // Format 1: no tmp, no isBinary - // Use generic builder to avoid adding default isBinary attribute - SmallVector operands = {src, dst}; - SmallVector attrs; - // Copy all attributes except isBinary - for (auto attr : op->getAttrs()) { - if (attr.getName() != "isBinary") { - attrs.push_back(attr); - } - } - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - operands, - attrs); - } - } - - DefaultInlineVector cvtops; - func.walk([&](mlir::pto::TCvtOp op) { cvtops.push_back(op); }); - - for (auto op : cvtops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - auto rmodeAttr = op.getRmodeAttr(); // PTO_RoundModeAttr - auto satModeAttr = op.getSatModeAttr(); - - auto newOp = rewriter.create( - op.getLoc(), - TypeRange{}, - src, - dst, - rmodeAttr, - satModeAttr); - - rewriter.replaceOp(op, newOp->getResults()); - } - - DefaultInlineVector divops; - func.walk([&](mlir::pto::TDivOp op) { divops.push_back(op); }); - - for (auto op : divops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - dst); - } - - DefaultInlineVector divsops; - func.walk([&](mlir::pto::TDivSOp op) { divsops.push_back(op); }); - - for (auto op : divsops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value scale = op.getScalar(); - Value dst = op.getDst(); - - // Check types - they might still be TileBufType or already converted to MemRefType - auto srcTy = dyn_cast(src.getType()); - auto srcTileTy = dyn_cast(src.getType()); - auto scaleTileTy = dyn_cast(scale.getType()); - auto dstTy = dyn_cast(dst.getType()); - auto dstTileTy = dyn_cast(dst.getType()); - - // Determine which operand is tile-like and which is scalar-like. - // Keep the original operand order (set by parser textual form). - // Check if src is memref/tensor/tile (not scalar) - bool srcIsMemref = (srcTy != nullptr || srcTileTy != nullptr || - isa(src.getType()) || - isa(src.getType())); - // Check if scale is memref/tensor/tile (not scalar) - bool scaleIsMemref = (isa(scale.getType()) || - scaleTileTy != nullptr || - isa(scale.getType()) || - isa(scale.getType())); - - // Type validation - ensure we have the right types - if (!srcIsMemref && !scaleIsMemref) { - op.emitError("at least one operand (src or scale) must be tile_buf or memref"); - signalPassFailure(); - return; - } - if (srcIsMemref && scaleIsMemref) { - op.emitError("exactly one operand (src or scale) must be tile_buf or memref, the other must be scalar"); - signalPassFailure(); - return; - } - - if (!dstTy && !dstTileTy) { - op.emitError("dst operand must be tile_buf or memref"); - signalPassFailure(); - return; - } - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - scale, - dst); - } - - DefaultInlineVector expandsops; - func.walk([&](mlir::pto::TExpandsOp op) { expandsops.push_back(op); }); - - for (auto op : expandsops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value scalar = op.getScalar(); - Value dst = op.getDst(); - - auto dstTy = dyn_cast(dst.getType()); - if (!dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - scalar, - dst); - } - - DefaultInlineVector extractops; - func.walk([&](mlir::pto::TExtractOp op) { extractops.push_back(op); }); - - for (auto op : extractops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value indexRow = op.getIndexRow(); - Value indexCol = op.getIndexCol(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto indexRowTy = dyn_cast(indexRow.getType()); - auto indexColTy = dyn_cast(indexCol.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !indexRowTy || !indexColTy || !dstTy) { - op.emitError("ins/outs are not correct yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - indexRow, - indexCol, - dst); - } - - DefaultInlineVector fillpadops; - func.walk([&](mlir::pto::TFillPadOp op) { fillpadops.push_back(op); }); - - for (auto op : fillpadops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - DefaultInlineVector fillpadInplaceOps; - func.walk( - [&](mlir::pto::TFillPadInplaceOp op) { fillpadInplaceOps.push_back(op); }); - - for (auto op : fillpadInplaceOps) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - // --- TSetValOp [Dst, Offset, Val] --- - // Lower tile-world scalar write to memref-world SETVAL DPS op. - DefaultInlineVector tsetvalops; - func.walk([&](mlir::pto::TSetValOp op) { tsetvalops.push_back(op); }); - - for (auto op : tsetvalops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value dst = op.getDst(); - Value offset = op.getOffset(); - Value val = op.getVal(); - - auto dstTy = dyn_cast(dst.getType()); - if (!dstTy) { - op.emitError("dst is not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - dst, - offset, - val); - } - - // --- TGetValOp [Src, Offset] -> Scalar --- - // Lower tile-world scalar read to memref-world GETVAL DPS op. - DefaultInlineVector tgetvalops; - func.walk([&](mlir::pto::TGetValOp op) { tgetvalops.push_back(op); }); - - for (auto op : tgetvalops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value offset = op.getOffset(); - Type dstType = op.getDst().getType(); - - auto srcTy = dyn_cast(src.getType()); - if (!srcTy) { - op.emitError("src is not memref yet"); - signalPassFailure(); - return; - } - - auto newOp = rewriter.create( - op.getLoc(), - dstType, - src, - offset); - rewriter.replaceOp(op, newOp.getDst()); - } - - DefaultInlineVector gatherops; - func.walk([&](mlir::pto::TGatherOp op) { gatherops.push_back(op); }); - - for (auto op : gatherops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - Value cdst = op.getCdst(); - Value indices = op.getIndices(); - Value tmp = op.getTmp(); - Value kValue = op.getKValue(); - auto maskPattern = op.getMaskPatternAttr(); - auto cmpMode = op.getCmpModeAttr(); - auto offset = op.getOffsetAttr(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - if (maskPattern) { - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst, - /*cdst=*/Value(), - /*indices=*/Value(), - /*tmp=*/Value(), - /*kValue=*/Value(), - /*maskPattern=*/maskPattern, - /*cmpMode=*/pto::CmpModeAttr(), - /*offset=*/IntegerAttr()); - continue; - } - - if (cdst || kValue) { - auto cdstTy = dyn_cast(cdst.getType()); - auto tmpTy = dyn_cast(tmp.getType()); - if (!cdstTy || !tmpTy) { - op.emitError("compare-form tgather expects cdst/tmp to be memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst, - cdst, - /*indices=*/Value(), - tmp, - kValue, - /*maskPattern=*/pto::MaskPatternAttr(), - cmpMode, - offset); - continue; - } - - if (indices || tmp) { - auto indicesTy = dyn_cast(indices.getType()); - auto tmpTy = dyn_cast(tmp.getType()); - if (!indicesTy || !tmpTy) { - op.emitError("index-form tgather expects indices/tmp to be memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst, - /*cdst=*/Value(), - indices, - tmp, - /*kValue=*/Value(), - /*maskPattern=*/pto::MaskPatternAttr(), - /*cmpMode=*/pto::CmpModeAttr(), - /*offset=*/IntegerAttr()); - continue; - } - - op.emitError("expects tgather to be in mask, index+tmp, or compare+tmp form"); + if (failed(lowerViewToMemrefComputeOps(func, ctx))) { signalPassFailure(); return; } - DefaultInlineVector gatherbops; - func.walk([&](mlir::pto::TGatherBOp op) { gatherbops.push_back(op); }); - - for (auto op : gatherbops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value offsets = op.getOffsets(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto offsetsTy = dyn_cast(offsets.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !offsetsTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - offsets, - dst); - } - - DefaultInlineVector logops; - func.walk([&](mlir::pto::TLogOp op) { logops.push_back(op); }); - - for (auto op : logops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - DefaultInlineVector lreluops; - func.walk([&](mlir::pto::TLReluOp op) { lreluops.push_back(op); }); - - for (auto op : lreluops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value slope = op.getSlope(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto slopeTy = dyn_cast(slope.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !slopeTy || !dstTy) { - op.emitError("ins/outs are not correct type yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - slope, - dst); - } - - DefaultInlineVector maxops; - func.walk([&](mlir::pto::TMaxOp op) { maxops.push_back(op); }); - - for (auto op : maxops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - dst); - } - - DefaultInlineVector maxsops; - func.walk([&](mlir::pto::TMaxSOp op) { maxsops.push_back(op); }); - - for (auto op : maxsops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value scalar = op.getScalar(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - bool scalarIsScalar = isa(scalar.getType()); - if (!srcTy || !scalarIsScalar || !dstTy) { - op.emitError("expects src/dst to be memref and scalar to be integer/float"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - scalar, - dst); - } - - DefaultInlineVector minops; - func.walk([&](mlir::pto::TMinOp op) { minops.push_back(op); }); - - for (auto op : minops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - dst); - } - - DefaultInlineVector minsops; - func.walk([&](mlir::pto::TMinSOp op) { minsops.push_back(op); }); - - for (auto op : minsops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value scalar = op.getScalar(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - bool scalarIsScalar = isa(scalar.getType()); - if (!srcTy || !scalarIsScalar || !dstTy) { - op.emitError("expects src/dst to be memref and scalar to be integer/float"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - scalar, - dst); - } - - DefaultInlineVector movfpops; - func.walk([&](mlir::pto::TMovFPOp op) { movfpops.push_back(op); }); - - for (auto op : movfpops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value fp = op.getFp(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto fpTy = dyn_cast(fp.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !fpTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - fp, - dst); - } - - DefaultInlineVector quantops; - func.walk([&](mlir::pto::TQuantOp op) { quantops.push_back(op); }); - - for (auto op : quantops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value fp = op.getFp(); - Value offset = op.getOffset(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto fpTy = dyn_cast(fp.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !fpTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - if (offset && !dyn_cast(offset.getType())) { - op.emitError("offset is not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - fp, - offset, - dst, - op.getQuantTypeAttr()); - } - - DefaultInlineVector mrgsortops; - func.walk([&](mlir::pto::TMrgSortOp op) { mrgsortops.push_back(op); }); - - for (auto op : mrgsortops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - if (op.isFormat1()) { - Value src = op.getSrc(); - Value dst = op.getDst(); - Value blockLenVal = op.getBlockLen(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - ValueRange{src}, - blockLenVal, - ValueRange{dst}, - Value() /*tmp*/, - Value() /*excuted*/, - op.getExhaustedAttr()); - } else if (op.isFormat2()) { - bool allMemRef = true; - for (Value v : op.getSrcs()) - if (!dyn_cast(v.getType())) { allMemRef = false; break; } - if (!allMemRef) { - op.emitError("format2 ins/outs are not memref yet"); - signalPassFailure(); - return; - } - if (op.getDsts().size() != 1u || !op.getTmp()) { - op.emitError("format2 expects outs(dst) and ins(tmp)"); - signalPassFailure(); - return; - } - - Value dst = op.getDst(); - Value tmp = op.getTmp(); - Value excuted = op.getExcuted(); - if (!dyn_cast(dst.getType()) || !dyn_cast(tmp.getType())) { - op.emitError("format2 dst/tmp must be memref"); - signalPassFailure(); - return; - } - if (!dyn_cast(excuted.getType())) { - op.emitError("format2 outs(excuted) must be vector"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - op.getSrcs(), - Value() /*blockLen*/, - ValueRange{dst}, - tmp, - excuted, - op.getExhaustedAttr()); - } else { - op.emitError("tmrgsort must be format1 or format2"); - signalPassFailure(); - return; - } - } - - DefaultInlineVector negops; - func.walk([&](mlir::pto::TNegOp op) { negops.push_back(op); }); - - for (auto op : negops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - DefaultInlineVector notops; - func.walk([&](mlir::pto::TNotOp op) { notops.push_back(op); }); - - for (auto op : notops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); - } - - DefaultInlineVector orops; - func.walk([&](mlir::pto::TOrOp op) { orops.push_back(op); }); - - for (auto op : orops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - src0, - src1, - dst); - } - - DefaultInlineVector orsops; - func.walk([&](mlir::pto::TOrSOp op) { orsops.push_back(op); }); - - for (auto op : orsops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value scalar = op.getScalar(); - Value dst = op.getDst(); - - auto srcTy = dyn_cast(src.getType()); - auto scalarTy = dyn_cast(scalar.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!srcTy || !scalarTy || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - scalar, - dst); - } - - DefaultInlineVector partaddops; - func.walk([&](mlir::pto::TPartAddOp op) { partaddops.push_back(op); }); - - for (auto op : partaddops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - src0, - src1, - dst); - } - - DefaultInlineVector partmulops; - func.walk([&](mlir::pto::TPartMulOp op) { partmulops.push_back(op); }); - - for (auto op : partmulops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src0 = op.getSrc0(); - Value src1 = op.getSrc1(); - Value dst = op.getDst(); - - auto src0Ty = dyn_cast(src0.getType()); - auto src1Ty = dyn_cast(src1.getType()); - auto dstTy = dyn_cast(dst.getType()); - if (!src0Ty || !src1Ty || !dstTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - src0, - src1, - dst); - } - - DefaultInlineVector mgatherops; - func.walk([&](mlir::pto::MGatherOp op) { mgatherops.push_back(op); }); - - for (auto op : mgatherops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value dst = op.getDst(); - Value idx = op.getIdx(); - Value mem = op.getMem(); - - auto dstTy = dyn_cast(dst.getType()); - auto idxTy = dyn_cast(idx.getType()); - auto memTy = dyn_cast(mem.getType()); - if (!dstTy || !idxTy || !memTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - mem, - idx, - dst, - op.getGatherOobAttr()); - } - - DefaultInlineVector mascatterops; - func.walk([&](mlir::pto::MScatterOp op) { mascatterops.push_back(op); }); - - for (auto op : mascatterops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - Value idx = op.getIdx(); - Value mem = op.getMem(); - - auto srcTy = dyn_cast(src.getType()); - auto idxTy = dyn_cast(idx.getType()); - auto memTy = dyn_cast(mem.getType()); - if (!srcTy || !idxTy || !memTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - idx, - mem, - op.getScatterAtomicOpAttr(), - op.getScatterOobAttr()); - } - DefaultInlineVector printops; - func.walk([&](mlir::pto::TPrintOp op) { printops.push_back(op); }); - - for (auto op : printops) { - IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(op); - - Value src = op.getSrc(); - - auto srcTy = dyn_cast(src.getType()); - if (!srcTy) { - op.emitError("ins/outs are not memref yet"); - signalPassFailure(); - return; - } - - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src); - } - // ------------------------------------------------------------------ // Stage 4: Reconcile control-flow result types // ------------------------------------------------------------------ diff --git a/lib/PTO/Transforms/PTOViewToMemrefCompute.cpp b/lib/PTO/Transforms/PTOViewToMemrefCompute.cpp new file mode 100644 index 000000000..47558fda0 --- /dev/null +++ b/lib/PTO/Transforms/PTOViewToMemrefCompute.cpp @@ -0,0 +1,1760 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- PTOViewToMemrefCompute.cpp ----------------------------------------===// +//===----------------------------------------------------------------------===// + +#include "PTOViewToMemrefInternal.h" + +#include "PTO/IR/PTO.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LLVM.h" + +using namespace mlir; + +namespace mlir::pto { + +namespace { + +template +using DefaultInlineVector = SmallVector; + +constexpr unsigned kThirdOperandIndex = 2; +constexpr unsigned kFourthOperandIndex = 3; +constexpr unsigned kFifthOperandIndex = 4; +constexpr unsigned kSixthOperandIndex = 5; + +} // namespace + +LogicalResult lowerViewToMemrefComputeOps(func::FuncOp func, MLIRContext *ctx) { +// ------------------------------------------------------------------ +// Stage 3: Rewrite Compute Ops +// [关键] 全面使用 op->getOperand(i) 避免 Typed Accessor Crash +// ------------------------------------------------------------------ + +// --- TLoadOp [Src, Dst] --- +DefaultInlineVector loads; +func.walk([&](mlir::pto::TLoadOp op) { loads.push_back(op); }); +for (auto op : loads) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op->getOperand(0); + Value dst = op->getOperand(1); + + auto newOp = + rewriter.create(op.getLoc(), TypeRange{}, src, dst); + newOp->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newOp->getResults()); +} + +// --- TStoreOp [Src, Dst] --- +DefaultInlineVector storeops; +func.walk([&](mlir::pto::TStoreOp op) { storeops.push_back(op); }); +for (auto op : storeops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op->getOperand(0); + Value dst = op->getOperand(1); + Value preQuant = op.getPreQuantScalar(); + + pto::TStoreOp newOp; + if (preQuant) { + newOp = rewriter.create(op.getLoc(), TypeRange{}, + src, dst, preQuant); + } else { + newOp = rewriter.create(op.getLoc(), TypeRange{}, + src, dst, Value{}); + } + newOp->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newOp->getResults()); +} + + // --- TTransOp [Src, Tmp, Dst] --- +DefaultInlineVector trans; +func.walk([&](mlir::pto::TTransOp op) { trans.push_back(op); }); +for (auto op : trans) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex)); +} + +// --- TExpOp [Src, Dst] --- +DefaultInlineVector exp; +func.walk([&](mlir::pto::TExpOp op) { exp.push_back(op); }); +for (auto op : exp) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, op->getOperand(0), op->getOperand(1)); +} + +// --- TMulOp [Src, Scalar, Dst] --- +DefaultInlineVector mul; +func.walk([&](mlir::pto::TMulOp op) { mul.push_back(op); }); +for (auto op : mul) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, op->getOperand(0), op.getOperand(1), + op->getOperand(kThirdOperandIndex)); +} + +// --- TMulSOp [Src, Scalar, Dst] --- +DefaultInlineVector muls; +func.walk([&](mlir::pto::TMulSOp op) { muls.push_back(op); }); +for (auto op : muls) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, op->getOperand(0), op.getScalar(), + op->getOperand(kThirdOperandIndex)); +} + +// --- TAddOp [Src0, Src1, Dst] --- +DefaultInlineVector addops; +func.walk([&](mlir::pto::TAddOp op) { addops.push_back(op); }); +for (auto op : addops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex)); +} + +// --- TMatmulOp [Lhs, Rhs, Dst] (no optional bias in ODS) --- +DefaultInlineVector matmuls; +func.walk([&](mlir::pto::TMatmulOp op) { matmuls.push_back(op); }); +for (auto op : matmuls) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Value lhs = op->getOperand(0); + Value rhs = op->getOperand(1); + Value dst = op->getOperand(kThirdOperandIndex); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, lhs, rhs, dst, op.getAccPhaseAttr()); +} + +// --- TMatmulAccOp [Acc, Lhs, Rhs, Dst] --- +DefaultInlineVector matmulAccs; +func.walk([&](mlir::pto::TMatmulAccOp op) { matmulAccs.push_back(op); }); +for (auto op : matmulAccs) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex), op.getAccPhaseAttr()); +} + +// --- TMatmulBiasOp [Acc, Lhs, Rhs, Bias, Dst] --- +DefaultInlineVector matmulBiass; +func.walk([&](mlir::pto::TMatmulBiasOp op) { matmulBiass.push_back(op); }); +for (auto op : matmulBiass) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex)); +} + +// --- TMatmulMxOp--- +DefaultInlineVector matmulMxs; +func.walk([&](mlir::pto::TMatmulMxOp op) { matmulMxs.push_back(op); }); +for (auto op : matmulMxs) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex), + op->getOperand(kFifthOperandIndex)); +} + +// --- TMatmulMxAccOp --- +DefaultInlineVector matmulMxAccs; +func.walk([&](mlir::pto::TMatmulMxAccOp op) { matmulMxAccs.push_back(op); }); +for (auto op : matmulMxAccs) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex), + op->getOperand(kFifthOperandIndex), + op->getOperand(kSixthOperandIndex)); +} + +// --- TMatmulMxBiasOp --- +DefaultInlineVector matmulMxBiass; +func.walk([&](mlir::pto::TMatmulMxBiasOp op) { matmulMxBiass.push_back(op); }); +for (auto op : matmulMxBiass) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex), + op->getOperand(kFifthOperandIndex), + op->getOperand(kSixthOperandIndex)); +} + +// --- TGemvOp [Lhs, Rhs, Dst] --- +DefaultInlineVector gemvs; +func.walk([&](mlir::pto::TGemvOp op) { gemvs.push_back(op); }); +for (auto op : gemvs) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value lhs = op->getOperand(0); + Value rhs = op->getOperand(1); + Value dst = op->getOperand(kThirdOperandIndex); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, lhs, rhs, dst); +} + +// --- TGemvAccOp [Acc, Lhs, Rhs, Dst] --- +DefaultInlineVector gemvAccs; +func.walk([&](mlir::pto::TGemvAccOp op) { gemvAccs.push_back(op); }); +for (auto op : gemvAccs) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex)); +} + +// --- TGemvBiasOp [Acc, Lhs, Rhs, Bias, Dst] --- +DefaultInlineVector gemvBiass; +func.walk([&](mlir::pto::TGemvBiasOp op) { gemvBiass.push_back(op); }); +for (auto op : gemvBiass) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex)); +} + +// --- TGemvMxOp [A, AScale, B, BScale, Dst] --- +DefaultInlineVector gemvMxs; +func.walk([&](mlir::pto::TGemvMxOp op) { gemvMxs.push_back(op); }); +for (auto op : gemvMxs) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex), + op->getOperand(kFifthOperandIndex)); +} + +// --- TGemvMxAccOp [CIn, A, AScale, B, BScale, Dst] --- +DefaultInlineVector gemvMxAccs; +func.walk([&](mlir::pto::TGemvMxAccOp op) { gemvMxAccs.push_back(op); }); +for (auto op : gemvMxAccs) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex), + op->getOperand(kFifthOperandIndex), + op->getOperand(kSixthOperandIndex)); +} + +// --- TGemvMxBiasOp [A, AScale, B, BScale, Bias, Dst] --- +DefaultInlineVector gemvMxBiass; +func.walk([&](mlir::pto::TGemvMxBiasOp op) { gemvMxBiass.push_back(op); }); +for (auto op : gemvMxBiass) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, + op->getOperand(0), op->getOperand(1), + op->getOperand(kThirdOperandIndex), + op->getOperand(kFourthOperandIndex), + op->getOperand(kFifthOperandIndex), + op->getOperand(kSixthOperandIndex)); +} + +// --- TMovOp [Src, Dst] --- +DefaultInlineVector movs; +func.walk([&](mlir::pto::TMovOp op) { movs.push_back(op); }); +for (auto op : movs) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, TypeRange{}, op.getSrc(), op.getDst(), op.getFp(), + op.getPreQuantScalar(), op.getAccToVecModeAttr(), + op.getReluPreModeAttr()); +} + +DefaultInlineVector abseops; +func.walk([&](mlir::pto::TAbsOp op) { abseops.push_back(op); }); + +for (auto op : abseops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); +} + +DefaultInlineVector addcops; +func.walk([&](mlir::pto::TAddCOp op) { addcops.push_back(op); }); + +for (auto op : addcops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value src2 = op.getSrc2(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto src2Ty = dyn_cast(src2.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !src2Ty ||!dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + src2, + dst); +} + +DefaultInlineVector addsops; +func.walk([&](mlir::pto::TAddSOp op) { addsops.push_back(op); }); + +for (auto op : addsops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value scalar = op.getScalar(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + scalar, + dst); +} + +DefaultInlineVector addscops; +func.walk([&](mlir::pto::TAddSCOp op) { addscops.push_back(op); }); + +for (auto op : addscops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value scalar = op.getScalar(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + scalar, + src1, + dst); +} + +DefaultInlineVector andops; +func.walk([&](mlir::pto::TAndOp op) { andops.push_back(op); }); + +for (auto op : andops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + dst); +} + +DefaultInlineVector concats; +func.walk([&](mlir::pto::TConcatOp op) { concats.push_back(op); }); + +for (auto op : concats) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + dst); +} + +DefaultInlineVector concatIdxs; +func.walk([&](mlir::pto::TConcatidxOp op) { concatIdxs.push_back(op); }); + +IRRewriter rewriter(ctx); +for (auto op : concatIdxs) { + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value src0Idx = op.getSrc0Idx(); + Value src1Idx = op.getSrc1Idx(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto src0IdxTy = dyn_cast(src0Idx.getType()); + auto src1IdxTy = dyn_cast(src1Idx.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !src0IdxTy || !src1IdxTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + src0Idx, + src1Idx, + dst); +} + +DefaultInlineVector andsops; +func.walk([&](mlir::pto::TAndSOp op) { andsops.push_back(op); }); + +for (auto op : andsops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value scalar = op.getScalar(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + scalar, + dst); +} + +DefaultInlineVector ciops; +func.walk([&](mlir::pto::TCIOp op) { ciops.push_back(op); }); + +for (auto op : ciops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value s = op->getOperand(0); + Value dst = op.getDst(); + bool descending = op.getDescending(); + + auto sTy = dyn_cast(s.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!sTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + s, + dst, + descending); +} + +DefaultInlineVector cmpops; +func.walk([&](mlir::pto::TCmpOp op) { cmpops.push_back(op); }); + +for (auto op : cmpops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + auto newOp = rewriter.create( + op.getLoc(), + TypeRange{}, + src0, + src1, + dst); + + if (auto a = op.getCmpModeAttr()) + newOp->setAttr("cmpMode", a); + + rewriter.replaceOp(op, newOp->getResults()); // 0 results -> OK +} + +DefaultInlineVector cmpsops; +func.walk([&](mlir::pto::TCmpSOp op) { cmpsops.push_back(op); }); + +for (auto op : cmpsops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value scalar = op.getScalar(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + auto scalarTy = scalar.getType(); + bool scalarOk = + isa(scalarTy); // ScalarType in ODS: int/float + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + if (!scalarOk) { + op.emitError("expects scalar to be an integer or float type"); + return failure(); + } + + auto cmpMode = op.getCmpModeAttr(); + auto newOp = rewriter.create( + op.getLoc(), + TypeRange{}, + src, + scalar, + cmpMode, + dst); + + rewriter.replaceOp(op, newOp->getResults()); // 0 results -> OK +} + +DefaultInlineVector colexpand; +func.walk([&](mlir::pto::TColExpandOp op) { colexpand.push_back(op); }); + +for (auto op : colexpand) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if ( !srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); +} + +DefaultInlineVector colmaxops; +func.walk([&](mlir::pto::TColMaxOp op) { colmaxops.push_back(op); }); + +for (auto op : colmaxops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if ( !srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); +} + +DefaultInlineVector colminops; +func.walk([&](mlir::pto::TColMinOp op) { colminops.push_back(op); }); + +for (auto op : colminops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if ( !srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); +} + +DefaultInlineVector colexpandmulops; +func.walk([&](mlir::pto::TColExpandMulOp op) { + colexpandmulops.push_back(op); +}); + +for (auto op : colexpandmulops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + dst); +} + +DefaultInlineVector colexpandmaxops; +func.walk([&](mlir::pto::TColExpandMaxOp op) { + colexpandmaxops.push_back(op); +}); + +for (auto op : colexpandmaxops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + dst); +} + +DefaultInlineVector colexpandminops; +func.walk([&](mlir::pto::TColExpandMinOp op) { + colexpandminops.push_back(op); +}); + +for (auto op : colexpandminops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + dst); +} + +DefaultInlineVector colsumops; +func.walk([&](mlir::pto::TColSumOp op) { colsumops.push_back(op); }); + +for (auto op : colsumops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + Value tmp = op.getTmp(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("src/dst are not memref yet"); + return failure(); + } + + // If tmp exists, it must have isBinary attribute + if (tmp) { + auto tmpTy = dyn_cast(tmp.getType()); + if (!tmpTy) { + op.emitError("tmp is not memref yet"); + return failure(); + } + + // Get isBinary attribute (should exist if tmp exists) + BoolAttr isBinaryAttr = op.getIsBinaryAttr(); + if (!isBinaryAttr) { + isBinaryAttr = BoolAttr::get(ctx, false); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + tmp, + dst, + isBinaryAttr); + } else { + // Format 1: no tmp, no isBinary + // Use generic builder to avoid adding default isBinary attribute + SmallVector operands = {src, dst}; + SmallVector attrs; + // Copy all attributes except isBinary + for (auto attr : op->getAttrs()) { + if (attr.getName() != "isBinary") { + attrs.push_back(attr); + } + } + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + operands, + attrs); + } +} + +DefaultInlineVector cvtops; +func.walk([&](mlir::pto::TCvtOp op) { cvtops.push_back(op); }); + +for (auto op : cvtops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + auto rmodeAttr = op.getRmodeAttr(); // PTO_RoundModeAttr + auto satModeAttr = op.getSatModeAttr(); + + auto newOp = rewriter.create( + op.getLoc(), + TypeRange{}, + src, + dst, + rmodeAttr, + satModeAttr); + + rewriter.replaceOp(op, newOp->getResults()); +} + +DefaultInlineVector divops; +func.walk([&](mlir::pto::TDivOp op) { divops.push_back(op); }); + +for (auto op : divops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + dst); +} + +DefaultInlineVector divsops; +func.walk([&](mlir::pto::TDivSOp op) { divsops.push_back(op); }); + +for (auto op : divsops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value scale = op.getScalar(); + Value dst = op.getDst(); + + // Check types - they might still be TileBufType or already converted to MemRefType + auto srcTy = dyn_cast(src.getType()); + auto srcTileTy = dyn_cast(src.getType()); + auto scaleTileTy = dyn_cast(scale.getType()); + auto dstTy = dyn_cast(dst.getType()); + auto dstTileTy = dyn_cast(dst.getType()); + + // Determine which operand is tile-like and which is scalar-like. + // Keep the original operand order (set by parser textual form). + // Check if src is memref/tensor/tile (not scalar) + bool srcIsMemref = (srcTy != nullptr || srcTileTy != nullptr || + isa(src.getType()) || + isa(src.getType())); + // Check if scale is memref/tensor/tile (not scalar) + bool scaleIsMemref = (isa(scale.getType()) || + scaleTileTy != nullptr || + isa(scale.getType()) || + isa(scale.getType())); + + // Type validation - ensure we have the right types + if (!srcIsMemref && !scaleIsMemref) { + op.emitError("at least one operand (src or scale) must be tile_buf or memref"); + return failure(); + } + if (srcIsMemref && scaleIsMemref) { + op.emitError("exactly one operand (src or scale) must be tile_buf or memref, the other must be scalar"); + return failure(); + } + + if (!dstTy && !dstTileTy) { + op.emitError("dst operand must be tile_buf or memref"); + return failure(); + } + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + scale, + dst); +} + +DefaultInlineVector expandsops; +func.walk([&](mlir::pto::TExpandsOp op) { expandsops.push_back(op); }); + +for (auto op : expandsops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value scalar = op.getScalar(); + Value dst = op.getDst(); + + auto dstTy = dyn_cast(dst.getType()); + if (!dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + scalar, + dst); +} + +DefaultInlineVector extractops; +func.walk([&](mlir::pto::TExtractOp op) { extractops.push_back(op); }); + +for (auto op : extractops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value indexRow = op.getIndexRow(); + Value indexCol = op.getIndexCol(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto indexRowTy = dyn_cast(indexRow.getType()); + auto indexColTy = dyn_cast(indexCol.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !indexRowTy || !indexColTy || !dstTy) { + op.emitError("ins/outs are not correct yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + indexRow, + indexCol, + dst); +} + +DefaultInlineVector fillpadops; +func.walk([&](mlir::pto::TFillPadOp op) { fillpadops.push_back(op); }); + +for (auto op : fillpadops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); +} + +DefaultInlineVector fillpadInplaceOps; +func.walk( + [&](mlir::pto::TFillPadInplaceOp op) { fillpadInplaceOps.push_back(op); }); + +for (auto op : fillpadInplaceOps) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); +} + +// --- TSetValOp [Dst, Offset, Val] --- +// Lower tile-world scalar write to memref-world SETVAL DPS op. +DefaultInlineVector tsetvalops; +func.walk([&](mlir::pto::TSetValOp op) { tsetvalops.push_back(op); }); + +for (auto op : tsetvalops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value dst = op.getDst(); + Value offset = op.getOffset(); + Value val = op.getVal(); + + auto dstTy = dyn_cast(dst.getType()); + if (!dstTy) { + op.emitError("dst is not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + dst, + offset, + val); +} + +// --- TGetValOp [Src, Offset] -> Scalar --- +// Lower tile-world scalar read to memref-world GETVAL DPS op. +DefaultInlineVector tgetvalops; +func.walk([&](mlir::pto::TGetValOp op) { tgetvalops.push_back(op); }); + +for (auto op : tgetvalops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value offset = op.getOffset(); + Type dstType = op.getDst().getType(); + + auto srcTy = dyn_cast(src.getType()); + if (!srcTy) { + op.emitError("src is not memref yet"); + return failure(); + } + + auto newOp = rewriter.create( + op.getLoc(), + dstType, + src, + offset); + rewriter.replaceOp(op, newOp.getDst()); +} + +DefaultInlineVector gatherops; +func.walk([&](mlir::pto::TGatherOp op) { gatherops.push_back(op); }); + +for (auto op : gatherops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + Value cdst = op.getCdst(); + Value indices = op.getIndices(); + Value tmp = op.getTmp(); + Value kValue = op.getKValue(); + auto maskPattern = op.getMaskPatternAttr(); + auto cmpMode = op.getCmpModeAttr(); + auto offset = op.getOffsetAttr(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + if (maskPattern) { + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst, + /*cdst=*/Value(), + /*indices=*/Value(), + /*tmp=*/Value(), + /*kValue=*/Value(), + /*maskPattern=*/maskPattern, + /*cmpMode=*/pto::CmpModeAttr(), + /*offset=*/IntegerAttr()); + continue; + } + + if (cdst || kValue) { + auto cdstTy = dyn_cast(cdst.getType()); + auto tmpTy = dyn_cast(tmp.getType()); + if (!cdstTy || !tmpTy) { + op.emitError("compare-form tgather expects cdst/tmp to be memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst, + cdst, + /*indices=*/Value(), + tmp, + kValue, + /*maskPattern=*/pto::MaskPatternAttr(), + cmpMode, + offset); + continue; + } + + if (indices || tmp) { + auto indicesTy = dyn_cast(indices.getType()); + auto tmpTy = dyn_cast(tmp.getType()); + if (!indicesTy || !tmpTy) { + op.emitError("index-form tgather expects indices/tmp to be memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst, + /*cdst=*/Value(), + indices, + tmp, + /*kValue=*/Value(), + /*maskPattern=*/pto::MaskPatternAttr(), + /*cmpMode=*/pto::CmpModeAttr(), + /*offset=*/IntegerAttr()); + continue; + } + + op.emitError("expects tgather to be in mask, index+tmp, or compare+tmp form"); + return failure(); +} + +DefaultInlineVector gatherbops; +func.walk([&](mlir::pto::TGatherBOp op) { gatherbops.push_back(op); }); + +for (auto op : gatherbops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value offsets = op.getOffsets(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto offsetsTy = dyn_cast(offsets.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !offsetsTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + offsets, + dst); +} + +DefaultInlineVector logops; +func.walk([&](mlir::pto::TLogOp op) { logops.push_back(op); }); + +for (auto op : logops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); +} + +DefaultInlineVector lreluops; +func.walk([&](mlir::pto::TLReluOp op) { lreluops.push_back(op); }); + +for (auto op : lreluops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value slope = op.getSlope(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto slopeTy = dyn_cast(slope.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !slopeTy || !dstTy) { + op.emitError("ins/outs are not correct type yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + slope, + dst); +} + +DefaultInlineVector maxops; +func.walk([&](mlir::pto::TMaxOp op) { maxops.push_back(op); }); + +for (auto op : maxops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + dst); +} + +DefaultInlineVector maxsops; +func.walk([&](mlir::pto::TMaxSOp op) { maxsops.push_back(op); }); + +for (auto op : maxsops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value scalar = op.getScalar(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + bool scalarIsScalar = isa(scalar.getType()); + if (!srcTy || !scalarIsScalar || !dstTy) { + op.emitError("expects src/dst to be memref and scalar to be integer/float"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + scalar, + dst); +} + +DefaultInlineVector minops; +func.walk([&](mlir::pto::TMinOp op) { minops.push_back(op); }); + +for (auto op : minops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src0, + src1, + dst); +} + +DefaultInlineVector minsops; +func.walk([&](mlir::pto::TMinSOp op) { minsops.push_back(op); }); + +for (auto op : minsops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value scalar = op.getScalar(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + bool scalarIsScalar = isa(scalar.getType()); + if (!srcTy || !scalarIsScalar || !dstTy) { + op.emitError("expects src/dst to be memref and scalar to be integer/float"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + scalar, + dst); +} + +DefaultInlineVector movfpops; +func.walk([&](mlir::pto::TMovFPOp op) { movfpops.push_back(op); }); + +for (auto op : movfpops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value fp = op.getFp(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto fpTy = dyn_cast(fp.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !fpTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + fp, + dst); +} + +DefaultInlineVector quantops; +func.walk([&](mlir::pto::TQuantOp op) { quantops.push_back(op); }); + +for (auto op : quantops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value fp = op.getFp(); + Value offset = op.getOffset(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto fpTy = dyn_cast(fp.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !fpTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + if (offset && !dyn_cast(offset.getType())) { + op.emitError("offset is not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + fp, + offset, + dst, + op.getQuantTypeAttr()); +} + +DefaultInlineVector mrgsortops; +func.walk([&](mlir::pto::TMrgSortOp op) { mrgsortops.push_back(op); }); + +for (auto op : mrgsortops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + if (op.isFormat1()) { + Value src = op.getSrc(); + Value dst = op.getDst(); + Value blockLenVal = op.getBlockLen(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + ValueRange{src}, + blockLenVal, + ValueRange{dst}, + Value() /*tmp*/, + Value() /*excuted*/, + op.getExhaustedAttr()); + } else if (op.isFormat2()) { + bool allMemRef = true; + for (Value v : op.getSrcs()) + if (!dyn_cast(v.getType())) { allMemRef = false; break; } + if (!allMemRef) { + op.emitError("format2 ins/outs are not memref yet"); + return failure(); + } + if (op.getDsts().size() != 1u || !op.getTmp()) { + op.emitError("format2 expects outs(dst) and ins(tmp)"); + return failure(); + } + + Value dst = op.getDst(); + Value tmp = op.getTmp(); + Value excuted = op.getExcuted(); + if (!dyn_cast(dst.getType()) || !dyn_cast(tmp.getType())) { + op.emitError("format2 dst/tmp must be memref"); + return failure(); + } + if (!dyn_cast(excuted.getType())) { + op.emitError("format2 outs(excuted) must be vector"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + op.getSrcs(), + Value() /*blockLen*/, + ValueRange{dst}, + tmp, + excuted, + op.getExhaustedAttr()); + } else { + op.emitError("tmrgsort must be format1 or format2"); + return failure(); + } +} + +DefaultInlineVector negops; +func.walk([&](mlir::pto::TNegOp op) { negops.push_back(op); }); + +for (auto op : negops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); +} + +DefaultInlineVector notops; +func.walk([&](mlir::pto::TNotOp op) { notops.push_back(op); }); + +for (auto op : notops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + dst); +} + +DefaultInlineVector orops; +func.walk([&](mlir::pto::TOrOp op) { orops.push_back(op); }); + +for (auto op : orops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + src0, + src1, + dst); +} + +DefaultInlineVector orsops; +func.walk([&](mlir::pto::TOrSOp op) { orsops.push_back(op); }); + +for (auto op : orsops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value scalar = op.getScalar(); + Value dst = op.getDst(); + + auto srcTy = dyn_cast(src.getType()); + auto scalarTy = dyn_cast(scalar.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!srcTy || !scalarTy || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + scalar, + dst); +} + +DefaultInlineVector partaddops; +func.walk([&](mlir::pto::TPartAddOp op) { partaddops.push_back(op); }); + +for (auto op : partaddops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + src0, + src1, + dst); +} + +DefaultInlineVector partmulops; +func.walk([&](mlir::pto::TPartMulOp op) { partmulops.push_back(op); }); + +for (auto op : partmulops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src0 = op.getSrc0(); + Value src1 = op.getSrc1(); + Value dst = op.getDst(); + + auto src0Ty = dyn_cast(src0.getType()); + auto src1Ty = dyn_cast(src1.getType()); + auto dstTy = dyn_cast(dst.getType()); + if (!src0Ty || !src1Ty || !dstTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + src0, + src1, + dst); +} + +DefaultInlineVector mgatherops; +func.walk([&](mlir::pto::MGatherOp op) { mgatherops.push_back(op); }); + +for (auto op : mgatherops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value dst = op.getDst(); + Value idx = op.getIdx(); + Value mem = op.getMem(); + + auto dstTy = dyn_cast(dst.getType()); + auto idxTy = dyn_cast(idx.getType()); + auto memTy = dyn_cast(mem.getType()); + if (!dstTy || !idxTy || !memTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + mem, + idx, + dst, + op.getGatherOobAttr()); +} + +DefaultInlineVector mascatterops; +func.walk([&](mlir::pto::MScatterOp op) { mascatterops.push_back(op); }); + +for (auto op : mascatterops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + Value idx = op.getIdx(); + Value mem = op.getMem(); + + auto srcTy = dyn_cast(src.getType()); + auto idxTy = dyn_cast(idx.getType()); + auto memTy = dyn_cast(mem.getType()); + if (!srcTy || !idxTy || !memTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src, + idx, + mem, + op.getScatterAtomicOpAttr(), + op.getScatterOobAttr()); +} +DefaultInlineVector printops; +func.walk([&](mlir::pto::TPrintOp op) { printops.push_back(op); }); + +for (auto op : printops) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + + Value src = op.getSrc(); + + auto srcTy = dyn_cast(src.getType()); + if (!srcTy) { + op.emitError("ins/outs are not memref yet"); + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, + TypeRange{}, + src); +} + + + return success(); +} + +} // namespace mlir::pto diff --git a/lib/PTO/Transforms/PTOViewToMemrefInternal.h b/lib/PTO/Transforms/PTOViewToMemrefInternal.h new file mode 100644 index 000000000..8cfb80d2c --- /dev/null +++ b/lib/PTO/Transforms/PTOViewToMemrefInternal.h @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef MLIR_DIALECT_PTO_TRANSFORMS_PTOVIEWTOMEMREFINTERNAL_H +#define MLIR_DIALECT_PTO_TRANSFORMS_PTOVIEWTOMEMREFINTERNAL_H + +#include "mlir/IR/MLIRContext.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir::func { +class FuncOp; +} // namespace mlir::func + +namespace mlir::pto { + +LogicalResult lowerViewToMemrefComputeOps(func::FuncOp func, MLIRContext *ctx); + +} // namespace mlir::pto + +#endif // MLIR_DIALECT_PTO_TRANSFORMS_PTOVIEWTOMEMREFINTERNAL_H diff --git a/python/pto/dialects/pto.py b/python/pto/dialects/pto.py index a0b5037c7..eca06e81b 100644 --- a/python/pto/dialects/pto.py +++ b/python/pto/dialects/pto.py @@ -1,4 +1,5 @@ # Copyright (c) 2026 Huawei Technologies Co., Ltd. +# -*- coding: utf-8 -*- # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License"). # Please refer to the License for details. You may not use this file except in compliance with the License. diff --git a/tools/ptobc/CMakeLists.txt b/tools/ptobc/CMakeLists.txt index 8224cd637..e21e7bc3d 100644 --- a/tools/ptobc/CMakeLists.txt +++ b/tools/ptobc/CMakeLists.txt @@ -23,6 +23,7 @@ add_library(ptobc_lib STATIC src/mlir_encode.cpp src/canonical_printer.cpp src/ptobc_decode_print.cpp + generated/ptobc_opcodes_v0.cpp ) target_include_directories(ptobc_lib PUBLIC diff --git a/tools/ptobc/generated/ptobc_opcodes_v0.cpp b/tools/ptobc/generated/ptobc_opcodes_v0.cpp new file mode 100644 index 000000000..233854ead --- /dev/null +++ b/tools/ptobc/generated/ptobc_opcodes_v0.cpp @@ -0,0 +1,679 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Generated by docs/bytecode/tools/gen_v0_tables.py + +#include "ptobc_opcodes_v0.h" + +namespace ptobc::v0 { + +const OpInfo kOpTable[] = { + {0x0000, "pto.get_block_idx", 0, 0x01, 0x00, 0, 1, 0, 0x00}, + {0x0001, "pto.get_block_num", 0, 0x01, 0x00, 0, 1, 0, 0x00}, + {0x0002, "pto.get_subblock_idx", 0, 0x01, 0x00, 0, 1, 0, 0x00}, + {0x0003, "pto.get_subblock_num", 0, 0x01, 0x00, 0, 1, 0, 0x00}, + {0x0004, "pto.make_tensor_view", 0, 0x01, 0x03, 1, 1, 0, 0x06}, + {0x0005, "pto.partition_view", 0, 0x01, 0x03, 1, 1, 0, 0x07}, + {0x0006, "pto.section", 1, 0x00, 0x00, 0, 0, 1, 0x00}, + {0x1000, "pto.addptr", 0, 0x01, 0x00, 2, 1, 0, 0x00}, + {0x1001, "pto.alloc_tile", 0, 0x01, 0x04, 0, 1, 0, 0x08}, + {0x1002, "pto.barrier", 0, 0x00, 0x00, 0, 0, 0, 0x00}, + {0x1003, "pto.mgather", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1004, "pto.mscatter", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1005, "pto.record_event", 0, 0x00, 0x00, 0, 0, 0, 0x02}, + {0x1006, "pto.tabs", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1007, "pto.tadd", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1008, "pto.taddc", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x1009, "pto.tadds", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x100A, "pto.taddsc", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x100B, "pto.tand", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x100C, "pto.tands", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x100D, "pto.tci", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x100E, "pto.tcmp", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x100F, "pto.tcmps", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1010, "pto.tcolexpand", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1011, "pto.tcolexpandadd", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1012, "pto.tcolexpanddiv", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1013, "pto.tcolexpandexpdif", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1014, "pto.tcolexpandmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1015, "pto.tcolexpandmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1016, "pto.tcolexpandmul", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1017, "pto.tcolexpandsub", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1018, "pto.tcolmax", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1019, "pto.tcolmin", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x101A, "pto.tcolprod", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x101B, "pto.tcolsum", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x101C, "pto.tcvt", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x101D, "pto.tdiv", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x101E, "pto.tdivs", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x101F, "pto.texp", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1020, "pto.texpands", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1021, "pto.textract", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x1022, "pto.textract_fp", 0, 0x00, 0x00, 5, 0, 0, 0x00}, + {0x1023, "pto.tfillpad", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1024, "pto.tfillpad_expand", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1025, "pto.tfillpad_inplace", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1026, "pto.tfmod", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1027, "pto.tfmods", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1028, "pto.tgather", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1029, "pto.tgatherb", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x102A, "pto.tgemv", 1, 0x00, 0x01, 0, 0, 0, 0x00}, + {0x102B, "pto.tgetval", 0, 0x01, 0x00, 2, 1, 0, 0x00}, + {0x102C, "pto.timg2col", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x102D, "pto.tinsert", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x102E, "pto.tinsert_fp", 0, 0x00, 0x00, 5, 0, 0, 0x00}, + {0x102F, "pto.tload", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1030, "pto.tlog", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1031, "pto.tlrelu", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1032, "pto.tmatmul", 1, 0x00, 0x01, 0, 0, 0, 0x00}, + {0x1033, "pto.tmatmul.mx", 1, 0x00, 0x01, 0, 0, 0, 0x00}, + {0x1034, "pto.tmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1035, "pto.tmaxs", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1036, "pto.tmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1037, "pto.tmins", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1038, "pto.tmov", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1039, "pto.tmov.fp", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x103A, "pto.tmrgsort", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x103B, "pto.tmul", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x103C, "pto.tmuls", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x103D, "pto.tneg", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x103E, "pto.tnot", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x103F, "pto.tor", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1040, "pto.tors", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1041, "pto.tpartadd", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1042, "pto.tpartmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1043, "pto.tpartmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1044, "pto.tpartmul", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1045, "pto.tprefetch", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1046, "pto.tprelu", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x1047, "pto.tquant", 0, 0x00, 0x02, 3, 0, 0, 0x00}, + {0x1048, "pto.trecip", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1049, "pto.trelu", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x104A, "pto.trem", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x104B, "pto.trems", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x104C, "pto.treshape", 0, 0x01, 0x00, 1, 1, 0, 0x00}, + {0x104D, "pto.trowexpand", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x104E, "pto.trowexpandadd", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x104F, "pto.trowexpandexpdif", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1050, "pto.trowexpandmax", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1051, "pto.trowexpandmin", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1052, "pto.trowmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1053, "pto.trowmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1054, "pto.trowsum", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1055, "pto.trsqrt", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1056, "pto.tscatter", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1057, "pto.tsel", 0, 0x00, 0x00, 5, 0, 0, 0x00}, + {0x1058, "pto.tsels", 0, 0x00, 0x00, 5, 0, 0, 0x00}, + {0x1059, "pto.tset_img2col_padding", 0, 0x00, 0x00, 1, 0, 0, 0x00}, + {0x105A, "pto.tset_img2col_rpt", 0, 0x00, 0x00, 1, 0, 0, 0x00}, + {0x105B, "pto.tsetfmatrix", 0, 0x00, 0x00, 1, 0, 0, 0x00}, + {0x105C, "pto.tsethf32mode", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x105D, "pto.tsettf32mode", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x105E, "pto.tsetval", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x105F, "pto.tshl", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1060, "pto.tshls", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1061, "pto.tshr", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1062, "pto.tshrs", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1063, "pto.tsort32", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1064, "pto.tsqrt", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1065, "pto.tstore", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1066, "pto.tstore_fp", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1067, "pto.tsub", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1068, "pto.tsubc", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x1069, "pto.tsubs", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x106A, "pto.tsubsc", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x106B, "pto.trowexpandsub", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x106C, "pto.ttrans", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x106D, "pto.ttri", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x106E, "pto.txor", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x106F, "pto.txors", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x1070, "pto.wait_event", 0, 0x00, 0x00, 0, 0, 0, 0x02}, + {0x1071, "pto.tprint", 0, 0x00, 0x00, 1, 0, 0, 0x00}, + {0x1072, "pto.subview", 0, 0x01, 0x02, 0, 1, 0, 0x00}, + {0x1073, "pto.trowexpanddiv", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1074, "pto.trowexpandmul", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1075, "pto.tdequant", 0, 0x00, 0x00, 4, 0, 0, 0x00}, + {0x1076, "pto.taxpy", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1077, "pto.thistogram", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1078, "pto.tget_scale_addr", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1079, "pto.trowargmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x107A, "pto.trowargmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x107B, "pto.tcolargmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x107C, "pto.tcolargmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x107D, "pto.tsync", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x107E, "pto.reserve_buffer", 0, 0x01, 0x00, 0, 1, 0, 0x00}, + {0x107F, "pto.import_reserved_buffer", 0, 0x01, 0x00, 0, 1, 0, 0x00}, + {0x1080, "pto.aic_initialize_pipe", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1081, "pto.aiv_initialize_pipe", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1082, "pto.tpush_to_aiv", 0, 0x00, 0x00, 1, 0, 0, 0x00}, + {0x1083, "pto.tpush_to_aic", 0, 0x00, 0x00, 1, 0, 0, 0x00}, + {0x1084, "pto.tpop_from_aic", 0, 0x01, 0x02, 0, 1, 0, 0x00}, + {0x1085, "pto.tpop_from_aiv", 0, 0x01, 0x02, 0, 1, 0, 0x00}, + {0x1086, "pto.tfree_from_aic", 0, 0x00, 0x00, 0, 0, 0, 0x00}, + {0x1087, "pto.tfree_from_aiv", 0, 0x00, 0x00, 0, 0, 0, 0x00}, + {0x1088, "pto.set_validshape", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1089, "pto.tconcat", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x108A, "pto.trowprod", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x108B, "pto.initialize_l2g2l_pipe", 0, 0x01, 0x02, 0, 1, 0, 0x00}, + {0x108C, "pto.initialize_l2l_pipe", 0, 0x01, 0x02, 0, 1, 0, 0x00}, + {0x108D, "pto.tpush", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x108E, "pto.declare_tile", 0, 0x01, 0x00, 0, 1, 0, 0x00}, + {0x108F, "pto.tpop", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1090, "pto.tfree", 0, 0x00, 0x00, 1, 0, 0, 0x00}, + {0x1091, "pto.comm.tput", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1092, "pto.comm.tget", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1093, "pto.comm.tnotify", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1094, "pto.comm.twait", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1095, "pto.comm.ttest", 0, 0x01, 0x02, 0, 1, 0, 0x00}, + {0x1096, "pto.comm.tbroadcast", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1097, "pto.comm.tgather", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1098, "pto.comm.tscatter", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x1099, "pto.comm.treduce", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x109A, "pto.tpartargmax", 0, 0x00, 0x00, 6, 0, 0, 0x00}, + {0x109B, "pto.tpartargmin", 0, 0x00, 0x00, 6, 0, 0, 0x00}, + {0x109C, "pto.tscatter.maskpattern", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x2000, "arith.addi", 0, 0x01, 0x00, 2, 1, 0, 0x00}, + {0x2001, "arith.ceildivsi", 0, 0x01, 0x00, 2, 1, 0, 0x00}, + {0x2002, "arith.cmpi", 0, 0x01, 0x00, 2, 1, 0, 0x01}, + {0x2003, "arith.constant", 0, 0x01, 0x00, 0, 1, 0, 0x05}, + {0x2004, "arith.index_cast", 0, 0x01, 0x00, 1, 1, 0, 0x00}, + {0x2005, "arith.minui", 0, 0x01, 0x00, 2, 1, 0, 0x00}, + {0x2006, "arith.muli", 0, 0x01, 0x00, 2, 1, 0, 0x00}, + {0x2007, "arith.select", 0, 0x01, 0x00, 3, 1, 0, 0x00}, + {0x2008, "arith.subi", 0, 0x01, 0x00, 2, 1, 0, 0x00}, + {0x4000, "scf.for", 0, 0x00, 0x00, 3, 0, 1, 0x00}, + {0x4001, "scf.if", 0, 0x00, 0x00, 1, 0, 2, 0x00}, + {0x4002, "scf.yield", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x6000, "func.func", 0, 0x00, 0x00, 0, 0, 0, 0x00}, + {0x6001, "func.return", 0, 0x00, 0x02, 0, 0, 0, 0x00}, + {0x6002, "func.call", 0, 0x02, 0x02, 0, 0, 0, 0x00}, +}; + +const OpInfo *lookupByOpcode(uint16_t opcode) { + // Binary search on kOpTable (sorted by opcode). + size_t lo = 0, hi = sizeof(kOpTable) / sizeof(kOpTable[0]); + while (lo < hi) { + size_t mid = lo + (hi - lo) / 2; + uint16_t v = kOpTable[mid].opcode; + if (v == opcode) return &kOpTable[mid]; + if (v < opcode) lo = mid + 1; else hi = mid; + } + return nullptr; +} + +std::optional lookupOpcodeByName(llvm::StringRef name) { + uint16_t v = llvm::StringSwitch(name) + .Case("arith.addi", 0x2000) + .Case("arith.ceildivsi", 0x2001) + .Case("arith.cmpi", 0x2002) + .Case("arith.constant", 0x2003) + .Case("arith.index_cast", 0x2004) + .Case("arith.minui", 0x2005) + .Case("arith.muli", 0x2006) + .Case("arith.select", 0x2007) + .Case("arith.subi", 0x2008) + .Case("func.func", 0x6000) + .Case("func.return", 0x6001) + .Case("func.call", 0x6002) + .Case("pto.addptr", 0x1000) + .Case("pto.alloc_tile", 0x1001) + .Case("pto.barrier", 0x1002) + .Case("pto.get_block_idx", 0x0000) + .Case("pto.get_block_num", 0x0001) + .Case("pto.get_subblock_idx", 0x0002) + .Case("pto.get_subblock_num", 0x0003) + .Case("pto.make_tensor_view", 0x0004) + .Case("pto.mgather", 0x1003) + .Case("pto.mscatter", 0x1004) + .Case("pto.partition_view", 0x0005) + .Case("pto.record_event", 0x1005) + .Case("pto.section", 0x0006) + .Case("pto.tabs", 0x1006) + .Case("pto.tadd", 0x1007) + .Case("pto.taddc", 0x1008) + .Case("pto.tadds", 0x1009) + .Case("pto.taddsc", 0x100A) + .Case("pto.tand", 0x100B) + .Case("pto.tands", 0x100C) + .Case("pto.tci", 0x100D) + .Case("pto.tcmp", 0x100E) + .Case("pto.tcmps", 0x100F) + .Case("pto.tcolexpand", 0x1010) + .Case("pto.tcolexpandadd", 0x1011) + .Case("pto.tcolexpanddiv", 0x1012) + .Case("pto.tcolexpandexpdif", 0x1013) + .Case("pto.tcolexpandmax", 0x1014) + .Case("pto.tcolexpandmin", 0x1015) + .Case("pto.tcolexpandmul", 0x1016) + .Case("pto.tcolexpandsub", 0x1017) + .Case("pto.tcolmax", 0x1018) + .Case("pto.tcolmin", 0x1019) + .Case("pto.tcolprod", 0x101A) + .Case("pto.tcolsum", 0x101B) + .Case("pto.tcvt", 0x101C) + .Case("pto.tdiv", 0x101D) + .Case("pto.tdivs", 0x101E) + .Case("pto.texp", 0x101F) + .Case("pto.texpands", 0x1020) + .Case("pto.textract", 0x1021) + .Case("pto.textract_fp", 0x1022) + .Case("pto.tfillpad", 0x1023) + .Case("pto.tfillpad_expand", 0x1024) + .Case("pto.tfillpad_inplace", 0x1025) + .Case("pto.tfmod", 0x1026) + .Case("pto.tfmods", 0x1027) + .Case("pto.tgather", 0x1028) + .Case("pto.tgatherb", 0x1029) + .Case("pto.tgemv", 0x102A) + .Case("pto.tgetval", 0x102B) + .Case("pto.timg2col", 0x102C) + .Case("pto.tinsert", 0x102D) + .Case("pto.tinsert_fp", 0x102E) + .Case("pto.tload", 0x102F) + .Case("pto.tlog", 0x1030) + .Case("pto.tlrelu", 0x1031) + .Case("pto.tmatmul", 0x1032) + .Case("pto.tmatmul.mx", 0x1033) + .Case("pto.tmax", 0x1034) + .Case("pto.tmaxs", 0x1035) + .Case("pto.tmin", 0x1036) + .Case("pto.tmins", 0x1037) + .Case("pto.tmov", 0x1038) + .Case("pto.tmov.fp", 0x1039) + .Case("pto.tmrgsort", 0x103A) + .Case("pto.tmul", 0x103B) + .Case("pto.tmuls", 0x103C) + .Case("pto.tneg", 0x103D) + .Case("pto.tnot", 0x103E) + .Case("pto.tor", 0x103F) + .Case("pto.tors", 0x1040) + .Case("pto.tpartadd", 0x1041) + .Case("pto.tpartmax", 0x1042) + .Case("pto.tpartmin", 0x1043) + .Case("pto.tpartmul", 0x1044) + .Case("pto.tprefetch", 0x1045) + .Case("pto.tprelu", 0x1046) + .Case("pto.tquant", 0x1047) + .Case("pto.trecip", 0x1048) + .Case("pto.trelu", 0x1049) + .Case("pto.trem", 0x104A) + .Case("pto.trems", 0x104B) + .Case("pto.treshape", 0x104C) + .Case("pto.trowexpand", 0x104D) + .Case("pto.trowexpandadd", 0x104E) + .Case("pto.trowexpandexpdif", 0x104F) + .Case("pto.trowexpandmax", 0x1050) + .Case("pto.trowexpandmin", 0x1051) + .Case("pto.trowmax", 0x1052) + .Case("pto.trowmin", 0x1053) + .Case("pto.trowsum", 0x1054) + .Case("pto.trsqrt", 0x1055) + .Case("pto.tscatter", 0x1056) + .Case("pto.tsel", 0x1057) + .Case("pto.tsels", 0x1058) + .Case("pto.tset_img2col_padding", 0x1059) + .Case("pto.tset_img2col_rpt", 0x105A) + .Case("pto.tsetfmatrix", 0x105B) + .Case("pto.tsethf32mode", 0x105C) + .Case("pto.tsettf32mode", 0x105D) + .Case("pto.tsetval", 0x105E) + .Case("pto.tshl", 0x105F) + .Case("pto.tshls", 0x1060) + .Case("pto.tshr", 0x1061) + .Case("pto.tshrs", 0x1062) + .Case("pto.tsort32", 0x1063) + .Case("pto.tsqrt", 0x1064) + .Case("pto.tstore", 0x1065) + .Case("pto.tstore_fp", 0x1066) + .Case("pto.tsub", 0x1067) + .Case("pto.tsubc", 0x1068) + .Case("pto.tsubs", 0x1069) + .Case("pto.tsubsc", 0x106A) + .Case("pto.trowexpandsub", 0x106B) + .Case("pto.ttrans", 0x106C) + .Case("pto.ttri", 0x106D) + .Case("pto.txor", 0x106E) + .Case("pto.txors", 0x106F) + .Case("pto.wait_event", 0x1070) + .Case("pto.tprint", 0x1071) + .Case("pto.subview", 0x1072) + .Case("pto.trowexpanddiv", 0x1073) + .Case("pto.trowexpandmul", 0x1074) + .Case("pto.tdequant", 0x1075) + .Case("pto.taxpy", 0x1076) + .Case("pto.thistogram", 0x1077) + .Case("pto.tget_scale_addr", 0x1078) + .Case("pto.trowargmax", 0x1079) + .Case("pto.trowargmin", 0x107A) + .Case("pto.tcolargmax", 0x107B) + .Case("pto.tcolargmin", 0x107C) + .Case("pto.tsync", 0x107D) + .Case("pto.reserve_buffer", 0x107E) + .Case("pto.import_reserved_buffer", 0x107F) + .Case("pto.aic_initialize_pipe", 0x1080) + .Case("pto.aiv_initialize_pipe", 0x1081) + .Case("pto.tpush_to_aiv", 0x1082) + .Case("pto.tpush_to_aic", 0x1083) + .Case("pto.tpop_from_aic", 0x1084) + .Case("pto.tpop_from_aiv", 0x1085) + .Case("pto.tfree_from_aic", 0x1086) + .Case("pto.tfree_from_aiv", 0x1087) + .Case("pto.set_validshape", 0x1088) + .Case("pto.tconcat", 0x1089) + .Case("pto.trowprod", 0x108A) + .Case("pto.initialize_l2g2l_pipe", 0x108B) + .Case("pto.initialize_l2l_pipe", 0x108C) + .Case("pto.tpush", 0x108D) + .Case("pto.declare_tile", 0x108E) + .Case("pto.tpop", 0x108F) + .Case("pto.tfree", 0x1090) + .Case("pto.comm.tput", 0x1091) + .Case("pto.comm.tget", 0x1092) + .Case("pto.comm.tnotify", 0x1093) + .Case("pto.comm.twait", 0x1094) + .Case("pto.comm.ttest", 0x1095) + .Case("pto.comm.tbroadcast", 0x1096) + .Case("pto.comm.tgather", 0x1097) + .Case("pto.comm.tscatter", 0x1098) + .Case("pto.comm.treduce", 0x1099) + .Case("pto.tpartargmax", 0x109A) + .Case("pto.tpartargmin", 0x109B) + .Case("scf.for", 0x4000) + .Case("scf.if", 0x4001) + .Case("scf.yield", 0x4002) + .Default(0xFFFF); + if (v == 0xFFFF) return std::nullopt; + return v; +} + +const OpInfo *lookupByName(llvm::StringRef name) { + auto o = lookupOpcodeByName(name); + if (!o) return nullptr; + return lookupByOpcode(*o); +} + +std::optional lookupOpcodeAndVariantByFullName(llvm::StringRef fullName) { + // For non-family ops, variant is 0. For family ops, variant is the assigned u8. + // NOTE: `pto.section` is not a real op name; use `pto.section.cube`/`pto.section.vector`. + return llvm::StringSwitch>(fullName) + .Case("arith.addi", OpcodeAndVariant{0x2000, 0, 0}) + .Case("arith.ceildivsi", OpcodeAndVariant{0x2001, 0, 0}) + .Case("arith.cmpi", OpcodeAndVariant{0x2002, 0, 0}) + .Case("arith.constant", OpcodeAndVariant{0x2003, 0, 0}) + .Case("arith.index_cast", OpcodeAndVariant{0x2004, 0, 0}) + .Case("arith.minui", OpcodeAndVariant{0x2005, 0, 0}) + .Case("arith.muli", OpcodeAndVariant{0x2006, 0, 0}) + .Case("arith.select", OpcodeAndVariant{0x2007, 0, 0}) + .Case("arith.subi", OpcodeAndVariant{0x2008, 0, 0}) + .Case("func.func", OpcodeAndVariant{0x6000, 0, 0}) + .Case("func.return", OpcodeAndVariant{0x6001, 0, 0}) + .Case("func.call", OpcodeAndVariant{0x6002, 0, 0}) + .Case("pto.addptr", OpcodeAndVariant{0x1000, 0, 0}) + .Case("pto.alloc_tile", OpcodeAndVariant{0x1001, 0, 0}) + .Case("pto.barrier", OpcodeAndVariant{0x1002, 0, 0}) + .Case("pto.get_block_idx", OpcodeAndVariant{0x0000, 0, 0}) + .Case("pto.get_block_num", OpcodeAndVariant{0x0001, 0, 0}) + .Case("pto.get_subblock_idx", OpcodeAndVariant{0x0002, 0, 0}) + .Case("pto.get_subblock_num", OpcodeAndVariant{0x0003, 0, 0}) + .Case("pto.make_tensor_view", OpcodeAndVariant{0x0004, 0, 0}) + .Case("pto.mgather", OpcodeAndVariant{0x1003, 0, 0}) + .Case("pto.mscatter", OpcodeAndVariant{0x1004, 0, 0}) + .Case("pto.partition_view", OpcodeAndVariant{0x0005, 0, 0}) + .Case("pto.record_event", OpcodeAndVariant{0x1005, 0, 0}) + .Case("pto.tabs", OpcodeAndVariant{0x1006, 0, 0}) + .Case("pto.tadd", OpcodeAndVariant{0x1007, 0, 0}) + .Case("pto.taddc", OpcodeAndVariant{0x1008, 0, 0}) + .Case("pto.tadds", OpcodeAndVariant{0x1009, 0, 0}) + .Case("pto.taddsc", OpcodeAndVariant{0x100A, 0, 0}) + .Case("pto.tand", OpcodeAndVariant{0x100B, 0, 0}) + .Case("pto.tands", OpcodeAndVariant{0x100C, 0, 0}) + .Case("pto.tci", OpcodeAndVariant{0x100D, 0, 0}) + .Case("pto.tcmp", OpcodeAndVariant{0x100E, 0, 0}) + .Case("pto.tcmps", OpcodeAndVariant{0x100F, 0, 0}) + .Case("pto.tcolexpand", OpcodeAndVariant{0x1010, 0, 0}) + .Case("pto.tcolexpandadd", OpcodeAndVariant{0x1011, 0, 0}) + .Case("pto.tcolexpanddiv", OpcodeAndVariant{0x1012, 0, 0}) + .Case("pto.tcolexpandexpdif", OpcodeAndVariant{0x1013, 0, 0}) + .Case("pto.tcolexpandmax", OpcodeAndVariant{0x1014, 0, 0}) + .Case("pto.tcolexpandmin", OpcodeAndVariant{0x1015, 0, 0}) + .Case("pto.tcolexpandmul", OpcodeAndVariant{0x1016, 0, 0}) + .Case("pto.tcolexpandsub", OpcodeAndVariant{0x1017, 0, 0}) + .Case("pto.tcolmax", OpcodeAndVariant{0x1018, 0, 0}) + .Case("pto.tcolmin", OpcodeAndVariant{0x1019, 0, 0}) + .Case("pto.tcolprod", OpcodeAndVariant{0x101A, 0, 0}) + .Case("pto.tcolsum", OpcodeAndVariant{0x101B, 0, 0}) + .Case("pto.tcvt", OpcodeAndVariant{0x101C, 0, 0}) + .Case("pto.tdiv", OpcodeAndVariant{0x101D, 0, 0}) + .Case("pto.tdivs", OpcodeAndVariant{0x101E, 0, 0}) + .Case("pto.texp", OpcodeAndVariant{0x101F, 0, 0}) + .Case("pto.texpands", OpcodeAndVariant{0x1020, 0, 0}) + .Case("pto.textract", OpcodeAndVariant{0x1021, 0, 0}) + .Case("pto.textract_fp", OpcodeAndVariant{0x1022, 0, 0}) + .Case("pto.tfillpad", OpcodeAndVariant{0x1023, 0, 0}) + .Case("pto.tfillpad_expand", OpcodeAndVariant{0x1024, 0, 0}) + .Case("pto.tfillpad_inplace", OpcodeAndVariant{0x1025, 0, 0}) + .Case("pto.tfmod", OpcodeAndVariant{0x1026, 0, 0}) + .Case("pto.tfmods", OpcodeAndVariant{0x1027, 0, 0}) + .Case("pto.tgather", OpcodeAndVariant{0x1028, 0, 0}) + .Case("pto.tgatherb", OpcodeAndVariant{0x1029, 0, 0}) + .Case("pto.tgetval", OpcodeAndVariant{0x102B, 0, 0}) + .Case("pto.timg2col", OpcodeAndVariant{0x102C, 0, 0}) + .Case("pto.tinsert", OpcodeAndVariant{0x102D, 0, 0}) + .Case("pto.tinsert_fp", OpcodeAndVariant{0x102E, 0, 0}) + .Case("pto.tload", OpcodeAndVariant{0x102F, 0, 0}) + .Case("pto.tlog", OpcodeAndVariant{0x1030, 0, 0}) + .Case("pto.tlrelu", OpcodeAndVariant{0x1031, 0, 0}) + .Case("pto.tmax", OpcodeAndVariant{0x1034, 0, 0}) + .Case("pto.tmaxs", OpcodeAndVariant{0x1035, 0, 0}) + .Case("pto.tmin", OpcodeAndVariant{0x1036, 0, 0}) + .Case("pto.tmins", OpcodeAndVariant{0x1037, 0, 0}) + .Case("pto.tmov", OpcodeAndVariant{0x1038, 0, 0}) + .Case("pto.tmov.fp", OpcodeAndVariant{0x1039, 0, 0}) + .Case("pto.tmrgsort", OpcodeAndVariant{0x103A, 0, 0}) + .Case("pto.tmul", OpcodeAndVariant{0x103B, 0, 0}) + .Case("pto.tmuls", OpcodeAndVariant{0x103C, 0, 0}) + .Case("pto.tneg", OpcodeAndVariant{0x103D, 0, 0}) + .Case("pto.tnot", OpcodeAndVariant{0x103E, 0, 0}) + .Case("pto.tor", OpcodeAndVariant{0x103F, 0, 0}) + .Case("pto.tors", OpcodeAndVariant{0x1040, 0, 0}) + .Case("pto.tpartadd", OpcodeAndVariant{0x1041, 0, 0}) + .Case("pto.tpartmax", OpcodeAndVariant{0x1042, 0, 0}) + .Case("pto.tpartmin", OpcodeAndVariant{0x1043, 0, 0}) + .Case("pto.tpartmul", OpcodeAndVariant{0x1044, 0, 0}) + .Case("pto.tprefetch", OpcodeAndVariant{0x1045, 0, 0}) + .Case("pto.tprelu", OpcodeAndVariant{0x1046, 0, 0}) + .Case("pto.tquant", OpcodeAndVariant{0x1047, 0, 0}) + .Case("pto.trecip", OpcodeAndVariant{0x1048, 0, 0}) + .Case("pto.trelu", OpcodeAndVariant{0x1049, 0, 0}) + .Case("pto.trem", OpcodeAndVariant{0x104A, 0, 0}) + .Case("pto.trems", OpcodeAndVariant{0x104B, 0, 0}) + .Case("pto.treshape", OpcodeAndVariant{0x104C, 0, 0}) + .Case("pto.trowexpand", OpcodeAndVariant{0x104D, 0, 0}) + .Case("pto.trowexpandadd", OpcodeAndVariant{0x104E, 0, 0}) + .Case("pto.trowexpandexpdif", OpcodeAndVariant{0x104F, 0, 0}) + .Case("pto.trowexpandmax", OpcodeAndVariant{0x1050, 0, 0}) + .Case("pto.trowexpandmin", OpcodeAndVariant{0x1051, 0, 0}) + .Case("pto.trowmax", OpcodeAndVariant{0x1052, 0, 0}) + .Case("pto.trowmin", OpcodeAndVariant{0x1053, 0, 0}) + .Case("pto.trowsum", OpcodeAndVariant{0x1054, 0, 0}) + .Case("pto.trsqrt", OpcodeAndVariant{0x1055, 0, 0}) + .Case("pto.tscatter", OpcodeAndVariant{0x1056, 0, 0}) + .Case("pto.tsel", OpcodeAndVariant{0x1057, 0, 0}) + .Case("pto.tsels", OpcodeAndVariant{0x1058, 0, 0}) + .Case("pto.tset_img2col_padding", OpcodeAndVariant{0x1059, 0, 0}) + .Case("pto.tset_img2col_rpt", OpcodeAndVariant{0x105A, 0, 0}) + .Case("pto.tsetfmatrix", OpcodeAndVariant{0x105B, 0, 0}) + .Case("pto.tsethf32mode", OpcodeAndVariant{0x105C, 0, 0}) + .Case("pto.tsettf32mode", OpcodeAndVariant{0x105D, 0, 0}) + .Case("pto.tsetval", OpcodeAndVariant{0x105E, 0, 0}) + .Case("pto.tshl", OpcodeAndVariant{0x105F, 0, 0}) + .Case("pto.tshls", OpcodeAndVariant{0x1060, 0, 0}) + .Case("pto.tshr", OpcodeAndVariant{0x1061, 0, 0}) + .Case("pto.tshrs", OpcodeAndVariant{0x1062, 0, 0}) + .Case("pto.tsort32", OpcodeAndVariant{0x1063, 0, 0}) + .Case("pto.tsqrt", OpcodeAndVariant{0x1064, 0, 0}) + .Case("pto.tstore", OpcodeAndVariant{0x1065, 0, 0}) + .Case("pto.tstore_fp", OpcodeAndVariant{0x1066, 0, 0}) + .Case("pto.tsub", OpcodeAndVariant{0x1067, 0, 0}) + .Case("pto.tsubc", OpcodeAndVariant{0x1068, 0, 0}) + .Case("pto.tsubs", OpcodeAndVariant{0x1069, 0, 0}) + .Case("pto.tsubsc", OpcodeAndVariant{0x106A, 0, 0}) + .Case("pto.trowexpandsub", OpcodeAndVariant{0x106B, 0, 0}) + .Case("pto.ttrans", OpcodeAndVariant{0x106C, 0, 0}) + .Case("pto.ttri", OpcodeAndVariant{0x106D, 0, 0}) + .Case("pto.txor", OpcodeAndVariant{0x106E, 0, 0}) + .Case("pto.txors", OpcodeAndVariant{0x106F, 0, 0}) + .Case("pto.wait_event", OpcodeAndVariant{0x1070, 0, 0}) + .Case("pto.tprint", OpcodeAndVariant{0x1071, 0, 0}) + .Case("pto.subview", OpcodeAndVariant{0x1072, 0, 0}) + .Case("pto.trowexpanddiv", OpcodeAndVariant{0x1073, 0, 0}) + .Case("pto.trowexpandmul", OpcodeAndVariant{0x1074, 0, 0}) + .Case("pto.tdequant", OpcodeAndVariant{0x1075, 0, 0}) + .Case("pto.taxpy", OpcodeAndVariant{0x1076, 0, 0}) + .Case("pto.thistogram", OpcodeAndVariant{0x1077, 0, 0}) + .Case("pto.tget_scale_addr", OpcodeAndVariant{0x1078, 0, 0}) + .Case("pto.trowargmax", OpcodeAndVariant{0x1079, 0, 0}) + .Case("pto.trowargmin", OpcodeAndVariant{0x107A, 0, 0}) + .Case("pto.tcolargmax", OpcodeAndVariant{0x107B, 0, 0}) + .Case("pto.tcolargmin", OpcodeAndVariant{0x107C, 0, 0}) + .Case("pto.tsync", OpcodeAndVariant{0x107D, 0, 0}) + .Case("pto.reserve_buffer", OpcodeAndVariant{0x107E, 0, 0}) + .Case("pto.import_reserved_buffer", OpcodeAndVariant{0x107F, 0, 0}) + .Case("pto.aic_initialize_pipe", OpcodeAndVariant{0x1080, 0, 0}) + .Case("pto.aiv_initialize_pipe", OpcodeAndVariant{0x1081, 0, 0}) + .Case("pto.tpush_to_aiv", OpcodeAndVariant{0x1082, 0, 0}) + .Case("pto.tpush_to_aic", OpcodeAndVariant{0x1083, 0, 0}) + .Case("pto.tpop_from_aic", OpcodeAndVariant{0x1084, 0, 0}) + .Case("pto.tpop_from_aiv", OpcodeAndVariant{0x1085, 0, 0}) + .Case("pto.tfree_from_aic", OpcodeAndVariant{0x1086, 0, 0}) + .Case("pto.tfree_from_aiv", OpcodeAndVariant{0x1087, 0, 0}) + .Case("pto.set_validshape", OpcodeAndVariant{0x1088, 0, 0}) + .Case("pto.tconcat", OpcodeAndVariant{0x1089, 0, 0}) + .Case("pto.trowprod", OpcodeAndVariant{0x108A, 0, 0}) + .Case("pto.initialize_l2g2l_pipe", OpcodeAndVariant{0x108B, 0, 0}) + .Case("pto.initialize_l2l_pipe", OpcodeAndVariant{0x108C, 0, 0}) + .Case("pto.tpush", OpcodeAndVariant{0x108D, 0, 0}) + .Case("pto.declare_tile", OpcodeAndVariant{0x108E, 0, 0}) + .Case("pto.tpop", OpcodeAndVariant{0x108F, 0, 0}) + .Case("pto.tfree", OpcodeAndVariant{0x1090, 0, 0}) + .Case("pto.comm.tput", OpcodeAndVariant{0x1091, 0, 0}) + .Case("pto.comm.tget", OpcodeAndVariant{0x1092, 0, 0}) + .Case("pto.comm.tnotify", OpcodeAndVariant{0x1093, 0, 0}) + .Case("pto.comm.twait", OpcodeAndVariant{0x1094, 0, 0}) + .Case("pto.comm.ttest", OpcodeAndVariant{0x1095, 0, 0}) + .Case("pto.comm.tbroadcast", OpcodeAndVariant{0x1096, 0, 0}) + .Case("pto.comm.tgather", OpcodeAndVariant{0x1097, 0, 0}) + .Case("pto.comm.tscatter", OpcodeAndVariant{0x1098, 0, 0}) + .Case("pto.comm.treduce", OpcodeAndVariant{0x1099, 0, 0}) + .Case("pto.tpartargmax", OpcodeAndVariant{0x109A, 0, 0}) + .Case("pto.tpartargmin", OpcodeAndVariant{0x109B, 0, 0}) + .Case("scf.for", OpcodeAndVariant{0x4000, 0, 0}) + .Case("scf.if", OpcodeAndVariant{0x4001, 0, 0}) + .Case("scf.yield", OpcodeAndVariant{0x4002, 0, 0}) + .Case("pto.section.cube", + OpcodeAndVariant{0x0006, kHasVariant, kSectionCubeVariant}) + .Case("pto.section.vector", + OpcodeAndVariant{0x0006, kHasVariant, kSectionVectorVariant}) + .Case("pto.tgemv", + OpcodeAndVariant{0x102A, kHasVariant, kVariantDefault}) + .Case("pto.tgemv.acc", + OpcodeAndVariant{0x102A, kHasVariant, kVariantAcc}) + .Case("pto.tgemv.bias", + OpcodeAndVariant{0x102A, kHasVariant, kVariantBias}) + .Case("pto.tgemv.mx", + OpcodeAndVariant{0x102A, kHasVariant, kVariantMx}) + .Case("pto.tgemv.mx.acc", + OpcodeAndVariant{0x102A, kHasVariant, kVariantMxAcc}) + .Case("pto.tgemv.mx.bias", + OpcodeAndVariant{0x102A, kHasVariant, kVariantMxBias}) + .Case("pto.tmatmul", + OpcodeAndVariant{0x1032, kHasVariant, kVariantDefault}) + .Case("pto.tmatmul.acc", + OpcodeAndVariant{0x1032, kHasVariant, kVariantAcc}) + .Case("pto.tmatmul.bias", + OpcodeAndVariant{0x1032, kHasVariant, kVariantBias}) + .Case("pto.tmatmul.mx", + OpcodeAndVariant{0x1033, kHasVariant, kVariantDefault}) + .Case("pto.tmatmul.mx.acc", + OpcodeAndVariant{0x1033, kHasVariant, kVariantAcc}) + .Case("pto.tmatmul.mx.bias", + OpcodeAndVariant{0x1033, kHasVariant, kVariantBias}) + .Default(std::nullopt); +} + +const char *fullNameFromOpcodeVariant(uint16_t opcode, uint8_t variant) { + const OpInfo *info = lookupByOpcode(opcode); + if (!info) return nullptr; + if (opcode == kTscatterMaskOpcode) return "pto.tscatter"; + if (!info->has_variant_u8) return info->name; + switch (opcode) { + case 0x0006: + switch (variant) { + case kSectionCubeVariant: return "pto.section.cube"; + case kSectionVectorVariant: return "pto.section.vector"; + default: return info->name; + } + case 0x102A: + switch (variant) { + case kVariantDefault: return "pto.tgemv"; + case kVariantAcc: return "pto.tgemv.acc"; + case kVariantBias: return "pto.tgemv.bias"; + case kVariantMx: return "pto.tgemv.mx"; + case kVariantMxAcc: return "pto.tgemv.mx.acc"; + case kVariantMxBias: return "pto.tgemv.mx.bias"; + default: return info->name; + } + case 0x1032: + switch (variant) { + case kVariantDefault: return "pto.tmatmul"; + case kVariantAcc: return "pto.tmatmul.acc"; + case kVariantBias: return "pto.tmatmul.bias"; + default: return info->name; + } + case 0x1033: + switch (variant) { + case kVariantDefault: return "pto.tmatmul.mx"; + case kVariantAcc: return "pto.tmatmul.mx.acc"; + case kVariantBias: return "pto.tmatmul.mx.bias"; + default: return info->name; + } + default: return info->name; + } +} + +std::optional lookupOperandsByVariant(uint16_t opcode, uint8_t variant) { + switch (opcode) { + case 0x102A: + switch (variant) { + case kVariantDefault: return kTgemvOperandCount; + case kVariantAcc: return kTgemvAccOperandCount; + case kVariantBias: return kTgemvBiasOperandCount; + case kVariantMx: return kTgemvMxOperandCount; + case kVariantMxAcc: return kTgemvMxAccOperandCount; + case kVariantMxBias: return kTgemvMxBiasOperandCount; + default: return std::nullopt; + } + case 0x1032: + switch (variant) { + case kVariantDefault: return kTmatmulOperandCount; + case kVariantAcc: return kTmatmulAccOperandCount; + case kVariantBias: return kTmatmulBiasOperandCount; + default: return std::nullopt; + } + case 0x1033: + switch (variant) { + case kVariantDefault: return kTmatmulMxOperandCount; + case kVariantAcc: return kTmatmulMxAccOperandCount; + case kVariantBias: return kTmatmulMxBiasOperandCount; + default: return std::nullopt; + } + default: return std::nullopt; + } +} + +} // namespace ptobc::v0 diff --git a/tools/ptobc/generated/ptobc_opcodes_v0.h b/tools/ptobc/generated/ptobc_opcodes_v0.h index 8303e1261..9c7a9a1d0 100644 --- a/tools/ptobc/generated/ptobc_opcodes_v0.h +++ b/tools/ptobc/generated/ptobc_opcodes_v0.h @@ -53,670 +53,16 @@ struct OpInfo { uint8_t imm_kind; }; -inline constexpr OpInfo kOpTable[] = { - {0x0000, "pto.get_block_idx", 0, 0x01, 0x00, 0, 1, 0, 0x00}, - {0x0001, "pto.get_block_num", 0, 0x01, 0x00, 0, 1, 0, 0x00}, - {0x0002, "pto.get_subblock_idx", 0, 0x01, 0x00, 0, 1, 0, 0x00}, - {0x0003, "pto.get_subblock_num", 0, 0x01, 0x00, 0, 1, 0, 0x00}, - {0x0004, "pto.make_tensor_view", 0, 0x01, 0x03, 1, 1, 0, 0x06}, - {0x0005, "pto.partition_view", 0, 0x01, 0x03, 1, 1, 0, 0x07}, - {0x0006, "pto.section", 1, 0x00, 0x00, 0, 0, 1, 0x00}, - {0x1000, "pto.addptr", 0, 0x01, 0x00, 2, 1, 0, 0x00}, - {0x1001, "pto.alloc_tile", 0, 0x01, 0x04, 0, 1, 0, 0x08}, - {0x1002, "pto.barrier", 0, 0x00, 0x00, 0, 0, 0, 0x00}, - {0x1003, "pto.mgather", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1004, "pto.mscatter", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1005, "pto.record_event", 0, 0x00, 0x00, 0, 0, 0, 0x02}, - {0x1006, "pto.tabs", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1007, "pto.tadd", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1008, "pto.taddc", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x1009, "pto.tadds", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x100A, "pto.taddsc", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x100B, "pto.tand", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x100C, "pto.tands", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x100D, "pto.tci", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x100E, "pto.tcmp", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x100F, "pto.tcmps", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1010, "pto.tcolexpand", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1011, "pto.tcolexpandadd", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1012, "pto.tcolexpanddiv", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1013, "pto.tcolexpandexpdif", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1014, "pto.tcolexpandmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1015, "pto.tcolexpandmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1016, "pto.tcolexpandmul", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1017, "pto.tcolexpandsub", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1018, "pto.tcolmax", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1019, "pto.tcolmin", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x101A, "pto.tcolprod", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x101B, "pto.tcolsum", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x101C, "pto.tcvt", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x101D, "pto.tdiv", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x101E, "pto.tdivs", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x101F, "pto.texp", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1020, "pto.texpands", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1021, "pto.textract", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x1022, "pto.textract_fp", 0, 0x00, 0x00, 5, 0, 0, 0x00}, - {0x1023, "pto.tfillpad", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1024, "pto.tfillpad_expand", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1025, "pto.tfillpad_inplace", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1026, "pto.tfmod", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1027, "pto.tfmods", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1028, "pto.tgather", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1029, "pto.tgatherb", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x102A, "pto.tgemv", 1, 0x00, 0x01, 0, 0, 0, 0x00}, - {0x102B, "pto.tgetval", 0, 0x01, 0x00, 2, 1, 0, 0x00}, - {0x102C, "pto.timg2col", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x102D, "pto.tinsert", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x102E, "pto.tinsert_fp", 0, 0x00, 0x00, 5, 0, 0, 0x00}, - {0x102F, "pto.tload", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1030, "pto.tlog", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1031, "pto.tlrelu", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1032, "pto.tmatmul", 1, 0x00, 0x01, 0, 0, 0, 0x00}, - {0x1033, "pto.tmatmul.mx", 1, 0x00, 0x01, 0, 0, 0, 0x00}, - {0x1034, "pto.tmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1035, "pto.tmaxs", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1036, "pto.tmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1037, "pto.tmins", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1038, "pto.tmov", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1039, "pto.tmov.fp", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x103A, "pto.tmrgsort", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x103B, "pto.tmul", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x103C, "pto.tmuls", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x103D, "pto.tneg", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x103E, "pto.tnot", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x103F, "pto.tor", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1040, "pto.tors", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1041, "pto.tpartadd", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1042, "pto.tpartmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1043, "pto.tpartmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1044, "pto.tpartmul", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1045, "pto.tprefetch", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1046, "pto.tprelu", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x1047, "pto.tquant", 0, 0x00, 0x02, 3, 0, 0, 0x00}, - {0x1048, "pto.trecip", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1049, "pto.trelu", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x104A, "pto.trem", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x104B, "pto.trems", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x104C, "pto.treshape", 0, 0x01, 0x00, 1, 1, 0, 0x00}, - {0x104D, "pto.trowexpand", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x104E, "pto.trowexpandadd", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x104F, "pto.trowexpandexpdif", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1050, "pto.trowexpandmax", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1051, "pto.trowexpandmin", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1052, "pto.trowmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1053, "pto.trowmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1054, "pto.trowsum", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1055, "pto.trsqrt", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1056, "pto.tscatter", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1057, "pto.tsel", 0, 0x00, 0x00, 5, 0, 0, 0x00}, - {0x1058, "pto.tsels", 0, 0x00, 0x00, 5, 0, 0, 0x00}, - {0x1059, "pto.tset_img2col_padding", 0, 0x00, 0x00, 1, 0, 0, 0x00}, - {0x105A, "pto.tset_img2col_rpt", 0, 0x00, 0x00, 1, 0, 0, 0x00}, - {0x105B, "pto.tsetfmatrix", 0, 0x00, 0x00, 1, 0, 0, 0x00}, - {0x105C, "pto.tsethf32mode", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x105D, "pto.tsettf32mode", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x105E, "pto.tsetval", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x105F, "pto.tshl", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1060, "pto.tshls", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1061, "pto.tshr", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1062, "pto.tshrs", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1063, "pto.tsort32", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1064, "pto.tsqrt", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1065, "pto.tstore", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1066, "pto.tstore_fp", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1067, "pto.tsub", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1068, "pto.tsubc", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x1069, "pto.tsubs", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x106A, "pto.tsubsc", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x106B, "pto.trowexpandsub", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x106C, "pto.ttrans", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x106D, "pto.ttri", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x106E, "pto.txor", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x106F, "pto.txors", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x1070, "pto.wait_event", 0, 0x00, 0x00, 0, 0, 0, 0x02}, - {0x1071, "pto.tprint", 0, 0x00, 0x00, 1, 0, 0, 0x00}, - {0x1072, "pto.subview", 0, 0x01, 0x02, 0, 1, 0, 0x00}, - {0x1073, "pto.trowexpanddiv", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1074, "pto.trowexpandmul", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1075, "pto.tdequant", 0, 0x00, 0x00, 4, 0, 0, 0x00}, - {0x1076, "pto.taxpy", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1077, "pto.thistogram", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1078, "pto.tget_scale_addr", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1079, "pto.trowargmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x107A, "pto.trowargmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x107B, "pto.tcolargmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x107C, "pto.tcolargmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x107D, "pto.tsync", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x107E, "pto.reserve_buffer", 0, 0x01, 0x00, 0, 1, 0, 0x00}, - {0x107F, "pto.import_reserved_buffer", 0, 0x01, 0x00, 0, 1, 0, 0x00}, - {0x1080, "pto.aic_initialize_pipe", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1081, "pto.aiv_initialize_pipe", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1082, "pto.tpush_to_aiv", 0, 0x00, 0x00, 1, 0, 0, 0x00}, - {0x1083, "pto.tpush_to_aic", 0, 0x00, 0x00, 1, 0, 0, 0x00}, - {0x1084, "pto.tpop_from_aic", 0, 0x01, 0x02, 0, 1, 0, 0x00}, - {0x1085, "pto.tpop_from_aiv", 0, 0x01, 0x02, 0, 1, 0, 0x00}, - {0x1086, "pto.tfree_from_aic", 0, 0x00, 0x00, 0, 0, 0, 0x00}, - {0x1087, "pto.tfree_from_aiv", 0, 0x00, 0x00, 0, 0, 0, 0x00}, - {0x1088, "pto.set_validshape", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x1089, "pto.tconcat", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x108A, "pto.trowprod", 0, 0x00, 0x00, 3, 0, 0, 0x00}, - {0x108B, "pto.initialize_l2g2l_pipe", 0, 0x01, 0x02, 0, 1, 0, 0x00}, - {0x108C, "pto.initialize_l2l_pipe", 0, 0x01, 0x02, 0, 1, 0, 0x00}, - {0x108D, "pto.tpush", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x108E, "pto.declare_tile", 0, 0x01, 0x00, 0, 1, 0, 0x00}, - {0x108F, "pto.tpop", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x1090, "pto.tfree", 0, 0x00, 0x00, 1, 0, 0, 0x00}, - {0x1091, "pto.comm.tput", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1092, "pto.comm.tget", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1093, "pto.comm.tnotify", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1094, "pto.comm.twait", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1095, "pto.comm.ttest", 0, 0x01, 0x02, 0, 1, 0, 0x00}, - {0x1096, "pto.comm.tbroadcast", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1097, "pto.comm.tgather", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1098, "pto.comm.tscatter", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x1099, "pto.comm.treduce", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x109A, "pto.tpartargmax", 0, 0x00, 0x00, 6, 0, 0, 0x00}, - {0x109B, "pto.tpartargmin", 0, 0x00, 0x00, 6, 0, 0, 0x00}, - {0x109C, "pto.tscatter.maskpattern", 0, 0x00, 0x00, 2, 0, 0, 0x00}, - {0x2000, "arith.addi", 0, 0x01, 0x00, 2, 1, 0, 0x00}, - {0x2001, "arith.ceildivsi", 0, 0x01, 0x00, 2, 1, 0, 0x00}, - {0x2002, "arith.cmpi", 0, 0x01, 0x00, 2, 1, 0, 0x01}, - {0x2003, "arith.constant", 0, 0x01, 0x00, 0, 1, 0, 0x05}, - {0x2004, "arith.index_cast", 0, 0x01, 0x00, 1, 1, 0, 0x00}, - {0x2005, "arith.minui", 0, 0x01, 0x00, 2, 1, 0, 0x00}, - {0x2006, "arith.muli", 0, 0x01, 0x00, 2, 1, 0, 0x00}, - {0x2007, "arith.select", 0, 0x01, 0x00, 3, 1, 0, 0x00}, - {0x2008, "arith.subi", 0, 0x01, 0x00, 2, 1, 0, 0x00}, - {0x4000, "scf.for", 0, 0x00, 0x00, 3, 0, 1, 0x00}, - {0x4001, "scf.if", 0, 0x00, 0x00, 1, 0, 2, 0x00}, - {0x4002, "scf.yield", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x6000, "func.func", 0, 0x00, 0x00, 0, 0, 0, 0x00}, - {0x6001, "func.return", 0, 0x00, 0x02, 0, 0, 0, 0x00}, - {0x6002, "func.call", 0, 0x02, 0x02, 0, 0, 0, 0x00}, -}; - -inline const OpInfo *lookupByOpcode(uint16_t opcode) { - // Binary search on kOpTable (sorted by opcode). - size_t lo = 0, hi = sizeof(kOpTable) / sizeof(kOpTable[0]); - while (lo < hi) { - size_t mid = lo + (hi - lo) / 2; - uint16_t v = kOpTable[mid].opcode; - if (v == opcode) return &kOpTable[mid]; - if (v < opcode) lo = mid + 1; else hi = mid; - } - return nullptr; -} - -inline std::optional lookupOpcodeByName(llvm::StringRef name) { - uint16_t v = llvm::StringSwitch(name) - .Case("arith.addi", 0x2000) - .Case("arith.ceildivsi", 0x2001) - .Case("arith.cmpi", 0x2002) - .Case("arith.constant", 0x2003) - .Case("arith.index_cast", 0x2004) - .Case("arith.minui", 0x2005) - .Case("arith.muli", 0x2006) - .Case("arith.select", 0x2007) - .Case("arith.subi", 0x2008) - .Case("func.func", 0x6000) - .Case("func.return", 0x6001) - .Case("func.call", 0x6002) - .Case("pto.addptr", 0x1000) - .Case("pto.alloc_tile", 0x1001) - .Case("pto.barrier", 0x1002) - .Case("pto.get_block_idx", 0x0000) - .Case("pto.get_block_num", 0x0001) - .Case("pto.get_subblock_idx", 0x0002) - .Case("pto.get_subblock_num", 0x0003) - .Case("pto.make_tensor_view", 0x0004) - .Case("pto.mgather", 0x1003) - .Case("pto.mscatter", 0x1004) - .Case("pto.partition_view", 0x0005) - .Case("pto.record_event", 0x1005) - .Case("pto.section", 0x0006) - .Case("pto.tabs", 0x1006) - .Case("pto.tadd", 0x1007) - .Case("pto.taddc", 0x1008) - .Case("pto.tadds", 0x1009) - .Case("pto.taddsc", 0x100A) - .Case("pto.tand", 0x100B) - .Case("pto.tands", 0x100C) - .Case("pto.tci", 0x100D) - .Case("pto.tcmp", 0x100E) - .Case("pto.tcmps", 0x100F) - .Case("pto.tcolexpand", 0x1010) - .Case("pto.tcolexpandadd", 0x1011) - .Case("pto.tcolexpanddiv", 0x1012) - .Case("pto.tcolexpandexpdif", 0x1013) - .Case("pto.tcolexpandmax", 0x1014) - .Case("pto.tcolexpandmin", 0x1015) - .Case("pto.tcolexpandmul", 0x1016) - .Case("pto.tcolexpandsub", 0x1017) - .Case("pto.tcolmax", 0x1018) - .Case("pto.tcolmin", 0x1019) - .Case("pto.tcolprod", 0x101A) - .Case("pto.tcolsum", 0x101B) - .Case("pto.tcvt", 0x101C) - .Case("pto.tdiv", 0x101D) - .Case("pto.tdivs", 0x101E) - .Case("pto.texp", 0x101F) - .Case("pto.texpands", 0x1020) - .Case("pto.textract", 0x1021) - .Case("pto.textract_fp", 0x1022) - .Case("pto.tfillpad", 0x1023) - .Case("pto.tfillpad_expand", 0x1024) - .Case("pto.tfillpad_inplace", 0x1025) - .Case("pto.tfmod", 0x1026) - .Case("pto.tfmods", 0x1027) - .Case("pto.tgather", 0x1028) - .Case("pto.tgatherb", 0x1029) - .Case("pto.tgemv", 0x102A) - .Case("pto.tgetval", 0x102B) - .Case("pto.timg2col", 0x102C) - .Case("pto.tinsert", 0x102D) - .Case("pto.tinsert_fp", 0x102E) - .Case("pto.tload", 0x102F) - .Case("pto.tlog", 0x1030) - .Case("pto.tlrelu", 0x1031) - .Case("pto.tmatmul", 0x1032) - .Case("pto.tmatmul.mx", 0x1033) - .Case("pto.tmax", 0x1034) - .Case("pto.tmaxs", 0x1035) - .Case("pto.tmin", 0x1036) - .Case("pto.tmins", 0x1037) - .Case("pto.tmov", 0x1038) - .Case("pto.tmov.fp", 0x1039) - .Case("pto.tmrgsort", 0x103A) - .Case("pto.tmul", 0x103B) - .Case("pto.tmuls", 0x103C) - .Case("pto.tneg", 0x103D) - .Case("pto.tnot", 0x103E) - .Case("pto.tor", 0x103F) - .Case("pto.tors", 0x1040) - .Case("pto.tpartadd", 0x1041) - .Case("pto.tpartmax", 0x1042) - .Case("pto.tpartmin", 0x1043) - .Case("pto.tpartmul", 0x1044) - .Case("pto.tprefetch", 0x1045) - .Case("pto.tprelu", 0x1046) - .Case("pto.tquant", 0x1047) - .Case("pto.trecip", 0x1048) - .Case("pto.trelu", 0x1049) - .Case("pto.trem", 0x104A) - .Case("pto.trems", 0x104B) - .Case("pto.treshape", 0x104C) - .Case("pto.trowexpand", 0x104D) - .Case("pto.trowexpandadd", 0x104E) - .Case("pto.trowexpandexpdif", 0x104F) - .Case("pto.trowexpandmax", 0x1050) - .Case("pto.trowexpandmin", 0x1051) - .Case("pto.trowmax", 0x1052) - .Case("pto.trowmin", 0x1053) - .Case("pto.trowsum", 0x1054) - .Case("pto.trsqrt", 0x1055) - .Case("pto.tscatter", 0x1056) - .Case("pto.tsel", 0x1057) - .Case("pto.tsels", 0x1058) - .Case("pto.tset_img2col_padding", 0x1059) - .Case("pto.tset_img2col_rpt", 0x105A) - .Case("pto.tsetfmatrix", 0x105B) - .Case("pto.tsethf32mode", 0x105C) - .Case("pto.tsettf32mode", 0x105D) - .Case("pto.tsetval", 0x105E) - .Case("pto.tshl", 0x105F) - .Case("pto.tshls", 0x1060) - .Case("pto.tshr", 0x1061) - .Case("pto.tshrs", 0x1062) - .Case("pto.tsort32", 0x1063) - .Case("pto.tsqrt", 0x1064) - .Case("pto.tstore", 0x1065) - .Case("pto.tstore_fp", 0x1066) - .Case("pto.tsub", 0x1067) - .Case("pto.tsubc", 0x1068) - .Case("pto.tsubs", 0x1069) - .Case("pto.tsubsc", 0x106A) - .Case("pto.trowexpandsub", 0x106B) - .Case("pto.ttrans", 0x106C) - .Case("pto.ttri", 0x106D) - .Case("pto.txor", 0x106E) - .Case("pto.txors", 0x106F) - .Case("pto.wait_event", 0x1070) - .Case("pto.tprint", 0x1071) - .Case("pto.subview", 0x1072) - .Case("pto.trowexpanddiv", 0x1073) - .Case("pto.trowexpandmul", 0x1074) - .Case("pto.tdequant", 0x1075) - .Case("pto.taxpy", 0x1076) - .Case("pto.thistogram", 0x1077) - .Case("pto.tget_scale_addr", 0x1078) - .Case("pto.trowargmax", 0x1079) - .Case("pto.trowargmin", 0x107A) - .Case("pto.tcolargmax", 0x107B) - .Case("pto.tcolargmin", 0x107C) - .Case("pto.tsync", 0x107D) - .Case("pto.reserve_buffer", 0x107E) - .Case("pto.import_reserved_buffer", 0x107F) - .Case("pto.aic_initialize_pipe", 0x1080) - .Case("pto.aiv_initialize_pipe", 0x1081) - .Case("pto.tpush_to_aiv", 0x1082) - .Case("pto.tpush_to_aic", 0x1083) - .Case("pto.tpop_from_aic", 0x1084) - .Case("pto.tpop_from_aiv", 0x1085) - .Case("pto.tfree_from_aic", 0x1086) - .Case("pto.tfree_from_aiv", 0x1087) - .Case("pto.set_validshape", 0x1088) - .Case("pto.tconcat", 0x1089) - .Case("pto.trowprod", 0x108A) - .Case("pto.initialize_l2g2l_pipe", 0x108B) - .Case("pto.initialize_l2l_pipe", 0x108C) - .Case("pto.tpush", 0x108D) - .Case("pto.declare_tile", 0x108E) - .Case("pto.tpop", 0x108F) - .Case("pto.tfree", 0x1090) - .Case("pto.comm.tput", 0x1091) - .Case("pto.comm.tget", 0x1092) - .Case("pto.comm.tnotify", 0x1093) - .Case("pto.comm.twait", 0x1094) - .Case("pto.comm.ttest", 0x1095) - .Case("pto.comm.tbroadcast", 0x1096) - .Case("pto.comm.tgather", 0x1097) - .Case("pto.comm.tscatter", 0x1098) - .Case("pto.comm.treduce", 0x1099) - .Case("pto.tpartargmax", 0x109A) - .Case("pto.tpartargmin", 0x109B) - .Case("scf.for", 0x4000) - .Case("scf.if", 0x4001) - .Case("scf.yield", 0x4002) - .Default(0xFFFF); - if (v == 0xFFFF) return std::nullopt; - return v; -} +extern const OpInfo kOpTable[]; -inline const OpInfo *lookupByName(llvm::StringRef name) { - auto o = lookupOpcodeByName(name); - if (!o) return nullptr; - return lookupByOpcode(*o); -} +const OpInfo *lookupByOpcode(uint16_t opcode); +std::optional lookupOpcodeByName(llvm::StringRef name); +const OpInfo *lookupByName(llvm::StringRef name); struct OpcodeAndVariant { uint16_t opcode; uint8_t hasVariant; uint8_t variant; }; -inline std::optional lookupOpcodeAndVariantByFullName(llvm::StringRef fullName) { - // For non-family ops, variant is 0. For family ops, variant is the assigned u8. - // NOTE: `pto.section` is not a real op name; use `pto.section.cube`/`pto.section.vector`. - return llvm::StringSwitch>(fullName) - .Case("arith.addi", OpcodeAndVariant{0x2000, 0, 0}) - .Case("arith.ceildivsi", OpcodeAndVariant{0x2001, 0, 0}) - .Case("arith.cmpi", OpcodeAndVariant{0x2002, 0, 0}) - .Case("arith.constant", OpcodeAndVariant{0x2003, 0, 0}) - .Case("arith.index_cast", OpcodeAndVariant{0x2004, 0, 0}) - .Case("arith.minui", OpcodeAndVariant{0x2005, 0, 0}) - .Case("arith.muli", OpcodeAndVariant{0x2006, 0, 0}) - .Case("arith.select", OpcodeAndVariant{0x2007, 0, 0}) - .Case("arith.subi", OpcodeAndVariant{0x2008, 0, 0}) - .Case("func.func", OpcodeAndVariant{0x6000, 0, 0}) - .Case("func.return", OpcodeAndVariant{0x6001, 0, 0}) - .Case("func.call", OpcodeAndVariant{0x6002, 0, 0}) - .Case("pto.addptr", OpcodeAndVariant{0x1000, 0, 0}) - .Case("pto.alloc_tile", OpcodeAndVariant{0x1001, 0, 0}) - .Case("pto.barrier", OpcodeAndVariant{0x1002, 0, 0}) - .Case("pto.get_block_idx", OpcodeAndVariant{0x0000, 0, 0}) - .Case("pto.get_block_num", OpcodeAndVariant{0x0001, 0, 0}) - .Case("pto.get_subblock_idx", OpcodeAndVariant{0x0002, 0, 0}) - .Case("pto.get_subblock_num", OpcodeAndVariant{0x0003, 0, 0}) - .Case("pto.make_tensor_view", OpcodeAndVariant{0x0004, 0, 0}) - .Case("pto.mgather", OpcodeAndVariant{0x1003, 0, 0}) - .Case("pto.mscatter", OpcodeAndVariant{0x1004, 0, 0}) - .Case("pto.partition_view", OpcodeAndVariant{0x0005, 0, 0}) - .Case("pto.record_event", OpcodeAndVariant{0x1005, 0, 0}) - .Case("pto.tabs", OpcodeAndVariant{0x1006, 0, 0}) - .Case("pto.tadd", OpcodeAndVariant{0x1007, 0, 0}) - .Case("pto.taddc", OpcodeAndVariant{0x1008, 0, 0}) - .Case("pto.tadds", OpcodeAndVariant{0x1009, 0, 0}) - .Case("pto.taddsc", OpcodeAndVariant{0x100A, 0, 0}) - .Case("pto.tand", OpcodeAndVariant{0x100B, 0, 0}) - .Case("pto.tands", OpcodeAndVariant{0x100C, 0, 0}) - .Case("pto.tci", OpcodeAndVariant{0x100D, 0, 0}) - .Case("pto.tcmp", OpcodeAndVariant{0x100E, 0, 0}) - .Case("pto.tcmps", OpcodeAndVariant{0x100F, 0, 0}) - .Case("pto.tcolexpand", OpcodeAndVariant{0x1010, 0, 0}) - .Case("pto.tcolexpandadd", OpcodeAndVariant{0x1011, 0, 0}) - .Case("pto.tcolexpanddiv", OpcodeAndVariant{0x1012, 0, 0}) - .Case("pto.tcolexpandexpdif", OpcodeAndVariant{0x1013, 0, 0}) - .Case("pto.tcolexpandmax", OpcodeAndVariant{0x1014, 0, 0}) - .Case("pto.tcolexpandmin", OpcodeAndVariant{0x1015, 0, 0}) - .Case("pto.tcolexpandmul", OpcodeAndVariant{0x1016, 0, 0}) - .Case("pto.tcolexpandsub", OpcodeAndVariant{0x1017, 0, 0}) - .Case("pto.tcolmax", OpcodeAndVariant{0x1018, 0, 0}) - .Case("pto.tcolmin", OpcodeAndVariant{0x1019, 0, 0}) - .Case("pto.tcolprod", OpcodeAndVariant{0x101A, 0, 0}) - .Case("pto.tcolsum", OpcodeAndVariant{0x101B, 0, 0}) - .Case("pto.tcvt", OpcodeAndVariant{0x101C, 0, 0}) - .Case("pto.tdiv", OpcodeAndVariant{0x101D, 0, 0}) - .Case("pto.tdivs", OpcodeAndVariant{0x101E, 0, 0}) - .Case("pto.texp", OpcodeAndVariant{0x101F, 0, 0}) - .Case("pto.texpands", OpcodeAndVariant{0x1020, 0, 0}) - .Case("pto.textract", OpcodeAndVariant{0x1021, 0, 0}) - .Case("pto.textract_fp", OpcodeAndVariant{0x1022, 0, 0}) - .Case("pto.tfillpad", OpcodeAndVariant{0x1023, 0, 0}) - .Case("pto.tfillpad_expand", OpcodeAndVariant{0x1024, 0, 0}) - .Case("pto.tfillpad_inplace", OpcodeAndVariant{0x1025, 0, 0}) - .Case("pto.tfmod", OpcodeAndVariant{0x1026, 0, 0}) - .Case("pto.tfmods", OpcodeAndVariant{0x1027, 0, 0}) - .Case("pto.tgather", OpcodeAndVariant{0x1028, 0, 0}) - .Case("pto.tgatherb", OpcodeAndVariant{0x1029, 0, 0}) - .Case("pto.tgetval", OpcodeAndVariant{0x102B, 0, 0}) - .Case("pto.timg2col", OpcodeAndVariant{0x102C, 0, 0}) - .Case("pto.tinsert", OpcodeAndVariant{0x102D, 0, 0}) - .Case("pto.tinsert_fp", OpcodeAndVariant{0x102E, 0, 0}) - .Case("pto.tload", OpcodeAndVariant{0x102F, 0, 0}) - .Case("pto.tlog", OpcodeAndVariant{0x1030, 0, 0}) - .Case("pto.tlrelu", OpcodeAndVariant{0x1031, 0, 0}) - .Case("pto.tmax", OpcodeAndVariant{0x1034, 0, 0}) - .Case("pto.tmaxs", OpcodeAndVariant{0x1035, 0, 0}) - .Case("pto.tmin", OpcodeAndVariant{0x1036, 0, 0}) - .Case("pto.tmins", OpcodeAndVariant{0x1037, 0, 0}) - .Case("pto.tmov", OpcodeAndVariant{0x1038, 0, 0}) - .Case("pto.tmov.fp", OpcodeAndVariant{0x1039, 0, 0}) - .Case("pto.tmrgsort", OpcodeAndVariant{0x103A, 0, 0}) - .Case("pto.tmul", OpcodeAndVariant{0x103B, 0, 0}) - .Case("pto.tmuls", OpcodeAndVariant{0x103C, 0, 0}) - .Case("pto.tneg", OpcodeAndVariant{0x103D, 0, 0}) - .Case("pto.tnot", OpcodeAndVariant{0x103E, 0, 0}) - .Case("pto.tor", OpcodeAndVariant{0x103F, 0, 0}) - .Case("pto.tors", OpcodeAndVariant{0x1040, 0, 0}) - .Case("pto.tpartadd", OpcodeAndVariant{0x1041, 0, 0}) - .Case("pto.tpartmax", OpcodeAndVariant{0x1042, 0, 0}) - .Case("pto.tpartmin", OpcodeAndVariant{0x1043, 0, 0}) - .Case("pto.tpartmul", OpcodeAndVariant{0x1044, 0, 0}) - .Case("pto.tprefetch", OpcodeAndVariant{0x1045, 0, 0}) - .Case("pto.tprelu", OpcodeAndVariant{0x1046, 0, 0}) - .Case("pto.tquant", OpcodeAndVariant{0x1047, 0, 0}) - .Case("pto.trecip", OpcodeAndVariant{0x1048, 0, 0}) - .Case("pto.trelu", OpcodeAndVariant{0x1049, 0, 0}) - .Case("pto.trem", OpcodeAndVariant{0x104A, 0, 0}) - .Case("pto.trems", OpcodeAndVariant{0x104B, 0, 0}) - .Case("pto.treshape", OpcodeAndVariant{0x104C, 0, 0}) - .Case("pto.trowexpand", OpcodeAndVariant{0x104D, 0, 0}) - .Case("pto.trowexpandadd", OpcodeAndVariant{0x104E, 0, 0}) - .Case("pto.trowexpandexpdif", OpcodeAndVariant{0x104F, 0, 0}) - .Case("pto.trowexpandmax", OpcodeAndVariant{0x1050, 0, 0}) - .Case("pto.trowexpandmin", OpcodeAndVariant{0x1051, 0, 0}) - .Case("pto.trowmax", OpcodeAndVariant{0x1052, 0, 0}) - .Case("pto.trowmin", OpcodeAndVariant{0x1053, 0, 0}) - .Case("pto.trowsum", OpcodeAndVariant{0x1054, 0, 0}) - .Case("pto.trsqrt", OpcodeAndVariant{0x1055, 0, 0}) - .Case("pto.tscatter", OpcodeAndVariant{0x1056, 0, 0}) - .Case("pto.tsel", OpcodeAndVariant{0x1057, 0, 0}) - .Case("pto.tsels", OpcodeAndVariant{0x1058, 0, 0}) - .Case("pto.tset_img2col_padding", OpcodeAndVariant{0x1059, 0, 0}) - .Case("pto.tset_img2col_rpt", OpcodeAndVariant{0x105A, 0, 0}) - .Case("pto.tsetfmatrix", OpcodeAndVariant{0x105B, 0, 0}) - .Case("pto.tsethf32mode", OpcodeAndVariant{0x105C, 0, 0}) - .Case("pto.tsettf32mode", OpcodeAndVariant{0x105D, 0, 0}) - .Case("pto.tsetval", OpcodeAndVariant{0x105E, 0, 0}) - .Case("pto.tshl", OpcodeAndVariant{0x105F, 0, 0}) - .Case("pto.tshls", OpcodeAndVariant{0x1060, 0, 0}) - .Case("pto.tshr", OpcodeAndVariant{0x1061, 0, 0}) - .Case("pto.tshrs", OpcodeAndVariant{0x1062, 0, 0}) - .Case("pto.tsort32", OpcodeAndVariant{0x1063, 0, 0}) - .Case("pto.tsqrt", OpcodeAndVariant{0x1064, 0, 0}) - .Case("pto.tstore", OpcodeAndVariant{0x1065, 0, 0}) - .Case("pto.tstore_fp", OpcodeAndVariant{0x1066, 0, 0}) - .Case("pto.tsub", OpcodeAndVariant{0x1067, 0, 0}) - .Case("pto.tsubc", OpcodeAndVariant{0x1068, 0, 0}) - .Case("pto.tsubs", OpcodeAndVariant{0x1069, 0, 0}) - .Case("pto.tsubsc", OpcodeAndVariant{0x106A, 0, 0}) - .Case("pto.trowexpandsub", OpcodeAndVariant{0x106B, 0, 0}) - .Case("pto.ttrans", OpcodeAndVariant{0x106C, 0, 0}) - .Case("pto.ttri", OpcodeAndVariant{0x106D, 0, 0}) - .Case("pto.txor", OpcodeAndVariant{0x106E, 0, 0}) - .Case("pto.txors", OpcodeAndVariant{0x106F, 0, 0}) - .Case("pto.wait_event", OpcodeAndVariant{0x1070, 0, 0}) - .Case("pto.tprint", OpcodeAndVariant{0x1071, 0, 0}) - .Case("pto.subview", OpcodeAndVariant{0x1072, 0, 0}) - .Case("pto.trowexpanddiv", OpcodeAndVariant{0x1073, 0, 0}) - .Case("pto.trowexpandmul", OpcodeAndVariant{0x1074, 0, 0}) - .Case("pto.tdequant", OpcodeAndVariant{0x1075, 0, 0}) - .Case("pto.taxpy", OpcodeAndVariant{0x1076, 0, 0}) - .Case("pto.thistogram", OpcodeAndVariant{0x1077, 0, 0}) - .Case("pto.tget_scale_addr", OpcodeAndVariant{0x1078, 0, 0}) - .Case("pto.trowargmax", OpcodeAndVariant{0x1079, 0, 0}) - .Case("pto.trowargmin", OpcodeAndVariant{0x107A, 0, 0}) - .Case("pto.tcolargmax", OpcodeAndVariant{0x107B, 0, 0}) - .Case("pto.tcolargmin", OpcodeAndVariant{0x107C, 0, 0}) - .Case("pto.tsync", OpcodeAndVariant{0x107D, 0, 0}) - .Case("pto.reserve_buffer", OpcodeAndVariant{0x107E, 0, 0}) - .Case("pto.import_reserved_buffer", OpcodeAndVariant{0x107F, 0, 0}) - .Case("pto.aic_initialize_pipe", OpcodeAndVariant{0x1080, 0, 0}) - .Case("pto.aiv_initialize_pipe", OpcodeAndVariant{0x1081, 0, 0}) - .Case("pto.tpush_to_aiv", OpcodeAndVariant{0x1082, 0, 0}) - .Case("pto.tpush_to_aic", OpcodeAndVariant{0x1083, 0, 0}) - .Case("pto.tpop_from_aic", OpcodeAndVariant{0x1084, 0, 0}) - .Case("pto.tpop_from_aiv", OpcodeAndVariant{0x1085, 0, 0}) - .Case("pto.tfree_from_aic", OpcodeAndVariant{0x1086, 0, 0}) - .Case("pto.tfree_from_aiv", OpcodeAndVariant{0x1087, 0, 0}) - .Case("pto.set_validshape", OpcodeAndVariant{0x1088, 0, 0}) - .Case("pto.tconcat", OpcodeAndVariant{0x1089, 0, 0}) - .Case("pto.trowprod", OpcodeAndVariant{0x108A, 0, 0}) - .Case("pto.initialize_l2g2l_pipe", OpcodeAndVariant{0x108B, 0, 0}) - .Case("pto.initialize_l2l_pipe", OpcodeAndVariant{0x108C, 0, 0}) - .Case("pto.tpush", OpcodeAndVariant{0x108D, 0, 0}) - .Case("pto.declare_tile", OpcodeAndVariant{0x108E, 0, 0}) - .Case("pto.tpop", OpcodeAndVariant{0x108F, 0, 0}) - .Case("pto.tfree", OpcodeAndVariant{0x1090, 0, 0}) - .Case("pto.comm.tput", OpcodeAndVariant{0x1091, 0, 0}) - .Case("pto.comm.tget", OpcodeAndVariant{0x1092, 0, 0}) - .Case("pto.comm.tnotify", OpcodeAndVariant{0x1093, 0, 0}) - .Case("pto.comm.twait", OpcodeAndVariant{0x1094, 0, 0}) - .Case("pto.comm.ttest", OpcodeAndVariant{0x1095, 0, 0}) - .Case("pto.comm.tbroadcast", OpcodeAndVariant{0x1096, 0, 0}) - .Case("pto.comm.tgather", OpcodeAndVariant{0x1097, 0, 0}) - .Case("pto.comm.tscatter", OpcodeAndVariant{0x1098, 0, 0}) - .Case("pto.comm.treduce", OpcodeAndVariant{0x1099, 0, 0}) - .Case("pto.tpartargmax", OpcodeAndVariant{0x109A, 0, 0}) - .Case("pto.tpartargmin", OpcodeAndVariant{0x109B, 0, 0}) - .Case("scf.for", OpcodeAndVariant{0x4000, 0, 0}) - .Case("scf.if", OpcodeAndVariant{0x4001, 0, 0}) - .Case("scf.yield", OpcodeAndVariant{0x4002, 0, 0}) - .Case("pto.section.cube", - OpcodeAndVariant{0x0006, kHasVariant, kSectionCubeVariant}) - .Case("pto.section.vector", - OpcodeAndVariant{0x0006, kHasVariant, kSectionVectorVariant}) - .Case("pto.tgemv", - OpcodeAndVariant{0x102A, kHasVariant, kVariantDefault}) - .Case("pto.tgemv.acc", - OpcodeAndVariant{0x102A, kHasVariant, kVariantAcc}) - .Case("pto.tgemv.bias", - OpcodeAndVariant{0x102A, kHasVariant, kVariantBias}) - .Case("pto.tgemv.mx", - OpcodeAndVariant{0x102A, kHasVariant, kVariantMx}) - .Case("pto.tgemv.mx.acc", - OpcodeAndVariant{0x102A, kHasVariant, kVariantMxAcc}) - .Case("pto.tgemv.mx.bias", - OpcodeAndVariant{0x102A, kHasVariant, kVariantMxBias}) - .Case("pto.tmatmul", - OpcodeAndVariant{0x1032, kHasVariant, kVariantDefault}) - .Case("pto.tmatmul.acc", - OpcodeAndVariant{0x1032, kHasVariant, kVariantAcc}) - .Case("pto.tmatmul.bias", - OpcodeAndVariant{0x1032, kHasVariant, kVariantBias}) - .Case("pto.tmatmul.mx", - OpcodeAndVariant{0x1033, kHasVariant, kVariantDefault}) - .Case("pto.tmatmul.mx.acc", - OpcodeAndVariant{0x1033, kHasVariant, kVariantAcc}) - .Case("pto.tmatmul.mx.bias", - OpcodeAndVariant{0x1033, kHasVariant, kVariantBias}) - .Default(std::nullopt); -} - -inline const char *fullNameFromOpcodeVariant(uint16_t opcode, uint8_t variant) { - const OpInfo *info = lookupByOpcode(opcode); - if (!info) return nullptr; - if (opcode == kTscatterMaskOpcode) return "pto.tscatter"; - if (!info->has_variant_u8) return info->name; - switch (opcode) { - case 0x0006: - switch (variant) { - case kSectionCubeVariant: return "pto.section.cube"; - case kSectionVectorVariant: return "pto.section.vector"; - default: return info->name; - } - case 0x102A: - switch (variant) { - case kVariantDefault: return "pto.tgemv"; - case kVariantAcc: return "pto.tgemv.acc"; - case kVariantBias: return "pto.tgemv.bias"; - case kVariantMx: return "pto.tgemv.mx"; - case kVariantMxAcc: return "pto.tgemv.mx.acc"; - case kVariantMxBias: return "pto.tgemv.mx.bias"; - default: return info->name; - } - case 0x1032: - switch (variant) { - case kVariantDefault: return "pto.tmatmul"; - case kVariantAcc: return "pto.tmatmul.acc"; - case kVariantBias: return "pto.tmatmul.bias"; - default: return info->name; - } - case 0x1033: - switch (variant) { - case kVariantDefault: return "pto.tmatmul.mx"; - case kVariantAcc: return "pto.tmatmul.mx.acc"; - case kVariantBias: return "pto.tmatmul.mx.bias"; - default: return info->name; - } - default: return info->name; - } -} - -inline std::optional lookupOperandsByVariant(uint16_t opcode, uint8_t variant) { - switch (opcode) { - case 0x102A: - switch (variant) { - case kVariantDefault: return kTgemvOperandCount; - case kVariantAcc: return kTgemvAccOperandCount; - case kVariantBias: return kTgemvBiasOperandCount; - case kVariantMx: return kTgemvMxOperandCount; - case kVariantMxAcc: return kTgemvMxAccOperandCount; - case kVariantMxBias: return kTgemvMxBiasOperandCount; - default: return std::nullopt; - } - case 0x1032: - switch (variant) { - case kVariantDefault: return kTmatmulOperandCount; - case kVariantAcc: return kTmatmulAccOperandCount; - case kVariantBias: return kTmatmulBiasOperandCount; - default: return std::nullopt; - } - case 0x1033: - switch (variant) { - case kVariantDefault: return kTmatmulMxOperandCount; - case kVariantAcc: return kTmatmulMxAccOperandCount; - case kVariantBias: return kTmatmulMxBiasOperandCount; - default: return std::nullopt; - } - default: return std::nullopt; - } -} +std::optional lookupOpcodeAndVariantByFullName(llvm::StringRef fullName); +const char *fullNameFromOpcodeVariant(uint16_t opcode, uint8_t variant); +std::optional lookupOperandsByVariant(uint16_t opcode, uint8_t variant); } // namespace ptobc::v0 diff --git a/tools/ptobc/tests/opcode_coverage_check.py b/tools/ptobc/tests/opcode_coverage_check.py index 757c7ea40..c4d99ff5b 100755 --- a/tools/ptobc/tests/opcode_coverage_check.py +++ b/tools/ptobc/tests/opcode_coverage_check.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# -*- coding: utf-8 -*- # Copyright (c) 2026 Huawei Technologies Co., Ltd. # This program is free software, you can redistribute it and/or modify it under the terms and conditions of # CANN Open Software License Agreement Version 2.0 (the "License").