From addb94039669b8094dea22885334e42a7909b761 Mon Sep 17 00:00:00 2001 From: jimmychou <1049859649@qq.com> Date: Mon, 25 May 2026 16:48:05 +0800 Subject: [PATCH 1/6] Add precision template attrs for vector math ops --- include/PTO/IR/PTOAttrs.td | 104 +++++++++++ include/PTO/IR/PTOOps.td | 30 ++- lib/PTO/Transforms/PTOToEmitC.cpp | 173 +++++++++++++++++- lib/PTO/Transforms/PTOViewToMemref.cpp | 9 +- .../lit/pto/tcolexpanddiv_precision_emitc.pto | 19 ++ test/lit/pto/tdiv_precision_emitc.pto | 21 +++ test/lit/pto/texp_precision_emitc.pto | 18 ++ test/lit/pto/tfmod_precision_emitc.pto | 19 ++ test/lit/pto/tlog_precision_emitc.pto | 18 ++ test/lit/pto/trecip_precision_emitc.pto | 18 ++ test/lit/pto/trem_precision_emitc.pto | 20 ++ .../lit/pto/trowexpanddiv_precision_emitc.pto | 20 ++ test/lit/pto/trsqrt_precision_emitc.pto | 19 ++ test/lit/pto/tsqrt_precision_emitc.pto | 18 ++ 14 files changed, 483 insertions(+), 23 deletions(-) create mode 100644 test/lit/pto/tcolexpanddiv_precision_emitc.pto create mode 100644 test/lit/pto/tdiv_precision_emitc.pto create mode 100644 test/lit/pto/texp_precision_emitc.pto create mode 100644 test/lit/pto/tfmod_precision_emitc.pto create mode 100644 test/lit/pto/tlog_precision_emitc.pto create mode 100644 test/lit/pto/trecip_precision_emitc.pto create mode 100644 test/lit/pto/trem_precision_emitc.pto create mode 100644 test/lit/pto/trowexpanddiv_precision_emitc.pto create mode 100644 test/lit/pto/trsqrt_precision_emitc.pto create mode 100644 test/lit/pto/tsqrt_precision_emitc.pto diff --git a/include/PTO/IR/PTOAttrs.td b/include/PTO/IR/PTOAttrs.td index fc1e0ca2f..79b3ff227 100644 --- a/include/PTO/IR/PTOAttrs.td +++ b/include/PTO/IR/PTOAttrs.td @@ -424,6 +424,110 @@ def PTO_SaturationModeAttr : EnumAttr, + I32EnumAttrCase<"HighPrecision", 1, "high_precision"> + ]>; + +def PTO_DivPrecisionAttr : EnumAttr { + let summary = "TDIV precision mode attribute"; +} +//===----------------------------------------------------------------------===// +// TEXP template controls +//===----------------------------------------------------------------------===// + +def PTO_ExpPrecisionEnum : PTO_I32Enum< + "ExpPrecision", "PTO TEXP precision mode", [ + I32EnumAttrCase<"Default", 0, "default">, + I32EnumAttrCase<"HighPrecision", 1, "high_precision"> + ]>; + +def PTO_ExpPrecisionAttr : EnumAttr { + let summary = "TEXP precision mode attribute"; +} +//===----------------------------------------------------------------------===// +// TLOG template controls +//===----------------------------------------------------------------------===// + +def PTO_LogPrecisionEnum : PTO_I32Enum< + "LogPrecision", "PTO TLOG precision mode", [ + I32EnumAttrCase<"Default", 0, "default">, + I32EnumAttrCase<"HighPrecision", 1, "high_precision"> + ]>; + +def PTO_LogPrecisionAttr : EnumAttr { + let summary = "TLOG precision mode attribute"; +} +//===----------------------------------------------------------------------===// +// TRECIP template controls +//===----------------------------------------------------------------------===// + +def PTO_RecipPrecisionEnum : PTO_I32Enum< + "RecipPrecision", "PTO TRECIP precision mode", [ + I32EnumAttrCase<"Default", 0, "default">, + I32EnumAttrCase<"HighPrecision", 1, "high_precision"> + ]>; + +def PTO_RecipPrecisionAttr : EnumAttr { + let summary = "TRECIP precision mode attribute"; +} +//===----------------------------------------------------------------------===// +// TREM template controls +//===----------------------------------------------------------------------===// + +def PTO_RemPrecisionEnum : PTO_I32Enum< + "RemPrecision", "PTO TREM precision mode", [ + I32EnumAttrCase<"Default", 0, "default">, + I32EnumAttrCase<"HighPrecision", 1, "high_precision"> + ]>; + +def PTO_RemPrecisionAttr : EnumAttr { + let summary = "TREM precision mode attribute"; +} +//===----------------------------------------------------------------------===// +// TRSQRT template controls +//===----------------------------------------------------------------------===// + +def PTO_RsqrtPrecisionEnum : PTO_I32Enum< + "RsqrtPrecision", "PTO TRSQRT precision mode", [ + I32EnumAttrCase<"Default", 0, "default">, + I32EnumAttrCase<"HighPrecision", 1, "high_precision"> + ]>; + +def PTO_RsqrtPrecisionAttr : EnumAttr { + let summary = "TRSQRT precision mode attribute"; +} +//===----------------------------------------------------------------------===// +// TSQRT template controls +//===----------------------------------------------------------------------===// + +def PTO_SqrtPrecisionEnum : PTO_I32Enum< + "SqrtPrecision", "PTO TSQRT precision mode", [ + I32EnumAttrCase<"Default", 0, "default">, + I32EnumAttrCase<"HighPrecision", 1, "high_precision"> + ]>; + +def PTO_SqrtPrecisionAttr : EnumAttr { + let summary = "TSQRT precision mode attribute"; +} +//===----------------------------------------------------------------------===// +// TFMOD template controls +//===----------------------------------------------------------------------===// + +def PTO_FmodPrecisionEnum : PTO_I32Enum< + "FmodPrecision", "PTO TFMOD precision mode", [ + I32EnumAttrCase<"Default", 0, "default">, + I32EnumAttrCase<"HighPrecision", 1, "high_precision"> + ]>; + +def PTO_FmodPrecisionAttr : EnumAttr { + let summary = "TFMOD precision mode attribute"; +} +//===----------------------------------------------------------------------===// // TStore template controls //===----------------------------------------------------------------------===// diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index 344b399a6..a944db149 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -3425,7 +3425,8 @@ def TColExpandDivOp : PTO_TOp<"tcolexpanddiv", [ let arguments = (ins PTODpsType:$src0, PTODpsType:$src1, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$precisionType ); let results = (outs); @@ -3804,7 +3805,8 @@ def TDivOp : PTO_TOp<"tdiv", [ let arguments = (ins PTODpsType:$src0, PTODpsType:$src1, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$precisionType ); let results = (outs); @@ -3860,7 +3862,8 @@ def TFModOp : PTO_TOp<"tfmod", [ let arguments = (ins PTODpsType:$src0, PTODpsType:$src1, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$precisionType ); let results = (outs); @@ -3917,7 +3920,8 @@ def TExpOp : PTO_TOp<"texp", [ let arguments = (ins PTODpsType:$src, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$precisionType ); let results = (outs); @@ -4355,7 +4359,8 @@ def TLogOp : PTO_TOp<"tlog", [ let arguments = (ins PTODpsType:$src, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$precisionType ); let results = (outs); @@ -5118,7 +5123,8 @@ def TRecipOp: PTO_TOp<"trecip", [ let arguments = (ins PTODpsType:$src, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$precisionType ); let results = (outs); @@ -5182,7 +5188,8 @@ def TRemOp: PTO_TOp<"trem", [ PTODpsType:$src0, PTODpsType:$src1, PTODpsType:$tmp, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$precisionType ); let results = (outs); @@ -5311,7 +5318,8 @@ def TRowExpandDivOp: PTO_TOp<"trowexpanddiv", [ PTODpsType:$src0, PTODpsType:$src1, Optional:$tmp, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$precisionType ); let results = (outs); @@ -5702,7 +5710,8 @@ def TRsqrtOp: PTO_TOp<"trsqrt", [ let arguments = (ins PTODpsType:$src, PTODpsType:$dst, - Optional:$tmp + Optional:$tmp, + DefaultValuedAttr:$precisionType ); let results = (outs); @@ -6012,7 +6021,8 @@ def TSqrtOp: PTO_TOp<"tsqrt", [ let arguments = (ins PTODpsType:$src, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$precisionType ); let results = (outs); diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index b3f0c6bbd..235f64d66 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -7683,15 +7683,31 @@ struct PTOColExpandDivToEmitC : public OpConversionPattern LogicalResult matchAndRewrite(pto::TColExpandDivOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); Value src0 = peelUnrealized(adaptor.getSrc0()); Value src1 = peelUnrealized(adaptor.getSrc1()); Value dst = peelUnrealized(adaptor.getDst()); + ArrayAttr templateArgs; + if (op.getPrecisionType() != pto::DivPrecision::Default) { + StringRef precisionTok; + switch (op.getPrecisionType()) { + case pto::DivPrecision::Default: + precisionTok = "pto::DivAlgorithm::DEFAULT"; + break; + case pto::DivPrecision::HighPrecision: + precisionTok = "pto::DivAlgorithm::HIGH_PRECISION"; + break; + } + templateArgs = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, precisionTok)}); + } + rewriter.create( loc, TypeRange{}, "TCOLEXPANDDIV", /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, + /*templateArgs=*/templateArgs, /*operands=*/ValueRange{dst, src0, src1}); rewriter.eraseOp(op); @@ -8136,10 +8152,26 @@ struct PTODivToTDIV : public OpConversionPattern { Value src0 = peelUnrealized(adaptor.getSrc0()); Value src1 = peelUnrealized(adaptor.getSrc1()); Value dst = peelUnrealized(adaptor.getDst()); + auto *ctx = rewriter.getContext(); + + ArrayAttr templateArgs; + if (op.getPrecisionType() != pto::DivPrecision::Default) { + StringRef precisionTok; + switch (op.getPrecisionType()) { + case pto::DivPrecision::Default: + precisionTok = "pto::DivAlgorithm::DEFAULT"; + break; + case pto::DivPrecision::HighPrecision: + precisionTok = "pto::DivAlgorithm::HIGH_PRECISION"; + break; + } + templateArgs = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, precisionTok)}); + } rewriter.create( op.getLoc(), TypeRange{}, "TDIV", - ArrayAttr{}, ArrayAttr{}, + ArrayAttr{}, templateArgs, ValueRange{dst, src0, src1}); rewriter.eraseOp(op); @@ -8210,13 +8242,29 @@ struct PTOExpToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TExpOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); Value src = peelUnrealized(adaptor.getSrc()); Value dst = peelUnrealized(adaptor.getDst()); + ArrayAttr templateArgs; + if (op.getPrecisionType() != pto::ExpPrecision::Default) { + StringRef precisionTok; + switch (op.getPrecisionType()) { + case pto::ExpPrecision::Default: + precisionTok = "pto::ExpAlgorithm::DEFAULT"; + break; + case pto::ExpPrecision::HighPrecision: + precisionTok = "pto::ExpAlgorithm::HIGH_PRECISION"; + break; + } + templateArgs = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, precisionTok)}); + } + rewriter.create( loc, TypeRange{}, "TEXP", - ArrayAttr{}, ArrayAttr{}, + ArrayAttr{}, templateArgs, ValueRange{dst, src}); rewriter.eraseOp(op); @@ -8567,14 +8615,29 @@ struct PTOLogToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TLogOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); Value src = peelUnrealized(adaptor.getSrc()); Value dst = peelUnrealized(adaptor.getDst()); SmallVector operands{dst, src}; + ArrayAttr templateArgs; + if (op.getPrecisionType() != pto::LogPrecision::Default) { + StringRef precisionTok; + switch (op.getPrecisionType()) { + case pto::LogPrecision::Default: + precisionTok = "pto::LogAlgorithm::DEFAULT"; + break; + case pto::LogPrecision::HighPrecision: + precisionTok = "pto::LogAlgorithm::HIGH_PRECISION"; + break; + } + templateArgs = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, precisionTok)}); + } rewriter.create( loc, TypeRange{}, "TLOG", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, /*operands=*/operands); rewriter.eraseOp(op); @@ -9374,14 +9437,29 @@ struct PTORecipToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TRecipOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); Value src = peelUnrealized(adaptor.getSrc()); Value dst = peelUnrealized(adaptor.getDst()); SmallVector operands{dst, src}; + ArrayAttr templateArgs; + if (op.getPrecisionType() != pto::RecipPrecision::Default) { + StringRef precisionTok; + switch (op.getPrecisionType()) { + case pto::RecipPrecision::Default: + precisionTok = "pto::RecipAlgorithm::DEFAULT"; + break; + case pto::RecipPrecision::HighPrecision: + precisionTok = "pto::RecipAlgorithm::HIGH_PRECISION"; + break; + } + templateArgs = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, precisionTok)}); + } rewriter.create( loc, TypeRange{}, "TRECIP", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, /*operands=*/operands); rewriter.eraseOp(op); @@ -9422,15 +9500,30 @@ struct PTORemToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TRemOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); Value src0 = peelUnrealized(adaptor.getSrc0()); Value src1 = peelUnrealized(adaptor.getSrc1()); Value tmp = peelUnrealized(adaptor.getTmp()); Value dst = peelUnrealized(adaptor.getDst()); SmallVector operands{dst, src0, src1, tmp}; + ArrayAttr templateArgs; + if (op.getPrecisionType() != pto::RemPrecision::Default) { + StringRef precisionTok; + switch (op.getPrecisionType()) { + case pto::RemPrecision::Default: + precisionTok = "pto::RemAlgorithm::DEFAULT"; + break; + case pto::RemPrecision::HighPrecision: + precisionTok = "pto::RemAlgorithm::HIGH_PRECISION"; + break; + } + templateArgs = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, precisionTok)}); + } rewriter.create( loc, TypeRange{}, "TREM", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, /*operands=*/operands); rewriter.eraseOp(op); @@ -9444,15 +9537,30 @@ struct PTOFModToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TFModOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); Value src0 = peelUnrealized(adaptor.getSrc0()); Value src1 = peelUnrealized(adaptor.getSrc1()); Value dst = peelUnrealized(adaptor.getDst()); SmallVector operands{dst, src0, src1}; + ArrayAttr templateArgs; + if (op.getPrecisionType() != pto::FmodPrecision::Default) { + StringRef precisionTok; + switch (op.getPrecisionType()) { + case pto::FmodPrecision::Default: + precisionTok = "pto::FmodAlgorithm::DEFAULT"; + break; + case pto::FmodPrecision::HighPrecision: + precisionTok = "pto::FmodAlgorithm::HIGH_PRECISION"; + break; + } + templateArgs = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, precisionTok)}); + } rewriter.create( loc, TypeRange{}, "TFMOD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, /*operands=*/operands); rewriter.eraseOp(op); @@ -9764,6 +9872,7 @@ struct PTORowExpandDivToEmitC : public OpConversionPattern LogicalResult matchAndRewrite(pto::TRowExpandDivOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); Value src0 = peelUnrealized(adaptor.getSrc0()); Value src1 = peelUnrealized(adaptor.getSrc1()); @@ -9775,9 +9884,23 @@ struct PTORowExpandDivToEmitC : public OpConversionPattern operands.assign({dst, src0, src1, tmp}); else operands.assign({dst, src0, src1}); + ArrayAttr templateArgs; + if (op.getPrecisionType() != pto::DivPrecision::Default) { + StringRef precisionTok; + switch (op.getPrecisionType()) { + case pto::DivPrecision::Default: + precisionTok = "pto::DivAlgorithm::DEFAULT"; + break; + case pto::DivPrecision::HighPrecision: + precisionTok = "pto::DivAlgorithm::HIGH_PRECISION"; + break; + } + templateArgs = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, precisionTok)}); + } rewriter.create( loc, TypeRange{}, "TROWEXPANDDIV", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, /*operands=*/operands); rewriter.eraseOp(op); @@ -10054,15 +10177,30 @@ struct PTORsqrtToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TRsqrtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); Value src = peelUnrealized(adaptor.getSrc()); Value dst = peelUnrealized(adaptor.getDst()); SmallVector operands{dst, src}; if (Value tmp = adaptor.getTmp()) operands.push_back(peelUnrealized(tmp)); + ArrayAttr templateArgs; + if (op.getPrecisionType() != pto::RsqrtPrecision::Default) { + StringRef precisionTok; + switch (op.getPrecisionType()) { + case pto::RsqrtPrecision::Default: + precisionTok = "pto::RsqrtAlgorithm::DEFAULT"; + break; + case pto::RsqrtPrecision::HighPrecision: + precisionTok = "pto::RsqrtAlgorithm::HIGH_PRECISION"; + break; + } + templateArgs = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, precisionTok)}); + } rewriter.create( loc, TypeRange{}, "TRSQRT", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, /*operands=*/operands); rewriter.eraseOp(op); @@ -10298,14 +10436,29 @@ struct PTOSqrtSToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TSqrtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); Value src = peelUnrealized(adaptor.getSrc()); Value dst = peelUnrealized(adaptor.getDst()); SmallVector operands{dst, src}; + ArrayAttr templateArgs; + if (op.getPrecisionType() != pto::SqrtPrecision::Default) { + StringRef precisionTok; + switch (op.getPrecisionType()) { + case pto::SqrtPrecision::Default: + precisionTok = "pto::SqrtAlgorithm::DEFAULT"; + break; + case pto::SqrtPrecision::HighPrecision: + precisionTok = "pto::SqrtAlgorithm::HIGH_PRECISION"; + break; + } + templateArgs = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, precisionTok)}); + } rewriter.create( loc, TypeRange{}, "TSQRT", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, /*operands=*/operands); rewriter.eraseOp(op); diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index c21669b81..63849ef2a 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -1864,7 +1864,8 @@ struct PTOViewToMemrefPass IRRewriter rewriter(ctx); rewriter.setInsertionPoint(op); rewriter.replaceOpWithNewOp( - op, TypeRange{}, op->getOperand(0), op->getOperand(1)); + op, TypeRange{}, op->getOperand(0), op->getOperand(1), + op.getPrecisionTypeAttr()); } // --- TMulOp [Src, Scalar, Dst] --- @@ -2694,7 +2695,8 @@ struct PTOViewToMemrefPass TypeRange{}, src0, src1, - dst); + dst, + op.getPrecisionTypeAttr()); } DefaultInlineVector divsops; @@ -3059,7 +3061,8 @@ struct PTOViewToMemrefPass op, TypeRange{}, src, - dst); + dst, + op.getPrecisionTypeAttr()); } DefaultInlineVector lreluops; diff --git a/test/lit/pto/tcolexpanddiv_precision_emitc.pto b/test/lit/pto/tcolexpanddiv_precision_emitc.pto new file mode 100644 index 000000000..ea46d994f --- /dev/null +++ b/test/lit/pto/tcolexpanddiv_precision_emitc.pto @@ -0,0 +1,19 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 + +module { + func.func @tcolexpanddiv_precision_emitc() { + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tcolexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tcolexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + {precisionType = #pto} + return + } +} + +// A5: TCOLEXPANDDIV([[VDST:v[0-9]+]], [[VSRC0:v[0-9]+]], [[VSRC1:v[0-9]+]]); +// A5: TCOLEXPANDDIV([[VDST]], [[VSRC0]], [[VSRC1]]); diff --git a/test/lit/pto/tdiv_precision_emitc.pto b/test/lit/pto/tdiv_precision_emitc.pto new file mode 100644 index 000000000..16aad222d --- /dev/null +++ b/test/lit/pto/tdiv_precision_emitc.pto @@ -0,0 +1,21 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 +// RUN: ptoas --pto-arch=a5 --enable-insert-sync %s 2>&1 | FileCheck %s --check-prefix=A5 + +module { + func.func @tdiv_precision_emitc() { + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tdiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tdiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + {precisionType = #pto} + return + } +} + +// A5: TDIV([[VDST:v[0-9]+]], [[VSRC0:v[0-9]+]], [[VSRC1:v[0-9]+]]); +// A5: TDIV([[VDST]], [[VSRC0]], [[VSRC1]]); diff --git a/test/lit/pto/texp_precision_emitc.pto b/test/lit/pto/texp_precision_emitc.pto new file mode 100644 index 000000000..68f300115 --- /dev/null +++ b/test/lit/pto/texp_precision_emitc.pto @@ -0,0 +1,18 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 + +module { + func.func @texp_precision_emitc() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.texp ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.texp ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + {precisionType = #pto} + return + } +} + +// A5: TEXP([[VDST:v[0-9]+]], [[VSRC:v[0-9]+]]); +// A5: TEXP([[VDST]], [[VSRC]]); diff --git a/test/lit/pto/tfmod_precision_emitc.pto b/test/lit/pto/tfmod_precision_emitc.pto new file mode 100644 index 000000000..bbd94ebc2 --- /dev/null +++ b/test/lit/pto/tfmod_precision_emitc.pto @@ -0,0 +1,19 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 + +module { + func.func @tfmod_precision_emitc() { + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tfmod ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tfmod ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + {precisionType = #pto} + return + } +} + +// A5: TFMOD([[VDST:v[0-9]+]], [[VSRC0:v[0-9]+]], [[VSRC1:v[0-9]+]]); +// A5: TFMOD([[VDST]], [[VSRC0]], [[VSRC1]]); diff --git a/test/lit/pto/tlog_precision_emitc.pto b/test/lit/pto/tlog_precision_emitc.pto new file mode 100644 index 000000000..f50d1aefb --- /dev/null +++ b/test/lit/pto/tlog_precision_emitc.pto @@ -0,0 +1,18 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 + +module { + func.func @tlog_precision_emitc() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tlog ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tlog ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + {precisionType = #pto} + return + } +} + +// A5: TLOG([[VDST:v[0-9]+]], [[VSRC:v[0-9]+]]); +// A5: TLOG([[VDST]], [[VSRC]]); diff --git a/test/lit/pto/trecip_precision_emitc.pto b/test/lit/pto/trecip_precision_emitc.pto new file mode 100644 index 000000000..830d03d0e --- /dev/null +++ b/test/lit/pto/trecip_precision_emitc.pto @@ -0,0 +1,18 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 + +module { + func.func @trecip_precision_emitc() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.trecip ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.trecip ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + {precisionType = #pto} + return + } +} + +// A5: TRECIP([[VDST:v[0-9]+]], [[VSRC:v[0-9]+]]); +// A5: TRECIP([[VDST]], [[VSRC]]); diff --git a/test/lit/pto/trem_precision_emitc.pto b/test/lit/pto/trem_precision_emitc.pto new file mode 100644 index 000000000..462867bd4 --- /dev/null +++ b/test/lit/pto/trem_precision_emitc.pto @@ -0,0 +1,20 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 + +module { + func.func @trem_precision_emitc() { + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.trem ins(%src0, %src1, %tmp : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.trem ins(%src0, %src1, %tmp : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + {precisionType = #pto} + return + } +} + +// A5: TREM([[VDST:v[0-9]+]], [[VSRC0:v[0-9]+]], [[VSRC1:v[0-9]+]], [[VTMP:v[0-9]+]]); +// A5: TREM([[VDST]], [[VSRC0]], [[VSRC1]], [[VTMP]]); diff --git a/test/lit/pto/trowexpanddiv_precision_emitc.pto b/test/lit/pto/trowexpanddiv_precision_emitc.pto new file mode 100644 index 000000000..6056f708f --- /dev/null +++ b/test/lit/pto/trowexpanddiv_precision_emitc.pto @@ -0,0 +1,20 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 + +module { + func.func @trowexpanddiv_precision_emitc() { + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.trowexpanddiv ins(%src0, %src1, %tmp : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + {precisionType = #pto} + return + } +} + +// A5: TROWEXPANDDIV([[VDST0:v[0-9]+]], [[VSRC00:v[0-9]+]], [[VSRC01:v[0-9]+]]); +// A5: TROWEXPANDDIV([[VDST1:v[0-9]+]], [[VSRC10:v[0-9]+]], [[VSRC11:v[0-9]+]], [[VTMP:v[0-9]+]]); diff --git a/test/lit/pto/trsqrt_precision_emitc.pto b/test/lit/pto/trsqrt_precision_emitc.pto new file mode 100644 index 000000000..0ff89ca39 --- /dev/null +++ b/test/lit/pto/trsqrt_precision_emitc.pto @@ -0,0 +1,19 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 + +module { + func.func @trsqrt_precision_emitc() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + + pto.trsqrt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.trsqrt ins(%src, %tmp : !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + {precisionType = #pto} + return + } +} + +// A5: TRSQRT([[VDST0:v[0-9]+]], [[VSRC0:v[0-9]+]]); +// A5: TRSQRT([[VDST1:v[0-9]+]], [[VSRC1:v[0-9]+]], [[VTMP:v[0-9]+]]); diff --git a/test/lit/pto/tsqrt_precision_emitc.pto b/test/lit/pto/tsqrt_precision_emitc.pto new file mode 100644 index 000000000..92a30a2ad --- /dev/null +++ b/test/lit/pto/tsqrt_precision_emitc.pto @@ -0,0 +1,18 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 + +module { + func.func @tsqrt_precision_emitc() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tsqrt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tsqrt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + {precisionType = #pto} + return + } +} + +// A5: TSQRT([[VDST:v[0-9]+]], [[VSRC:v[0-9]+]]); +// A5: TSQRT([[VDST]], [[VSRC]]); From e2f75cf853606bde467effa8d0641df74258c8a0 Mon Sep 17 00:00:00 2001 From: jimmychou <1049859649@qq.com> Date: Tue, 26 May 2026 09:32:18 +0800 Subject: [PATCH 2/6] Add AccPhase support for matmul and gemv ops --- include/PTO/IR/PTOAttrs.td | 14 +++++++++ include/PTO/IR/PTOOps.td | 18 ++++++++---- test/lit/pto/tgemv_accphase_emitc.pto | 34 ++++++++++++++++++++++ test/lit/pto/tgemv_mx_accphase_emitc.pto | 19 ++++++++++++ test/lit/pto/tmatmul_accphase_emitc.pto | 34 ++++++++++++++++++++++ test/lit/pto/tmatmul_mx_accphase_emitc.pto | 21 +++++++++++++ 6 files changed, 134 insertions(+), 6 deletions(-) create mode 100644 test/lit/pto/tgemv_accphase_emitc.pto create mode 100644 test/lit/pto/tgemv_mx_accphase_emitc.pto create mode 100644 test/lit/pto/tmatmul_accphase_emitc.pto create mode 100644 test/lit/pto/tmatmul_mx_accphase_emitc.pto diff --git a/include/PTO/IR/PTOAttrs.td b/include/PTO/IR/PTOAttrs.td index 79b3ff227..c6a100455 100644 --- a/include/PTO/IR/PTOAttrs.td +++ b/include/PTO/IR/PTOAttrs.td @@ -528,6 +528,20 @@ def PTO_FmodPrecisionAttr : EnumAttr, + I32EnumAttrCase<"Partial", 2, "partial">, + I32EnumAttrCase<"Final", 3, "final"> + ]>; + +def PTO_AccPhaseAttr : EnumAttr { + let summary = "matmul/gemv accumulate phase attribute"; +} +//===----------------------------------------------------------------------===// // TStore template controls //===----------------------------------------------------------------------===// diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index a944db149..863ed8c51 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -817,7 +817,8 @@ def TMatmulBiasOp : PTO_TOp<"tmatmul.bias", [ PTODpsType:$a, PTODpsType:$b, PTODpsType:$bias, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$accPhase ); let results = (outs Optional:$result); @@ -856,7 +857,8 @@ def TMatmulMxOp : PTO_TOp<"tmatmul.mx", [ PTODpsType:$a_scale, PTODpsType:$b, PTODpsType:$b_scale, - PTODpsType:$dst); + PTODpsType:$dst, + DefaultValuedAttr:$accPhase); let results = (outs Optional:$result); let hasVerifier = 1; @@ -1031,7 +1033,8 @@ def TGemvOp : PTO_TOp<"tgemv", [ let arguments = (ins PTODpsType:$lhs, PTODpsType:$rhs, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$accPhase ); let results = (outs @@ -1067,7 +1070,8 @@ def TGemvAccOp : PTO_TOp<"tgemv.acc", [ PTODpsType:$acc_in, PTODpsType:$lhs, PTODpsType:$rhs, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$accPhase ); let results = (outs @@ -1102,7 +1106,8 @@ def TGemvBiasOp : PTO_TOp<"tgemv.bias", [ PTODpsType :$a, PTODpsType :$b, PTODpsType :$bias, - PTODpsType :$dst + PTODpsType :$dst, + DefaultValuedAttr:$accPhase ); let results = (outs Optional:$result); @@ -1137,7 +1142,8 @@ def TGemvMxOp : PTO_TOp<"tgemv.mx", [ PTODpsType:$a_scale, PTODpsType:$b, PTODpsType:$b_scale, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$accPhase ); let results = (outs Optional:$result); diff --git a/test/lit/pto/tgemv_accphase_emitc.pto b/test/lit/pto/tgemv_accphase_emitc.pto new file mode 100644 index 000000000..a7aa3ce88 --- /dev/null +++ b/test/lit/pto/tgemv_accphase_emitc.pto @@ -0,0 +1,34 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 + +module attributes {pto.target_arch = "a5"} { + func.func @tgemv_accphase_emitc() attributes {pto.kernel_kind = #pto.kernel_kind} { + %lhs = pto.alloc_tile : !pto.tile_buf + %rhs = pto.alloc_tile : !pto.tile_buf + %acc_in = pto.alloc_tile : !pto.tile_buf + %bias = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tgemv ins(%lhs, %rhs : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) + pto.tgemv ins(%lhs, %rhs : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {accPhase = #pto} + pto.tgemv ins(%lhs, %rhs : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {accPhase = #pto} + + pto.tgemv.acc ins(%acc_in, %lhs, %rhs : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) + pto.tgemv.acc ins(%acc_in, %lhs, %rhs : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {accPhase = #pto} + pto.tgemv.acc ins(%acc_in, %lhs, %rhs : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {accPhase = #pto} + + pto.tgemv.bias ins(%lhs, %rhs, %bias : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) + pto.tgemv.bias ins(%lhs, %rhs, %bias : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {accPhase = #pto} + pto.tgemv.bias ins(%lhs, %rhs, %bias : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {accPhase = #pto} + return + } +} + +// A5: TGEMV([[D0:v[0-9]+]], [[L0:v[0-9]+]], [[R0:v[0-9]+]]); +// A5: TGEMV([[D0]], [[L0]], [[R0]]); +// A5: TGEMV([[D0]], [[L0]], [[R0]]); +// A5: TGEMV_ACC([[D1:v[0-9]+]], [[CIN:v[0-9]+]], [[L1:v[0-9]+]], [[R1:v[0-9]+]]); +// A5: TGEMV_ACC([[D1]], [[CIN]], [[L1]], [[R1]]); +// A5: TGEMV_ACC([[D1]], [[CIN]], [[L1]], [[R1]]); +// A5: TGEMV_BIAS([[D2:v[0-9]+]], [[L2:v[0-9]+]], [[R2:v[0-9]+]], [[B2:v[0-9]+]]); +// A5: TGEMV_BIAS([[D2]], [[L2]], [[R2]], [[B2]]); +// A5: TGEMV_BIAS([[D2]], [[L2]], [[R2]], [[B2]]); diff --git a/test/lit/pto/tgemv_mx_accphase_emitc.pto b/test/lit/pto/tgemv_mx_accphase_emitc.pto new file mode 100644 index 000000000..f8c5d5876 --- /dev/null +++ b/test/lit/pto/tgemv_mx_accphase_emitc.pto @@ -0,0 +1,19 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 + +module attributes {pto.target_arch = "a5"} { + func.func @tgemv_mx_accphase_emitc(%a : !pto.tile_buf, + %b : !pto.tile_buf) attributes {pto.kernel_kind = #pto.kernel_kind} { + %a_scale = pto.alloc_tile : !pto.tile_buf + %b_scale = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tgemv.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) + pto.tgemv.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {accPhase = #pto} + pto.tgemv.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {accPhase = #pto} + return + } +} + +// A5: TGEMV_MX([[D0:v[0-9]+]], [[A0:v[0-9]+]], [[AS0:v[0-9]+]], [[B0:v[0-9]+]], [[BS0:v[0-9]+]]); +// A5: TGEMV_MX([[D1:v[0-9]+]], [[A1:v[0-9]+]], [[AS1:v[0-9]+]], [[B1:v[0-9]+]], [[BS1:v[0-9]+]]); +// A5: TGEMV_MX([[D2:v[0-9]+]], [[A2:v[0-9]+]], [[AS2:v[0-9]+]], [[B2:v[0-9]+]], [[BS2:v[0-9]+]]); diff --git a/test/lit/pto/tmatmul_accphase_emitc.pto b/test/lit/pto/tmatmul_accphase_emitc.pto new file mode 100644 index 000000000..ee02bb197 --- /dev/null +++ b/test/lit/pto/tmatmul_accphase_emitc.pto @@ -0,0 +1,34 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 + +module attributes {pto.target_arch = "a5"} { + func.func @tmatmul_accphase_emitc() attributes {pto.kernel_kind = #pto.kernel_kind} { + %lhs = pto.alloc_tile : !pto.tile_buf + %rhs = pto.alloc_tile : !pto.tile_buf + %acc_in = pto.alloc_tile : !pto.tile_buf + %bias = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tmatmul ins(%lhs, %rhs : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) + pto.tmatmul ins(%lhs, %rhs : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {accPhase = #pto} + pto.tmatmul ins(%lhs, %rhs : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {accPhase = #pto} + + pto.tmatmul.acc ins(%acc_in, %lhs, %rhs : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) + pto.tmatmul.acc ins(%acc_in, %lhs, %rhs : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {accPhase = #pto} + pto.tmatmul.acc ins(%acc_in, %lhs, %rhs : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {accPhase = #pto} + + pto.tmatmul.bias ins(%lhs, %rhs, %bias : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) + pto.tmatmul.bias ins(%lhs, %rhs, %bias : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {accPhase = #pto} + pto.tmatmul.bias ins(%lhs, %rhs, %bias : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {accPhase = #pto} + return + } +} + +// A5: TMATMUL([[D0:v[0-9]+]], [[L0:v[0-9]+]], [[R0:v[0-9]+]]); +// A5: TMATMUL([[D0]], [[L0]], [[R0]]); +// A5: TMATMUL([[D0]], [[L0]], [[R0]]); +// A5: TMATMUL_ACC([[D1:v[0-9]+]], [[CIN:v[0-9]+]], [[L1:v[0-9]+]], [[R1:v[0-9]+]]); +// A5: TMATMUL_ACC([[D1]], [[CIN]], [[L1]], [[R1]]); +// A5: TMATMUL_ACC([[D1]], [[CIN]], [[L1]], [[R1]]); +// A5: TMATMUL_BIAS([[D2:v[0-9]+]], [[L2:v[0-9]+]], [[R2:v[0-9]+]], [[B2:v[0-9]+]]); +// A5: TMATMUL_BIAS([[D2]], [[L2]], [[R2]], [[B2]]); +// A5: TMATMUL_BIAS([[D2]], [[L2]], [[R2]], [[B2]]); diff --git a/test/lit/pto/tmatmul_mx_accphase_emitc.pto b/test/lit/pto/tmatmul_mx_accphase_emitc.pto new file mode 100644 index 000000000..bde097042 --- /dev/null +++ b/test/lit/pto/tmatmul_mx_accphase_emitc.pto @@ -0,0 +1,21 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 + +module attributes {pto.target_arch = "a5"} { + func.func @tmatmul_mx_accphase_emitc(%a : !pto.tile_buf, + %b : !pto.tile_buf) attributes {pto.kernel_kind = #pto.kernel_kind} { + %a_scale = pto.alloc_tile : !pto.tile_buf + %b_scale = pto.alloc_tile : !pto.tile_buf + %dst0 = pto.alloc_tile : !pto.tile_buf + %dst1 = pto.alloc_tile : !pto.tile_buf + %dst2 = pto.alloc_tile : !pto.tile_buf + + pto.tmatmul.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst0 : !pto.tile_buf) + pto.tmatmul.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst1 : !pto.tile_buf) {accPhase = #pto} + pto.tmatmul.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst2 : !pto.tile_buf) {accPhase = #pto} + return + } +} + +// A5: TMATMUL_MX([[D0:v[0-9]+]], [[A0:v[0-9]+]], [[AS0:v[0-9]+]], [[B0:v[0-9]+]], [[BS0:v[0-9]+]]); +// A5: TMATMUL_MX([[D1:v[0-9]+]], [[A1:v[0-9]+]], [[AS1:v[0-9]+]], [[B1:v[0-9]+]], [[BS1:v[0-9]+]]); +// A5: TMATMUL_MX([[D2:v[0-9]+]], [[A2:v[0-9]+]], [[AS2:v[0-9]+]], [[B2:v[0-9]+]], [[BS2:v[0-9]+]]); From fca3aa98c4d403d9e12a5b873a4d43ca159fe371 Mon Sep 17 00:00:00 2001 From: jimmychou <1049859649@qq.com> Date: Tue, 26 May 2026 12:20:46 +0800 Subject: [PATCH 3/6] fix(pto): keep void EmitC calls void in matmul/gemv lowering --- lib/PTO/Transforms/PTOToEmitC.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index 235f64d66..bc6b39ad4 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -9698,10 +9698,10 @@ static void replaceOrEraseWithOpaqueCall(Operation *op, StringRef callee, ArrayRef args, ConversionPatternRewriter &rewriter) { - TypeRange resultTypes = op->getResultTypes(); auto call = rewriter.create( - op->getLoc(), resultTypes, callee, ArrayAttr{}, ArrayAttr{}, ValueRange(args)); - if (resultTypes.empty()) + op->getLoc(), TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, + ValueRange(args)); + if (op->getNumResults() == 0) rewriter.eraseOp(op); else rewriter.replaceOp(op, call.getResults()); From 880cfcda51eb36b9dce84bd200bd860cf91940a8 Mon Sep 17 00:00:00 2001 From: jimmychou <1049859649@qq.com> Date: Tue, 26 May 2026 14:53:49 +0800 Subject: [PATCH 4/6] fix(pto): reuse upstream AccPhase attr definition --- include/PTO/IR/PTOAttrs.td | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/include/PTO/IR/PTOAttrs.td b/include/PTO/IR/PTOAttrs.td index c6a100455..79b3ff227 100644 --- a/include/PTO/IR/PTOAttrs.td +++ b/include/PTO/IR/PTOAttrs.td @@ -528,20 +528,6 @@ def PTO_FmodPrecisionAttr : EnumAttr, - I32EnumAttrCase<"Partial", 2, "partial">, - I32EnumAttrCase<"Final", 3, "final"> - ]>; - -def PTO_AccPhaseAttr : EnumAttr { - let summary = "matmul/gemv accumulate phase attribute"; -} -//===----------------------------------------------------------------------===// // TStore template controls //===----------------------------------------------------------------------===// From 849a30193820538055f299acd7bd7dfd9282fb50 Mon Sep 17 00:00:00 2001 From: jimmychou <1049859649@qq.com> Date: Tue, 26 May 2026 19:37:01 +0800 Subject: [PATCH 5/6] fix(pto): complete accphase emitc lowering --- lib/PTO/IR/PTO.cpp | 87 ++++++++++++++++++++---- lib/PTO/Transforms/PTOToEmitC.cpp | 76 +++++++++++---------- lib/PTO/Transforms/PTOViewToMemref.cpp | 12 ++-- test/lit/pto/tgemv_mx_accphase_emitc.pto | 15 ++-- test/lit/pto/tmatmul_acc_phase_emitc.pto | 8 +-- 5 files changed, 134 insertions(+), 64 deletions(-) diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index 5427ac36e..d3d557a84 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -4121,6 +4121,63 @@ static LogicalResult verifyGemvTileOperands(Operation *op, Type lhsTy, Type rhsT return failure(); } +static LogicalResult verifyA5MxGemvTileOperands(Operation *op, Type lhsTy, + Type rhsTy, Type dstTy) { + if (failed(verifyTileBufCommon(op, lhsTy, "lhs", /*allowLowPrecision=*/true)) || + failed(verifyTileBufCommon(op, rhsTy, "rhs", /*allowLowPrecision=*/true)) || + failed(verifyAccTileCommon(op, dstTy, "dst"))) + return failure(); + + auto lhsSpace = getPTOMemorySpaceEnum(lhsTy); + auto rhsSpace = getPTOMemorySpaceEnum(rhsTy); + auto dstSpace = getPTOMemorySpaceEnum(dstTy); + if (!lhsSpace || !rhsSpace || !dstSpace) + return op->emitOpError("expects lhs, rhs, and dst to have explicit address spaces"); + if (*lhsSpace != pto::AddressSpace::LEFT || *rhsSpace != pto::AddressSpace::RIGHT || + *dstSpace != pto::AddressSpace::ACC) + return op->emitOpError( + "expects lhs, rhs, and dst to use the left, right, and acc address spaces"); + + auto lhsValid = getValidShapeVec(lhsTy); + auto rhsValid = getValidShapeVec(rhsTy); + auto dstValid = getValidShapeVec(dstTy); + if (lhsValid[0] != ShapedType::kDynamic && lhsValid[0] != 1) + return op->emitOpError("expects lhs valid_shape[0] to be 1 for tgemv"); + if (dstValid[0] != ShapedType::kDynamic && dstValid[0] != 1) + return op->emitOpError("expects dst valid_shape[0] to be 1 for tgemv"); + if (lhsValid[1] != ShapedType::kDynamic && rhsValid[0] != ShapedType::kDynamic && + lhsValid[1] != rhsValid[0]) + return op->emitOpError() + << "expects lhs valid_shape[1] to equal rhs valid_shape[0], but got " + << lhsValid[1] << " vs " << rhsValid[0]; + if (rhsValid[1] != ShapedType::kDynamic && dstValid[1] != ShapedType::kDynamic && + rhsValid[1] != dstValid[1]) + return op->emitOpError() + << "expects rhs valid_shape[1] to equal dst valid_shape[1], but got " + << rhsValid[1] << " vs " << dstValid[1]; + + auto lhsTb = dyn_cast(lhsTy); + auto rhsTb = dyn_cast(rhsTy); + auto dstTb = dyn_cast(dstTy); + if (!lhsTb || !rhsTb || !dstTb) + return success(); + + if (lhsTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor)) + return op->emitOpError("expects lhs to use the col_major blayout on A5"); + if (rhsTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor)) + return op->emitOpError("expects rhs to use the row_major blayout on A5"); + if (dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor)) + return op->emitOpError("expects dst to use the col_major blayout on A5"); + + if (lhsTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor)) + return op->emitOpError("expects lhs to use the row_major slayout on A5"); + if (rhsTb.getSLayoutValueI32() != static_cast(pto::SLayout::ColMajor)) + return op->emitOpError("expects rhs to use the col_major slayout on A5"); + if (dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor)) + return op->emitOpError("expects dst to use the row_major slayout on A5"); + return success(); +} + static LogicalResult verifyMatBiasTileA2A3(Operation *op, Type biasTy, Type dstTy, bool requireFloatBias) { if (failed(verifyTileBufCommon(op, biasTy, "bias"))) @@ -5625,7 +5682,9 @@ static bool isA5Fp8LikeType(Type ty) { } static bool isA5MxInputType(Type ty) { - return isA5Fp8LikeType(ty); + if (auto ft = dyn_cast(ty)) + return ft.isFloat8E4M3FN(); + return false; } static LogicalResult verifyA5MxTypeTriple(Operation *op, Type lhsTy, Type rhsTy, @@ -5635,10 +5694,12 @@ static LogicalResult verifyA5MxTypeTriple(Operation *op, Type lhsTy, Type rhsTy, Type rhsElem = getElemTy(rhsTy); Type dstElem = getElemTy(dstTy); - if (!isA5MxInputType(lhsElem) || !isA5MxInputType(rhsElem)) - return op->emitOpError() - << "expects A5 mx operands " << lhsName << " and " << rhsName - << " to use fp8 element types"; + if (!isA5MxInputType(lhsElem)) + return op->emitOpError() << lhsName << ": dtype " << lhsElem + << " is not supported by this op yet"; + if (!isA5MxInputType(rhsElem)) + return op->emitOpError() << rhsName << ": dtype " << rhsElem + << " is not supported by this op yet"; if (!dstElem.isF32()) return op->emitOpError() @@ -6609,11 +6670,11 @@ LogicalResult TGemvMxOp::verify() { getA().getType(), "a_scale", "a")) || failed(verifyScaleTileMatchesOperand(*this, getBScale().getType(), getB().getType(), "b_scale", "b")) || - failed(verifyGemvTileOperands(*this, getA().getType(), getB().getType(), - getDst().getType()))) + failed(verifyA5MxGemvTileOperands(*this, getA().getType(), getB().getType(), + getDst().getType()))) return failure(); if (failed(verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), - getDst().getType(), "a", "b", "dst"))) + getDst().getType(), "lhs", "rhs", "dst"))) return failure(); return verifyMatmulLike(*this, getA().getType(), getB().getType(), getDst().getType()); @@ -6635,7 +6696,7 @@ LogicalResult TGemvMxAccOp::verify() { getDst().getType()))) return failure(); if (failed(verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), - getDst().getType(), "a", "b", "dst"))) + getDst().getType(), "lhs", "rhs", "dst"))) return failure(); if (failed(verifyTileBufSameElemType(*this, getCIn().getType(), getDst().getType(), "c_in", "dst")) || @@ -6663,7 +6724,7 @@ LogicalResult TGemvMxBiasOp::verify() { /*requireFloatBias=*/true))) return failure(); if (failed(verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), - getDst().getType(), "a", "b", "dst"))) + getDst().getType(), "lhs", "rhs", "dst"))) return failure(); auto biasShape = getShapeVec(getBias().getType()); auto dstShape = getShapeVec(getDst().getType()); @@ -6710,7 +6771,7 @@ LogicalResult TMatmulMxOp::verify() { if (failed(verifyA2A3())) return failure(); return verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), - getDst().getType(), "a", "b", "dst"); + getDst().getType(), "lhs", "rhs", "dst"); }; return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); } @@ -6727,7 +6788,7 @@ LogicalResult TMatmulMxAccOp::verify() { if (failed(verifyA2A3())) return failure(); if (failed(verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), - getDst().getType(), "a", "b", "dst"))) + getDst().getType(), "lhs", "rhs", "dst"))) return failure(); if (failed(verifyTileBufSameElemType(*this, getCIn().getType(), getDst().getType(), "c_in", "dst")) || @@ -6754,7 +6815,7 @@ LogicalResult TMatmulMxBiasOp::verify() { if (failed(verifyA2A3())) return failure(); return verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), - getDst().getType(), "a", "b", "dst"); + getDst().getType(), "lhs", "rhs", "dst"); }; return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); } diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index bc6b39ad4..e9e65050a 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -4520,8 +4520,8 @@ struct PTOTStoreToTSTORE : public OpConversionPattern { // Render `pto.tmatmul` as one of three forms depending on the optional // `acc_phase` attribute: // * absent / Unspecified -> `TMATMUL(dst, lhs, rhs)` -// * Partial -> `TMATMUL(dst, lhs, rhs)` -// * Final -> `TMATMUL(dst, lhs, rhs)` +// * Partial -> `TMATMUL(dst, lhs, rhs)` +// * Final -> `TMATMUL(dst, lhs, rhs)` // The Unspecified default keeps backward compatibility with all upstream IR // that does not yet emit an explicit phase attribute. static ArrayAttr buildAccPhaseTemplateArgs(ConversionPatternRewriter &rewriter, @@ -4531,10 +4531,10 @@ static ArrayAttr buildAccPhaseTemplateArgs(ConversionPatternRewriter &rewriter, case pto::AccPhase::Unspecified: return ArrayAttr{}; case pto::AccPhase::Partial: - tmpl = "AccPhase::Partial"; + tmpl = "pto::AccPhase::Partial"; break; case pto::AccPhase::Final: - tmpl = "AccPhase::Final"; + tmpl = "pto::AccPhase::Final"; break; } if (tmpl.empty()) @@ -4585,10 +4585,12 @@ struct PTOTGemvToTGEMV : public OpConversionPattern { Value rhs = peelUnrealized(adaptor.getRhs()); // B (Vector) Value dst = peelUnrealized(adaptor.getDst()); // C (Result) - // 2. 直接生成函数调用 TGEMV(dst, lhs, rhs) + ArrayAttr templateArgs = + buildAccPhaseTemplateArgs(rewriter, op.getAccPhase()); + rewriter.create( op.getLoc(), TypeRange{}, "TGEMV", - ArrayAttr{}, ArrayAttr{}, + /*args=*/ArrayAttr{}, /*template_args=*/templateArgs, ValueRange{dst, lhs, rhs}); // 3. 处理 Op 替换/删除 @@ -4618,10 +4620,12 @@ struct PTOTGemvAccToTGEMVACC : public OpConversionPattern { Value rhs = peelUnrealized(adaptor.getRhs()); // B (Vector) Value dst = peelUnrealized(adaptor.getDst()); // AccNew - // 2. 直接生成函数调用 TGEMV_ACC(dst, accIn, lhs, rhs) + ArrayAttr templateArgs = + buildAccPhaseTemplateArgs(rewriter, op.getAccPhase()); + rewriter.create( op.getLoc(), TypeRange{}, "TGEMV_ACC", - ArrayAttr{}, ArrayAttr{}, + /*args=*/ArrayAttr{}, /*template_args=*/templateArgs, ValueRange{dst, accIn, lhs, rhs}); // 3. 处理 Op 替换/删除 @@ -9694,25 +9698,13 @@ struct PTORowExpandExpdifToEmitC // PTOConvert.cpp (add lowering + patterns.add for TROWEXPANDDIV DPS/memref op) //===----------------------------------------------------------------------===// // Helper: replace or erase based on whether op has results. -static void replaceOrEraseWithOpaqueCall(Operation *op, - StringRef callee, - ArrayRef args, - ConversionPatternRewriter &rewriter) { - auto call = rewriter.create( - op->getLoc(), TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, - ValueRange(args)); - if (op->getNumResults() == 0) - rewriter.eraseOp(op); - else - rewriter.replaceOp(op, call.getResults()); -} - static void replaceOrEraseWithOpaqueCallAndReturnDst(Operation *op, Value dst, StringRef callee, ArrayRef args, + ArrayAttr templateArgs, ConversionPatternRewriter &rewriter) { rewriter.create( - op->getLoc(), TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, ValueRange(args)); + op->getLoc(), TypeRange{}, callee, ArrayAttr{}, templateArgs, ValueRange(args)); if (op->getNumResults() == 1) rewriter.replaceOp(op, dst); else @@ -9731,8 +9723,10 @@ struct PTOTGemvBiasToTGEMV_BIAS Value bias = peelUnrealized(adaptor.getBias()); Value dst = peelUnrealized(adaptor.getDst()); - replaceOrEraseWithOpaqueCall(op.getOperation(), "TGEMV_BIAS", - {dst, a, b, bias}, rewriter); + ArrayAttr templateArgs = + buildAccPhaseTemplateArgs(rewriter, op.getAccPhase()); + replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_BIAS", + {dst, a, b, bias}, templateArgs, rewriter); return success(); } }; @@ -9749,8 +9743,11 @@ struct PTOTGemvMXToTGEMV_MX Value bScale = peelUnrealized(adaptor.getBScale()); Value dst = peelUnrealized(adaptor.getDst()); + ArrayAttr templateArgs = + buildAccPhaseTemplateArgs(rewriter, op.getAccPhase()); replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", - {dst, a, aScale, b, bScale}, rewriter); + {dst, a, aScale, b, bScale}, templateArgs, + rewriter); return success(); } }; @@ -9769,7 +9766,8 @@ struct PTOTGemvMXAccToTGEMV_MX Value dst = peelUnrealized(adaptor.getDst()); replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", - {dst, cIn, a, aScale, b, bScale}, rewriter); + {dst, cIn, a, aScale, b, bScale}, ArrayAttr{}, + rewriter); return success(); } }; @@ -9788,7 +9786,8 @@ struct PTOTGemvMXBiasToTGEMV_MX Value dst = peelUnrealized(adaptor.getDst()); replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", - {dst, a, aScale, b, bScale, bias}, rewriter); + {dst, a, aScale, b, bScale, bias}, ArrayAttr{}, + rewriter); return success(); } }; @@ -9804,8 +9803,10 @@ struct PTOTMatmulBiasToTMATMUL_BIAS Value bias = peelUnrealized(adaptor.getBias()); Value dst = peelUnrealized(adaptor.getDst()); - replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_BIAS", - {dst, a, b, bias}, rewriter); + ArrayAttr templateArgs = + buildAccPhaseTemplateArgs(rewriter, op.getAccPhase()); + replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TMATMUL_BIAS", + {dst, a, b, bias}, templateArgs, rewriter); return success(); } }; @@ -9822,8 +9823,11 @@ struct PTOTMatmulMXToTMATMUL_MX Value bScale = peelUnrealized(adaptor.getBScale()); Value dst = peelUnrealized(adaptor.getDst()); - replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_MX", - {dst, a, aScale, b, bScale}, rewriter); + ArrayAttr templateArgs = + buildAccPhaseTemplateArgs(rewriter, op.getAccPhase()); + replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TMATMUL_MX", + {dst, a, aScale, b, bScale}, templateArgs, + rewriter); return success(); } }; @@ -9841,8 +9845,9 @@ struct PTOTMatmulMXAccToTMATMUL_MX_ACC Value bScale = peelUnrealized(adaptor.getBScale()); Value dst = peelUnrealized(adaptor.getDst()); - replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_MX", - {dst, cIn, a, aScale, b, bScale}, rewriter); + replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TMATMUL_MX", + {dst, cIn, a, aScale, b, bScale}, ArrayAttr{}, + rewriter); return success(); } }; @@ -9860,8 +9865,9 @@ struct PTOTMatmulMXBiasToTMATMUL_MX_BIAS Value bias = peelUnrealized(adaptor.getBias()); Value dst = peelUnrealized(adaptor.getDst()); - replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_MX", - {dst, a, aScale, b, bScale, bias}, rewriter); + replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TMATMUL_MX", + {dst, a, aScale, b, bScale, bias}, ArrayAttr{}, + rewriter); return success(); } }; diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index 63849ef2a..f17d2e99d 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -1940,7 +1940,7 @@ struct PTOViewToMemrefPass op, TypeRange{}, op->getOperand(0), op->getOperand(1), op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex)); + op->getOperand(kFourthOperandIndex), op.getAccPhaseAttr()); } // --- TMatmulMxOp--- @@ -1954,7 +1954,7 @@ struct PTOViewToMemrefPass op->getOperand(0), op->getOperand(1), op->getOperand(kThirdOperandIndex), op->getOperand(kFourthOperandIndex), - op->getOperand(kFifthOperandIndex)); + op->getOperand(kFifthOperandIndex), op.getAccPhaseAttr()); } // --- TMatmulMxAccOp --- @@ -1999,7 +1999,7 @@ struct PTOViewToMemrefPass Value dst = op->getOperand(kThirdOperandIndex); rewriter.replaceOpWithNewOp( - op, TypeRange{}, lhs, rhs, dst); + op, TypeRange{}, lhs, rhs, dst, op.getAccPhaseAttr()); } // --- TGemvAccOp [Acc, Lhs, Rhs, Dst] --- @@ -2012,7 +2012,7 @@ struct PTOViewToMemrefPass op, TypeRange{}, op->getOperand(0), op->getOperand(1), op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex)); + op->getOperand(kFourthOperandIndex), op.getAccPhaseAttr()); } // --- TGemvBiasOp [Acc, Lhs, Rhs, Bias, Dst] --- @@ -2025,7 +2025,7 @@ struct PTOViewToMemrefPass op, TypeRange{}, op->getOperand(0), op->getOperand(1), op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex)); + op->getOperand(kFourthOperandIndex), op.getAccPhaseAttr()); } // --- TGemvMxOp [A, AScale, B, BScale, Dst] --- @@ -2039,7 +2039,7 @@ struct PTOViewToMemrefPass op->getOperand(0), op->getOperand(1), op->getOperand(kThirdOperandIndex), op->getOperand(kFourthOperandIndex), - op->getOperand(kFifthOperandIndex)); + op->getOperand(kFifthOperandIndex), op.getAccPhaseAttr()); } // --- TGemvMxAccOp [CIn, A, AScale, B, BScale, Dst] --- diff --git a/test/lit/pto/tgemv_mx_accphase_emitc.pto b/test/lit/pto/tgemv_mx_accphase_emitc.pto index f8c5d5876..7a434746a 100644 --- a/test/lit/pto/tgemv_mx_accphase_emitc.pto +++ b/test/lit/pto/tgemv_mx_accphase_emitc.pto @@ -1,15 +1,18 @@ // RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 module attributes {pto.target_arch = "a5"} { - func.func @tgemv_mx_accphase_emitc(%a : !pto.tile_buf, - %b : !pto.tile_buf) attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @tgemv_mx_accphase_emitc() attributes {pto.kernel_kind = #pto.kernel_kind} { + %a = pto.alloc_tile : !pto.tile_buf + %b = pto.alloc_tile : !pto.tile_buf %a_scale = pto.alloc_tile : !pto.tile_buf %b_scale = pto.alloc_tile : !pto.tile_buf - %dst = pto.alloc_tile : !pto.tile_buf + %dst0 = pto.alloc_tile : !pto.tile_buf + %dst1 = pto.alloc_tile : !pto.tile_buf + %dst2 = pto.alloc_tile : !pto.tile_buf - pto.tgemv.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) - pto.tgemv.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {accPhase = #pto} - pto.tgemv.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {accPhase = #pto} + pto.tgemv.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst0 : !pto.tile_buf) + pto.tgemv.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst1 : !pto.tile_buf) {accPhase = #pto} + pto.tgemv.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst2 : !pto.tile_buf) {accPhase = #pto} return } } diff --git a/test/lit/pto/tmatmul_acc_phase_emitc.pto b/test/lit/pto/tmatmul_acc_phase_emitc.pto index 868a991dc..41dad0049 100644 --- a/test/lit/pto/tmatmul_acc_phase_emitc.pto +++ b/test/lit/pto/tmatmul_acc_phase_emitc.pto @@ -49,8 +49,8 @@ module { } // A3: TMATMUL([[C:v[0-9]+]], [[A:v[0-9]+]], [[B:v[0-9]+]]); -// A3: TMATMUL([[C]], [[A]], [[B]]); -// A3: TMATMUL([[C]], [[A]], [[B]]); +// A3: TMATMUL([[C]], [[A]], [[B]]); +// A3: TMATMUL([[C]], [[A]], [[B]]); // A3: TMATMUL_ACC([[C]], [[C]], [[A]], [[B]]); -// A3: TMATMUL_ACC([[C]], [[C]], [[A]], [[B]]); -// A3: TMATMUL_ACC([[C]], [[C]], [[A]], [[B]]); +// A3: TMATMUL_ACC([[C]], [[C]], [[A]], [[B]]); +// A3: TMATMUL_ACC([[C]], [[C]], [[A]], [[B]]); From 89dc65404639deb94534fe45371b9965188c7e9f Mon Sep 17 00:00:00 2001 From: jimmychou <1049859649@qq.com> Date: Wed, 27 May 2026 17:38:47 +0800 Subject: [PATCH 6/6] fix(pto): complete MX accphase coverage --- include/PTO/IR/PTOOps.td | 6 +++-- lib/PTO/IR/PTO.cpp | 22 +++++++++++++++++-- lib/PTO/Transforms/PTOToEmitC.cpp | 8 +++++-- lib/PTO/Transforms/PTOViewToMemref.cpp | 4 ++-- test/lit/pto/tgemv_mx_accphase_emitc.pto | 10 +++++++++ .../pto/tgemv_mx_verify_invalid_shape_a5.pto | 15 +++++++++++++ test/lit/pto/tmatmul_mx_accphase_emitc.pto | 10 +++++++++ 7 files changed, 67 insertions(+), 8 deletions(-) create mode 100644 test/lit/pto/tgemv_mx_verify_invalid_shape_a5.pto diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index 863ed8c51..99b00c28d 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -894,7 +894,8 @@ def TMatmulMxAccOp : PTO_TOp<"tmatmul.mx.acc", [ PTODpsType:$a_scale, PTODpsType:$b, PTODpsType:$b_scale, - PTODpsType:$dst); + PTODpsType:$dst, + DefaultValuedAttr:$accPhase); let results = (outs Optional:$result); let hasVerifier = 1; @@ -1177,7 +1178,8 @@ def TGemvMxAccOp : PTO_TOp<"tgemv.mx.acc", [ PTODpsType:$a_scale, PTODpsType:$b, PTODpsType:$b_scale, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$accPhase ); let results = (outs Optional:$result); diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index d3d557a84..e9d4b3f73 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -4138,9 +4138,27 @@ static LogicalResult verifyA5MxGemvTileOperands(Operation *op, Type lhsTy, return op->emitOpError( "expects lhs, rhs, and dst to use the left, right, and acc address spaces"); + auto lhsShape = getMatmulLogicalShapeVec(lhsTy); + auto rhsShape = getMatmulLogicalShapeVec(rhsTy); + auto dstShape = getMatmulLogicalShapeVec(dstTy); + if ((lhsShape[0] != dstShape[0] || rhsShape[1] != dstShape[1] || + lhsShape[1] != rhsShape[0])) + return op->emitOpError( + "expects static matmul tile shapes lhs[M,K], rhs[K,N], and dst[M,N]"); + auto lhsValid = getValidShapeVec(lhsTy); auto rhsValid = getValidShapeVec(rhsTy); auto dstValid = getValidShapeVec(dstTy); + if (lhsValid.size() == 2 && rhsValid.size() == 2) { + int64_t m = lhsValid[0]; + int64_t k = lhsValid[1]; + int64_t n = rhsValid[1]; + if ((m != ShapedType::kDynamic && (m < 1 || m > 4095)) || + (k != ShapedType::kDynamic && (k < 1 || k > 4095)) || + (n != ShapedType::kDynamic && (n < 1 || n > 4095))) + return op->emitOpError("expects m, k, and n valid sizes to be in [1, 4095]"); + } + if (lhsValid[0] != ShapedType::kDynamic && lhsValid[0] != 1) return op->emitOpError("expects lhs valid_shape[0] to be 1 for tgemv"); if (dstValid[0] != ShapedType::kDynamic && dstValid[0] != 1) @@ -6692,8 +6710,8 @@ LogicalResult TGemvMxAccOp::verify() { getA().getType(), "a_scale", "a")) || failed(verifyScaleTileMatchesOperand(*this, getBScale().getType(), getB().getType(), "b_scale", "b")) || - failed(verifyGemvTileOperands(*this, getA().getType(), getB().getType(), - getDst().getType()))) + failed(verifyA5MxGemvTileOperands(*this, getA().getType(), getB().getType(), + getDst().getType()))) return failure(); if (failed(verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), getDst().getType(), "lhs", "rhs", "dst"))) diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index e9e65050a..4d2b0f6a5 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -9765,8 +9765,10 @@ struct PTOTGemvMXAccToTGEMV_MX Value bScale = peelUnrealized(adaptor.getBScale()); Value dst = peelUnrealized(adaptor.getDst()); + ArrayAttr templateArgs = + buildAccPhaseTemplateArgs(rewriter, op.getAccPhase()); replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", - {dst, cIn, a, aScale, b, bScale}, ArrayAttr{}, + {dst, cIn, a, aScale, b, bScale}, templateArgs, rewriter); return success(); } @@ -9845,8 +9847,10 @@ struct PTOTMatmulMXAccToTMATMUL_MX_ACC Value bScale = peelUnrealized(adaptor.getBScale()); Value dst = peelUnrealized(adaptor.getDst()); + ArrayAttr templateArgs = + buildAccPhaseTemplateArgs(rewriter, op.getAccPhase()); replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TMATMUL_MX", - {dst, cIn, a, aScale, b, bScale}, ArrayAttr{}, + {dst, cIn, a, aScale, b, bScale}, templateArgs, rewriter); return success(); } diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index f17d2e99d..075d37a7c 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -1969,7 +1969,7 @@ struct PTOViewToMemrefPass op->getOperand(kThirdOperandIndex), op->getOperand(kFourthOperandIndex), op->getOperand(kFifthOperandIndex), - op->getOperand(kSixthOperandIndex)); + op->getOperand(kSixthOperandIndex), op.getAccPhaseAttr()); } // --- TMatmulMxBiasOp --- @@ -2054,7 +2054,7 @@ struct PTOViewToMemrefPass op->getOperand(kThirdOperandIndex), op->getOperand(kFourthOperandIndex), op->getOperand(kFifthOperandIndex), - op->getOperand(kSixthOperandIndex)); + op->getOperand(kSixthOperandIndex), op.getAccPhaseAttr()); } // --- TGemvMxBiasOp [A, AScale, B, BScale, Bias, Dst] --- diff --git a/test/lit/pto/tgemv_mx_accphase_emitc.pto b/test/lit/pto/tgemv_mx_accphase_emitc.pto index 7a434746a..dc1b70601 100644 --- a/test/lit/pto/tgemv_mx_accphase_emitc.pto +++ b/test/lit/pto/tgemv_mx_accphase_emitc.pto @@ -6,13 +6,20 @@ module attributes {pto.target_arch = "a5"} { %b = pto.alloc_tile : !pto.tile_buf %a_scale = pto.alloc_tile : !pto.tile_buf %b_scale = pto.alloc_tile : !pto.tile_buf + %c_in = pto.alloc_tile : !pto.tile_buf %dst0 = pto.alloc_tile : !pto.tile_buf %dst1 = pto.alloc_tile : !pto.tile_buf %dst2 = pto.alloc_tile : !pto.tile_buf + %dst3 = pto.alloc_tile : !pto.tile_buf + %dst4 = pto.alloc_tile : !pto.tile_buf + %dst5 = pto.alloc_tile : !pto.tile_buf pto.tgemv.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst0 : !pto.tile_buf) pto.tgemv.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst1 : !pto.tile_buf) {accPhase = #pto} pto.tgemv.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst2 : !pto.tile_buf) {accPhase = #pto} + pto.tgemv.mx.acc ins(%c_in, %a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst3 : !pto.tile_buf) + pto.tgemv.mx.acc ins(%c_in, %a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst4 : !pto.tile_buf) {accPhase = #pto} + pto.tgemv.mx.acc ins(%c_in, %a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst5 : !pto.tile_buf) {accPhase = #pto} return } } @@ -20,3 +27,6 @@ module attributes {pto.target_arch = "a5"} { // A5: TGEMV_MX([[D0:v[0-9]+]], [[A0:v[0-9]+]], [[AS0:v[0-9]+]], [[B0:v[0-9]+]], [[BS0:v[0-9]+]]); // A5: TGEMV_MX([[D1:v[0-9]+]], [[A1:v[0-9]+]], [[AS1:v[0-9]+]], [[B1:v[0-9]+]], [[BS1:v[0-9]+]]); // A5: TGEMV_MX([[D2:v[0-9]+]], [[A2:v[0-9]+]], [[AS2:v[0-9]+]], [[B2:v[0-9]+]], [[BS2:v[0-9]+]]); +// A5: TGEMV_MX([[D3:v[0-9]+]], [[C3:v[0-9]+]], [[A3:v[0-9]+]], [[AS3:v[0-9]+]], [[B3:v[0-9]+]], [[BS3:v[0-9]+]]); +// A5: TGEMV_MX([[D4:v[0-9]+]], [[C4:v[0-9]+]], [[A4:v[0-9]+]], [[AS4:v[0-9]+]], [[B4:v[0-9]+]], [[BS4:v[0-9]+]]); +// A5: TGEMV_MX([[D5:v[0-9]+]], [[C5:v[0-9]+]], [[A5:v[0-9]+]], [[AS5:v[0-9]+]], [[B5:v[0-9]+]], [[BS5:v[0-9]+]]); diff --git a/test/lit/pto/tgemv_mx_verify_invalid_shape_a5.pto b/test/lit/pto/tgemv_mx_verify_invalid_shape_a5.pto new file mode 100644 index 000000000..a0f4da6fe --- /dev/null +++ b/test/lit/pto/tgemv_mx_verify_invalid_shape_a5.pto @@ -0,0 +1,15 @@ +// RUN: not ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + func.func @tgemv_mx_invalid_static_shape(%a : !pto.tile_buf, + %b : !pto.tile_buf) attributes {pto.kernel_kind = #pto.kernel_kind} { + %a_scale = pto.alloc_tile : !pto.tile_buf + %b_scale = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tgemv.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) + return + } +} + +// CHECK: error: 'pto.tgemv.mx' op expects static matmul tile shapes lhs[M,K], rhs[K,N], and dst[M,N] diff --git a/test/lit/pto/tmatmul_mx_accphase_emitc.pto b/test/lit/pto/tmatmul_mx_accphase_emitc.pto index bde097042..d62c8de84 100644 --- a/test/lit/pto/tmatmul_mx_accphase_emitc.pto +++ b/test/lit/pto/tmatmul_mx_accphase_emitc.pto @@ -5,13 +5,20 @@ module attributes {pto.target_arch = "a5"} { %b : !pto.tile_buf) attributes {pto.kernel_kind = #pto.kernel_kind} { %a_scale = pto.alloc_tile : !pto.tile_buf %b_scale = pto.alloc_tile : !pto.tile_buf + %c_in = pto.alloc_tile : !pto.tile_buf %dst0 = pto.alloc_tile : !pto.tile_buf %dst1 = pto.alloc_tile : !pto.tile_buf %dst2 = pto.alloc_tile : !pto.tile_buf + %dst3 = pto.alloc_tile : !pto.tile_buf + %dst4 = pto.alloc_tile : !pto.tile_buf + %dst5 = pto.alloc_tile : !pto.tile_buf pto.tmatmul.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst0 : !pto.tile_buf) pto.tmatmul.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst1 : !pto.tile_buf) {accPhase = #pto} pto.tmatmul.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst2 : !pto.tile_buf) {accPhase = #pto} + pto.tmatmul.mx.acc ins(%c_in, %a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst3 : !pto.tile_buf) + pto.tmatmul.mx.acc ins(%c_in, %a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst4 : !pto.tile_buf) {accPhase = #pto} + pto.tmatmul.mx.acc ins(%c_in, %a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst5 : !pto.tile_buf) {accPhase = #pto} return } } @@ -19,3 +26,6 @@ module attributes {pto.target_arch = "a5"} { // A5: TMATMUL_MX([[D0:v[0-9]+]], [[A0:v[0-9]+]], [[AS0:v[0-9]+]], [[B0:v[0-9]+]], [[BS0:v[0-9]+]]); // A5: TMATMUL_MX([[D1:v[0-9]+]], [[A1:v[0-9]+]], [[AS1:v[0-9]+]], [[B1:v[0-9]+]], [[BS1:v[0-9]+]]); // A5: TMATMUL_MX([[D2:v[0-9]+]], [[A2:v[0-9]+]], [[AS2:v[0-9]+]], [[B2:v[0-9]+]], [[BS2:v[0-9]+]]); +// A5: TMATMUL_MX([[D3:v[0-9]+]], [[C3:v[0-9]+]], [[A3:v[0-9]+]], [[AS3:v[0-9]+]], [[B3:v[0-9]+]], [[BS3:v[0-9]+]]); +// A5: TMATMUL_MX([[D4:v[0-9]+]], [[C4:v[0-9]+]], [[A4:v[0-9]+]], [[AS4:v[0-9]+]], [[B4:v[0-9]+]], [[BS4:v[0-9]+]]); +// A5: TMATMUL_MX([[D5:v[0-9]+]], [[C5:v[0-9]+]], [[A5:v[0-9]+]], [[AS5:v[0-9]+]], [[B5:v[0-9]+]], [[BS5:v[0-9]+]]);