Skip to content

Commit

Permalink
[uncategorized_lowerings] Add lowering for torch.aten.round.decimals
Browse files Browse the repository at this point in the history
Implement missing lowering for the op in a similar fashion as done by torch
inductor. Also fix data movement and reduce op variants patterns to correctly
handle explicitly declared legal ops.

Signed-off-by: Prathamesh Tagore <prathamesh+1@polymagelabs.com>
  • Loading branch information
meshtag committed Oct 22, 2024
1 parent 140cad5 commit 0f5a2dc
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 19 deletions.
2 changes: 1 addition & 1 deletion include/torch-mlir/Dialect/Torch/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();
std::unique_ptr<OperationPass<ModuleOp>> createInlineGlobalSlotsPass();

std::unique_ptr<OperationPass<func::FuncOp>>
createReduceOpVariantsPass(StringRef extraLibrary);
createReduceOpVariantsPass(StringRef extraLibrary, ArrayRef<std::string> = {});

std::unique_ptr<OperationPass<func::FuncOp>> createMaximizeValueSemanticsPass();

Expand Down
3 changes: 3 additions & 0 deletions include/torch-mlir/Dialect/Torch/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ def ReduceOpVariants : Pass<"torch-reduce-op-variants", "func::FuncOp"> {
let options = [
Option<"extraLibrary", "extra-library", "std::string", /*default=*/"",
"MLIR module for verifying custom op value semantics">,
ListOption<"legalOps", "legal-ops", "std::string",
"Comma separated list of operation names that should be considered legal",
"llvm::cl::ZeroOrMore">
];
let description = [{
Replaces ops with other ops to reduce the number of variants that
Expand Down
5 changes: 4 additions & 1 deletion lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2874,7 +2874,10 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
patterns.add<ConvertAtenDiagEmbedOp>(typeConverter, context);
// Rewrite all special sparse conversions hidden as operators.
target.addDynamicallyLegalOp<OperatorOp>([&](Torch::OperatorOp op) {
return !ConvertSparseOperatorOp::isSparsePrimitive(op.getNameAttr());
// Note: Legality behaviour of torch.operator ops that are not sparse
// primitives should be conserved and not modified by this block.
return !ConvertSparseOperatorOp::isSparsePrimitive(op.getNameAttr()) &&
typeConverter.isLegal(op);
});
patterns.add<ConvertSparseOperatorOp>(typeConverter, context);
}
46 changes: 44 additions & 2 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1557,6 +1557,43 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
threshold);
}

if (auto operatorOp = dyn_cast<OperatorOp>(op)) {
// We do not yet implement lowering for other variants of the op.
if (operatorOp.getNameAttr().str() != "torch.aten.round.decimals")
return nullptr;

// Lower the op in a similar fashion as described here:
// https://github.com/pytorch/pytorch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/torch/_inductor/decomposition.py#L223.
// Note that `aten.round` is converted to `math.roundeven`, we do this
// implicitly here because `aten.round` cannot operate on a single input
// tensor element which is what we get as payload argument.

Location loc = op->getLoc();
Type i64Type = b.getI64Type();

auto torchIntOp = dyn_cast<ConstantIntOp>(
operatorOp.getOperands().back().getDefiningOp());
if (!torchIntOp)
return nullptr;
int64_t numDecimalsArg = torchIntOp.getValue();

Value inputTensorElem = payloadArgs[0];
Type inputTensorElemType = inputTensorElem.getType();

auto numDecimals = b.create<arith::ConstantOp>(
loc, i64Type, IntegerAttr::get(i64Type, numDecimalsArg));
auto const10 = b.create<arith::ConstantOp>(
loc, inputTensorElemType, FloatAttr::get(inputTensorElemType, 10));
auto tenPowDecimals = b.create<math::FPowIOp>(loc, const10, numDecimals);

auto mulTenPowDecimalsinputTensorElem =
b.create<arith::MulFOp>(loc, inputTensorElem, tenPowDecimals);
auto roundOp =
b.create<math::RoundEvenOp>(loc, mulTenPowDecimalsinputTensorElem);
auto res = b.create<arith::DivFOp>(loc, roundOp, tenPowDecimals);
return res;
}

op->emitError("unimplemented lowering in "
"createLinalgPayloadCalculationForElementwiseOp");
return nullptr;
Expand Down Expand Up @@ -1616,9 +1653,14 @@ class ConvertElementwiseOp : public ConversionPattern {
AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp, AtenAcosOp,
AtenAtanhOp, AtenAcoshOp, AtenAsinOp, AtenAsinhOp, AtenRealOp,
AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
AtenQuantizePerTensorOp, AtenIscloseOp>(op))
AtenQuantizePerTensorOp, AtenIscloseOp, OperatorOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");

if (auto operatorOp = dyn_cast<OperatorOp>(op))
if (operatorOp.getNameAttr().str() != "torch.aten.round.decimals")
return rewriter.notifyMatchFailure(op,
"not a supported elementwise op");

if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();

Expand Down Expand Up @@ -3375,7 +3417,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenTrilOp, AtenRemainderScalarOp, AtenRemainderTensorOp,
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
AtenQuantizePerTensorOp, AtenIscloseOp>();
AtenQuantizePerTensorOp, AtenIscloseOp, OperatorOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
Expand Down
8 changes: 4 additions & 4 deletions lib/Dialect/Torch/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ void mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline(

void mlir::torch::Torch::createTorchDynamoExportToTorchBackendPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
pm.addNestedPass<func::FuncOp>(
createReduceOpVariantsPass(options.extraLibrary));
pm.addNestedPass<func::FuncOp>(createReduceOpVariantsPass(
options.extraLibrary, options.backendLegalOps));
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
if (options.decompose) {
pm.addNestedPass<func::FuncOp>(
Expand Down Expand Up @@ -161,8 +161,8 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createRecomposeComplexOpsPass());
// Reduce variants of ops to a smaller set of primitives.
pm.addNestedPass<func::FuncOp>(
createReduceOpVariantsPass(options.extraLibrary));
pm.addNestedPass<func::FuncOp>(createReduceOpVariantsPass(
options.extraLibrary, options.backendLegalOps));
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// Remove dead global slots.
pm.addPass(createSymbolDCEPass());
Expand Down
24 changes: 14 additions & 10 deletions lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,9 @@ namespace {
struct ReduceOpVariantsPass
: public ReduceOpVariantsBase<ReduceOpVariantsPass> {
ReduceOpVariantsPass() = default;
ReduceOpVariantsPass(StringRef extraLibrary) {
ReduceOpVariantsPass(StringRef extraLibrary, ArrayRef<std::string> legalOps) {
this->extraLibrary = extraLibrary.str();
this->legalOps = legalOps;
}
void runOnOperation() override {
MLIRContext *context = &getContext();
Expand Down Expand Up @@ -439,13 +440,15 @@ struct ReduceOpVariantsPass
target.addIllegalOp<NonValueTensorLiteralOp>();
target.addIllegalOp<AtenBernoulli_FloatOp>();
target.addIllegalOp<AtenArangeStartOutOp>();
target.markUnknownOpDynamicallyLegal([&extraLibraryModuleSymTable,
&specializedNames](Operation *op) {
if (isa<OperatorOp>(op)) {
if (specializedNames.contains(cast<OperatorOp>(op).getNameAttr())) {
return false;
}
}

target.addDynamicallyLegalOp<OperatorOp>([&](OperatorOp op) {
auto opNameAttr = op.getNameAttr();
return llvm::find(legalOps, opNameAttr.str()) != legalOps.end() &&
!specializedNames.contains(opNameAttr);
});

target.markUnknownOpDynamicallyLegal([&extraLibraryModuleSymTable](
Operation *op) {
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>() ||
(isa<OperatorOp>(op) &&
operatorOpHasValueSemantics(cast<OperatorOp>(op),
Expand Down Expand Up @@ -479,6 +482,7 @@ struct ReduceOpVariantsPass
} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::Torch::createReduceOpVariantsPass(StringRef extraLibrary) {
return std::make_unique<ReduceOpVariantsPass>(extraLibrary);
mlir::torch::Torch::createReduceOpVariantsPass(StringRef extraLibrary,
ArrayRef<std::string> legalOps) {
return std::make_unique<ReduceOpVariantsPass>(extraLibrary, legalOps);
}
33 changes: 33 additions & 0 deletions test/Conversion/TorchToLinalg/elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,36 @@ func.func @elementwise_sinh(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3
%0 = torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
return %0 : !torch.vtensor<[3],f32>
}

// -----

// CHECK-LABEL: func.func @torch_aten_round_decimals
// CHECK: %[[VAL2:.*]] = linalg.generic
// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %{{.*}}: f32):
// CHECK-NEXT: %[[CONST_64:.*]] = arith.constant
// CHECK-NEXT: %[[CONST_10:.*]] = arith.constant 1.000000e+01
// CHECK-NEXT: %[[VAL4:.*]] = math.fpowi %[[CONST_10]], %[[CONST_64]]
// CHECK-NEXT: %[[VAL5:.*]] = arith.mulf %[[IN]], %[[VAL4]]
// CHECK-NEXT: %[[VAL6:.*]] = math.roundeven %[[VAL5]]
// CHECK-NEXT: %[[VAL7:.*]] = arith.divf %[[VAL6]], %[[VAL4]]
// CHECK-NEXT: linalg.yield %[[VAL7]]
// CHECK: %[[CAST:.*]] = tensor.cast %[[VAL2]]
// CHECK-NEXT: %[[VAL3:.*]] = torch_c.from_builtin_tensor %[[CAST]]
// CHECK-NEXT: return %[[VAL3]]
func.func @torch_aten_round_decimals(%0: !torch.vtensor<[1,1024,1024,3],f32>) -> !torch.vtensor<[1, 1024,1024,3],f32> {
%int0 = torch.constant.int 0
%1 = torch.operator "torch.aten.round.decimals"(%0, %int0) : (!torch.vtensor<[1,1024,1024,3],f32>, !torch.int) -> !torch.vtensor<[1,1024,1024,3],f32>
return %1 : !torch.vtensor<[1, 1024,1024,3],f32>
}

// -----

// Test that unhandled versions of `torch.operator` op are not legalized.
func.func @torch.prims.device_put(%arg0: !torch.vtensor<[1,77],si64>) -> !torch.vtensor<[1,77],si64> {
%cuda3A0 = torch.constant.device "cuda:0"
// expected-error @+1 {{failed to legalize operation 'torch.operator' that was explicitly marked illegal}}
%0 = torch.operator "torch.prims.device_put"(%arg0, %cuda3A0) : (!torch.vtensor<[1,77],si64>, !torch.Device) -> !torch.vtensor<[1,77],si64>
%int4 = torch.constant.int 4
%1 = torch.prims.convert_element_type %0, %int4 : !torch.vtensor<[1,77],si64>, !torch.int -> !torch.vtensor<[1,77],si64>
return %1 : !torch.vtensor<[1,77],si64>
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: torch-mlir-opt -pass-pipeline='builtin.module(torch-function-to-torch-backend-pipeline{backend-legal-ops=aten.square,aten.argmax})' -split-input-file %s | FileCheck %s
// RUN: torch-mlir-opt -pass-pipeline='builtin.module(torch-function-to-torch-backend-pipeline{backend-legal-ops=aten.square,aten.argmax,torch.aten.round.decimals})' -split-input-file %s | FileCheck %s

// CHECK-LABEL: func.func @torch.aten.square
func.func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
Expand All @@ -25,3 +25,11 @@ func.func @torch.uint8(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[
%1 = torch.aten.reshape %arg0, %0 : !torch.tensor, !torch.list<int> -> !torch.tensor
return %1 : !torch.tensor
}

// Test that "torch.aten.round.decimals" was considered legal after explicitly specifying it in pass options.
// CHECK-LABEL: func.func @torch_aten_round_decimals
func.func @torch_aten_round_decimals(%0: !torch.vtensor<[1,1024,1024,3],f32>) -> !torch.vtensor<[1, 1024,1024,3],f32> {
%int0 = torch.constant.int 0
%1 = torch.operator "torch.aten.round.decimals"(%0, %int0) : (!torch.vtensor<[1,1024,1024,3],f32>, !torch.int) -> !torch.vtensor<[1,1024,1024,3],f32>
return %1 : !torch.vtensor<[1, 1024,1024,3],f32>
}

0 comments on commit 0f5a2dc

Please sign in to comment.