Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support aten._trilinear and improve einsum decomposition #3784

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -14091,6 +14091,36 @@ def Torch_AtenGridSamplerOp : Torch_Op<"aten.grid_sampler", [
}];
}

def Torch_Aten_TrilinearOp : Torch_Op<"aten._trilinear", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::_trilinear : (Tensor, Tensor, Tensor, int[], int[], int[], int[], int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$i1,
AnyTorchTensorType:$i2,
AnyTorchTensorType:$i3,
AnyTorchListOfTorchIntType:$expand1,
AnyTorchListOfTorchIntType:$expand2,
AnyTorchListOfTorchIntType:$expand3,
AnyTorchListOfTorchIntType:$sumdim,
Torch_IntType:$unroll_dim
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten_TrilinearOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 8, 1);
}
void Aten_TrilinearOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 8, 1);
}
}];
}

def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
115 changes: 115 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8852,6 +8852,112 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.linear(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten._trilinear\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.list<int>, %arg7: !torch.int) -> !torch.list<int> {\n"
" %int3 = torch.constant.int 3\n"
" %int-1 = torch.constant.int -1\n"
" %str = torch.constant.str \"AssertionError: number of dimensions must match\"\n"
" %str_0 = torch.constant.str \"expand dimension {} is out of bounds for input of shape {}\"\n"
" %true = torch.constant.bool true\n"
" %none = torch.constant.none\n"
" %str_1 = torch.constant.str \"AssertionError: \"\n"
" %str_2 = torch.constant.str \"unroll_dim must be in [0, {}]\"\n"
" %false = torch.constant.bool false\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.len.t %arg3 : !torch.list<int> -> !torch.int\n"
" %2 = torch.aten.add.int %0, %1 : !torch.int, !torch.int -> !torch.int\n"
" %3 = torch.aten.ge.int %arg7, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %4 = torch.prim.If %3 -> (!torch.bool) {\n"
" %23 = torch.aten.lt.int %arg7, %2 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %23 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %4 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" %23 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %24 = torch.aten.format(%str_2, %23) : !torch.str, !torch.int -> !torch.str\n"
" %25 = torch.aten.add.str %str_1, %24 : !torch.str, !torch.str -> !torch.str\n"
" torch.prim.RaiseException %25, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %5 = call @__torch__.torch.jit._shape_functions._copy(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" %6 = call @__torch__.torch.jit._shape_functions._copy(%arg1) : (!torch.list<int>) -> !torch.list<int>\n"
" %7 = call @__torch__.torch.jit._shape_functions._copy(%arg2) : (!torch.list<int>) -> !torch.list<int>\n"
" %8 = torch.prim.ListConstruct %5, %6, %7 : (!torch.list<int>, !torch.list<int>, !torch.list<int>) -> !torch.list<list<int>>\n"
" %9 = torch.prim.ListConstruct %arg3, %arg4, %arg5 : (!torch.list<int>, !torch.list<int>, !torch.list<int>) -> !torch.list<list<int>>\n"
" torch.prim.Loop %int3, %true, init() {\n"
" ^bb0(%arg8: !torch.int):\n"
" %23 = torch.aten.__getitem__.t %9, %arg8 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
" %24 = torch.aten.__getitem__.t %8, %arg8 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
" %25 = torch.aten.len.t %24 : !torch.list<int> -> !torch.int\n"
" %26 = torch.aten.len.t %23 : !torch.list<int> -> !torch.int\n"
" torch.prim.Loop %26, %true, init() {\n"
" ^bb0(%arg9: !torch.int):\n"
" %27 = torch.aten.__getitem__.t %23, %arg9 : !torch.list<int>, !torch.int -> !torch.int\n"
" %28 = torch.aten.le.int %27, %25 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %28 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" %30 = torch.aten.__getitem__.t %8, %arg8 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
" %31 = torch.aten.format(%str_0, %27, %30) : !torch.str, !torch.int, !torch.list<int> -> !torch.str\n"
" %32 = torch.aten.add.str %str_1, %31 : !torch.str, !torch.str -> !torch.str\n"
" torch.prim.RaiseException %32, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %29 = torch.aten.__getitem__.t %8, %arg8 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
" torch.aten.insert.t %29, %27, %int1 : !torch.list<int>, !torch.int, !torch.int\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %10 = torch.aten.len.t %5 : !torch.list<int> -> !torch.int\n"
" %11 = torch.aten.len.t %6 : !torch.list<int> -> !torch.int\n"
" %12 = torch.aten.eq.int %10, %11 : !torch.int, !torch.int -> !torch.bool\n"
" %13 = torch.prim.If %12 -> (!torch.bool) {\n"
" %23 = torch.aten.len.t %6 : !torch.list<int> -> !torch.int\n"
" %24 = torch.aten.len.t %7 : !torch.list<int> -> !torch.int\n"
" %25 = torch.aten.eq.int %23, %24 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %25 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %13 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %14 = call @__torch__.torch.jit._shape_functions.broadcast_three(%5, %6, %7) : (!torch.list<int>, !torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" %15 = torch.prim.ListConstruct %false : (!torch.bool) -> !torch.list<bool>\n"
" %16 = torch.aten.len.t %14 : !torch.list<int> -> !torch.int\n"
" %17 = torch.operator \"aten.mul.left_t\"(%15, %16) : (!torch.list<bool>, !torch.int) -> !torch.list<bool> \n"
" %18 = torch.aten.len.t %arg6 : !torch.list<int> -> !torch.int\n"
" torch.prim.Loop %18, %true, init() {\n"
" ^bb0(%arg8: !torch.int):\n"
" %23 = torch.aten.__getitem__.t %arg6, %arg8 : !torch.list<int>, !torch.int -> !torch.int\n"
" %24 = torch.aten._set_item.t %17, %23, %true : !torch.list<bool>, !torch.int, !torch.bool -> !torch.list<bool>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %19 = torch.aten.len.t %14 : !torch.list<int> -> !torch.int\n"
" %20 = torch.aten.sub.int %19, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %21 = torch.aten.__range_length %20, %int-1, %int-1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
" %22 = torch.prim.Loop %21, %true, init(%14) {\n"
" ^bb0(%arg8: !torch.int, %arg9: !torch.list<int>):\n"
" %23 = torch.aten.__derive_index %arg8, %20, %int-1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
" %24 = torch.aten.__getitem__.t %17, %23 : !torch.list<bool>, !torch.int -> !torch.bool\n"
" %25 = torch.prim.If %24 -> (!torch.list<int>) {\n"
" %26 = func.call @__torch__.torch.jit._shape_functions._reduce_along_dim(%arg9, %23, %false) : (!torch.list<int>, !torch.int, !torch.bool) -> !torch.list<int>\n"
" torch.prim.If.yield %26 : !torch.list<int>\n"
" } else {\n"
" torch.prim.If.yield %arg9 : !torch.list<int>\n"
" }\n"
" torch.prim.Loop.condition %true, iter(%25 : !torch.list<int>)\n"
" } : (!torch.int, !torch.bool, !torch.list<int>) -> !torch.list<int>\n"
" return %22 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.scaled_dot_product_attention\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.optional<float>, %arg7: !torch.bool) -> !torch.list<int> {\n"
" %int-1 = torch.constant.int -1\n"
" %0 = torch.aten.__getitem__.t %arg2, %int-1 : !torch.list<int>, !torch.int -> !torch.int\n"
Expand Down Expand Up @@ -15206,6 +15312,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten._trilinear\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.list<int>, %arg7: !torch.int) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %3 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %4 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %5 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.cat\"(%arg0: !torch.list<tuple<int, int>>, %arg1: !torch.int) -> !torch.int {\n"
" %true = torch.constant.bool true\n"
" %none = torch.constant.none\n"
Expand Down
161 changes: 152 additions & 9 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "PassDetail.h"

#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
Expand Down Expand Up @@ -399,9 +400,9 @@ static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc,
auto inputType = cast<ValueTensorType>(input.getType());
auto inputRank = batchDimsLength + contractingDimsLength + otherDimsLength +
reduceDimsLength;
SmallVector<Value> inputShapeTensor;
SmallVector<OpFoldResult> inputShapeTensor;
for (auto i = 0; i < inputRank; ++i) {
inputShapeTensor.emplace_back(rewriter.create<AtenSizeIntOp>(
inputShapeTensor.emplace_back(rewriter.createOrFold<AtenSizeIntOp>(
loc, input,
rewriter.create<Torch::ConstantIntOp>(loc,
rewriter.getI64IntegerAttr(i))));
Expand All @@ -412,13 +413,23 @@ static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc,
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
auto dimOffset = 0;

auto materializeIntFold = [&](OpFoldResult thing) {
if (auto attr = dyn_cast<mlir::Attribute>(thing)) {
Value result = rewriter.create<Torch::ConstantIntOp>(
loc, cast<mlir::IntegerAttr>(attr));
return result;
}
return cast<mlir::Value>(thing);
};

auto appendDims = [&](int64_t dimLength) {
Value prod = constOne;
OpFoldResult prod = getAsOpFoldResult(constOne);
for (auto i = 0; i < dimLength; ++i) {
prod = rewriter.create<AtenMulIntOp>(loc, prod,
inputShapeTensor[i + dimOffset]);
prod = rewriter.createOrFold<AtenMulIntOp>(
loc, materializeIntFold(prod),
materializeIntFold(inputShapeTensor[i + dimOffset]));
}
outShapeTensor.emplace_back(prod);
outShapeTensor.emplace_back(materializeIntFold(prod));
dimOffset += dimLength;
};

Expand Down Expand Up @@ -570,21 +581,32 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
Type outputDType = lhsType.hasDtype() ? lhsType.getOptionalDtype()
: rhsType.getOptionalDtype();

auto materializeIntFold = [&](OpFoldResult thing) {
if (auto attr = dyn_cast<mlir::Attribute>(thing)) {
Value result = rewriter.create<Torch::ConstantIntOp>(
loc, cast<mlir::IntegerAttr>(attr));
return result;
}
return cast<mlir::Value>(thing);
};

llvm::SmallDenseMap<char, Value> lhsDimShapeMap;
for (size_t idx = 0; idx < lhsTokens.size(); ++idx) {
char d = lhsTokens[idx];
lhsDimShapeMap[d] = rewriter.create<AtenSizeIntOp>(
OpFoldResult lhsFold = rewriter.createOrFold<AtenSizeIntOp>(
loc, lhs,
rewriter.create<Torch::ConstantIntOp>(loc,
rewriter.getI64IntegerAttr(idx)));
lhsDimShapeMap[d] = materializeIntFold(lhsFold);
}
llvm::SmallDenseMap<char, Value> rhsDimShapeMap;
for (size_t idx = 0; idx < rhsTokens.size(); ++idx) {
char d = rhsTokens[idx];
rhsDimShapeMap[d] = rewriter.create<AtenSizeIntOp>(
OpFoldResult rhsFold = rewriter.createOrFold<AtenSizeIntOp>(
loc, rhs,
rewriter.create<Torch::ConstantIntOp>(loc,
rewriter.getI64IntegerAttr(idx)));
rhsDimShapeMap[d] = materializeIntFold(rhsFold);
}

// parse batch, contracting, other, reduce dims of lhs and rhs
Expand All @@ -604,8 +626,9 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
bool lhsContains = lhsDimShapeMap.count(d) > 0;
bool rhsContains = rhsDimShapeMap.count(d) > 0;
if (lhsContains && rhsContains) {
outDimShapeMap[d] = rewriter.create<Torch::PrimMaxIntOp>(
OpFoldResult out = rewriter.createOrFold<Torch::PrimMaxIntOp>(
loc, lhsDimShapeMap[d], rhsDimShapeMap[d]);
outDimShapeMap[d] = materializeIntFold(out);
} else if (lhsContains) {
outDimShapeMap[d] = lhsDimShapeMap[d];
} else if (rhsContains) {
Expand Down Expand Up @@ -1973,6 +1996,125 @@ class DecomposeAtenEinsumOp : public OpRewritePattern<AtenEinsumOp> {
};
} // namespace

namespace {
// Trilinear einstein sum, decomposed to:
// (i1.unsqueeze(expand1) * i2.unsqueeze(expand2) * i3.unsqueeze(expand3))
// .sum(sumdim)
// The unrollDim operand does not impact the output of the operation, so
// it is ignored.

class DecomposeAten_TrilinearOp : public OpRewritePattern<Aten_TrilinearOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_TrilinearOp op,
PatternRewriter &rewriter) const override {

Location loc = op.getLoc();

Value input1 = op.getI1();
Value input2 = op.getI2();
Value input3 = op.getI3();

// Expansions
SmallVector<int64_t> expand1;
SmallVector<int64_t> expand2;
SmallVector<int64_t> expand3;
if (!matchPattern(op.getExpand1(), m_TorchListOfConstantInts(expand1))) {
return rewriter.notifyMatchFailure(op, "expand1 should be constant");
}
if (!matchPattern(op.getExpand2(), m_TorchListOfConstantInts(expand2))) {
return rewriter.notifyMatchFailure(op, "expand2 should be constant");
}
if (!matchPattern(op.getExpand3(), m_TorchListOfConstantInts(expand3))) {
return rewriter.notifyMatchFailure(op, "expand3 should be constant");
}
stbaione marked this conversation as resolved.
Show resolved Hide resolved

SmallVector<int64_t> sumDim;
if (!matchPattern(op.getSumdim(), m_TorchListOfConstantInts(sumDim))) {
return rewriter.notifyMatchFailure(op, "sumDim should be constant");
}

// Check if there are any dimensions that intersect between expand1,
// expand2, and expand3.
int64_t totalDims =
cast<BaseTensorType>(input1.getType()).getSizes().size() +
expand1.size();
if (sharedExpandDims(totalDims, expand1, expand2, expand3, sumDim)) {
// pytorch issue filed: https://github.com/pytorch/pytorch/issues/138353
// TODO: Remove warning when issue gets resolved.
op->emitWarning("aten::_trilinear implementation in this case is "
"non-functional (returns an empty dimension). We will "
"intentionally deviate from this behavior.");
}

// Apply unsqueeze to respective input tensors at the specified dimensions
SmallVector<int64_t> sortedExpand1 = expand1;
std::sort(sortedExpand1.begin(), sortedExpand1.end());
for (auto expand : sortedExpand1) {
Value expandDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(expand));
input1 = *unsqueezeTensor(rewriter, op, input1, expandDim);
}
stbaione marked this conversation as resolved.
Show resolved Hide resolved
SmallVector<int64_t> sortedExpand2 = expand2;
std::sort(sortedExpand2.begin(), sortedExpand2.end());
for (auto expand : sortedExpand2) {
Value expandDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(expand));
input2 = *unsqueezeTensor(rewriter, op, input2, expandDim);
}
SmallVector<int64_t> sortedExpand3 = expand3;
std::sort(sortedExpand3.begin(), sortedExpand3.end());
for (auto expand : sortedExpand3) {
Value expandDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(expand));
input3 = *unsqueezeTensor(rewriter, op, input3, expandDim);
}

// Apply multiplication operation.
auto mul1 =
rewriter.create<AtenMulTensorOp>(loc, op.getType(), input1, input2);
auto mul2 =
rewriter.create<AtenMulTensorOp>(loc, op.getType(), mul1, input3);

// Apply sum operation.
// Parse sumDim in descending order to avoid any issues with the
// dimensions being removed.
Value result = mul2;
SmallVector<int64_t> sortedSumDims = sumDim;
std::sort(sortedSumDims.rbegin(), sortedSumDims.rend());
for (int64_t dim : sortedSumDims) {
Value dimValue = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(dim));
result =
createSumAlongDimension(rewriter, loc, op, result, dimValue, false);
}
stbaione marked this conversation as resolved.
Show resolved Hide resolved

rewriter.replaceOp(op, result);
return success();
}

private:
// Determine if there are any dimensions that intersect between expand1,
// expand2, and expand3.
bool sharedExpandDims(const int64_t &totalDims,
const SmallVector<int64_t> &expand1,
const SmallVector<int64_t> &expand2,
const SmallVector<int64_t> &expand3,
const SmallVector<int64_t> &sumDim) const {
for (int64_t i = 0; i < totalDims; ++i) {
if (!contains(sumDim, i) && contains(expand1, i) &&
contains(expand2, i) && contains(expand3, i)) {
return true;
}
}
return false;
}
bool contains(const SmallVector<int64_t> &vec, int64_t value) const {
return std::find(vec.begin(), vec.end(), value) != vec.end();
}
};
} // namespace

namespace {
// Calculate the trace of the input tensor as the sum over its diagonal
// elements. This computation is performed as:
Expand Down Expand Up @@ -9928,6 +10070,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast1dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast2dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEinsumOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_TrilinearOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTraceOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenHardswishOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftplusOp>(patterns);
Expand Down
Loading
Loading