From 5beb9ad73626beb8fd5d0e769f0923161b76ab27 Mon Sep 17 00:00:00 2001 From: Stanley Winata <68087699+raikonenfnu@users.noreply.github.com> Date: Tue, 20 Aug 2024 09:51:52 -0700 Subject: [PATCH] [LinalgExt] Better doc for FP8 attention clamping. (#18301) Signed-off-by: Stanley Winata --- .../LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp index a9669a1fe719..cdedecffcebe 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp @@ -95,13 +95,15 @@ static Value truncateFloat(OpBuilder &builder, Location loc, AffineMap inputMap, APFloat::getLargest(dstTy.getFloatSemantics(), /*Negative=*/false) .convertToDouble(); - // Truncate to the `fp8` range so avoid nan values. + // Clamp input to dstTy(usually `fp8`) MAX value to prevent NaNs. + // We do not clamp for `-MAX` because this function meant to only be + // used by attention's exp2 who's value is always > 0. Value mx = builder.create( loc, builder.getFloatAttr(srcTy, mxDbl)); - Value sel0 = b.create(loc, mx, args[0]); + Value clamped = b.create(loc, mx, args[0]); // Convert scale to the same datatype as input. - Value trunc = convertScalarToDtype(b, loc, sel0, dstTy, + Value trunc = convertScalarToDtype(b, loc, clamped, dstTy, /*isUnsignedCast=*/false); b.create(loc, trunc); });