Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement constant folding for tosa.slice #388

Merged
merged 2 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1688,6 +1688,133 @@ struct TosaFoldConstantPad : public TosaFoldConstantBase<tosa::PadOp> {
}
};

template <typename BaseType, typename RangeT>
void sliceArray(ShapedType inputType, RangeT inputValues,
llvm::ArrayRef<int64_t> startValues, ShapedType outputType,
SmallVector<BaseType> &outputValues) {

auto outputShape = outputType.getShape();
auto inputShape = inputType.getShape();

int64_t rank = inputType.getRank();

// Implements the logic from
// https://www.mlplatform.org/tosa/tosa_spec.html#_slice
for (size_t outIndex = 0, e = outputValues.size(); outIndex < e; ++outIndex) {
auto indexInTarget = offsetToIndex(outputShape, outIndex);

for (int64_t i = 0; i < rank; ++i) {
indexInTarget[i] = indexInTarget[i] + startValues[i];
}

auto inputIndexOffset = indexToOffset(inputShape, indexInTarget);
outputValues[outIndex] = inputValues[inputIndexOffset];
}
}

template <typename BaseType>
DenseElementsAttr sliceType(ElementsAttr attr, ShapedType inputType,
llvm::ArrayRef<int64_t> start,
ShapedType outputType) {

auto inputValues = attr.getValues<BaseType>();
SmallVector<BaseType> outputValues(outputType.getNumElements(),
*std::begin(inputValues));
sliceArray<BaseType>(inputType, inputValues, start, outputType, outputValues);
return DenseElementsAttr::get(outputType,
llvm::ArrayRef<BaseType>(outputValues));
}

template <typename BaseType>
DenseElementsAttr sliceTypeRaw(ElementsAttr attr, ShapedType inputType,
llvm::ArrayRef<int64_t> start,
ShapedType outputType) {

ArrayRef<BaseType> inputValues =
cast<DenseIntOrFPElementsAttr>(attr).getNonSplatRawData<BaseType>();

SmallVector<BaseType> outputValues;
outputValues.resize_for_overwrite(outputType.getNumElements());
sliceArray<BaseType>(inputType, inputValues, start, outputType, outputValues);

ArrayRef rawOutputValues(reinterpret_cast<const char *>(outputValues.data()),
outputValues.size() * sizeof(BaseType));
return DenseElementsAttr::getFromRawBuffer(outputType, rawOutputValues);
}

DenseElementsAttr slice(ShapedType inputType, ElementsAttr inputValues,
llvm::ArrayRef<int64_t> start, ShapedType outputType) {

auto baseType = inputType.getElementType();

if (inputValues.isSplat()) {
if (isa<IntegerType>(baseType))
return DenseElementsAttr::get(outputType,
inputValues.getSplatValue<APInt>());
return DenseElementsAttr::get(outputType,
inputValues.getSplatValue<APFloat>());
}

// Handle possible integer types
if (auto intType = dyn_cast<IntegerType>(baseType)) {
switch (intType.getWidth()) {
case 1:
// i1 has special alignment which is not handled by sliceTypeRaw.
return sliceType<bool>(inputValues, inputType, start, outputType);
case 8:
return sliceTypeRaw<uint8_t>(inputValues, inputType, start, outputType);
case 16:
return sliceTypeRaw<uint16_t>(inputValues, inputType, start, outputType);
case 32:
return sliceTypeRaw<uint32_t>(inputValues, inputType, start, outputType);
case 64:
return sliceTypeRaw<uint64_t>(inputValues, inputType, start, outputType);
default:
return sliceType<APInt>(inputValues, inputType, start, outputType);
}
}

// Handle possible float types
if (baseType.isF32()) {
return sliceTypeRaw<uint32_t>(inputValues, inputType, start, outputType);
}
if (baseType.isF64()) {
return sliceTypeRaw<uint64_t>(inputValues, inputType, start, outputType);
}
if (baseType.isBF16()) {
return sliceTypeRaw<uint16_t>(inputValues, inputType, start, outputType);
}
return sliceType<APFloat>(inputValues, inputType, start, outputType);
}

struct TosaFoldConstantSlice : public TosaFoldConstantBase<tosa::SliceOp> {
using TosaFoldConstantBase::TosaFoldConstantBase;

LogicalResult matchAndRewrite(tosa::SliceOp op,
PatternRewriter &rewriter) const override {
auto outputType = cast<ShapedType>(op.getType());
// TOSA doesn't support quantized types.
if (!outputType.getElementType().isIntOrIndexOrFloat())
return failure();

auto start = op.getStart();
auto input = op.getInput();
ElementsAttr inputValues;
if (!matchPattern(input, m_Constant(&inputValues)))
return failure();

// Only fold op with multiple users if foldSplatOrSingleUseOnly is false.
if (!llvm::hasSingleElement(input.getDefiningOp()->getUsers()) &&
foldSplatOrSingleUseOnly)
return failure();

auto resultAttr = slice(input.getType(), inputValues, start, outputType);
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputType, resultAttr);

return success();
}
};

template <typename BaseType, typename RangeT>
void tileArray(ShapedType inputType, RangeT inputValues, ShapedType outputType,
SmallVector<BaseType> &outputValues) {
Expand Down Expand Up @@ -1991,6 +2118,7 @@ void mlir::tosa::populateTosaFoldConstantPatterns(
patterns.add<TosaFoldConstantMinimum>(ctx, options.foldSplatOrSingleUseOnly);
patterns.add<TosaFoldConstantMaximum>(ctx, options.foldSplatOrSingleUseOnly);
patterns.add<TosaFoldConstantPad>(ctx, options.foldSplatOrSingleUseOnly);
patterns.add<TosaFoldConstantSlice>(ctx, options.foldSplatOrSingleUseOnly);
patterns.add<TosaFoldConstantMatMul>(ctx, options.foldSplatOrSingleUseOnly);
if (options.enableTileFolding)
patterns.add<TosaFoldConstantTile>(ctx, options.foldSplatOrSingleUseOnly);
Expand Down
13 changes: 13 additions & 0 deletions mlir/test/Dialect/Tosa/constant-slice-multi-user.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold="fold-splat-or-single-use-only=0" %s | FileCheck %s
// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold="fold-splat-or-single-use-only=1" %s | FileCheck %s --check-prefix=ONLY-SINGLE-USE-CHECK

// CHECK-LABEL: @slice_bf16
func.func @slice_bf16() -> (tensor<3x3xbf16>, tensor<3x2xbf16>) {
// CHECK-DAG: "tosa.const"() <{value = dense<{{\[\[}}3.000000e+00, 4.000000e+00, 5.000000e+00], [6.000000e+00, 7.000000e+00, 8.000000e+00], [9.000000e+00, 1.000000e+01, 1.100000e+01]]>
// CHECK-DAG: "tosa.const"() <{value = dense<{{\[\[}}4.000000e+00, 5.000000e+00], [7.000000e+00, 8.000000e+00], [1.000000e+01, 1.100000e+01]]>
// ONLY-SINGLE-USE-CHECK: tosa.slice
%0 = "tosa.const"() {value = dense<[[3.0, 4.0, 5.0], [6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]> : tensor<3x3xbf16>} : () -> tensor<3x3xbf16>
%1 = "tosa.slice"(%0){size = array<i64: 3, 2>, start = array<i64: 0, 1>} : (tensor<3x3xbf16>) -> tensor<3x2xbf16>
return %0, %1 : tensor<3x3xbf16>, tensor<3x2xbf16>
}

172 changes: 172 additions & 0 deletions mlir/test/Dialect/Tosa/constant-slice.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold %s | FileCheck %s

// CHECK-LABEL: @slice_int8
func.func @slice_int8() -> (tensor<1x1xi8>) {
// CHECK: "tosa.const"() <{value = dense<3>
%0 = "tosa.const"() {value = dense<[[3, 4], [5, 6]]> : tensor<2x2xi8>} : () -> tensor<2x2xi8>
%1 = "tosa.slice"(%0){size = array<i64: 1, 1>, start = array<i64: 0, 0>} : (tensor<2x2xi8>) -> tensor<1x1xi8>
return %1 : tensor<1x1xi8>
}

func.func @slice_int16() -> (tensor<2x1xi16>) {
// CHECK: "tosa.const"() <{value = dense<{{\[\[}}3], [5]]>
%0 = "tosa.const"() {value = dense<[[3, 4], [5, 6]]> : tensor<2x2xi16>} : () -> tensor<2x2xi16>
%1 = "tosa.slice"(%0){size = array<i64: 2, 1>, start = array<i64: 0, 0>} : (tensor<2x2xi16>) -> tensor<2x1xi16>
return %1 : tensor<2x1xi16>
}

// CHECK-LABEL: @slice_int32
func.func @slice_int32() -> (tensor<2x1xi32>) {
// CHECK: "tosa.const"() <{value = dense<{{\[\[}}4], [6]]>
%0 = "tosa.const"() {value = dense<[[3, 4], [5, 6]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
%1 = "tosa.slice"(%0){size = array<i64: 2, 1>, start = array<i64: 0, 1>} : (tensor<2x2xi32>) -> tensor<2x1xi32>
return %1 : tensor<2x1xi32>
}

// CHECK-LABEL: @slice_int32_default_value
func.func @slice_int32_default_value() -> (tensor<3x1xi32>) {
// CHECK: "tosa.const"() <{value = dense<{{\[\[}}3], [6], [9]]>
%0 = "tosa.const"() {value = dense<[[3, 4, 5], [6, 7, 8], [9, 10, 11]]> : tensor<3x3xi32>} : () -> tensor<3x3xi32>
%1 = "tosa.slice"(%0){size = array<i64: 3, 1>, start = array<i64: 0, 0>} : (tensor<3x3xi32>) -> tensor<3x1xi32>
return %1 : tensor<3x1xi32>
}

// CHECK-LABEL: @slice_bf16_default_value
func.func @slice_bf16_default_value() -> (tensor<3x2xbf16>) {
// CHECK: "tosa.const"() <{value = dense<{{\[\[}}4.000000e+00, 5.000000e+00], [7.000000e+00, 8.000000e+00], [1.000000e+01, 1.100000e+01]]>
%0 = "tosa.const"() {value = dense<[[3.0, 4.0, 5.0], [6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]> : tensor<3x3xbf16>} : () -> tensor<3x3xbf16>
%1 = "tosa.slice"(%0){size = array<i64: 3, 2>, start = array<i64: 0, 1>} : (tensor<3x3xbf16>) -> tensor<3x2xbf16>
return %1 : tensor<3x2xbf16>
}

// -----

// Following tests are all done with the following tensor, and different configurations:
// [[[1.0 , 2.25 , 3.50 , 4.75],
// [ 5.0 , 6.25 , 7.50 , 8.75]],
// [[ 13.0 , 14.25 , 15.50 , 16.75 ],
// [ 17.0 , 18.25 , 19.50 , 20.75]],
// [[-1.0 , -2.25 , -3.50 , -4.75],
// [ -5.0 , -6.25 , -7.50 , -8.75]],
// [[ -13.0 , -14.25 , -15.50 , -16.75 ],
// [ -17.0 , -18.25 , -19.50 , -20.75]]]

// Should produce
// 1.0, 2.25, 3.50, 4.75,
// 13.0, 14.25, 15.50, 16.75,
// -1.0, -2.25, -3.50, -4.75,
// -13.0, -14.25, -15.50, -16.75
func.func @slice_bf16_dim_1_start_zero() -> (tensor<4x1x4xbf16>) {
// CHECK-LABEL: @slice_bf16_dim_1_start_zero
// CHECK: 1.000000e+00, 2.250000e+00, 3.500000e+00, 4.750000e+00
// CHECK-SAME: 1.300000e+01, 1.425000e+01, 1.550000e+01, 1.675000e+01
// CHECK-SAME: -1.000000e+00, -2.250000e+00, -3.500000e+00, -4.750000e+00
// CHECK-SAME: -1.300000e+01, -1.425000e+01, -1.550000e+01, -1.675000e+01
%0 = "tosa.const"() {value = dense<[[[1.0, 2.25, 3.50, 4.75], [ 5.0, 6.25, 7.50, 8.75]], [[ 13.0, 14.25, 15.50, 16.75 ], [ 17.0, 18.25, 19.50, 20.75]], [[-1.0, -2.25, -3.50, -4.75], [ -5.0, -6.25, -7.50, -8.75]], [[ -13.0, -14.25, -15.50, -16.75 ], [ -17.0, -18.25, -19.50, -20.75]]]> : tensor<4x2x4xbf16>} : () -> tensor<4x2x4xbf16>
%1 = "tosa.slice"(%0){size = array<i64: 4, 1, 4>, start = array<i64: 0, 0, 0>} : (tensor<4x2x4xbf16>) -> tensor<4x1x4xbf16>
return %1 : tensor<4x1x4xbf16>
}

// Should produce
// 1.0, 2.25, 3.50, 4.75,
// 13.0, 14.25, 15.50, 16.75,
// -1.0, -2.25, -3.50, -4.75,
// -13.0, -14.25, -15.50, -16.75
func.func @slice_f16_dim_1_start_zero() -> (tensor<4x1x4xf16>) {
// CHECK-LABEL: @slice_f16_dim_1_start_zero
// CHECK: 1.000000e+00, 2.250000e+00, 3.500000e+00, 4.750000e+00
// CHECK-SAME: 1.300000e+01, 1.425000e+01, 1.550000e+01, 1.675000e+01
// CHECK-SAME: -1.000000e+00, -2.250000e+00, -3.500000e+00, -4.750000e+00
// CHECK-SAME: -1.300000e+01, -1.425000e+01, -1.550000e+01, -1.675000e+01
%0 = "tosa.const"() {value = dense<[[[1.0, 2.25, 3.50, 4.75], [ 5.0, 6.25, 7.50, 8.75]], [[ 13.0, 14.25, 15.50, 16.75 ], [ 17.0, 18.25, 19.50, 20.75]], [[-1.0, -2.25, -3.50, -4.75], [ -5.0, -6.25, -7.50, -8.75]], [[ -13.0, -14.25, -15.50, -16.75 ], [ -17.0, -18.25, -19.50, -20.75]]]> : tensor<4x2x4xf16>} : () -> tensor<4x2x4xf16>
%1 = "tosa.slice"(%0){size = array<i64: 4, 1, 4>, start = array<i64: 0, 0, 0>} : (tensor<4x2x4xf16>) -> tensor<4x1x4xf16>
return %1 : tensor<4x1x4xf16>
}

// Should produce
// 5.0, 6.25, 7.50, 8.75
// 17.0, 18.25, 19.50, 20.75
// -5.0, -6.25, -7.50, -8.75
// -17.0, -18.25, -19.50, -20.75
func.func @slice_bf16_start_dim_1_start_one() -> (tensor<4x1x4xbf16>) {
// CHECK-LABEL: @slice_bf16_start_dim_1_start_one
// CHECK: 5.000000e+00, 6.250000e+00, 7.500000e+00, 8.750000e+00
// CHECK-SAME: 1.700000e+01, 1.825000e+01, 1.950000e+01, 2.075000e+01
// CHECK-SAME: -5.000000e+00, -6.250000e+00, -7.500000e+00, -8.750000e+00
// CHECK-SAME: -1.700000e+01, -1.825000e+01, -1.950000e+01, -2.075000e+01
%0 = "tosa.const"() {value = dense<[[[1.0, 2.25, 3.50, 4.75], [ 5.0, 6.25, 7.50, 8.75]], [[ 13.0, 14.25, 15.50, 16.75 ], [ 17.0, 18.25, 19.50, 20.75]], [[-1.0, -2.25, -3.50, -4.75], [ -5.0, -6.25, -7.50, -8.75]], [[ -13.0, -14.25, -15.50, -16.75 ], [ -17.0, -18.25, -19.50, -20.75]]]> : tensor<4x2x4xbf16>} : () -> tensor<4x2x4xbf16>
%1 = "tosa.slice"(%0){size = array<i64: 4, 1, 4>, start = array<i64: 0, 1, 0>} : (tensor<4x2x4xbf16>) -> tensor<4x1x4xbf16>
return %1 : tensor<4x1x4xbf16>
}

// Should produce
// 5.0, 6.25, 7.50, 8.75
// 17.0, 18.25, 19.50, 20.75
// -5.0, -6.25, -7.50, -8.75
// -17.0, -18.25, -19.50, -20.75
func.func @slice_f16_start_dim_1_start_one() -> (tensor<4x1x4xf16>) {
// CHECK-LABEL: @slice_f16_start_dim_1_start_one
// CHECK: 5.000000e+00, 6.250000e+00, 7.500000e+00, 8.750000e+00
// CHECK-SAME: 1.700000e+01, 1.825000e+01, 1.950000e+01, 2.075000e+01
// CHECK-SAME: -5.000000e+00, -6.250000e+00, -7.500000e+00, -8.750000e+00
// CHECK-SAME: -1.700000e+01, -1.825000e+01, -1.950000e+01, -2.075000e+01
%0 = "tosa.const"() {value = dense<[[[1.0, 2.25, 3.50, 4.75], [ 5.0, 6.25, 7.50, 8.75]], [[ 13.0, 14.25, 15.50, 16.75 ], [ 17.0, 18.25, 19.50, 20.75]], [[-1.0, -2.25, -3.50, -4.75], [ -5.0, -6.25, -7.50, -8.75]], [[ -13.0, -14.25, -15.50, -16.75 ], [ -17.0, -18.25, -19.50, -20.75]]]> : tensor<4x2x4xf16>} : () -> tensor<4x2x4xf16>
%1 = "tosa.slice"(%0){size = array<i64: 4, 1, 4>, start = array<i64: 0, 1, 0>} : (tensor<4x2x4xf16>) -> tensor<4x1x4xf16>
return %1 : tensor<4x1x4xf16>
}

// Should produce
// 1.0, 2.25, 3.50
// 13.0, 14.25, 15.50
// -1.0, -2.25, -3.50
func.func @slice_bf16_start_zero_multiple_dims() -> (tensor<3x1x3xbf16>) {
// CHECK-LABEL: @slice_bf16_start_zero_multiple_dims
// CHECK: 1.000000e+00, 2.250000e+00, 3.500000e+00
// CHECK-SAME: 1.300000e+01, 1.425000e+01, 1.550000e+01
// CHECK-SAME: -1.000000e+00, -2.250000e+00, -3.500000e+00
%0 = "tosa.const"() {value = dense<[[[1.0, 2.25, 3.50, 4.75], [ 5.0, 6.25, 7.50, 8.75]], [[ 13.0, 14.25, 15.50, 16.75 ], [ 17.0, 18.25, 19.50, 20.75]], [[-1.0, -2.25, -3.50, -4.75], [ -5.0, -6.25, -7.50, -8.75]], [[ -13.0, -14.25, -15.50, -16.75 ], [ -17.0, -18.25, -19.50, -20.75]]]> : tensor<4x2x4xbf16>} : () -> tensor<4x2x4xbf16>
%1 = "tosa.slice"(%0){size = array<i64: 3, 1, 3>, start = array<i64: 0, 0, 0>} : (tensor<4x2x4xbf16>) -> tensor<3x1x3xbf16>
return %1 : tensor<3x1x3xbf16>
}

// Should produce
// 1.0, 2.25, 3.50
// 13.0, 14.25, 15.50
// -1.0, -2.25, -3.50
func.func @slice_f16_start_zero_multiple_dims() -> (tensor<3x1x3xf16>) {
// CHECK-LABEL: @slice_f16_start_zero_multiple_dims
// CHECK: 1.000000e+00, 2.250000e+00, 3.500000e+00
// CHECK-SAME: 1.300000e+01, 1.425000e+01, 1.550000e+01
// CHECK-SAME: -1.000000e+00, -2.250000e+00, -3.500000e+00
%0 = "tosa.const"() {value = dense<[[[1.0, 2.25, 3.50, 4.75], [ 5.0, 6.25, 7.50, 8.75]], [[ 13.0, 14.25, 15.50, 16.75 ], [ 17.0, 18.25, 19.50, 20.75]], [[-1.0, -2.25, -3.50, -4.75], [ -5.0, -6.25, -7.50, -8.75]], [[ -13.0, -14.25, -15.50, -16.75 ], [ -17.0, -18.25, -19.50, -20.75]]]> : tensor<4x2x4xf16>} : () -> tensor<4x2x4xf16>
%1 = "tosa.slice"(%0){size = array<i64: 3, 1, 3>, start = array<i64: 0, 0, 0>} : (tensor<4x2x4xf16>) -> tensor<3x1x3xf16>
return %1 : tensor<3x1x3xf16>
}

// Produces
// 18.25, 19.50, 20.75
// -6.25, -7.50, -8.75
// -18.25, -19.50, -20.75
func.func @slice_bf16_start_non_zero_multiple_dims() -> (tensor<3x1x3xbf16>) {
// CHECK-LABEL: @slice_bf16_start_non_zero_multiple_dims
// CHECK: 1.825000e+01, 1.950000e+01, 2.075000e+01
// CHECK-SAME: -6.250000e+00, -7.500000e+00, -8.750000e+00
// CHECK-SAME: -1.825000e+01, -1.950000e+01, -2.075000e+01
%0 = "tosa.const"() {value = dense<[[[1.0, 2.25, 3.50, 4.75], [ 5.0, 6.25, 7.50, 8.75]], [[ 13.0, 14.25, 15.50, 16.75 ], [ 17.0, 18.25, 19.50, 20.75]], [[-1.0, -2.25, -3.50, -4.75], [ -5.0, -6.25, -7.50, -8.75]], [[ -13.0, -14.25, -15.50, -16.75 ], [ -17.0, -18.25, -19.50, -20.75]]]> : tensor<4x2x4xbf16>} : () -> tensor<4x2x4xbf16>
%1 = "tosa.slice"(%0){size = array<i64: 3, 1, 3>, start = array<i64: 1, 1, 1>} : (tensor<4x2x4xbf16>) -> tensor<3x1x3xbf16>
return %1 : tensor<3x1x3xbf16>
}

// Produces
// 18.25, 19.50, 20.75
// -6.25, -7.50, -8.75
// -18.25, -19.50, -20.75
func.func @slice_f16_start_non_zero_multiple_dims() -> (tensor<3x1x3xf16>) {
// CHECK-LABEL: @slice_f16_start_non_zero_multiple_dims
// CHECK: 1.825000e+01, 1.950000e+01, 2.075000e+01
// CHECK-SAME: -6.250000e+00, -7.500000e+00, -8.750000e+00
// CHECK-SAME: -1.825000e+01, -1.950000e+01, -2.075000e+01
%0 = "tosa.const"() {value = dense<[[[1.0, 2.25, 3.50, 4.75], [ 5.0, 6.25, 7.50, 8.75]], [[ 13.0, 14.25, 15.50, 16.75 ], [ 17.0, 18.25, 19.50, 20.75]], [[-1.0, -2.25, -3.50, -4.75], [ -5.0, -6.25, -7.50, -8.75]], [[ -13.0, -14.25, -15.50, -16.75 ], [ -17.0, -18.25, -19.50, -20.75]]]> : tensor<4x2x4xf16>} : () -> tensor<4x2x4xf16>
%1 = "tosa.slice"(%0){size = array<i64: 3, 1, 3>, start = array<i64: 1, 1, 1>} : (tensor<4x2x4xf16>) -> tensor<3x1x3xf16>
return %1 : tensor<3x1x3xf16>
}