From 7514d3e8c28c53053f434fbfcf1c75a73eaaa09e Mon Sep 17 00:00:00 2001 From: NatashaKnk Date: Fri, 9 Jun 2023 12:25:02 -0700 Subject: [PATCH] Add pattern to map iota->sort->slice to topK (#13972) --- .../Preprocessing/StableHLOToStableHLO.cpp | 84 +++++++++++++++++++ .../test/stablehlo_to_stablehlo.mlir | 19 +++++ 2 files changed, 103 insertions(+) diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp index 6dad3e3277b7..15692381bfb8 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp @@ -1540,6 +1540,87 @@ struct CustomCallIsTopK final } }; +struct IotaSortSliceIsTopK final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::SortOp op, + PatternRewriter &rewriter) const override { + auto opOperands = op.getOperands(); + auto opResults = op.getResults(); + Value topKInput; + if (opOperands.size() != 2 || opResults.size() != 2) { + return rewriter.notifyMatchFailure( + op, "Slice that maps to TopK must have exactly two inputs/outputs"); + } + + Value inputIota; + // Check that one of the inputs is iota, assume that the other one is the + // input. + for (Value operand : opOperands) { + auto iotaOp = + dyn_cast_or_null(operand.getDefiningOp()); + if (iotaOp) { + inputIota = iotaOp.getResult(); + } else { + topKInput = operand; + } + } + + if (!inputIota) { + return rewriter.notifyMatchFailure(op, "Sort isn't called from Iota."); + } + + Block &block = op.getRegion().front(); + auto stablehloCompareOp = + dyn_cast(block.front()); + if (!stablehloCompareOp) { + return rewriter.notifyMatchFailure(op, "not stablehlo compare op"); + } + + auto direction = stablehloCompareOp.getComparisonDirection(); + bool getTop = direction == mlir::stablehlo::ComparisonDirection::GT || + direction == mlir::stablehlo::ComparisonDirection::GE; + + if (!getTop) { + return rewriter.notifyMatchFailure(op, + "Unsupported comparison direction"); + } + + Value topV, topI; + int64_t k; + // Check that the output of the sort op gets fed into a slice. + for (auto [idx, result] : llvm::enumerate(opResults)) { + auto sliceOp = + dyn_cast(*result.getUsers().begin()); + if (!sliceOp) { + return rewriter.notifyMatchFailure( + op, "Sort isn't calling into a slice op."); + } + + for (auto stride : sliceOp.getStrides().getValues()) { + if (stride != 1) { + return rewriter.notifyMatchFailure( + op, "All slice strides must be 1 in order to match to TopK."); + } + } + + // Treat the first slice as inputs, the second as indices. + if (idx == 0) { + topV = sliceOp.getResult(); + k = sliceOp.getLimitIndices().getValues()[1]; + } else { + topI = sliceOp.getResult(); + } + } + + auto topK = rewriter.create( + op.getLoc(), TypeRange{topV.getType(), topI.getType()}, topKInput, k); + topV.replaceAllUsesWith(topK.getResults()[0]); + topI.replaceAllUsesWith(topK.getResults()[1]); + return success(); + } +}; + struct StableHLOToStableHLOPreprocessing final : impl::StableHLOToStableHLOPreprocessingBase< StableHLOToStableHLOPreprocessing> { @@ -1605,6 +1686,9 @@ struct StableHLOToStableHLOPreprocessing final // Identify known custom calls and convert them to equivalent StableHLO. patterns.insert(context); + // Identify an iota->sort->slice pattern that maps to TopK. + patterns.insert(context); + // Additional canonicalizers that simplify to computationally // less-complex operations. patterns.insert(context); diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/stablehlo_to_stablehlo.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/stablehlo_to_stablehlo.mlir index 90ffb6c01517..8d7c9d5874e7 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/stablehlo_to_stablehlo.mlir +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/stablehlo_to_stablehlo.mlir @@ -390,3 +390,22 @@ func.func private @comparison(%arg0: tensor, %arg1: tensor, %arg2: t // CHECK-SAME: %[[ARG0:[a-z0-9]+]] // CHECK: %[[VALUES:.+]], %[[INDICES:.+]] = chlo.top_k(%[[ARG0]], k = 40) : tensor<4x8000xbf16> -> (tensor<4x40xbf16>, tensor<4x40xi32>) // CHECK: return %[[VALUES]], %[[INDICES]] : tensor<4x40xbf16>, tensor<4x40xi32> + +// ----- + +func.func @iota_sort_slice_is_topk(%in : tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) { + %iota = "stablehlo.iota"() { iota_dimension = 1 : i64 } : () -> tensor<16x16xi32> + %0:2 = "stablehlo.sort"(%in, %iota) ({ + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): + %7 = "stablehlo.compare"(%arg0, %arg1) {comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%7) : (tensor) -> () + }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) + %1 = "stablehlo.slice"(%0#0) { start_indices = dense<[0, 0]> : tensor<2xi64>, limit_indices = dense<[16, 8]> : tensor<2xi64>, strides = dense<[1, 1]> : tensor<2xi64> } : (tensor<16x16xf32>) -> tensor<16x8xf32> + %2 = "stablehlo.slice"(%0#1) { start_indices = dense<[0, 0]> : tensor<2xi64>, limit_indices = dense<[16, 8]> : tensor<2xi64>, strides = dense<[1, 1]> : tensor<2xi64> } : (tensor<16x16xi32>) -> tensor<16x8xi32> + return %1, %2 : tensor<16x8xf32>, tensor<16x8xi32> +} + +// CHECK-LABEL: @iota_sort_slice_is_topk +// CHECK-SAME: %[[IN:[a-z0-9]+]] +// CHECK: %[[VALUES:.+]], %[[INDICES:.+]] = chlo.top_k(%[[IN]], k = 8) : tensor<16x16xf32> -> (tensor<16x8xf32>, tensor<16x8xi32>) +// CHECK: return %[[VALUES]], %[[INDICES]] : tensor<16x8xf32>, tensor<16x8xi32>