Skip to content

Commit

Permalink
Do not keep intermediates
Browse files Browse the repository at this point in the history
  • Loading branch information
Groverkss committed Aug 22, 2023
1 parent 8813bbd commit 0e075cf
Showing 1 changed file with 20 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -157,18 +157,16 @@ raiseTensorExtractToInput(linalg::GenericOp linalgOp, RewriterBase &rewriter) {
linalgOp.getIteratorTypesAttr(), linalgOp.getDocAttr(),
linalgOp.getLibraryCallAttr(), bodyBuilder);

rewriter.replaceOp(linalgOp, newLinalgOp.getResults());
return newLinalgOp;
}

/// Given a linalg.generic operation, and input/output tensors with their
/// indexing maps, tries to raise the operation to a tensor.extract_slice
/// operation. The tensor.extract_slice produced can be rank reducing.
static LogicalResult tryRaiseToExtractSlice(AffineMap inputIndexingMap,
AffineMap outputIndexingMap,
Value input, Value output,
linalg::GenericOp linalgOp,
RewriterBase &rewriter) {
static FailureOr<tensor::ExtractSliceOp>
tryRaiseToExtractSlice(AffineMap inputIndexingMap, AffineMap outputIndexingMap,
Value input, Value output, linalg::GenericOp linalgOp,
RewriterBase &rewriter) {
// Output shape must be smaller than input shape.
if (outputIndexingMap.getNumResults() >= inputIndexingMap.getNumResults()) {
return failure();
Expand Down Expand Up @@ -229,15 +227,14 @@ static LogicalResult tryRaiseToExtractSlice(AffineMap inputIndexingMap,
// will always be 1.
SmallVector<OpFoldResult> strides(inputIndexingMap.getNumResults(), one);

rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(linalgOp, outType, input,
offsets, sizes, strides);
return success();
return rewriter.create<tensor::ExtractSliceOp>(
linalgOp.getLoc(), outType, input, offsets, sizes, strides);
}

/// Matches a linalg.generic operation with a single input and init output
/// tensor, and tries to raise it to a view-like operation on the input tensor.
static LogicalResult tryRaiseToView(linalg::GenericOp linalgOp,
RewriterBase &rewriter) {
static FailureOr<Operation *> tryRaiseToView(linalg::GenericOp linalgOp,
RewriterBase &rewriter) {
if (!linalgOp.hasTensorSemantics()) {
return failure();
}
Expand Down Expand Up @@ -286,17 +283,24 @@ struct RaiseSpecialOpsPass : public RaiseSpecialOpsBase<RaiseSpecialOpsPass> {
IRRewriter rewriter(&getContext());

getOperation()->walk([&](linalg::GenericOp op) {
// Try raising to tensor.export.
linalg::GenericOp linalgOp = op;

// Try raising to tensor.export and create an intermediate linalg.generic.
rewriter.setInsertionPoint(op);
FailureOr<linalg::GenericOp> maybeNewOp =
raiseTensorExtractToInput(op, rewriter);
raiseTensorExtractToInput(linalgOp, rewriter);
if (succeeded(maybeNewOp)) {
op = *maybeNewOp;
linalgOp = *maybeNewOp;
}

// Try raising to a view-like operation.
// Try raising to a view-like operation. Replace if the op raising was
// successful.
rewriter.setInsertionPoint(op);
(void)tryRaiseToView(op, rewriter);
FailureOr<Operation *> maybeRaisedView =
tryRaiseToView(linalgOp, rewriter);
if (succeeded(maybeRaisedView)) {
rewriter.replaceOp(op, *maybeRaisedView);
}
});

SmallVector<std::pair<linalg::LinalgOp, Value>> softmaxRoots;
Expand Down

0 comments on commit 0e075cf

Please sign in to comment.