Skip to content

Commit

Permalink
[Codegen][LLVMGPU] Set global read layouts at linalg level (iree-org#…
Browse files Browse the repository at this point in the history
…18860)

- Operand promotion is now done the same way as TileAndFuse pipeline, by
reading promote_operands config from lowering_config.
- Moves global read layout setting to LLVMGPUConfigureTensorLayouts,
from LLVMGPUConfigureVectorLayouts pass, anchoring based on lowering
config.

These changes by side effect allow setting layouts on gathers in
VectorDistribute pipeline.
  • Loading branch information
Groverkss authored Oct 30, 2024
1 parent 12cb042 commit a744285
Show file tree
Hide file tree
Showing 14 changed files with 357 additions and 461 deletions.
1 change: 0 additions & 1 deletion compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ iree_compiler_cc_library(
"LLVMGPUCastAddressSpaceFunction.cpp",
"LLVMGPUCastTypeToFitMMA.cpp",
"LLVMGPUConfigureTensorLayouts.cpp",
"LLVMGPUConfigureVectorLayouts.cpp",
"LLVMGPUConvolutionToIGEMM.cpp",
"LLVMGPULinkExecutables.cpp",
"LLVMGPULowerExecutableTarget.cpp",
Expand Down
1 change: 0 additions & 1 deletion compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ iree_cc_library(
"LLVMGPUCastAddressSpaceFunction.cpp"
"LLVMGPUCastTypeToFitMMA.cpp"
"LLVMGPUConfigureTensorLayouts.cpp"
"LLVMGPUConfigureVectorLayouts.cpp"
"LLVMGPUConvolutionToIGEMM.cpp"
"LLVMGPULinkExecutables.cpp"
"LLVMGPULowerExecutableTarget.cpp"
Expand Down
22 changes: 22 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target,
b.getI64ArrayAttr(workgroupTileSizes));
attrs.emplace_back(StringAttr::get(context, "reduction"),
b.getI64ArrayAttr(reductionTileSizes));
IREE::GPU::LoweringConfigAttr::setPromotedOperandList(context, attrs, {0, 1});

auto configDict = DictionaryAttr::get(context, attrs);
auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict);
Expand Down Expand Up @@ -633,6 +634,7 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target,
b.getI64ArrayAttr(workgroupTileSizes));
attrs.emplace_back(StringAttr::get(context, "reduction"),
b.getI64ArrayAttr(reductionTileSizes));
IREE::GPU::LoweringConfigAttr::setPromotedOperandList(context, attrs, {0, 1});

auto configDict = DictionaryAttr::get(context, attrs);
auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict);
Expand Down Expand Up @@ -830,13 +832,33 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target,
b.getI64ArrayAttr(workgroupTileSizes));
attrs.emplace_back(StringAttr::get(context, "reduction"),
b.getI64ArrayAttr(reductionTileSizes));
IREE::GPU::LoweringConfigAttr::setPromotedOperandList(context, attrs,
{0, 1, 2});

SmallVector<NamedAttribute, 2> qkConfig;
SmallVector<NamedAttribute, 2> pvConfig;

IREE::GPU::LoweringConfigAttr::setPromotedOperandList(context, qkConfig,
{0, 1});
IREE::GPU::LoweringConfigAttr::setPromotedOperandList(context, pvConfig, {1});

SmallVector<NamedAttribute, 2> qkAttrs;
SmallVector<NamedAttribute, 2> pvAttrs;

qkAttrs.emplace_back(b.getNamedAttr("attention_qk_matmul", b.getUnitAttr()));
pvAttrs.emplace_back(b.getNamedAttr("attention_pv_matmul", b.getUnitAttr()));

auto qkConfigDict = b.getDictionaryAttr(qkConfig);
auto pvConfigDict = b.getDictionaryAttr(pvConfig);

auto qkLoweringConfig =
IREE::GPU::LoweringConfigAttr::get(context, qkConfigDict);
auto pvLoweringConfig =
IREE::GPU::LoweringConfigAttr::get(context, pvConfigDict);

qkAttrs.emplace_back(b.getNamedAttr("lowering_config", qkLoweringConfig));
pvAttrs.emplace_back(b.getNamedAttr("lowering_config", pvLoweringConfig));

auto qkAttrDict = b.getDictionaryAttr(qkAttrs);
auto pvAttrDict = b.getDictionaryAttr(pvAttrs);

Expand Down
Loading

0 comments on commit a744285

Please sign in to comment.