Skip to content

Commit

Permalink
MLIR: Add ReplaceOpAction into DialectConversion and PatternMatch
Browse files Browse the repository at this point in the history
  • Loading branch information
mgehre-amd committed Jul 28, 2023
1 parent 45c7a6b commit af9dcbd
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 0 deletions.
20 changes: 20 additions & 0 deletions mlir/include/mlir/IR/PatternMatchAction.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#ifndef MLIR_IR_PATTERNMATCHACTION_H
#define MLIR_IR_PATTERNMATCHACTION_H

#include "mlir/IR/Action.h"

namespace mlir {
struct ReplaceOpAction : public tracing::ActionImpl<ReplaceOpAction> {
using Base = tracing::ActionImpl<ReplaceOpAction>;
ReplaceOpAction(ArrayRef<IRUnit> irUnits, ValueRange replacement);
static constexpr StringLiteral tag = "op-replacement";
void print(raw_ostream &os) const override;

Operation *getOp() const;

public:
ValueRange replacement;
};
}

#endif
32 changes: 32 additions & 0 deletions mlir/lib/IR/PatternMatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
//===----------------------------------------------------------------------===//

#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/PatternMatchAction.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Action.h"

using namespace mlir;

Expand Down Expand Up @@ -269,13 +271,43 @@ void RewriterBase::replaceOpWithinBlock(Operation *op, ValueRange newValues,
});
}

ReplaceOpAction::ReplaceOpAction(ArrayRef<IRUnit> irUnits, ValueRange replacement)
: Base(irUnits), replacement(replacement) {
assert(irUnits.size() == 1);
assert(irUnits[0]);
assert(isa<Operation*>(irUnits[0]));
}

void ReplaceOpAction::print(raw_ostream &os) const {
OpPrintingFlags flags;
flags.elideLargeElementsAttrs(10);
os << "`" << tag << "` replacing operation `";
getOp()->print(os, flags);
os << "` by ";
bool first = true;
for(auto r : replacement) {
if (!first)
os << ", ";
os << "`";
r.print(os, flags);
os << "`";
first = false;
}
}

Operation *ReplaceOpAction::getOp() const {
return llvm::dyn_cast<Operation *>(getContextIRUnits()[0]);
}

/// This method replaces the results of the operation with the specified list of
/// values. The number of provided values must match the number of results of
/// the operation.
void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
assert(op->getNumResults() == newValues.size() &&
"incorrect # of replacement values");

getContext()->executeAction<ReplaceOpAction>([]() {}, {op}, newValues);

// Notify the listener that we're about to remove this op.
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
rewriteListener->notifyOperationReplaced(op, newValues);
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "mlir/Transforms/DialectConversion.h"
#include "mlir/IR/PatternMatchAction.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
Expand Down Expand Up @@ -1458,6 +1459,7 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
<< "** Replace : '" << op->getName() << "'(" << op << ")\n";
});
impl->notifyOpReplaced(op, newValues);
getContext()->executeAction<ReplaceOpAction>([]() {}, {op}, newValues);
}

void ConversionPatternRewriter::eraseOp(Operation *op) {
Expand Down

0 comments on commit af9dcbd

Please sign in to comment.