diff --git a/src/mlir/cxx/mlir/CxxOps.td b/src/mlir/cxx/mlir/CxxOps.td index 6f518739..9356a2c7 100644 --- a/src/mlir/cxx/mlir/CxxOps.td +++ b/src/mlir/cxx/mlir/CxxOps.td @@ -418,6 +418,25 @@ def Cxx_NotEqualFOp : Cxx_Op<"nef"> { let results = (outs Cxx_BoolType:$result); } +// bitwise ops +def Cxx_AndOp : Cxx_Op<"and"> { + let arguments = (ins Cxx_IntegerType:$lhs, Cxx_IntegerType:$rhs); + + let results = (outs Cxx_IntegerType:$result); +} + +def Cxx_OrOp : Cxx_Op<"or"> { + let arguments = (ins Cxx_IntegerType:$lhs, Cxx_IntegerType:$rhs); + + let results = (outs Cxx_IntegerType:$result); +} + +def Cxx_XorOp : Cxx_Op<"xor"> { + let arguments = (ins Cxx_IntegerType:$lhs, Cxx_IntegerType:$rhs); + + let results = (outs Cxx_IntegerType:$result); +} + // // control flow ops // diff --git a/src/mlir/cxx/mlir/codegen_expressions.cc b/src/mlir/cxx/mlir/codegen_expressions.cc index e6fe9e43..4662cb8c 100644 --- a/src/mlir/cxx/mlir/codegen_expressions.cc +++ b/src/mlir/cxx/mlir/codegen_expressions.cc @@ -1578,7 +1578,7 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast) } case TokenKind::T_LESS_LESS: { - if (control()->is_integral(ast->type)) { + if (control()->is_integral_or_unscoped_enum(ast->type)) { auto op = mlir::cxx::ShiftLeftOp::create(gen.builder_, loc, resultType, leftExpressionResult.value, rightExpressionResult.value); @@ -1589,7 +1589,7 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast) } case TokenKind::T_GREATER_GREATER: { - if (control()->is_integral(ast->type)) { + if (control()->is_integral_or_unscoped_enum(ast->type)) { auto op = mlir::cxx::ShiftRightOp::create(gen.builder_, loc, resultType, leftExpressionResult.value, rightExpressionResult.value); @@ -1600,7 +1600,7 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast) } case TokenKind::T_EQUAL_EQUAL: { - if (control()->is_integral(ast->type)) { + if (control()->is_integral_or_unscoped_enum(ast->type)) { auto op = mlir::cxx::EqualOp::create(gen.builder_, loc, resultType, leftExpressionResult.value, rightExpressionResult.value); @@ -1618,14 +1618,14 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast) } case TokenKind::T_EXCLAIM_EQUAL: { - if (control()->is_integral(ast->type)) { + if (control()->is_integral_or_unscoped_enum(ast->leftExpression->type)) { auto op = mlir::cxx::NotEqualOp::create(gen.builder_, loc, resultType, leftExpressionResult.value, rightExpressionResult.value); return {op}; } - if (control()->is_floating_point(ast->type)) { + if (control()->is_floating_point(ast->leftExpression->type)) { auto op = mlir::cxx::NotEqualFOp::create(gen.builder_, loc, resultType, leftExpressionResult.value, rightExpressionResult.value); @@ -1636,14 +1636,14 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast) } case TokenKind::T_LESS: { - if (control()->is_integral(ast->type)) { + if (control()->is_integral_or_unscoped_enum(ast->leftExpression->type)) { auto op = mlir::cxx::LessThanOp::create(gen.builder_, loc, resultType, leftExpressionResult.value, rightExpressionResult.value); return {op}; } - if (control()->is_floating_point(ast->type)) { + if (control()->is_floating_point(ast->leftExpression->type)) { auto op = mlir::cxx::LessThanFOp::create(gen.builder_, loc, resultType, leftExpressionResult.value, rightExpressionResult.value); @@ -1654,14 +1654,14 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast) } case TokenKind::T_LESS_EQUAL: { - if (control()->is_integral(ast->type)) { + if (control()->is_integral_or_unscoped_enum(ast->leftExpression->type)) { auto op = mlir::cxx::LessEqualOp::create(gen.builder_, loc, resultType, leftExpressionResult.value, rightExpressionResult.value); return {op}; } - if (control()->is_floating_point(ast->type)) { + if (control()->is_floating_point(ast->leftExpression->type)) { auto op = mlir::cxx::LessEqualFOp::create(gen.builder_, loc, resultType, leftExpressionResult.value, rightExpressionResult.value); @@ -1672,14 +1672,14 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast) } case TokenKind::T_GREATER: { - if (control()->is_integral(ast->type)) { + if (control()->is_integral_or_unscoped_enum(ast->leftExpression->type)) { auto op = mlir::cxx::GreaterThanOp::create( gen.builder_, loc, resultType, leftExpressionResult.value, rightExpressionResult.value); return {op}; } - if (control()->is_floating_point(ast->type)) { + if (control()->is_floating_point(ast->leftExpression->type)) { auto op = mlir::cxx::GreaterThanFOp::create( gen.builder_, loc, resultType, leftExpressionResult.value, rightExpressionResult.value); @@ -1690,14 +1690,14 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast) } case TokenKind::T_GREATER_EQUAL: { - if (control()->is_integral(ast->type)) { + if (control()->is_integral_or_unscoped_enum(ast->leftExpression->type)) { auto op = mlir::cxx::GreaterEqualOp::create( gen.builder_, loc, resultType, leftExpressionResult.value, rightExpressionResult.value); return {op}; } - if (control()->is_floating_point(ast->type)) { + if (control()->is_floating_point(ast->leftExpression->type)) { auto op = mlir::cxx::GreaterEqualFOp::create( gen.builder_, loc, resultType, leftExpressionResult.value, rightExpressionResult.value); @@ -1707,6 +1707,27 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast) break; } + case TokenKind::T_CARET: { + auto op = mlir::cxx::XorOp::create(gen.builder_, loc, resultType, + leftExpressionResult.value, + rightExpressionResult.value); + return {op}; + } + + case TokenKind::T_AMP: { + auto op = mlir::cxx::AndOp::create(gen.builder_, loc, resultType, + leftExpressionResult.value, + rightExpressionResult.value); + return {op}; + } + + case TokenKind::T_BAR: { + auto op = mlir::cxx::OrOp::create(gen.builder_, loc, resultType, + leftExpressionResult.value, + rightExpressionResult.value); + return {op}; + } + default: break; } // switch diff --git a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc index bb58d482..b2cb2b75 100644 --- a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc +++ b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc @@ -1036,6 +1036,79 @@ class GreaterEqualOpLowering : public OpConversionPattern { } }; +// +// bitwise operations +// + +class AndOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::AndOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto resultType = typeConverter->convertType(op.getType()); + if (!resultType) { + return rewriter.notifyMatchFailure( + op, "failed to convert and operation type"); + } + + rewriter.replaceOpWithNewOp(op, resultType, adaptor.getLhs(), + adaptor.getRhs()); + + return success(); + } +}; + +class OrOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::OrOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto resultType = typeConverter->convertType(op.getType()); + if (!resultType) { + return rewriter.notifyMatchFailure(op, + "failed to convert or operation type"); + } + + rewriter.replaceOpWithNewOp(op, resultType, adaptor.getLhs(), + adaptor.getRhs()); + + return success(); + } +}; + +class XorOpLowering : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + auto matchAndRewrite(cxx::XorOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const + -> LogicalResult override { + auto typeConverter = getTypeConverter(); + auto context = getContext(); + + auto resultType = typeConverter->convertType(op.getType()); + if (!resultType) { + return rewriter.notifyMatchFailure( + op, "failed to convert xor operation type"); + } + + rewriter.replaceOpWithNewOp(op, resultType, adaptor.getLhs(), + adaptor.getRhs()); + + return success(); + } +}; + // // floating point operations // @@ -1470,6 +1543,10 @@ void CxxToLLVMLoweringPass::runOnOperation() { LessEqualOpLowering, GreaterThanOpLowering, GreaterEqualOpLowering>(typeConverter, context); + // bitwise operations + patterns.insert(typeConverter, + context); + // floating point operations patterns .insert(