Skip to content

Commit

Permalink
feat: add verifiers for concat, pad and sigmoid
Browse files Browse the repository at this point in the history
  • Loading branch information
sayeddla committed Jan 10, 2024
1 parent 83d5a86 commit ff218c1
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 38 deletions.
15 changes: 12 additions & 3 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,8 @@ def Tosa_SigmoidOp : Tosa_ElemWiseUnaryOp<"sigmoid"> {
let results = (outs
Tosa_Tensor:$output
);

let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1423,15 +1425,14 @@ def Tosa_ConcatOp : Tosa_Op<"concat", [
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
let hasFolder = 1;
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// Operator: pad
//===----------------------------------------------------------------------===//
def Tosa_PadOp : Tosa_Op<"pad", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
Pure]> {
InferTensorType, Pure]> {
let summary = "Pads a tensor with value specified.";

let description = [{
Expand Down Expand Up @@ -1470,6 +1471,14 @@ def Tosa_PadOp : Tosa_Op<"pad", [

let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;

let extraClassDeclaration = [{
/// Returns true when two result types are compatible for this op;
/// Method used by InferTypeOpInterface.
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];

}

//===----------------------------------------------------------------------===//
Expand Down
139 changes: 107 additions & 32 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
// TOSA Operator Verifiers.
//===----------------------------------------------------------------------===//

template <typename T> static LogicalResult verifyConvOp(T op) {
template <typename T>
static LogicalResult verifyConvOp(T op) {
// All TOSA conv ops have an input() and weight().
auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
Expand Down Expand Up @@ -503,6 +504,41 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
return success();
}

LogicalResult ConcatOp::verify() {
OperandRange inputs = getInput1();

auto inputRank = ShapedType::kDynamic;
bool hasRankedInputs;
for (auto input : inputs) {
auto inputType = llvm::cast<ShapedType>(input.getType());
if (inputType.hasRank()) {
hasRankedInputs = true;
inputRank = inputType.getRank();
break;
}
}

if (hasRankedInputs) {
int64_t axis = getAxis();
if (axis < 0 || axis >= std::max((int64_t)1, inputRank)) {
return emitOpError() << "axis must be in range 0 to " << inputRank - 1;
}

for (auto input : inputs) {
auto inputType = llvm::cast<ShapedType>(input.getType());
if (!inputType.hasRank()) {
continue;
}
if (inputRank != inputType.getRank()) {
return emitOpError()
<< "rank of input " << inputType
<< " does not match other input rank(s) (" << inputRank << ")";
}
}
}
return success();
}

LogicalResult tosa::EqualOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
Expand Down Expand Up @@ -590,6 +626,7 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
Type inputType = getElementTypeOrSelf(operands[0]);
ShapeAdaptor inputShape = operands.getShape(0);
ShapeAdaptor paddingShape = operands.getShape(1);
SmallVector<int64_t> outputShape;
Expand All @@ -610,15 +647,17 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
}

outputShape.resize(paddingShape.getDimSize(0), ShapedType::kDynamic);
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
inferredReturnShapes.push_back(
ShapedTypeComponents(outputShape, inputType));
return success();
}

DenseIntElementsAttr paddings;
// If the paddings value is not a constant, all dimensions must be dynamic.
if (!matchPattern(operands[1], m_Constant(&paddings))) {
outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
inferredReturnShapes.push_back(
ShapedTypeComponents(outputShape, inputType));
return success();
}

Expand All @@ -638,7 +677,35 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
paddingValues[i * 2 + 1]);
}

inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
return success();
}

LogicalResult PadOp::verify() {
ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
ShapedType paddingType = llvm::cast<ShapedType>(getPadding().getType());
if (paddingType.hasRank()) {
if (paddingType.getRank() != 2) {
return emitOpError() << "paddings must be a tensor of rank 2";
}
if (inputType.hasRank() && !paddingType.isDynamicDim(0) &&
inputType.getRank() != paddingType.getDimSize(0)) {
return emitOpError() << "paddings must be a tensor of shape ["
<< inputType.getRank() << ", 2]";
}
if (!paddingType.isDynamicDim(1) && paddingType.getDimSize(1) != 2) {
return emitOpError() << "paddings must be a tensor of shape ["
<< inputType.getRank() << ", 2]";
}

DenseIntElementsAttr paddings;
if (matchPattern(getPadding(), m_Constant(&paddings))) {
if (llvm::any_of(paddings,
[](auto val) { return val.getSExtValue() < 0; })) {
return emitOpError() << "number of pad elements must be positive";
}
}
}
return success();
}

Expand Down Expand Up @@ -767,18 +834,18 @@ mlir::LogicalResult tosa::ReshapeOp::verify() {
}

if ((int64_t)getNewShape().size() != outputType.getRank()) {
return emitOpError() << "rank of newShape (" << getNewShape().size()
<< ") and output ("
<< outputType.getRank()
return emitOpError() << "rank of newShape (" << getNewShape().size()
<< ") and output (" << outputType.getRank()
<< ") must match";
}

for (int64_t dim=0; dim < outputType.getRank(); ++dim) {
if (getNewShape()[dim] != -1 && getNewShape()[dim] != outputType.getShape()[dim]) {
return emitOpError() << "newShape attribute (" << getNewShape()[dim]
<< ") does not match output type ("
<< outputType.getShape()[dim]
<< ") in dimension " << dim;
for (int64_t dim = 0; dim < outputType.getRank(); ++dim) {
if (getNewShape()[dim] != -1 &&
getNewShape()[dim] != outputType.getShape()[dim]) {
return emitOpError()
<< "newShape attribute (" << getNewShape()[dim]
<< ") does not match output type (" << outputType.getShape()[dim]
<< ") in dimension " << dim;
}
}
}
Expand All @@ -792,38 +859,34 @@ mlir::LogicalResult tosa::SliceOp::verify() {

if (inputType.getRank() != outputType.getRank()) {
return emitOpError() << "rank of input (" << inputType.getRank()
<< ") and output ("
<< outputType.getRank()
<< ") must match";
<< ") and output (" << outputType.getRank()
<< ") must match";
}

if ((int64_t)getSize().size() != outputType.getRank()) {
return emitOpError() << "rank of size (" << getSize().size()
<< ") and output ("
<< outputType.getRank()
<< ") must match";
return emitOpError() << "rank of size (" << getSize().size()
<< ") and output (" << outputType.getRank()
<< ") must match";
}
for (int64_t dim=0; dim < outputType.getRank(); ++dim) {
if (getSize()[dim] != -1 && !outputType.isDynamicDim(dim) &&
getSize()[dim] != outputType.getShape()[dim]) {
for (int64_t dim = 0; dim < outputType.getRank(); ++dim) {
if (getSize()[dim] != -1 && !outputType.isDynamicDim(dim) &&
getSize()[dim] != outputType.getShape()[dim]) {
return emitOpError() << "size attribute (" << getSize()[dim]
<< ") does not match output type ("
<< outputType.getShape()[dim] << ") in dimension "
<< dim;
}
}
}

if ((int64_t)getStart().size() != inputType.getRank()) {
return emitOpError() << "rank of start (" << getStart().size()
<< ") and input ("
<< inputType.getRank()
<< ") must match";
return emitOpError() << "rank of start (" << getStart().size()
<< ") and input (" << inputType.getRank()
<< ") must match";
}
if ((int64_t)getSize().size() != inputType.getRank()) {
return emitOpError() << "rank of size (" << getSize().size()
<< ") and input ("
<< inputType.getRank()
<< ") must match";
return emitOpError() << "rank of size (" << getSize().size()
<< ") and input (" << inputType.getRank()
<< ") must match";
}

for (int i = 0; i < outputType.getRank(); ++i) {
Expand Down Expand Up @@ -1069,6 +1132,7 @@ REDUCE_SHAPE_INFER(tosa::ReduceProdOp)
REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
#undef REDUCE_SHAPE_INFER
COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
COMPATIBLE_RETURN_TYPES(tosa::PadOp)
#undef COMPATIBLE_RETURN_TYPES

static LogicalResult NAryInferReturnTypes(
Expand Down Expand Up @@ -1561,6 +1625,17 @@ LogicalResult WhileOp::inferReturnTypeComponents(
return success();
}

LogicalResult SigmoidOp::verify() {
auto inputType = llvm::cast<ShapedType>(getInput().getType());
auto outputType = llvm::cast<ShapedType>(getOutput().getType());
auto result = verifyCompatibleShapes(inputType, outputType);
if (result.failed()) {
return emitOpError() << "input type " << inputType << " and output type "
<< outputType << " are not compatible";
}
return success();
}

std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
if (auto vt = llvm::dyn_cast<VectorType>(getType()))
return llvm::to_vector<4>(vt.getShape());
Expand Down
Loading

0 comments on commit ff218c1

Please sign in to comment.