Skip to content

Commit

Permalink
[stablehlo] Add matcher for ApproxTopK custom call (#14899)
Browse files Browse the repository at this point in the history
  • Loading branch information
rsuderman authored Sep 18, 2023
1 parent bef0e2d commit 2f72249
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -1756,6 +1756,85 @@ struct IotaSortSliceIsTopK final : OpRewritePattern<mlir::stablehlo::SortOp> {
}
};

struct ApproxTopK final : OpRewritePattern<mlir::stablehlo::CustomCallOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::CustomCallOp op,
PatternRewriter &rewriter) const override {
if (op.getCallTargetName() != "ApproxTopK")
return rewriter.notifyMatchFailure(op, "not ApproxTopK operation.");

auto computationName =
dyn_cast<SymbolRefAttr>(op.getCalledComputationsAttr()[0]);
Operation *funcOp;
for (auto parent = op->getParentOp(); parent;
parent = parent->getParentOp()) {
funcOp = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
parent, computationName);
if (funcOp)
break;
}
if (!funcOp)
return rewriter.notifyMatchFailure(op, "computation function not found.");

int64_t k = cast<ShapedType>(op.getType(0)).getDimSize(1);
auto input = op.getOperand(0);
auto iota = op.getOperand(1);

if (auto iotaOp =
dyn_cast_or_null<mlir::stablehlo::IotaOp>(iota.getDefiningOp())) {
int64_t iotaDim = iotaOp.getIotaDimension();
auto iotaLastDim = cast<ShapedType>(iotaOp.getType()).getRank() - 1;
if (iotaDim != iotaLastDim || iotaLastDim != 1) {
return rewriter.notifyMatchFailure(op, "Iota of last dim not found.");
}
}

Block &block = funcOp->getRegion(0).front();
auto stablehloCompareOp =
dyn_cast<mlir::stablehlo::CompareOp>(block.front());
if (!stablehloCompareOp) {
return rewriter.notifyMatchFailure(op, "not stablehlo compare op");
}

auto returnOp = block.getTerminator();
auto freturnOp = dyn_cast<func::ReturnOp>(returnOp);
auto sreturnOp = dyn_cast<mlir::stablehlo::ReturnOp>(returnOp);
if (!freturnOp && !sreturnOp) {
return rewriter.notifyMatchFailure(op, "could not find ReturnOp");
}

if (returnOp->getNumOperands() != 1 ||
returnOp->getOperand(0) != stablehloCompareOp.getResult()) {
return rewriter.notifyMatchFailure(op, "ReturnOp operand not compare op");
}

auto direction = stablehloCompareOp.getComparisonDirection();
bool getTop = direction == mlir::stablehlo::ComparisonDirection::GT ||
direction == mlir::stablehlo::ComparisonDirection::GE;
if (getTop) {
auto topK =
rewriter.create<chlo::TopKOp>(op.getLoc(), op.getResultTypes(), input,
rewriter.getI64IntegerAttr(k));
rewriter.replaceOp(op, topK);
return success();
}

bool getBottom = direction == mlir::stablehlo::ComparisonDirection::LT ||
direction == mlir::stablehlo::ComparisonDirection::LE;
if (getBottom) {
input = rewriter.create<mlir::stablehlo::NegOp>(op.getLoc(), input);
auto topK =
rewriter.create<chlo::TopKOp>(op.getLoc(), op.getResultTypes(), input,
rewriter.getI64IntegerAttr(k));
rewriter.replaceOp(op, topK);
return success();
}

return failure();
}
};

struct StableHLOToStableHLOPreprocessing final
: impl::StableHLOToStableHLOPreprocessingBase<
StableHLOToStableHLOPreprocessing> {
Expand Down Expand Up @@ -1825,7 +1904,7 @@ struct StableHLOToStableHLOPreprocessing final
patterns.insert<CustomCallIsTopK>(context);

// Identify an iota->sort->slice pattern that maps to TopK.
patterns.insert<IotaSortSliceIsTopK>(context);
patterns.insert<IotaSortSliceIsTopK, ApproxTopK>(context);

// Additional canonicalizers that simplify to computationally
// less-complex operations.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -473,3 +473,40 @@ func.func @concat_remove_zero_extents(%arg0: tensor<2x3xi32>, %arg1 : tensor<2x3
// CHECK: [[R0:%.+]] = stablehlo.concatenate %[[ARG0]], %[[ARG1]], dim = 1 : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x6xi32>
return %0 : tensor<2x6xi32>
}

// -----

func.func private @top_k_gt_f32_comparator(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>) -> tensor<i1> {
%0 = stablehlo.compare GT, %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<i1>
stablehlo.return %0 : tensor<i1>
}

// CHECK-LABEL: @custom_call_topk
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x160xf32>
func.func @custom_call_topk(%arg0 : tensor<1x160xf32>, %arg1 : tensor<f32>, %arg2 : tensor<i32>) -> (tensor<1x16xf32>, tensor<1x16xi32>) {
// CHECK: %[[TOPK:.+]], %[[IND:.+]] = chlo.top_k(%[[ARG0]], k = 16)
%iota = stablehlo.iota dim = 1 : tensor<1x16xi32>
%approx:2 = stablehlo.custom_call @ApproxTopK(%arg0, %iota, %arg1, %arg2) {called_computations = [@top_k_gt_f32_comparator], mhlo.backend_config = {aggregate_to_topk = true, is_fallback = true, recall_target = 8.500000e-01 : f32, reduction_dim = 1 : i64, reduction_input_size_override = -1 : i64, top_k = 16 : i64}} : (tensor<1x160xf32>, tensor<1x16xi32>, tensor<f32>, tensor<i32>) -> (tensor<1x16xf32>, tensor<1x16xi32>)

// CHECK: return %[[TOPK]], %[[IND]]
return %approx#0, %approx#1 : tensor<1x16xf32>, tensor<1x16xi32>
}

// -----

func.func private @bottom_k_gt_f32_comparator(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>) -> tensor<i1> {
%0 = stablehlo.compare LT, %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<i1>
stablehlo.return %0 : tensor<i1>
}

// CHECK-LABEL: @custom_call_bottomk
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x160xf32>
func.func @custom_call_bottomk(%arg0 : tensor<1x160xf32>, %arg1 : tensor<f32>, %arg2 : tensor<i32>) -> (tensor<1x16xf32>, tensor<1x16xi32>) {
// CHECK: %[[NEG:.+]] = stablehlo.negate %[[ARG0]]
// CHECK: %[[TOPK:.+]], %[[IND:.+]] = chlo.top_k(%[[NEG]], k = 16)
%iota = stablehlo.iota dim = 1 : tensor<1x16xi32>
%approx:2 = stablehlo.custom_call @ApproxTopK(%arg0, %iota, %arg1, %arg2) {called_computations = [@bottom_k_gt_f32_comparator], mhlo.backend_config = {aggregate_to_topk = true, is_fallback = true, recall_target = 8.500000e-01 : f32, reduction_dim = 1 : i64, reduction_input_size_override = -1 : i64, top_k = 16 : i64}} : (tensor<1x160xf32>, tensor<1x16xi32>, tensor<f32>, tensor<i32>) -> (tensor<1x16xf32>, tensor<1x16xi32>)

// CHECK: return %[[TOPK]], %[[IND]]
return %approx#0, %approx#1 : tensor<1x16xf32>, tensor<1x16xi32>
}

0 comments on commit 2f72249

Please sign in to comment.