diff --git a/build_tools/ci/cpu_comparison/matmul_template/matmul_truncf_MxK_KxN.mlir b/build_tools/ci/cpu_comparison/matmul_template/matmul_truncf_MxK_KxN.mlir new file mode 100644 index 000000000..b5c71d965 --- /dev/null +++ b/build_tools/ci/cpu_comparison/matmul_template/matmul_truncf_MxK_KxN.mlir @@ -0,0 +1,13 @@ +// input ${M}x${K}x${TYPE1} +// input ${K}x${N}x${TYPE1} + +func.func @matmul_truncf(%arg0: tensor<${M}x${K}x${TYPE1}>, %arg1: tensor<${K}x${N}x${TYPE1}>) -> tensor<${M}x${N}x${TYPE1}> +{ + %cst = arith.constant ${ZERO} : ${TYPE2} + %0 = tensor.empty() : tensor<${M}x${N}x${TYPE2}> + %1 = linalg.fill ins(%cst : ${TYPE2}) outs(%0 : tensor<${M}x${N}x${TYPE2}>) -> tensor<${M}x${N}x${TYPE2}> + %2 = linalg.matmul ins(%arg0, %arg1 : tensor<${M}x${K}x${TYPE1}>, tensor<${K}x${N}x${TYPE1}>) + outs(%1: tensor<${M}x${N}x${TYPE2}>) -> tensor<${M}x${N}x${TYPE2}> + %3 = arith.truncf %2 : tensor<${M}x${N}x${TYPE2}> to tensor<${M}x${N}x${TYPE1}> + return %3: tensor<${M}x${N}x${TYPE1}> +} diff --git a/build_tools/ci/cpu_comparison/run.py b/build_tools/ci/cpu_comparison/run.py index dc63990dc..fdacf7cbc 100755 --- a/build_tools/ci/cpu_comparison/run.py +++ b/build_tools/ci/cpu_comparison/run.py @@ -758,6 +758,31 @@ def run(self, config): lower_to_aie_pipeline="objectFifo", ) + # Test(s) of the form matmul(A,B) + truncf(C) where A:MxK, B:KxN and C:MxN + test_name = output_dir / f"test_from_template_matmul_truncf.mlir" + template_name = matmul_template_dir / "matmul_truncf_MxK_KxN.mlir" + generate_matmul_test(test_name, template_name, 8, 8, 8, "bf16", "f32") + identity_mat = np.eye(8, dtype=np.float32) + ones = np.ones(8 * 8, dtype=np.float32).reshape([8, 8]) + lhs = ones * 101 + rhs = identity_mat * 3 + input_args = generate_inputs(test_name, output_dir, 1, {1: lhs, 2: rhs}) + aie_vs_baseline( + config, + test_name, + input_args, + ones * 302, # exected output + use_ukernel=False, + tile_pipeline="pack-peel", + lower_to_aie_pipeline="objectFifo", + function_name=None, + seed=1, + rtol=0, + atol=0, + n_repeats=1, + output_type=get_output_type(test_name), + ) + class SmokeSet(TestSet): def __init__(self): diff --git a/compiler/plugins/target/AMD-AIE/aievec/VectorToVectorConversions.cpp b/compiler/plugins/target/AMD-AIE/aievec/VectorToVectorConversions.cpp index 6488fbc1b..fed15b694 100644 --- a/compiler/plugins/target/AMD-AIE/aievec/VectorToVectorConversions.cpp +++ b/compiler/plugins/target/AMD-AIE/aievec/VectorToVectorConversions.cpp @@ -400,6 +400,52 @@ struct ToMinorIdentityTransferReadPattern } }; +// clang-format off +/// Pattern to linearize arith.truncf because later aievec.srs in AIEVecToLLVM is +/// expected to have 1-D source and target. +/// Refer: https://github.com/nod-ai/iree-amd-aie/blob/main/compiler/plugins/target/AMD-AIE/aievec/AIEVecToLLVM.cpp#L73-L74 +/// +/// Example of what this pattern achieves :- +/// INPUT +/// %0 = arith.truncf %inp : vector<2x3xf32> to vector<2x3xbf16> +/// OUTPUT +/// %0 = vector.shape_cast %inp : vector<2x3xf32> to vector<6xf32> +/// %1 = arith.truncf %0 : vector<6xf32> to vector<6xbf16> +/// %2 = vector.shape_cast %1 : vector<6xbf16> to vector<2x3xbf16> +// clang-format on +struct FlattenArithTruncFOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::TruncFOp op, + PatternRewriter &rewriter) const override { + // Get old shape type. + auto oldShapedType = dyn_cast(op.getType()); + if (!oldShapedType) return failure(); + // Bail out if it's already linearized. + if (oldShapedType.getRank() == 1) return failure(); + // Linearize the shape. + int64_t linearizedSize = oldShapedType.getNumElements(); + // Fetch input. + Value origInputOfTruncFOp = op.getIn(); + // Form linearized vector shape type for input and output. + VectorType newVectorTypeForInput = VectorType::get( + {linearizedSize}, + cast(origInputOfTruncFOp.getType()).getElementType()); + VectorType newVectorTypeForOutput = + VectorType::get({linearizedSize}, oldShapedType.getElementType()); + // Shape cast the original input to linearized shape type. + Value newInputVector = rewriter.create( + op.getLoc(), newVectorTypeForInput, origInputOfTruncFOp); + // Create new base operation with the linearized input/output. + Value newTruncFOp = rewriter.create( + op.getLoc(), newVectorTypeForOutput, newInputVector); + // Delinearize the output back to the original type. + rewriter.replaceOpWithNewOp(op, op.getType(), + newTruncFOp); + return success(); + } +}; + // This pattern extracts an implicit transposition of the 2 innermost // dimensions of `rhs` in a gemm-like contraction op, making it an explicit // `vector.transpose` op. @@ -591,16 +637,16 @@ struct CanonicalizeVectorForAIEVecPass } { RewritePatternSet patterns(context); - patterns.add(context); + patterns + .add(context); patterns.add(context); mlir::vector::populateFlattenVectorTransferPatterns(patterns); mlir::vector::populateVectorBroadcastLoweringPatterns(patterns); (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } - { // These must run after 'populateFlattenVectorTransferPatterns' because // vector.shape_casts are introduced. Merging into a single pass creates diff --git a/compiler/plugins/target/AMD-AIE/aievec/test/precanonicalization-aieml-llvmir.mlir b/compiler/plugins/target/AMD-AIE/aievec/test/precanonicalization-aieml-llvmir.mlir index f330315b1..bc80f51e6 100644 --- a/compiler/plugins/target/AMD-AIE/aievec/test/precanonicalization-aieml-llvmir.mlir +++ b/compiler/plugins/target/AMD-AIE/aievec/test/precanonicalization-aieml-llvmir.mlir @@ -154,3 +154,16 @@ func.func @noncontiguous_write(%v: vector<4x8xi8>) { vector.transfer_write %v, %alloc[%c0, %c0] : vector<4x8xi8>, memref<4x10xi8> return } + +// ----- + +// CHECK-LABEL: @arith_truncf( +// CHECK-SAME: %[[INP:.*]]: vector<2x3xf32>) +func.func @arith_truncf(%inp: vector<2x3xf32>) -> vector<2x3xbf16> { + // CHECK: %[[LINEARIZE:.*]] = vector.shape_cast %[[INP]] : vector<2x3xf32> to vector<6xf32> + // CHECK: %[[TRUNCF:.*]] = arith.truncf %[[LINEARIZE]] : vector<6xf32> to vector<6xbf16> + // CHECK: %[[DELINEARIZE:.*]] = vector.shape_cast %[[TRUNCF]] : vector<6xbf16> to vector<2x3xbf16> + // CHECK: return %[[DELINEARIZE]] + %0 = arith.truncf %inp : vector<2x3xf32> to vector<2x3xbf16> + return %0 : vector<2x3xbf16> +} diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEFlattenVectorizedOps.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEFlattenVectorizedOps.cpp deleted file mode 100644 index 28a18a461..000000000 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEFlattenVectorizedOps.cpp +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2024 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-amd-aie/IR/AMDAIEOps.h" -#include "iree-amd-aie/Transforms/Passes.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" - -#define DEBUG_TYPE "iree-amdaie-flatten-vectorized-ops" - -namespace mlir::iree_compiler::AMDAIE { - -namespace { - -class AMDAIEFlattenVectorizedOpsPass - : public impl::AMDAIEFlattenVectorizedOpsBase< - AMDAIEFlattenVectorizedOpsPass> { - public: - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override; -}; - -void AMDAIEFlattenVectorizedOpsPass::runOnOperation() { - MLIRContext *context = &getContext(); - ModuleOp moduleOp = getOperation(); - IRRewriter rewriter(context); - // TODO(avarma): Currently this is fixated on just `arith.truncf`. Follow-up - // on this later to generalize. - moduleOp->walk([&](arith::TruncFOp op) { - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(op); - // Get old shape type. - auto oldShapedType = cast(op.getType()); - // Linearize the shape. - int64_t linearizedSize = oldShapedType.getNumElements(); - // Fetch input(s). - Value origInputOfTruncFOp = op.getIn(); - // Form linearized vector shape type for input and output. - VectorType newVectorTypeForInput = VectorType::get( - {linearizedSize}, - cast(origInputOfTruncFOp.getType()).getElementType()); - VectorType newVectorTypeForOutput = - VectorType::get({linearizedSize}, oldShapedType.getElementType()); - // Shape cast the original input to linearized shape type. - Value newInputVector = rewriter.create( - op.getLoc(), newVectorTypeForInput, origInputOfTruncFOp); - // Create new base operation with the linearized input/output. - Value newTruncFOp = rewriter.create( - op.getLoc(), newVectorTypeForOutput, newInputVector); - // Delinearize the output back to the original type. - Value newOutputVector = rewriter.create( - op.getLoc(), op.getType(), newTruncFOp); - rewriter.replaceOp(op, newOutputVector); - }); -} - -} // namespace - -std::unique_ptr createAMDAIEFlattenVectorizedOpsPass() { - return std::make_unique(); -} -} // namespace mlir::iree_compiler::AMDAIE diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEInsertCores.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEInsertCores.cpp index 84c756dd7..e5a797650 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEInsertCores.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEInsertCores.cpp @@ -29,6 +29,19 @@ namespace mlir::iree_compiler::AMDAIE { namespace { +/// Utility which returns 'true' is the operation needs to be inserted with an +/// `amdaie.core` op. +/// Some ops are surrrounded by scf.for loop nests. Place the entire +/// loop nest inside the amdaie.core op here. Currently look for a +/// subset of ops which we know should be in the core. +/// TODO(newling) improve this design. +static bool isCoreComputeOp(Operation *op) { + return isa( + op); +} + /// Utility to map the parallel mapping attributes to the corresponding /// induction variables. void getAttributeMapping(SmallVector forallOps, @@ -128,14 +141,7 @@ LogicalResult insertCoreOps(mlir::ModuleOp moduleOp) { coreOp.setLinkWith(fnDecl->getAttrOfType("link_with")); } - // Some ops are surrrounded by scf.for loop nests. Place the entire - // loop nest inside the amdaie.core op here. Currently look for a - // subset of ops which we know should be in the core. - // TODO(newling) improve this design. - bool insertInCore = - isa(op) || isa(op) || - isa(op) || isa(op); - if (insertInCore) { + if (isCoreComputeOp(op)) { // Most distant ancestor of 'op' that's a strict descendant of // 'forallOp'. Operation *ancestor = op; diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEInsertLoopsForVectorization.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEInsertLoopsForVectorization.cpp index 0b0a2fe0d..91de3639f 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEInsertLoopsForVectorization.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEInsertLoopsForVectorization.cpp @@ -66,12 +66,31 @@ class AMDAIEInsertLoopsForVectorizationPass // Return success if the generic op is rewritten, failure otherwise. LogicalResult maybeRewrite(linalg::GenericOp genericOp, IRRewriter &rewriter) { + if (isa(genericOp)) return failure(); + auto iteratorTypes = genericOp.getIteratorTypesArray(); auto numIterators = iteratorTypes.size(); - // No outer dimensions to tile if fewer than 4 iterators. - if (numIterators < 4) return failure(); - + // No outer dimensions to tile if fewer than 3 iterators. + if (numIterators < 3) return failure(); + + // Enable generating loops for vectorization in case of element-wise ops. + // We tile all but the innermost two dimensions currently because they form + // the smallest tiled M x N dimension of the matmul. + if (llvm::all_of(iteratorTypes, [&](mlir::utils::IteratorType iterator) { + return linalg::isParallelIterator(iterator); + })) { + assert(numIterators >= 2 && "expected at least 2 iterators here"); + SmallVector tileSizes(numIterators, 1); + tileSizes[numIterators - 2] = 0; + tileSizes[numIterators - 1] = 0; + auto opts = linalg::LinalgTilingOptions().setTileSizes(tileSizes); + auto tiled = linalg::tileLinalgOp(rewriter, genericOp, opts); + const auto &loops = tiled.value().loops; + assert(!loops.empty() && "expected at least one loop here"); + rewriter.replaceOp(genericOp, loops[0]->getResult(0)); + return success(); + } // Matmul-like ops have 3 operands. if (genericOp->getNumOperands() != 3) return failure(); diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEVectorization.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEVectorization.cpp index cec0c635b..c0ba285e0 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEVectorization.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEVectorization.cpp @@ -84,9 +84,18 @@ void AMDAIEVectorizationPass::runOnOperation() { // edge case so just disable. See // https://github.com/nod-ai/iree-amd-aie/issues/594 for more info. if (isElementwise(cast(op))) { - op->emitRemark() << "not vectorizing linalg elementwise op"; - return; + // TODO(avarma): Currently switching vectorization on only for + // arith.truncf. Improve this later by trying to bridge the + // gap between this pass and vector-to-aievec. + for (Operation &innerOps : + cast(op).getBody()->getOperations()) { + if (!isa(innerOps)) { + op->emitRemark() << "not vectorizing linalg elementwise op"; + return; + } + } } + // AIE architecture has no vector instructions for 32/64-bit types. if (!hasOperandWithSmallElementType(op)) return; diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/CMakeLists.txt b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/CMakeLists.txt index b4234f621..1c91b4d90 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/CMakeLists.txt +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/CMakeLists.txt @@ -69,7 +69,6 @@ iree_cc_library( "AMDAIEDmaToCircularDma.cpp" "AMDAIEDmaUtils.cpp" "AMDAIEFlattenLogicalObjectFifo.cpp" - "AMDAIEFlattenVectorizedOps.cpp" "AMDAIEFuseConsumerIntoLoop.cpp" "AMDAIEFuseFillIntoForall.cpp" "AMDAIEFusePackIntoLoop.cpp" diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/PassDetail.h b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/PassDetail.h index b2ee057f4..5deb1c72d 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/PassDetail.h +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/PassDetail.h @@ -46,7 +46,6 @@ namespace mlir::iree_compiler::AMDAIE { #define GEN_PASS_DEF_AMDAIEDMALOOPSUBSUMPTION #define GEN_PASS_DEF_AMDAIEDMATOCIRCULARDMA #define GEN_PASS_DEF_AMDAIEFLATTENLOGICALOBJECTFIFO -#define GEN_PASS_DEF_AMDAIEFLATTENVECTORIZEDOPS #define GEN_PASS_DEF_AMDAIEFUSECONSUMERINTOLOOP #define GEN_PASS_DEF_AMDAIEFUSEFILLINTOFORALL #define GEN_PASS_DEF_AMDAIEFUSEPACKINTOLOOP diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.h b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.h index 6ced4ff21..79618316d 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.h +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.h @@ -158,9 +158,6 @@ std::unique_ptr createAMDAIEDmaToCircularDmaPass(); /// Create a pass to flatten the logical objectFifos. std::unique_ptr createAMDAIEFlattenLogicalObjectFifoPass(); -/// Create a pass to flatten vectorized ops. -std::unique_ptr createAMDAIEFlattenVectorizedOpsPass(); - /// Create a pass to fuse the consumer op into the innermost last scf loop. std::unique_ptr createAMDAIEFuseConsumerIntoLoopPass( AMDAIEFuseConsumerIntoLoopOptions options = {}); diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.td b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.td index 48c184359..bb8256ea6 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.td +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.td @@ -222,12 +222,6 @@ def AMDAIEFlattenLogicalObjectFifo : let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEFlattenLogicalObjectFifoPass()"; } -def AMDAIEFlattenVectorizedOps : - Pass<"iree-amdaie-flatten-vectorized-ops", "ModuleOp"> { - let summary = "Flatten vectorized ops."; - let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEFlattenVectorizedOpsPass()"; -} - def AMDAIEFuseConsumerIntoLoop : InterfacePass<"iree-amdaie-fuse-consumer-into-loop", "mlir::FunctionOpInterface"> { let summary = "Fuse the consumer operation into the innermost last scf loop."; diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/CMakeLists.txt b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/CMakeLists.txt index a8856b2f9..c083125bd 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/CMakeLists.txt +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/CMakeLists.txt @@ -34,7 +34,6 @@ iree_lit_test_suite( "dma_loop_subsumption.mlir" "dma_to_circular_dma.mlir" "flatten_logical_objectfifo.mlir" - "flatten_vectorized_ops.mlir" "fuse_consumer_into_loop_scf_for.mlir" "fuse_consumer_into_loop_scf_forall.mlir" "fuse_fill_into_forall.mlir" diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/flatten_vectorized_ops.mlir b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/flatten_vectorized_ops.mlir deleted file mode 100644 index 385415190..000000000 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/flatten_vectorized_ops.mlir +++ /dev/null @@ -1,32 +0,0 @@ - -// RUN: iree-opt --pass-pipeline="builtin.module(iree-amdaie-flatten-vectorized-ops)" --split-input-file %s | FileCheck %s - -// CHECK-LABEL: @flatten_truncf -module { - func.func @flatten_truncf() attributes {translation_info = #iree_codegen.translation_info} { - %cst = arith.constant 0.000000e+00 : f32 - %c0 = arith.constant 0 : index - %c2 = arith.constant 2 : index - amdaie.workgroup { - %arg0 = memref.alloc() : memref<1x1x4x4x4x8xf32, 2 : i32> - %arg1 = memref.alloc() : memref<1x1x4x4x4x8xbf16, 2 : i32> - %tile_2 = amdaie.tile(%c0, %c2) - %0 = amdaie.core(%tile_2, in : [], out : []) { - // CHECK: %[[READ:.*]] = vector.transfer_read - // CHECK: %[[LINEARIZE:.*]] = vector.shape_cast %[[READ]] : vector<1x1x1x1x4x8xf32> to vector<32xf32> - // CHECK: %[[TRUNCF:.*]] = arith.truncf %[[LINEARIZE]] : vector<32xf32> to vector<32xbf16> - // CHECK: %[[DELINEARIZE:.*]] = vector.shape_cast %[[TRUNCF]] : vector<32xbf16> to vector<1x1x1x1x4x8xbf16> - // CHECK: vector.transfer_write %[[DELINEARIZE]] - // CHECK: amdaie.end - %1 = vector.transfer_read %arg0[%c0,%c0,%c0,%c0,%c0,%c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x4x4x4x8xf32, 2 : i32>, vector<1x1x1x1x4x8xf32> - %2 = arith.truncf %1 : vector<1x1x1x1x4x8xf32> to vector<1x1x1x1x4x8xbf16> - vector.transfer_write %2, %arg1[%c0,%c0,%c0,%c0,%c0,%c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x8xbf16>, memref<1x1x4x4x4x8xbf16, 2 : i32> - amdaie.end - } - amdaie.controlcode { - amdaie.end - } - } - return - } -} diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/insert_cores.mlir b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/insert_cores.mlir index 171551ceb..808f6d188 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/insert_cores.mlir +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/insert_cores.mlir @@ -300,3 +300,28 @@ module { return } } + +// ----- + +// CHECK-LABEL: @insert_truncf_within_core +// CHECK: scf.forall +// CHECK: amdaie.tile +// CHECK: amdaie.core +// CHECK: vector.transfer_read +// CHECK: arith.truncf +// CHECK: vector.transfer_write +// CHECK: amdaie.end +module { + func.func @insert_truncf_within_core(%arg0: memref<10x10xf32, 2 : i32>, %arg1: memref<10x10xbf16, 2 : i32>) { + %cst = arith.constant 0.000000e+00 : f32 + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %c0 = arith.constant 0 : index + scf.forall (%arg3, %arg4) in (2, 2) { + %read = vector.transfer_read %arg0[%c0, %c1], %cst {in_bounds = [true, true]} : memref<10x10xf32, 2 : i32>, vector<1x1xf32> + %truncf = arith.truncf %read : vector<1x1xf32> to vector<1x1xbf16> + vector.transfer_write %truncf, %arg1[%c0, %c1] {in_bounds = [true, true]} : vector<1x1xbf16>, memref<10x10xbf16, 2 : i32> + } {mapping = [#gpu.thread, #gpu.thread]} + return + } +} diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/insert_loops_for_vectorization.mlir b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/insert_loops_for_vectorization.mlir index 57c2cae52..96ac9048d 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/insert_loops_for_vectorization.mlir +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/insert_loops_for_vectorization.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-amdaie-insert-loops-for-vectorization))" %s | FileCheck %s +// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-amdaie-insert-loops-for-vectorization))" --split-input-file %s | FileCheck %s !t2_bf16 = tensor<64x64xbf16> !t3_bf16 = tensor<64x64x64xbf16> @@ -137,25 +137,6 @@ module { return %0 : !t3_f32 } - - // A check that a linalg.generic where the number of operands is not 3, does - // not get transformed to have an scf.for - // CHECK-LABEL: funcWithTwoOperands - // CHECK-NOT: scf.for - func.func @funcWithTwoOperands(%arg0: !t4_bf16, %arg1: !t4_bf16) -> !t4_bf16 { - %0 = linalg.generic {indexing_maps = - [ - affine_map<(b0, d0, d1, d2) -> (b0, d0, d1, d2)>, - affine_map<(b0, d0, d1, d2) -> (d0, d1, d2, b0)> - ], - iterator_types = ["parallel", "parallel", "parallel", "parallel"]} - ins(%arg0 : !t4_bf16) outs(%arg1 : !t4_bf16) { - ^bb0(%in: bf16, %out: bf16): - linalg.yield %in : bf16 - } -> !t4_bf16 - return %0 : !t4_bf16 - } - // Check that the final 3 dimensions do have the pattern of a matmul (or matmul transpose) // CHECK-LABEL: batched1 // CHECK-NOT: scf.for @@ -246,4 +227,33 @@ module { } } +// ----- +// CHECK-LABEL: @element_wise +// CHECK-SAME: (%[[ARG0:.*]]: tensor<4x6x8xf32>, %[[ARG1:.*]]: tensor<4x6x8xbf16>) +module { + func.func @element_wise(%arg0: tensor<4x6x8xf32>, %arg1: tensor<4x6x8xbf16>) -> tensor<4x6x8xbf16>{ + %cst = arith.constant 0.000000e+00 : bf16 + // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index + // CHECK: scf.for %[[IV:.*]] = %{{.*}} to %[[C4]] + // CHECK-SAME: iter_args(%[[ARG3:.*]] = %[[ARG1]]) + // CHECK-NOT: scf.for + // CHECK: tensor.extract_slice %[[ARG0]][%[[IV]], 0, 0] [1, 6, 8] [1, 1, 1] + // CHECK: tensor.extract_slice %[[ARG3]][%[[IV]], 0, 0] [1, 6, 8] [1, 1, 1] + // CHECK: %[[RES:.*]] = linalg.generic + // CHECK: tensor.insert_slice %[[RES]] into %[[ARG3]][%[[IV]], 0, 0] [1, 6, 8] [1, 1, 1] + %0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)> + ], + iterator_types = ["parallel", "parallel", "parallel"] + } ins(%arg0 : tensor<4x6x8xf32>) + outs(%arg1 : tensor<4x6x8xbf16>) { + ^bb0(%in: f32, %out: bf16): + %1 = arith.truncf %in : f32 to bf16 + %2 = arith.maximumf %1, %cst : bf16 + linalg.yield %2 : bf16 + } -> tensor<4x6x8xbf16> + return %0 : tensor<4x6x8xbf16> + } +} diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/vectorization.mlir b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/vectorization.mlir index ade6a48e0..4f6c95dfe 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/vectorization.mlir +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/vectorization.mlir @@ -64,12 +64,16 @@ func.func @fillAndCopy() -> tensor<8xbf16> { } -func.func @matmul_elementwise(%3 : tensor<4240x160xi8>, %ele : tensor<160xi8>) -> tensor<4240x160xi8> { - // expected-remark @below {{not vectorizing linalg elementwise op}} - %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%3, %ele : tensor<4240x160xi8>, tensor<160xi8>) outs(%3 : tensor<4240x160xi8>) { - ^bb0(%in: i8, %in_5: i8, %out: i8): - %10 = arith.addi %in, %in_5 : i8 - linalg.yield %10 : i8 - } -> tensor<4240x160xi8> - return %9 : tensor<4240x160xi8> +// CHECK-LABEL: @matmul_elementwise +// CHECK-SAME: (%[[ARG0:.*]]: tensor<4240x160xf32>, %[[ARG1:.*]]: tensor<4240x160xbf16>) +func.func @matmul_elementwise(%arg0: tensor<4240x160xf32>, %arg1: tensor<4240x160xbf16>) -> tensor<4240x160xbf16> { + %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%arg0: tensor<4240x160xf32>) outs(%arg1 : tensor<4240x160xbf16>) { + ^bb0(%in: f32, %out: bf16): + %1 = arith.truncf %in : f32 to bf16 + linalg.yield %1 : bf16 + } -> tensor<4240x160xbf16> + return %0 : tensor<4240x160xbf16> } +// CHECK: %[[VEC_OPERAND_0:.*]] = vector.transfer_read %[[ARG0]]{{.*}} vector<4240x160xf32> +// CHECK: %[[TRUNCF:.*]] = arith.truncf %[[VEC_OPERAND_0]] +// CHECK: vector.transfer_write %[[TRUNCF]], %[[ARG1]]