Skip to content

Commit

Permalink
[Flow] Teach RaiseSpecialOps to raise tensor.extract and views (#14718)
Browse files Browse the repository at this point in the history
This patch teaches RaiseSpecialOps pass to raise tensor.extract to an
input arguement to a linalg.generic. It also teaches it to raise
linalg.generic to tensor.expand_shape.

Fixes #14742
  • Loading branch information
Groverkss authored Aug 24, 2023
1 parent fdc873c commit ffde368
Show file tree
Hide file tree
Showing 2 changed files with 301 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,261 @@ std::optional<Value> matchATransposeBMatmul(linalg::LinalgOp matmulOp) {
return std::nullopt;
}

/// Matches a linalg.generic operation reading data from a tensor `source` using
/// tensor.extract, and raises the `source` tensor to an input of the linalg
/// operation.
static FailureOr<linalg::GenericOp>
raiseTensorExtractToInput(linalg::GenericOp linalgOp, RewriterBase &rewriter) {
if (!linalgOp.hasTensorSemantics()) {
return failure();
}
if (!isElementwise(linalgOp)) {
return failure();
}
if (!llvm::hasSingleElement(linalgOp.getResults())) {
return failure();
}

// Find a tensor.extract op in the linalgOp body.
auto extractOps = linalgOp.getBody()->getOps<tensor::ExtractOp>();
if (!llvm::hasSingleElement(extractOps)) {
return failure();
}
tensor::ExtractOp extractOp = *extractOps.begin();
auto resultType = dyn_cast<TensorType>(linalgOp.getResult(0).getType());
if (!resultType) {
return failure();
}

ArrayRef<int64_t> sourceShape = extractOp.getTensor().getType().getShape();
ArrayRef<int64_t> resultShape = resultType.getShape();

// Raise the tensor.extract op to an input.
SmallVector<AffineExpr> exprs;
for (auto [idx, indexValue] : llvm::enumerate(extractOp.getIndices())) {
// For raising, the indexing value must be one of the following:
// 1. A constant value.
// 2. A linalg.index.

// 1. Indexing value is a constant.
APInt constantIndex;
if (matchPattern(indexValue, m_ConstantInt(&constantIndex))) {
// Restrict to cases where the constant is 0. This is because handling
// constants other than 0 in indexing map, may cause problems in the
// lowering pipeline later.
if (constantIndex.getLimitedValue() != 0)
return failure();
exprs.push_back(getAffineConstantExpr(0, rewriter.getContext()));
continue;
}
// 2. The indexing value is a linalg.index.
if (auto indexOp = indexValue.getDefiningOp<linalg::IndexOp>()) {
// Make sure that for this index, the size of the input and output
// match and are not dynamic. We need this to maintain the op to be
// elementwise.
// TODO: This restriction can be relaxed by adding a extract_slice op
// on the `source` tensor. This is not same as raising the whole
// operation to an extract_slice, as there can be permutations and
// projections involved.
if (sourceShape[idx] == ShapedType::kDynamic ||
resultShape[indexOp.getDim()] == ShapedType::kDynamic ||
sourceShape[idx] != resultShape[indexOp.getDim()]) {
return failure();
}
exprs.push_back(
getAffineDimExpr(indexOp.getDim(), rewriter.getContext()));
continue;
}
return failure();
}
AffineMap indexingMap = AffineMap::get(
/*dimCount=*/linalgOp.getNumLoops(),
/*symbolCount=*/0, exprs, rewriter.getContext());

// Replace the linalgOp with a new linalgOp where the source tensor is
// an input with the indexing map.
SmallVector<Value> newInputs = linalgOp.getInputs();
newInputs.insert(newInputs.begin(), extractOp.getTensor());
SmallVector<Attribute> newIndexingMaps;
newIndexingMaps.push_back(AffineMapAttr::get(indexingMap));
for (AffineMap map : linalgOp.getIndexingMapsArray()) {
newIndexingMaps.push_back(AffineMapAttr::get(map));
}

auto bodyBuilder = [&](OpBuilder &builder, Location loc, ValueRange args) {
// Create an IR mapping from old block arguements to new ones.
IRMapping mapper;
ArrayRef<BlockArgument> oldArgs = linalgOp.getBody()->getArguments();
// Map i^th old argument to (i + 1)^th new argument.
for (unsigned i = 0; i < oldArgs.size(); ++i) {
mapper.map(oldArgs[i], args[i + 1]);
}
// Clone the body of the linalgOp.
for (Operation &op : linalgOp.getBody()->getOperations()) {
// Replace the extractOp with the first block argument.
if (&op == extractOp) {
mapper.map(op.getResult(0), args[0]);
} else {
builder.clone(op, mapper);
}
}
};

linalg::GenericOp newLinalgOp = rewriter.create<linalg::GenericOp>(
linalgOp.getLoc(), linalgOp.getResultTypes(), newInputs,
linalgOp.getOutputs(),
ArrayAttr::get(linalgOp->getContext(), newIndexingMaps),
linalgOp.getIteratorTypesAttr(), linalgOp.getDocAttr(),
linalgOp.getLibraryCallAttr(), bodyBuilder);

return newLinalgOp;
}

/// Given a linalg.generic operation, and input/output tensors with their
/// indexing maps, tries to raise the operation to a tensor.extract_slice
/// operation. The tensor.extract_slice produced can be rank reducing.
static FailureOr<tensor::ExtractSliceOp>
tryRaiseToExtractSlice(AffineMap inputIndexingMap, AffineMap outputIndexingMap,
Value input, Value output, linalg::GenericOp linalgOp,
RewriterBase &rewriter) {
// Output shape must be smaller than input shape.
if (outputIndexingMap.getNumResults() >= inputIndexingMap.getNumResults()) {
return failure();
}
// Output map should be identity.
if (!outputIndexingMap.isIdentity()) {
return failure();
}

auto outType = dyn_cast<RankedTensorType>(output.getType());
if (!outType) {
return failure();
}
ArrayRef<int64_t> outShape = outType.getShape();

// Try to match each output dimension to an input dimension, in order.
// If we find a constant access, we assume that dimension is supposed to be
// rank reduced.
// TODO: Support cases where the constant access matches the output dimension.
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> sizes;
IntegerAttr zero = rewriter.getI64IntegerAttr(0);
IntegerAttr one = rewriter.getI64IntegerAttr(1);
unsigned currOutDim = 0;
for (auto [idx, expr] : llvm::enumerate(inputIndexingMap.getResults())) {
// Check if the input dimension matches the current output dimension.
if (expr == outputIndexingMap.getResult(currOutDim)) {
offsets.push_back(zero);
// Get the dim size from the output tensor.
if (outShape[currOutDim] == ShapedType::kDynamic) {
auto dim = rewriter.create<tensor::DimOp>(linalgOp.getLoc(), output,
currOutDim);
sizes.push_back(dim.getResult());
} else {
sizes.push_back(rewriter.getI64IntegerAttr(outShape[currOutDim]));
}
++currOutDim;
continue;
}
// Assume that the constant access is a rank reducing access.
if (expr.isa<AffineConstantExpr>()) {
IntegerAttr constIdx = rewriter.getI64IntegerAttr(
expr.cast<AffineConstantExpr>().getValue());
offsets.push_back(constIdx);
sizes.push_back(one);
continue;
}
// Unknown access, fail.
return failure();
}

// All output dimensions did not match an input dimension.
if (currOutDim != outputIndexingMap.getNumResults()) {
return failure();
}

// We only support dim expr or a constant expr on the input map, so strides
// will always be 1.
SmallVector<OpFoldResult> strides(inputIndexingMap.getNumResults(), one);

return rewriter.create<tensor::ExtractSliceOp>(
linalgOp.getLoc(), outType, input, offsets, sizes, strides);
}

/// Matches a linalg.generic operation with a single input and init output
/// tensor, and tries to raise it to a view-like operation on the input tensor.
static FailureOr<Operation *> tryRaiseToView(linalg::GenericOp linalgOp,
RewriterBase &rewriter) {
if (!linalgOp.hasTensorSemantics()) {
return failure();
}

// Assume there is only 1 input, and 1 init tensor.
if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1) {
return failure();
}
OpOperand *inputOperand = linalgOp.getDpsInputOperand(0);
OpOperand *outputOperand = linalgOp.getDpsInitOperand(0);

// Check if linalg.yield yields a block arguement.
auto yieldOp = dyn_cast<linalg::YieldOp>(linalgOp.getBody()->getTerminator());
if (!yieldOp) {
return failure();
}
auto blockArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
if (!blockArg) {
return failure();
}
// Check if the block argument is an argument of the linalgOp.
if (blockArg.getOwner() != linalgOp.getBody()) {
return failure();
}
// Check that the block arguement corresponds to the input.
if (blockArg.getArgNumber() != 0) {
return failure();
}

Value input = inputOperand->get();
Value output = outputOperand->get();
AffineMap inputIndexingMap = linalgOp.getMatchingIndexingMap(inputOperand);
AffineMap outputIndexingMap = linalgOp.getMatchingIndexingMap(outputOperand);

// Try raising to tensor.collapse_shape.
return tryRaiseToExtractSlice(inputIndexingMap, outputIndexingMap, input,
output, linalgOp, rewriter);
}

struct RaiseSpecialOpsPass : public RaiseSpecialOpsBase<RaiseSpecialOpsPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::LinalgExt::IREELinalgExtDialect>();
}

void runOnOperation() override {
IRRewriter rewriter(&getContext());

getOperation()->walk([&](linalg::GenericOp op) {
linalg::GenericOp linalgOp = op;

OpBuilder::InsertionGuard guard(rewriter);

// Try raising to tensor.export and create an intermediate linalg.generic.
rewriter.setInsertionPoint(op);
FailureOr<linalg::GenericOp> maybeNewOp =
raiseTensorExtractToInput(linalgOp, rewriter);
if (succeeded(maybeNewOp)) {
linalgOp = *maybeNewOp;
}

// Try raising to a view-like operation. Replace if the op raising was
// successful.
rewriter.setInsertionPoint(op);
FailureOr<Operation *> maybeRaisedView =
tryRaiseToView(linalgOp, rewriter);
if (succeeded(maybeRaisedView)) {
rewriter.replaceOp(op, *maybeRaisedView);
}
});

SmallVector<std::pair<linalg::LinalgOp, Value>> softmaxRoots;
SmallVector<std::pair<linalg::MatmulOp, Value>> transposeMatmulRoots;
getOperation()->walk([&](linalg::LinalgOp op) {
Expand All @@ -100,14 +349,15 @@ struct RaiseSpecialOpsPass : public RaiseSpecialOpsBase<RaiseSpecialOpsPass> {
}
}
});
IRRewriter rewriter(&getContext());

for (std::pair<linalg::LinalgOp, Value> softmax : softmaxRoots) {
linalg::LinalgOp op = softmax.first;
Value src = softmax.second;
rewriter.setInsertionPoint(softmax.first);
rewriter.replaceOpWithNewOp<IREE::LinalgExt::SoftmaxOp>(
op, src, op.getDpsInitOperand(0)->get(), op.getNumLoops() - 1);
}

for (std::pair<linalg::MatmulOp, Value> aTransposeBMatmul :
transposeMatmulRoots) {
auto matmulOp = aTransposeBMatmul.first;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt --iree-flow-raise-special-ops -canonicalize %s | FileCheck %s
// RUN: iree-opt --iree-flow-raise-special-ops -canonicalize --split-input-file %s | FileCheck %s

// CHECK-LABEL: @softmax
// CHECK-SAME: %[[ARG:.+]]: tensor<?x?x?xf32>
Expand Down Expand Up @@ -186,3 +186,52 @@ func.func @aTransposeBMatmul(%arg0 : tensor<10x20xf32>,
// CHECK: %[[RESULT:.+]] = linalg.matmul_transpose_b
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] :
// CHECK: return %[[RESULT]]

// -----

#map = affine_map<(d0) -> (d0)>
func.func @test(%A : tensor<1x1x5120xf32>, %B : tensor<5120xf32>) -> tensor<5120xf32> {
%c0 = arith.constant 0 : index
// CHECK: tensor.extract_slice
%0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%B : tensor<5120xf32>) {
^bb0(%out: f32):
%12 = linalg.index 0 : index
%extracted = tensor.extract %A[%c0, %c0, %12] : tensor<1x1x5120xf32>
linalg.yield %extracted : f32
} -> tensor<5120xf32>
return %0 : tensor<5120xf32>
}

// -----

// This currently should not be raised as the operation does not remain
// elementwise after raising the tensor.extract to input.
#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @test(%A : tensor<128x128x128xf32>, %B : tensor<64x64xf32>) -> tensor<64x64xf32> {
%c0 = arith.constant 0 : index
// CHECK: linalg.generic
%0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%B : tensor<64x64xf32>) {
^bb0(%out: f32):
%i1 = linalg.index 0 : index
%i2 = linalg.index 1 : index
%extracted = tensor.extract %A[%i1, %c0, %i2] : tensor<128x128x128xf32>
linalg.yield %extracted : f32
} -> tensor<64x64xf32>
return %0 : tensor<64x64xf32>
}

// -----

#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @test(%A : tensor<64x64x64xf32>, %B : tensor<64x64xf32>) -> tensor<64x64xf32> {
%c0 = arith.constant 0 : index
%0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%B : tensor<64x64xf32>) {
^bb0(%out: f32):
%i1 = linalg.index 0 : index
%i2 = linalg.index 1 : index
// CHECK: tensor.extract_slice
%extracted = tensor.extract %A[%i1, %c0, %i2] : tensor<64x64x64xf32>
linalg.yield %extracted : f32
} -> tensor<64x64xf32>
return %0 : tensor<64x64xf32>
}

0 comments on commit ffde368

Please sign in to comment.