Skip to content

Commit

Permalink
[mlir][Transforms] Dialect conversion: Unify materialization of value…
Browse files Browse the repository at this point in the history
… replacements (llvm#108381)

PR llvm#106760 aligned the handling of dropped block arguments and dropped
op results. The two helper functions that insert source materializations
for uses of replaced block arguments / op results that survived the
conversion are now almost identical (`legalizeConvertedArgumentTypes`
and `legalizeConvertedOpResultTypes`). This PR merges the two functions
and moves the implementation directly into `finalize`.

This PR simplifies the code base and improves the efficiency a bit:
previously, `finalize` iterated over
`ConversionPatternRewriterImpl::rewrites` twice. Now, only one iteration
is needed.

---------

Co-authored-by: Jakub Kuderski <jakub@nod-labs.com>
  • Loading branch information
matthias-springer and kuhar authored Sep 21, 2024
1 parent c57b9f5 commit 8527861
Showing 1 changed file with 41 additions and 92 deletions.
133 changes: 41 additions & 92 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2338,17 +2338,6 @@ struct OperationConverter {
/// remaining artifacts and complete the conversion.
LogicalResult finalize(ConversionPatternRewriter &rewriter);

/// Legalize the types of converted block arguments.
LogicalResult
legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &rewriterImpl);

/// Legalize the types of converted op results.
LogicalResult legalizeConvertedOpResultTypes(
ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &rewriterImpl,
DenseMap<Value, SmallVector<Value>> &inverseMapping);

/// Dialect conversion configuration.
ConversionConfig config;

Expand Down Expand Up @@ -2512,19 +2501,6 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
return success();
}

LogicalResult
OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
return failure();
DenseMap<Value, SmallVector<Value>> inverseMapping =
rewriterImpl.mapping.getInverse();
if (failed(legalizeConvertedOpResultTypes(rewriter, rewriterImpl,
inverseMapping)))
return failure();
return success();
}

/// Finds a user of the given value, or of any other value that the given value
/// replaced, that was not replaced in the conversion process.
static Operation *findLiveUserOfReplaced(
Expand All @@ -2548,87 +2524,60 @@ static Operation *findLiveUserOfReplaced(
return nullptr;
}

LogicalResult OperationConverter::legalizeConvertedOpResultTypes(
ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &rewriterImpl,
DenseMap<Value, SmallVector<Value>> &inverseMapping) {
// Process requested operation replacements.
for (unsigned i = 0; i < rewriterImpl.rewrites.size(); ++i) {
auto *opReplacement =
dyn_cast<ReplaceOperationRewrite>(rewriterImpl.rewrites[i].get());
if (!opReplacement)
continue;
Operation *op = opReplacement->getOperation();
for (OpResult result : op->getResults()) {
// If the type of this op result changed and the result is still live,
// we need to materialize a conversion.
if (rewriterImpl.mapping.lookupOrNull(result, result.getType()))
/// Helper function that returns the replaced values and the type converter if
/// the given rewrite object is an "operation replacement" or a "block type
/// conversion" (which corresponds to a "block replacement"). Otherwise, return
/// an empty ValueRange and a null type converter pointer.
static std::pair<ValueRange, const TypeConverter *>
getReplacedValues(IRRewrite *rewrite) {
if (auto *opRewrite = dyn_cast<ReplaceOperationRewrite>(rewrite))
return {opRewrite->getOperation()->getResults(), opRewrite->getConverter()};
if (auto *blockRewrite = dyn_cast<BlockTypeConversionRewrite>(rewrite))
return {blockRewrite->getOrigBlock()->getArguments(),
blockRewrite->getConverter()};
return {};
}

LogicalResult
OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
DenseMap<Value, SmallVector<Value>> inverseMapping =
rewriterImpl.mapping.getInverse();

// Process requested value replacements.
for (unsigned i = 0, e = rewriterImpl.rewrites.size(); i < e; ++i) {
ValueRange replacedValues;
const TypeConverter *converter;
std::tie(replacedValues, converter) =
getReplacedValues(rewriterImpl.rewrites[i].get());
for (Value originalValue : replacedValues) {
// If the type of this value changed and the value is still live, we need
// to materialize a conversion.
if (rewriterImpl.mapping.lookupOrNull(originalValue,
originalValue.getType()))
continue;
Operation *liveUser =
findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
findLiveUserOfReplaced(originalValue, rewriterImpl, inverseMapping);
if (!liveUser)
continue;

// Legalize this result.
Value newValue = rewriterImpl.mapping.lookupOrNull(result);
// Legalize this value replacement.
Value newValue = rewriterImpl.mapping.lookupOrNull(originalValue);
assert(newValue && "replacement value not found");
Value castValue = rewriterImpl.buildUnresolvedMaterialization(
MaterializationKind::Source, computeInsertPoint(result), op->getLoc(),
/*inputs=*/newValue, /*outputType=*/result.getType(),
opReplacement->getConverter());
rewriterImpl.mapping.map(result, castValue);
inverseMapping[castValue].push_back(result);
llvm::erase(inverseMapping[newValue], result);
MaterializationKind::Source, computeInsertPoint(newValue),
originalValue.getLoc(),
/*inputs=*/newValue, /*outputType=*/originalValue.getType(),
converter);
rewriterImpl.mapping.map(originalValue, castValue);
inverseMapping[castValue].push_back(originalValue);
llvm::erase(inverseMapping[newValue], originalValue);
}
}

return success();
}

LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &rewriterImpl) {
// Functor used to check if all users of a value will be dead after
// conversion.
// TODO: This should probably query the inverse mapping, same as in
// `legalizeConvertedOpResultTypes`.
auto findLiveUser = [&](Value val) {
auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) {
return rewriterImpl.isOpIgnored(user);
});
return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
};
// Note: `rewrites` may be reallocated as the loop is running.
for (int64_t i = 0; i < static_cast<int64_t>(rewriterImpl.rewrites.size());
++i) {
auto &rewrite = rewriterImpl.rewrites[i];
if (auto *blockTypeConversionRewrite =
dyn_cast<BlockTypeConversionRewrite>(rewrite.get())) {
// Process the remapping for each of the original arguments.
for (Value origArg :
blockTypeConversionRewrite->getOrigBlock()->getArguments()) {
// If the type of this argument changed and the argument is still live,
// we need to materialize a conversion.
if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
continue;
Operation *liveUser = findLiveUser(origArg);
if (!liveUser)
continue;

Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg);
assert(replacementValue && "replacement value not found");
Value repl = rewriterImpl.buildUnresolvedMaterialization(
MaterializationKind::Source, computeInsertPoint(replacementValue),
origArg.getLoc(), /*inputs=*/replacementValue,
/*outputType=*/origArg.getType(),
blockTypeConversionRewrite->getConverter());
rewriterImpl.mapping.map(origArg, repl);
}
}
}
return success();
}

//===----------------------------------------------------------------------===//
// Reconcile Unrealized Casts
//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit 8527861

Please sign in to comment.