Skip to content

Commit

Permalink
add helper lambda fn
Browse files Browse the repository at this point in the history
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
  • Loading branch information
raikonenfnu committed Oct 29, 2024
1 parent d6cb7fd commit 1c245b4
Showing 1 changed file with 38 additions and 25 deletions.
63 changes: 38 additions & 25 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,25 +309,29 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target,
GPUMatmulShapeType problem{bounds[mDim], bounds[nDim], bounds[kDim],
lhsElemType, rhsElemType, initElemType};

// Helper fn to store mma information.
auto storeMmaInfo = [](IREE::GPU::MMAAttr mma,
SmallVector<GPUMatmulShapeType> &intrinsics,
SmallVector<IREE::GPU::MMAAttr> &mmaAttrs) {
auto [mSize, nSize, kSize] = mma.getMNKShape();
auto [aType, bType, cType] = mma.getABCElementTypes();
intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType);
mmaAttrs.emplace_back(mma);
};

SmallVector<GPUMatmulShapeType> intrinsics;
intrinsics.reserve(target.getWgp().getMma().size());
SmallVector<IREE::GPU::MMAAttr> mmaAttrs;
MLIRContext *context = op.getContext();
for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) {
if (mma.getSubgroupSize() != targetSubgroupSize)
continue;
auto [mSize, nSize, kSize] = mma.getMNKShape();
auto [aType, bType, cType] = mma.getABCElementTypes();
intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType);
mmaAttrs.emplace_back(mma);
storeMmaInfo(mma, intrinsics, mmaAttrs);
// Store info on virtual intrinsics based on current mma if any
for (IREE::GPU::MMAIntrinsic virtualIntrinsic :
mma.getVirtualIntrinsics()) {
auto virtualMma = IREE::GPU::MMAAttr::get(context, virtualIntrinsic);
auto [mSize, nSize, kSize] = virtualMma.getMNKShape();
auto [aType, bType, cType] = virtualMma.getABCElementTypes();
intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType);
mmaAttrs.emplace_back(virtualMma);
storeMmaInfo(virtualMma, intrinsics, mmaAttrs);
}
}

Expand Down Expand Up @@ -515,25 +519,29 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target,
GPUMatmulShapeType problem{bounds[mDim], bounds[nDim], bounds[kDim],
lhsElemType, rhsElemType, initElemType};

// Helper fn to store mma information.
auto storeMmaInfo = [](IREE::GPU::MMAAttr mma,
SmallVector<GPUMatmulShapeType> &intrinsics,
SmallVector<IREE::GPU::MMAAttr> &mmaAttrs) {
auto [mSize, nSize, kSize] = mma.getMNKShape();
auto [aType, bType, cType] = mma.getABCElementTypes();
intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType);
mmaAttrs.emplace_back(mma);
};

SmallVector<GPUMatmulShapeType> intrinsics;
intrinsics.reserve(target.getWgp().getMma().size());
SmallVector<IREE::GPU::MMAAttr> mmaAttrs;
MLIRContext *context = op.getContext();
for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) {
if (mma.getSubgroupSize() != targetSubgroupSize)
continue;
auto [mSize, nSize, kSize] = mma.getMNKShape();
auto [aType, bType, cType] = mma.getABCElementTypes();
intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType);
mmaAttrs.emplace_back(mma);
storeMmaInfo(mma, intrinsics, mmaAttrs);
// Store info on virtual intrinsics based on current mma if any
for (IREE::GPU::MMAIntrinsic virtualIntrinsic :
mma.getVirtualIntrinsics()) {
auto virtualMma = IREE::GPU::MMAAttr::get(context, virtualIntrinsic);
auto [mSize, nSize, kSize] = virtualMma.getMNKShape();
auto [aType, bType, cType] = virtualMma.getABCElementTypes();
intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType);
mmaAttrs.emplace_back(virtualMma);
storeMmaInfo(virtualMma, intrinsics, mmaAttrs);
}
}

Expand Down Expand Up @@ -727,27 +735,32 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target,
Value kMatrix = op.getKey();
Value vMatrix = op.getValue();

// Helper fn to store mma information.
auto storeMmaInfo = [](IREE::GPU::MMAAttr mma,
SmallVector<GPUMatmulShapeType> &intrinsics,
SmallVector<IREE::GPU::MMAAttr> &mmaAttrs) {
auto [mSize, nSize, kSize] = mma.getMNKShape();
auto [aType, bType, cType] = mma.getABCElementTypes();
intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType);
mmaAttrs.emplace_back(mma);
};

SmallVector<GPUMatmulShapeType> intrinsics;
intrinsics.reserve(target.getWgp().getMma().size());
SmallVector<IREE::GPU::MMAAttr> mmaAttrs;
MLIRContext *context = op.getContext();
for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) {
if (mma.getSubgroupSize() != targetSubgroupSize)
continue;
auto [mSize, nSize, kSize] = mma.getMNKShape();
auto [aType, bType, cType] = mma.getABCElementTypes();
intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType);
mmaAttrs.emplace_back(mma);
// Store info on virtual intrinsics based on current mma if any.
storeMmaInfo(mma, intrinsics, mmaAttrs);
// Store info on virtual intrinsics based on current mma if any
for (IREE::GPU::MMAIntrinsic virtualIntrinsic :
mma.getVirtualIntrinsics()) {
auto virtualMma = IREE::GPU::MMAAttr::get(context, virtualIntrinsic);
auto [mSize, nSize, kSize] = virtualMma.getMNKShape();
auto [aType, bType, cType] = virtualMma.getABCElementTypes();
intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType);
mmaAttrs.emplace_back(virtualMma);
storeMmaInfo(virtualMma, intrinsics, mmaAttrs);
}
}

if (intrinsics.empty())
return failure();

Expand Down

0 comments on commit 1c245b4

Please sign in to comment.