Skip to content

Commit

Permalink
[LinalgExt] Remove default implementation for getStaticLoopRanges (#1…
Browse files Browse the repository at this point in the history
…8745)

The default implementation of getStaticLoopRanges is dangerous and
causes unexpected bugs. It only works for operands with distinct loop
ranges as dimensions. It's better to have operations specify it.
  • Loading branch information
Groverkss authored Oct 14, 2024
1 parent d7378bb commit 1e6bbb8
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 22 deletions.
7 changes: 6 additions & 1 deletion compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,12 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target,

// Get iteration domain bounds.
OpBuilder b(op);
SmallVector<int64_t, 4> bounds = op.getStaticLoopRanges();
FailureOr<SmallVector<int64_t>> maybeBounds = op.getStaticLoopRanges();
if (failed(maybeBounds)) {
return failure();
}

ArrayRef<int64_t> bounds = maybeBounds.value();

auto opInfo =
IREE::LinalgExt::AttentionOpDetail::get(op.getIndexingMapsArray())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,10 @@ struct LinalgFusionOpInterfaceAdapter
return (llvm::cast<ConcreteType>(op).getNumLoops());
}

SmallVector<int64_t, 4> getStaticLoopRanges(mlir::Operation *op) const {
return (llvm::cast<ConcreteType>(op).getStaticLoopRanges());
FailureOr<SmallVector<int64_t>>
getStaticLoopRanges(mlir::Operation *op) const {
return SmallVector<int64_t>(
llvm::cast<ConcreteType>(op).getStaticLoopRanges());
}

AffineMap getIndexingMapMatchingResult(mlir::Operation *op,
Expand Down Expand Up @@ -128,6 +130,12 @@ struct SoftmaxFusionOpInterfaceAdapter
}));
}

FailureOr<SmallVector<int64_t>> getStaticLoopRanges(Operation *op) const {
auto softmaxOp = cast<linalg::SoftmaxOp>(op);
// Softmax loop range is the input shape.
return SmallVector<int64_t>(softmaxOp.getInputOperandType().getShape());
}

AffineMap getIndexingMapMatchingResult(mlir::Operation *op,
OpResult result) const {
return getIndexingMapsForResults(op)[result.getResultNumber()];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,18 +110,12 @@ def LinalgFusionInterface : OpInterface<"LinalgFusionOpInterface"> {
/*desc=*/[{
Return the static loop ranges.
}],
/*retTy=*/"SmallVector<int64_t, 4>",
/*retTy=*/"FailureOr<SmallVector<int64_t>>",
/*methodName=*/"getStaticLoopRanges",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
SmallVector<int64_t, 4> loopRanges;
llvm::for_each($_op.getOperands(), [&](Value operand) {
if (auto shapedType = dyn_cast<ShapedType>(operand.getType())) {
llvm::append_range(loopRanges, shapedType.getShape());
}
});
return loopRanges;
return failure();
}]
>,
InterfaceMethod<
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,11 @@ ScatterOp::reifyResultShapes(OpBuilder &b,
.reifyResultShapes(b, reifiedReturnShapes);
}

FailureOr<SmallVector<int64_t>> ScatterOp::getStaticLoopRanges() {
// Scatter loop ranges are loop ranges for update.
return SmallVector<int64_t>(getUpdateType().getShape());
}

SmallVector<AffineMap> ScatterOp::getIndexingMapsForOperands() {
Builder builder(getContext());
return {builder.getMultiDimIdentityMap(getUpdateType().getRank()),
Expand Down Expand Up @@ -1321,8 +1326,8 @@ SmallVector<AffineMap> AttentionOp::getIndexingMapsArray() {
getIndexingMaps().getAsValueRange<AffineMapAttr>());
}

SmallVector<int64_t, 4> AttentionOp::getStaticLoopRanges() {
SmallVector<int64_t, 4> bounds(getIterationDomainRank());
FailureOr<SmallVector<int64_t>> AttentionOp::getStaticLoopRanges() {
SmallVector<int64_t> bounds(getIterationDomainRank());
SmallVector<bool> dimsFound(getIterationDomainRank(), false);

// batch(s), m, k1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ let opDocGroup = OpGroupNonStructuredOps in {

def IREELinalgExt_ScatterOp : IREELinalgExt_Op<"scatter",
[DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<LinalgFusionInterface>,
DeclareOpInterfaceMethods<LinalgFusionInterface,
["getIndexingMapsForResults", "getIndexingMapsForOperands",
"getStaticLoopRanges"]>,
DeclareOpInterfaceMethods<TilingInterface,
["generateScalarImplementation",
"getIterationDomain",
Expand Down Expand Up @@ -469,7 +471,8 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DestinationStyleOpInterface, LinalgExtInterface,
DeclareOpInterfaceMethods<LinalgFusionInterface,
["getIndexingMapsForResults", "getIndexingMapsForOperands"]>,
["getIndexingMapsForResults", "getIndexingMapsForOperands",
"getStaticLoopRanges"]>,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<TilingInterface,
["getIterationDomain",
Expand Down Expand Up @@ -528,8 +531,6 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention",

SmallVector<AffineMap> getIndexingMapsArray();

SmallVector<int64_t, 4> getStaticLoopRanges();

AffineMap getQueryMap() {
return cast<AffineMap>(getIndexingMapsArray()[0]);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,12 @@ LogicalResult ExpansionInfo::compute(OpTy op, OpOperand *fusableOpOperand,
if (reassociationMaps.empty())
return failure();
AffineMap fusedIndexMap = op.getMatchingIndexingMap(fusableOpOperand);
SmallVector<int64_t, 4> originalLoopRange = op.getStaticLoopRanges();
originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end());
FailureOr<SmallVector<int64_t>> originalLoopRange = op.getStaticLoopRanges();
if (failed(originalLoopRange)) {
return failure();
}
originalLoopExtent.assign(originalLoopRange->begin(),
originalLoopRange->end());

reassociation.clear();
expandedShapeMap.clear();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -568,9 +568,15 @@ isFusableWithConsumer(OpOperand &fusedOperand,
// TODO(#12664): This is unnecessary requirement, but we need a better config
// to tile the consumer with a larger iteration space.
if (!options.aggressiveFusion) {
auto producerIterationSpace = producerFusionOp.getStaticLoopRanges();
auto consumerIterationSpace = consumerFusionOp.getStaticLoopRanges();
if (producerIterationSpace.size() < consumerIterationSpace.size()) {
FailureOr<SmallVector<int64_t>> producerIterationSpace =
producerFusionOp.getStaticLoopRanges();
FailureOr<SmallVector<int64_t>> consumerIterationSpace =
consumerFusionOp.getStaticLoopRanges();
if (failed(producerIterationSpace) || failed(consumerIterationSpace)) {
return false;
}
if (producerIterationSpace.value().size() <
consumerIterationSpace.value().size()) {
return false;
}
}
Expand Down

0 comments on commit 1e6bbb8

Please sign in to comment.