Skip to content

Commit 3fc3290

Browse files
srinathavaSrinath Avadhanula
andauthored
Add support for integer division to TCP (#97)
As titled, add support for integer division in TCP. This closely mimics the existing capabilities in pytorch and arith. --------- Co-authored-by: Srinath Avadhanula <srinath.avadhanula@getcruise.com>
1 parent 36418dc commit 3fc3290

File tree

7 files changed

+151
-14
lines changed

7 files changed

+151
-14
lines changed

include/mlir-tcp/Dialect/IR/TcpEnums.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,22 @@ def Tcp_Signedness : I32EnumAttr<"Signedness",
3232

3333
def Tcp_SignednessAttr : EnumAttr<Tcp_Dialect, Tcp_Signedness, "signedness">;
3434

35+
// TCP rounding mode
36+
def Tcp_RoundingMode_Trunc : I32EnumAttrCase<"Trunc", 0>;
37+
def Tcp_RoundingMode_Floor : I32EnumAttrCase<"Floor", 1>;
38+
def Tcp_RoundingMode_Ceil : I32EnumAttrCase<"Ceil", 2>;
39+
40+
def Tcp_RoundingMode : I32EnumAttr<"RoundingMode",
41+
"Rounding mode for integer operations which need a rounding mode",
42+
[
43+
Tcp_RoundingMode_Trunc,
44+
Tcp_RoundingMode_Floor,
45+
Tcp_RoundingMode_Ceil
46+
]> {
47+
let genSpecializedAttr = 0;
48+
let cppNamespace = "::mlir::tcp";
49+
}
50+
51+
def Tcp_RoundingModeAttr : EnumAttr<Tcp_Dialect, Tcp_RoundingMode, "roundingMode">;
52+
3553
#endif // TCP_ENUMS

include/mlir-tcp/Dialect/IR/TcpOps.td

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,46 @@ def Tcp_DivFOp : Tcp_BinaryElementwiseOp<"divf", [SameOperandsAndResultElementTy
160160
let assemblyFormat = "$in1 `,` $in2 attr-dict `:` type($in1) `,` type($in2) `->` type($out)";
161161
}
162162

163+
def Tcp_DivSIOp : Tcp_BinaryElementwiseOp<"divsi", [SameOperandsAndResultElementType]> {
164+
let summary = "Computes elementwise signed integer division";
165+
166+
let description = [{
167+
Computes the signed integer division of `in1` and `in2`.
168+
}];
169+
170+
let arguments = (ins
171+
Tcp_IntTensor:$in1,
172+
Tcp_IntTensor:$in2,
173+
Tcp_RoundingModeAttr:$rounding_mode
174+
);
175+
176+
let results = (outs
177+
Tcp_IntTensor:$out
178+
);
179+
180+
let assemblyFormat = "$in1 `,` $in2 attr-dict `:` type($in1) `,` type($in2) `->` type($out)";
181+
}
182+
183+
def Tcp_DivUIOp : Tcp_BinaryElementwiseOp<"divui", [SameOperandsAndResultElementType]> {
184+
let summary = "Computes elementwise unsigned integer division";
185+
186+
let description = [{
187+
Computes the unsigned integer division of `in1` and `in2`.
188+
}];
189+
190+
let arguments = (ins
191+
Tcp_IntTensor:$in1,
192+
Tcp_IntTensor:$in2,
193+
Tcp_RoundingModeAttr:$rounding_mode
194+
);
195+
196+
let results = (outs
197+
Tcp_IntTensor:$out
198+
);
199+
200+
let assemblyFormat = "$in1 `,` $in2 attr-dict `:` type($in1) `,` type($in2) `->` type($out)";
201+
}
202+
163203
def Tcp_ConstOp : Tcp_Op<"const", [ConstantLike, Pure]> {
164204
let summary = "Constant op";
165205

lib/Conversion/TcpToLinalg/Elementwise.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,32 @@ createLinalgPayloadForElementwiseOp(Operation *op,
195195
"createLinalgPayloadForElementwiseOp for tcp.divf");
196196
}
197197

198+
if (auto divOp = dyn_cast<DivSIOp>(op)) {
199+
if (!elemType.isa<mlir::IntegerType>())
200+
llvm_unreachable("unsupported element type in "
201+
"createLinalgPayloadForElementwiseOp for tcp.divsi");
202+
if (divOp.getRoundingMode() == RoundingMode::Trunc)
203+
return {b.create<arith::DivSIOp>(loc, payloadArgs[0], payloadArgs[1])};
204+
else if (divOp.getRoundingMode() == RoundingMode::Ceil)
205+
return {
206+
b.create<arith::CeilDivSIOp>(loc, payloadArgs[0], payloadArgs[1])};
207+
else
208+
return {
209+
b.create<arith::FloorDivSIOp>(loc, payloadArgs[0], payloadArgs[1])};
210+
}
211+
212+
if (auto divOp = dyn_cast<DivUIOp>(op)) {
213+
if (!elemType.isa<mlir::IntegerType>())
214+
llvm_unreachable("unsupported element type in "
215+
"createLinalgPayloadForElementwiseOp for tcp.divui");
216+
if (divOp.getRoundingMode() == RoundingMode::Trunc ||
217+
divOp.getRoundingMode() == RoundingMode::Floor)
218+
return {b.create<arith::DivUIOp>(loc, payloadArgs[0], payloadArgs[1])};
219+
else
220+
return {
221+
b.create<arith::CeilDivUIOp>(loc, payloadArgs[0], payloadArgs[1])};
222+
}
223+
198224
if (isa<Atan2Op>(op)) {
199225
if (elemType.isa<mlir::FloatType>())
200226
return {b.create<math::Atan2Op>(loc, payloadArgs[0], payloadArgs[1])};
@@ -330,6 +356,8 @@ void mlir::TcpToLinalg::populateElementwisePatternsAndLegality(
330356
INSERT_TCP_TO_LINALG_PATTERN(ClampOp);
331357
INSERT_TCP_TO_LINALG_PATTERN(MulOp);
332358
INSERT_TCP_TO_LINALG_PATTERN(DivFOp);
359+
INSERT_TCP_TO_LINALG_PATTERN(DivSIOp);
360+
INSERT_TCP_TO_LINALG_PATTERN(DivUIOp);
333361
INSERT_TCP_TO_LINALG_PATTERN(SubOp);
334362
INSERT_TCP_TO_LINALG_PATTERN(TanhOp);
335363
INSERT_TCP_TO_LINALG_PATTERN(SigmoidOp);

lib/Conversion/TorchToTcp/Elementwise.cpp

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
290290
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
291291
ConversionPatternRewriter &rewriter) const override {
292292
Value lhs = adaptor.getSelf();
293-
RankedTensorType lhsType = lhs.getType().dyn_cast<RankedTensorType>();
293+
RankedTensorType lhsType = dyn_cast<RankedTensorType>(lhs.getType());
294294

295295
Value rhs = adaptor.getOther();
296296

@@ -303,13 +303,6 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
303303
return rewriter.notifyMatchFailure(
304304
op, "Only Ranked Tensor types are supported in TCP");
305305

306-
// TODO: Add integer conversions once `tcp.divsi` and `tcp.divui` are
307-
// added
308-
if (resultType.getElementType().isa<mlir::IntegerType>()) {
309-
return rewriter.notifyMatchFailure(
310-
op, "Only floating point division supported for now");
311-
}
312-
313306
auto inputAType = op.getSelf()
314307
.getType()
315308
.template dyn_cast<torch::Torch::ValueTensorType>()
@@ -318,17 +311,20 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
318311
.template dyn_cast<torch::Torch::ValueTensorType>()
319312
.getDtype();
320313

314+
Type inputBType = nullptr;
321315
if (isa<AtenDivScalarOp>(op)) {
316+
inputBType = adaptor.getOther().getType();
317+
322318
rhs = convertScalarOperandToTensor(rewriter, op, op.getOther(),
323319
adaptor.getOther(), outputType,
324320
resultType.getElementType());
325321
if (!rhs)
326322
return rewriter.notifyMatchFailure(op, "Unsupported rhs data type");
327323
} else {
328-
auto inputBType = op.getOther()
329-
.getType()
330-
.template dyn_cast<torch::Torch::ValueTensorType>()
331-
.getDtype();
324+
inputBType = op.getOther()
325+
.getType()
326+
.template dyn_cast<torch::Torch::ValueTensorType>()
327+
.getDtype();
332328
rhs = torch_to_tcp::castTensorToDtype(rewriter, inputBType, outputType,
333329
rhs, resultType.getElementType());
334330
}
@@ -337,7 +333,29 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
337333
std::tie(lhs, rhs) =
338334
torch_to_tcp::broadcastToMatchShape(rewriter, lhs, rhs);
339335

340-
rewriter.replaceOpWithNewOp<tcp::DivFOp>(op, resultType, lhs, rhs);
336+
if (isa<mlir::FloatType>(outputType)) {
337+
rewriter.replaceOpWithNewOp<tcp::DivFOp>(op, resultType, lhs, rhs);
338+
} else {
339+
auto in1IntType = cast<mlir::IntegerType>(inputAType);
340+
auto in2IntType = cast<mlir::IntegerType>(inputBType);
341+
auto outIntType = cast<mlir::IntegerType>(outputType);
342+
if ((in1IntType.getSignedness() != in2IntType.getSignedness()) ||
343+
(in1IntType.getSignedness() != outIntType.getSignedness()))
344+
return rewriter.notifyMatchFailure(op,
345+
"Mixed signedness not supported");
346+
if (in1IntType.getSignedness() ==
347+
mlir::IntegerType::SignednessSemantics::Signless)
348+
return rewriter.notifyMatchFailure(
349+
op, "Signless division not supported in TCP");
350+
351+
if (outIntType.getSignedness() ==
352+
mlir::IntegerType::SignednessSemantics::Unsigned)
353+
rewriter.replaceOpWithNewOp<tcp::DivUIOp>(op, resultType, lhs, rhs,
354+
tcp::RoundingMode::Trunc);
355+
else
356+
rewriter.replaceOpWithNewOp<tcp::DivSIOp>(op, resultType, lhs, rhs,
357+
tcp::RoundingMode::Trunc);
358+
}
341359
return success();
342360
}
343361
};

lib/Conversion/TorchToTcp/Utils.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@ getTcpSignednessAttr(MLIRContext *context,
3838
return SignednessAttr::get(context, Signedness::Unsigned);
3939
}
4040

41+
Signedness getTcpSignedness(IntegerType::SignednessSemantics signednessInfo) {
42+
if (signednessInfo == IntegerType::SignednessSemantics::Signless)
43+
return Signedness::Signless;
44+
if (signednessInfo == IntegerType::SignednessSemantics::Signed)
45+
return Signedness::Signed;
46+
return Signedness::Unsigned;
47+
}
48+
4149
// The parameter input is expected to be of RankedTensorType.
4250
Value broadcastRankInLeadingDims(ConversionPatternRewriter &rewriter,
4351
Value input, int64_t rankIncrease) {

lib/Conversion/TorchToTcp/Utils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ mlir::tcp::SignednessAttr
2323
getTcpSignednessAttr(MLIRContext *context,
2424
IntegerType::SignednessSemantics signednessInfo);
2525

26+
mlir::tcp::Signedness
27+
getTcpSignedness(IntegerType::SignednessSemantics signednessInfo);
28+
2629
// Helper function to expand the rank of the input tensor. Works by
2730
// adding 1-dim shape to the leading dims using `tensor::ExpandShapeOp`.
2831
Value broadcastRankInLeadingDims(ConversionPatternRewriter &rewriter,

test/Pipeline/torch_to_tcp_pipeline.mlir

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,30 @@ func.func @torch.aten.div.Tensor$mixed_type_fp(%arg0: !torch.vtensor<[?, ?],f32>
108108

109109
// -----
110110

111+
// CHECK: func.func @torch.aten.div.Tensor$mixed_type_int(%[[ARG0:.+]]: tensor<?x?xi16>, %[[ARG1:.+]]: tensor<?x?xi32>) -> tensor<?x?xi32> {
112+
// CHECK: %[[V0:.+]] = tcp.cast %[[ARG0]] {in_int_signedness = #tcp<signedness Signed>, out_int_signedness = #tcp<signedness Signed>} : tensor<?x?xi16> -> tensor<?x?xi32>
113+
// CHECK: %[[V1:.+]] = tcp.divsi %[[V0]], %[[ARG1]] {rounding_mode = #tcp<roundingMode Trunc>} : tensor<?x?xi32>, tensor<?x?xi32> -> tensor<?x?xi32>
114+
// CHECK: return %[[V1]] : tensor<?x?xi32>
111115
func.func @torch.aten.div.Tensor$mixed_type_int(%arg0: !torch.vtensor<[?, ?],si16>, %arg1: !torch.vtensor<[?, ?],si32>) -> !torch.vtensor<[?, ?],si32> {
112-
// expected-error @below {{failed to legalize operation 'torch.aten.div.Tensor' that was explicitly marked illegal}}
113116
%0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],si16>, !torch.vtensor<[?, ?],si32> -> !torch.vtensor<[?, ?],si32>
114117
return %0 : !torch.vtensor<[?, ?],si32>
115118
}
119+
120+
// -----
121+
122+
// CHECK: func.func @torch.aten.div.Tensor$mixed_type_uint(%[[ARG0:.+]]: tensor<?x?xi16>, %[[ARG1:.+]]: tensor<?x?xi32>) -> tensor<?x?xi32> {
123+
// CHECK: %[[V0:.+]] = tcp.cast %[[ARG0]] {in_int_signedness = #tcp<signedness Unsigned>, out_int_signedness = #tcp<signedness Unsigned>} : tensor<?x?xi16> -> tensor<?x?xi32>
124+
// CHECK: %[[V1:.+]] = tcp.divui %[[V0]], %[[ARG1]] {rounding_mode = #tcp<roundingMode Trunc>} : tensor<?x?xi32>, tensor<?x?xi32> -> tensor<?x?xi32>
125+
// CHECK: return %[[V1]] : tensor<?x?xi32>
126+
func.func @torch.aten.div.Tensor$mixed_type_uint(%arg0: !torch.vtensor<[?, ?],ui16>, %arg1: !torch.vtensor<[?, ?],ui32>) -> !torch.vtensor<[?, ?],ui32> {
127+
%0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],ui16>, !torch.vtensor<[?, ?],ui32> -> !torch.vtensor<[?, ?],ui32>
128+
return %0 : !torch.vtensor<[?, ?],ui32>
129+
}
130+
131+
// -----
132+
133+
func.func @torch.aten.div.Tensor$mixed_signed_int_div(%arg0: !torch.vtensor<[?, ?],si16>, %arg1: !torch.vtensor<[?, ?],ui32>) -> !torch.vtensor<[?, ?],ui32> {
134+
// expected-error @below {{failed to legalize operation 'torch.aten.div.Tensor' that was explicitly marked illegal}}
135+
%0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],si16>, !torch.vtensor<[?, ?],ui32> -> !torch.vtensor<[?, ?],ui32>
136+
return %0 : !torch.vtensor<[?, ?],ui32>
137+
}

0 commit comments

Comments
 (0)