Skip to content

Commit

Permalink
[LinalgExt] Reland QK scaling for attention decomp and further optimi…
Browse files Browse the repository at this point in the history
…zations of it. (iree-org#18293)

This PR relands original QK scaling and optimizing it further for
attention decomposition. Addditionally, we Re-order QK scaling to run
pre-exp2 S.T we can use linear offset instead of multiplication (`scale
* exp(y) == exp(y+x)`). To further optimize, we move said linear scaling
to before reduce<max>, since linear offset from max will be canceled out
( `oldMax - newMax == (oldMax + f8_linear_offset) - (oldMax +
f8_linear_offset)`.

---------

Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Co-authored-by: Rob Suderman <rob.suderman@gmail.com>
Co-authored-by: Kunwar Grover <groverkss@gmail.com>
  • Loading branch information
3 people authored Aug 20, 2024
1 parent ab0d4c6 commit 5ba9a89
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ static llvm::cl::opt<float> clAttentionSoftmaxMax(
llvm::cl::desc("maximum expected value from attention softmax"),
llvm::cl::init(1.0));

static Value scaleValueInPlace(OpBuilder &builder, Location loc,
AffineMap inputMap, AffineMap scaleMap,
Value value, Value scale) {
template <typename T>
static Value elementwiseValueInPlace(OpBuilder &builder, Location loc,
AffineMap inputMap, AffineMap scaleMap,
Value value, Value scale) {
SmallVector<AffineMap> compressedMaps =
compressUnusedDims(SmallVector<AffineMap>{inputMap, scaleMap});
inputMap = compressedMaps[0];
Expand All @@ -46,7 +47,7 @@ static Value scaleValueInPlace(OpBuilder &builder, Location loc,
// Convert scale to the same datatype as input.
Value scale = convertScalarToDtype(b, loc, args[0], args[1].getType(),
/*isUnsignedCast=*/false);
Value result = b.create<arith::MulFOp>(loc, scale, args[1]);
Value result = b.create<T>(loc, scale, args[1]);
b.create<linalg::YieldOp>(loc, result);
});
return genericOp.getResult(0);
Expand Down Expand Up @@ -90,28 +91,17 @@ static Value truncateFloat(OpBuilder &builder, Location loc, AffineMap inputMap,
auto srcTy = cast<FloatType>(args[0].getType());
auto dstTy = cast<FloatType>(args[1].getType());

// We clamp to the min / max of the floating point representation
double mnDbl =
APFloat::getLargest(dstTy.getFloatSemantics(), /*Negative=*/true)
.convertToDouble();
double mxDbl =
APFloat::getLargest(dstTy.getFloatSemantics(), /*Negative=*/false)
.convertToDouble();

// Truncate to the `fp8` range so avoid nan values.
Value mn = builder.create<arith::ConstantOp>(
loc, builder.getFloatAttr(srcTy, mnDbl));
Value mx = builder.create<arith::ConstantOp>(
loc, builder.getFloatAttr(srcTy, mxDbl));
Value gt = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
args[0], mx);
Value lt = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT,
args[0], mn);
Value sel0 = b.create<arith::SelectOp>(loc, gt, mx, args[0]);
Value sel1 = b.create<arith::SelectOp>(loc, lt, mn, sel0);
Value sel0 = b.create<arith::MinimumFOp>(loc, mx, args[0]);

// Convert scale to the same datatype as input.
Value trunc = convertScalarToDtype(b, loc, sel1, dstTy,
Value trunc = convertScalarToDtype(b, loc, sel0, dstTy,
/*isUnsignedCast=*/false);
b.create<linalg::YieldOp>(loc, trunc);
});
Expand Down Expand Up @@ -285,7 +275,8 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
AffineMap qMap = getQueryMap();
AffineMap scaleMap = AffineMap::get(/*dimCount=*/qMap.getNumInputs(),
/*symbolCount=*/0, getContext());
query = scaleValueInPlace(b, loc, qMap, scaleMap, query, scale);
query = elementwiseValueInPlace<arith::MulFOp>(b, loc, qMap, scaleMap,
query, scale);
}

// ---- Matmul 1 ----
Expand All @@ -305,14 +296,28 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
Value s = b.create<linalg::FillOp>(loc, sZero, emptyS).getResult(0);
s = computeMatmul(b, loc, getQueryMap(), getKeyMap(), sMap, query, key, s);

// For low bit-depth types we perform post Q @ K scaling. This is to avoid
// losing numerical precision due to the low dynamic range of fp8 types when
// pre applying the sclaing.
if (qETy.getIntOrFloatBitWidth() <= 8) {
// For low bit-depth types we perform post Q @ K scaling. This is to avoid
// losing numerical precision due to the low dynamic range of fp8 types when
// pre applying the sclaing.
AffineMap sMap = b.getMultiDimIdentityMap(sSizes.size());
AffineMap scaleMap = AffineMap::get(/*dimCount=*/sMap.getNumInputs(),
/*symbolCount=*/0, getContext());
s = scaleValueInPlace(b, loc, sMap, scaleMap, s, scale);
s = elementwiseValueInPlace<arith::MulFOp>(b, loc, sMap, scaleMap, s,
scale);

// If we need to truncate to fp8 post softmax we apply a scaling to use the
// full fp8 range. We can do this with a offset as post `exp2` this equates
// to multiplying by a static value. We are able to do this as `max` and
// `sum` are scaled by the same value so the end result is the same.
auto fpTy = cast<FloatType>(qETy);
double mx =
APFloat::getLargest(fpTy.getFloatSemantics(), /*Negative=*/false)
.convertToDouble();
Value offset = b.create<arith::ConstantOp>(
loc, b.getFloatAttr(elementType, clAttentionSoftmaxMax / mx));
s = elementwiseValueInPlace<arith::AddFOp>(b, loc, sMap, scaleMap, s,
offset);
}

// TODO: This decomposition should be in a seperate op called
Expand All @@ -323,19 +328,20 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
AffineMap maxMap = getMaxMap();
Value newMax = reduce<arith::MaximumFOp>(b, loc, sMap, maxMap, s, oldMax);

// P = exp2(S - newMax)
// PMap = SMap
AffineMap pMap = sMap;
Value p = computeSubAndExp2(b, loc, maxMap, sMap, newMax, s);

// norm = exp2(oldMax - newMax)
// normMap = maxMap
AffineMap normMap = getMaxMap();
Value norm = computeSubAndExp2(b, loc, maxMap, normMap, newMax, oldMax);

// normSum = norm * oldSum
AffineMap sumMap = getSumMap();
Value normSum = scaleValueInPlace(b, loc, sumMap, normMap, oldSum, norm);
Value normSum = elementwiseValueInPlace<arith::MulFOp>(b, loc, sumMap,
normMap, oldSum, norm);

// P = exp2(S - newMax)
// PMap = SMap
AffineMap pMap = sMap;
Value p = computeSubAndExp2(b, loc, maxMap, sMap, newMax, s);

// newSum = normSum + rowSum(P)
Value newSum = reduce<arith::AddFOp>(b, loc, pMap, sumMap, p, normSum);
Expand All @@ -344,54 +350,19 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
AffineMap accMap = getOutputMap();

// ---- Scale and truncate LHS to match RHS ----
Value pScale;
auto pETy = getElementTypeOrSelf(p.getType());
if (pETy != vETy && isa<FloatType>(vETy)) {
if (vETy.getIntOrFloatBitWidth() <= 8) {
SmallVector<OpFoldResult> mSizes(
llvm::map_range(maxMap.getResults(), [&](AffineExpr dimExpr) {
return sizes[cast<AffineDimExpr>(dimExpr).getPosition()];
}));

auto fpTy = cast<FloatType>(vETy);
double largestDbl =
APFloat::getLargest(fpTy.getFloatSemantics(), /*Negative=*/false)
.convertToDouble();

// We normalize p from [0, max] to [0, fp8.max] to guarantee we
// use the full `fp8` range, then renormlize post Softmax@V matmul
// to correct.
pScale = b.create<arith::ConstantOp>(
loc, b.getFloatAttr(elementType, clAttentionSoftmaxMax / largestDbl));

// Compute the pre matmul scale to handle fp8 quantization:
Value pScaleInv = b.create<arith::ConstantOp>(
loc, b.getFloatAttr(elementType, largestDbl / clAttentionSoftmaxMax));

AffineMap scaleMap = AffineMap::get(/*dimCount=*/maxMap.getNumInputs(),
/*symbolCount=*/0, getContext());
p = scaleValueInPlace(b, loc, pMap, scaleMap, p, pScaleInv);
norm = scaleValueInPlace(b, loc, normMap, scaleMap, norm, pScaleInv);
}

Value convertP = b.create<tensor::EmptyOp>(loc, sSizes, vETy);
p = truncateFloat(b, loc, pMap, pMap, p, convertP);
}

Value newAcc = scaleValueInPlace(b, loc, accMap, normMap, oldAcc, norm);
Value newAcc = elementwiseValueInPlace<arith::MulFOp>(b, loc, accMap, normMap,
oldAcc, norm);

// ---- Matmul 2 ----

// newAcc = P @ V + newAcc
newAcc = computeMatmul(b, loc, pMap, getValueMap(), accMap, p, value, newAcc);

// Update for for the FP8 dynamic scale:
if (pScale) {
AffineMap scaleMap = AffineMap::get(/*dimCount=*/maxMap.getNumInputs(),
/*symbolCount=*/0, getContext());
newAcc = scaleValueInPlace(b, loc, accMap, scaleMap, newAcc, pScale);
}

return SmallVector<Value>{newAcc, newMax, newSum};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,6 @@ func.func @attention_f16(%query: tensor<192x1024x64xf16>,
// CHECK: linalg.generic
// CHECK: arith.maximumf
// CHECK: linalg.yield
// P = exp2(S - newMax)
// CHECK: linalg.generic
// CHECK: arith.subf
// CHECK: math.exp2
// CHECK: linalg.yield
// norm = exp2(oldMax - newMax)
// CHECK: linalg.generic
// CHECK: arith.subf
Expand All @@ -52,7 +47,12 @@ func.func @attention_f16(%query: tensor<192x1024x64xf16>,
// CHECK: linalg.generic
// CHECK: arith.mulf
// CHECK: linalg.yield
// newSum = normSum + rowMax(P)
// P = exp2(S - newMax)
// CHECK: linalg.generic
// CHECK: arith.subf
// CHECK: math.exp2
// CHECK: linalg.yield
// newSum = normSum + rowSum(P)
// CHECK: linalg.generic
// CHECK: arith.addf
// CHECK: linalg.yield
Expand Down Expand Up @@ -103,15 +103,15 @@ func.func @attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>,
// S = S * scale
// CHECK: linalg.generic
// CHECK: arith.mulf
// CHECK-NEXT: linalg.yield
// S = S + F8_linear_offset
// CHECK: linalg.generic
// CHECK: arith.addf
// CHECK-NEXT: linalg.yield
// newMax = max(oldMax, rowMax(S))
// CHECK: linalg.generic
// CHECK: arith.maximumf
// CHECK: linalg.yield
// P = exp2(S - newMax)
// CHECK: linalg.generic
// CHECK: arith.subf
// CHECK: math.exp2
// CHECK: linalg.yield
// norm = exp2(oldMax - newMax)
// CHECK: linalg.generic
// CHECK: arith.subf
Expand All @@ -121,16 +121,18 @@ func.func @attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>,
// CHECK: linalg.generic
// CHECK: arith.mulf
// CHECK: linalg.yield
// newSum = normSum + rowMax(P)
// P = exp2(S - newMax)
// CHECK: linalg.generic
// CHECK: arith.subf
// CHECK: math.exp2
// CHECK: linalg.yield
// newSum = normSum + rowSum(P)
// CHECK: linalg.generic
// CHECK: arith.addf
// CHECK: linalg.yield
// clamp = clamp(norm)
// CHECK: linalg.generic
// CHECK: arith.cmpf ogt
// CHECK: arith.cmpf olt
// CHECK: arith.select
// CHECK: arith.select
// CHECK: arith.minimumf
// CHECK: arith.truncf
// newAcc = norm * oldAcc
// CHECK: linalg.generic
Expand Down

0 comments on commit 5ba9a89

Please sign in to comment.