Skip to content

Commit

Permalink
[LinalgExt] Better doc for FP8 attention clamping. (iree-org#18301)
Browse files Browse the repository at this point in the history
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
  • Loading branch information
raikonenfnu authored Aug 20, 2024
1 parent 5ba9a89 commit 5beb9ad
Showing 1 changed file with 5 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,15 @@ static Value truncateFloat(OpBuilder &builder, Location loc, AffineMap inputMap,
APFloat::getLargest(dstTy.getFloatSemantics(), /*Negative=*/false)
.convertToDouble();

// Truncate to the `fp8` range so avoid nan values.
// Clamp input to dstTy(usually `fp8`) MAX value to prevent NaNs.
// We do not clamp for `-MAX` because this function meant to only be
// used by attention's exp2 who's value is always > 0.
Value mx = builder.create<arith::ConstantOp>(
loc, builder.getFloatAttr(srcTy, mxDbl));
Value sel0 = b.create<arith::MinimumFOp>(loc, mx, args[0]);
Value clamped = b.create<arith::MinimumFOp>(loc, mx, args[0]);

// Convert scale to the same datatype as input.
Value trunc = convertScalarToDtype(b, loc, sel0, dstTy,
Value trunc = convertScalarToDtype(b, loc, clamped, dstTy,
/*isUnsignedCast=*/false);
b.create<linalg::YieldOp>(loc, trunc);
});
Expand Down

0 comments on commit 5beb9ad

Please sign in to comment.