diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c3e0141530e4..56ab43a23c50 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -14091,6 +14091,36 @@ def Torch_AtenGridSamplerOp : Torch_Op<"aten.grid_sampler", [ }]; } +def Torch_Aten_TrilinearOp : Torch_Op<"aten._trilinear", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_trilinear : (Tensor, Tensor, Tensor, int[], int[], int[], int[], int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$i1, + AnyTorchTensorType:$i2, + AnyTorchTensorType:$i3, + AnyTorchListOfTorchIntType:$expand1, + AnyTorchListOfTorchIntType:$expand2, + AnyTorchListOfTorchIntType:$expand3, + AnyTorchListOfTorchIntType:$sumdim, + Torch_IntType:$unroll_dim + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_TrilinearOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 8, 1); + } + void Aten_TrilinearOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 8, 1); + } + }]; +} + def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index f2963f7c803d..58f7565e9096 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -8852,6 +8852,112 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.linear(%arg0, %arg1, %arg2) : (!torch.list, !torch.list, !torch.optional>) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten._trilinear\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.list, %arg7: !torch.int) -> !torch.list {\n" +" %int3 = torch.constant.int 3\n" +" %int-1 = torch.constant.int -1\n" +" %str = torch.constant.str \"AssertionError: number of dimensions must match\"\n" +" %str_0 = torch.constant.str \"expand dimension {} is out of bounds for input of shape {}\"\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str_1 = torch.constant.str \"AssertionError: \"\n" +" %str_2 = torch.constant.str \"unroll_dim must be in [0, {}]\"\n" +" %false = torch.constant.bool false\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.len.t %arg3 : !torch.list -> !torch.int\n" +" %2 = torch.aten.add.int %0, %1 : !torch.int, !torch.int -> !torch.int\n" +" %3 = torch.aten.ge.int %arg7, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %23 = torch.aten.lt.int %arg7, %2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %23 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" %23 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %24 = torch.aten.format(%str_2, %23) : !torch.str, !torch.int -> !torch.str\n" +" %25 = torch.aten.add.str %str_1, %24 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %25, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = call @__torch__.torch.jit._shape_functions._copy(%arg0) : (!torch.list) -> !torch.list\n" +" %6 = call @__torch__.torch.jit._shape_functions._copy(%arg1) : (!torch.list) -> !torch.list\n" +" %7 = call @__torch__.torch.jit._shape_functions._copy(%arg2) : (!torch.list) -> !torch.list\n" +" %8 = torch.prim.ListConstruct %5, %6, %7 : (!torch.list, !torch.list, !torch.list) -> !torch.list>\n" +" %9 = torch.prim.ListConstruct %arg3, %arg4, %arg5 : (!torch.list, !torch.list, !torch.list) -> !torch.list>\n" +" torch.prim.Loop %int3, %true, init() {\n" +" ^bb0(%arg8: !torch.int):\n" +" %23 = torch.aten.__getitem__.t %9, %arg8 : !torch.list>, !torch.int -> !torch.list\n" +" %24 = torch.aten.__getitem__.t %8, %arg8 : !torch.list>, !torch.int -> !torch.list\n" +" %25 = torch.aten.len.t %24 : !torch.list -> !torch.int\n" +" %26 = torch.aten.len.t %23 : !torch.list -> !torch.int\n" +" torch.prim.Loop %26, %true, init() {\n" +" ^bb0(%arg9: !torch.int):\n" +" %27 = torch.aten.__getitem__.t %23, %arg9 : !torch.list, !torch.int -> !torch.int\n" +" %28 = torch.aten.le.int %27, %25 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %28 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" %30 = torch.aten.__getitem__.t %8, %arg8 : !torch.list>, !torch.int -> !torch.list\n" +" %31 = torch.aten.format(%str_0, %27, %30) : !torch.str, !torch.int, !torch.list -> !torch.str\n" +" %32 = torch.aten.add.str %str_1, %31 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %32, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %29 = torch.aten.__getitem__.t %8, %arg8 : !torch.list>, !torch.int -> !torch.list\n" +" torch.aten.insert.t %29, %27, %int1 : !torch.list, !torch.int, !torch.int\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %10 = torch.aten.len.t %5 : !torch.list -> !torch.int\n" +" %11 = torch.aten.len.t %6 : !torch.list -> !torch.int\n" +" %12 = torch.aten.eq.int %10, %11 : !torch.int, !torch.int -> !torch.bool\n" +" %13 = torch.prim.If %12 -> (!torch.bool) {\n" +" %23 = torch.aten.len.t %6 : !torch.list -> !torch.int\n" +" %24 = torch.aten.len.t %7 : !torch.list -> !torch.int\n" +" %25 = torch.aten.eq.int %23, %24 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %25 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %13 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %14 = call @__torch__.torch.jit._shape_functions.broadcast_three(%5, %6, %7) : (!torch.list, !torch.list, !torch.list) -> !torch.list\n" +" %15 = torch.prim.ListConstruct %false : (!torch.bool) -> !torch.list\n" +" %16 = torch.aten.len.t %14 : !torch.list -> !torch.int\n" +" %17 = torch.operator \"aten.mul.left_t\"(%15, %16) : (!torch.list, !torch.int) -> !torch.list \n" +" %18 = torch.aten.len.t %arg6 : !torch.list -> !torch.int\n" +" torch.prim.Loop %18, %true, init() {\n" +" ^bb0(%arg8: !torch.int):\n" +" %23 = torch.aten.__getitem__.t %arg6, %arg8 : !torch.list, !torch.int -> !torch.int\n" +" %24 = torch.aten._set_item.t %17, %23, %true : !torch.list, !torch.int, !torch.bool -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %19 = torch.aten.len.t %14 : !torch.list -> !torch.int\n" +" %20 = torch.aten.sub.int %19, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %21 = torch.aten.__range_length %20, %int-1, %int-1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %22 = torch.prim.Loop %21, %true, init(%14) {\n" +" ^bb0(%arg8: !torch.int, %arg9: !torch.list):\n" +" %23 = torch.aten.__derive_index %arg8, %20, %int-1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %24 = torch.aten.__getitem__.t %17, %23 : !torch.list, !torch.int -> !torch.bool\n" +" %25 = torch.prim.If %24 -> (!torch.list) {\n" +" %26 = func.call @__torch__.torch.jit._shape_functions._reduce_along_dim(%arg9, %23, %false) : (!torch.list, !torch.int, !torch.bool) -> !torch.list\n" +" torch.prim.If.yield %26 : !torch.list\n" +" } else {\n" +" torch.prim.If.yield %arg9 : !torch.list\n" +" }\n" +" torch.prim.Loop.condition %true, iter(%25 : !torch.list)\n" +" } : (!torch.int, !torch.bool, !torch.list) -> !torch.list\n" +" return %22 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.scaled_dot_product_attention\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional>, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.optional, %arg7: !torch.bool) -> !torch.list {\n" " %int-1 = torch.constant.int -1\n" " %0 = torch.aten.__getitem__.t %arg2, %int-1 : !torch.list, !torch.int -> !torch.int\n" @@ -15206,6 +15312,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._trilinear\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.list, %arg7: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list>\n" +" %4 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %5 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.cat\"(%arg0: !torch.list>, %arg1: !torch.int) -> !torch.int {\n" " %true = torch.constant.bool true\n" " %none = torch.constant.none\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 9b24d0e959f3..29c5679573f9 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -9,6 +9,7 @@ #include "PassDetail.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -399,9 +400,9 @@ static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc, auto inputType = cast(input.getType()); auto inputRank = batchDimsLength + contractingDimsLength + otherDimsLength + reduceDimsLength; - SmallVector inputShapeTensor; + SmallVector inputShapeTensor; for (auto i = 0; i < inputRank; ++i) { - inputShapeTensor.emplace_back(rewriter.create( + inputShapeTensor.emplace_back(rewriter.createOrFold( loc, input, rewriter.create(loc, rewriter.getI64IntegerAttr(i)))); @@ -412,13 +413,23 @@ static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc, rewriter.create(loc, rewriter.getI64IntegerAttr(1)); auto dimOffset = 0; + auto materializeIntFold = [&](OpFoldResult thing) { + if (auto attr = dyn_cast(thing)) { + Value result = rewriter.create( + loc, cast(attr)); + return result; + } + return cast(thing); + }; + auto appendDims = [&](int64_t dimLength) { - Value prod = constOne; + OpFoldResult prod = getAsOpFoldResult(constOne); for (auto i = 0; i < dimLength; ++i) { - prod = rewriter.create(loc, prod, - inputShapeTensor[i + dimOffset]); + prod = rewriter.createOrFold( + loc, materializeIntFold(prod), + materializeIntFold(inputShapeTensor[i + dimOffset])); } - outShapeTensor.emplace_back(prod); + outShapeTensor.emplace_back(materializeIntFold(prod)); dimOffset += dimLength; }; @@ -570,21 +581,32 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc, Type outputDType = lhsType.hasDtype() ? lhsType.getOptionalDtype() : rhsType.getOptionalDtype(); + auto materializeIntFold = [&](OpFoldResult thing) { + if (auto attr = dyn_cast(thing)) { + Value result = rewriter.create( + loc, cast(attr)); + return result; + } + return cast(thing); + }; + llvm::SmallDenseMap lhsDimShapeMap; for (size_t idx = 0; idx < lhsTokens.size(); ++idx) { char d = lhsTokens[idx]; - lhsDimShapeMap[d] = rewriter.create( + OpFoldResult lhsFold = rewriter.createOrFold( loc, lhs, rewriter.create(loc, rewriter.getI64IntegerAttr(idx))); + lhsDimShapeMap[d] = materializeIntFold(lhsFold); } llvm::SmallDenseMap rhsDimShapeMap; for (size_t idx = 0; idx < rhsTokens.size(); ++idx) { char d = rhsTokens[idx]; - rhsDimShapeMap[d] = rewriter.create( + OpFoldResult rhsFold = rewriter.createOrFold( loc, rhs, rewriter.create(loc, rewriter.getI64IntegerAttr(idx))); + rhsDimShapeMap[d] = materializeIntFold(rhsFold); } // parse batch, contracting, other, reduce dims of lhs and rhs @@ -604,8 +626,9 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc, bool lhsContains = lhsDimShapeMap.count(d) > 0; bool rhsContains = rhsDimShapeMap.count(d) > 0; if (lhsContains && rhsContains) { - outDimShapeMap[d] = rewriter.create( + OpFoldResult out = rewriter.createOrFold( loc, lhsDimShapeMap[d], rhsDimShapeMap[d]); + outDimShapeMap[d] = materializeIntFold(out); } else if (lhsContains) { outDimShapeMap[d] = lhsDimShapeMap[d]; } else if (rhsContains) { @@ -1973,6 +1996,125 @@ class DecomposeAtenEinsumOp : public OpRewritePattern { }; } // namespace +namespace { +// Trilinear einstein sum, decomposed to: +// (i1.unsqueeze(expand1) * i2.unsqueeze(expand2) * i3.unsqueeze(expand3)) +// .sum(sumdim) +// The unrollDim operand does not impact the output of the operation, so +// it is ignored. + +class DecomposeAten_TrilinearOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(Aten_TrilinearOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + + Value input1 = op.getI1(); + Value input2 = op.getI2(); + Value input3 = op.getI3(); + + // Expansions + SmallVector expand1; + SmallVector expand2; + SmallVector expand3; + if (!matchPattern(op.getExpand1(), m_TorchListOfConstantInts(expand1))) { + return rewriter.notifyMatchFailure(op, "expand1 should be constant"); + } + if (!matchPattern(op.getExpand2(), m_TorchListOfConstantInts(expand2))) { + return rewriter.notifyMatchFailure(op, "expand2 should be constant"); + } + if (!matchPattern(op.getExpand3(), m_TorchListOfConstantInts(expand3))) { + return rewriter.notifyMatchFailure(op, "expand3 should be constant"); + } + + SmallVector sumDim; + if (!matchPattern(op.getSumdim(), m_TorchListOfConstantInts(sumDim))) { + return rewriter.notifyMatchFailure(op, "sumDim should be constant"); + } + + // Check if there are any dimensions that intersect between expand1, + // expand2, and expand3. + int64_t totalDims = + cast(input1.getType()).getSizes().size() + + expand1.size(); + if (sharedExpandDims(totalDims, expand1, expand2, expand3, sumDim)) { + // pytorch issue filed: https://github.com/pytorch/pytorch/issues/138353 + // TODO: Remove warning when issue gets resolved. + op->emitWarning("aten::_trilinear implementation in this case is " + "non-functional (returns an empty dimension). We will " + "intentionally deviate from this behavior."); + } + + // Apply unsqueeze to respective input tensors at the specified dimensions + SmallVector sortedExpand1 = expand1; + std::sort(sortedExpand1.begin(), sortedExpand1.end()); + for (auto expand : sortedExpand1) { + Value expandDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(expand)); + input1 = *unsqueezeTensor(rewriter, op, input1, expandDim); + } + SmallVector sortedExpand2 = expand2; + std::sort(sortedExpand2.begin(), sortedExpand2.end()); + for (auto expand : sortedExpand2) { + Value expandDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(expand)); + input2 = *unsqueezeTensor(rewriter, op, input2, expandDim); + } + SmallVector sortedExpand3 = expand3; + std::sort(sortedExpand3.begin(), sortedExpand3.end()); + for (auto expand : sortedExpand3) { + Value expandDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(expand)); + input3 = *unsqueezeTensor(rewriter, op, input3, expandDim); + } + + // Apply multiplication operation. + auto mul1 = + rewriter.create(loc, op.getType(), input1, input2); + auto mul2 = + rewriter.create(loc, op.getType(), mul1, input3); + + // Apply sum operation. + // Parse sumDim in descending order to avoid any issues with the + // dimensions being removed. + Value result = mul2; + SmallVector sortedSumDims = sumDim; + std::sort(sortedSumDims.rbegin(), sortedSumDims.rend()); + for (int64_t dim : sortedSumDims) { + Value dimValue = rewriter.create( + loc, rewriter.getI64IntegerAttr(dim)); + result = + createSumAlongDimension(rewriter, loc, op, result, dimValue, false); + } + + rewriter.replaceOp(op, result); + return success(); + } + +private: + // Determine if there are any dimensions that intersect between expand1, + // expand2, and expand3. + bool sharedExpandDims(const int64_t &totalDims, + const SmallVector &expand1, + const SmallVector &expand2, + const SmallVector &expand3, + const SmallVector &sumDim) const { + for (int64_t i = 0; i < totalDims; ++i) { + if (!contains(sumDim, i) && contains(expand1, i) && + contains(expand2, i) && contains(expand3, i)) { + return true; + } + } + return false; + } + bool contains(const SmallVector &vec, int64_t value) const { + return std::find(vec.begin(), vec.end(), value) != vec.end(); + } +}; +} // namespace + namespace { // Calculate the trace of the input tensor as the sum over its diagonal // elements. This computation is performed as: @@ -9928,6 +10070,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index ebc43faa595c..bba4a8dde40c 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -400,6 +400,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 1755806a0e66..275536601efc 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -29,6 +29,10 @@ "DeformConv2D_basic", "ReduceAnyDimFloatModule_basic", "UnfoldModule_basic", + # _trilinear is an implementation of einsum, but sets dimensions to zero + # if a dimension is specified in all expand lists, and not in sumdim list. + # This is a bug in the implementation of _trilinear in PyTorch. + "Aten_TrilinearModuleZerodDimBug_basic", } if torch_version_for_comparison() < version.parse("2.5.0.dev"): @@ -403,6 +407,8 @@ "AtenMmQMixedSigni8_basic", "AtenMmQint8_basic", "AtenMmQuint8_basic", + "Aten_TrilinearModuleVaryingRanks_basic", + "Aten_TrilinearModuleZerodDimBug_basic", "QuantizedReluInt32_basic", "QuantizedReluInt8_basic", "QuantizedReluUint8_basic", @@ -548,6 +554,9 @@ "_SoftmaxModule_basic", "UpSampleNearest2dDynamicFactor_basic", "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", + "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", + "Aten_TrilinearModuleSumAllDims_basic", + "Aten_TrilinearModuleSumdims_basic", } FX_IMPORTER_STABLEHLO_XFAIL_SET = { @@ -658,6 +667,8 @@ "AtenTopKModule_basic", "AtenTopKSmallestModule_basic", "Aten_EmbeddingBagExample_basic", + "Aten_TrilinearModuleVaryingRanks_basic", + "Aten_TrilinearModuleZerodDimBug_basic", "AvgPool2dDivisorOverrideModule_basic", "BernoulliTensorModule_basic", "BincountMinlengthModule_basic", @@ -945,11 +956,6 @@ "AtenItemIntOpModule_basic", "CrossEntropyLossModule_basic", "CrossEntropyLossNoReductionModule_basic", - "EinsumStaticContractRhsModule_basic", - "EinsumStaticFourDimensionModule_basic", - "EinsumStaticModule_basic", - "EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic", - "EinsumStaticWithEllipsisSlicingModule_basic", "ElementwiseExpm1IntModule_basic", "ElementwiseExpm1Module_basic", "InterpolateDynamicModule_sizes_nearest", @@ -996,6 +1002,9 @@ # materialization callback produced value of incorrect type failed "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", + "Aten_TrilinearModuleSumdims_basic", + "Aten_TrilinearModuleSumAllDims_basic", + "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", } STABLEHLO_PASS_SET = { @@ -3254,6 +3263,12 @@ "Unfold_Module_Rank_Zero_Size_Zero_basic", "Unfold_Module_Dynamic_basic", "ViewDtypeStaticModule_basic", + "Aten_TrilinearModule_basic", + "Aten_TrilinearModuleSumdims_basic", + "Aten_TrilinearModuleSumAllDims_basic", + "Aten_TrilinearModuleVaryingRanks_basic", + "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", + "Aten_TrilinearModuleZerodDimBug_basic", } if torch_version_for_comparison() < version.parse("2.3.0.dev"): @@ -4075,6 +4090,12 @@ "AtenSubFloatModule_basic", "AtenTopKModule_basic", "AtenTopKSmallestModule_basic", + "Aten_TrilinearModule_basic", + "Aten_TrilinearModuleSumdims_basic", + "Aten_TrilinearModuleSumAllDims_basic", + "Aten_TrilinearModuleVaryingRanks_basic", + "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", + "Aten_TrilinearModuleZerodDimBug_basic", "AtenTrilModule_basic", "AtenTrilWithNegDiagonalModule_basic", "AtenTrilWithPosDiagonalModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index d632e9815443..5b772fe26ee9 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1286,6 +1286,44 @@ def aten〇unflatten〇int〡shape(self: List[int], dim: int, sizes: List[int]) def aten〇linear〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None) -> List[int]: return upstream_shape_functions.linear(input, weight, bias) +@check_shape_function([ + Invocation(TensorOfShape(3, 3, 3), TensorOfShape(3, 3, 3), TensorOfShape(3, 3, 3), [], [], [], [], 0), # Basic case + Invocation(TensorOfShape(4, 5, 6), TensorOfShape(4, 5, 6), TensorOfShape(4, 5, 6), [1], [0], [0], [], 2), # Expansions w/ Non-Zero unroll_dim + Invocation(TensorOfShape(3, 3, 3), TensorOfShape(3, 3, 3), TensorOfShape(3, 3, 3), [1, 2], [1, 2], [1, 2], [1, 2], 0), # Multiple expansions + Invocation(TensorOfShape(3, 3, 3), TensorOfShape(3, 3, 3), TensorOfShape(3, 3, 3), [1, 2], [2, 1], [1, 2], [1, 2], 0), # Unordered expansion + ErrorInvocation(TensorOfShape(4, 5, 1), TensorOfShape(4, 5, 3), TensorOfShape(1, 5, 3), [], [], [0], [2], 0), # Num dimensions don't match +]) +def aten〇_trilinear〡shape(i1: List[int], i2: List[int], i3: List[int], expand1: List[int], expand2: List[int], expand3: List[int], sumdim: List[int], unroll_dim: int = 1) -> List[int]: + total_dims = len(i1) + len(expand1) + + assert unroll_dim >= 0 and unroll_dim < total_dims, f"unroll_dim must be in [0, {total_dims - 1}]" + + i1_copy = upstream_shape_functions._copy(i1) + i2_copy = upstream_shape_functions._copy(i2) + i3_copy = upstream_shape_functions._copy(i3) + + # Expand dimensions based on args + inputs = [i1_copy, i2_copy, i3_copy] + expands = [expand1, expand2, expand3] + for index, expand in enumerate(expands): + size = len(inputs[index]) + for dim in expand: + assert dim <= size, f"expand dimension {dim} is out of bounds for input of shape {inputs[index]}" + inputs[index].insert(dim, 1) + + assert len(i1_copy) == len(i2_copy) == len(i3_copy), "number of dimensions must match" + + output_shape = upstream_shape_functions.broadcast_three(i1_copy, i2_copy, i3_copy) + sumdim_bools = [False] * len(output_shape) + for dim in sumdim: + sumdim_bools[dim] = True + + for i in range(len(output_shape) - 1, -1, -1): + if sumdim_bools[i]: + output_shape = upstream_shape_functions._reduce_along_dim(output_shape, i, False) + + return output_shape + @check_shape_function([ Invocation(TensorOfShape(3, 2, 8, 4), TensorOfShape(3, 2, 8, 4), TensorOfShape(3, 2, 8, 4)), # Same shape Invocation(TensorOfShape(3, 2, 16, 8), TensorOfShape(3, 2, 8, 8), TensorOfShape(3, 2, 8, 4)), # Different shape @@ -5332,6 +5370,21 @@ def aten〇linear〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: promoted_dtype = promote_dtypes(ranks, dtypes) return promoted_dtype +@check_dtype_function( + _check_tensors_with_the_same_dtype(3, None, None, None, expand1 = [], expand2 = [], expand3 = [], sumdim = [], unroll_dim = 0), +) +def aten〇_trilinear〡dtype(i1_rank_dtype: Tuple[int, int], i2_rank_dtype: Tuple[int, int], i3_rank_dtype: Tuple[int, int], expand1: List[int], expand2: List[int], expand3: List[int], sumdim: List[int], unroll_dim: int = 1) -> int: + i1_rank, i1_dtype = i1_rank_dtype + i2_rank, i2_dtype = i2_rank_dtype + i3_rank, i3_dtype = i3_rank_dtype + + ranks: List[Optional[int]] = [i1_rank, i2_rank, i3_rank] + dtypes = [i1_dtype, i2_dtype, i3_dtype] + return promote_dtypes( + ranks, + dtypes, + ) + @check_dtype_function( [Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32)]), Invocation([NonZeroDTensorWithDtype(torch.float16), NonZeroDTensorWithDtype(torch.float64)]), diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index e5dcc913527f..f8e67bfa5e06 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -1013,6 +1013,9 @@ def emit_with_mutating_variants(key, **kwargs): "aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?, bool) -> (Tensor)" ) emit("aten::grid_sampler : (Tensor, Tensor, int, int, bool) -> (Tensor)") + emit( + "aten::_trilinear : (Tensor, Tensor, Tensor, int[], int[], int[], int[], int) -> (Tensor)" + ) # Dict ops. emit("aten::__contains__.str : (Dict(str, t), str) -> (bool)", has_folder=True) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py index 9e2d2693b62b..a8820f59c373 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -1674,6 +1674,9 @@ def Rot90NegativeEvenRotationsModule_basic(module, tu: TestUtils): module.forward(tu.rand(6, 5, 1, 7, 3)) +# ============================================================================== + + class Unfold_Module(torch.nn.Module): def __init__(self): super().__init__() @@ -1772,3 +1775,173 @@ def forward(self, x): @register_test_case(module_factory=lambda: Unfold_Module_Dynamic()) def Unfold_Module_Dynamic_basic(module, tu: TestUtils): module.forward(tu.rand(6, 4, 4, 4)) + + +# ============================================================================== + + +class Aten_TrilinearModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3, 3, 3], torch.float32, True), + ([3, 3, 3], torch.float32, True), + ([3, 3, 3], torch.float32, True), + ] + ) + def forward(self, i1, i2, i3): + return torch.ops.aten._trilinear( + i1, i2, i3, expand1=[], expand2=[], expand3=[], sumdim=[], unroll_dim=0 + ) + + +@register_test_case(module_factory=lambda: Aten_TrilinearModule()) +def Aten_TrilinearModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 3, 3), tu.rand(3, 3, 3), tu.rand(3, 3, 3)) + + +class Aten_TrilinearModuleSumdims(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 6], torch.float32, True), + ([2, 3, 6], torch.float32, True), + ([2, 3, 6], torch.float32, True), + ] + ) + def forward(self, i1, i2, i3): + return torch.ops.aten._trilinear( + i1, i2, i3, expand1=[1], expand2=[], expand3=[], sumdim=[0, 2], unroll_dim=0 + ) + + +@register_test_case(module_factory=lambda: Aten_TrilinearModuleSumdims()) +def Aten_TrilinearModuleSumdims_basic(module, tu: TestUtils): + return module.forward(tu.rand(2, 6), tu.rand(2, 3, 6), tu.rand(2, 3, 6)) + + +class Aten_TrilinearModuleSumAllDims(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 6], torch.float32, True), + ([2, 3, 6], torch.float32, True), + ([2, 3, 6], torch.float32, True), + ] + ) + def forward(self, i1, i2, i3): + return torch.ops.aten._trilinear( + i1, + i2, + i3, + expand1=[1], + expand2=[], + expand3=[], + sumdim=[0, 1, 2], + unroll_dim=0, + ) + + +@register_test_case(module_factory=lambda: Aten_TrilinearModuleSumAllDims()) +def Aten_TrilinearModuleSumAllDims_basic(module, tu: TestUtils): + return module.forward(tu.rand(2, 6), tu.rand(2, 3, 6), tu.rand(2, 3, 6)) + + +class Aten_TrilinearModuleVaryingRanks(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 6], torch.float32, True), + ([2, 3, 6], torch.float32, True), + ([6], torch.float32, True), + ] + ) + def forward(self, i1, i2, i3): + return torch.ops.aten._trilinear( + i1, + i2, + i3, + expand1=[1], + expand2=[], + expand3=[0, 1], + sumdim=[0], + unroll_dim=0, + ) + + +@register_test_case(module_factory=lambda: Aten_TrilinearModuleVaryingRanks()) +def Aten_TrilinearModuleVaryingRanks_basic(module, tu: TestUtils): + return module.forward(tu.rand(2, 6), tu.rand(2, 3, 6), tu.rand(6)) + + +class Aten_TrilinearModuleVaryingRanksUnorderedExpands(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 6], torch.float32, True), + ([2, 3, 6], torch.float32, True), + ([6], torch.float32, True), + ] + ) + def forward(self, i1, i2, i3): + return torch.ops.aten._trilinear( + i1, + i2, + i3, + expand1=[1], + expand2=[], + expand3=[1, 0], + sumdim=[2, 0], + unroll_dim=0, + ) + + +@register_test_case( + module_factory=lambda: Aten_TrilinearModuleVaryingRanksUnorderedExpands() +) +def Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic(module, tu: TestUtils): + return module.forward(tu.rand(2, 6), tu.rand(2, 3, 6), tu.rand(6)) + + +class Aten_TrilinearModuleZerodDimBug(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3, 6], torch.float32, True), + ([2, 3, 6], torch.float32, True), + ([2, 3, 6], torch.float32, True), + ] + ) + def forward(self, i1, i2, i3): + return torch.ops.aten._trilinear( + i1, i2, i3, expand1=[0], expand2=[0], expand3=[0], sumdim=[2], unroll_dim=0 + ) + + +@register_test_case(module_factory=lambda: Aten_TrilinearModuleZerodDimBug()) +def Aten_TrilinearModuleZerodDimBug_basic(module, tu: TestUtils): + return module.forward(tu.rand(2, 3, 6), tu.rand(2, 3, 6), tu.rand(2, 3, 6))