Skip to content

Commit

Permalink
PDLL: Add equals operator
Browse files Browse the repository at this point in the history
  • Loading branch information
mgehre-amd committed Jan 27, 2024
1 parent 9d3b31f commit 9d16822
Show file tree
Hide file tree
Showing 11 changed files with 141 additions and 2 deletions.
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/PDL/IR/Builtins.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#ifndef MLIR_DIALECT_PDL_IR_BUILTINS_H_
#define MLIR_DIALECT_PDL_IR_BUILTINS_H_

#include "mlir/Support/LogicalResult.h"

namespace mlir {
class PDLPatternModule;
class Attribute;
Expand All @@ -29,6 +31,7 @@ Attribute addEntryToDictionaryAttr(PatternRewriter &rewriter,
Attribute createArrayAttr(PatternRewriter &rewriter);
Attribute addElemToArrayAttr(PatternRewriter &rewriter, Attribute attr,
Attribute element);
LogicalResult equals(PatternRewriter &rewriter, Attribute lhs, Attribute rhs);
} // namespace builtin
} // namespace pdl
} // namespace mlir
Expand Down
25 changes: 25 additions & 0 deletions mlir/lib/Dialect/PDL/IR/Builtins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,30 @@ mlir::Attribute addElemToArrayAttr(mlir::PatternRewriter &rewriter,
values.push_back(element);
return rewriter.getArrayAttr(values);
}

LogicalResult equals(mlir::PatternRewriter &, mlir::Attribute lhs,
mlir::Attribute rhs) {
if (auto lhsAttr = dyn_cast_or_null<IntegerAttr>(lhs)) {
auto rhsAttr = dyn_cast_or_null<IntegerAttr>(rhs);
if (!rhsAttr || lhsAttr.getType() != rhsAttr.getType())
return failure();

APInt lhsVal = lhsAttr.getValue();
APInt rhsVal = rhsAttr.getValue();
return success(lhsVal.eq(rhsVal));
}

if (auto lhsAttr = dyn_cast_or_null<FloatAttr>(lhs)) {
auto rhsAttr = dyn_cast_or_null<FloatAttr>(rhs);
if (!rhsAttr || lhsAttr.getType() != rhsAttr.getType())
return failure();

APFloat lhsVal = lhsAttr.getValue();
APFloat rhsVal = rhsAttr.getValue();
return success(lhsVal.compare(rhsVal) == llvm::APFloatBase::cmpEqual);
}
return failure();
}
} // namespace builtin

void registerBuiltins(PDLPatternModule &pdlPattern) {
Expand All @@ -52,5 +76,6 @@ void registerBuiltins(PDLPatternModule &pdlPattern) {
createArrayAttr);
pdlPattern.registerRewriteFunction("__builtin_addElemToArrayAttr",
addElemToArrayAttr);
pdlPattern.registerConstraintFunction("__builtin_equals", equals);
}
} // namespace mlir::pdl
4 changes: 4 additions & 0 deletions mlir/lib/Tools/PDLL/Parser/Lexer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ Token Lexer::lexToken() {
++curPtr;
return formToken(Token::equal_arrow, tokStart);
}
if (*curPtr == '=') {
++curPtr;
return formToken(Token::equal_equal, tokStart);
}
return formToken(Token::equal, tokStart);
case ';':
return formToken(Token::semicolon, tokStart);
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Tools/PDLL/Parser/Lexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class Token {
dot,
equal,
equal_arrow,
equal_equal,
semicolon,
exclam,
/// Paired punctuation.
Expand Down
24 changes: 22 additions & 2 deletions mlir/lib/Tools/PDLL/Parser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,7 @@ class Parser {
ast::UserRewriteDecl *addEntryToDictionaryAttr;
ast::UserRewriteDecl *createArrayAttr;
ast::UserRewriteDecl *addElemToArrayAttr;
ast::UserConstraintDecl *equals;
} builtins{};
};
} // namespace
Expand Down Expand Up @@ -629,7 +630,7 @@ T *Parser::declareBuiltin(StringRef name, ArrayRef<StringRef> argNames,
popDeclScope();

auto *constraintDecl = T::createNative(ctx, ast::Name::create(ctx, name, loc),
args, results, {}, attrTy);
args, results, {}, createUserConstraintRewriteResultType(results));
curDeclScope->add(constraintDecl);
return constraintDecl;
}
Expand All @@ -645,6 +646,10 @@ void Parser::declareBuiltins() {
builtins.addElemToArrayAttr = declareBuiltin<ast::UserRewriteDecl>(
"__builtin_addElemToArrayAttr", {"attr", "element"},
/*returnsAttr=*/true);

builtins.equals = declareBuiltin<ast::UserConstraintDecl>(
"__builtin_equals", {"lhs", "rhs"},
/*returnsAttr=*/false);
}

FailureOr<ast::Module *> Parser::parseModule() {
Expand Down Expand Up @@ -1892,7 +1897,22 @@ FailureOr<ast::Expr *> Parser::parseLogicalAndExpr() {
}

FailureOr<ast::Expr *> Parser::parseEqualityExpr() {
return parseRelationExpr();
auto lhs = parseRelationExpr();
if (failed(lhs))
return failure();

switch (curToken.getKind()) {
case Token::equal_equal: {
consumeToken();
auto rhs = parseRelationExpr();
if (failed(rhs))
return failure();
SmallVector<ast::Expr *> args{*lhs, *rhs};
return createBuiltinCall(curToken.getLoc(), builtins.equals, args);
}
default:
return lhs;
}
}

FailureOr<ast::Expr *> Parser::parseRelationExpr() { return parseAddSubExpr(); }
Expand Down
7 changes: 7 additions & 0 deletions mlir/test/lib/Tools/PDLL/TestPDLL.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,10 @@ Pattern TestSimplePattern => replace op<test.simple> with op<test.success>;

// Test the import of interfaces.
Pattern TestInterface => replace _: CastOpInterface with op<test.success>;

// Test equals builtin
Pattern TestEquals {
let op = op<test.equals> {val = val : Attr};
val == attr<"4 : i32">;
replace op with op<test.success>;
}
14 changes: 14 additions & 0 deletions mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,17 @@ Pattern RewriteMultiplyElementsArrayAttr {
replace root with newRoot;
};
}

// -----
// CHECK-LABEL: pdl.pattern @TestEquals : benefit(0) {
// CHECK: %[[VAL_0:.*]] = operands
// CHECK: %[[VAL_1:.*]] = attribute
// CHECK: %[[VAL_2:.*]] = types
// CHECK: %[[VAL_3:.*]] = operation "test.op"(%[[VAL_0]] : !pdl.range<value>) {"val" = %[[VAL_1]]} -> (%[[VAL_2]] : !pdl.range<type>)
// CHECK: %[[VAL_4:.*]] = attribute = 4 : i32
// CHECK: apply_native_constraint "__builtin_equals"(%[[VAL_1]], %[[VAL_4]] : !pdl.attribute, !pdl.attribute)
Pattern TestEquals {
let op = op<test.op> {val = val : Attr};
val == attr<"4 : i32">;
replace op with op<test.success>;
}
9 changes: 9 additions & 0 deletions mlir/test/mlir-pdll/Integration/test-pdll.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,12 @@ func.func @testImportedInterface() -> i1 {
%value = "builtin.unrealized_conversion_cast"() : () -> (i1)
return %value : i1
}

// CHECK-LABEL: func @test_builtin
func.func @test_builtin() {
// CHECK: test.success
// CHECK: test.equals_neg
"test.equals"() { val = 4 : i32 }: () -> ()
"test.equals_neg"() { val = 4 : i32 }: () -> ()
return
}
20 changes: 20 additions & 0 deletions mlir/test/mlir-pdll/Parser/expr-failure.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -412,3 +412,23 @@ Pattern {
// CHECK: expected `>` after type literal
let foo = type<"";
}

// -----

//===----------------------------------------------------------------------===//
// Builtins
//===----------------------------------------------------------------------===//

Pattern {
// CHECK: expected expression
==
erase _: Op;
}

// -----

Pattern {
// CHECK: expected expression
attr<"4 : i32"> ==
erase _: Op;
}
16 changes: 16 additions & 0 deletions mlir/test/mlir-pdll/Parser/expr.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,19 @@ Pattern {

erase _: Op;
}

// -----

//===----------------------------------------------------------------------===//
// Builtins
//===----------------------------------------------------------------------===//

// CHECK: Module {{.*}}
// CHECK: UserConstraintDecl {{.*}} Name<__builtin_equals> ResultType<Tuple<>>
// CHECK: UserConstraintDecl {{.*}} Name<__builtin_equals> ResultType<Tuple<>>
Pattern {
attr<"4 : i32"> == attr<"5 : i32">;
let a: Attr;
a == a;
erase _: Op;
}
20 changes: 20 additions & 0 deletions mlir/unittests/Dialect/PDL/BuiltinTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,24 @@ TEST_F(BuiltinTest, addElemToArrayAttr) {
cast<DictionaryAttr>(*cast<ArrayAttr>(updatedArrAttr).begin());
EXPECT_EQ(dictInsideArrAttr, dict);
}

TEST_F(BuiltinTest, equals) {
auto onei16 = rewriter.getI16IntegerAttr(1);
auto onei32 = rewriter.getI32IntegerAttr(1);
auto zeroi32 = rewriter.getI32IntegerAttr(0);

EXPECT_TRUE(builtin::equals(rewriter, onei16, onei16).succeeded());
EXPECT_TRUE(builtin::equals(rewriter, onei16, onei32).failed());
EXPECT_TRUE(builtin::equals(rewriter, zeroi32, onei32).failed());

auto onef32 = rewriter.getF32FloatAttr(1.0);
auto zerof32 = rewriter.getF32FloatAttr(0.0);
auto negzerof32 = rewriter.getF32FloatAttr(-0.0);
auto zerof64 = rewriter.getF64FloatAttr(0.0);

EXPECT_TRUE(builtin::equals(rewriter, onef32, onef32).succeeded());
EXPECT_TRUE(builtin::equals(rewriter, onef32, zerof32).failed());
EXPECT_TRUE(builtin::equals(rewriter, negzerof32, zerof32).succeeded());
EXPECT_TRUE(builtin::equals(rewriter, zerof32, zerof64).failed());
}
} // namespace

0 comments on commit 9d16822

Please sign in to comment.