Skip to content

Commit

Permalink
[LinalgExt] Generalize attribute setting for attention decomposition (#…
Browse files Browse the repository at this point in the history
…18780)

This PR teaches attention decomposition to set attributes for attention
matmuls by passing attribute dictionaries to
iree_linalg_ext.online_attention operation. This allows us to further
control codegen of matmuls (generally the root operations) after
decomposition (for example, setting lowering config on the decompose
matmuls).
  • Loading branch information
Groverkss authored Oct 28, 2024
1 parent a041798 commit e66171a
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 16 deletions.
23 changes: 22 additions & 1 deletion compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,25 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target,
attrs.emplace_back(StringAttr::get(context, "reduction"),
b.getI64ArrayAttr(reductionTileSizes));

auto configDict = DictionaryAttr::get(context, attrs);
SmallVector<NamedAttribute, 2> qkAttrs;
SmallVector<NamedAttribute, 2> pvAttrs;

qkAttrs.emplace_back(b.getNamedAttr("attention_qk_matmul", b.getUnitAttr()));
pvAttrs.emplace_back(b.getNamedAttr("attention_pv_matmul", b.getUnitAttr()));

auto qkAttrDict = b.getDictionaryAttr(qkAttrs);
auto pvAttrDict = b.getDictionaryAttr(pvAttrs);

SmallVector<NamedAttribute, 2> decompositionConfig;
decompositionConfig.emplace_back(
b.getNamedAttr(IREE::LinalgExt::AttentionOp::getQKAttrStr(), qkAttrDict));
decompositionConfig.emplace_back(
b.getNamedAttr(IREE::LinalgExt::AttentionOp::getPVAttrStr(), pvAttrDict));

DictionaryAttr decompositionConfigDict =
b.getDictionaryAttr(decompositionConfig);

auto configDict = b.getDictionaryAttr(attrs);
auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict);

// Attach the MMA schedule as an attribute to the entry point export function
Expand All @@ -843,6 +861,9 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target,

auto pipelineConfig = DictionaryAttr::get(context, pipelineAttrs);

// Set attention decomposition control config.
op.setDecompositionConfigAttr(decompositionConfigDict);

return setOpConfigAndEntryPointFnTranslation(
entryPoint, op, loweringConfig, CodeGenPipeline::LLVMGPUVectorDistribute,
workgroupSize, targetSubgroupSize, pipelineConfig);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,9 @@ hal.executable private @attention_20x4096x64x4096x64 {
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>],
lowering_config = #config}
lowering_config = #config,
decomposition_config = {qk_attrs = {attention_qk_matmul},
pv_attrs = {attention_pv_matmul}}}
ins(%4, %5, %6, %cst : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16) outs(%7 : tensor<20x4096x64xf16>) {
^bb0(%score: f32):
iree_linalg_ext.yield %score : f32
Expand Down Expand Up @@ -753,7 +755,15 @@ hal.executable private @attention_multiple_m_transpose {
%6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [24, 4608, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<24x4608x128xf16>> -> tensor<24x4608x128xf16>
%7 = tensor.empty() : tensor<64x4608x24x128xf16>
%8 = tensor.empty() : tensor<24x64x4608x128xf16>
%9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) {
%9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> ()>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>],
lowering_config = #config,
decomposition_config = {qk_attrs = {attention_qk_matmul},
pv_attrs = {attention_pv_matmul}}}
ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) {
^bb0(%score: f32):
iree_linalg_ext.yield %score : f32
} -> tensor<24x64x4608x128xf16>
Expand Down Expand Up @@ -811,7 +821,15 @@ hal.executable private @attention_mfma_32x32x8 {
%6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [24, 4608, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<24x4608x128xf16>> -> tensor<24x4608x128xf16>
%7 = tensor.empty() : tensor<64x4608x24x128xf16>
%8 = tensor.empty() : tensor<24x64x4608x128xf16>
%9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) {
%9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> ()>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>],
lowering_config = #config,
decomposition_config = {qk_attrs = {attention_qk_matmul},
pv_attrs = {attention_pv_matmul}}}
ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) {
^bb0(%score: f32):
iree_linalg_ext.yield %score : f32
} -> tensor<24x64x4608x128xf16>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,13 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
Value oldMax = getMax();
Value oldSum = getSum();
Type elementType = getElementTypeOrSelf(getOutput().getType());
DictionaryAttr config = getDecompositionConfigAttr();

DictionaryAttr qkAttrs, pvAttrs;
if (config) {
qkAttrs = config.getAs<DictionaryAttr>(getQKAttrStr());
pvAttrs = config.getAs<DictionaryAttr>(getPVAttrStr());
}

FailureOr<AttentionOpDetail> maybeOpInfo =
AttentionOpDetail::get(getIndexingMapsArray());
Expand Down Expand Up @@ -368,10 +375,9 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
Value s = b.create<linalg::FillOp>(loc, sZero, emptyS).getResult(0);

s = computeMatmul(b, loc, getQueryMap(), getKeyMap(), sMap, query, key, s);

// TODO: We shouldn't be relying on such attributes. We need a better
// mechanism to identify attention matmuls.
s.getDefiningOp()->setAttr("attention_qk_matmul", b.getUnitAttr());
if (qkAttrs) {
s.getDefiningOp()->setDiscardableAttrs(qkAttrs);
}

s = applyPostQKMatmulElementwise(b, loc, getRegion(), s);

Expand Down Expand Up @@ -448,9 +454,9 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {

// newAcc = P @ V + newAcc
newAcc = computeMatmul(b, loc, pMap, getValueMap(), accMap, p, value, newAcc);
// TODO: We shouldn't be relying on such attributes. We need a better
// mechanism to identify attention matmuls.
newAcc.getDefiningOp()->setAttr("attention_pv_matmul", b.getUnitAttr());
if (pvAttrs) {
newAcc.getDefiningOp()->setDiscardableAttrs(pvAttrs);
}

return SmallVector<Value>{newAcc, newMax, newSum};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1213,7 +1213,7 @@ void AttentionOp::build(OpBuilder &odsBuilder, OperationState &odsState,
std::optional<Value> mask) {
Value maskIn = mask.value_or(Value());
build(odsBuilder, odsState, results, query, key, value, scale, maskIn, output,
indexingMaps);
indexingMaps, DictionaryAttr());
}

LogicalResult AttentionOp::verify() {
Expand Down Expand Up @@ -1388,7 +1388,7 @@ void OnlineAttentionOp::build(OpBuilder &odsBuilder, OperationState &odsState,
std::optional<Value> mask) {
Value maskIn = mask.value_or(Value());
build(odsBuilder, odsState, results, query, key, value, maskIn, scale, output,
max, sum, indexingMaps);
max, sum, indexingMaps, DictionaryAttr());
}

LogicalResult OnlineAttentionOp::verify() {
Expand Down
18 changes: 16 additions & 2 deletions compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,8 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention",
AnyFloat:$scale,
Optional<AnyShaped>:$mask,
AnyShaped:$output,
AffineMapArrayAttr:$indexing_maps
AffineMapArrayAttr:$indexing_maps,
OptionalAttr<DictionaryAttr>:$decomposition_config
);
let regions = (region SizedRegion<1>:$region);

Expand Down Expand Up @@ -558,6 +559,12 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention",
int64_t getIterationDomainRank() {
return getQueryMap().getNumDims();
}

/* Decomposition control attributes */

// Attributes to set on QK and PV matmul after decomposition.
static StringRef getQKAttrStr() { return "qk_attrs"; }
static StringRef getPVAttrStr() { return "pv_attrs"; }
}];
}

Expand Down Expand Up @@ -612,7 +619,8 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_PureOp<"online_attention",
AnyShaped:$output,
AnyShaped:$max,
AnyShaped:$sum,
AffineMapArrayAttr:$indexing_maps
AffineMapArrayAttr:$indexing_maps,
OptionalAttr<DictionaryAttr>:$decomposition_config
);
let regions = (region SizedRegion<1>:$region);

Expand Down Expand Up @@ -679,6 +687,12 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_PureOp<"online_attention",
int64_t getIterationDomainRank() {
return getQueryMap().getNumDims();
}

/* Decomposition control attributes */

// Attributes to set on QK and PV matmul after decomposition.
static StringRef getQKAttrStr() { return "qk_attrs"; }
static StringRef getPVAttrStr() { return "pv_attrs"; }
}];
}
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ void convertToOnlineAttention(IREE::LinalgExt::AttentionOp attnOp,
loc, TypeRange{accFill.getType(), maxFill.getType(), sumFill.getType()},
attnOp.getQuery(), attnOp.getKey(), attnOp.getValue(), attnOp.getScale(),
mask, accFill, maxFill, sumFill,
rewriter.getAffineMapArrayAttr(indexingMaps));
rewriter.getAffineMapArrayAttr(indexingMaps),
attnOp.getDecompositionConfigAttr());

rewriter.cloneRegionBefore(attnOp.getRegion(), onlineAttn.getRegion(),
onlineAttn.getRegion().begin());
Expand Down

0 comments on commit e66171a

Please sign in to comment.