Skip to content

Commit 62a6850

Browse files
committed
Initial work on the MLIR codegen for switch statements
1 parent 16124bc commit 62a6850

File tree

7 files changed

+168
-15
lines changed

7 files changed

+168
-15
lines changed

src/mlir/cxx/mlir/CxxOps.td

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,12 +424,52 @@ def Cxx_GotoOp : Cxx_Op<"goto"> {
424424
let arguments = (ins StringProp:$label);
425425
}
426426

427-
def CondBranchOp : Cxx_Op<"cond_br", [ AttrSizedOperandSegments, Terminator ]> {
427+
def Cxx_CondBranchOp : Cxx_Op<"cond_br", [ AttrSizedOperandSegments, Terminator ]> {
428428
let arguments = (ins Cxx_BoolType:$condition, Variadic<AnyType>:$trueDestOperands, Variadic<AnyType>:$falseDestOperands);
429429

430430
let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
431431
}
432432

433+
def Cxx_SwitchOp : Cxx_Op<"switch", [ AttrSizedOperandSegments, Terminator ]> {
434+
let arguments = (ins
435+
Cxx_IntegerType:$value,
436+
Variadic<AnyType>:$defaultOperands,
437+
VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
438+
OptionalAttr<AnyIntElementsAttr>:$case_values,
439+
DenseI32ArrayAttr:$case_operand_segments
440+
);
441+
442+
let successors = (successor
443+
AnySuccessor:$defaultDestination,
444+
VariadicSuccessor<AnySuccessor>:$caseDestinations
445+
);
446+
447+
let builders = [
448+
OpBuilder<(ins "Value":$value,
449+
"Block *":$defaultDestination,
450+
"ValueRange":$defaultOperands,
451+
CArg<"ArrayRef<std::int64_t>", "{}">:$caseValues,
452+
CArg<"BlockRange", "{}">:$caseDestinations,
453+
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>,
454+
OpBuilder<(ins "Value":$value,
455+
"Block *":$defaultDestination,
456+
"ValueRange":$defaultOperands,
457+
CArg<"DenseIntElementsAttr", "{}">:$caseValues,
458+
CArg<"BlockRange", "{}">:$caseDestinations,
459+
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>
460+
];
461+
462+
let extraClassDeclaration = [{
463+
auto getCaseOperands(unsigned index) -> OperandRange{
464+
return getCaseOperands()[index];
465+
}
466+
467+
auto getCaseOperandsMutable(unsigned index) -> MutableOperandRange {
468+
return getCaseOperandsMutable()[index];
469+
}
470+
}];
471+
}
472+
433473
//
434474
// todo ops
435475
//

src/mlir/cxx/mlir/codegen.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,12 @@ class Codegen {
286286
: continueBlock(continueBlock), breakBlock(breakBlock) {}
287287
};
288288

289+
struct Switch {
290+
std::vector<std::int64_t> caseValues;
291+
std::vector<mlir::Block*> caseDestinations;
292+
mlir::Block* defaultDestination = nullptr;
293+
};
294+
289295
struct UnitVisitor;
290296
struct DeclarationVisitor;
291297
struct StatementVisitor;
@@ -324,6 +330,7 @@ class Codegen {
324330
std::unordered_map<std::string_view, int> uniqueSymbolNames_;
325331
std::unordered_map<const StringLiteral*, mlir::StringAttr> stringLiterals_;
326332
Loop loop_;
333+
Switch switch_;
327334
int count_ = 0;
328335
};
329336

src/mlir/cxx/mlir/codegen_statements.cc

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
// cxx
2424
#include <cxx/ast.h>
2525
#include <cxx/control.h>
26+
#include <cxx/memory_layout.h>
2627
#include <cxx/names.h>
2728

2829
// mlir
@@ -69,7 +70,8 @@ struct Codegen::ExceptionDeclarationVisitor {
6970
void Codegen::statement(StatementAST* ast) {
7071
if (!ast) return;
7172

72-
if (currentBlockMightHaveTerminator()) return;
73+
// TODO: move to the op visitors
74+
// if (currentBlockMightHaveTerminator()) return;
7375

7476
visit(StatementVisitor{*this}, ast);
7577
}
@@ -104,15 +106,21 @@ void Codegen::StatementVisitor::operator()(LabeledStatementAST* ast) {
104106
}
105107

106108
void Codegen::StatementVisitor::operator()(CaseStatementAST* ast) {
107-
(void)gen.emitTodoStmt(ast->firstSourceLocation(), to_string(ast->kind()));
109+
auto block = gen.newBlock();
108110

109-
#if false
110-
auto expressionResult = gen.expression(ast->expression);
111-
#endif
111+
gen.branch(gen.getLocation(ast->firstSourceLocation()), block);
112+
gen.builder_.setInsertionPointToEnd(block);
113+
114+
gen.switch_.caseValues.push_back(ast->caseValue);
115+
gen.switch_.caseDestinations.push_back(block);
112116
}
113117

114118
void Codegen::StatementVisitor::operator()(DefaultStatementAST* ast) {
115-
(void)gen.emitTodoStmt(ast->firstSourceLocation(), to_string(ast->kind()));
119+
auto block = gen.newBlock();
120+
gen.branch(gen.getLocation(ast->firstSourceLocation()), block);
121+
gen.builder_.setInsertionPointToEnd(block);
122+
123+
gen.switch_.defaultDestination = block;
116124
}
117125

118126
void Codegen::StatementVisitor::operator()(ExpressionStatementAST* ast) {
@@ -155,13 +163,41 @@ void Codegen::StatementVisitor::operator()(ConstevalIfStatementAST* ast) {
155163
}
156164

157165
void Codegen::StatementVisitor::operator()(SwitchStatementAST* ast) {
158-
(void)gen.emitTodoStmt(ast->firstSourceLocation(), to_string(ast->kind()));
159-
160-
#if false
161166
gen.statement(ast->initializer);
162-
auto conditionResult = gen.expression(ast->condition);
167+
168+
Switch previousSwitch;
169+
std::swap(gen.switch_, previousSwitch);
170+
171+
auto beginSwitchBlock = gen.newBlock();
172+
auto bodySwitchBlock = gen.newBlock();
173+
auto endSwitchBlock = gen.newBlock();
174+
175+
gen.branch(gen.getLocation(ast->firstSourceLocation()), beginSwitchBlock);
176+
177+
gen.builder_.setInsertionPointToEnd(bodySwitchBlock);
178+
179+
Loop previousLoop{gen.loop_.continueBlock, endSwitchBlock};
180+
std::swap(gen.loop_, previousLoop);
181+
163182
gen.statement(ast->statement);
164-
#endif
183+
gen.branch(gen.getLocation(ast->lastSourceLocation()), endSwitchBlock);
184+
185+
gen.builder_.setInsertionPointToEnd(beginSwitchBlock);
186+
187+
auto conditionResult = gen.expression(ast->condition);
188+
189+
mlir::cxx::SwitchOp::create(
190+
gen.builder_, gen.getLocation(ast->firstSourceLocation()),
191+
conditionResult.value, gen.switch_.defaultDestination, {},
192+
gen.switch_.caseValues, gen.switch_.caseDestinations,
193+
mlir::SmallVector<mlir::ValueRange>(gen.switch_.caseValues.size()));
194+
195+
std::swap(gen.switch_, previousSwitch);
196+
std::swap(gen.loop_, previousLoop);
197+
198+
gen.builder_.setInsertionPointToEnd(endSwitchBlock);
199+
200+
bodySwitchBlock->erase();
165201
}
166202

167203
void Codegen::StatementVisitor::operator()(WhileStatementAST* ast) {
@@ -358,4 +394,4 @@ auto Codegen::ExceptionDeclarationVisitor::operator()(
358394
return {};
359395
}
360396

361-
} // namespace cxx
397+
} // namespace cxx

src/mlir/cxx/mlir/cxx_dialect.cc

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
#include <mlir/IR/OpImplementation.h>
3333
#include <mlir/Interfaces/FunctionImplementation.h>
3434

35+
#include <numeric>
36+
3537
namespace mlir::cxx {
3638

3739
struct detail::ClassTypeStorage : public TypeStorage {
@@ -161,6 +163,46 @@ auto StoreOp::verify() -> LogicalResult {
161163
return success();
162164
}
163165

166+
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
167+
Block *defaultDestination, ValueRange defaultOperands,
168+
DenseIntElementsAttr caseValues,
169+
BlockRange caseDestinations,
170+
ArrayRef<ValueRange> caseOperands) {
171+
build(builder, result, value, defaultOperands, caseOperands, caseValues,
172+
defaultDestination, caseDestinations);
173+
}
174+
175+
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
176+
Block *defaultDestination, ValueRange defaultOperands,
177+
ArrayRef<std::int64_t> caseValues,
178+
BlockRange caseDestinations,
179+
ArrayRef<ValueRange> caseOperands) {
180+
DenseIntElementsAttr caseValuesAttr;
181+
182+
if (!caseValues.empty()) {
183+
auto elementTy =
184+
mlir::TypeSwitch<mlir::Type, mlir::IntegerType>(value.getType())
185+
.Case<mlir::cxx::IntegerType>(
186+
[&](mlir::cxx::IntegerType ty) -> mlir::IntegerType {
187+
return builder.getIntegerType(ty.getWidth());
188+
})
189+
.Default([](mlir::Type ty) -> mlir::IntegerType { return {}; });
190+
191+
auto shapeType =
192+
VectorType::get(static_cast<std::int64_t>(caseValues.size()),
193+
builder.getIntegerType(64));
194+
195+
caseValuesAttr = mlir::cast<DenseIntElementsAttr>(
196+
DenseIntElementsAttr::get(shapeType, caseValues)
197+
.mapValues(elementTy, [&](APInt v) {
198+
return APInt(elementTy.getIntOrFloatBitWidth(), v.getZExtValue());
199+
}));
200+
}
201+
202+
build(builder, result, value, defaultDestination, defaultOperands,
203+
caseValuesAttr, caseDestinations, caseOperands);
204+
}
205+
164206
auto FunctionType::clone(TypeRange inputs, TypeRange results) const
165207
-> FunctionType {
166208
return get(getContext(), llvm::to_vector(inputs), llvm::to_vector(results),

src/mlir/cxx/mlir/cxx_dialect_conversions.cc

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1269,6 +1269,24 @@ class LabelOpLowering : public OpConversionPattern<cxx::LabelOp> {
12691269
}
12701270
};
12711271

1272+
class SwitchOpLowering : public OpConversionPattern<cxx::SwitchOp> {
1273+
public:
1274+
using OpConversionPattern::OpConversionPattern;
1275+
1276+
auto matchAndRewrite(cxx::SwitchOp op, OpAdaptor adaptor,
1277+
ConversionPatternRewriter& rewriter) const
1278+
-> LogicalResult override {
1279+
auto context = getContext();
1280+
1281+
rewriter.replaceOpWithNewOp<cf::SwitchOp>(
1282+
op, adaptor.getValue(), op.getDefaultDestination(),
1283+
adaptor.getDefaultOperands(), *adaptor.getCaseValues(),
1284+
op.getCaseDestinations(), adaptor.getCaseOperands());
1285+
1286+
return success();
1287+
}
1288+
};
1289+
12721290
class CxxToLLVMLoweringPass
12731291
: public PassWrapper<CxxToLLVMLoweringPass, OperationPass<ModuleOp>> {
12741292
public:
@@ -1442,8 +1460,8 @@ void CxxToLLVMLoweringPass::runOnOperation() {
14421460
FloatToIntOpLowering>(typeConverter, context);
14431461

14441462
// control flow operations
1445-
patterns.insert<CondBranchOpLowering>(typeConverter, context);
1446-
patterns.insert<LabelOpLowering>(typeConverter, context);
1463+
patterns.insert<CondBranchOpLowering, LabelOpLowering, SwitchOpLowering>(
1464+
typeConverter, context);
14471465
patterns.insert<GotoOpLowering>(typeConverter, labelConverter, context);
14481466

14491467
populateFunctionOpInterfaceTypeConversionPattern<cxx::FuncOp>(patterns,

src/parser/cxx/ast.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,6 +1300,7 @@ class CaseStatementAST final : public StatementAST {
13001300
SourceLocation caseLoc;
13011301
ExpressionAST* expression = nullptr;
13021302
SourceLocation colonLoc;
1303+
std::int64_t caseValue = 0;
13031304

13041305
void accept(ASTVisitor* visitor) override { visitor->visit(this); }
13051306

src/parser/cxx/parser.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3414,6 +3414,15 @@ auto Parser::parse_case_statement(StatementAST*& yyast) -> bool {
34143414
ast->expression = expression;
34153415
ast->colonLoc = colonLoc;
34163416

3417+
if (value.has_value()) {
3418+
auto interp = ASTInterpreter{unit};
3419+
if (control()->is_unsigned(expression->type)) {
3420+
ast->caseValue = *interp.toUInt(*value);
3421+
} else {
3422+
ast->caseValue = *interp.toInt(*value);
3423+
}
3424+
}
3425+
34173426
return true;
34183427
}
34193428

0 commit comments

Comments
 (0)