Skip to content

Commit

Permalink
Merge commit '226f162d17f3' into HEAD
Browse files Browse the repository at this point in the history
  • Loading branch information
mgehre-amd committed Aug 11, 2023
2 parents af9dcbd + 226f162 commit 26a6043
Show file tree
Hide file tree
Showing 14 changed files with 1,184 additions and 10 deletions.
11 changes: 11 additions & 0 deletions mlir/include/mlir/IR/OperationSupport.h
Original file line number Diff line number Diff line change
Expand Up @@ -1096,6 +1096,10 @@ class OpPrintingFlags {
/// elements.
OpPrintingFlags &elideLargeElementsAttrs(int64_t largeElementLimit = 16);

/// Enables breaking attributes on individual lines when there are more than
/// the given number of attributes on an operation.
OpPrintingFlags& newlineAfterAttribute(int64_t attributeLimit = 2);

/// Enable or disable printing of debug information (based on `enable`). If
/// 'prettyForm' is set to true, debug information is printed in a more
/// readable 'pretty' form. Note: The IR generated with 'prettyForm' is not
Expand Down Expand Up @@ -1126,6 +1130,9 @@ class OpPrintingFlags {
/// Return the size limit for printing large ElementsAttr.
std::optional<int64_t> getLargeElementsAttrLimit() const;

/// Return the size limit for printing newlines after attributes.
std::optional<unsigned> getNewlineAfterAttrLimit() const;

/// Return if debug information should be printed.
bool shouldPrintDebugInfo() const;

Expand All @@ -1152,6 +1159,10 @@ class OpPrintingFlags {
/// the upper limit.
std::optional<int64_t> elementsAttrElementLimit;

/// Print newlines after each attribute when an operation has more than
/// the given number of attributes.
std::optional<unsigned> newlineAfterAttr;

/// Print debug information.
bool printDebugInfoFlag : 1;
bool printDebugInfoPrettyFormFlag : 1;
Expand Down
281 changes: 281 additions & 0 deletions mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,8 @@ struct TosaFoldConstantGreater : public TosaFoldConstantBinary<TosaFoldConstantG
return applyElementWise<APFloat, APInt>(
lhsValues, rhsValues, op.getType(),
[](const APFloat &first, const APFloat &second) {
if (first.isNaN() || second.isNaN())
return APInt(1, false);
return APInt(1, first > second);
});
}
Expand Down Expand Up @@ -1090,6 +1092,277 @@ struct TosaFoldConstantErf
}
};

struct TosaFoldConstantLog
: public TosaFoldConstantUnaryElementwise<TosaFoldConstantLog, LogOp> {
using TosaFoldConstantUnaryElementwise<
TosaFoldConstantLog, LogOp>::TosaFoldConstantUnaryElementwise;

DenseElementsAttr computeFloat(DenseElementsAttr values,
PatternRewriter &rewriter, TosaOp op) const {
return applyElementWise<APFloat, APFloat, FloatType>(
values,
[](const APFloat &val, FloatType) {
auto res = APFloat(std::log(val.convertToFloat()));
bool lostPrecision;
res.convert(val.getSemantics(), APFloat::rmNearestTiesToEven,
&lostPrecision);
return res;
},
cast<FloatType>(values.getElementType()));
}

bool isSupportedElementType(Type type) const {
// convertToFloat uses F32, so we specify the supported types to make sure
// to properly handle F64 if needed in the future.
return type.isBF16() || type.isF16() || type.isF32();
}
};

struct TosaFoldConstantBitwiseAnd
: public TosaFoldConstantBinary<TosaFoldConstantBitwiseAnd, BitwiseAndOp> {
using TosaFoldConstantBinary<TosaFoldConstantBitwiseAnd,
BitwiseAndOp>::TosaFoldConstantBinary;

DenseElementsAttr computeInteger(DenseElementsAttr lhsValues,
DenseElementsAttr rhsValues,
PatternRewriter &rewriter,
BitwiseAndOp op) const {
return applyElementWise<APInt, APInt>(
lhsValues, rhsValues, op.getType(),
[](const APInt &lhs, const APInt &rhs) { return lhs & rhs; });
}
};

struct TosaFoldConstantBitwiseOr
: public TosaFoldConstantBinary<TosaFoldConstantBitwiseOr, BitwiseOrOp> {
using TosaFoldConstantBinary<TosaFoldConstantBitwiseOr,
BitwiseOrOp>::TosaFoldConstantBinary;

DenseElementsAttr computeInteger(DenseElementsAttr lhsValues,
DenseElementsAttr rhsValues,
PatternRewriter &rewriter,
BitwiseOrOp op) const {
return applyElementWise<APInt, APInt>(
lhsValues, rhsValues, op.getType(),
[](const APInt &lhs, const APInt &rhs) { return lhs | rhs; });
}
};

struct TosaFoldConstantGreaterEqual
: public TosaFoldConstantBinary<TosaFoldConstantGreaterEqual,
GreaterEqualOp> {
using TosaFoldConstantBinary<TosaFoldConstantGreaterEqual,
GreaterEqualOp>::TosaFoldConstantBinary;

DenseElementsAttr computeInteger(DenseElementsAttr lhsValues,
DenseElementsAttr rhsValues,
PatternRewriter &rewriter,
GreaterEqualOp op) const {
return applyElementWise<APInt, APInt>(
lhsValues, rhsValues, op.getType(),
[](const APInt &first, const APInt &second) {
return APInt(1, first.sge(second));
});
}

DenseElementsAttr computeFloat(DenseElementsAttr lhsValues,
DenseElementsAttr rhsValues,
PatternRewriter &rewriter,
GreaterEqualOp op) const {
return applyElementWise<APFloat, APInt>(
lhsValues, rhsValues, op.getType(),
[](const APFloat &first, const APFloat &second) {
if (first.isNaN() || second.isNaN())
return APInt(1, false);
return APInt(1, first >= second);
});
}
};

struct TosaFoldConstantEqual
: public TosaFoldConstantBinary<TosaFoldConstantEqual, EqualOp> {
using TosaFoldConstantBinary<TosaFoldConstantEqual,
EqualOp>::TosaFoldConstantBinary;

DenseElementsAttr computeInteger(DenseElementsAttr lhsValues,
DenseElementsAttr rhsValues,
PatternRewriter &rewriter,
EqualOp op) const {
return applyElementWise<APInt, APInt>(
lhsValues, rhsValues, op.getType(),
[](const APInt &first, const APInt &second) {
return APInt(1, first.eq(second));
});
}

DenseElementsAttr computeFloat(DenseElementsAttr lhsValues,
DenseElementsAttr rhsValues,
PatternRewriter &rewriter, EqualOp op) const {
return applyElementWise<APFloat, APInt>(
lhsValues, rhsValues, op.getType(),
[](const APFloat &first, const APFloat &second) {
return APInt(1, first == second);
});
}
};

struct TosaFoldConstantMinimum
: public TosaFoldConstantBinary<TosaFoldConstantMinimum, MinimumOp> {
using TosaFoldConstantBinary<TosaFoldConstantMinimum,
MinimumOp>::TosaFoldConstantBinary;

DenseElementsAttr computeInteger(DenseElementsAttr lhsValues,
DenseElementsAttr rhsValues,
PatternRewriter &rewriter,
MinimumOp op) const {
return applyElementWise<APInt, APInt>(
lhsValues, rhsValues, op.getType(),
[](const APInt &first, const APInt &second) {
return first.slt(second) ? first : second;
});
}

DenseElementsAttr computeFloat(DenseElementsAttr lhsValues,
DenseElementsAttr rhsValues,
PatternRewriter &rewriter,
MinimumOp op) const {
return applyElementWise<APFloat, APFloat>(
lhsValues, rhsValues, op.getType(),
[](const APFloat &first, const APFloat &second) {
if (first.isNaN() || second.isNaN())
return first.isNaN() ? first : second;
return first < second ? first : second;
});
}
};

struct TosaFoldConstantMaximum
: public TosaFoldConstantBinary<TosaFoldConstantMaximum, MaximumOp> {
using TosaFoldConstantBinary<TosaFoldConstantMaximum,
MaximumOp>::TosaFoldConstantBinary;

DenseElementsAttr computeInteger(DenseElementsAttr lhsValues,
DenseElementsAttr rhsValues,
PatternRewriter &rewriter,
MaximumOp op) const {
return applyElementWise<APInt, APInt>(
lhsValues, rhsValues, op.getType(),
[](const APInt &first, const APInt &second) {
return first.sgt(second) ? first : second;
});
}

DenseElementsAttr computeFloat(DenseElementsAttr lhsValues,
DenseElementsAttr rhsValues,
PatternRewriter &rewriter,
MaximumOp op) const {
return applyElementWise<APFloat, APFloat>(
lhsValues, rhsValues, op.getType(),
[](const APFloat &first, const APFloat &second) {
if (first.isNaN() || second.isNaN())
return first.isNaN() ? first : second;
return first > second ? first : second;
});
}
};

template <typename BaseType>
DenseElementsAttr padType(ShapedType inputType, ElementsAttr inputValues,
DenseElementsAttr paddings,
std::optional<DenseElementsAttr> padConstValue,
ShapedType outputType, BaseType zero) {
BaseType padConst(zero);
if (padConstValue.has_value())
padConst = padConstValue.value().getSplatValue<BaseType>();

auto values = inputValues.getValues<BaseType>();
auto paddingVals = paddings.getValues<int64_t>();

auto outputShape = outputType.getShape();
auto inputShape = inputType.getShape();

// Implements the logic from
// https://www.mlplatform.org/tosa/tosa_spec.html#_pad
SmallVector<BaseType> outputValues(outputType.getNumElements(), padConst);
for (size_t outIndex = 0, e = outputValues.size(); outIndex < e; ++outIndex) {
auto indexInTarget = offsetToIndex(outputShape, outIndex);

bool isPad =
llvm::any_of(llvm::enumerate(indexInTarget), [&](const auto &dimInfo) {
auto index = dimInfo.index();
auto i = dimInfo.value() - paddingVals[index * 2];
return static_cast<bool>(i < 0 || i >= inputShape[index]);
});

auto inputIndexOffset = indexToOffset(outputShape, indexInTarget);
outputValues[outIndex] = isPad ? padConst : values[inputIndexOffset];
}
return DenseElementsAttr::get(outputType,
llvm::ArrayRef<BaseType>(outputValues));
}

DenseElementsAttr pad(ShapedType inputType, ElementsAttr inputValues,
DenseElementsAttr paddings,
std::optional<DenseElementsAttr> padConstValue,
ShapedType outputType) {

auto baseType = inputType.getElementType();

// Handle integer types with APInt
if (auto intType = dyn_cast<IntegerType>(baseType))
return padType<APInt>(inputType, inputValues, paddings, padConstValue,
outputType,
APInt(baseType.getIntOrFloatBitWidth(), 0));

assert(isa<FloatType>(baseType) && "Unknown element type.");
FloatType fpType = cast<FloatType>(baseType);

// Handle FP types with APFloat
APFloat zero(fpType.getFloatSemantics(), APInt::getZero(fpType.getWidth()));
return padType<APFloat>(inputType, inputValues, paddings, padConstValue,
outputType, zero);
}

struct TosaFoldConstantPad : public TosaFoldConstantBase<tosa::PadOp> {
using TosaFoldConstantBase::TosaFoldConstantBase;

LogicalResult matchAndRewrite(tosa::PadOp op,
PatternRewriter &rewriter) const override {
auto outputType = cast<ShapedType>(op.getType());
// TOSA doesn't support quantized types.
if (!outputType.getElementType().isIntOrIndexOrFloat())
return failure();

auto input = op.getInput1();
ElementsAttr inputValues;
if (!matchPattern(input, m_Constant(&inputValues)))
return failure();

// Only fold op with multiple users if foldSplatOrSingleUseOnly == true.
if (!llvm::hasSingleElement(input.getDefiningOp()->getUsers()) &&
foldSplatOrSingleUseOnly)
return failure();

std::optional<DenseElementsAttr> padConstValue;
if (op.getPadConst()) {
DenseElementsAttr attr;
if (!matchPattern(op.getPadConst(), m_Constant(&attr)))
return failure();
padConstValue = attr;
}

DenseElementsAttr paddings;
if (!matchPattern(op.getPadding(), m_Constant(&paddings)))
return failure();

auto resultAttr =
pad(input.getType(), inputValues, paddings, padConstValue, outputType);
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputType, resultAttr);

return success();
}
};

} // namespace

void mlir::tosa::populateTosaFoldConstantPatterns(
Expand All @@ -1113,4 +1386,12 @@ void mlir::tosa::populateTosaFoldConstantPatterns(
patterns.add<TosaFoldConstantBitwiseNot>(ctx, foldSplatOrSingleUseOnly);
patterns.add<TosaFoldConstantCeil>(ctx, foldSplatOrSingleUseOnly);
patterns.add<TosaFoldConstantErf>(ctx, foldSplatOrSingleUseOnly);
patterns.add<TosaFoldConstantLog>(ctx, foldSplatOrSingleUseOnly);
patterns.add<TosaFoldConstantBitwiseAnd>(ctx, foldSplatOrSingleUseOnly);
patterns.add<TosaFoldConstantBitwiseOr>(ctx, foldSplatOrSingleUseOnly);
patterns.add<TosaFoldConstantGreaterEqual>(ctx, foldSplatOrSingleUseOnly);
patterns.add<TosaFoldConstantEqual>(ctx, foldSplatOrSingleUseOnly);
patterns.add<TosaFoldConstantMinimum>(ctx, foldSplatOrSingleUseOnly);
patterns.add<TosaFoldConstantMaximum>(ctx, foldSplatOrSingleUseOnly);
patterns.add<TosaFoldConstantPad>(ctx, foldSplatOrSingleUseOnly);
}
Loading

0 comments on commit 26a6043

Please sign in to comment.