diff --git a/compiler/include/byteir/Dialect/GPU/Transforms/Utils.h b/compiler/include/byteir/Dialect/GPU/Transforms/Utils.h index 782de0ae7..79d3ef645 100644 --- a/compiler/include/byteir/Dialect/GPU/Transforms/Utils.h +++ b/compiler/include/byteir/Dialect/GPU/Transforms/Utils.h @@ -75,6 +75,16 @@ static constexpr StringRef getCopyFromSharedMemoryAccMarker() { return "__byteir_store_matrix_c__"; }; +static constexpr StringRef getMatmulMainLoopMarker() { + return "__byteir_main_loop__"; +} + +constexpr StringRef getLinalgMMALevelAttrName() { + return "__byteir_mma_level__"; +} + +constexpr StringRef getMMAPatternAttrName() { return "__byteir_mma__"; } + static constexpr StringRef getEpilogueMarker() { return "__byteir_epilogue__"; } std::optional> getGemmTileSize(func::FuncOp funcOp); diff --git a/compiler/lib/Dialect/Linalg/Transforms/CanonicalizeMatmulEpilogue.cpp b/compiler/lib/Dialect/Linalg/Transforms/CanonicalizeMatmulEpilogue.cpp index 1670a1df0..8ecdab386 100644 --- a/compiler/lib/Dialect/Linalg/Transforms/CanonicalizeMatmulEpilogue.cpp +++ b/compiler/lib/Dialect/Linalg/Transforms/CanonicalizeMatmulEpilogue.cpp @@ -44,8 +44,11 @@ modifyUseToGetValueIntoStoreSet(RewriterBase &rewriter, OpOperand *inOperand = nullptr; OpOperand *initOperand = nullptr; for (OpOperand *in : genericOp.getDpsInputOperands()) { - // if operand is generated by a scf.for, then it's a result of matmul - if (isa(in->get().getDefiningOp())) { + // if operand is generated by a op which has MainLoop Marker or it's a + // linalg.matmul + if (hasMarker( + in->get().getDefiningOp(), + ArrayRef{getMatmulMainLoopMarker(), getMMAPatternAttrName()})) { inOperand = in; } else { newInputs.push_back(in->get()); @@ -122,6 +125,7 @@ class CanonicalizeMatmulEpiloguePass // modify the epilogue to get the value into the store set if (failed(modifyUseToGetValueIntoStoreSet(rewriter, epilogueOp))) { + llvm::errs() << "failed in modifyUseToGetValueIntoStoreSet\n"; return signalPassFailure(); } diff --git a/compiler/lib/Pipelines/GPU/GemmCodegen.cpp b/compiler/lib/Pipelines/GPU/GemmCodegen.cpp index 86064a66d..fdf6e1b19 100644 --- a/compiler/lib/Pipelines/GPU/GemmCodegen.cpp +++ b/compiler/lib/Pipelines/GPU/GemmCodegen.cpp @@ -43,12 +43,6 @@ namespace { constexpr StringRef getLinalgToGPUAttrName() { return "__byteir_to_gpu__"; } -constexpr StringRef getLinalgMMALevelAttrName() { - return "__byteir_mma_level__"; -} - -constexpr StringRef getMMAPatternAttrName() { return "__byteir_mma__"; } - constexpr StringRef getLinalgTargetAttrName() { return "__byteir_target__"; } void createGPUTileGemmTransformImpl(OpPassManager &pm, @@ -168,6 +162,14 @@ void createGPUTileGemmTransformImpl(OpPassManager &pm, auto tileKMatmulOp = b.create(tiledMatmulOp, reductionTileSizes); auto matmulKOp = tileKMatmulOp.getTiledLinalgOp(); + auto forLoops = tileKMatmulOp.getLoops(); + if (!forLoops.empty()) { + b.create(forLoops[0], getMatmulMainLoopMarker(), + Value()); + } else { + b.create(matmulKOp, getMatmulMainLoopMarker(), + Value()); + } b.create(matmulKOp, getLinalgMMALevelAttrName(), mmaLevel);