Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Xinyu302 committed Aug 15, 2024
1 parent e929786 commit aa396a0
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 8 deletions.
10 changes: 10 additions & 0 deletions compiler/include/byteir/Dialect/GPU/Transforms/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@ static constexpr StringRef getCopyFromSharedMemoryAccMarker() {
return "__byteir_store_matrix_c__";
};

static constexpr StringRef getMatmulMainLoopMarker() {
return "__byteir_main_loop__";
}

constexpr StringRef getLinalgMMALevelAttrName() {
return "__byteir_mma_level__";
}

constexpr StringRef getMMAPatternAttrName() { return "__byteir_mma__"; }

static constexpr StringRef getEpilogueMarker() { return "__byteir_epilogue__"; }

std::optional<SmallVector<int64_t, 3>> getGemmTileSize(func::FuncOp funcOp);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,11 @@ modifyUseToGetValueIntoStoreSet(RewriterBase &rewriter,
OpOperand *inOperand = nullptr;
OpOperand *initOperand = nullptr;
for (OpOperand *in : genericOp.getDpsInputOperands()) {
// if operand is generated by a scf.for, then it's a result of matmul
if (isa<scf::ForOp>(in->get().getDefiningOp())) {
// if operand is generated by a op which has MainLoop Marker or it's a
// linalg.matmul
if (hasMarker(
in->get().getDefiningOp(),
ArrayRef{getMatmulMainLoopMarker(), getMMAPatternAttrName()})) {
inOperand = in;
} else {
newInputs.push_back(in->get());
Expand Down Expand Up @@ -122,6 +125,7 @@ class CanonicalizeMatmulEpiloguePass

// modify the epilogue to get the value into the store set
if (failed(modifyUseToGetValueIntoStoreSet(rewriter, epilogueOp))) {
llvm::errs() << "failed in modifyUseToGetValueIntoStoreSet\n";
return signalPassFailure();
}

Expand Down
14 changes: 8 additions & 6 deletions compiler/lib/Pipelines/GPU/GemmCodegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,6 @@ namespace {

constexpr StringRef getLinalgToGPUAttrName() { return "__byteir_to_gpu__"; }

constexpr StringRef getLinalgMMALevelAttrName() {
return "__byteir_mma_level__";
}

constexpr StringRef getMMAPatternAttrName() { return "__byteir_mma__"; }

constexpr StringRef getLinalgTargetAttrName() { return "__byteir_target__"; }

void createGPUTileGemmTransformImpl(OpPassManager &pm,
Expand Down Expand Up @@ -168,6 +162,14 @@ void createGPUTileGemmTransformImpl(OpPassManager &pm,
auto tileKMatmulOp =
b.create<transform::TileUsingForOp>(tiledMatmulOp, reductionTileSizes);
auto matmulKOp = tileKMatmulOp.getTiledLinalgOp();
auto forLoops = tileKMatmulOp.getLoops();
if (!forLoops.empty()) {
b.create<transform::AnnotateOp>(forLoops[0], getMatmulMainLoopMarker(),
Value());
} else {
b.create<transform::AnnotateOp>(matmulKOp, getMatmulMainLoopMarker(),
Value());
}

b.create<transform::AnnotateOp>(matmulKOp, getLinalgMMALevelAttrName(),
mmaLevel);
Expand Down

0 comments on commit aa396a0

Please sign in to comment.