diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp index 17e7a8224fac8..be2460726c62c 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp @@ -1756,6 +1756,85 @@ struct IotaSortSliceIsTopK final : OpRewritePattern { } }; +struct ApproxTopK final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::CustomCallOp op, + PatternRewriter &rewriter) const override { + if (op.getCallTargetName() != "ApproxTopK") + return failure(); + + auto computationName = + dyn_cast(op.getCalledComputationsAttr()[0]); + Operation *funcOp; + for (auto parent = op->getParentOp(); parent; + parent = parent->getParentOp()) { + funcOp = SymbolTable::lookupNearestSymbolFrom( + parent, computationName); + if (funcOp) + break; + } + if (!funcOp) + return failure(); + + int64_t k = cast(op.getType(0)).getDimSize(1); + auto input = op.getOperand(0); + auto iota = op.getOperand(1); + + if (auto iotaOp = + dyn_cast_or_null(iota.getDefiningOp())) { + int64_t iotaDim = iotaOp.getIotaDimension(); + auto iotaLastDim = cast(iotaOp.getType()).getRank() - 1; + if (iotaDim != iotaLastDim || iotaLastDim != 1) { + return failure(); + } + } + + Block &block = funcOp->getRegion(0).front(); + auto stablehloCompareOp = + dyn_cast(block.front()); + if (!stablehloCompareOp) { + return rewriter.notifyMatchFailure(op, "not stablehlo compare op"); + } + + auto returnOp = block.getTerminator(); + auto freturnOp = dyn_cast(returnOp); + auto sreturnOp = dyn_cast(returnOp); + if (!freturnOp && !sreturnOp) { + return rewriter.notifyMatchFailure(op, "could not find ReturnOp"); + } + + if (returnOp->getNumOperands() != 1 || + returnOp->getOperand(0) != stablehloCompareOp.getResult()) { + return rewriter.notifyMatchFailure(op, "ReturnOp operand not compare op"); + } + + auto direction = stablehloCompareOp.getComparisonDirection(); + bool getTop = direction == mlir::stablehlo::ComparisonDirection::GT || + direction == mlir::stablehlo::ComparisonDirection::GE; + if (getTop) { + auto topK = + rewriter.create(op.getLoc(), op.getResultTypes(), input, + rewriter.getI64IntegerAttr(k)); + rewriter.replaceOp(op, topK); + return success(); + } + + bool getBottom = direction == mlir::stablehlo::ComparisonDirection::LT || + direction == mlir::stablehlo::ComparisonDirection::LE; + if (getBottom) { + input = rewriter.create(op.getLoc(), input); + auto topK = + rewriter.create(op.getLoc(), op.getResultTypes(), input, + rewriter.getI64IntegerAttr(k)); + rewriter.replaceOp(op, topK); + return success(); + } + + return failure(); + } +}; + struct StableHLOToStableHLOPreprocessing final : impl::StableHLOToStableHLOPreprocessingBase< StableHLOToStableHLOPreprocessing> { @@ -1825,7 +1904,7 @@ struct StableHLOToStableHLOPreprocessing final patterns.insert(context); // Identify an iota->sort->slice pattern that maps to TopK. - patterns.insert(context); + patterns.insert(context); // Additional canonicalizers that simplify to computationally // less-complex operations. diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/stablehlo_to_stablehlo.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/stablehlo_to_stablehlo.mlir index debce84d62892..56694fe4a93fb 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/stablehlo_to_stablehlo.mlir +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/stablehlo_to_stablehlo.mlir @@ -474,3 +474,40 @@ func.func @concat_remove_zero_extents(%arg0: tensor<2x3xi32>, %arg1 : tensor<2x3 return %0 : tensor<2x6xi32> } + +// ----- + +func.func private @top_k_gt_f32_comparator(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> tensor { + %0 = stablehlo.compare GT, %arg0, %arg1 : (tensor, tensor) -> tensor + stablehlo.return %0 : tensor +} + +// CHECK-LABEL: @custom_call_topk +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x160xf32> +func.func @custom_call_topk(%arg0 : tensor<1x160xf32>, %arg1 : tensor, %arg2 : tensor) -> (tensor<1x16xf32>, tensor<1x16xi32>) { + // CHECK: %[[TOPK:.+]], %[[IND:.+]] = chlo.top_k(%[[ARG0]], k = 16) + %iota = stablehlo.iota dim = 1 : tensor<1x16xi32> + %approx:2 = stablehlo.custom_call @ApproxTopK(%arg0, %iota, %arg1, %arg2) {called_computations = [@top_k_gt_f32_comparator], mhlo.backend_config = {aggregate_to_topk = true, is_fallback = true, recall_target = 8.500000e-01 : f32, reduction_dim = 1 : i64, reduction_input_size_override = -1 : i64, top_k = 16 : i64}} : (tensor<1x160xf32>, tensor<1x16xi32>, tensor, tensor) -> (tensor<1x16xf32>, tensor<1x16xi32>) + + // CHECK: return %[[TOPK]], %[[IND]] + return %approx#0, %approx#1 : tensor<1x16xf32>, tensor<1x16xi32> +} + +// ----- + +func.func private @bottom_k_gt_f32_comparator(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> tensor { + %0 = stablehlo.compare LT, %arg0, %arg1 : (tensor, tensor) -> tensor + stablehlo.return %0 : tensor +} + +// CHECK-LABEL: @custom_call_bottomk +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x160xf32> +func.func @custom_call_bottomk(%arg0 : tensor<1x160xf32>, %arg1 : tensor, %arg2 : tensor) -> (tensor<1x16xf32>, tensor<1x16xi32>) { + // CHECK: %[[NEG:.+]] = stablehlo.negate %[[ARG0]] + // CHECK: %[[TOPK:.+]], %[[IND:.+]] = chlo.top_k(%[[NEG]], k = 16) + %iota = stablehlo.iota dim = 1 : tensor<1x16xi32> + %approx:2 = stablehlo.custom_call @ApproxTopK(%arg0, %iota, %arg1, %arg2) {called_computations = [@bottom_k_gt_f32_comparator], mhlo.backend_config = {aggregate_to_topk = true, is_fallback = true, recall_target = 8.500000e-01 : f32, reduction_dim = 1 : i64, reduction_input_size_override = -1 : i64, top_k = 16 : i64}} : (tensor<1x160xf32>, tensor<1x16xi32>, tensor, tensor) -> (tensor<1x16xf32>, tensor<1x16xi32>) + + // CHECK: return %[[TOPK]], %[[IND]] + return %approx#0, %approx#1 : tensor<1x16xf32>, tensor<1x16xi32> +}