Skip to content

Commit

Permalink
[Codegen] Set global read layouts at linalg level
Browse files Browse the repository at this point in the history
  • Loading branch information
Groverkss committed Oct 29, 2024
1 parent 3cf5b65 commit 21e5d57
Show file tree
Hide file tree
Showing 15 changed files with 369 additions and 465 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,17 @@ void promoteOperand(OpBuilder &builder, Operation *op, unsigned index) {
return;
}
}
setLoweringConfig(producer, IREE::GPU::DerivedThreadConfigAttr::get(
builder.getContext()));
return;

bool promoteProducer = true;
if (isa<tensor::PadOp>(producer)) {
promoteProducer = false;
}

if (promoteProducer) {
setLoweringConfig(producer, IREE::GPU::DerivedThreadConfigAttr::get(
builder.getContext()));
return;
}
}

auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
Expand Down
1 change: 0 additions & 1 deletion compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ iree_compiler_cc_library(
"LLVMGPUCastAddressSpaceFunction.cpp",
"LLVMGPUCastTypeToFitMMA.cpp",
"LLVMGPUConfigureTensorLayouts.cpp",
"LLVMGPUConfigureVectorLayouts.cpp",
"LLVMGPUConvolutionToIGEMM.cpp",
"LLVMGPULowerExecutableTarget.cpp",
"LLVMGPUPackSharedMemoryAlloc.cpp",
Expand Down
1 change: 0 additions & 1 deletion compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ iree_cc_library(
"LLVMGPUCastAddressSpaceFunction.cpp"
"LLVMGPUCastTypeToFitMMA.cpp"
"LLVMGPUConfigureTensorLayouts.cpp"
"LLVMGPUConfigureVectorLayouts.cpp"
"LLVMGPUConvolutionToIGEMM.cpp"
"LLVMGPULowerExecutableTarget.cpp"
"LLVMGPUPackSharedMemoryAlloc.cpp"
Expand Down
22 changes: 22 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target,
b.getI64ArrayAttr(workgroupTileSizes));
attrs.emplace_back(StringAttr::get(context, "reduction"),
b.getI64ArrayAttr(reductionTileSizes));
IREE::GPU::LoweringConfigAttr::setPromotedOperandList(context, attrs, {0, 1});

auto configDict = DictionaryAttr::get(context, attrs);
auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict);
Expand Down Expand Up @@ -629,6 +630,7 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target,
b.getI64ArrayAttr(workgroupTileSizes));
attrs.emplace_back(StringAttr::get(context, "reduction"),
b.getI64ArrayAttr(reductionTileSizes));
IREE::GPU::LoweringConfigAttr::setPromotedOperandList(context, attrs, {0, 1});

auto configDict = DictionaryAttr::get(context, attrs);
auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict);
Expand Down Expand Up @@ -824,13 +826,33 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target,
b.getI64ArrayAttr(workgroupTileSizes));
attrs.emplace_back(StringAttr::get(context, "reduction"),
b.getI64ArrayAttr(reductionTileSizes));
IREE::GPU::LoweringConfigAttr::setPromotedOperandList(context, attrs,
{0, 1, 2});

SmallVector<NamedAttribute, 2> qkConfig;
SmallVector<NamedAttribute, 2> pvConfig;

IREE::GPU::LoweringConfigAttr::setPromotedOperandList(context, qkConfig,
{0, 1});
IREE::GPU::LoweringConfigAttr::setPromotedOperandList(context, pvConfig, {1});

SmallVector<NamedAttribute, 2> qkAttrs;
SmallVector<NamedAttribute, 2> pvAttrs;

qkAttrs.emplace_back(b.getNamedAttr("attention_qk_matmul", b.getUnitAttr()));
pvAttrs.emplace_back(b.getNamedAttr("attention_pv_matmul", b.getUnitAttr()));

auto qkConfigDict = b.getDictionaryAttr(qkConfig);
auto pvConfigDict = b.getDictionaryAttr(pvConfig);

auto qkLoweringConfig =
IREE::GPU::LoweringConfigAttr::get(context, qkConfigDict);
auto pvLoweringConfig =
IREE::GPU::LoweringConfigAttr::get(context, pvConfigDict);

qkAttrs.emplace_back(b.getNamedAttr("lowering_config", qkLoweringConfig));
pvAttrs.emplace_back(b.getNamedAttr("lowering_config", pvLoweringConfig));

auto qkAttrDict = b.getDictionaryAttr(qkAttrs);
auto pvAttrDict = b.getDictionaryAttr(pvAttrs);

Expand Down
Loading

0 comments on commit 21e5d57

Please sign in to comment.