From af9dcbdac168e280282c2b57084f3a1c0c937432 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Thu, 27 Jul 2023 11:13:22 +0200 Subject: [PATCH] MLIR: Add ReplaceOpAction into DialectConversion and PatternMatch --- mlir/include/mlir/IR/PatternMatchAction.h | 20 ++++++++++++ mlir/lib/IR/PatternMatch.cpp | 32 +++++++++++++++++++ .../Transforms/Utils/DialectConversion.cpp | 2 ++ 3 files changed, 54 insertions(+) create mode 100644 mlir/include/mlir/IR/PatternMatchAction.h diff --git a/mlir/include/mlir/IR/PatternMatchAction.h b/mlir/include/mlir/IR/PatternMatchAction.h new file mode 100644 index 00000000000000..a7d9b27e46c414 --- /dev/null +++ b/mlir/include/mlir/IR/PatternMatchAction.h @@ -0,0 +1,20 @@ +#ifndef MLIR_IR_PATTERNMATCHACTION_H +#define MLIR_IR_PATTERNMATCHACTION_H + +#include "mlir/IR/Action.h" + +namespace mlir { +struct ReplaceOpAction : public tracing::ActionImpl { + using Base = tracing::ActionImpl; + ReplaceOpAction(ArrayRef irUnits, ValueRange replacement); + static constexpr StringLiteral tag = "op-replacement"; + void print(raw_ostream &os) const override; + + Operation *getOp() const; + +public: + ValueRange replacement; +}; +} + +#endif diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index 46b5c1e6852de6..c757bb4519b19e 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -7,7 +7,9 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/PatternMatchAction.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/Action.h" using namespace mlir; @@ -269,6 +271,34 @@ void RewriterBase::replaceOpWithinBlock(Operation *op, ValueRange newValues, }); } +ReplaceOpAction::ReplaceOpAction(ArrayRef irUnits, ValueRange replacement) + : Base(irUnits), replacement(replacement) { + assert(irUnits.size() == 1); + assert(irUnits[0]); + assert(isa(irUnits[0])); +} + +void ReplaceOpAction::print(raw_ostream &os) const { + OpPrintingFlags flags; + flags.elideLargeElementsAttrs(10); + os << "`" << tag << "` replacing operation `"; + getOp()->print(os, flags); + os << "` by "; + bool first = true; + for(auto r : replacement) { + if (!first) + os << ", "; + os << "`"; + r.print(os, flags); + os << "`"; + first = false; + } +} + +Operation *ReplaceOpAction::getOp() const { + return llvm::dyn_cast(getContextIRUnits()[0]); +} + /// This method replaces the results of the operation with the specified list of /// values. The number of provided values must match the number of results of /// the operation. @@ -276,6 +306,8 @@ void RewriterBase::replaceOp(Operation *op, ValueRange newValues) { assert(op->getNumResults() == newValues.size() && "incorrect # of replacement values"); + getContext()->executeAction([]() {}, {op}, newValues); + // Notify the listener that we're about to remove this op. if (auto *rewriteListener = dyn_cast_if_present(listener)) rewriteListener->notifyOperationReplaced(op, newValues); diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 615c8e4a99ceb7..a37770aa12b4a4 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Transforms/DialectConversion.h" +#include "mlir/IR/PatternMatchAction.h" #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" @@ -1458,6 +1459,7 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) { << "** Replace : '" << op->getName() << "'(" << op << ")\n"; }); impl->notifyOpReplaced(op, newValues); + getContext()->executeAction([]() {}, {op}, newValues); } void ConversionPatternRewriter::eraseOp(Operation *op) {