Skip to content

Commit

Permalink
Merge pull request #157 from Xilinx/rogarcia.allow_parsing_of_diction…
Browse files Browse the repository at this point in the history
…aries_in_constraint

Allow parsing of dictionaries in constraint section
  • Loading branch information
roberteg16 authored Apr 9, 2024
2 parents 23171db + 3280a80 commit d288ea3
Show file tree
Hide file tree
Showing 11 changed files with 220 additions and 69 deletions.
10 changes: 6 additions & 4 deletions mlir/include/mlir/Dialect/PDL/IR/Builtins.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ namespace pdl {
void registerBuiltins(PDLPatternModule &pdlPattern);

namespace builtin {
Attribute createDictionaryAttr(PatternRewriter &rewriter);
Attribute addEntryToDictionaryAttr(PatternRewriter &rewriter,
Attribute dictAttr, Attribute attrName,
Attribute attrEntry);
LogicalResult createDictionaryAttr(PatternRewriter &rewriter,
PDLResultList &results,
ArrayRef<PDLValue> args);
LogicalResult addEntryToDictionaryAttr(PatternRewriter &rewriter,
PDLResultList &results,
ArrayRef<PDLValue> args);
Attribute createArrayAttr(PatternRewriter &rewriter);
Attribute addElemToArrayAttr(PatternRewriter &rewriter, Attribute attr,
Attribute element);
Expand Down
4 changes: 3 additions & 1 deletion mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def PDL_ApplyNativeConstraintOp
Variadic<PDL_AnyType>:$args,
DefaultValuedAttr<BoolAttr, "false">:$isNegated);
let results = (outs Variadic<PDL_AnyType>:$results);
let assemblyFormat = "$name `(` $args `:` type($args) `)` (`:` type($results)^ )? attr-dict";
let assemblyFormat = [{
$name (`(` $args^ `:` type($args) `)`)? (`:` type($results)^)? attr-dict
}];
let hasVerifier = 1;
}

Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
DefaultValuedAttr<BoolAttr, "false">:$isNegated);
let results = (outs Variadic<PDL_AnyType>:$results);
let assemblyFormat = [{
$name `(` $args `:` type($args) `)` (`:` type($results)^)? attr-dict `->` successors
$name (`(` $args^ `:` type($args) `)`)? (`:` type($results)^)? attr-dict `->` successors
}];
}

Expand Down
22 changes: 16 additions & 6 deletions mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,14 +265,24 @@ static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
DenseMap<Value, Position *> &inputs) {
OperandRange arguments = op.getArgs();

Position *pos = nullptr;
std::vector<Position *> allPositions;
allPositions.reserve(arguments.size());
for (Value arg : arguments)
allPositions.push_back(inputs.lookup(arg));

// Push the constraint to the furthest position.
Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
comparePosDepth);
// If this constraint has no arguments, this means it has no dependencies, and
// the same applies to all results
if (arguments.empty()) {
pos = builder.getRoot();
} else {
allPositions.reserve(arguments.size());
for (Value arg : arguments)
allPositions.push_back(inputs.lookup(arg));

// Push the constraint to the furthest position.
pos = *std::max_element(allPositions.begin(), allPositions.end(),
comparePosDepth);
}
assert(pos && "Must have a non-null value");

ResultRange results = op.getResults();
PredicateBuilder::Predicate pred = builder.getConstraint(
op.getName(), allPositions, SmallVector<Type>(results.getTypes()),
Expand Down
38 changes: 24 additions & 14 deletions mlir/lib/Dialect/PDL/IR/Builtins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,31 @@ using namespace mlir;

namespace mlir::pdl {
namespace builtin {
mlir::Attribute createDictionaryAttr(mlir::PatternRewriter &rewriter) {
return rewriter.getDictionaryAttr({});

LogicalResult createDictionaryAttr(PatternRewriter &rewriter,
PDLResultList &results,
ArrayRef<PDLValue> args) {
results.push_back(rewriter.getDictionaryAttr({}));
return success();
}

mlir::Attribute addEntryToDictionaryAttr(mlir::PatternRewriter &rewriter,
mlir::Attribute dictAttr,
mlir::Attribute attrName,
mlir::Attribute attrEntry) {
assert(isa<DictionaryAttr>(dictAttr));
auto attr = dictAttr.cast<DictionaryAttr>();
auto name = attrName.cast<StringAttr>();
std::vector<NamedAttribute> values = attr.getValue().vec();
LogicalResult addEntryToDictionaryAttr(PatternRewriter &rewriter,
PDLResultList &results,
ArrayRef<PDLValue> args) {
auto dictAttr = cast<DictionaryAttr>(args[0].cast<Attribute>());
auto name = cast<StringAttr>(args[1].cast<Attribute>());
auto attrEntry = args[2].cast<Attribute>();

std::vector<NamedAttribute> values = dictAttr.getValue().vec();

// Remove entry if it exists in the dictionary.
llvm::erase_if(values, [&](NamedAttribute &namedAttr) {
return namedAttr.getName() == name.getValue();
});

values.push_back(rewriter.getNamedAttr(name, attrEntry));
return rewriter.getDictionaryAttr(values);
results.push_back(rewriter.getDictionaryAttr(values));
return success();
}

mlir::Attribute createArrayAttr(mlir::PatternRewriter &rewriter) {
Expand Down Expand Up @@ -110,14 +115,19 @@ LogicalResult add(mlir::PatternRewriter &rewriter, mlir::PDLResultList &results,
void registerBuiltins(PDLPatternModule &pdlPattern) {
using namespace builtin;
// See Parser::defineBuiltins()
pdlPattern.registerRewriteFunction("__builtin_createDictionaryAttr",
pdlPattern.registerRewriteFunction("__builtin_createDictionaryAttr_rewrite",
createDictionaryAttr);
pdlPattern.registerRewriteFunction("__builtin_addEntryToDictionaryAttr",
addEntryToDictionaryAttr);
pdlPattern.registerRewriteFunction(
"__builtin_addEntryToDictionaryAttr_rewrite", addEntryToDictionaryAttr);
pdlPattern.registerRewriteFunction("__builtin_createArrayAttr",
createArrayAttr);
pdlPattern.registerRewriteFunction("__builtin_addElemToArrayAttr",
addElemToArrayAttr);
pdlPattern.registerConstraintFunctionWithResults(
"__builtin_createDictionaryAttr_constraint", createDictionaryAttr);
pdlPattern.registerConstraintFunctionWithResults(
"__builtin_addEntryToDictionaryAttr_constraint",
addEntryToDictionaryAttr);
pdlPattern.registerConstraintFunctionWithResults("__builtin_add", add);
}
} // namespace mlir::pdl
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/PDL/IR/PDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ static void visit(Operation *op, DenseSet<Operation *> &visited) {
//===----------------------------------------------------------------------===//

LogicalResult ApplyNativeConstraintOp::verify() {
if (getNumOperands() == 0)
return emitOpError("expected at least one argument");
if (getNumOperands() == 0 && getNumResults() == 0)
return emitOpError("expected at least one argument or result");
return success();
}

Expand Down
56 changes: 42 additions & 14 deletions mlir/lib/Tools/PDLL/Parser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -600,8 +600,10 @@ class Parser {
CodeCompleteContext *codeCompleteContext;

struct {
ast::UserRewriteDecl *createDictionaryAttr;
ast::UserRewriteDecl *addEntryToDictionaryAttr;
ast::UserRewriteDecl *createDictionaryAttr_Rewrite;
ast::UserConstraintDecl *createDictionaryAttr_Constraint;
ast::UserRewriteDecl *addEntryToDictionaryAttr_Rewrite;
ast::UserConstraintDecl *addEntryToDictionaryAttr_Constraint;
ast::UserRewriteDecl *createArrayAttr;
ast::UserRewriteDecl *addElemToArrayAttr;
ast::UserConstraintDecl *add;
Expand Down Expand Up @@ -640,11 +642,22 @@ T *Parser::declareBuiltin(StringRef name, ArrayRef<StringRef> argNames,
}

void Parser::declareBuiltins() {
builtins.createDictionaryAttr = declareBuiltin<ast::UserRewriteDecl>(
"__builtin_createDictionaryAttr", {}, /*returnsAttr=*/true);
builtins.addEntryToDictionaryAttr = declareBuiltin<ast::UserRewriteDecl>(
"__builtin_addEntryToDictionaryAttr", {"attr", "attrName", "attrEntry"},
/*returnsAttr=*/true);
builtins.createDictionaryAttr_Rewrite = declareBuiltin<ast::UserRewriteDecl>(
"__builtin_createDictionaryAttr_rewrite", {}, /*returnsAttr=*/true);
builtins.createDictionaryAttr_Constraint =
declareBuiltin<ast::UserConstraintDecl>(
"__builtin_createDictionaryAttr_constraint", {},
/*returnsAttr=*/true);
builtins.addEntryToDictionaryAttr_Rewrite =
declareBuiltin<ast::UserRewriteDecl>(
"__builtin_addEntryToDictionaryAttr_rewrite",
{"attr", "attrName", "attrEntry"},
/*returnsAttr=*/true);
builtins.addEntryToDictionaryAttr_Constraint =
declareBuiltin<ast::UserConstraintDecl>(
"__builtin_addEntryToDictionaryAttr_constraint",
{"attr", "attrName", "attrEntry"},
/*returnsAttr=*/true);
builtins.createArrayAttr = declareBuiltin<ast::UserRewriteDecl>(
"__builtin_createArrayAttr", {}, /*returnsAttr=*/true);
builtins.addElemToArrayAttr = declareBuiltin<ast::UserRewriteDecl>(
Expand Down Expand Up @@ -2110,14 +2123,22 @@ FailureOr<ast::Expr *> Parser::parseDictAttrExpr() {
consumeToken(Token::l_brace);
SMRange loc = curToken.getLoc();

if (parserContext != ParserContext::Rewrite)
return emitError(
"Parsing of dictionary attributes as constraint not supported!");

auto dictAttrCall = createBuiltinCall(loc, builtins.createDictionaryAttr, {});
FailureOr<ast::Expr *> dictAttrCall;
if (parserContext == ParserContext::Rewrite) {
dictAttrCall =
createBuiltinCall(loc, builtins.createDictionaryAttr_Rewrite, {});
} else {
dictAttrCall =
createBuiltinCall(loc, builtins.createDictionaryAttr_Constraint, {});
}
if (failed(dictAttrCall))
return failure();

// No key-values inside dictionary
if (consumeIf(Token::r_brace)) {
return dictAttrCall;
}

// Add each nested attribute to the dict
do {
FailureOr<ast::NamedAttributeDecl *> decl =
Expand Down Expand Up @@ -2148,8 +2169,15 @@ FailureOr<ast::Expr *> Parser::parseDictAttrExpr() {
// Create addEntryToDictionaryAttr native call.
SmallVector<ast::Expr *> arrayAttrArgs{*dictAttrCall, *stringAttrRef,
namedDecl->getValue()};
auto entryToDictionaryCall = createBuiltinCall(
loc, builtins.addEntryToDictionaryAttr, arrayAttrArgs);

FailureOr<ast::Expr *> entryToDictionaryCall;
if (parserContext == ParserContext::Rewrite) {
entryToDictionaryCall = createBuiltinCall(
loc, builtins.addEntryToDictionaryAttr_Rewrite, arrayAttrArgs);
} else {
entryToDictionaryCall = createBuiltinCall(
loc, builtins.addEntryToDictionaryAttr_Constraint, arrayAttrArgs);
}
if (failed(entryToDictionaryCall))
return failure();

Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,20 @@ module @constraint_with_result_multiple {

// -----

// CHECK-LABEL: module @constraint_with_no_inputs
module @constraint_with_no_inputs {
pdl.pattern : benefit(0) {
%attr = apply_native_constraint "constraint_no_inputs" : !pdl.attribute
%root = operation
rewrite %root with "rewriter"
}
}

// CHECK: func @matcher(%arg0: !pdl.operation) {
// CHECK: pdl_interp.apply_constraint "constraint_no_inputs" : !pdl.attribute -> ^{{.*}}, ^{{.*}}

// -----

// CHECK-LABEL: module @negated_constraint
module @negated_constraint {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
Expand Down
50 changes: 40 additions & 10 deletions mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,9 @@ Pattern TypeExpr => erase op<> -> (type<"i32">);
// CHECK: %[[VAL_1:.*]] = operation "test.op"
// CHECK: %[[VAL_2:.*]] = attribute = "test"
// CHECK: rewrite %[[VAL_1]] {
// CHECK: %[[VAL_3:.*]] = apply_native_rewrite "__builtin_createDictionaryAttr"
// CHECK: %[[VAL_3:.*]] = apply_native_rewrite "__builtin_createDictionaryAttr_rewrite"
// CHECK: %[[VAL_4:.*]] = attribute = "firstAttr"
// CHECK: %[[VAL_5:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr"(%[[VAL_3]], %[[VAL_4]], %[[VAL_2]]
// CHECK: %[[VAL_5:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr_rewrite"(%[[VAL_3]], %[[VAL_4]], %[[VAL_2]]
// CHECK: %[[VAL_6:.*]] = operation "test.success" {"some_dictionary" = %[[VAL_5]]}
// CHECK: replace %[[VAL_1]] with %[[VAL_6]]
Pattern RewriteOneEntryDictionary {
Expand All @@ -188,14 +188,14 @@ Pattern RewriteOneEntryDictionary {
// CHECK: %[[VAL_2:.*]] = attribute = "test2"
// CHECK: %[[VAL_3:.*]] = attribute = "test3"
// CHECK: rewrite %[[VAL_1]] {
// CHECK: %[[VAL_4:.*]] = apply_native_rewrite "__builtin_createDictionaryAttr"
// CHECK: %[[VAL_4:.*]] = apply_native_rewrite "__builtin_createDictionaryAttr_rewrite"
// CHECK: %[[VAL_5:.*]] = attribute = "firstAttr"
// CHECK: %[[VAL_6:.*]] = attribute = "test1"
// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr"(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]]
// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr_rewrite"(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]]
// CHECK: %[[VAL_8:.*]] = attribute = "secondAttr"
// CHECK: %[[VAL_9:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr"(%[[VAL_7]], %[[VAL_8]], %[[VAL_2]]
// CHECK: %[[VAL_9:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr_rewrite"(%[[VAL_7]], %[[VAL_8]], %[[VAL_2]]
// CHECK: %[[VAL_10:.*]] = attribute = "thirdAttr"
// CHECK: %[[VAL_11:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr"(%[[VAL_9]], %[[VAL_10]], %[[VAL_3]]
// CHECK: %[[VAL_11:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr_rewrite"(%[[VAL_9]], %[[VAL_10]], %[[VAL_3]]
// CHECK: %[[VAL_12:.*]] = operation "test.success" {"some_dictionary" = %[[VAL_11]]}
// CHECK: replace %[[VAL_1]] with %[[VAL_12]]
Pattern RewriteMultipleEntriesDictionary {
Expand All @@ -214,10 +214,10 @@ Pattern RewriteMultipleEntriesDictionary {
// CHECK: %[[VAL_1:.*]] = operation "test.op"
// CHECK: rewrite %[[VAL_1]] {
// CHECK: %[[VAL_2:.*]] = apply_native_rewrite "__builtin_createArrayAttr"
// CHECK: %[[VAL_3:.*]] = apply_native_rewrite "__builtin_createDictionaryAttr"
// CHECK: %[[VAL_3:.*]] = apply_native_rewrite "__builtin_createDictionaryAttr_rewrite"
// CHECK: %[[VAL_4:.*]] = attribute = "firstAttr"
// CHECK: %[[VAL_5:.*]] = attribute = "test1"
// CHECK: %[[VAL_6:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr"(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]]
// CHECK: %[[VAL_6:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr_rewrite"(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]]
// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttr"(%[[VAL_2]], %[[VAL_6]]
// CHECK: %[[VAL_8:.*]] = operation "test.success" {"some_array" = %[[VAL_7]]}
// CHECK: replace %[[VAL_1]] with %[[VAL_8]]
Expand All @@ -236,10 +236,10 @@ Pattern RewriteOneDictionaryArrayAttr {
// CHECK: %[[VAL_2:.*]] = attribute = "test2"
// CHECK: rewrite %[[VAL_1]] {
// CHECK: %[[VAL_3:.*]] = apply_native_rewrite "__builtin_createArrayAttr"
// CHECK: %[[VAL_4:.*]] = apply_native_rewrite "__builtin_createDictionaryAttr"
// CHECK: %[[VAL_4:.*]] = apply_native_rewrite "__builtin_createDictionaryAttr_rewrite"
// CHECK: %[[VAL_5:.*]] = attribute = "firstAttr"
// CHECK: %[[VAL_6:.*]] = attribute = "test1"
// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr"(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]]
// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr_rewrite"(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]]
// CHECK: %[[VAL_8:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttr"(%[[VAL_3]], %[[VAL_7]]
// CHECK: %[[VAL_9:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttr"(%[[VAL_8]], %[[VAL_2]]
// CHECK: %[[VAL_10:.*]] = operation "test.success" {"some_array" = %[[VAL_9]]}
Expand All @@ -253,6 +253,36 @@ Pattern RewriteMultiplyElementsArrayAttr {
};
}

// -----

// CHECK-LABEL: pdl.pattern @ConstraintWithEmptyDictionary : benefit(0) {
// CHECK: %[[DICT:.+]] = apply_native_constraint "__builtin_createDictionaryAttr_constraint" : !pdl.attribute
// CHECK: rewrite %{{.*}} {
// CHECK-NEXT: operation "test.success" {"importantAttr" = %[[DICT]]}

Pattern ConstraintWithEmptyDictionary {
let dict = {};
replace op<test.op> with op<test.success>() { importantAttr = dict };
}

// -----
// CHECK-LABEL: pdl.pattern @ConstraintWithTwoEntriesDictionary : benefit(0) {
// CHECK: %[[DICT:.+]] = apply_native_constraint "__builtin_createDictionaryAttr_constraint" : !pdl.attribute
// CHECK-NEXT: %[[KEY:.+]] = attribute = "hello"
// CHECK-NEXT: %[[VAL:.+]] = attribute = "world"
// CHECK-NEXT: %[[DICT2:.+]] = apply_native_constraint "__builtin_addEntryToDictionaryAttr_constraint"(%[[DICT]], %[[KEY]], %[[VAL]] : !pdl.attribute, !pdl.attribute, !pdl.attribute) : !pdl.attribute
// CHECK-NEXT: %[[KEY2:.+]] = attribute = "bye"
// CHECK-NEXT: %[[VAL2:.+]] = attribute = 100 : ui8
// CHECK-NEXT: %[[DICT3:.+]] = apply_native_constraint "__builtin_addEntryToDictionaryAttr_constraint"(%[[DICT2]], %[[KEY2]], %[[VAL2]] : !pdl.attribute, !pdl.attribute, !pdl.attribute) : !pdl.attribute
// CHECK: rewrite %{{.*}} {
// CHECK-NEXT: operation "test.success" {"importantAttr" = %[[DICT3]]}


Pattern ConstraintWithTwoEntriesDictionary {
let dict = { hello = "world", bye = 100 : ui8 };
replace op<test.op> with op<test.success>() { importantAttr = dict };
}

// -----
// CHECK-LABEL: pdl.pattern @TestAdd : benefit(0) {
// CHECK: %[[VAL_0:.*]] = attribute = 4 : i32
Expand Down
Loading

0 comments on commit d288ea3

Please sign in to comment.