From 8314b5ccb30abac44f59672849bb385e212fa84d Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Thu, 24 Aug 2023 18:18:56 -0400 Subject: [PATCH] Support default configs for BATCH_MATMUL_* in MaterializeEncoding pass (#14762) --- .../compiler/Codegen/Common/CPU/BUILD.bazel | 1 + .../Codegen/Common/CPU/CMakeLists.txt | 1 + .../Common/CPU/CPUMaterializeEncodingPass.cpp | 4 +- .../compiler/Codegen/Common/EncodingInfo.h | 11 --- .../MaterializeEncodingIntoPackUnPack.cpp | 46 ----------- .../iree-dialects/BUILD.bazel | 29 +++++-- .../Dialect/LinalgExt/Utils/EncodingUtils.h | 33 ++++++++ .../Dialect/LinalgExt/Passes/CMakeLists.txt | 1 + .../LinalgExt/Passes/MaterializeEncoding.cpp | 29 ++++--- .../Dialect/LinalgExt/Utils/CMakeLists.txt | 8 ++ .../Dialect/LinalgExt/Utils/EncodingUtils.cpp | 63 +++++++++++++++ .../iree_linalg_ext/materialize_encoding.mlir | 77 +++++++++++++++++++ 12 files changed, 227 insertions(+), 76 deletions(-) create mode 100644 llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h create mode 100644 llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel index 88aea51c44d1..3ee0efde5106 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel @@ -60,6 +60,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Codegen/Utils", "//compiler/src/iree/compiler/Dialect/HAL/IR", "//llvm-external-projects/iree-dialects:IREELinalgExtDialect", + "//llvm-external-projects/iree-dialects:IREELinalgExtEncodingUtils", "//llvm-external-projects/iree-dialects:IREELinalgExtTransforms", "//llvm-external-projects/iree-dialects:IREELinalgExtUtils", "@llvm-project//llvm:Support", diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt index d1ed3f3e7c51..2721fbd6234b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt @@ -48,6 +48,7 @@ iree_cc_library( ::PassHeaders ::PassesIncGen IREELinalgExtDialect + IREELinalgExtEncodingUtils IREELinalgExtTransforms IREELinalgExtUtils LLVMSupport diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp index c9558d5a1f69..782fff7d0db0 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp @@ -7,6 +7,7 @@ #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h" #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h" +#include "iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h" #include "iree-dialects/Dialect/LinalgExt/Utils/Utils.h" #include "iree/compiler/Codegen/Common/CPU/PassDetail.h" #include "iree/compiler/Codegen/Common/CPU/Passes.h" @@ -171,7 +172,8 @@ materializeEncodingForTarget(RankedTensorType tensorType, auto user = encoding.getUser().getValue(); auto role = encoding.getRole().getValue(); MatmulTileParams tileParams = chooseMatmulTileParams(user, targetAttr); - auto encodingInfo = chooseEncodingInfoForMatmul(user, role, tileParams); + auto encodingInfo = + IREE::LinalgExt::chooseEncodingInfoForMatmul(user, role, tileParams); auto originalTypeAttr = encoding.getOriginalType(); auto originalType = originalTypeAttr ? originalTypeAttr.getValue().cast() diff --git a/compiler/src/iree/compiler/Codegen/Common/EncodingInfo.h b/compiler/src/iree/compiler/Codegen/Common/EncodingInfo.h index 31159388799a..3774a95833f2 100644 --- a/compiler/src/iree/compiler/Codegen/Common/EncodingInfo.h +++ b/compiler/src/iree/compiler/Codegen/Common/EncodingInfo.h @@ -14,21 +14,10 @@ namespace mlir { namespace iree_compiler { -struct MatmulTileParams { - int64_t M = 1; - int64_t K = 1; - int64_t N = 1; -}; - void adjustTileSizesToNarrowStaticShape( IREE::LinalgExt::MaterializeEncodingInfo &encodingInfo, ArrayRef shape); -IREE::LinalgExt::MaterializeEncodingInfo -chooseEncodingInfoForMatmul(IREE::LinalgExt::EncodingUser user, - IREE::LinalgExt::EncodingRole role, - MatmulTileParams tileParams); - IREE::LinalgExt::MaterializeEncodingValueFn getMaterializeEncodingValueFn(IREE::HAL::ExecutableTargetAttr targetAttr); diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp index b97b4d979c6d..b5d2a1b873d7 100644 --- a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp @@ -269,52 +269,6 @@ struct MaterializeFlowDispatchTensorStoreOp } // namespace -IREE::LinalgExt::MaterializeEncodingInfo -chooseEncodingInfoForMatmul(EncodingUser user, EncodingRole role, - MatmulTileParams tileParams) { - // Start dim of the MxK (LHS), KxN (RHS), or MxN (RESULT) 2D matrix. - int64_t matmulDimBase = 0; - switch (user) { - case EncodingUser::BATCH_MATMUL_F32F32F32: - case EncodingUser::BATCH_MATMUL_F16F16F32: - case EncodingUser::BATCH_MATMUL_F16F16F16: - case EncodingUser::BATCH_MATMUL_BF16BF16F32: - case EncodingUser::BATCH_MATMUL_BF16BF16BF16: - case EncodingUser::BATCH_MATMUL_I8I8I32: - matmulDimBase = 1; - break; - default: - break; - } - - MaterializeEncodingInfo encodingInfo; - encodingInfo.innerDimsPos = {matmulDimBase, matmulDimBase + 1}; - switch (role) { - case (EncodingRole::LHS): { - encodingInfo.innerTileSizes = {tileParams.M, tileParams.K}; - break; - } - case (EncodingRole::RHS): { - encodingInfo.innerTileSizes = {tileParams.N, tileParams.K}; - encodingInfo.innerDimsPos = {matmulDimBase + 1, matmulDimBase}; - encodingInfo.outerDimsPerm = - llvm::to_vector(llvm::seq(0, matmulDimBase)); - encodingInfo.outerDimsPerm.push_back(matmulDimBase + 1); - encodingInfo.outerDimsPerm.push_back(matmulDimBase); - break; - } - case (EncodingRole::RESULT): { - encodingInfo.innerTileSizes = {tileParams.M, tileParams.N}; - break; - } - default: { - assert(false); - return {}; - } - } - return encodingInfo; -} - void adjustTileSizesToNarrowStaticShape(MaterializeEncodingInfo &encodingInfo, ArrayRef shape) { for (size_t i = 0; i < encodingInfo.innerDimsPos.size(); i++) { diff --git a/llvm-external-projects/iree-dialects/BUILD.bazel b/llvm-external-projects/iree-dialects/BUILD.bazel index 1f8a3fd63d3b..a2945ed611a9 100644 --- a/llvm-external-projects/iree-dialects/BUILD.bazel +++ b/llvm-external-projects/iree-dialects/BUILD.bazel @@ -312,12 +312,13 @@ gentbl_cc_library( cc_library( name = "IREELinalgExtUtils", - srcs = glob([ - "lib/Dialect/LinalgExt/Utils/*.cpp", - ]), - hdrs = glob([ - "include/iree-dialects/Dialect/LinalgExt/Utils/*.h", - ]), + srcs = [ + "lib/Dialect/LinalgExt/Utils/Utils.cpp", + ], + hdrs = [ + "include/iree-dialects/Dialect/LinalgExt/Utils/Utils.h", + "include/iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h", + ], includes = ["include"], deps = [ "@llvm-project//llvm:Support", @@ -331,6 +332,21 @@ cc_library( ], ) +cc_library( + name = "IREELinalgExtEncodingUtils", + srcs = [ + "lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp", + ], + hdrs = [ + "include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h", + ], + includes = ["include"], + deps = [ + ":IREELinalgExtDialect", + ":IREELinalgExtUtils", + ], +) + cc_library( name = "IREELinalgExtDialect", srcs = glob([ @@ -437,6 +453,7 @@ cc_library( deps = [ ":IREEInputDialect", ":IREELinalgExtDialect", + ":IREELinalgExtEncodingUtils", ":IREELinalgExtPassIncGen", ":IREELinalgExtUtils", "@llvm-project//llvm:Support", diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h new file mode 100644 index 000000000000..700dce50e025 --- /dev/null +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h @@ -0,0 +1,33 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_DIALECTS_DIALECT_LINALGEXT_UTILS_ENCODING_UTILS_H_ +#define IREE_DIALECTS_DIALECT_LINALGEXT_UTILS_ENCODING_UTILS_H_ + +#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" +#include "iree-dialects/Dialect/LinalgExt/Utils/Utils.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace LinalgExt { + +struct MatmulTileParams { + int64_t M = 1; + int64_t K = 1; + int64_t N = 1; +}; + +MaterializeEncodingInfo +chooseEncodingInfoForMatmul(EncodingUser user, EncodingRole role, + MatmulTileParams tileParams); + +} // namespace LinalgExt +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_DIALECTS_DIALECT_LINALGEXT_UTILS_ENCODING_UTILS_H_ diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/CMakeLists.txt index 27a2e58e2434..5c921daf1a47 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/CMakeLists.txt +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/CMakeLists.txt @@ -16,6 +16,7 @@ add_mlir_library(IREELinalgExtPasses LINK_LIBS PUBLIC IREEInputDialect IREELinalgExtDialect + IREELinalgExtEncodingUtils IREELinalgExtUtils MLIRAffineDialect MLIRIR diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp index e510ebf81ba4..769541bc300d 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp @@ -7,6 +7,7 @@ #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h" #include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h" +#include "iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h" #include "iree-dialects/Dialect/LinalgExt/Utils/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -89,19 +90,23 @@ chooseEncodingInfo(RankedTensorType tensorType) { auto encoding = getEncodingAttr(tensorType); if (!encoding) return failure(); + + auto user = encoding.getUser().getValue(); auto role = encoding.getRole().getValue(); - switch (role) { - case EncodingRole::LHS: - return MaterializeEncodingInfo{{0, 1}, {8, 4}, {}}; - break; - case EncodingRole::RHS: - return MaterializeEncodingInfo{{1, 0}, {8, 4}, {1, 0}}; - break; - case EncodingRole::RESULT: - return MaterializeEncodingInfo{{0, 1}, {8, 8}, {}}; - break; - default: - return failure(); + switch (user) { + case EncodingUser::MATMUL_F32F32F32: + case EncodingUser::MATMUL_F16F16F32: + case EncodingUser::MATMUL_F16F16F16: + case EncodingUser::MATMUL_BF16BF16F32: + case EncodingUser::MATMUL_BF16BF16BF16: + case EncodingUser::MATMUL_I8I8I32: + case EncodingUser::BATCH_MATMUL_F32F32F32: + case EncodingUser::BATCH_MATMUL_F16F16F32: + case EncodingUser::BATCH_MATMUL_F16F16F16: + case EncodingUser::BATCH_MATMUL_BF16BF16F32: + case EncodingUser::BATCH_MATMUL_BF16BF16BF16: + case EncodingUser::BATCH_MATMUL_I8I8I32: + return chooseEncodingInfoForMatmul(user, role, /*tileParams=*/{8, 4, 8}); } } diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/CMakeLists.txt index 85bdc21e5efc..fcff05643681 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/CMakeLists.txt +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/CMakeLists.txt @@ -11,3 +11,11 @@ add_mlir_library(IREELinalgExtUtils MLIRTensorDialect MLIRMemRefDialect ) + +add_mlir_library(IREELinalgExtEncodingUtils + EncodingUtils.cpp + + LINK_LIBS PUBLIC + IREELinalgExtDialect + IREELinalgExtUtils +) diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp new file mode 100644 index 000000000000..cae238d65fdc --- /dev/null +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp @@ -0,0 +1,63 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace LinalgExt { + +MaterializeEncodingInfo +chooseEncodingInfoForMatmul(EncodingUser user, EncodingRole role, + MatmulTileParams tileParams) { + // Start dim of the MxK (LHS), KxN (RHS), or MxN (RESULT) 2D matrix. + int64_t matmulDimBase = 0; + switch (user) { + case EncodingUser::BATCH_MATMUL_F32F32F32: + case EncodingUser::BATCH_MATMUL_F16F16F32: + case EncodingUser::BATCH_MATMUL_F16F16F16: + case EncodingUser::BATCH_MATMUL_BF16BF16F32: + case EncodingUser::BATCH_MATMUL_BF16BF16BF16: + case EncodingUser::BATCH_MATMUL_I8I8I32: + matmulDimBase = 1; + break; + default: + break; + } + + MaterializeEncodingInfo encodingInfo; + encodingInfo.innerDimsPos = {matmulDimBase, matmulDimBase + 1}; + switch (role) { + case (EncodingRole::LHS): { + encodingInfo.innerTileSizes = {tileParams.M, tileParams.K}; + break; + } + case (EncodingRole::RHS): { + encodingInfo.innerTileSizes = {tileParams.N, tileParams.K}; + encodingInfo.innerDimsPos = {matmulDimBase + 1, matmulDimBase}; + encodingInfo.outerDimsPerm = + llvm::to_vector(llvm::seq(0, matmulDimBase)); + encodingInfo.outerDimsPerm.push_back(matmulDimBase + 1); + encodingInfo.outerDimsPerm.push_back(matmulDimBase); + break; + } + case (EncodingRole::RESULT): { + encodingInfo.innerTileSizes = {tileParams.M, tileParams.N}; + break; + } + default: { + assert(false); + return {}; + } + } + return encodingInfo; +} + +} // namespace LinalgExt +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir index 60cc9bedb0f1..7f757aa1e829 100644 --- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir +++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir @@ -163,3 +163,80 @@ func.func @pack_gemm_fill_dynamic(%arg0 : tensor, %arg1 : tensor) -> tensor { + %0 = iree_linalg_ext.set_encoding %arg0 : tensor -> tensor> + %1 = iree_linalg_ext.unset_encoding %0 : tensor> -> tensor + return %1 : tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)> +// CHECK: func @pack_unpack_batch_matmul_lhs( +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG0]], %[[C2]] +// CHECK-DAG: %[[OUTER_D1:.+]] = affine.apply #[[MAP0]]()[%[[D1]]] +// CHECK-DAG: %[[OUTER_D2:.+]] = affine.apply #[[MAP1]]()[%[[D2]]] +// CHECK: %[[PACK_DEST:.+]] = tensor.empty(%[[D0]], %[[OUTER_D1]], %[[OUTER_D2]]) : tensor +// CHECK: %[[PACK:.+]] = tensor.pack +// CHECK-SAME: %[[ARG0]] inner_dims_pos = [1, 2] inner_tiles = [8, 4] into %[[PACK_DEST]] +// CHECK: %[[UNPACK_DEST:.+]] = tensor.empty(%[[D0]], %[[D1]], %[[D2]]) : tensor +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PACK]] inner_dims_pos = [1, 2] inner_tiles = [8, 4] into %[[UNPACK_DEST]] +// CHECK: return %[[UNPACK]] + +// ----- + +func.func @pack_unpack_batch_matmul_rhs(%arg0 : tensor) -> tensor { + %0 = iree_linalg_ext.set_encoding %arg0 : tensor -> tensor> + %1 = iree_linalg_ext.unset_encoding %0 : tensor> -> tensor + return %1 : tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)> +// CHECK: func @pack_unpack_batch_matmul_rhs( +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG0]], %[[C2]] +// CHECK-DAG: %[[OUTER_D1:.+]] = affine.apply #[[MAP0]]()[%[[D2]]] +// CHECK-DAG: %[[OUTER_D2:.+]] = affine.apply #[[MAP1]]()[%[[D1]]] +// CHECK: %[[PACK_DEST:.+]] = tensor.empty(%[[D0]], %[[OUTER_D1]], %[[OUTER_D2]]) : tensor +// CHECK: %[[PACK:.+]] = tensor.pack +// CHECK-SAME: %[[ARG0]] outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [8, 4] into %[[PACK_DEST]] +// CHECK: %[[UNPACK_DEST:.+]] = tensor.empty(%[[D0]], %[[D1]], %[[D2]]) : tensor +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PACK]] outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [8, 4] into %[[UNPACK_DEST]] +// CHECK: return %[[UNPACK]] + +// ----- + +func.func @pack_unpack_batch_matmul_result(%arg0 : tensor) -> tensor { + %0 = iree_linalg_ext.set_encoding %arg0 : tensor -> tensor> + %1 = iree_linalg_ext.unset_encoding %0 : tensor> -> tensor + return %1 : tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)> +// CHECK: func @pack_unpack_batch_matmul_result( +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG0]], %[[C2]] +// CHECK-DAG: %[[OUTER_D1:.+]] = affine.apply #[[MAP0]]()[%[[D1]]] +// CHECK-DAG: %[[OUTER_D2:.+]] = affine.apply #[[MAP0]]()[%[[D2]]] +// CHECK: %[[PACK_DEST:.+]] = tensor.empty(%[[D0]], %[[OUTER_D1]], %[[OUTER_D2]]) : tensor +// CHECK: %[[PACK:.+]] = tensor.pack +// CHECK-SAME: %[[ARG0]] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[PACK_DEST]] +// CHECK: %[[UNPACK_DEST:.+]] = tensor.empty(%[[D0]], %[[D1]], %[[D2]]) : tensor +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PACK]] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[UNPACK_DEST]] +// CHECK: return %[[UNPACK]]