Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ if (TRITON_BUILD_PYTHON_MODULE)

LinalgExtTransforms
TritonExtTransforms
LinalgExtAnalysis

LinalgToLinked
LinkedToHIVM
Expand Down
5 changes: 3 additions & 2 deletions backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def __init__(self, target: str) -> None:
self.binary_ext = "mcfatbin"
elif self.driver.target == "ascend":
self.binary_ext = "npubin"
self.capability = target.arch
else:
raise RuntimeError(f"Target '{self.target_type}' is not supported.")

Expand Down Expand Up @@ -249,7 +250,7 @@ def add_stages(self, stages, options, language=None):
)
stages["npubin"] = (
lambda src, metadata: linalg_to_bin_enable_npu_compile(
src, metadata, options
src, metadata, options, self.capability
)
)
else:
Expand All @@ -264,7 +265,7 @@ def add_stages(self, stages, options, language=None):
)
stages["npubin"] = (
lambda src, metadata: linalg_to_bin_enable_npu_compile(
src, metadata, options
src, metadata, options, self.capability
)
)
else:
Expand Down
5 changes: 4 additions & 1 deletion backend/npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ def ttsharedir_to_linkedir(mod, metadata, opt, *, named_ops=False):
dicp_triton.passes.linked_npu.add_scalar_to_1d_tensor(pm)
dicp_triton.passes.linked_npu.add_linalg_to_linked(pm, False, True)
dicp_triton.passes.linked_npu.add_linked_to_hivm(pm)
dicp_triton.passes.linked_npu.add_npu_unroll_pipeline(pm)
pm.run(mod)

# TODO(zmz): 修改test_path 中内容,暂时在python中处理,bishengir-compile后续会支持,去掉这里逻辑。
Expand Down Expand Up @@ -683,7 +684,7 @@ def _parse_linalg_metadata(linalg: str, metadata: dict):
return linalg, metadata


def linalg_to_bin_enable_npu_compile(linalg: str, metadata, opt):
def linalg_to_bin_enable_npu_compile(linalg: str, metadata, opt, capability):
linalg, metadata = _parse_linalg_metadata(linalg, metadata)
with tempfile.TemporaryDirectory() as tmpdir:
ttadapter_path = os.path.join(tmpdir, "kernel.ttadapter.mlir")
Expand All @@ -706,6 +707,8 @@ def linalg_to_bin_enable_npu_compile(linalg: str, metadata, opt):
_compile_option_list += [
f"--enable-auto-multi-buffer={multibuffer}",
]
if capability:
_compile_option_list += [f"--target={capability}"]

if _is_ascend_sanitizer_enabled():
_compile_option_list += ["--enable-sanitizer=true"]
Expand Down
166 changes: 166 additions & 0 deletions compiler/include/dicp/Dialect/LinalgExt/Analysis/DimAnalyzer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
#ifndef DICP_DIALECT_LINALGEXT_TRANSFORMS_DIMANALYZER_H
#define DICP_DIALECT_LINALGEXT_TRANSFORMS_DIMANALYZER_H

#include "dicp/Dialect/LinalgExt/Analysis/StageDependencyAnalyzer.h"

#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"

#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"

#include <numeric>
#include <queue>
#include <vector>

namespace mlir {
namespace dicp {

/// Classification of a dimension's role in the computation graph.
/// This helps determine if a dimension is safe to tile or parallelize.
enum class DimKind {
Unknown, // No specific property inferred yet.
Parallel, // Dimension implies independent iterations (safe to tile).
Reduction, // Dimension is collapsed/reduced (requires accumulation).
Broadcast, // Dimension is replicated (data invariant along this axis).
Complex // Dimension undergoes complex transformation (e.g., non-affine
// reshape).
};

std::string toString(DimKind k);

/// Disjoint Set Union (DSU) for tracking dimension equivalence and properties.
///
/// This class implements a Disjoint Set data structure (Union-Find)
/// specifically designed for Tensor/MemRef dimensions. It serves two main
/// purposes:
/// 1. **Equivalence Tracking**: Determines which dimensions across different
/// values
/// represent the same logical axis (e.g., the 'N' dimension in a Matmul
/// propagating through element-wise adds).
/// 2. **Property Propagation**: Merges semantic properties (DimKind) when
/// dimensions
/// are unified. For example, if a dimension is used as a Reduction iterator
/// in one operation, that property propagates to all equivalent dimensions
/// in the set.
class DimensionDisjointSet {
public:
explicit DimensionDisjointSet(size_t size = 0) { resize(size); }

/// Allocates `n` new dimension IDs in the set.
/// \return The ID of the first allocated dimension.
int64_t allocate(size_t n = 1);

/// Finds the representative (root) ID for the set containing dimension `i`.
/// Implements path compression for amortized constant time lookups.
int64_t find(int64_t i);

/// Merges the sets containing dimensions `i` and `j`.
/// This also merges the `DimKind` properties of both roots using
/// `mergeKinds`.
void unionSets(int64_t i, int64_t j);

/// Updates the DimKind property for the set containing dimension `i`.
/// The new kind is merged with the existing kind to ensure safety (e.g.,
/// Reduction is sticky).
void setKind(int64_t i, DimKind k);

/// Retrieves the DimKind property of the set containing dimension `i`.
DimKind getKind(int64_t i);

private:
/// Resizes the internal storage to accommodate `n` dimensions.
void resize(size_t n);

/// Defines the logic for combining two dimension kinds.
/// Hierarchy of "stickiness": Complex > Reduction > Broadcast/Parallel.
DimKind mergeKinds(DimKind a, DimKind b);

std::vector<int64_t> parent; // Parent pointers for DSU.
std::vector<DimKind> kind; // Properties associated with each root.
};

/// DimAnalyzer:
/// Analyzes a specific execution stage (StageInfo) to determine tiling
/// strategies.
///
/// The analyzer constructs a constraint graph where nodes are tensor dimensions
/// and edges represent data flow relationships. It uses a Breadth-First Search
/// (BFS) approach to traverse operations and propagate dimension IDs.
///
/// Algorithm Overview:
/// 1. **Initialization**: Seeds the analysis with stage inputs (operands
/// defined outside the stage).
/// 2. **BFS Propagation**: Traverses the def-use chains. For each operation, it
/// uses specific handlers (e.g., processMatmulOp) to bind input dimensions to
/// output dimensions.
/// 3. **Anchor Heuristic**: Identifies the "Anchor" operation (typically the
/// final LinalgOp) to interpret the resulting loops.
/// 4. **Tiling Selection**: Checks the properties of the Anchor's loops in the
/// DSU to recommend outermost parallel loops for tiling.
class DimAnalyzer {
public:
explicit DimAnalyzer(const StageInfo &stage);

/// Analyzes the stage operations and returns indices of loops recommended for
/// tiling. The indices correspond to the loop nest of the "Anchor" operation.
SmallVector<int64_t> analyzeAndGetTilingDims();

private:
const StageInfo &stage_;
// Quick lookup for ops belonging to this stage.
DenseSet<Operation *> stageOps_;
DimensionDisjointSet dsu_;
// Maps SSA Value -> [Dim IDs]
DenseMap<Value, std::vector<int64_t>> valueDims_;

// BFS State passed to handlers to allow them to enqueue new values.
using BFSQueue = std::queue<Value>;
using VisitedSet = DenseSet<Value>;

/// Drives the traversal of the data flow graph.
void processBFS();

/// Dispatches the operation to the appropriate handler.
/// \return true if the operation was handled, false otherwise.
bool processOperation(Operation *op, Value current, BFSQueue &q,
VisitedSet &v);

/// Lazily retrieves or allocates unique IDs for the dimensions of a Value.
std::vector<int64_t> getOrAllocateDims(Value v);

/// Helper to strictly bind all dimensions of v1 to v2 (1-to-1 mapping).
/// Used for Elementwise, Copy, etc.
void bindDimensions(Value v1, Value v2);

// --- Op Handlers ---
// Each handler interprets the semantics of the op to union input/output
// dimensions correctly.

void processElementwise(Operation *op, Value current);
void processMatmulOp(linalg::MatmulOp op);
void processReduceOp(linalg::ReduceOp op);
void processTransposeOp(linalg::TransposeOp op);
void processBroadcastOp(linalg::BroadcastOp op);
void processLinalgOpGeneric(linalg::LinalgOp op);
void processReshapeOp(Operation *op);
void processConcatOp(tensor::ConcatOp op);
void processPadOp(tensor::PadOp op);
void processExtractSliceOp(tensor::ExtractSliceOp op);
void processInsertSliceOp(tensor::InsertSliceOp op);

// Handlers that may need to continue BFS propagation explicitly
void processMemrefCopyOp(memref::CopyOp op, Value current, BFSQueue &q,
VisitedSet &v);
void processMemrefCastOp(Operation *op);
void processBufferizationToTensor(bufferization::ToTensorOp op);
void processMaterializeOp(bufferization::MaterializeInDestinationOp op,
Value current, BFSQueue &q, VisitedSet &v);
};

} // namespace dicp
} // namespace mlir

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#ifndef DICP_DIALECT_LINALGEXT_TRANSFORMS_STAGEDEPENDENCYANALYZER_H
#define DICP_DIALECT_LINALGEXT_TRANSFORMS_STAGEDEPENDENCYANALYZER_H

#include "mlir/Analysis/AliasAnalysis.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Operation.h"

#include "llvm/ADT/SetVector.h"

#include <set>
#include <vector>

namespace mlir {
namespace dicp {

/// Represents a single pipeline stage.
/// A stage is a sequence of operations that execute together.
/// Synchronization operations (SyncBlockWaitOp) typically delimit stage
/// boundaries.
struct StageInfo {
int id = -1;
std::vector<Operation *> ops;
// IDs of stages that this stage depends on
std::set<int> preds;
// IDs of stages that depend on this stage
std::set<int> succs;
bool hasSync = false;
};

// StageDependencyAnalyzer:
// 1. Partitioning a loop body into "stages" based on synchronization primitives
// (hivm::SyncBlockWaitOp).
// 2. Building a dependency graph between these stages considering both:
// - SSA Data Flow (Producer-Consumer relationships).
// - Memory Dependencies (Read-After-Write via
// AliasAnalysis).
// 3. Computing a topological ordering (levels) to detect cycles and determine
// a valid execution schedule.
// 4. Physically reordering the IR operations to match the valid schedule.
//
class StageDependencyAnalyzer {
public:
StageDependencyAnalyzer(scf::ForOp forOp, AliasAnalysis &aliasAnalysis)
: forOp(forOp), aliasAnalysis(aliasAnalysis) {}

/// Runs the analysis, computes the topological sort, and physically reorders
/// the operations in the loop body.
/// Returns the ordered list of StageInfo on success, or failure if a cycle is
/// detected.
FailureOr<std::vector<StageInfo>> runAndReorder(RewriterBase &rewriter);

private:
/// Internal node structure for the dependency graph.
struct StageNode {
int id;
StageInfo *stageInfo;
int level = 0; // Topological level (depth)

// Memory dependencies
llvm::SetVector<Value> readValues;
llvm::SetVector<Value> writeValues;

// SSA Value dependencies
llvm::SetVector<Value> producedValues; // Values defined in this stage
llvm::SetVector<Value> consumedValues; // Values used in this stage
};

scf::ForOp forOp;
AliasAnalysis &aliasAnalysis;
std::vector<StageInfo> stages;
std::vector<StageNode> nodes;

/// Scans the loop body to populate the `stages` vector.
void collectStages();

/// Collects SSA definitions/uses and Memory Read/Write effects for each
/// stage.
void collectEffects();

/// Builds the directed graph edges based on SSA and Memory conflicts.
void buildDependencyGraph();

/// Computes the topological level of each node using DFS.
/// Returns failure if a cycle is detected.
LogicalResult computeStageLevels();

/// Sorts the `stages` vector based on the computed topological levels.
void reorderStagesLogical();

/// Moves the operations in the IR to match the logical order of `stages`.
void materializeScheduleToIR();
};

} // namespace dicp
} // namespace mlir

#endif // DICP_DIALECT_LINALGEXT_TRANSFORMS_STAGEDEPENDENCYANALYZER_H
3 changes: 3 additions & 0 deletions compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ std::unique_ptr<OperationPass<mlir::func::FuncOp>> createScalarTo1DTensorPass();
std::unique_ptr<OperationPass<mlir::func::FuncOp>>
createNormalizeSliceOpsPass();

std::unique_ptr<OperationPass<mlir::func::FuncOp>>
createNPUUnroolPipelinePass();

#define GEN_PASS_REGISTRATION
#include "dicp/Dialect/LinalgExt/Transforms/Passes.h.inc"

Expand Down
12 changes: 12 additions & 0 deletions compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,16 @@ def NormalizeSliceOps : Pass<"normalize-slice-ops", "func::FuncOp"> {
let dependentDialects = ["mlir::tensor::TensorDialect"];
}

def NPUUnroolPipeline : Pass<"npu-unrool-pipeline", "func::FuncOp"> {
let summary = "DLC Pipelines.";
let constructor = "mlir::dicp::LinalgExt::createNPUUnroolPipelinePass()";
let dependentDialects = [
"mlir::arith::ArithDialect",
"mlir::memref::MemRefDialect",
"mlir::tensor::TensorDialect",
"mlir::bufferization::BufferizationDialect",
"mlir::func::FuncDialect"
];
}

#endif
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "dicp/Conversion/TritonToUnstructure/UnstructureConversionPass.h"
#include "dicp/Utils/Utils.h"

#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h"

#include "bishengir/Dialect/Annotation/IR/Annotation.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
Expand Down Expand Up @@ -680,7 +682,7 @@ void replacePtrLoopArguments(Operation *rootOp,
op.getLoc(), rewriter.getI32Type(), ValueRange({}))
->getResult(0);
if (auto forOp = dyn_cast<scf::ForOp>(op.getOperation())) {
newOp = rewriter.create<scf::ForOp>(
auto createdFor = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(),
constructOperands(forOp.getInitArgs(), tempVar, mapping),
Expand All @@ -701,6 +703,13 @@ void replacePtrLoopArguments(Operation *rootOp,
yieldOp.getLoc(),
constructOperands(yieldOp.getOperands(), tempVar, mapping));
});

// propagate Triton-specific loop attribute if present on the old for
if (forOp->hasAttr(triton::kNumStagesAttrName))
createdFor->setAttr(triton::kNumStagesAttrName,
forOp->getAttr(triton::kNumStagesAttrName));

newOp = createdFor;
} else if (auto whileOp = dyn_cast<scf::WhileOp>(op.getOperation())) {
newOp = rewriter.create<scf::WhileOp>(
whileOp.getLoc(), constructTypes(whileOp->getResultTypes()),
Expand Down
Loading