Skip to content

Commit 06c6d10

Browse files
committed
Add bitwise ops and their lowering patterns
Signed-off-by: Roberto Raggi <roberto.raggi@gmail.com>
1 parent cb771cf commit 06c6d10

File tree

3 files changed

+130
-13
lines changed

3 files changed

+130
-13
lines changed

src/mlir/cxx/mlir/CxxOps.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,25 @@ def Cxx_NotEqualFOp : Cxx_Op<"nef"> {
418418
let results = (outs Cxx_BoolType:$result);
419419
}
420420

421+
// bitwise ops
422+
def Cxx_AndOp : Cxx_Op<"and"> {
423+
let arguments = (ins Cxx_IntegerType:$lhs, Cxx_IntegerType:$rhs);
424+
425+
let results = (outs Cxx_IntegerType:$result);
426+
}
427+
428+
def Cxx_OrOp : Cxx_Op<"or"> {
429+
let arguments = (ins Cxx_IntegerType:$lhs, Cxx_IntegerType:$rhs);
430+
431+
let results = (outs Cxx_IntegerType:$result);
432+
}
433+
434+
def Cxx_XorOp : Cxx_Op<"xor"> {
435+
let arguments = (ins Cxx_IntegerType:$lhs, Cxx_IntegerType:$rhs);
436+
437+
let results = (outs Cxx_IntegerType:$result);
438+
}
439+
421440
//
422441
// control flow ops
423442
//

src/mlir/cxx/mlir/codegen_expressions.cc

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1578,7 +1578,7 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast)
15781578
}
15791579

15801580
case TokenKind::T_LESS_LESS: {
1581-
if (control()->is_integral(ast->type)) {
1581+
if (control()->is_integral_or_unscoped_enum(ast->type)) {
15821582
auto op = mlir::cxx::ShiftLeftOp::create(gen.builder_, loc, resultType,
15831583
leftExpressionResult.value,
15841584
rightExpressionResult.value);
@@ -1589,7 +1589,7 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast)
15891589
}
15901590

15911591
case TokenKind::T_GREATER_GREATER: {
1592-
if (control()->is_integral(ast->type)) {
1592+
if (control()->is_integral_or_unscoped_enum(ast->type)) {
15931593
auto op = mlir::cxx::ShiftRightOp::create(gen.builder_, loc, resultType,
15941594
leftExpressionResult.value,
15951595
rightExpressionResult.value);
@@ -1600,7 +1600,7 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast)
16001600
}
16011601

16021602
case TokenKind::T_EQUAL_EQUAL: {
1603-
if (control()->is_integral(ast->type)) {
1603+
if (control()->is_integral_or_unscoped_enum(ast->type)) {
16041604
auto op = mlir::cxx::EqualOp::create(gen.builder_, loc, resultType,
16051605
leftExpressionResult.value,
16061606
rightExpressionResult.value);
@@ -1618,14 +1618,14 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast)
16181618
}
16191619

16201620
case TokenKind::T_EXCLAIM_EQUAL: {
1621-
if (control()->is_integral(ast->type)) {
1621+
if (control()->is_integral_or_unscoped_enum(ast->leftExpression->type)) {
16221622
auto op = mlir::cxx::NotEqualOp::create(gen.builder_, loc, resultType,
16231623
leftExpressionResult.value,
16241624
rightExpressionResult.value);
16251625
return {op};
16261626
}
16271627

1628-
if (control()->is_floating_point(ast->type)) {
1628+
if (control()->is_floating_point(ast->leftExpression->type)) {
16291629
auto op = mlir::cxx::NotEqualFOp::create(gen.builder_, loc, resultType,
16301630
leftExpressionResult.value,
16311631
rightExpressionResult.value);
@@ -1636,14 +1636,14 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast)
16361636
}
16371637

16381638
case TokenKind::T_LESS: {
1639-
if (control()->is_integral(ast->type)) {
1639+
if (control()->is_integral_or_unscoped_enum(ast->leftExpression->type)) {
16401640
auto op = mlir::cxx::LessThanOp::create(gen.builder_, loc, resultType,
16411641
leftExpressionResult.value,
16421642
rightExpressionResult.value);
16431643
return {op};
16441644
}
16451645

1646-
if (control()->is_floating_point(ast->type)) {
1646+
if (control()->is_floating_point(ast->leftExpression->type)) {
16471647
auto op = mlir::cxx::LessThanFOp::create(gen.builder_, loc, resultType,
16481648
leftExpressionResult.value,
16491649
rightExpressionResult.value);
@@ -1654,14 +1654,14 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast)
16541654
}
16551655

16561656
case TokenKind::T_LESS_EQUAL: {
1657-
if (control()->is_integral(ast->type)) {
1657+
if (control()->is_integral_or_unscoped_enum(ast->leftExpression->type)) {
16581658
auto op = mlir::cxx::LessEqualOp::create(gen.builder_, loc, resultType,
16591659
leftExpressionResult.value,
16601660
rightExpressionResult.value);
16611661
return {op};
16621662
}
16631663

1664-
if (control()->is_floating_point(ast->type)) {
1664+
if (control()->is_floating_point(ast->leftExpression->type)) {
16651665
auto op = mlir::cxx::LessEqualFOp::create(gen.builder_, loc, resultType,
16661666
leftExpressionResult.value,
16671667
rightExpressionResult.value);
@@ -1672,14 +1672,14 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast)
16721672
}
16731673

16741674
case TokenKind::T_GREATER: {
1675-
if (control()->is_integral(ast->type)) {
1675+
if (control()->is_integral_or_unscoped_enum(ast->leftExpression->type)) {
16761676
auto op = mlir::cxx::GreaterThanOp::create(
16771677
gen.builder_, loc, resultType, leftExpressionResult.value,
16781678
rightExpressionResult.value);
16791679
return {op};
16801680
}
16811681

1682-
if (control()->is_floating_point(ast->type)) {
1682+
if (control()->is_floating_point(ast->leftExpression->type)) {
16831683
auto op = mlir::cxx::GreaterThanFOp::create(
16841684
gen.builder_, loc, resultType, leftExpressionResult.value,
16851685
rightExpressionResult.value);
@@ -1690,14 +1690,14 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast)
16901690
}
16911691

16921692
case TokenKind::T_GREATER_EQUAL: {
1693-
if (control()->is_integral(ast->type)) {
1693+
if (control()->is_integral_or_unscoped_enum(ast->leftExpression->type)) {
16941694
auto op = mlir::cxx::GreaterEqualOp::create(
16951695
gen.builder_, loc, resultType, leftExpressionResult.value,
16961696
rightExpressionResult.value);
16971697
return {op};
16981698
}
16991699

1700-
if (control()->is_floating_point(ast->type)) {
1700+
if (control()->is_floating_point(ast->leftExpression->type)) {
17011701
auto op = mlir::cxx::GreaterEqualFOp::create(
17021702
gen.builder_, loc, resultType, leftExpressionResult.value,
17031703
rightExpressionResult.value);
@@ -1707,6 +1707,27 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast)
17071707
break;
17081708
}
17091709

1710+
case TokenKind::T_CARET: {
1711+
auto op = mlir::cxx::XorOp::create(gen.builder_, loc, resultType,
1712+
leftExpressionResult.value,
1713+
rightExpressionResult.value);
1714+
return {op};
1715+
}
1716+
1717+
case TokenKind::T_AMP: {
1718+
auto op = mlir::cxx::AndOp::create(gen.builder_, loc, resultType,
1719+
leftExpressionResult.value,
1720+
rightExpressionResult.value);
1721+
return {op};
1722+
}
1723+
1724+
case TokenKind::T_BAR: {
1725+
auto op = mlir::cxx::OrOp::create(gen.builder_, loc, resultType,
1726+
leftExpressionResult.value,
1727+
rightExpressionResult.value);
1728+
return {op};
1729+
}
1730+
17101731
default:
17111732
break;
17121733
} // switch

src/mlir/cxx/mlir/cxx_dialect_conversions.cc

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,6 +1036,79 @@ class GreaterEqualOpLowering : public OpConversionPattern<cxx::GreaterEqualOp> {
10361036
}
10371037
};
10381038

1039+
//
1040+
// bitwise operations
1041+
//
1042+
1043+
class AndOpLowering : public OpConversionPattern<cxx::AndOp> {
1044+
public:
1045+
using OpConversionPattern::OpConversionPattern;
1046+
1047+
auto matchAndRewrite(cxx::AndOp op, OpAdaptor adaptor,
1048+
ConversionPatternRewriter& rewriter) const
1049+
-> LogicalResult override {
1050+
auto typeConverter = getTypeConverter();
1051+
auto context = getContext();
1052+
1053+
auto resultType = typeConverter->convertType(op.getType());
1054+
if (!resultType) {
1055+
return rewriter.notifyMatchFailure(
1056+
op, "failed to convert and operation type");
1057+
}
1058+
1059+
rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, resultType, adaptor.getLhs(),
1060+
adaptor.getRhs());
1061+
1062+
return success();
1063+
}
1064+
};
1065+
1066+
class OrOpLowering : public OpConversionPattern<cxx::OrOp> {
1067+
public:
1068+
using OpConversionPattern::OpConversionPattern;
1069+
1070+
auto matchAndRewrite(cxx::OrOp op, OpAdaptor adaptor,
1071+
ConversionPatternRewriter& rewriter) const
1072+
-> LogicalResult override {
1073+
auto typeConverter = getTypeConverter();
1074+
auto context = getContext();
1075+
1076+
auto resultType = typeConverter->convertType(op.getType());
1077+
if (!resultType) {
1078+
return rewriter.notifyMatchFailure(op,
1079+
"failed to convert or operation type");
1080+
}
1081+
1082+
rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, resultType, adaptor.getLhs(),
1083+
adaptor.getRhs());
1084+
1085+
return success();
1086+
}
1087+
};
1088+
1089+
class XorOpLowering : public OpConversionPattern<cxx::XorOp> {
1090+
public:
1091+
using OpConversionPattern::OpConversionPattern;
1092+
1093+
auto matchAndRewrite(cxx::XorOp op, OpAdaptor adaptor,
1094+
ConversionPatternRewriter& rewriter) const
1095+
-> LogicalResult override {
1096+
auto typeConverter = getTypeConverter();
1097+
auto context = getContext();
1098+
1099+
auto resultType = typeConverter->convertType(op.getType());
1100+
if (!resultType) {
1101+
return rewriter.notifyMatchFailure(
1102+
op, "failed to convert xor operation type");
1103+
}
1104+
1105+
rewriter.replaceOpWithNewOp<LLVM::XOrOp>(op, resultType, adaptor.getLhs(),
1106+
adaptor.getRhs());
1107+
1108+
return success();
1109+
}
1110+
};
1111+
10391112
//
10401113
// floating point operations
10411114
//
@@ -1470,6 +1543,10 @@ void CxxToLLVMLoweringPass::runOnOperation() {
14701543
LessEqualOpLowering, GreaterThanOpLowering,
14711544
GreaterEqualOpLowering>(typeConverter, context);
14721545

1546+
// bitwise operations
1547+
patterns.insert<AndOpLowering, OrOpLowering, XorOpLowering>(typeConverter,
1548+
context);
1549+
14731550
// floating point operations
14741551
patterns
14751552
.insert<AddFOpLowering, SubFOpLowering, MulFOpLowering, DivFOpLowering>(

0 commit comments

Comments
 (0)