diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index 4b64cda3adc9..0d9c7f9ad2e6 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -825,7 +825,25 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target, attrs.emplace_back(StringAttr::get(context, "reduction"), b.getI64ArrayAttr(reductionTileSizes)); - auto configDict = DictionaryAttr::get(context, attrs); + SmallVector qkAttrs; + SmallVector pvAttrs; + + qkAttrs.emplace_back(b.getNamedAttr("attention_qk_matmul", b.getUnitAttr())); + pvAttrs.emplace_back(b.getNamedAttr("attention_pv_matmul", b.getUnitAttr())); + + auto qkAttrDict = b.getDictionaryAttr(qkAttrs); + auto pvAttrDict = b.getDictionaryAttr(pvAttrs); + + SmallVector decompositionConfig; + decompositionConfig.emplace_back( + b.getNamedAttr(IREE::LinalgExt::AttentionOp::getQKAttrStr(), qkAttrDict)); + decompositionConfig.emplace_back( + b.getNamedAttr(IREE::LinalgExt::AttentionOp::getPVAttrStr(), pvAttrDict)); + + DictionaryAttr decompositionConfigDict = + b.getDictionaryAttr(decompositionConfig); + + auto configDict = b.getDictionaryAttr(attrs); auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict); // Attach the MMA schedule as an attribute to the entry point export function @@ -843,6 +861,9 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target, auto pipelineConfig = DictionaryAttr::get(context, pipelineAttrs); + // Set attention decomposition control config. + op.setDecompositionConfigAttr(decompositionConfigDict); + return setOpConfigAndEntryPointFnTranslation( entryPoint, op, loweringConfig, CodeGenPipeline::LLVMGPUVectorDistribute, workgroupSize, targetSubgroupSize, pipelineConfig); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir index d21faf8867b1..4334e79d6f88 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir @@ -688,7 +688,9 @@ hal.executable private @attention_20x4096x64x4096x64 { affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>], - lowering_config = #config} + lowering_config = #config, + decomposition_config = {qk_attrs = {attention_qk_matmul}, + pv_attrs = {attention_pv_matmul}}} ins(%4, %5, %6, %cst : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16) outs(%7 : tensor<20x4096x64xf16>) { ^bb0(%score: f32): iree_linalg_ext.yield %score : f32 @@ -753,7 +755,15 @@ hal.executable private @attention_multiple_m_transpose { %6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [24, 4608, 128], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<24x4608x128xf16> %7 = tensor.empty() : tensor<64x4608x24x128xf16> %8 = tensor.empty() : tensor<24x64x4608x128xf16> - %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) { + %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], + lowering_config = #config, + decomposition_config = {qk_attrs = {attention_qk_matmul}, + pv_attrs = {attention_pv_matmul}}} + ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) { ^bb0(%score: f32): iree_linalg_ext.yield %score : f32 } -> tensor<24x64x4608x128xf16> @@ -811,7 +821,15 @@ hal.executable private @attention_mfma_32x32x8 { %6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [24, 4608, 128], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<24x4608x128xf16> %7 = tensor.empty() : tensor<64x4608x24x128xf16> %8 = tensor.empty() : tensor<24x64x4608x128xf16> - %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) { + %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], + lowering_config = #config, + decomposition_config = {qk_attrs = {attention_qk_matmul}, + pv_attrs = {attention_pv_matmul}}} + ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) { ^bb0(%score: f32): iree_linalg_ext.yield %score : f32 } -> tensor<24x64x4608x128xf16> diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp index 02d1e71e423c..204ae3533c7b 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp @@ -313,6 +313,13 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { Value oldMax = getMax(); Value oldSum = getSum(); Type elementType = getElementTypeOrSelf(getOutput().getType()); + DictionaryAttr config = getDecompositionConfigAttr(); + + DictionaryAttr qkAttrs, pvAttrs; + if (config) { + qkAttrs = config.getAs(getQKAttrStr()); + pvAttrs = config.getAs(getPVAttrStr()); + } FailureOr maybeOpInfo = AttentionOpDetail::get(getIndexingMapsArray()); @@ -368,10 +375,9 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { Value s = b.create(loc, sZero, emptyS).getResult(0); s = computeMatmul(b, loc, getQueryMap(), getKeyMap(), sMap, query, key, s); - - // TODO: We shouldn't be relying on such attributes. We need a better - // mechanism to identify attention matmuls. - s.getDefiningOp()->setAttr("attention_qk_matmul", b.getUnitAttr()); + if (qkAttrs) { + s.getDefiningOp()->setDiscardableAttrs(qkAttrs); + } s = applyPostQKMatmulElementwise(b, loc, getRegion(), s); @@ -448,9 +454,9 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { // newAcc = P @ V + newAcc newAcc = computeMatmul(b, loc, pMap, getValueMap(), accMap, p, value, newAcc); - // TODO: We shouldn't be relying on such attributes. We need a better - // mechanism to identify attention matmuls. - newAcc.getDefiningOp()->setAttr("attention_pv_matmul", b.getUnitAttr()); + if (pvAttrs) { + newAcc.getDefiningOp()->setDiscardableAttrs(pvAttrs); + } return SmallVector{newAcc, newMax, newSum}; } diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp index 6abaec41f91a..77a2d518acb2 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp @@ -1213,7 +1213,7 @@ void AttentionOp::build(OpBuilder &odsBuilder, OperationState &odsState, std::optional mask) { Value maskIn = mask.value_or(Value()); build(odsBuilder, odsState, results, query, key, value, scale, maskIn, output, - indexingMaps); + indexingMaps, DictionaryAttr()); } LogicalResult AttentionOp::verify() { @@ -1388,7 +1388,7 @@ void OnlineAttentionOp::build(OpBuilder &odsBuilder, OperationState &odsState, std::optional mask) { Value maskIn = mask.value_or(Value()); build(odsBuilder, odsState, results, query, key, value, maskIn, scale, output, - max, sum, indexingMaps); + max, sum, indexingMaps, DictionaryAttr()); } LogicalResult OnlineAttentionOp::verify() { diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td index e097ce5a9089..329c79ca5297 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td @@ -501,7 +501,8 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention", AnyFloat:$scale, Optional:$mask, AnyShaped:$output, - AffineMapArrayAttr:$indexing_maps + AffineMapArrayAttr:$indexing_maps, + OptionalAttr:$decomposition_config ); let regions = (region SizedRegion<1>:$region); @@ -558,6 +559,12 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention", int64_t getIterationDomainRank() { return getQueryMap().getNumDims(); } + + /* Decomposition control attributes */ + + // Attributes to set on QK and PV matmul after decomposition. + static StringRef getQKAttrStr() { return "qk_attrs"; } + static StringRef getPVAttrStr() { return "pv_attrs"; } }]; } @@ -612,7 +619,8 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_PureOp<"online_attention", AnyShaped:$output, AnyShaped:$max, AnyShaped:$sum, - AffineMapArrayAttr:$indexing_maps + AffineMapArrayAttr:$indexing_maps, + OptionalAttr:$decomposition_config ); let regions = (region SizedRegion<1>:$region); @@ -679,6 +687,12 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_PureOp<"online_attention", int64_t getIterationDomainRank() { return getQueryMap().getNumDims(); } + + /* Decomposition control attributes */ + + // Attributes to set on QK and PV matmul after decomposition. + static StringRef getQKAttrStr() { return "qk_attrs"; } + static StringRef getPVAttrStr() { return "pv_attrs"; } }]; } //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp index 0aa3a37aa5fe..d9a48736fdd4 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp @@ -106,7 +106,8 @@ void convertToOnlineAttention(IREE::LinalgExt::AttentionOp attnOp, loc, TypeRange{accFill.getType(), maxFill.getType(), sumFill.getType()}, attnOp.getQuery(), attnOp.getKey(), attnOp.getValue(), attnOp.getScale(), mask, accFill, maxFill, sumFill, - rewriter.getAffineMapArrayAttr(indexingMaps)); + rewriter.getAffineMapArrayAttr(indexingMaps), + attnOp.getDecompositionConfigAttr()); rewriter.cloneRegionBefore(attnOp.getRegion(), onlineAttn.getRegion(), onlineAttn.getRegion().begin());