Skip to content

Commit

Permalink
feat: Tosa folders for bitwise_and, bitwise_or, greater_equal and log.
Browse files Browse the repository at this point in the history
  • Loading branch information
ttjost committed Jul 28, 2023
1 parent f36f404 commit 4f03561
Show file tree
Hide file tree
Showing 8 changed files with 865 additions and 0 deletions.
175 changes: 175 additions & 0 deletions mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1090,6 +1090,174 @@ struct TosaFoldConstantErf
}
};

struct TosaFoldConstantLog
: public TosaFoldConstantUnaryElementwise<TosaFoldConstantLog, LogOp> {
using TosaFoldConstantUnaryElementwise<
TosaFoldConstantLog, LogOp>::TosaFoldConstantUnaryElementwise;

DenseElementsAttr computeFloat(DenseElementsAttr values,
PatternRewriter &rewriter, TosaOp op) const {
return applyElementWise<APFloat, APFloat, FloatType>(
values,
[](const APFloat &val, FloatType) {
auto res = APFloat(std::log(val.convertToFloat()));
bool lostPrecision;
res.convert(val.getSemantics(), APFloat::rmNearestTiesToEven,
&lostPrecision);
return res;
},
cast<FloatType>(values.getElementType()));
}

bool isSupportedElementType(Type type) const {
// convertToFloat uses F32, so we specify the supported types to make sure
// to properly handle F64 if needed in the future.
return type.isBF16() || type.isF16() || type.isF32();
}
};

struct TosaFoldConstantBitwiseAnd
: public TosaFoldConstantBinary<TosaFoldConstantBitwiseAnd, BitwiseAndOp> {
using TosaFoldConstantBinary<TosaFoldConstantBitwiseAnd,
BitwiseAndOp>::TosaFoldConstantBinary;

DenseElementsAttr computeInteger(DenseElementsAttr lhsValues,
DenseElementsAttr rhsValues,
PatternRewriter &rewriter,
BitwiseAndOp op) const {
return applyElementWise<APInt, APInt>(
lhsValues, rhsValues, op.getType(),
[](const APInt &lhs, const APInt &rhs) { return lhs & rhs; });
}
};

struct TosaFoldConstantBitwiseOr
: public TosaFoldConstantBinary<TosaFoldConstantBitwiseOr, BitwiseOrOp> {
using TosaFoldConstantBinary<TosaFoldConstantBitwiseOr,
BitwiseOrOp>::TosaFoldConstantBinary;

DenseElementsAttr computeInteger(DenseElementsAttr lhsValues,
DenseElementsAttr rhsValues,
PatternRewriter &rewriter,
BitwiseOrOp op) const {
return applyElementWise<APInt, APInt>(
lhsValues, rhsValues, op.getType(),
[](const APInt &lhs, const APInt &rhs) { return lhs | rhs; });
}
};

struct TosaFoldConstantGreaterEqual
: public TosaFoldConstantBinary<TosaFoldConstantGreaterEqual,
GreaterEqualOp> {
using TosaFoldConstantBinary<TosaFoldConstantGreaterEqual,
GreaterEqualOp>::TosaFoldConstantBinary;

DenseElementsAttr computeInteger(DenseElementsAttr lhsValues,
DenseElementsAttr rhsValues,
PatternRewriter &rewriter,
GreaterEqualOp op) const {
return applyElementWise<APInt, APInt>(
lhsValues, rhsValues, op.getType(),
[](const APInt &first, const APInt &second) {
return APInt(1, first.sge(second));
});
}

DenseElementsAttr computeFloat(DenseElementsAttr lhsValues,
DenseElementsAttr rhsValues,
PatternRewriter &rewriter,
GreaterEqualOp op) const {
return applyElementWise<APFloat, APInt>(
lhsValues, rhsValues, op.getType(),
[](const APFloat &first, const APFloat &second) {
return APInt(1, first >= second);
});
}
};

struct TosaFoldConstantEqual
: public TosaFoldConstantBinary<TosaFoldConstantEqual, EqualOp> {
using TosaFoldConstantBinary<TosaFoldConstantEqual,
EqualOp>::TosaFoldConstantBinary;

DenseElementsAttr computeInteger(DenseElementsAttr lhsValues,
DenseElementsAttr rhsValues,
PatternRewriter &rewriter,
EqualOp op) const {
return applyElementWise<APInt, APInt>(
lhsValues, rhsValues, op.getType(),
[](const APInt &first, const APInt &second) {
return APInt(1, first.eq(second));
});
}

DenseElementsAttr computeFloat(DenseElementsAttr lhsValues,
DenseElementsAttr rhsValues,
PatternRewriter &rewriter, EqualOp op) const {
return applyElementWise<APFloat, APInt>(
lhsValues, rhsValues, op.getType(),
[](const APFloat &first, const APFloat &second) {
return APInt(1, first == second);
});
}
};

struct TosaFoldConstantMinimum
: public TosaFoldConstantBinary<TosaFoldConstantMinimum, MinimumOp> {
using TosaFoldConstantBinary<TosaFoldConstantMinimum,
MinimumOp>::TosaFoldConstantBinary;

DenseElementsAttr computeInteger(DenseElementsAttr lhsValues,
DenseElementsAttr rhsValues,
PatternRewriter &rewriter,
MinimumOp op) const {
return applyElementWise<APInt, APInt>(
lhsValues, rhsValues, op.getType(),
[](const APInt &first, const APInt &second) {
return first.slt(second) ? first : second;
});
}

DenseElementsAttr computeFloat(DenseElementsAttr lhsValues,
DenseElementsAttr rhsValues,
PatternRewriter &rewriter,
MinimumOp op) const {
return applyElementWise<APFloat, APFloat>(
lhsValues, rhsValues, op.getType(),
[](const APFloat &first, const APFloat &second) {
return first < second ? first : second;
});
}
};

struct TosaFoldConstantMaximum
: public TosaFoldConstantBinary<TosaFoldConstantMaximum, MaximumOp> {
using TosaFoldConstantBinary<TosaFoldConstantMaximum,
MaximumOp>::TosaFoldConstantBinary;

DenseElementsAttr computeInteger(DenseElementsAttr lhsValues,
DenseElementsAttr rhsValues,
PatternRewriter &rewriter,
MaximumOp op) const {
return applyElementWise<APInt, APInt>(
lhsValues, rhsValues, op.getType(),
[](const APInt &first, const APInt &second) {
return first.sgt(second) ? first : second;
});
}

DenseElementsAttr computeFloat(DenseElementsAttr lhsValues,
DenseElementsAttr rhsValues,
PatternRewriter &rewriter,
MaximumOp op) const {
return applyElementWise<APFloat, APFloat>(
lhsValues, rhsValues, op.getType(),
[](const APFloat &first, const APFloat &second) {
return first > second ? first : second;
});
}
};

} // namespace

void mlir::tosa::populateTosaFoldConstantPatterns(
Expand All @@ -1113,4 +1281,11 @@ void mlir::tosa::populateTosaFoldConstantPatterns(
patterns.add<TosaFoldConstantBitwiseNot>(ctx, foldSplatOrSingleUseOnly);
patterns.add<TosaFoldConstantCeil>(ctx, foldSplatOrSingleUseOnly);
patterns.add<TosaFoldConstantErf>(ctx, foldSplatOrSingleUseOnly);
patterns.add<TosaFoldConstantLog>(ctx, foldSplatOrSingleUseOnly);
patterns.add<TosaFoldConstantBitwiseAnd>(ctx, foldSplatOrSingleUseOnly);
patterns.add<TosaFoldConstantBitwiseOr>(ctx, foldSplatOrSingleUseOnly);
patterns.add<TosaFoldConstantGreaterEqual>(ctx, foldSplatOrSingleUseOnly);
patterns.add<TosaFoldConstantEqual>(ctx, foldSplatOrSingleUseOnly);
patterns.add<TosaFoldConstantMinimum>(ctx, foldSplatOrSingleUseOnly);
patterns.add<TosaFoldConstantMaximum>(ctx, foldSplatOrSingleUseOnly);
}
74 changes: 74 additions & 0 deletions mlir/test/Dialect/Tosa/constant-bitwise-and.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold %s | FileCheck %s

// CHECK-LABEL: @bitwise_and_fold_single_valued
func.func @bitwise_and_fold_single_valued() -> tensor<i32> {
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}-65536
// CHECK-NOT: tosa.bitwise_and
// CHECK: return [[RES]]
%0 = "tosa.const"() {value = dense<0xFFFFFFFF> : tensor<i32>} : () -> tensor<i32>
%1 = "tosa.const"() {value = dense<0xFFFF0000> : tensor<i32>} : () -> tensor<i32>
%2 = "tosa.bitwise_and"(%0, %1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
return %2 : tensor<i32>
}

// CHECK-LABEL: @bitwise_and_fold_splat
func.func @bitwise_and_fold_splat() -> tensor<12x7xi32> {
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}65535
// CHECK-NOT: tosa.bitwise_and
// CHECK: return [[RES]]
%0 = "tosa.const"() {value = dense<0xFFFFFFFF> : tensor<12x7xi32>} : () -> tensor<12x7xi32>
%1 = "tosa.const"() {value = dense<0x0000FFFF> : tensor<12x7xi32>} : () -> tensor<12x7xi32>
%2 = "tosa.bitwise_and"(%0, %1) : (tensor<12x7xi32>, tensor<12x7xi32>) -> tensor<12x7xi32>
return %2 : tensor<12x7xi32>
}

// CHECK-LABEL: @bitwise_and_no_fold
// The folding optimization works only intra-procedurally, so we won't be able
// to fold anything here
func.func @bitwise_and_no_fold(%arg0: tensor<?x?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
// CHECK: tosa.bitwise_and
// CHECK-NEXT: return
%0 = "tosa.bitwise_and"(%arg0, %arg1) : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
return %0 : tensor<?x?xi32>
}

// CHECK-LABEL: @bitwise_and_fold
func.func @bitwise_and_fold() -> tensor<2x6xi32> {
// CHECK: [[RES:]] ={{.*}}tosa.const
// CHECK-SAME{LITERAL}: [[-1, -2, -3, -4, -5, -6],
// CHECK-SAME{LITERAL}: [1, 2, 3, 4, 5, 6]]
// CHECK-NOT: tosa.bitwise_and
// CHECK: return [[RES]]
%0 = "tosa.const"() { value = dense<
[[0xFFFFFFFF, 0xFFFFFFFE, 0xFFFFFFFD,
0xFFFFFFFC, 0xFFFFFFFB, 0xFFFFFFFA],
[1, 2, 3, 4, 5, 6]]>
: tensor<2x6xi32>
} : () -> tensor<2x6xi32>
%1 = "tosa.const"() { value = dense<
[[0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF],
[0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF]]>
: tensor<2x6xi32>
} : () -> tensor<2x6xi32>
%2 = "tosa.bitwise_and"(%0, %1) : (tensor<2x6xi32>, tensor<2x6xi32>) -> tensor<2x6xi32>
return %2 : tensor<2x6xi32>
}

// CHECK-LABEL: @bitwise_and_of_const_sparse
// Sparse tensors are currently not supported
func.func @bitwise_and_of_const_sparse() -> tensor<32xi8> {
// CHECK: tosa.const
// CHECK: tosa.bitwise_and
%0 = "tosa.const"() { value = sparse<
[[0], [3], [11], [17], [20], [23], [25], [30], [31]],
[0, 1, 2, 3, 4, 0xFF, 0xFE, 0xFD, 0xFC]>
: tensor<32xi8> } : () -> tensor<32xi8>
%1 = "tosa.const"() { value = sparse<
[[0], [3], [11], [17], [20], [23], [25], [30], [31]],
[0, 1, 2, 3, 4, 0xFF, 0xFE, 0xFD, 0xFC]>
: tensor<32xi8> } : () -> tensor<32xi8>
%2 = "tosa.bitwise_and"(%0, %1) : (tensor<32xi8>, tensor<32xi8>) -> tensor<32xi8>
return %2 : tensor<32xi8>
}
73 changes: 73 additions & 0 deletions mlir/test/Dialect/Tosa/constant-bitwise-or.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold %s | FileCheck %s

// CHECK-LABEL: @bitwise_or_fold_single_valued
func.func @bitwise_or_fold_single_valued() -> tensor<i32> {
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}-1
// CHECK-NOT: tosa.bitwise_or
// CHECK: return [[RES]]
%0 = "tosa.const"() {value = dense<0xFFFFFFFF> : tensor<i32>} : () -> tensor<i32>
%1 = "tosa.const"() {value = dense<0xFFFF0000> : tensor<i32>} : () -> tensor<i32>
%2 = "tosa.bitwise_or"(%0, %1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
return %2 : tensor<i32>
}

// CHECK-LABEL: @bitwise_or_fold_splat
func.func @bitwise_or_fold_splat() -> tensor<12x7xi32> {
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}-1
// CHECK-NOT: tosa.bitwise_or
// CHECK: return [[RES]]
%0 = "tosa.const"() {value = dense<0xFFFFFFFF> : tensor<12x7xi32>} : () -> tensor<12x7xi32>
%1 = "tosa.const"() {value = dense<0x0000FFFF> : tensor<12x7xi32>} : () -> tensor<12x7xi32>
%2 = "tosa.bitwise_or"(%0, %1) : (tensor<12x7xi32>, tensor<12x7xi32>) -> tensor<12x7xi32>
return %2 : tensor<12x7xi32>
}

// CHECK-LABEL: @bitwise_or_no_fold
// The folding optimization works only intra-procedurally, so we won't be able
// to fold anything here
func.func @bitwise_or_no_fold(%arg0: tensor<?x?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
// CHECK: tosa.bitwise_or
// CHECK-NEXT: return
%0 = "tosa.bitwise_or"(%arg0, %arg1) : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
return %0 : tensor<?x?xi32>
}

// CHECK-LABEL: @bitwise_or_fold
func.func @bitwise_or_fold() -> tensor<2x6xi32> {
// CHECK: [[RES:]] ={{.*}}tosa.const
// CHECK-SAME{LITERAL}: [[-1, -1, -1, -1, -1, -1],
// CHECK-SAME{LITERAL}: [1, 3, 3, 5, 5, 7]]
// CHECK-NOT: tosa.bitwise_or
// CHECK: return [[RES]]
%0 = "tosa.const"() { value = dense<
[[0xFFFFFFFF, 0xFFFFFFFE, 0xFFFFFFFD,
0xFFFFFFFC, 0xFFFFFFFB, 0xFFFFFFFA],
[1, 2, 3, 4, 5, 6]]>
: tensor<2x6xi32>
} : () -> tensor<2x6xi32>
%1 = "tosa.const"() { value = dense<
[[0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF],
[1, 1, 1, 1, 1, 1]]>
: tensor<2x6xi32>
} : () -> tensor<2x6xi32>
%2 = "tosa.bitwise_or"(%0, %1) : (tensor<2x6xi32>, tensor<2x6xi32>) -> tensor<2x6xi32>
return %2 : tensor<2x6xi32>
}

// CHECK-LABEL: @bitwise_or_of_const_sparse
// Sparse tensors are currently not supported
func.func @bitwise_or_of_const_sparse() -> tensor<32xi8> {
// CHECK: tosa.const
// CHECK: tosa.bitwise_or
%0 = "tosa.const"() { value = sparse<
[[0], [3], [11], [17], [20], [23], [25], [30], [31]],
[0, 1, 2, 3, 4, 0xFF, 0xFE, 0xFD, 0xFC]>
: tensor<32xi8> } : () -> tensor<32xi8>
%1 = "tosa.const"() { value = sparse<
[[0], [3], [11], [17], [20], [23], [25], [30], [31]],
[0, 1, 2, 3, 4, 0xFF, 0xFE, 0xFD, 0xFC]>
: tensor<32xi8> } : () -> tensor<32xi8>
%2 = "tosa.bitwise_or"(%0, %1) : (tensor<32xi8>, tensor<32xi8>) -> tensor<32xi8>
return %2 : tensor<32xi8>
}
Loading

0 comments on commit 4f03561

Please sign in to comment.