Skip to content

Commit

Permalink
GPU target parameters for data tiling. (iree-org#18839)
Browse files Browse the repository at this point in the history
This replaces some constants what were hardcoded in
GPUMaterializeEncoding.cpp by actual GPU target parameters.

The logic in `getSwizzle` was doing wonky things with its own local
`const int targetPreferredLoadBitWidth = 128;`, using it in a helper
function inferring interleaving dimensions. That was all dating back to
early days -- that was effectively trying to infer which inner-most
dimensions to skip to get at the first non-Internal dimension... so that
is one more thing that we can fix now that we have
`TileSwizzle::Dim::Kind`. See `getInnermostNonInternalDimIdx`.

The heuristic in `chooseDataTiledMMAAttr` becomes much more robust, and
tested more extensively by `gpu_materialize_encoding.mlir`, now that we
can pass arbitrary parameters in ad-hoc `#iree_gpu.target` attributes,
see the test updates. It's unfortunately verbose (one screenful of MLIR
code for each testcase) because each has to be a complete function with
`flow.dispatch` ops, but that's a separate problem.

---------

Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
  • Loading branch information
bjacob authored Oct 21, 2024
1 parent 114a142 commit c08362a
Show file tree
Hide file tree
Showing 6 changed files with 569 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
// GFX942-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],
// GFX942-SAME: subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
// GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
// GFX942-SAME: max_workgroup_counts = [2147483647, 2147483647, 2147483647]>,
// GFX942-SAME: max_workgroup_counts = [2147483647, 2147483647, 2147483647],
// MI300X: chip = <wgp_count = 304>>
// MI300A: chip = <wgp_count = 228>>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <cfloat>
#include "iree/compiler/Codegen/Common/EncodingUtils.h"
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
Expand Down Expand Up @@ -54,6 +55,9 @@ static std::optional<IREE::GPU::DataTiledMMAAttr>
chooseDataTiledMMAAttr(TypeRange eTypes, IREE::GPU::TargetAttr target,
IREE::Encoding::EncodingAttr encoding) {
using namespace IREE::GPU;
if (!target) {
return std::nullopt;
}
MLIRContext *ctx = target.getContext();

//
Expand Down Expand Up @@ -85,56 +89,118 @@ chooseDataTiledMMAAttr(TypeRange eTypes, IREE::GPU::TargetAttr target,
// Step 2: Select the unrolling factors for the generic case where there is no
// narrow dimension.
//
// These hardcoded constants should become functions querying `target`.
//
// Target ISA preferred load instruction size, in bits.
const int kLoadInstructionBits = 128;
// Target ISA preferred number of subgroups per block to get full utilization.
const int kNumSubgroups = 4;
// Number of register space bits to use for accumulators. Should typically be
// between 50% and 80% of total available register space, as the accumulator
// tends to be larger than the A and B matrix tiles.
const int kMaxAccumulatorRegisterBits = 256 * 32;
IREE::GPU::TargetWgpAttr wgp = target.getWgp();
if (!wgp.getMaxLoadInstructionBits() || !wgp.getVgprSpaceBits() ||
!wgp.getSimdsPerWgp()) {
// Missing workgroup parameters: data tiling not supported on this target.
return std::nullopt;
}

auto sizeInBits = [](VectorType type) -> int {
return type.getElementTypeBitWidth() * type.getNumElements();
};

MMAAttr intrinsicMma = MMAAttr::get(ctx, *intrinsic);
auto [intrinsicA, intrinsicB, intrinsicC] = intrinsicMma.getABCVectorTypes();
// The unrollK factor serves to allow loads from the A and B matrices to use
// the target ISA's vector loads. For instance, if the ISA has 128-bit loads
// and each intrinsic consumes only 32 bits from A and B, then we want to set
// unrollK=4 to turn 4 separate 32-bit loads into one 128-bit load.
const int unrollK =
kLoadInstructionBits /
std::min(
intrinsicA.getElementTypeBitWidth() * intrinsicA.getNumElements(),
intrinsicB.getElementTypeBitWidth() * intrinsicB.getNumElements());
int intrinsicLoadBits =
std::min(sizeInBits(intrinsicA), sizeInBits(intrinsicB));
if (*wgp.getMaxLoadInstructionBits() % intrinsicLoadBits != 0) {
// Never seen that case: the ISA does not have a suitable load instruction
// to feed that intrinsic?!
return std::nullopt;
}
const int unrollK = *wgp.getMaxLoadInstructionBits() / intrinsicLoadBits;

// The total amount of unrolling along the M and N dimensions is normally
// limited only by the number of available registers, since larger M and N
// yields higher arithmetic intensity. Here, we do not yet distinguish between
// plain unrolling (more instructions on each thread) and
// unrolling-to-subgroups (more threads).
const int totalUnrollMN =
kMaxAccumulatorRegisterBits /
(intrinsicC.getElementTypeBitWidth() * intrinsicC.getNumElements());
const int totalUnrollM = static_cast<int>(
std::floor(std::sqrt(static_cast<float>(totalUnrollMN))));
const int totalUnrollN = totalUnrollMN / totalUnrollM;
// unrolling-to-subgroups (more threads), since expanding to more subgroups
// correspondingly divides the available register space between this many
// subgroups, making it cancel out of the equation here.
//
// We need to solve for two variables here, unroll_m and unroll_n, constrained
// by one quadratic equation expressing that the A, B and C tiles must fit in
// VGPR space. Since we have only 1 constraint for two variables, we
// self-impose a second constraint for now: that the unrolling shape should be
// square, i.e. unrollM == unrollN.
// TODO(#18850): that is suboptimal for narrow cases.
//
// Now we have only one variable, call it x, to solve for.

// The register space taken is:
// A-tile: x * unrollK * sizeInBits(intrinsicA)
// B-tile: x * unrollK * sizeInBits(intrinsicB)
// C-tile: x^2 * sizeInBits(intrinsicC)
// So the equation to solve is:
// x^2 * sizeInBits(intrinsicC)
// + x * unrollK * (sizeInBits(intrinsicA) + sizeInBits(intrinsicB))
// == wgp.getVgprSpaceBits()
float c2 = sizeInBits(intrinsicC);
float c1 = unrollK * (sizeInBits(intrinsicA) + sizeInBits(intrinsicB));
float c0 = -*wgp.getVgprSpaceBits(); // negative by construction.
// Now the equation to solve is: c2 * x^2 + c1 * x + c0 == 0.
float discriminant = c1 * c1 - 4 * c0 * c2; // positive, because c0 < 0.
// x = unique positive solution.
float x = (-c1 + std::sqrt(discriminant)) / (2 * c2);

#ifndef NDEBUG
// Self-check quadratic solver. 10 epsilon is just a crude upper bound;
// In practice, cancellation results in check == 0 in current cases.
float check = c2 * x * x + c1 * x + c0;
assert(std::abs(check) < 10 * FLT_EPSILON * std::abs(c0));
#endif

// Now, looking geometrically at our unrolling space along the M and N
// dimensions, we solve the following problem in the (M,N)-plane: approximate
// a square of side length `x`, by a rectangle of side lengths `totalUnrollM`
// and `totalUnrollN`, under the constraints:
// 1. totalUnrollM * totalUnrollN <= x * x
// * Reason: by construction of x, any larger area would exceed the
// wgp.getVgprSpaceBits() budget)
// 2. totalUnrollM and totalUnrollN are powers of 2.
// * Reason: that is a self-imposed constraint for now to avoid prematurely
// entering excessing fine-tuning of unrolling factors. Also, since below
// we will put all the unroll-to-subgroups in the N dimension, that
// requires totalUnrollN to be a multiple of wgp.getSimdsPerWgp(),
// which is typically a power of 2, specifically 4.
// TODO(#18851): we will not always put all the unroll-to-subgroups on N.
// 3. totalUnrollN >= totalUnrollM.
// * Reason: Just like the previous constraint, that is also motivated by
// the code below currently putting all the unroll-to-subgroups in the N
// dimension, which requires a sufficiently large totalUnrollN.
// TODO(#18851): we will not always put all the unroll-to-subgroups on N.
//
// Set totalUnrollN = round x to nearest power of two, break ties away from 0
// per specification of std::round.
int totalUnrollN = std::exp2(std::round(std::log2(x)));
// Based on above constraint 1:
float unroundedMaxTotalUnrollM = x * x / totalUnrollN;
int totalUnrollM = std::exp2(std::floor(std::log2(unroundedMaxTotalUnrollM)));

// Now we introduce unroll-to-subgroups. It doesn't change the overall tile
// size, as it increases the number of subgroups but correspondingly decreases
// the number of registers available to each subgroups. In other words, the
// overall tile size determined above only needed to be concerned with the
// overall number of registers, not with how they are split between subgroups.
//
// For now for simplicity we put all the unroll-to-subgroups in the N
// dimension. That might be suboptimal, revisit later. That does simplify the
// below adjustments for narrow M/N, as we don't need to think about
// unroll-to-subgroups when making the narrowing adjustment.
// dimension. TODO(#18851): revisit that.
//
// That does simplify the below adjustments for narrow M/N, as we don't need
// to think about unroll-to-subgroups when making the narrowing adjustment.
int unrollMToSubgroups = 1;
int unrollNToSubgroups = kNumSubgroups;
int unrollNToSubgroups = *wgp.getSimdsPerWgp();
int unrollM = totalUnrollM / unrollMToSubgroups;
int unrollN = totalUnrollN / unrollNToSubgroups;

//
// Step 3: Adjust the unrolling factors when there is a narrow dimension.
// TODO(#18850): dealing with narrow cases as a fix-up is suboptimal.
//
IREE::Encoding::MatmulNarrowDim narrowDim =
IREE::Encoding::getMatmulNarrowDim(encoding);
Expand Down
Loading

0 comments on commit c08362a

Please sign in to comment.