Skip to content

Commit

Permalink
Address PR comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
ttjost committed Jul 25, 2023
1 parent 56e56b0 commit 6df4548
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 12 deletions.
39 changes: 27 additions & 12 deletions mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,13 @@ struct TosaFoldConstantUnaryElementwise : public TosaFoldConstantBase<TosaOp> {
"tensor has a single user");
}

TensorType opType = dyn_cast<TensorType>(op.getType());
if (opType == nullptr ||
!static_cast<const BaseClass *>(this)->isSupportedElementType(
opType.getElementType())) {
return rewriter.notifyMatchFailure(op, "Type is not supported.");
}

DenseElementsAttr newTensor = static_cast<const BaseClass *>(this)->compute(
inputValues, rewriter, op);
if (!newTensor) {
Expand Down Expand Up @@ -394,6 +401,9 @@ struct TosaFoldConstantUnaryElementwise : public TosaFoldConstantBase<TosaOp> {
PatternRewriter &rewriter, TosaOp op) const {
return {};
}

/// Return true if the \p elementType is supported by the folder.
bool isSupportedElementType(Type type) const { return true; }
};

template<typename BaseClass, typename TosaOp>
Expand Down Expand Up @@ -559,6 +569,10 @@ struct TosaFoldConstantRSQRT
return applyElementWise<APFloat, APFloat, FloatType>(
values, &computeRSQRT, cast<FloatType>(values.getElementType()));
}

bool isSupportedElementType(Type type) const {
return type.isBF16() || type.isF16() || type.isF32();
}
};

struct TosaFoldConstantLogicalNot
Expand Down Expand Up @@ -618,6 +632,10 @@ struct TosaFoldConstantPow
return applyElementWise<APFloat, APFloat>(lhsValues, rhsValues,
op.getType(), computePower);
}

bool isSupportedElementType(Type type) const {
return type.isBF16() || type.isF16() || type.isF32();
}
};

struct TosaFoldConstantMul
Expand Down Expand Up @@ -1023,10 +1041,7 @@ struct TosaFoldConstantBitwiseNot
DenseElementsAttr computeInteger(DenseElementsAttr values,
PatternRewriter &rewriter, TosaOp op) const {
return applyElementWise<APInt, APInt, IntegerType>(
values,
[](const APInt &val, IntegerType) {
return APInt(val.getBitWidth(), ~val.getZExtValue());
},
values, [](const APInt &val, IntegerType) { return ~val; },
cast<IntegerType>(values.getElementType()));
}
};
Expand All @@ -1041,12 +1056,8 @@ struct TosaFoldConstantCeil
return applyElementWise<APFloat, APFloat, FloatType>(
values,
[](const APFloat &val, FloatType) {
// Compute ceil (APFloat unfortunately does not provide this function,
// such that we need to unpack here)
auto res = APFloat(std::ceil(val.convertToFloat()));
bool lostPrecision;
res.convert(val.getSemantics(), APFloat::rmNearestTiesToEven,
&lostPrecision);
auto res = val;
res.roundToIntegral(llvm::RoundingMode::TowardPositive);
return res;
},
cast<FloatType>(values.getElementType()));
Expand All @@ -1063,8 +1074,6 @@ struct TosaFoldConstantErf
return applyElementWise<APFloat, APFloat, FloatType>(
values,
[](const APFloat &val, FloatType) {
// Compute ceil (APFloat unfortunately does not provide this function,
// such that we need to unpack here)
auto res = APFloat(std::erf(val.convertToFloat()));
bool lostPrecision;
res.convert(val.getSemantics(), APFloat::rmNearestTiesToEven,
Expand All @@ -1073,6 +1082,12 @@ struct TosaFoldConstantErf
},
cast<FloatType>(values.getElementType()));
}

bool isSupportedElementType(Type type) const {
// Note: For now, we only support BF16 and F32 as std::erf may
// have an impact on the accuracy of the returned value.
return type.isBF16() || type.isF32();
}
};

} // namespace
Expand Down
20 changes: 20 additions & 0 deletions mlir/test/Dialect/Tosa/constant-ceil.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,26 @@ func.func @ceil_fold_splat() -> tensor<12x7xf32> {
return %1 : tensor<12x7xf32>
}

// CHECK-LABEL: @ceil_fold_bf16
func.func @ceil_fold_bf16() -> tensor<12x7xbf16> {
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}5.000000e+00
// CHECK-NOT: tosa.ceil
// CHECK: return [[RES]]
%0 = "tosa.const"() {value = dense<4.2> : tensor<12x7xbf16>} : () -> tensor<12x7xbf16>
%1 = "tosa.ceil"(%0) : (tensor<12x7xbf16>) -> tensor<12x7xbf16>
return %1 : tensor<12x7xbf16>
}

// CHECK-LABEL: @ceil_fold_f16
func.func @ceil_fold_f16() -> tensor<12x7xf16> {
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}5.000000e+00
// CHECK-NOT: tosa.ceil
// CHECK: return [[RES]]
%0 = "tosa.const"() {value = dense<4.2> : tensor<12x7xf16>} : () -> tensor<12x7xf16>
%1 = "tosa.ceil"(%0) : (tensor<12x7xf16>) -> tensor<12x7xf16>
return %1 : tensor<12x7xf16>
}

// CHECK-LABEL: @ceil_nan
func.func @ceil_nan() -> tensor<f32> {
// 0x7FC00000 is the value for NAN
Expand Down
20 changes: 20 additions & 0 deletions mlir/test/Dialect/Tosa/constant-erf.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ func.func @erf_fold_splat() -> tensor<12x7xf32> {
return %1 : tensor<12x7xf32>
}

// CHECK-LABEL: @erf_fold_bf16
func.func @erf_fold_bf16() -> tensor<12x7xbf16> {
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}7.031250e-02
// CHECK-NOT: tosa.erf
// CHECK: return [[RES]]
%0 = "tosa.const"() {value = dense<0.0625> : tensor<12x7xbf16>} : () -> tensor<12x7xbf16>
%1 = "tosa.erf"(%0) : (tensor<12x7xbf16>) -> tensor<12x7xbf16>
return %1 : tensor<12x7xbf16>
}

// CHECK-LABEL: @erf_zero
func.func @erf_zero() -> tensor<f32> {
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0.000000e+00
Expand Down Expand Up @@ -101,6 +111,16 @@ func.func @erf_no_fold(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
return %0 : tensor<?x?xf32>
}

// CHECK-LABEL: @erf_no_fold_f16
func.func @erf_no_fold_f16() -> tensor<12x7xf16> {
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}6.250000e-02
// CHECK: tosa.erf
// CHECK: return [[RES]]
%0 = "tosa.const"() {value = dense<6.250000e-02> : tensor<12x7xf16>} : () -> tensor<12x7xf16>
%1 = "tosa.erf"(%0) : (tensor<12x7xf16>) -> tensor<12x7xf16>
return %1 : tensor<12x7xf16>
}

// CHECK-LABEL: @erf_fold
func.func @erf_fold() -> tensor<4x6xf32> {
// CHECK: [[RES:]] ={{.*}}tosa.const
Expand Down

0 comments on commit 6df4548

Please sign in to comment.