Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/mlir/cxx/mlir/CxxOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,25 @@ def Cxx_NotEqualFOp : Cxx_Op<"nef"> {
let results = (outs Cxx_BoolType:$result);
}

// bitwise ops
def Cxx_AndOp : Cxx_Op<"and"> {
let arguments = (ins Cxx_IntegerType:$lhs, Cxx_IntegerType:$rhs);

let results = (outs Cxx_IntegerType:$result);
}

def Cxx_OrOp : Cxx_Op<"or"> {
let arguments = (ins Cxx_IntegerType:$lhs, Cxx_IntegerType:$rhs);

let results = (outs Cxx_IntegerType:$result);
}

def Cxx_XorOp : Cxx_Op<"xor"> {
let arguments = (ins Cxx_IntegerType:$lhs, Cxx_IntegerType:$rhs);

let results = (outs Cxx_IntegerType:$result);
}

//
// control flow ops
//
Expand Down
47 changes: 34 additions & 13 deletions src/mlir/cxx/mlir/codegen_expressions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1578,7 +1578,7 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast)
}

case TokenKind::T_LESS_LESS: {
if (control()->is_integral(ast->type)) {
if (control()->is_integral_or_unscoped_enum(ast->type)) {
auto op = mlir::cxx::ShiftLeftOp::create(gen.builder_, loc, resultType,
leftExpressionResult.value,
rightExpressionResult.value);
Expand All @@ -1589,7 +1589,7 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast)
}

case TokenKind::T_GREATER_GREATER: {
if (control()->is_integral(ast->type)) {
if (control()->is_integral_or_unscoped_enum(ast->type)) {
auto op = mlir::cxx::ShiftRightOp::create(gen.builder_, loc, resultType,
leftExpressionResult.value,
rightExpressionResult.value);
Expand All @@ -1600,7 +1600,7 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast)
}

case TokenKind::T_EQUAL_EQUAL: {
if (control()->is_integral(ast->type)) {
if (control()->is_integral_or_unscoped_enum(ast->type)) {
auto op = mlir::cxx::EqualOp::create(gen.builder_, loc, resultType,
leftExpressionResult.value,
rightExpressionResult.value);
Expand All @@ -1618,14 +1618,14 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast)
}

case TokenKind::T_EXCLAIM_EQUAL: {
if (control()->is_integral(ast->type)) {
if (control()->is_integral_or_unscoped_enum(ast->leftExpression->type)) {
auto op = mlir::cxx::NotEqualOp::create(gen.builder_, loc, resultType,
leftExpressionResult.value,
rightExpressionResult.value);
return {op};
}

if (control()->is_floating_point(ast->type)) {
if (control()->is_floating_point(ast->leftExpression->type)) {
auto op = mlir::cxx::NotEqualFOp::create(gen.builder_, loc, resultType,
leftExpressionResult.value,
rightExpressionResult.value);
Expand All @@ -1636,14 +1636,14 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast)
}

case TokenKind::T_LESS: {
if (control()->is_integral(ast->type)) {
if (control()->is_integral_or_unscoped_enum(ast->leftExpression->type)) {
auto op = mlir::cxx::LessThanOp::create(gen.builder_, loc, resultType,
leftExpressionResult.value,
rightExpressionResult.value);
return {op};
}

if (control()->is_floating_point(ast->type)) {
if (control()->is_floating_point(ast->leftExpression->type)) {
auto op = mlir::cxx::LessThanFOp::create(gen.builder_, loc, resultType,
leftExpressionResult.value,
rightExpressionResult.value);
Expand All @@ -1654,14 +1654,14 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast)
}

case TokenKind::T_LESS_EQUAL: {
if (control()->is_integral(ast->type)) {
if (control()->is_integral_or_unscoped_enum(ast->leftExpression->type)) {
auto op = mlir::cxx::LessEqualOp::create(gen.builder_, loc, resultType,
leftExpressionResult.value,
rightExpressionResult.value);
return {op};
}

if (control()->is_floating_point(ast->type)) {
if (control()->is_floating_point(ast->leftExpression->type)) {
auto op = mlir::cxx::LessEqualFOp::create(gen.builder_, loc, resultType,
leftExpressionResult.value,
rightExpressionResult.value);
Expand All @@ -1672,14 +1672,14 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast)
}

case TokenKind::T_GREATER: {
if (control()->is_integral(ast->type)) {
if (control()->is_integral_or_unscoped_enum(ast->leftExpression->type)) {
auto op = mlir::cxx::GreaterThanOp::create(
gen.builder_, loc, resultType, leftExpressionResult.value,
rightExpressionResult.value);
return {op};
}

if (control()->is_floating_point(ast->type)) {
if (control()->is_floating_point(ast->leftExpression->type)) {
auto op = mlir::cxx::GreaterThanFOp::create(
gen.builder_, loc, resultType, leftExpressionResult.value,
rightExpressionResult.value);
Expand All @@ -1690,14 +1690,14 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast)
}

case TokenKind::T_GREATER_EQUAL: {
if (control()->is_integral(ast->type)) {
if (control()->is_integral_or_unscoped_enum(ast->leftExpression->type)) {
auto op = mlir::cxx::GreaterEqualOp::create(
gen.builder_, loc, resultType, leftExpressionResult.value,
rightExpressionResult.value);
return {op};
}

if (control()->is_floating_point(ast->type)) {
if (control()->is_floating_point(ast->leftExpression->type)) {
auto op = mlir::cxx::GreaterEqualFOp::create(
gen.builder_, loc, resultType, leftExpressionResult.value,
rightExpressionResult.value);
Expand All @@ -1707,6 +1707,27 @@ auto Codegen::ExpressionVisitor::operator()(BinaryExpressionAST* ast)
break;
}

case TokenKind::T_CARET: {
auto op = mlir::cxx::XorOp::create(gen.builder_, loc, resultType,
leftExpressionResult.value,
rightExpressionResult.value);
return {op};
}

case TokenKind::T_AMP: {
auto op = mlir::cxx::AndOp::create(gen.builder_, loc, resultType,
leftExpressionResult.value,
rightExpressionResult.value);
return {op};
}

case TokenKind::T_BAR: {
auto op = mlir::cxx::OrOp::create(gen.builder_, loc, resultType,
leftExpressionResult.value,
rightExpressionResult.value);
return {op};
}

default:
break;
} // switch
Expand Down
77 changes: 77 additions & 0 deletions src/mlir/cxx/mlir/cxx_dialect_conversions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,79 @@ class GreaterEqualOpLowering : public OpConversionPattern<cxx::GreaterEqualOp> {
}
};

//
// bitwise operations
//

class AndOpLowering : public OpConversionPattern<cxx::AndOp> {
public:
using OpConversionPattern::OpConversionPattern;

auto matchAndRewrite(cxx::AndOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const
-> LogicalResult override {
auto typeConverter = getTypeConverter();
auto context = getContext();

auto resultType = typeConverter->convertType(op.getType());
if (!resultType) {
return rewriter.notifyMatchFailure(
op, "failed to convert and operation type");
}

rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, resultType, adaptor.getLhs(),
adaptor.getRhs());

return success();
}
};

class OrOpLowering : public OpConversionPattern<cxx::OrOp> {
public:
using OpConversionPattern::OpConversionPattern;

auto matchAndRewrite(cxx::OrOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const
-> LogicalResult override {
auto typeConverter = getTypeConverter();
auto context = getContext();

auto resultType = typeConverter->convertType(op.getType());
if (!resultType) {
return rewriter.notifyMatchFailure(op,
"failed to convert or operation type");
}

rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, resultType, adaptor.getLhs(),
adaptor.getRhs());

return success();
}
};

class XorOpLowering : public OpConversionPattern<cxx::XorOp> {
public:
using OpConversionPattern::OpConversionPattern;

auto matchAndRewrite(cxx::XorOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const
-> LogicalResult override {
auto typeConverter = getTypeConverter();
auto context = getContext();

auto resultType = typeConverter->convertType(op.getType());
if (!resultType) {
return rewriter.notifyMatchFailure(
op, "failed to convert xor operation type");
}

rewriter.replaceOpWithNewOp<LLVM::XOrOp>(op, resultType, adaptor.getLhs(),
adaptor.getRhs());

return success();
}
};

//
// floating point operations
//
Expand Down Expand Up @@ -1470,6 +1543,10 @@ void CxxToLLVMLoweringPass::runOnOperation() {
LessEqualOpLowering, GreaterThanOpLowering,
GreaterEqualOpLowering>(typeConverter, context);

// bitwise operations
patterns.insert<AndOpLowering, OrOpLowering, XorOpLowering>(typeConverter,
context);

// floating point operations
patterns
.insert<AddFOpLowering, SubFOpLowering, MulFOpLowering, DivFOpLowering>(
Expand Down