@@ -290,7 +290,7 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
290290 matchAndRewrite (AtenOpT op, OpAdaptor adaptor,
291291 ConversionPatternRewriter &rewriter) const override {
292292 Value lhs = adaptor.getSelf ();
293- RankedTensorType lhsType = lhs. getType (). dyn_cast <RankedTensorType>();
293+ RankedTensorType lhsType = dyn_cast<RankedTensorType>(lhs. getType () );
294294
295295 Value rhs = adaptor.getOther ();
296296
@@ -303,13 +303,6 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
303303 return rewriter.notifyMatchFailure (
304304 op, " Only Ranked Tensor types are supported in TCP" );
305305
306- // TODO: Add integer conversions once `tcp.divsi` and `tcp.divui` are
307- // added
308- if (resultType.getElementType ().isa <mlir::IntegerType>()) {
309- return rewriter.notifyMatchFailure (
310- op, " Only floating point division supported for now" );
311- }
312-
313306 auto inputAType = op.getSelf ()
314307 .getType ()
315308 .template dyn_cast <torch::Torch::ValueTensorType>()
@@ -318,17 +311,20 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
318311 .template dyn_cast <torch::Torch::ValueTensorType>()
319312 .getDtype ();
320313
314+ Type inputBType = nullptr ;
321315 if (isa<AtenDivScalarOp>(op)) {
316+ inputBType = adaptor.getOther ().getType ();
317+
322318 rhs = convertScalarOperandToTensor (rewriter, op, op.getOther (),
323319 adaptor.getOther (), outputType,
324320 resultType.getElementType ());
325321 if (!rhs)
326322 return rewriter.notifyMatchFailure (op, " Unsupported rhs data type" );
327323 } else {
328- auto inputBType = op.getOther ()
329- .getType ()
330- .template dyn_cast <torch::Torch::ValueTensorType>()
331- .getDtype ();
324+ inputBType = op.getOther ()
325+ .getType ()
326+ .template dyn_cast <torch::Torch::ValueTensorType>()
327+ .getDtype ();
332328 rhs = torch_to_tcp::castTensorToDtype (rewriter, inputBType, outputType,
333329 rhs, resultType.getElementType ());
334330 }
@@ -337,7 +333,29 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
337333 std::tie (lhs, rhs) =
338334 torch_to_tcp::broadcastToMatchShape (rewriter, lhs, rhs);
339335
340- rewriter.replaceOpWithNewOp <tcp::DivFOp>(op, resultType, lhs, rhs);
336+ if (isa<mlir::FloatType>(outputType)) {
337+ rewriter.replaceOpWithNewOp <tcp::DivFOp>(op, resultType, lhs, rhs);
338+ } else {
339+ auto in1IntType = cast<mlir::IntegerType>(inputAType);
340+ auto in2IntType = cast<mlir::IntegerType>(inputBType);
341+ auto outIntType = cast<mlir::IntegerType>(outputType);
342+ if ((in1IntType.getSignedness () != in2IntType.getSignedness ()) ||
343+ (in1IntType.getSignedness () != outIntType.getSignedness ()))
344+ return rewriter.notifyMatchFailure (op,
345+ " Mixed signedness not supported" );
346+ if (in1IntType.getSignedness () ==
347+ mlir::IntegerType::SignednessSemantics::Signless)
348+ return rewriter.notifyMatchFailure (
349+ op, " Signless division not supported in TCP" );
350+
351+ if (outIntType.getSignedness () ==
352+ mlir::IntegerType::SignednessSemantics::Unsigned)
353+ rewriter.replaceOpWithNewOp <tcp::DivUIOp>(op, resultType, lhs, rhs,
354+ tcp::RoundingMode::Trunc);
355+ else
356+ rewriter.replaceOpWithNewOp <tcp::DivSIOp>(op, resultType, lhs, rhs,
357+ tcp::RoundingMode::Trunc);
358+ }
341359 return success ();
342360 }
343361};
0 commit comments