From 286443f75611766de087c95e148b666bfa0854d2 Mon Sep 17 00:00:00 2001 From: meshtag Date: Tue, 22 Oct 2024 10:54:26 +0530 Subject: [PATCH] [uncategorized_lowerings] Add lowering for torch.aten.round.decimals Implement missing lowering for the op in a similar fashion as done by torch inductor. Also fix data movement and reduce op variants patterns to correctly handle explicitly declared legal ops. Signed-off-by: Prathamesh Tagore --- .../Dialect/Torch/Transforms/Passes.h | 2 +- .../Dialect/Torch/Transforms/Passes.td | 3 ++ lib/Conversion/TorchToLinalg/DataMovement.cpp | 5 +- .../TorchToLinalg/Uncategorized.cpp | 46 ++++++++++++++++++- lib/Dialect/Torch/Transforms/Passes.cpp | 8 ++-- .../Torch/Transforms/ReduceOpVariants.cpp | 20 ++++---- .../Conversion/TorchToLinalg/elementwise.mlir | 33 +++++++++++++ ...ch-function-to-torch-backend-pipeline.mlir | 10 +++- 8 files changed, 110 insertions(+), 17 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h index 13d3a8de9463..01c138c2aabd 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h @@ -109,7 +109,7 @@ std::unique_ptr> createAdjustCallingConventionsPass(); std::unique_ptr> createInlineGlobalSlotsPass(); std::unique_ptr> -createReduceOpVariantsPass(StringRef extraLibrary); +createReduceOpVariantsPass(StringRef extraLibrary, ArrayRef = {}); std::unique_ptr> createMaximizeValueSemanticsPass(); diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td index e6b19201e85b..b3f55b64d74a 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td @@ -148,6 +148,9 @@ def ReduceOpVariants : Pass<"torch-reduce-op-variants", "func::FuncOp"> { let options = [ Option<"extraLibrary", "extra-library", "std::string", /*default=*/"", "MLIR module for verifying custom op value semantics">, + ListOption<"legalOps", "legal-ops", "std::string", + "Comma separated list of operation names that should be considered legal", + "llvm::cl::ZeroOrMore"> ]; let description = [{ Replaces ops with other ops to reduce the number of variants that diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index a18c0bae01fc..a2a738bf414d 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -2874,7 +2874,10 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( patterns.add(typeConverter, context); // Rewrite all special sparse conversions hidden as operators. target.addDynamicallyLegalOp([&](Torch::OperatorOp op) { - return !ConvertSparseOperatorOp::isSparsePrimitive(op.getNameAttr()); + // Note: Legality behaviour of torch.operator ops that are not sparse + // primitives should be conserved and not modified by this block. + return !ConvertSparseOperatorOp::isSparsePrimitive(op.getNameAttr()) && + typeConverter.isLegal(op); }); patterns.add(typeConverter, context); } diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 0f6f92bd7c2c..bb6286a008d4 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1557,6 +1557,43 @@ static Value createLinalgPayloadCalculationForElementwiseOp( threshold); } + if (auto operatorOp = dyn_cast(op)) { + // We do not yet implement lowering for other variants of the op. + if (operatorOp.getNameAttr().str() != "torch.aten.round.decimals") + return nullptr; + + // Lower the op in a similar fashion as described here: + // https://github.com/pytorch/pytorch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/torch/_inductor/decomposition.py#L223. + // Note that `aten.round` is converted to `math.roundeven`, we do this + // implicitly here because `aten.round` cannot operate on a single input + // tensor element which is what we get as payload argument. + + Location loc = op->getLoc(); + Type i64Type = b.getI64Type(); + + auto torchIntOp = dyn_cast( + operatorOp.getOperands().back().getDefiningOp()); + if (!torchIntOp) + return nullptr; + int64_t numDecimalsArg = torchIntOp.getValue(); + + Value inputTensorElem = payloadArgs[0]; + Type inputTensorElemType = inputTensorElem.getType(); + + auto numDecimals = b.create( + loc, i64Type, IntegerAttr::get(i64Type, numDecimalsArg)); + auto const10 = b.create( + loc, inputTensorElemType, FloatAttr::get(inputTensorElemType, 10)); + auto tenPowDecimals = b.create(loc, const10, numDecimals); + + auto mulTenPowDecimalsinputTensorElem = + b.create(loc, inputTensorElem, tenPowDecimals); + auto roundOp = + b.create(loc, mulTenPowDecimalsinputTensorElem); + auto res = b.create(loc, roundOp, tenPowDecimals); + return res; + } + op->emitError("unimplemented lowering in " "createLinalgPayloadCalculationForElementwiseOp"); return nullptr; @@ -1616,9 +1653,14 @@ class ConvertElementwiseOp : public ConversionPattern { AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp, AtenAcosOp, AtenAtanhOp, AtenAcoshOp, AtenAsinOp, AtenAsinhOp, AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp, - AtenQuantizePerTensorOp, AtenIscloseOp>(op)) + AtenQuantizePerTensorOp, AtenIscloseOp, OperatorOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); + if (auto operatorOp = dyn_cast(op)) + if (operatorOp.getNameAttr().str() != "torch.aten.round.decimals") + return rewriter.notifyMatchFailure(op, + "not a supported elementwise op"); + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); @@ -3375,7 +3417,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenTrilOp, AtenRemainderScalarOp, AtenRemainderTensorOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp, - AtenQuantizePerTensorOp, AtenIscloseOp>(); + AtenQuantizePerTensorOp, AtenIscloseOp, OperatorOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/Passes.cpp b/lib/Dialect/Torch/Transforms/Passes.cpp index 846470202c15..6eb09e4239e5 100644 --- a/lib/Dialect/Torch/Transforms/Passes.cpp +++ b/lib/Dialect/Torch/Transforms/Passes.cpp @@ -70,8 +70,8 @@ void mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline( void mlir::torch::Torch::createTorchDynamoExportToTorchBackendPipeline( OpPassManager &pm, const TorchLoweringPipelineOptions &options) { - pm.addNestedPass( - createReduceOpVariantsPass(options.extraLibrary)); + pm.addNestedPass(createReduceOpVariantsPass( + options.extraLibrary, options.backendLegalOps)); pm.addNestedPass(createCanonicalizerPass()); if (options.decompose) { pm.addNestedPass( @@ -161,8 +161,8 @@ void mlir::torch::Torch::createTorchSimplificationPipeline( pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass(createRecomposeComplexOpsPass()); // Reduce variants of ops to a smaller set of primitives. - pm.addNestedPass( - createReduceOpVariantsPass(options.extraLibrary)); + pm.addNestedPass(createReduceOpVariantsPass( + options.extraLibrary, options.backendLegalOps)); pm.addNestedPass(createCanonicalizerPass()); // Remove dead global slots. pm.addPass(createSymbolDCEPass()); diff --git a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp index 5712b66f6c1d..2dd84f6d5b7e 100644 --- a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp +++ b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp @@ -403,8 +403,9 @@ namespace { struct ReduceOpVariantsPass : public ReduceOpVariantsBase { ReduceOpVariantsPass() = default; - ReduceOpVariantsPass(StringRef extraLibrary) { + ReduceOpVariantsPass(StringRef extraLibrary, ArrayRef legalOps) { this->extraLibrary = extraLibrary.str(); + this->legalOps = legalOps; } void runOnOperation() override { MLIRContext *context = &getContext(); @@ -439,13 +440,15 @@ struct ReduceOpVariantsPass target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + + target.addDynamicallyLegalOp([&](OperatorOp op) { + auto opNameAttr = op.getNameAttr(); + return llvm::find(legalOps, opNameAttr.str()) != legalOps.end() && + !specializedNames.contains(opNameAttr); + }); + target.markUnknownOpDynamicallyLegal([&extraLibraryModuleSymTable, &specializedNames](Operation *op) { - if (isa(op)) { - if (specializedNames.contains(cast(op).getNameAttr())) { - return false; - } - } if (op->hasTrait() || (isa(op) && operatorOpHasValueSemantics(cast(op), @@ -479,6 +482,7 @@ struct ReduceOpVariantsPass } // namespace std::unique_ptr> -mlir::torch::Torch::createReduceOpVariantsPass(StringRef extraLibrary) { - return std::make_unique(extraLibrary); +mlir::torch::Torch::createReduceOpVariantsPass(StringRef extraLibrary, + ArrayRef legalOps) { + return std::make_unique(extraLibrary, legalOps); } diff --git a/test/Conversion/TorchToLinalg/elementwise.mlir b/test/Conversion/TorchToLinalg/elementwise.mlir index aa2be74f5d7e..4b08eeecb36f 100644 --- a/test/Conversion/TorchToLinalg/elementwise.mlir +++ b/test/Conversion/TorchToLinalg/elementwise.mlir @@ -102,3 +102,36 @@ func.func @elementwise_sinh(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3 %0 = torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch_aten_round_decimals +// CHECK: %[[VAL2:.*]] = linalg.generic +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %{{.*}}: f32): +// CHECK-NEXT: %[[CONST_64:.*]] = arith.constant +// CHECK-NEXT: %[[CONST_10:.*]] = arith.constant 1.000000e+01 +// CHECK-NEXT: %[[VAL4:.*]] = math.fpowi %[[CONST_10]], %[[CONST_64]] +// CHECK-NEXT: %[[VAL5:.*]] = arith.mulf %[[IN]], %[[VAL4]] +// CHECK-NEXT: %[[VAL6:.*]] = math.roundeven %[[VAL5]] +// CHECK-NEXT: %[[VAL7:.*]] = arith.divf %[[VAL6]], %[[VAL4]] +// CHECK-NEXT: linalg.yield %[[VAL7]] +// CHECK: %[[CAST:.*]] = tensor.cast %[[VAL2]] +// CHECK-NEXT: %[[VAL3:.*]] = torch_c.from_builtin_tensor %[[CAST]] +// CHECK-NEXT: return %[[VAL3]] +func.func @torch_aten_round_decimals(%0: !torch.vtensor<[1,1024,1024,3],f32>) -> !torch.vtensor<[1, 1024,1024,3],f32> { + %int0 = torch.constant.int 0 + %1 = torch.operator "torch.aten.round.decimals"(%0, %int0) : (!torch.vtensor<[1,1024,1024,3],f32>, !torch.int) -> !torch.vtensor<[1,1024,1024,3],f32> + return %1 : !torch.vtensor<[1, 1024,1024,3],f32> +} + +// ----- + +// Test that unhandled versions of `torch.operator` op are not legalized. +func.func @torch.prims.device_put(%arg0: !torch.vtensor<[1,77],si64>) -> !torch.vtensor<[1,77],si64> { + %cuda3A0 = torch.constant.device "cuda:0" + // expected-error @+1 {{failed to legalize operation 'torch.operator' that was explicitly marked illegal}} + %0 = torch.operator "torch.prims.device_put"(%arg0, %cuda3A0) : (!torch.vtensor<[1,77],si64>, !torch.Device) -> !torch.vtensor<[1,77],si64> + %int4 = torch.constant.int 4 + %1 = torch.prims.convert_element_type %0, %int4 : !torch.vtensor<[1,77],si64>, !torch.int -> !torch.vtensor<[1,77],si64> + return %1 : !torch.vtensor<[1,77],si64> +} diff --git a/test/Dialect/Torch/torch-function-to-torch-backend-pipeline.mlir b/test/Dialect/Torch/torch-function-to-torch-backend-pipeline.mlir index 2cc024ee40eb..850b782faef3 100644 --- a/test/Dialect/Torch/torch-function-to-torch-backend-pipeline.mlir +++ b/test/Dialect/Torch/torch-function-to-torch-backend-pipeline.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt -pass-pipeline='builtin.module(torch-function-to-torch-backend-pipeline{backend-legal-ops=aten.square,aten.argmax})' -split-input-file %s | FileCheck %s +// RUN: torch-mlir-opt -pass-pipeline='builtin.module(torch-function-to-torch-backend-pipeline{backend-legal-ops=aten.square,aten.argmax,torch.aten.round.decimals})' -split-input-file %s | FileCheck %s // CHECK-LABEL: func.func @torch.aten.square func.func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { @@ -25,3 +25,11 @@ func.func @torch.uint8(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[ %1 = torch.aten.reshape %arg0, %0 : !torch.tensor, !torch.list -> !torch.tensor return %1 : !torch.tensor } + +// Test that "torch.aten.round.decimals" was considered legal after explicitly specifying it in pass options. +// CHECK-LABEL: func.func @torch_aten_round_decimals +func.func @torch_aten_round_decimals(%0: !torch.vtensor<[1,1024,1024,3],f32>) -> !torch.vtensor<[1, 1024,1024,3],f32> { + %int0 = torch.constant.int 0 + %1 = torch.operator "torch.aten.round.decimals"(%0, %int0) : (!torch.vtensor<[1,1024,1024,3],f32>, !torch.int) -> !torch.vtensor<[1,1024,1024,3],f32> + return %1 : !torch.vtensor<[1, 1024,1024,3],f32> +}