Skip to content

Commit

Permalink
Add pattern to map iota->sort->slice to topK (#13972)
Browse files Browse the repository at this point in the history
  • Loading branch information
NatashaKnk authored Jun 9, 2023
1 parent e4c27f5 commit 7514d3e
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1540,6 +1540,87 @@ struct CustomCallIsTopK final
}
};

struct IotaSortSliceIsTopK final : OpRewritePattern<mlir::stablehlo::SortOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::SortOp op,
PatternRewriter &rewriter) const override {
auto opOperands = op.getOperands();
auto opResults = op.getResults();
Value topKInput;
if (opOperands.size() != 2 || opResults.size() != 2) {
return rewriter.notifyMatchFailure(
op, "Slice that maps to TopK must have exactly two inputs/outputs");
}

Value inputIota;
// Check that one of the inputs is iota, assume that the other one is the
// input.
for (Value operand : opOperands) {
auto iotaOp =
dyn_cast_or_null<mlir::stablehlo::IotaOp>(operand.getDefiningOp());
if (iotaOp) {
inputIota = iotaOp.getResult();
} else {
topKInput = operand;
}
}

if (!inputIota) {
return rewriter.notifyMatchFailure(op, "Sort isn't called from Iota.");
}

Block &block = op.getRegion().front();
auto stablehloCompareOp =
dyn_cast<mlir::stablehlo::CompareOp>(block.front());
if (!stablehloCompareOp) {
return rewriter.notifyMatchFailure(op, "not stablehlo compare op");
}

auto direction = stablehloCompareOp.getComparisonDirection();
bool getTop = direction == mlir::stablehlo::ComparisonDirection::GT ||
direction == mlir::stablehlo::ComparisonDirection::GE;

if (!getTop) {
return rewriter.notifyMatchFailure(op,
"Unsupported comparison direction");
}

Value topV, topI;
int64_t k;
// Check that the output of the sort op gets fed into a slice.
for (auto [idx, result] : llvm::enumerate(opResults)) {
auto sliceOp =
dyn_cast<mlir::stablehlo::SliceOp>(*result.getUsers().begin());
if (!sliceOp) {
return rewriter.notifyMatchFailure(
op, "Sort isn't calling into a slice op.");
}

for (auto stride : sliceOp.getStrides().getValues<int64_t>()) {
if (stride != 1) {
return rewriter.notifyMatchFailure(
op, "All slice strides must be 1 in order to match to TopK.");
}
}

// Treat the first slice as inputs, the second as indices.
if (idx == 0) {
topV = sliceOp.getResult();
k = sliceOp.getLimitIndices().getValues<int64_t>()[1];
} else {
topI = sliceOp.getResult();
}
}

auto topK = rewriter.create<chlo::TopKOp>(
op.getLoc(), TypeRange{topV.getType(), topI.getType()}, topKInput, k);
topV.replaceAllUsesWith(topK.getResults()[0]);
topI.replaceAllUsesWith(topK.getResults()[1]);
return success();
}
};

struct StableHLOToStableHLOPreprocessing final
: impl::StableHLOToStableHLOPreprocessingBase<
StableHLOToStableHLOPreprocessing> {
Expand Down Expand Up @@ -1605,6 +1686,9 @@ struct StableHLOToStableHLOPreprocessing final
// Identify known custom calls and convert them to equivalent StableHLO.
patterns.insert<CustomCallIsTopK>(context);

// Identify an iota->sort->slice pattern that maps to TopK.
patterns.insert<IotaSortSliceIsTopK>(context);

// Additional canonicalizers that simplify to computationally
// less-complex operations.
patterns.insert<DotToMul>(context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,22 @@ func.func private @comparison(%arg0: tensor<bf16>, %arg1: tensor<bf16>, %arg2: t
// CHECK-SAME: %[[ARG0:[a-z0-9]+]]
// CHECK: %[[VALUES:.+]], %[[INDICES:.+]] = chlo.top_k(%[[ARG0]], k = 40) : tensor<4x8000xbf16> -> (tensor<4x40xbf16>, tensor<4x40xi32>)
// CHECK: return %[[VALUES]], %[[INDICES]] : tensor<4x40xbf16>, tensor<4x40xi32>

// -----

func.func @iota_sort_slice_is_topk(%in : tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) {
%iota = "stablehlo.iota"() { iota_dimension = 1 : i64 } : () -> tensor<16x16xi32>
%0:2 = "stablehlo.sort"(%in, %iota) ({
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
%7 = "stablehlo.compare"(%arg0, %arg1) {comparison_direction = #stablehlo<comparison_direction GT>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"stablehlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>)
%1 = "stablehlo.slice"(%0#0) { start_indices = dense<[0, 0]> : tensor<2xi64>, limit_indices = dense<[16, 8]> : tensor<2xi64>, strides = dense<[1, 1]> : tensor<2xi64> } : (tensor<16x16xf32>) -> tensor<16x8xf32>
%2 = "stablehlo.slice"(%0#1) { start_indices = dense<[0, 0]> : tensor<2xi64>, limit_indices = dense<[16, 8]> : tensor<2xi64>, strides = dense<[1, 1]> : tensor<2xi64> } : (tensor<16x16xi32>) -> tensor<16x8xi32>
return %1, %2 : tensor<16x8xf32>, tensor<16x8xi32>
}

// CHECK-LABEL: @iota_sort_slice_is_topk
// CHECK-SAME: %[[IN:[a-z0-9]+]]
// CHECK: %[[VALUES:.+]], %[[INDICES:.+]] = chlo.top_k(%[[IN]], k = 8) : tensor<16x16xf32> -> (tensor<16x8xf32>, tensor<16x8xi32>)
// CHECK: return %[[VALUES]], %[[INDICES]] : tensor<16x8xf32>, tensor<16x8xi32>

0 comments on commit 7514d3e

Please sign in to comment.