diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index 10fe8f47170f..5426df90de48 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -309,6 +309,16 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target, GPUMatmulShapeType problem{bounds[mDim], bounds[nDim], bounds[kDim], lhsElemType, rhsElemType, initElemType}; + // Helper fn to store mma information. + auto storeMmaInfo = [](IREE::GPU::MMAAttr mma, + SmallVector &intrinsics, + SmallVector &mmaAttrs) { + auto [mSize, nSize, kSize] = mma.getMNKShape(); + auto [aType, bType, cType] = mma.getABCElementTypes(); + intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType); + mmaAttrs.emplace_back(mma); + }; + SmallVector intrinsics; intrinsics.reserve(target.getWgp().getMma().size()); SmallVector mmaAttrs; @@ -316,18 +326,12 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target, for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) { if (mma.getSubgroupSize() != targetSubgroupSize) continue; - auto [mSize, nSize, kSize] = mma.getMNKShape(); - auto [aType, bType, cType] = mma.getABCElementTypes(); - intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType); - mmaAttrs.emplace_back(mma); + storeMmaInfo(mma, intrinsics, mmaAttrs); // Store info on virtual intrinsics based on current mma if any for (IREE::GPU::MMAIntrinsic virtualIntrinsic : mma.getVirtualIntrinsics()) { auto virtualMma = IREE::GPU::MMAAttr::get(context, virtualIntrinsic); - auto [mSize, nSize, kSize] = virtualMma.getMNKShape(); - auto [aType, bType, cType] = virtualMma.getABCElementTypes(); - intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType); - mmaAttrs.emplace_back(virtualMma); + storeMmaInfo(virtualMma, intrinsics, mmaAttrs); } } @@ -515,6 +519,16 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, GPUMatmulShapeType problem{bounds[mDim], bounds[nDim], bounds[kDim], lhsElemType, rhsElemType, initElemType}; + // Helper fn to store mma information. + auto storeMmaInfo = [](IREE::GPU::MMAAttr mma, + SmallVector &intrinsics, + SmallVector &mmaAttrs) { + auto [mSize, nSize, kSize] = mma.getMNKShape(); + auto [aType, bType, cType] = mma.getABCElementTypes(); + intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType); + mmaAttrs.emplace_back(mma); + }; + SmallVector intrinsics; intrinsics.reserve(target.getWgp().getMma().size()); SmallVector mmaAttrs; @@ -522,18 +536,12 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) { if (mma.getSubgroupSize() != targetSubgroupSize) continue; - auto [mSize, nSize, kSize] = mma.getMNKShape(); - auto [aType, bType, cType] = mma.getABCElementTypes(); - intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType); - mmaAttrs.emplace_back(mma); + storeMmaInfo(mma, intrinsics, mmaAttrs); // Store info on virtual intrinsics based on current mma if any for (IREE::GPU::MMAIntrinsic virtualIntrinsic : mma.getVirtualIntrinsics()) { auto virtualMma = IREE::GPU::MMAAttr::get(context, virtualIntrinsic); - auto [mSize, nSize, kSize] = virtualMma.getMNKShape(); - auto [aType, bType, cType] = virtualMma.getABCElementTypes(); - intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType); - mmaAttrs.emplace_back(virtualMma); + storeMmaInfo(virtualMma, intrinsics, mmaAttrs); } } @@ -727,6 +735,16 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target, Value kMatrix = op.getKey(); Value vMatrix = op.getValue(); + // Helper fn to store mma information. + auto storeMmaInfo = [](IREE::GPU::MMAAttr mma, + SmallVector &intrinsics, + SmallVector &mmaAttrs) { + auto [mSize, nSize, kSize] = mma.getMNKShape(); + auto [aType, bType, cType] = mma.getABCElementTypes(); + intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType); + mmaAttrs.emplace_back(mma); + }; + SmallVector intrinsics; intrinsics.reserve(target.getWgp().getMma().size()); SmallVector mmaAttrs; @@ -734,20 +752,15 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target, for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) { if (mma.getSubgroupSize() != targetSubgroupSize) continue; - auto [mSize, nSize, kSize] = mma.getMNKShape(); - auto [aType, bType, cType] = mma.getABCElementTypes(); - intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType); - mmaAttrs.emplace_back(mma); - // Store info on virtual intrinsics based on current mma if any. + storeMmaInfo(mma, intrinsics, mmaAttrs); + // Store info on virtual intrinsics based on current mma if any for (IREE::GPU::MMAIntrinsic virtualIntrinsic : mma.getVirtualIntrinsics()) { auto virtualMma = IREE::GPU::MMAAttr::get(context, virtualIntrinsic); - auto [mSize, nSize, kSize] = virtualMma.getMNKShape(); - auto [aType, bType, cType] = virtualMma.getABCElementTypes(); - intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType); - mmaAttrs.emplace_back(virtualMma); + storeMmaInfo(virtualMma, intrinsics, mmaAttrs); } } + if (intrinsics.empty()) return failure();