Skip to content

Commit

Permalink
[xls][mlir] Speed up lower_counted_for
Browse files Browse the repository at this point in the history
On large modules this pass could dominate due to SymbolTable lookups.

Before: 2m05s. After 2s.
PiperOrigin-RevId: 685238926
  • Loading branch information
James Molloy authored and copybara-github committed Oct 12, 2024
1 parent a81cbf3 commit 203e3e7
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 22 deletions.
8 changes: 4 additions & 4 deletions xls/contrib/mlir/testdata/lower_counted_for.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func.func @reduce_arity_2(%arg0: i32, %arg1: i32) -> i32 attributes {xls = true}
return %0#0 : i32
}

// CHECK-LABEL: func.func private @for_body_3(
// CHECK-LABEL: func.func private @for_body_1(
// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) -> i32
// CHECK: %[[VAL_4:.*]] = arith.index_cast %[[VAL_0]] : i32 to index
// CHECK: %[[VAL_5:.*]] = arith.addi %[[VAL_2]], %[[VAL_3]] : i32
Expand All @@ -74,10 +74,10 @@ func.func @reduce_arity_2(%arg0: i32, %arg1: i32) -> i32 attributes {xls = true}
// CHECK-LABEL: func.func private @for_body_2(
// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32) -> i32
// CHECK: %[[VAL_3:.*]] = arith.index_cast %[[VAL_0]] : i32 to index
// CHECK: %[[VAL_4:.*]] = "xls.counted_for"(%[[VAL_1]], %[[VAL_1]], %[[VAL_2]]) <{stride = 1 : i64, to_apply = @for_body_3, trip_count = 1024 : i64}> : (i32, i32, i32) -> i32
// CHECK: %[[VAL_4:.*]] = "xls.counted_for"(%[[VAL_1]], %[[VAL_1]], %[[VAL_2]]) <{stride = 1 : i64, to_apply = @for_body_1, trip_count = 1024 : i64}> : (i32, i32, i32) -> i32
// CHECK: return %[[VAL_4]] : i32

// CHECK-LABEL: func.func private @for_body_1(
// CHECK-LABEL: func.func private @for_body_3(
// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32) -> i32
// CHECK: %[[VAL_3:.*]] = arith.index_cast %[[VAL_0]] : i32 to index
// CHECK: %[[VAL_4:.*]] = "xls.counted_for"(%[[VAL_1]], %[[VAL_2]]) <{stride = 1 : i64, to_apply = @for_body_2, trip_count = 1024 : i64}> : (i32, i32) -> i32
Expand All @@ -89,7 +89,7 @@ func.func @reduce_arity_2(%arg0: i32, %arg1: i32) -> i32 attributes {xls = true}
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1024 : index
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : i32
// CHECK: %[[VAL_5:.*]] = "xls.counted_for"(%[[VAL_4]], %[[VAL_0]]) <{stride = 1 : i64, to_apply = @for_body_1, trip_count = 1024 : i64}> : (i32, i32) -> i32
// CHECK: %[[VAL_5:.*]] = "xls.counted_for"(%[[VAL_4]], %[[VAL_0]]) <{stride = 1 : i64, to_apply = @for_body_3, trip_count = 1024 : i64}> : (i32, i32) -> i32
// CHECK: return %[[VAL_5]] : i32
func.func @triple_nest(%arg0: i32) -> i32 attributes {xls = true} {
%c0 = arith.constant 0 : index
Expand Down
36 changes: 18 additions & 18 deletions xls/contrib/mlir/transforms/lower_counted_for.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ namespace mlir::xls {
#define GEN_PASS_DEF_LOWERCOUNTEDFORPASS
#include "xls/contrib/mlir/transforms/passes.h.inc" // IWYU pragma: keep

using ::llvm::SmallVector;

namespace {
using ::llvm::SmallVector;

namespace fixed {
// TODO(jmolloy): This is a copy of the one in SCF utils. But that version is
Expand Down Expand Up @@ -229,24 +228,25 @@ class TuplifyRewrite : public OpConversionPattern<ForOp> {
}
};

std::string createUniqueName(Operation *op, std::string prefix) {
// TODO(jpienaar): This could be made more efficient. Current approach does
// work that could be cached and reused.
mlir::Operation *symbolTableOp =
op->getParentWithTrait<mlir::OpTrait::SymbolTable>();
if (mlir::SymbolTable::lookupSymbolIn(symbolTableOp, prefix) == nullptr) {
return prefix;
StringAttr createUniqueName(MLIRContext &context, SymbolTable &symbolTable,
DenseSet<StringRef> &addedSymbols,
StringRef prefix) {
if (symbolTable.lookup(prefix) == nullptr && !addedSymbols.contains(prefix)) {
addedSymbols.insert(prefix);
return StringAttr::get(&context, prefix);
}

unsigned uniquingCounter = 0;
llvm::SmallString<128> name = SymbolTable::generateSymbolName<128>(
prefix,
[&](llvm::StringRef candidate) {
return mlir::SymbolTable::lookupSymbolIn(symbolTableOp, candidate) !=
nullptr;
return symbolTable.lookup(candidate) ||
addedSymbols.contains(candidate.str());
},
uniquingCounter);
return std::string(name.str());
auto result = StringAttr::get(&context, name);
addedSymbols.insert(result);
return result;
}

class ForToCountedForRewrite : public OpConversionPattern<ForOp> {
Expand All @@ -267,9 +267,7 @@ class ForToCountedForRewrite : public OpConversionPattern<ForOp> {
// not be updated until after rewrites have completed (meaning
// createUniqueName would always return the same value in the same rewrite
// cycle causing clashes).
std::string preferredName =
cast<StringAttr>(op->getAttr(kPreferredNameAttr)).str();
std::string name = createUniqueName(op, preferredName);
std::string name = cast<StringAttr>(op->getAttr(kPreferredNameAttr)).str();

mlir::func::CallOp callOp;
auto func = fixed::outlineSingleBlockRegion(
Expand All @@ -293,10 +291,12 @@ class LowerCountedForPass
private:
void runOnOperation() override {
// See comment in ForToCountedForRewrite for why we do this.
DenseSet<StringRef> addedSymbols;
SymbolTable symbolTable(getOperation());
getOperation().walk([&](ForOp op) {
op->setAttr(
kPreferredNameAttr,
StringAttr::get(op->getContext(), createUniqueName(op, "for_body")));
op->setAttr(kPreferredNameAttr,
createUniqueName(getContext(), symbolTable, addedSymbols,
"for_body"));
});

ConversionTarget target(getContext());
Expand Down

0 comments on commit 203e3e7

Please sign in to comment.