Skip to content

Commit

Permalink
[MLIR][TOSA] add additional verification to TOSA (llvm#108133)
Browse files Browse the repository at this point in the history
----------
Motivation:
----------

Spec conformance. Allows assumptions to be made in TOSA code.

------------
Changes Made:
------------

Add full permutation tensor verification to tosa.TRANSPOSE. Priorly
would not verify that permuted values were between 0 - (rank - 1).

Update tosa.TRANSPOSE perms data type to be strictly i32.

Verify input/output shapes for tosa.TRANSPOSE.

Add verifier to tosa.CONST, with consideration for quantization.

Fix TOSA conformance of tensor type to disallow dimensions with size 0
for ranked tensors, per spec.
This is not the same as rank 0 tensors. Here is an example of a
disallowed tensor: tensor<3x0xi32>. Naturally, this means that the
number of elements in a TOSA tensor will always be greater than 0.

Signed-off-by: Arteen Abrishami <arteen.abrishami@arm.com>
  • Loading branch information
arteen1000 authored Sep 11, 2024
1 parent e55d6f5 commit a54efdb
Show file tree
Hide file tree
Showing 12 changed files with 301 additions and 161 deletions.
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ add_mlir_doc(TosaOps TosaOps Dialects/ -gen-op-doc)
add_mlir_interface(TosaInterfaces)

set(LLVM_TARGET_DEFINITIONS TosaOps.td)
mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs)
mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=tosa)
mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=tosa)
add_public_tablegen_target(MLIRTosaAttributesIncGen)

set(LLVM_TARGET_DEFINITIONS TosaDialectBytecode.td)
Expand Down
58 changes: 27 additions & 31 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {

let arguments = (ins
Tosa_Tensor4D:$input,

Tosa_IntArrayAttr2:$kernel,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr4:$pad,
Expand Down Expand Up @@ -102,9 +101,8 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {

let arguments = (ins
Tosa_Tensor4D:$input,
4DTensorOf<[Tosa_Weight]>:$weight,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,

Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr2:$dilation,
Expand Down Expand Up @@ -132,9 +130,8 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {

let arguments = (ins
Tosa_Tensor5D:$input,
TensorRankOf<[Tosa_Weight], [5]>:$weight,
TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
Tosa_Tensor1D:$bias,

Tosa_IntArrayAttr6:$pad,
Tosa_IntArrayAttr3:$stride,
Tosa_IntArrayAttr3:$dilation,
Expand Down Expand Up @@ -163,9 +160,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> {

let arguments = (ins
Tosa_Tensor4D:$input,
4DTensorOf<[Tosa_Weight]>:$weight,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,

Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr2:$dilation,
Expand Down Expand Up @@ -232,7 +228,7 @@ def Tosa_FullyConnectedOp : Tosa_InferShapedTypeOp<"fully_connected"> {

let arguments = (ins
Tosa_Tensor2D:$input,
2DTensorOf<[Tosa_Weight]>:$weight,
TosaTensorRankOf<[Tosa_Weight], [2]>:$weight,
Tosa_Tensor1D:$bias,
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info
);
Expand Down Expand Up @@ -347,9 +343,8 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {

let arguments = (ins
Tosa_Tensor4D:$input,
4DTensorOf<[Tosa_Weight]>:$filter,
TosaTensorRankOf<[Tosa_Weight], [4]>:$filter,
Tosa_Tensor1D:$bias,

Tosa_IntArrayAttr4:$out_pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttrUpto4:$out_shape,
Expand Down Expand Up @@ -641,12 +636,12 @@ def Tosa_LogicalAndOp : Tosa_ElementwiseOp<"logical_and", [
}];

let arguments = (ins
I1Tensor:$input1,
I1Tensor:$input2
Tosa_I1Tensor:$input1,
Tosa_I1Tensor:$input2
);

let results = (outs
I1Tensor:$z
Tosa_I1Tensor:$z
);
}

Expand Down Expand Up @@ -708,12 +703,12 @@ def Tosa_LogicalOrOp : Tosa_ElementwiseOp<"logical_or", [
}];

let arguments = (ins
I1Tensor:$input1,
I1Tensor:$input2
Tosa_I1Tensor:$input1,
Tosa_I1Tensor:$input2
);

let results = (outs
I1Tensor:$z
Tosa_I1Tensor:$z
);
}

Expand All @@ -731,12 +726,12 @@ def Tosa_LogicalXorOp : Tosa_ElementwiseOp<"logical_xor", [
}];

let arguments = (ins
I1Tensor:$input1,
I1Tensor:$input2
Tosa_I1Tensor:$input1,
Tosa_I1Tensor:$input2
);

let results = (outs
I1Tensor:$z
Tosa_I1Tensor:$z
);
}

Expand Down Expand Up @@ -1085,11 +1080,11 @@ def Tosa_LogicalNotOp : Tosa_ElementwiseOp<"logical_not",
}];

let arguments = (ins
I1Tensor:$input1
Tosa_I1Tensor:$input1
);

let results = (outs
I1Tensor:$output
Tosa_I1Tensor:$output
);
}

Expand Down Expand Up @@ -1208,7 +1203,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
}];

let arguments = (ins
I1Tensor:$pred,
Tosa_I1Tensor:$pred,
Tosa_Tensor:$on_true,
Tosa_Tensor:$on_false
);
Expand Down Expand Up @@ -1249,7 +1244,7 @@ def Tosa_EqualOp : Tosa_ElementwiseOp<"equal", [
);

let results = (outs
I1Tensor:$output
Tosa_I1Tensor:$output
);

let extraClassDeclaration = [{
Expand Down Expand Up @@ -1277,7 +1272,7 @@ def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [SameOperandsElementType]> {
);

let results = (outs
I1Tensor:$output
Tosa_I1Tensor:$output
);

let hasFolder = 1;
Expand All @@ -1300,7 +1295,7 @@ def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal",
);

let results = (outs
I1Tensor:$output
Tosa_I1Tensor:$output
);

let hasFolder = 1;
Expand Down Expand Up @@ -1721,15 +1716,15 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",

let arguments = (ins
Tosa_Tensor:$input1,
Tosa_Int32Or64Tensor:$perms
Tosa_Int32Tensor:$perms
);

let results = (
outs Tosa_Tensor:$output
);

let extraClassDeclaration = [{
LogicalResult getConstantPerms(llvm::SmallVector<int64_t> &perms);
LogicalResult getConstantPerms(llvm::SmallVector<int32_t> &perms);
}];

let hasCanonicalizer = 1;
Expand All @@ -1755,7 +1750,7 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {

let arguments = (ins
Tosa_Tensor3D:$values,
2DTensorOf<[Tosa_Int32]>:$indices
TosaTensorRankOf<[Tosa_Int32], [2]>:$indices
);

let results = (outs
Expand All @@ -1776,7 +1771,7 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {

let arguments = (ins
Tosa_Tensor3D:$values_in,
2DTensorOf<[Tosa_Int32]>:$indices,
TosaTensorRankOf<[Tosa_Int32], [2]>:$indices,
Tosa_Tensor3D:$input
);

Expand Down Expand Up @@ -1947,10 +1942,11 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
);

let results = (outs
TensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output
TosaTensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output
);

let hasFolder = 1;
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2054,7 +2050,7 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
}];

let arguments = (ins
I1Tensor:$cond,
Tosa_I1Tensor:$cond,
Variadic<Tosa_Tensor>:$inputs
);

Expand Down
61 changes: 43 additions & 18 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -82,58 +82,83 @@ def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
Tosa_QuantizedInt, AnyFloat]>;

//===----------------------------------------------------------------------===//
// TOSA Tensor Conformance
//===----------------------------------------------------------------------===//

def HasNo0Dimensions : And<[
IsRankedTensorTypePred,
CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v != 0; })">]>;

class TosaTensorOf<
list<Type> allowedTypes, string summary = "tosa-conformant tensor">
: TensorOf<allowedTypes, [Or<[HasNo0Dimensions, IsUnrankedTensorTypePred]>], summary>;

class TosaRankedTensorOf<
list<Type> allowedTypes, list<Pred> preds = [], string summary = "tosa-conformant ranked tensor">
: RankedTensorOf<allowedTypes, !listconcat([HasNo0Dimensions], preds), summary>;

class TosaUnrankedTensorOf<list<Type> allowedTypes, list<Pred> preds = [], string summary = "tosa-conformant unranked tensor">
: UnrankedTensorOf<allowedTypes, preds, summary>;

class TosaTensorRankOf<list<Type> allowedTypes, list<int> ranks>
: TosaRankedTensorOf<allowedTypes,
[HasAnyRankOfPred<ranks>],
!interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">;

//===----------------------------------------------------------------------===//
// Tensor types
//===----------------------------------------------------------------------===//

def Tosa_Int32Tensor : TensorOf<[Tosa_Int32]>;
def Tosa_Int32Or64Tensor : TensorOf<[Tosa_Int32Or64]>;
def Tosa_I1Tensor : TosaTensorOf<[I1]>;
def Tosa_Int32Tensor : TosaTensorOf<[Tosa_Int32]>;
def Tosa_Int32Or64Tensor :TosaTensorOf<[Tosa_Int32Or64]>;

def Tosa_FloatTensor : TensorOf<[AnyFloat]>;
def Tosa_FloatTensor : TosaTensorOf<[AnyFloat]>;

// Either ranked or unranked tensor of TOSA supported element types.
def Tosa_Tensor : TensorOf<[Tosa_AnyNumber]>;
def Tosa_Tensor : TosaTensorOf<[Tosa_AnyNumber]>;

// Must be ranked but no further constraints
def Tosa_RankedTensor : RankedTensorOf<[Tosa_AnyNumber]>;
def Tosa_RankedTensor : TosaRankedTensorOf<[Tosa_AnyNumber]>;

// Any tensor element type allowed in Tosa ops.
def Tosa_ElementType : Type<Or<[Tosa_Int.predicate, Tosa_QuantizedInt.predicate,
AnyFloat.predicate]>, "tosa.dtype">;

class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
AnyTypeOf<[TensorOf<allowedTypes>, NoneType], description>;
AnyTypeOf<[TosaTensorOf<allowedTypes>, NoneType], description>;

//===----------------------------------------------------------------------===//
// Tensor types with constrained ranks.
//===----------------------------------------------------------------------===//

// Rank-0 (scalar) tensor
def Tosa_ScalarTensor : TensorRankOf<[Tosa_AnyNumber], [0]>;
def Tosa_ScalarTensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;

// We include unranked tensors as a supported type for all possible tosa
// Tensors as unranked does not guarantee invalid. If unranked tensors exist
// they should be shape propagate used Tosa's shape inference pass and verified
// to not include any remaining unranked tensors.
def Tosa_UnrankedTensor : UnrankedTensorOf<[Tosa_AnyNumber]>;
def Tosa_UnrankedTensor : TosaUnrankedTensorOf<[Tosa_AnyNumber]>;

def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, 1DTensorOf<[Tosa_AnyNumber]>], "1-d tensor", "::mlir::TensorType">;
def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, 2DTensorOf<[Tosa_AnyNumber]>], "2-d tensor", "::mlir::TensorType">;
def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, 3DTensorOf<[Tosa_AnyNumber]>], "3-d tensor", "::mlir::TensorType">;
def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, 4DTensorOf<[Tosa_AnyNumber]>], "4-d tensor", "::mlir::TensorType">;
def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tensor", "::mlir::TensorType">;
def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1]>], "1-d tosa-conformant tensor", "::mlir::TensorType">;
def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [2]>], "2-d tosa-conformant tensor", "::mlir::TensorType">;
def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [3]>], "3-d tosa-conformant tensor", "::mlir::TensorType">;
def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [4]>], "4-d tosa-conformant tensor", "::mlir::TensorType">;
def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tosa-conformant tensor", "::mlir::TensorType">;

// Ranked tensors up to given rank.
def Tosa_Tensor1Dto4D : AnyTypeOf<[
Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;
def Tosa_Tensor1Dto6D : AnyTypeOf<[
Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5,6]>]>;
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5,6]>]>;

def Tosa_TensorUpto4D : AnyTypeOf<[
Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>]>;
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>]>;

def Tosa_Int32TensorUpto4D : AnyTypeOf<[
Tosa_UnrankedTensor, TensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>;
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>;

//===----------------------------------------------------------------------===//
// Generic scalar, vector, or tensor of a particular type.
Expand All @@ -142,7 +167,7 @@ def Tosa_Int32TensorUpto4D : AnyTypeOf<[
class Tosa_TypeLike<list<Type> types, string description = ""> : TypeConstraint<Or<[
AnyTypeOf<types>.predicate,
VectorOf<types>.predicate,
TensorOf<types>.predicate]>,
TosaTensorOf<types>.predicate]>,
description>;

def Tosa_IntLike : Tosa_TypeLike<[Tosa_Int], "signless-integer-like">;
Expand Down
13 changes: 13 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,19 @@ TosaOp CreateOpAndInferShape(PatternRewriter &rewriter, Location loc,
return CreateOpAndInferShape<TosaOp>(builder, resultTy, args...);
}

// Apply an int32_t permutation to some input, that should be of the same
// size as perms. Perms should contain some permutation of 0 - perms.size() - 1.
template <typename T>
SmallVector<T> applyTOSAPermutation(ArrayRef<T> input,
ArrayRef<int32_t> perms) {
SmallVector<T> permuted;
size_t N = input.size();
permuted.resize_for_overwrite(N);
for (size_t i = 0; i < N; i++)
permuted[i] = input[perms[i]];
return permuted;
}

} // namespace tosa
} // namespace mlir

Expand Down
Loading

0 comments on commit a54efdb

Please sign in to comment.