Skip to content

Commit

Permalink
Merge with fixes of 10a57f3
Browse files Browse the repository at this point in the history
  • Loading branch information
cferry-AMD committed Aug 19, 2024
2 parents 4ff51c6 + 10a57f3 commit d4c1ac3
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 5 deletions.
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Math/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ void populateExpandFloorFPattern(RewritePatternSet &patterns);
void populateExpandCeilFPattern(RewritePatternSet &patterns);
void populateExpandExp2FPattern(RewritePatternSet &patterns);
void populateExpandPowFPattern(RewritePatternSet &patterns);
void populateExpandFPowIPattern(RewritePatternSet &patterns);
void populateExpandRoundFPattern(RewritePatternSet &patterns);
void populateExpandRoundEvenPattern(RewritePatternSet &patterns);
void populateExpandRsqrtPattern(RewritePatternSet &patterns);
Expand Down
87 changes: 82 additions & 5 deletions mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
//===- ExpandTanh.cpp - Code to perform expanding tanh op -----------------===//
//===- ExpandPatterns.cpp - Code to expand various math operations. -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements expansion of tanh op.
// This file implements expansion of various math operations.
//
//===----------------------------------------------------------------------===//

Expand All @@ -23,9 +23,14 @@
using namespace mlir;

/// Create a float constant.
static Value createFloatConst(Location loc, Type type, double value,
static Value createFloatConst(Location loc, Type type, APFloat value,
OpBuilder &b) {
auto attr = b.getFloatAttr(getElementTypeOrSelf(type), value);
bool losesInfo = false;
auto eltType = getElementTypeOrSelf(type);
// Convert double to the given `FloatType` with round-to-nearest-ties-to-even.
value.convert(cast<FloatType>(eltType).getFloatSemantics(),
APFloat::rmNearestTiesToEven, &losesInfo);
auto attr = b.getFloatAttr(eltType, value);
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
return b.create<arith::ConstantOp>(loc,
DenseElementsAttr::get(shapedTy, attr));
Expand All @@ -34,7 +39,12 @@ static Value createFloatConst(Location loc, Type type, double value,
return b.create<arith::ConstantOp>(loc, attr);
}

/// Create a float constant.
static Value createFloatConst(Location loc, Type type, double value,
OpBuilder &b) {
return createFloatConst(loc, type, APFloat(value), b);
}

/// Create an integer constant.
static Value createIntConst(Location loc, Type type, int64_t value,
OpBuilder &b) {
auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value);
Expand Down Expand Up @@ -202,6 +212,69 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
rewriter.replaceOp(op, ret);
return success();
}

// Convert `math.fpowi` to a series of `arith.mulf` operations.
// If the power is negative, we divide one by the result.
// If both the base and power are zero, the result is 1.
static LogicalResult convertFPowICstOp(math::FPowIOp op,
PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value base = op.getOperand(0);
Value power = op.getOperand(1);
Type baseType = base.getType();

Attribute cstAttr;
if (!matchPattern(power, m_Constant(&cstAttr)))
return failure();

APInt value;
if (!matchPattern(cstAttr, m_ConstantInt(&value)))
return failure();

int64_t powerInt = value.getSExtValue();
bool isNegative = powerInt < 0;
int64_t absPower = std::abs(powerInt);
Value one = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
Value res = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);

while (absPower > 0) {
if (absPower & 1)
res = b.create<arith::MulFOp>(baseType, base, res);
absPower >>= 1;
base = b.create<arith::MulFOp>(baseType, base, base);
}

// Make sure not to introduce UB in case of negative power.
if (isNegative) {
auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf(baseType))
.getFloatSemantics();
Value zero =
createFloatConst(op->getLoc(), baseType,
APFloat::getZero(sem, /*Negative=*/false), rewriter);
Value negZero =
createFloatConst(op->getLoc(), baseType,
APFloat::getZero(sem, /*Negative=*/true), rewriter);
Value posInfinity =
createFloatConst(op->getLoc(), baseType,
APFloat::getInf(sem, /*Negative=*/false), rewriter);
Value negInfinity =
createFloatConst(op->getLoc(), baseType,
APFloat::getInf(sem, /*Negative=*/true), rewriter);
Value zeroEqCheck =
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, zero);
Value negZeroEqCheck =
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, negZero);
res = b.create<arith::DivFOp>(baseType, one, res);
res =
b.create<arith::SelectOp>(op->getLoc(), zeroEqCheck, posInfinity, res);
res = b.create<arith::SelectOp>(op->getLoc(), negZeroEqCheck, negInfinity,
res);
}

rewriter.replaceOp(op, res);
return success();
}

// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Expand Down Expand Up @@ -534,6 +607,10 @@ void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
patterns.add(convertPowfOp);
}

void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) {
patterns.add(convertFPowICstOp);
}

void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
patterns.add(convertRoundOp);
}
Expand Down
99 changes: 99 additions & 0 deletions mlir/test/Dialect/Math/expand-math.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,105 @@ func.func @roundeven16(%arg: f16) -> f16 {

// -----

// CHECK-LABEL: func.func @math_fpowi_neg_odd_power
func.func @math_fpowi_neg_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
%1 = arith.constant dense<-3> : tensor<8xi64>
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
return %2 : tensor<8xf32>
}
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
// CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
// CHECK-DAG: %[[CSTNEG0:.*]] = arith.constant dense<-0.000000e+00> : tensor<8xf32>
// CHECK-DAG: %[[CSTINF:.*]] = arith.constant dense<0x7F800000> : tensor<8xf32>
// CHECK-DAG: %[[CSTNEGINF:.*]] = arith.constant dense<0xFF800000> : tensor<8xf32>
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
// CHECK: %[[CUBE:.*]] = arith.mulf %[[SQ]], %[[ARG0]] : tensor<8xf32>
// CHECK: %[[CMP0:.*]] = arith.cmpf oeq, %[[CUBE]], %[[CST0]] : tensor<8xf32>
// CHECK: %[[CMPNEG0:.*]] = arith.cmpf oeq, %[[CUBE]], %[[CSTNEG0]] : tensor<8xf32>
// CHECK: %[[INV:.*]] = arith.divf %[[CST1]], %[[CUBE]] : tensor<8xf32>
// CHECK: %[[UB1:.*]] = arith.select %[[CMP0]], %[[CSTINF]], %[[INV]] : tensor<8xi1>, tensor<8xf32>
// CHECK: %[[UB2:.*]] = arith.select %[[CMPNEG0]], %[[CSTNEGINF]], %[[UB1]] : tensor<8xi1>, tensor<8xf32>
// CHECK: return %[[UB2]] : tensor<8xf32>

// -----

// CHECK-LABEL: func.func @math_fpowi_neg_even_power
func.func @math_fpowi_neg_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
%1 = arith.constant dense<-4> : tensor<8xi64>
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
return %2 : tensor<8xf32>
}
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
// CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
// CHECK-DAG: %[[CSTNEG0:.*]] = arith.constant dense<-0.000000e+00> : tensor<8xf32>
// CHECK-DAG: %[[CSTINF:.*]] = arith.constant dense<0x7F800000> : tensor<8xf32>
// CHECK-DAG: %[[CSTNEGINF:.*]] = arith.constant dense<0xFF800000> : tensor<8xf32>
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
// CHECK: %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
// CHECK: %[[CMP0:.*]] = arith.cmpf oeq, %[[PW4]], %[[CST0]] : tensor<8xf32>
// CHECK: %[[CMPNEG0:.*]] = arith.cmpf oeq, %[[PW4]], %[[CSTNEG0]] : tensor<8xf32>
// CHECK: %[[INV:.*]] = arith.divf %[[CST1]], %[[PW4]] : tensor<8xf32>
// CHECK: %[[UB1:.*]] = arith.select %[[CMP0]], %[[CSTINF]], %[[INV]] : tensor<8xi1>, tensor<8xf32>
// CHECK: %[[UB2:.*]] = arith.select %[[CMPNEG0]], %[[CSTNEGINF]], %[[UB1]] : tensor<8xi1>, tensor<8xf32>
// CHECK: return %[[UB2]] : tensor<8xf32>

// -----

// CHECK-LABEL: func.func @math_fpowi_pos_odd_power
func.func @math_fpowi_pos_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
%1 = arith.constant dense<5> : tensor<8xi64>
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
return %2 : tensor<8xf32>
}
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
// CHECK: %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
// CHECK: %[[PW5:.*]] = arith.mulf %[[PW4]], %[[ARG0]] : tensor<8xf32>
// CHECK: return %[[PW5]] : tensor<8xf32>

// -----

// CHECK-LABEL: func.func @math_fpowi_pos_even_power
func.func @math_fpowi_pos_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
%1 = arith.constant dense<4> : tensor<8xi64>
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
return %2 : tensor<8xf32>
}
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
// CHECK: %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
// CHECK: return %[[PW4]] : tensor<8xf32>

// -----

// CHECK-LABEL: func.func @math_fpowi_even_scalar
func.func @math_fpowi_even_scalar(%0 : f32) -> f32 {
%pow = arith.constant 2 : i64
%2 = math.fpowi %0, %pow : f32, i64
return %2 : f32
}
// CHECK-SAME: (%[[ARG0:.*]]: f32) -> f32 {
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : f32
// CHECK: return %[[SQ]] : f32

// -----

// CHECK-LABEL: func.func @math_fpowi_scalar_zero
func.func @math_fpowi_scalar_zero(%0 : f32) -> f32 {
%pow = arith.constant 0 : i64
%2 = math.fpowi %0, %pow : f32, i64
return %2 : f32
}
// CHECK-SAME: (%[[ARG0:.*]]: f32) -> f32 {
// CHECK: %[[RET:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: return %[[RET]] : f32

// -----

// -----

// CHECK-LABEL: func.func @rsqrt
// CHECK-SAME: (%[[ARG:.*]]: f16)
// CHECK-SAME: -> f16
Expand Down
1 change: 1 addition & 0 deletions mlir/test/lib/Dialect/Math/TestExpandMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ void TestExpandMathPass::runOnOperation() {
populateExpandFloorFPattern(patterns);
populateExpandCeilFPattern(patterns);
populateExpandPowFPattern(patterns);
populateExpandFPowIPattern(patterns);
populateExpandRoundFPattern(patterns);
populateExpandRoundEvenPattern(patterns);
populateExpandRsqrtPattern(patterns);
Expand Down

0 comments on commit d4c1ac3

Please sign in to comment.