diff --git a/include/PTO/IR/PTOAttrs.td b/include/PTO/IR/PTOAttrs.td index fc1e0ca2f..79b3ff227 100644 --- a/include/PTO/IR/PTOAttrs.td +++ b/include/PTO/IR/PTOAttrs.td @@ -424,6 +424,110 @@ def PTO_SaturationModeAttr : EnumAttr, + I32EnumAttrCase<"HighPrecision", 1, "high_precision"> + ]>; + +def PTO_DivPrecisionAttr : EnumAttr { + let summary = "TDIV precision mode attribute"; +} +//===----------------------------------------------------------------------===// +// TEXP template controls +//===----------------------------------------------------------------------===// + +def PTO_ExpPrecisionEnum : PTO_I32Enum< + "ExpPrecision", "PTO TEXP precision mode", [ + I32EnumAttrCase<"Default", 0, "default">, + I32EnumAttrCase<"HighPrecision", 1, "high_precision"> + ]>; + +def PTO_ExpPrecisionAttr : EnumAttr { + let summary = "TEXP precision mode attribute"; +} +//===----------------------------------------------------------------------===// +// TLOG template controls +//===----------------------------------------------------------------------===// + +def PTO_LogPrecisionEnum : PTO_I32Enum< + "LogPrecision", "PTO TLOG precision mode", [ + I32EnumAttrCase<"Default", 0, "default">, + I32EnumAttrCase<"HighPrecision", 1, "high_precision"> + ]>; + +def PTO_LogPrecisionAttr : EnumAttr { + let summary = "TLOG precision mode attribute"; +} +//===----------------------------------------------------------------------===// +// TRECIP template controls +//===----------------------------------------------------------------------===// + +def PTO_RecipPrecisionEnum : PTO_I32Enum< + "RecipPrecision", "PTO TRECIP precision mode", [ + I32EnumAttrCase<"Default", 0, "default">, + I32EnumAttrCase<"HighPrecision", 1, "high_precision"> + ]>; + +def PTO_RecipPrecisionAttr : EnumAttr { + let summary = "TRECIP precision mode attribute"; +} +//===----------------------------------------------------------------------===// +// TREM template controls +//===----------------------------------------------------------------------===// + +def PTO_RemPrecisionEnum : PTO_I32Enum< + "RemPrecision", "PTO TREM precision mode", [ + I32EnumAttrCase<"Default", 0, "default">, + I32EnumAttrCase<"HighPrecision", 1, "high_precision"> + ]>; + +def PTO_RemPrecisionAttr : EnumAttr { + let summary = "TREM precision mode attribute"; +} +//===----------------------------------------------------------------------===// +// TRSQRT template controls +//===----------------------------------------------------------------------===// + +def PTO_RsqrtPrecisionEnum : PTO_I32Enum< + "RsqrtPrecision", "PTO TRSQRT precision mode", [ + I32EnumAttrCase<"Default", 0, "default">, + I32EnumAttrCase<"HighPrecision", 1, "high_precision"> + ]>; + +def PTO_RsqrtPrecisionAttr : EnumAttr { + let summary = "TRSQRT precision mode attribute"; +} +//===----------------------------------------------------------------------===// +// TSQRT template controls +//===----------------------------------------------------------------------===// + +def PTO_SqrtPrecisionEnum : PTO_I32Enum< + "SqrtPrecision", "PTO TSQRT precision mode", [ + I32EnumAttrCase<"Default", 0, "default">, + I32EnumAttrCase<"HighPrecision", 1, "high_precision"> + ]>; + +def PTO_SqrtPrecisionAttr : EnumAttr { + let summary = "TSQRT precision mode attribute"; +} +//===----------------------------------------------------------------------===// +// TFMOD template controls +//===----------------------------------------------------------------------===// + +def PTO_FmodPrecisionEnum : PTO_I32Enum< + "FmodPrecision", "PTO TFMOD precision mode", [ + I32EnumAttrCase<"Default", 0, "default">, + I32EnumAttrCase<"HighPrecision", 1, "high_precision"> + ]>; + +def PTO_FmodPrecisionAttr : EnumAttr { + let summary = "TFMOD precision mode attribute"; +} +//===----------------------------------------------------------------------===// // TStore template controls //===----------------------------------------------------------------------===// diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index 344b399a6..99b00c28d 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -817,7 +817,8 @@ def TMatmulBiasOp : PTO_TOp<"tmatmul.bias", [ PTODpsType:$a, PTODpsType:$b, PTODpsType:$bias, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$accPhase ); let results = (outs Optional:$result); @@ -856,7 +857,8 @@ def TMatmulMxOp : PTO_TOp<"tmatmul.mx", [ PTODpsType:$a_scale, PTODpsType:$b, PTODpsType:$b_scale, - PTODpsType:$dst); + PTODpsType:$dst, + DefaultValuedAttr:$accPhase); let results = (outs Optional:$result); let hasVerifier = 1; @@ -892,7 +894,8 @@ def TMatmulMxAccOp : PTO_TOp<"tmatmul.mx.acc", [ PTODpsType:$a_scale, PTODpsType:$b, PTODpsType:$b_scale, - PTODpsType:$dst); + PTODpsType:$dst, + DefaultValuedAttr:$accPhase); let results = (outs Optional:$result); let hasVerifier = 1; @@ -1031,7 +1034,8 @@ def TGemvOp : PTO_TOp<"tgemv", [ let arguments = (ins PTODpsType:$lhs, PTODpsType:$rhs, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$accPhase ); let results = (outs @@ -1067,7 +1071,8 @@ def TGemvAccOp : PTO_TOp<"tgemv.acc", [ PTODpsType:$acc_in, PTODpsType:$lhs, PTODpsType:$rhs, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$accPhase ); let results = (outs @@ -1102,7 +1107,8 @@ def TGemvBiasOp : PTO_TOp<"tgemv.bias", [ PTODpsType :$a, PTODpsType :$b, PTODpsType :$bias, - PTODpsType :$dst + PTODpsType :$dst, + DefaultValuedAttr:$accPhase ); let results = (outs Optional:$result); @@ -1137,7 +1143,8 @@ def TGemvMxOp : PTO_TOp<"tgemv.mx", [ PTODpsType:$a_scale, PTODpsType:$b, PTODpsType:$b_scale, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$accPhase ); let results = (outs Optional:$result); @@ -1171,7 +1178,8 @@ def TGemvMxAccOp : PTO_TOp<"tgemv.mx.acc", [ PTODpsType:$a_scale, PTODpsType:$b, PTODpsType:$b_scale, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$accPhase ); let results = (outs Optional:$result); @@ -3425,7 +3433,8 @@ def TColExpandDivOp : PTO_TOp<"tcolexpanddiv", [ let arguments = (ins PTODpsType:$src0, PTODpsType:$src1, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$precisionType ); let results = (outs); @@ -3804,7 +3813,8 @@ def TDivOp : PTO_TOp<"tdiv", [ let arguments = (ins PTODpsType:$src0, PTODpsType:$src1, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$precisionType ); let results = (outs); @@ -3860,7 +3870,8 @@ def TFModOp : PTO_TOp<"tfmod", [ let arguments = (ins PTODpsType:$src0, PTODpsType:$src1, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$precisionType ); let results = (outs); @@ -3917,7 +3928,8 @@ def TExpOp : PTO_TOp<"texp", [ let arguments = (ins PTODpsType:$src, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$precisionType ); let results = (outs); @@ -4355,7 +4367,8 @@ def TLogOp : PTO_TOp<"tlog", [ let arguments = (ins PTODpsType:$src, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$precisionType ); let results = (outs); @@ -5118,7 +5131,8 @@ def TRecipOp: PTO_TOp<"trecip", [ let arguments = (ins PTODpsType:$src, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$precisionType ); let results = (outs); @@ -5182,7 +5196,8 @@ def TRemOp: PTO_TOp<"trem", [ PTODpsType:$src0, PTODpsType:$src1, PTODpsType:$tmp, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$precisionType ); let results = (outs); @@ -5311,7 +5326,8 @@ def TRowExpandDivOp: PTO_TOp<"trowexpanddiv", [ PTODpsType:$src0, PTODpsType:$src1, Optional:$tmp, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$precisionType ); let results = (outs); @@ -5702,7 +5718,8 @@ def TRsqrtOp: PTO_TOp<"trsqrt", [ let arguments = (ins PTODpsType:$src, PTODpsType:$dst, - Optional:$tmp + Optional:$tmp, + DefaultValuedAttr:$precisionType ); let results = (outs); @@ -6012,7 +6029,8 @@ def TSqrtOp: PTO_TOp<"tsqrt", [ let arguments = (ins PTODpsType:$src, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$precisionType ); let results = (outs); diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index 5427ac36e..e9d4b3f73 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -4121,6 +4121,81 @@ static LogicalResult verifyGemvTileOperands(Operation *op, Type lhsTy, Type rhsT return failure(); } +static LogicalResult verifyA5MxGemvTileOperands(Operation *op, Type lhsTy, + Type rhsTy, Type dstTy) { + if (failed(verifyTileBufCommon(op, lhsTy, "lhs", /*allowLowPrecision=*/true)) || + failed(verifyTileBufCommon(op, rhsTy, "rhs", /*allowLowPrecision=*/true)) || + failed(verifyAccTileCommon(op, dstTy, "dst"))) + return failure(); + + auto lhsSpace = getPTOMemorySpaceEnum(lhsTy); + auto rhsSpace = getPTOMemorySpaceEnum(rhsTy); + auto dstSpace = getPTOMemorySpaceEnum(dstTy); + if (!lhsSpace || !rhsSpace || !dstSpace) + return op->emitOpError("expects lhs, rhs, and dst to have explicit address spaces"); + if (*lhsSpace != pto::AddressSpace::LEFT || *rhsSpace != pto::AddressSpace::RIGHT || + *dstSpace != pto::AddressSpace::ACC) + return op->emitOpError( + "expects lhs, rhs, and dst to use the left, right, and acc address spaces"); + + auto lhsShape = getMatmulLogicalShapeVec(lhsTy); + auto rhsShape = getMatmulLogicalShapeVec(rhsTy); + auto dstShape = getMatmulLogicalShapeVec(dstTy); + if ((lhsShape[0] != dstShape[0] || rhsShape[1] != dstShape[1] || + lhsShape[1] != rhsShape[0])) + return op->emitOpError( + "expects static matmul tile shapes lhs[M,K], rhs[K,N], and dst[M,N]"); + + auto lhsValid = getValidShapeVec(lhsTy); + auto rhsValid = getValidShapeVec(rhsTy); + auto dstValid = getValidShapeVec(dstTy); + if (lhsValid.size() == 2 && rhsValid.size() == 2) { + int64_t m = lhsValid[0]; + int64_t k = lhsValid[1]; + int64_t n = rhsValid[1]; + if ((m != ShapedType::kDynamic && (m < 1 || m > 4095)) || + (k != ShapedType::kDynamic && (k < 1 || k > 4095)) || + (n != ShapedType::kDynamic && (n < 1 || n > 4095))) + return op->emitOpError("expects m, k, and n valid sizes to be in [1, 4095]"); + } + + if (lhsValid[0] != ShapedType::kDynamic && lhsValid[0] != 1) + return op->emitOpError("expects lhs valid_shape[0] to be 1 for tgemv"); + if (dstValid[0] != ShapedType::kDynamic && dstValid[0] != 1) + return op->emitOpError("expects dst valid_shape[0] to be 1 for tgemv"); + if (lhsValid[1] != ShapedType::kDynamic && rhsValid[0] != ShapedType::kDynamic && + lhsValid[1] != rhsValid[0]) + return op->emitOpError() + << "expects lhs valid_shape[1] to equal rhs valid_shape[0], but got " + << lhsValid[1] << " vs " << rhsValid[0]; + if (rhsValid[1] != ShapedType::kDynamic && dstValid[1] != ShapedType::kDynamic && + rhsValid[1] != dstValid[1]) + return op->emitOpError() + << "expects rhs valid_shape[1] to equal dst valid_shape[1], but got " + << rhsValid[1] << " vs " << dstValid[1]; + + auto lhsTb = dyn_cast(lhsTy); + auto rhsTb = dyn_cast(rhsTy); + auto dstTb = dyn_cast(dstTy); + if (!lhsTb || !rhsTb || !dstTb) + return success(); + + if (lhsTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor)) + return op->emitOpError("expects lhs to use the col_major blayout on A5"); + if (rhsTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor)) + return op->emitOpError("expects rhs to use the row_major blayout on A5"); + if (dstTb.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor)) + return op->emitOpError("expects dst to use the col_major blayout on A5"); + + if (lhsTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor)) + return op->emitOpError("expects lhs to use the row_major slayout on A5"); + if (rhsTb.getSLayoutValueI32() != static_cast(pto::SLayout::ColMajor)) + return op->emitOpError("expects rhs to use the col_major slayout on A5"); + if (dstTb.getSLayoutValueI32() != static_cast(pto::SLayout::RowMajor)) + return op->emitOpError("expects dst to use the row_major slayout on A5"); + return success(); +} + static LogicalResult verifyMatBiasTileA2A3(Operation *op, Type biasTy, Type dstTy, bool requireFloatBias) { if (failed(verifyTileBufCommon(op, biasTy, "bias"))) @@ -5625,7 +5700,9 @@ static bool isA5Fp8LikeType(Type ty) { } static bool isA5MxInputType(Type ty) { - return isA5Fp8LikeType(ty); + if (auto ft = dyn_cast(ty)) + return ft.isFloat8E4M3FN(); + return false; } static LogicalResult verifyA5MxTypeTriple(Operation *op, Type lhsTy, Type rhsTy, @@ -5635,10 +5712,12 @@ static LogicalResult verifyA5MxTypeTriple(Operation *op, Type lhsTy, Type rhsTy, Type rhsElem = getElemTy(rhsTy); Type dstElem = getElemTy(dstTy); - if (!isA5MxInputType(lhsElem) || !isA5MxInputType(rhsElem)) - return op->emitOpError() - << "expects A5 mx operands " << lhsName << " and " << rhsName - << " to use fp8 element types"; + if (!isA5MxInputType(lhsElem)) + return op->emitOpError() << lhsName << ": dtype " << lhsElem + << " is not supported by this op yet"; + if (!isA5MxInputType(rhsElem)) + return op->emitOpError() << rhsName << ": dtype " << rhsElem + << " is not supported by this op yet"; if (!dstElem.isF32()) return op->emitOpError() @@ -6609,11 +6688,11 @@ LogicalResult TGemvMxOp::verify() { getA().getType(), "a_scale", "a")) || failed(verifyScaleTileMatchesOperand(*this, getBScale().getType(), getB().getType(), "b_scale", "b")) || - failed(verifyGemvTileOperands(*this, getA().getType(), getB().getType(), - getDst().getType()))) + failed(verifyA5MxGemvTileOperands(*this, getA().getType(), getB().getType(), + getDst().getType()))) return failure(); if (failed(verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), - getDst().getType(), "a", "b", "dst"))) + getDst().getType(), "lhs", "rhs", "dst"))) return failure(); return verifyMatmulLike(*this, getA().getType(), getB().getType(), getDst().getType()); @@ -6631,11 +6710,11 @@ LogicalResult TGemvMxAccOp::verify() { getA().getType(), "a_scale", "a")) || failed(verifyScaleTileMatchesOperand(*this, getBScale().getType(), getB().getType(), "b_scale", "b")) || - failed(verifyGemvTileOperands(*this, getA().getType(), getB().getType(), - getDst().getType()))) + failed(verifyA5MxGemvTileOperands(*this, getA().getType(), getB().getType(), + getDst().getType()))) return failure(); if (failed(verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), - getDst().getType(), "a", "b", "dst"))) + getDst().getType(), "lhs", "rhs", "dst"))) return failure(); if (failed(verifyTileBufSameElemType(*this, getCIn().getType(), getDst().getType(), "c_in", "dst")) || @@ -6663,7 +6742,7 @@ LogicalResult TGemvMxBiasOp::verify() { /*requireFloatBias=*/true))) return failure(); if (failed(verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), - getDst().getType(), "a", "b", "dst"))) + getDst().getType(), "lhs", "rhs", "dst"))) return failure(); auto biasShape = getShapeVec(getBias().getType()); auto dstShape = getShapeVec(getDst().getType()); @@ -6710,7 +6789,7 @@ LogicalResult TMatmulMxOp::verify() { if (failed(verifyA2A3())) return failure(); return verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), - getDst().getType(), "a", "b", "dst"); + getDst().getType(), "lhs", "rhs", "dst"); }; return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); } @@ -6727,7 +6806,7 @@ LogicalResult TMatmulMxAccOp::verify() { if (failed(verifyA2A3())) return failure(); if (failed(verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), - getDst().getType(), "a", "b", "dst"))) + getDst().getType(), "lhs", "rhs", "dst"))) return failure(); if (failed(verifyTileBufSameElemType(*this, getCIn().getType(), getDst().getType(), "c_in", "dst")) || @@ -6754,7 +6833,7 @@ LogicalResult TMatmulMxBiasOp::verify() { if (failed(verifyA2A3())) return failure(); return verifyA5MxTypeTriple(*this, getA().getType(), getB().getType(), - getDst().getType(), "a", "b", "dst"); + getDst().getType(), "lhs", "rhs", "dst"); }; return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); } diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index b3f0c6bbd..4d2b0f6a5 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -4520,8 +4520,8 @@ struct PTOTStoreToTSTORE : public OpConversionPattern { // Render `pto.tmatmul` as one of three forms depending on the optional // `acc_phase` attribute: // * absent / Unspecified -> `TMATMUL(dst, lhs, rhs)` -// * Partial -> `TMATMUL(dst, lhs, rhs)` -// * Final -> `TMATMUL(dst, lhs, rhs)` +// * Partial -> `TMATMUL(dst, lhs, rhs)` +// * Final -> `TMATMUL(dst, lhs, rhs)` // The Unspecified default keeps backward compatibility with all upstream IR // that does not yet emit an explicit phase attribute. static ArrayAttr buildAccPhaseTemplateArgs(ConversionPatternRewriter &rewriter, @@ -4531,10 +4531,10 @@ static ArrayAttr buildAccPhaseTemplateArgs(ConversionPatternRewriter &rewriter, case pto::AccPhase::Unspecified: return ArrayAttr{}; case pto::AccPhase::Partial: - tmpl = "AccPhase::Partial"; + tmpl = "pto::AccPhase::Partial"; break; case pto::AccPhase::Final: - tmpl = "AccPhase::Final"; + tmpl = "pto::AccPhase::Final"; break; } if (tmpl.empty()) @@ -4585,10 +4585,12 @@ struct PTOTGemvToTGEMV : public OpConversionPattern { Value rhs = peelUnrealized(adaptor.getRhs()); // B (Vector) Value dst = peelUnrealized(adaptor.getDst()); // C (Result) - // 2. 直接生成函数调用 TGEMV(dst, lhs, rhs) + ArrayAttr templateArgs = + buildAccPhaseTemplateArgs(rewriter, op.getAccPhase()); + rewriter.create( op.getLoc(), TypeRange{}, "TGEMV", - ArrayAttr{}, ArrayAttr{}, + /*args=*/ArrayAttr{}, /*template_args=*/templateArgs, ValueRange{dst, lhs, rhs}); // 3. 处理 Op 替换/删除 @@ -4618,10 +4620,12 @@ struct PTOTGemvAccToTGEMVACC : public OpConversionPattern { Value rhs = peelUnrealized(adaptor.getRhs()); // B (Vector) Value dst = peelUnrealized(adaptor.getDst()); // AccNew - // 2. 直接生成函数调用 TGEMV_ACC(dst, accIn, lhs, rhs) + ArrayAttr templateArgs = + buildAccPhaseTemplateArgs(rewriter, op.getAccPhase()); + rewriter.create( op.getLoc(), TypeRange{}, "TGEMV_ACC", - ArrayAttr{}, ArrayAttr{}, + /*args=*/ArrayAttr{}, /*template_args=*/templateArgs, ValueRange{dst, accIn, lhs, rhs}); // 3. 处理 Op 替换/删除 @@ -7683,15 +7687,31 @@ struct PTOColExpandDivToEmitC : public OpConversionPattern LogicalResult matchAndRewrite(pto::TColExpandDivOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); Value src0 = peelUnrealized(adaptor.getSrc0()); Value src1 = peelUnrealized(adaptor.getSrc1()); Value dst = peelUnrealized(adaptor.getDst()); + ArrayAttr templateArgs; + if (op.getPrecisionType() != pto::DivPrecision::Default) { + StringRef precisionTok; + switch (op.getPrecisionType()) { + case pto::DivPrecision::Default: + precisionTok = "pto::DivAlgorithm::DEFAULT"; + break; + case pto::DivPrecision::HighPrecision: + precisionTok = "pto::DivAlgorithm::HIGH_PRECISION"; + break; + } + templateArgs = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, precisionTok)}); + } + rewriter.create( loc, TypeRange{}, "TCOLEXPANDDIV", /*args=*/ArrayAttr{}, - /*templateArgs=*/ArrayAttr{}, + /*templateArgs=*/templateArgs, /*operands=*/ValueRange{dst, src0, src1}); rewriter.eraseOp(op); @@ -8136,10 +8156,26 @@ struct PTODivToTDIV : public OpConversionPattern { Value src0 = peelUnrealized(adaptor.getSrc0()); Value src1 = peelUnrealized(adaptor.getSrc1()); Value dst = peelUnrealized(adaptor.getDst()); + auto *ctx = rewriter.getContext(); + + ArrayAttr templateArgs; + if (op.getPrecisionType() != pto::DivPrecision::Default) { + StringRef precisionTok; + switch (op.getPrecisionType()) { + case pto::DivPrecision::Default: + precisionTok = "pto::DivAlgorithm::DEFAULT"; + break; + case pto::DivPrecision::HighPrecision: + precisionTok = "pto::DivAlgorithm::HIGH_PRECISION"; + break; + } + templateArgs = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, precisionTok)}); + } rewriter.create( op.getLoc(), TypeRange{}, "TDIV", - ArrayAttr{}, ArrayAttr{}, + ArrayAttr{}, templateArgs, ValueRange{dst, src0, src1}); rewriter.eraseOp(op); @@ -8210,13 +8246,29 @@ struct PTOExpToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TExpOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); Value src = peelUnrealized(adaptor.getSrc()); Value dst = peelUnrealized(adaptor.getDst()); + ArrayAttr templateArgs; + if (op.getPrecisionType() != pto::ExpPrecision::Default) { + StringRef precisionTok; + switch (op.getPrecisionType()) { + case pto::ExpPrecision::Default: + precisionTok = "pto::ExpAlgorithm::DEFAULT"; + break; + case pto::ExpPrecision::HighPrecision: + precisionTok = "pto::ExpAlgorithm::HIGH_PRECISION"; + break; + } + templateArgs = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, precisionTok)}); + } + rewriter.create( loc, TypeRange{}, "TEXP", - ArrayAttr{}, ArrayAttr{}, + ArrayAttr{}, templateArgs, ValueRange{dst, src}); rewriter.eraseOp(op); @@ -8567,14 +8619,29 @@ struct PTOLogToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TLogOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); Value src = peelUnrealized(adaptor.getSrc()); Value dst = peelUnrealized(adaptor.getDst()); SmallVector operands{dst, src}; + ArrayAttr templateArgs; + if (op.getPrecisionType() != pto::LogPrecision::Default) { + StringRef precisionTok; + switch (op.getPrecisionType()) { + case pto::LogPrecision::Default: + precisionTok = "pto::LogAlgorithm::DEFAULT"; + break; + case pto::LogPrecision::HighPrecision: + precisionTok = "pto::LogAlgorithm::HIGH_PRECISION"; + break; + } + templateArgs = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, precisionTok)}); + } rewriter.create( loc, TypeRange{}, "TLOG", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, /*operands=*/operands); rewriter.eraseOp(op); @@ -9374,14 +9441,29 @@ struct PTORecipToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TRecipOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); Value src = peelUnrealized(adaptor.getSrc()); Value dst = peelUnrealized(adaptor.getDst()); SmallVector operands{dst, src}; + ArrayAttr templateArgs; + if (op.getPrecisionType() != pto::RecipPrecision::Default) { + StringRef precisionTok; + switch (op.getPrecisionType()) { + case pto::RecipPrecision::Default: + precisionTok = "pto::RecipAlgorithm::DEFAULT"; + break; + case pto::RecipPrecision::HighPrecision: + precisionTok = "pto::RecipAlgorithm::HIGH_PRECISION"; + break; + } + templateArgs = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, precisionTok)}); + } rewriter.create( loc, TypeRange{}, "TRECIP", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, /*operands=*/operands); rewriter.eraseOp(op); @@ -9422,15 +9504,30 @@ struct PTORemToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TRemOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); Value src0 = peelUnrealized(adaptor.getSrc0()); Value src1 = peelUnrealized(adaptor.getSrc1()); Value tmp = peelUnrealized(adaptor.getTmp()); Value dst = peelUnrealized(adaptor.getDst()); SmallVector operands{dst, src0, src1, tmp}; + ArrayAttr templateArgs; + if (op.getPrecisionType() != pto::RemPrecision::Default) { + StringRef precisionTok; + switch (op.getPrecisionType()) { + case pto::RemPrecision::Default: + precisionTok = "pto::RemAlgorithm::DEFAULT"; + break; + case pto::RemPrecision::HighPrecision: + precisionTok = "pto::RemAlgorithm::HIGH_PRECISION"; + break; + } + templateArgs = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, precisionTok)}); + } rewriter.create( loc, TypeRange{}, "TREM", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, /*operands=*/operands); rewriter.eraseOp(op); @@ -9444,15 +9541,30 @@ struct PTOFModToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TFModOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); Value src0 = peelUnrealized(adaptor.getSrc0()); Value src1 = peelUnrealized(adaptor.getSrc1()); Value dst = peelUnrealized(adaptor.getDst()); SmallVector operands{dst, src0, src1}; + ArrayAttr templateArgs; + if (op.getPrecisionType() != pto::FmodPrecision::Default) { + StringRef precisionTok; + switch (op.getPrecisionType()) { + case pto::FmodPrecision::Default: + precisionTok = "pto::FmodAlgorithm::DEFAULT"; + break; + case pto::FmodPrecision::HighPrecision: + precisionTok = "pto::FmodAlgorithm::HIGH_PRECISION"; + break; + } + templateArgs = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, precisionTok)}); + } rewriter.create( loc, TypeRange{}, "TFMOD", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, /*operands=*/operands); rewriter.eraseOp(op); @@ -9586,25 +9698,13 @@ struct PTORowExpandExpdifToEmitC // PTOConvert.cpp (add lowering + patterns.add for TROWEXPANDDIV DPS/memref op) //===----------------------------------------------------------------------===// // Helper: replace or erase based on whether op has results. -static void replaceOrEraseWithOpaqueCall(Operation *op, - StringRef callee, - ArrayRef args, - ConversionPatternRewriter &rewriter) { - TypeRange resultTypes = op->getResultTypes(); - auto call = rewriter.create( - op->getLoc(), resultTypes, callee, ArrayAttr{}, ArrayAttr{}, ValueRange(args)); - if (resultTypes.empty()) - rewriter.eraseOp(op); - else - rewriter.replaceOp(op, call.getResults()); -} - static void replaceOrEraseWithOpaqueCallAndReturnDst(Operation *op, Value dst, StringRef callee, ArrayRef args, + ArrayAttr templateArgs, ConversionPatternRewriter &rewriter) { rewriter.create( - op->getLoc(), TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, ValueRange(args)); + op->getLoc(), TypeRange{}, callee, ArrayAttr{}, templateArgs, ValueRange(args)); if (op->getNumResults() == 1) rewriter.replaceOp(op, dst); else @@ -9623,8 +9723,10 @@ struct PTOTGemvBiasToTGEMV_BIAS Value bias = peelUnrealized(adaptor.getBias()); Value dst = peelUnrealized(adaptor.getDst()); - replaceOrEraseWithOpaqueCall(op.getOperation(), "TGEMV_BIAS", - {dst, a, b, bias}, rewriter); + ArrayAttr templateArgs = + buildAccPhaseTemplateArgs(rewriter, op.getAccPhase()); + replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_BIAS", + {dst, a, b, bias}, templateArgs, rewriter); return success(); } }; @@ -9641,8 +9743,11 @@ struct PTOTGemvMXToTGEMV_MX Value bScale = peelUnrealized(adaptor.getBScale()); Value dst = peelUnrealized(adaptor.getDst()); + ArrayAttr templateArgs = + buildAccPhaseTemplateArgs(rewriter, op.getAccPhase()); replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", - {dst, a, aScale, b, bScale}, rewriter); + {dst, a, aScale, b, bScale}, templateArgs, + rewriter); return success(); } }; @@ -9660,8 +9765,11 @@ struct PTOTGemvMXAccToTGEMV_MX Value bScale = peelUnrealized(adaptor.getBScale()); Value dst = peelUnrealized(adaptor.getDst()); + ArrayAttr templateArgs = + buildAccPhaseTemplateArgs(rewriter, op.getAccPhase()); replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", - {dst, cIn, a, aScale, b, bScale}, rewriter); + {dst, cIn, a, aScale, b, bScale}, templateArgs, + rewriter); return success(); } }; @@ -9680,7 +9788,8 @@ struct PTOTGemvMXBiasToTGEMV_MX Value dst = peelUnrealized(adaptor.getDst()); replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TGEMV_MX", - {dst, a, aScale, b, bScale, bias}, rewriter); + {dst, a, aScale, b, bScale, bias}, ArrayAttr{}, + rewriter); return success(); } }; @@ -9696,8 +9805,10 @@ struct PTOTMatmulBiasToTMATMUL_BIAS Value bias = peelUnrealized(adaptor.getBias()); Value dst = peelUnrealized(adaptor.getDst()); - replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_BIAS", - {dst, a, b, bias}, rewriter); + ArrayAttr templateArgs = + buildAccPhaseTemplateArgs(rewriter, op.getAccPhase()); + replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TMATMUL_BIAS", + {dst, a, b, bias}, templateArgs, rewriter); return success(); } }; @@ -9714,8 +9825,11 @@ struct PTOTMatmulMXToTMATMUL_MX Value bScale = peelUnrealized(adaptor.getBScale()); Value dst = peelUnrealized(adaptor.getDst()); - replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_MX", - {dst, a, aScale, b, bScale}, rewriter); + ArrayAttr templateArgs = + buildAccPhaseTemplateArgs(rewriter, op.getAccPhase()); + replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TMATMUL_MX", + {dst, a, aScale, b, bScale}, templateArgs, + rewriter); return success(); } }; @@ -9733,8 +9847,11 @@ struct PTOTMatmulMXAccToTMATMUL_MX_ACC Value bScale = peelUnrealized(adaptor.getBScale()); Value dst = peelUnrealized(adaptor.getDst()); - replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_MX", - {dst, cIn, a, aScale, b, bScale}, rewriter); + ArrayAttr templateArgs = + buildAccPhaseTemplateArgs(rewriter, op.getAccPhase()); + replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TMATMUL_MX", + {dst, cIn, a, aScale, b, bScale}, templateArgs, + rewriter); return success(); } }; @@ -9752,8 +9869,9 @@ struct PTOTMatmulMXBiasToTMATMUL_MX_BIAS Value bias = peelUnrealized(adaptor.getBias()); Value dst = peelUnrealized(adaptor.getDst()); - replaceOrEraseWithOpaqueCall(op.getOperation(), "TMATMUL_MX", - {dst, a, aScale, b, bScale, bias}, rewriter); + replaceOrEraseWithOpaqueCallAndReturnDst(op.getOperation(), dst, "TMATMUL_MX", + {dst, a, aScale, b, bScale, bias}, ArrayAttr{}, + rewriter); return success(); } }; @@ -9764,6 +9882,7 @@ struct PTORowExpandDivToEmitC : public OpConversionPattern LogicalResult matchAndRewrite(pto::TRowExpandDivOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); Value src0 = peelUnrealized(adaptor.getSrc0()); Value src1 = peelUnrealized(adaptor.getSrc1()); @@ -9775,9 +9894,23 @@ struct PTORowExpandDivToEmitC : public OpConversionPattern operands.assign({dst, src0, src1, tmp}); else operands.assign({dst, src0, src1}); + ArrayAttr templateArgs; + if (op.getPrecisionType() != pto::DivPrecision::Default) { + StringRef precisionTok; + switch (op.getPrecisionType()) { + case pto::DivPrecision::Default: + precisionTok = "pto::DivAlgorithm::DEFAULT"; + break; + case pto::DivPrecision::HighPrecision: + precisionTok = "pto::DivAlgorithm::HIGH_PRECISION"; + break; + } + templateArgs = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, precisionTok)}); + } rewriter.create( loc, TypeRange{}, "TROWEXPANDDIV", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, /*operands=*/operands); rewriter.eraseOp(op); @@ -10054,15 +10187,30 @@ struct PTORsqrtToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TRsqrtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); Value src = peelUnrealized(adaptor.getSrc()); Value dst = peelUnrealized(adaptor.getDst()); SmallVector operands{dst, src}; if (Value tmp = adaptor.getTmp()) operands.push_back(peelUnrealized(tmp)); + ArrayAttr templateArgs; + if (op.getPrecisionType() != pto::RsqrtPrecision::Default) { + StringRef precisionTok; + switch (op.getPrecisionType()) { + case pto::RsqrtPrecision::Default: + precisionTok = "pto::RsqrtAlgorithm::DEFAULT"; + break; + case pto::RsqrtPrecision::HighPrecision: + precisionTok = "pto::RsqrtAlgorithm::HIGH_PRECISION"; + break; + } + templateArgs = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, precisionTok)}); + } rewriter.create( loc, TypeRange{}, "TRSQRT", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, /*operands=*/operands); rewriter.eraseOp(op); @@ -10298,14 +10446,29 @@ struct PTOSqrtSToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TSqrtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); Value src = peelUnrealized(adaptor.getSrc()); Value dst = peelUnrealized(adaptor.getDst()); SmallVector operands{dst, src}; + ArrayAttr templateArgs; + if (op.getPrecisionType() != pto::SqrtPrecision::Default) { + StringRef precisionTok; + switch (op.getPrecisionType()) { + case pto::SqrtPrecision::Default: + precisionTok = "pto::SqrtAlgorithm::DEFAULT"; + break; + case pto::SqrtPrecision::HighPrecision: + precisionTok = "pto::SqrtAlgorithm::HIGH_PRECISION"; + break; + } + templateArgs = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, precisionTok)}); + } rewriter.create( loc, TypeRange{}, "TSQRT", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, /*operands=*/operands); rewriter.eraseOp(op); diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index c21669b81..075d37a7c 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -1864,7 +1864,8 @@ struct PTOViewToMemrefPass IRRewriter rewriter(ctx); rewriter.setInsertionPoint(op); rewriter.replaceOpWithNewOp( - op, TypeRange{}, op->getOperand(0), op->getOperand(1)); + op, TypeRange{}, op->getOperand(0), op->getOperand(1), + op.getPrecisionTypeAttr()); } // --- TMulOp [Src, Scalar, Dst] --- @@ -1939,7 +1940,7 @@ struct PTOViewToMemrefPass op, TypeRange{}, op->getOperand(0), op->getOperand(1), op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex)); + op->getOperand(kFourthOperandIndex), op.getAccPhaseAttr()); } // --- TMatmulMxOp--- @@ -1953,7 +1954,7 @@ struct PTOViewToMemrefPass op->getOperand(0), op->getOperand(1), op->getOperand(kThirdOperandIndex), op->getOperand(kFourthOperandIndex), - op->getOperand(kFifthOperandIndex)); + op->getOperand(kFifthOperandIndex), op.getAccPhaseAttr()); } // --- TMatmulMxAccOp --- @@ -1968,7 +1969,7 @@ struct PTOViewToMemrefPass op->getOperand(kThirdOperandIndex), op->getOperand(kFourthOperandIndex), op->getOperand(kFifthOperandIndex), - op->getOperand(kSixthOperandIndex)); + op->getOperand(kSixthOperandIndex), op.getAccPhaseAttr()); } // --- TMatmulMxBiasOp --- @@ -1998,7 +1999,7 @@ struct PTOViewToMemrefPass Value dst = op->getOperand(kThirdOperandIndex); rewriter.replaceOpWithNewOp( - op, TypeRange{}, lhs, rhs, dst); + op, TypeRange{}, lhs, rhs, dst, op.getAccPhaseAttr()); } // --- TGemvAccOp [Acc, Lhs, Rhs, Dst] --- @@ -2011,7 +2012,7 @@ struct PTOViewToMemrefPass op, TypeRange{}, op->getOperand(0), op->getOperand(1), op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex)); + op->getOperand(kFourthOperandIndex), op.getAccPhaseAttr()); } // --- TGemvBiasOp [Acc, Lhs, Rhs, Bias, Dst] --- @@ -2024,7 +2025,7 @@ struct PTOViewToMemrefPass op, TypeRange{}, op->getOperand(0), op->getOperand(1), op->getOperand(kThirdOperandIndex), - op->getOperand(kFourthOperandIndex)); + op->getOperand(kFourthOperandIndex), op.getAccPhaseAttr()); } // --- TGemvMxOp [A, AScale, B, BScale, Dst] --- @@ -2038,7 +2039,7 @@ struct PTOViewToMemrefPass op->getOperand(0), op->getOperand(1), op->getOperand(kThirdOperandIndex), op->getOperand(kFourthOperandIndex), - op->getOperand(kFifthOperandIndex)); + op->getOperand(kFifthOperandIndex), op.getAccPhaseAttr()); } // --- TGemvMxAccOp [CIn, A, AScale, B, BScale, Dst] --- @@ -2053,7 +2054,7 @@ struct PTOViewToMemrefPass op->getOperand(kThirdOperandIndex), op->getOperand(kFourthOperandIndex), op->getOperand(kFifthOperandIndex), - op->getOperand(kSixthOperandIndex)); + op->getOperand(kSixthOperandIndex), op.getAccPhaseAttr()); } // --- TGemvMxBiasOp [A, AScale, B, BScale, Bias, Dst] --- @@ -2694,7 +2695,8 @@ struct PTOViewToMemrefPass TypeRange{}, src0, src1, - dst); + dst, + op.getPrecisionTypeAttr()); } DefaultInlineVector divsops; @@ -3059,7 +3061,8 @@ struct PTOViewToMemrefPass op, TypeRange{}, src, - dst); + dst, + op.getPrecisionTypeAttr()); } DefaultInlineVector lreluops; diff --git a/test/lit/pto/tcolexpanddiv_precision_emitc.pto b/test/lit/pto/tcolexpanddiv_precision_emitc.pto new file mode 100644 index 000000000..ea46d994f --- /dev/null +++ b/test/lit/pto/tcolexpanddiv_precision_emitc.pto @@ -0,0 +1,19 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 + +module { + func.func @tcolexpanddiv_precision_emitc() { + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tcolexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tcolexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + {precisionType = #pto} + return + } +} + +// A5: TCOLEXPANDDIV([[VDST:v[0-9]+]], [[VSRC0:v[0-9]+]], [[VSRC1:v[0-9]+]]); +// A5: TCOLEXPANDDIV([[VDST]], [[VSRC0]], [[VSRC1]]); diff --git a/test/lit/pto/tdiv_precision_emitc.pto b/test/lit/pto/tdiv_precision_emitc.pto new file mode 100644 index 000000000..16aad222d --- /dev/null +++ b/test/lit/pto/tdiv_precision_emitc.pto @@ -0,0 +1,21 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 +// RUN: ptoas --pto-arch=a5 --enable-insert-sync %s 2>&1 | FileCheck %s --check-prefix=A5 + +module { + func.func @tdiv_precision_emitc() { + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tdiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tdiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + {precisionType = #pto} + return + } +} + +// A5: TDIV([[VDST:v[0-9]+]], [[VSRC0:v[0-9]+]], [[VSRC1:v[0-9]+]]); +// A5: TDIV([[VDST]], [[VSRC0]], [[VSRC1]]); diff --git a/test/lit/pto/texp_precision_emitc.pto b/test/lit/pto/texp_precision_emitc.pto new file mode 100644 index 000000000..68f300115 --- /dev/null +++ b/test/lit/pto/texp_precision_emitc.pto @@ -0,0 +1,18 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 + +module { + func.func @texp_precision_emitc() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.texp ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.texp ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + {precisionType = #pto} + return + } +} + +// A5: TEXP([[VDST:v[0-9]+]], [[VSRC:v[0-9]+]]); +// A5: TEXP([[VDST]], [[VSRC]]); diff --git a/test/lit/pto/tfmod_precision_emitc.pto b/test/lit/pto/tfmod_precision_emitc.pto new file mode 100644 index 000000000..bbd94ebc2 --- /dev/null +++ b/test/lit/pto/tfmod_precision_emitc.pto @@ -0,0 +1,19 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 + +module { + func.func @tfmod_precision_emitc() { + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tfmod ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tfmod ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + {precisionType = #pto} + return + } +} + +// A5: TFMOD([[VDST:v[0-9]+]], [[VSRC0:v[0-9]+]], [[VSRC1:v[0-9]+]]); +// A5: TFMOD([[VDST]], [[VSRC0]], [[VSRC1]]); diff --git a/test/lit/pto/tgemv_accphase_emitc.pto b/test/lit/pto/tgemv_accphase_emitc.pto new file mode 100644 index 000000000..a7aa3ce88 --- /dev/null +++ b/test/lit/pto/tgemv_accphase_emitc.pto @@ -0,0 +1,34 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 + +module attributes {pto.target_arch = "a5"} { + func.func @tgemv_accphase_emitc() attributes {pto.kernel_kind = #pto.kernel_kind} { + %lhs = pto.alloc_tile : !pto.tile_buf + %rhs = pto.alloc_tile : !pto.tile_buf + %acc_in = pto.alloc_tile : !pto.tile_buf + %bias = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tgemv ins(%lhs, %rhs : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) + pto.tgemv ins(%lhs, %rhs : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {accPhase = #pto} + pto.tgemv ins(%lhs, %rhs : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {accPhase = #pto} + + pto.tgemv.acc ins(%acc_in, %lhs, %rhs : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) + pto.tgemv.acc ins(%acc_in, %lhs, %rhs : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {accPhase = #pto} + pto.tgemv.acc ins(%acc_in, %lhs, %rhs : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {accPhase = #pto} + + pto.tgemv.bias ins(%lhs, %rhs, %bias : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) + pto.tgemv.bias ins(%lhs, %rhs, %bias : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {accPhase = #pto} + pto.tgemv.bias ins(%lhs, %rhs, %bias : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {accPhase = #pto} + return + } +} + +// A5: TGEMV([[D0:v[0-9]+]], [[L0:v[0-9]+]], [[R0:v[0-9]+]]); +// A5: TGEMV([[D0]], [[L0]], [[R0]]); +// A5: TGEMV([[D0]], [[L0]], [[R0]]); +// A5: TGEMV_ACC([[D1:v[0-9]+]], [[CIN:v[0-9]+]], [[L1:v[0-9]+]], [[R1:v[0-9]+]]); +// A5: TGEMV_ACC([[D1]], [[CIN]], [[L1]], [[R1]]); +// A5: TGEMV_ACC([[D1]], [[CIN]], [[L1]], [[R1]]); +// A5: TGEMV_BIAS([[D2:v[0-9]+]], [[L2:v[0-9]+]], [[R2:v[0-9]+]], [[B2:v[0-9]+]]); +// A5: TGEMV_BIAS([[D2]], [[L2]], [[R2]], [[B2]]); +// A5: TGEMV_BIAS([[D2]], [[L2]], [[R2]], [[B2]]); diff --git a/test/lit/pto/tgemv_mx_accphase_emitc.pto b/test/lit/pto/tgemv_mx_accphase_emitc.pto new file mode 100644 index 000000000..dc1b70601 --- /dev/null +++ b/test/lit/pto/tgemv_mx_accphase_emitc.pto @@ -0,0 +1,32 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 + +module attributes {pto.target_arch = "a5"} { + func.func @tgemv_mx_accphase_emitc() attributes {pto.kernel_kind = #pto.kernel_kind} { + %a = pto.alloc_tile : !pto.tile_buf + %b = pto.alloc_tile : !pto.tile_buf + %a_scale = pto.alloc_tile : !pto.tile_buf + %b_scale = pto.alloc_tile : !pto.tile_buf + %c_in = pto.alloc_tile : !pto.tile_buf + %dst0 = pto.alloc_tile : !pto.tile_buf + %dst1 = pto.alloc_tile : !pto.tile_buf + %dst2 = pto.alloc_tile : !pto.tile_buf + %dst3 = pto.alloc_tile : !pto.tile_buf + %dst4 = pto.alloc_tile : !pto.tile_buf + %dst5 = pto.alloc_tile : !pto.tile_buf + + pto.tgemv.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst0 : !pto.tile_buf) + pto.tgemv.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst1 : !pto.tile_buf) {accPhase = #pto} + pto.tgemv.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst2 : !pto.tile_buf) {accPhase = #pto} + pto.tgemv.mx.acc ins(%c_in, %a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst3 : !pto.tile_buf) + pto.tgemv.mx.acc ins(%c_in, %a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst4 : !pto.tile_buf) {accPhase = #pto} + pto.tgemv.mx.acc ins(%c_in, %a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst5 : !pto.tile_buf) {accPhase = #pto} + return + } +} + +// A5: TGEMV_MX([[D0:v[0-9]+]], [[A0:v[0-9]+]], [[AS0:v[0-9]+]], [[B0:v[0-9]+]], [[BS0:v[0-9]+]]); +// A5: TGEMV_MX([[D1:v[0-9]+]], [[A1:v[0-9]+]], [[AS1:v[0-9]+]], [[B1:v[0-9]+]], [[BS1:v[0-9]+]]); +// A5: TGEMV_MX([[D2:v[0-9]+]], [[A2:v[0-9]+]], [[AS2:v[0-9]+]], [[B2:v[0-9]+]], [[BS2:v[0-9]+]]); +// A5: TGEMV_MX([[D3:v[0-9]+]], [[C3:v[0-9]+]], [[A3:v[0-9]+]], [[AS3:v[0-9]+]], [[B3:v[0-9]+]], [[BS3:v[0-9]+]]); +// A5: TGEMV_MX([[D4:v[0-9]+]], [[C4:v[0-9]+]], [[A4:v[0-9]+]], [[AS4:v[0-9]+]], [[B4:v[0-9]+]], [[BS4:v[0-9]+]]); +// A5: TGEMV_MX([[D5:v[0-9]+]], [[C5:v[0-9]+]], [[A5:v[0-9]+]], [[AS5:v[0-9]+]], [[B5:v[0-9]+]], [[BS5:v[0-9]+]]); diff --git a/test/lit/pto/tgemv_mx_verify_invalid_shape_a5.pto b/test/lit/pto/tgemv_mx_verify_invalid_shape_a5.pto new file mode 100644 index 000000000..a0f4da6fe --- /dev/null +++ b/test/lit/pto/tgemv_mx_verify_invalid_shape_a5.pto @@ -0,0 +1,15 @@ +// RUN: not ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + func.func @tgemv_mx_invalid_static_shape(%a : !pto.tile_buf, + %b : !pto.tile_buf) attributes {pto.kernel_kind = #pto.kernel_kind} { + %a_scale = pto.alloc_tile : !pto.tile_buf + %b_scale = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tgemv.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) + return + } +} + +// CHECK: error: 'pto.tgemv.mx' op expects static matmul tile shapes lhs[M,K], rhs[K,N], and dst[M,N] diff --git a/test/lit/pto/tlog_precision_emitc.pto b/test/lit/pto/tlog_precision_emitc.pto new file mode 100644 index 000000000..f50d1aefb --- /dev/null +++ b/test/lit/pto/tlog_precision_emitc.pto @@ -0,0 +1,18 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 + +module { + func.func @tlog_precision_emitc() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tlog ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tlog ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + {precisionType = #pto} + return + } +} + +// A5: TLOG([[VDST:v[0-9]+]], [[VSRC:v[0-9]+]]); +// A5: TLOG([[VDST]], [[VSRC]]); diff --git a/test/lit/pto/tmatmul_acc_phase_emitc.pto b/test/lit/pto/tmatmul_acc_phase_emitc.pto index 868a991dc..41dad0049 100644 --- a/test/lit/pto/tmatmul_acc_phase_emitc.pto +++ b/test/lit/pto/tmatmul_acc_phase_emitc.pto @@ -49,8 +49,8 @@ module { } // A3: TMATMUL([[C:v[0-9]+]], [[A:v[0-9]+]], [[B:v[0-9]+]]); -// A3: TMATMUL([[C]], [[A]], [[B]]); -// A3: TMATMUL([[C]], [[A]], [[B]]); +// A3: TMATMUL([[C]], [[A]], [[B]]); +// A3: TMATMUL([[C]], [[A]], [[B]]); // A3: TMATMUL_ACC([[C]], [[C]], [[A]], [[B]]); -// A3: TMATMUL_ACC([[C]], [[C]], [[A]], [[B]]); -// A3: TMATMUL_ACC([[C]], [[C]], [[A]], [[B]]); +// A3: TMATMUL_ACC([[C]], [[C]], [[A]], [[B]]); +// A3: TMATMUL_ACC([[C]], [[C]], [[A]], [[B]]); diff --git a/test/lit/pto/tmatmul_accphase_emitc.pto b/test/lit/pto/tmatmul_accphase_emitc.pto new file mode 100644 index 000000000..ee02bb197 --- /dev/null +++ b/test/lit/pto/tmatmul_accphase_emitc.pto @@ -0,0 +1,34 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 + +module attributes {pto.target_arch = "a5"} { + func.func @tmatmul_accphase_emitc() attributes {pto.kernel_kind = #pto.kernel_kind} { + %lhs = pto.alloc_tile : !pto.tile_buf + %rhs = pto.alloc_tile : !pto.tile_buf + %acc_in = pto.alloc_tile : !pto.tile_buf + %bias = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tmatmul ins(%lhs, %rhs : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) + pto.tmatmul ins(%lhs, %rhs : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {accPhase = #pto} + pto.tmatmul ins(%lhs, %rhs : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {accPhase = #pto} + + pto.tmatmul.acc ins(%acc_in, %lhs, %rhs : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) + pto.tmatmul.acc ins(%acc_in, %lhs, %rhs : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {accPhase = #pto} + pto.tmatmul.acc ins(%acc_in, %lhs, %rhs : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {accPhase = #pto} + + pto.tmatmul.bias ins(%lhs, %rhs, %bias : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) + pto.tmatmul.bias ins(%lhs, %rhs, %bias : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {accPhase = #pto} + pto.tmatmul.bias ins(%lhs, %rhs, %bias : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {accPhase = #pto} + return + } +} + +// A5: TMATMUL([[D0:v[0-9]+]], [[L0:v[0-9]+]], [[R0:v[0-9]+]]); +// A5: TMATMUL([[D0]], [[L0]], [[R0]]); +// A5: TMATMUL([[D0]], [[L0]], [[R0]]); +// A5: TMATMUL_ACC([[D1:v[0-9]+]], [[CIN:v[0-9]+]], [[L1:v[0-9]+]], [[R1:v[0-9]+]]); +// A5: TMATMUL_ACC([[D1]], [[CIN]], [[L1]], [[R1]]); +// A5: TMATMUL_ACC([[D1]], [[CIN]], [[L1]], [[R1]]); +// A5: TMATMUL_BIAS([[D2:v[0-9]+]], [[L2:v[0-9]+]], [[R2:v[0-9]+]], [[B2:v[0-9]+]]); +// A5: TMATMUL_BIAS([[D2]], [[L2]], [[R2]], [[B2]]); +// A5: TMATMUL_BIAS([[D2]], [[L2]], [[R2]], [[B2]]); diff --git a/test/lit/pto/tmatmul_mx_accphase_emitc.pto b/test/lit/pto/tmatmul_mx_accphase_emitc.pto new file mode 100644 index 000000000..d62c8de84 --- /dev/null +++ b/test/lit/pto/tmatmul_mx_accphase_emitc.pto @@ -0,0 +1,31 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 + +module attributes {pto.target_arch = "a5"} { + func.func @tmatmul_mx_accphase_emitc(%a : !pto.tile_buf, + %b : !pto.tile_buf) attributes {pto.kernel_kind = #pto.kernel_kind} { + %a_scale = pto.alloc_tile : !pto.tile_buf + %b_scale = pto.alloc_tile : !pto.tile_buf + %c_in = pto.alloc_tile : !pto.tile_buf + %dst0 = pto.alloc_tile : !pto.tile_buf + %dst1 = pto.alloc_tile : !pto.tile_buf + %dst2 = pto.alloc_tile : !pto.tile_buf + %dst3 = pto.alloc_tile : !pto.tile_buf + %dst4 = pto.alloc_tile : !pto.tile_buf + %dst5 = pto.alloc_tile : !pto.tile_buf + + pto.tmatmul.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst0 : !pto.tile_buf) + pto.tmatmul.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst1 : !pto.tile_buf) {accPhase = #pto} + pto.tmatmul.mx ins(%a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst2 : !pto.tile_buf) {accPhase = #pto} + pto.tmatmul.mx.acc ins(%c_in, %a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst3 : !pto.tile_buf) + pto.tmatmul.mx.acc ins(%c_in, %a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst4 : !pto.tile_buf) {accPhase = #pto} + pto.tmatmul.mx.acc ins(%c_in, %a, %a_scale, %b, %b_scale : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%dst5 : !pto.tile_buf) {accPhase = #pto} + return + } +} + +// A5: TMATMUL_MX([[D0:v[0-9]+]], [[A0:v[0-9]+]], [[AS0:v[0-9]+]], [[B0:v[0-9]+]], [[BS0:v[0-9]+]]); +// A5: TMATMUL_MX([[D1:v[0-9]+]], [[A1:v[0-9]+]], [[AS1:v[0-9]+]], [[B1:v[0-9]+]], [[BS1:v[0-9]+]]); +// A5: TMATMUL_MX([[D2:v[0-9]+]], [[A2:v[0-9]+]], [[AS2:v[0-9]+]], [[B2:v[0-9]+]], [[BS2:v[0-9]+]]); +// A5: TMATMUL_MX([[D3:v[0-9]+]], [[C3:v[0-9]+]], [[A3:v[0-9]+]], [[AS3:v[0-9]+]], [[B3:v[0-9]+]], [[BS3:v[0-9]+]]); +// A5: TMATMUL_MX([[D4:v[0-9]+]], [[C4:v[0-9]+]], [[A4:v[0-9]+]], [[AS4:v[0-9]+]], [[B4:v[0-9]+]], [[BS4:v[0-9]+]]); +// A5: TMATMUL_MX([[D5:v[0-9]+]], [[C5:v[0-9]+]], [[A5:v[0-9]+]], [[AS5:v[0-9]+]], [[B5:v[0-9]+]], [[BS5:v[0-9]+]]); diff --git a/test/lit/pto/trecip_precision_emitc.pto b/test/lit/pto/trecip_precision_emitc.pto new file mode 100644 index 000000000..830d03d0e --- /dev/null +++ b/test/lit/pto/trecip_precision_emitc.pto @@ -0,0 +1,18 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 + +module { + func.func @trecip_precision_emitc() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.trecip ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.trecip ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + {precisionType = #pto} + return + } +} + +// A5: TRECIP([[VDST:v[0-9]+]], [[VSRC:v[0-9]+]]); +// A5: TRECIP([[VDST]], [[VSRC]]); diff --git a/test/lit/pto/trem_precision_emitc.pto b/test/lit/pto/trem_precision_emitc.pto new file mode 100644 index 000000000..462867bd4 --- /dev/null +++ b/test/lit/pto/trem_precision_emitc.pto @@ -0,0 +1,20 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 + +module { + func.func @trem_precision_emitc() { + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.trem ins(%src0, %src1, %tmp : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.trem ins(%src0, %src1, %tmp : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + {precisionType = #pto} + return + } +} + +// A5: TREM([[VDST:v[0-9]+]], [[VSRC0:v[0-9]+]], [[VSRC1:v[0-9]+]], [[VTMP:v[0-9]+]]); +// A5: TREM([[VDST]], [[VSRC0]], [[VSRC1]], [[VTMP]]); diff --git a/test/lit/pto/trowexpanddiv_precision_emitc.pto b/test/lit/pto/trowexpanddiv_precision_emitc.pto new file mode 100644 index 000000000..6056f708f --- /dev/null +++ b/test/lit/pto/trowexpanddiv_precision_emitc.pto @@ -0,0 +1,20 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 + +module { + func.func @trowexpanddiv_precision_emitc() { + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.trowexpanddiv ins(%src0, %src1, %tmp : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + {precisionType = #pto} + return + } +} + +// A5: TROWEXPANDDIV([[VDST0:v[0-9]+]], [[VSRC00:v[0-9]+]], [[VSRC01:v[0-9]+]]); +// A5: TROWEXPANDDIV([[VDST1:v[0-9]+]], [[VSRC10:v[0-9]+]], [[VSRC11:v[0-9]+]], [[VTMP:v[0-9]+]]); diff --git a/test/lit/pto/trsqrt_precision_emitc.pto b/test/lit/pto/trsqrt_precision_emitc.pto new file mode 100644 index 000000000..0ff89ca39 --- /dev/null +++ b/test/lit/pto/trsqrt_precision_emitc.pto @@ -0,0 +1,19 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 + +module { + func.func @trsqrt_precision_emitc() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + + pto.trsqrt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.trsqrt ins(%src, %tmp : !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + {precisionType = #pto} + return + } +} + +// A5: TRSQRT([[VDST0:v[0-9]+]], [[VSRC0:v[0-9]+]]); +// A5: TRSQRT([[VDST1:v[0-9]+]], [[VSRC1:v[0-9]+]], [[VTMP:v[0-9]+]]); diff --git a/test/lit/pto/tsqrt_precision_emitc.pto b/test/lit/pto/tsqrt_precision_emitc.pto new file mode 100644 index 000000000..92a30a2ad --- /dev/null +++ b/test/lit/pto/tsqrt_precision_emitc.pto @@ -0,0 +1,18 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 + +module { + func.func @tsqrt_precision_emitc() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tsqrt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tsqrt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + {precisionType = #pto} + return + } +} + +// A5: TSQRT([[VDST:v[0-9]+]], [[VSRC:v[0-9]+]]); +// A5: TSQRT([[VDST]], [[VSRC]]);