forked from triton-lang/triton
-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MFMA][Test] Add scripts generating mfma related lit tests
This PR adds scripts for lit test generation.
- Loading branch information
Showing
7 changed files
with
329 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# LIT test generator scripts | ||
|
||
### generate_accelerate_matmul_tests.pp | ||
|
||
This script generates CDNA related tests for AccelerateAMDMatmul pass. | ||
There are 3 generations of CDNA architecture, so to generate all tests following commands are needed: | ||
|
||
``` bash | ||
python3 generate_accelerate_matmul_tests.py 2 ../../../test/TritonGPU/accelerate-matmul-cdna1.mlir | ||
python3 generate_accelerate_matmul_tests.py 2 ../../../test/TritonGPU/accelerate-matmul-cdna2.mlir | ||
python3 generate_accelerate_matmul_tests.py 2 ../../../test/TritonGPU/accelerate-matmul-cdna3.mlir | ||
``` | ||
|
||
### generate_mfma_variants.py | ||
|
||
This script generates CDNA related tests for TritonGPU to LLVM transformation: | ||
|
||
``` bash | ||
python3 generate_mfma_variants.py ../../../test/Conversion/AMDGPU/mfma_variants.mlir | ||
``` | ||
|
164 changes: 164 additions & 0 deletions
164
scripts/amd/lit_tests/generate_accelerate_matmul_tests.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
import argparse | ||
import sys | ||
|
||
# M N K a_ty b_ty c_ty | ||
configs = [[32, 32, 32, "f16", "f16", "f32"], | ||
[32, 32, 32, "bf16", "bf16", "f32"], | ||
[32, 32, 32, "f32", "f32", "f32"], | ||
[32, 32, 32, "i8", "i8", "i32"], | ||
[32, 32, 32, "f8E4M3FNUZ", "f8E4M3FNUZ", "f32"], | ||
[32, 32, 32, "f8E4M3FNUZ", "f8E5M2FNUZ", "f32"], | ||
[32, 32, 32, "f8E5M2FNUZ", "f8E4M3FNUZ", "f32"], | ||
[32, 32, 32, "f8E5M2FNUZ", "f8E5M2FNUZ", "f32"], | ||
|
||
[16, 16, 32, "f16", "f16", "f32"], | ||
[16, 16, 32, "bf16", "bf16", "f32"], | ||
[16, 16, 32, "f32", "f32", "f32"], | ||
[16, 16, 32, "i8", "i8", "i32"], | ||
[16, 16, 32, "f8E4M3FNUZ", "f8E4M3FNUZ", "f32"], | ||
[16, 16, 32, "f8E4M3FNUZ", "f8E5M2FNUZ", "f32"], | ||
[16, 16, 32, "f8E5M2FNUZ", "f8E4M3FNUZ", "f32"], | ||
[16, 16, 32, "f8E5M2FNUZ", "f8E5M2FNUZ", "f32"], | ||
|
||
[4, 4, 64, "f16", "f16", "f32"], | ||
[4, 4, 64, "bf16", "bf16", "f32"], | ||
[4, 4, 64, "f32", "f32", "f32"], | ||
[4, 4, 64, "i8", "i8", "i32"], | ||
[4, 4, 64, "f8E4M3FNUZ", "f8E4M3FNUZ", "f32"], | ||
[4, 4, 64, "f8E4M3FNUZ", "f8E5M2FNUZ", "f32"], | ||
[4, 4, 64, "f8E5M2FNUZ", "f8E4M3FNUZ", "f32"], | ||
[4, 4, 64, "f8E5M2FNUZ", "f8E5M2FNUZ", "f32"], | ||
|
||
[64, 4, 4, "f16", "f16", "f32"], | ||
[64, 4, 4, "bf16", "bf16", "f32"], | ||
[64, 4, 4, "f32", "f32", "f32"], | ||
[64, 4, 4, "i8", "i8", "i32"], | ||
[64, 4, 4, "f8E4M3FNUZ", "f8E4M3FNUZ", "f32"], | ||
[64, 4, 4, "f8E4M3FNUZ", "f8E5M2FNUZ", "f32"], | ||
[64, 4, 4, "f8E5M2FNUZ", "f8E4M3FNUZ", "f32"], | ||
[64, 4, 4, "f8E5M2FNUZ", "f8E5M2FNUZ", "f32"], | ||
|
||
[4, 64, 4, "f16", "f16", "f32"], | ||
[4, 64, 4, "bf16", "bf16", "f32"], | ||
[4, 64, 4, "f32", "f32", "f32"], | ||
[4, 64, 4, "i8", "i8", "i32"], | ||
[4, 64, 4, "f8E4M3FNUZ", "f8E4M3FNUZ", "f32"], | ||
[4, 64, 4, "f8E4M3FNUZ", "f8E5M2FNUZ", "f32"], | ||
[4, 64, 4, "f8E5M2FNUZ", "f8E4M3FNUZ", "f32"], | ||
[4, 64, 4, "f8E5M2FNUZ", "f8E5M2FNUZ", "f32"] | ||
] | ||
|
||
def generate(cdna_version, output_file): | ||
arch_names = {0:"", 1: "gfx908", 2: "gfx90a", 3: "gfx940"} | ||
arch_name = arch_names[cdna_version] | ||
print(f"// This file is generated: $ python3 {' '.join(sys.argv)}", file=output_file) | ||
print(f"// RUN: (! triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul=arch-generation-name={arch_name} --mlir-pass-pipeline-crash-reproducer=%t 2>/dev/null) | FileCheck --check-prefixes=CHECK %s", file=output_file) | ||
|
||
for cfg_id in range(len(configs)): | ||
cfg = configs[cfg_id] | ||
|
||
cfg_name = "_".join([str(item) for item in cfg]) | ||
|
||
M, N, K, a_ty, b_ty, c_ty = cfg | ||
if "i" in c_ty: | ||
cst_val = "0" | ||
else: | ||
cst_val = "0.000000e+00" | ||
|
||
supported = True | ||
if cdna_version < 3 and ("f8" in a_ty or "f8" in b_ty): | ||
supported = False | ||
|
||
if M >= 32 and N >= 32: | ||
m_dim = 32 | ||
n_dim = 32 | ||
elif M >= 16 and N >= 16: | ||
m_dim = 16 | ||
n_dim = 16 | ||
elif M >= 64 and N < 16: | ||
m_dim = 64 | ||
n_dim = 4 | ||
elif M < 16 and N >= 64: | ||
m_dim = 4 | ||
n_dim = 64 | ||
elif M < 16 and N < 16: | ||
m_dim = 4 | ||
n_dim = 4 | ||
if ("f8" in a_ty or "f8" in b_ty) and min(m_dim, n_dim) == 4: | ||
supported = False | ||
|
||
if cdna_version == 1: | ||
if a_ty == "f16": | ||
k_width = 4 | ||
if a_ty == "bf16": | ||
k_width = 2 | ||
if a_ty == "i8": | ||
k_width = 4 | ||
if a_ty == "f32": | ||
k_width = 1 | ||
if cdna_version == 2: | ||
if a_ty == "f16": | ||
k_width = 4 | ||
if a_ty == "bf16": | ||
k_width = 4 | ||
if a_ty == "i8": | ||
k_width = 4 | ||
if a_ty == "f32": | ||
k_width = 1 | ||
if cdna_version == 3: | ||
if "f8" in a_ty: | ||
k_width = 8 | ||
if a_ty == "f16": | ||
k_width = 4 | ||
if a_ty == "bf16": | ||
k_width = 4 | ||
if a_ty == "i8": | ||
if min(m_dim, n_dim) == 4: | ||
k_width = 4 | ||
else: | ||
k_width = 8 | ||
if a_ty == "f32": | ||
k_width = 1 | ||
|
||
if supported: | ||
mfma_check = f"// CHECK: #mfma = #triton_gpu.mfma<{{versionMajor = {cdna_version}, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [{m_dim}, {n_dim}], isTransposed = false}}>" | ||
label_check = f"// CHECK: convert_dot_{cfg_name}" | ||
checks =f"""// CHECK: triton_gpu.convert_layout {{{{.*}}}} : (tensor<{{{{.*}}}}, #blocked>) -> tensor<{{{{.*}}}}, #mfma> | ||
// CHECK: triton_gpu.convert_layout {{{{.*}}}} : (tensor<{{{{.*}}}}, #triton_gpu.dot_op<{{opIdx = 0, parent = #blocked}}>>) -> tensor<{{{{.*}}}}, #triton_gpu.dot_op<{{opIdx = 0, parent = #mfma, kWidth = {k_width}}}>> | ||
// CHECK: triton_gpu.convert_layout {{{{.*}}}} : (tensor<{{{{.*}}}}, #triton_gpu.dot_op<{{opIdx = 1, parent = #blocked}}>>) -> tensor<{{{{.*}}}}, #triton_gpu.dot_op<{{opIdx = 1, parent = #mfma, kWidth = {k_width}}}>>""" | ||
else: | ||
mfma_check = "" | ||
label_check = f"// CHECK-NOT: convert_dot_{cfg_name}" | ||
checks = "" | ||
|
||
case_text = f''' | ||
!a_ty = {a_ty} | ||
!b_ty = {b_ty} | ||
!c_ty = {c_ty} | ||
#blocked = #triton_gpu.blocked<{{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}}> | ||
#dot_operand_a = #triton_gpu.dot_op<{{opIdx=0, parent=#blocked}}> | ||
#dot_operand_b = #triton_gpu.dot_op<{{opIdx=1, parent=#blocked}}> | ||
module attributes {{"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32}} {{ | ||
{mfma_check} | ||
{label_check} | ||
tt.func @convert_dot_{cfg_name}(%a: tensor<{M}x{K}x!a_ty, #dot_operand_a>, %b: tensor<{K}x{N}x!b_ty, #dot_operand_b>) -> tensor<{M}x{N}x!c_ty, #blocked> {{ | ||
%cst_c = arith.constant dense<{cst_val}> : tensor<{M}x{N}x!c_ty, #blocked> | ||
{checks} | ||
%D = tt.dot %a, %b, %cst_c {{allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false}} : tensor<{M}x{K}x!a_ty, #dot_operand_a> * tensor<{K}x{N}x!b_ty, #dot_operand_b> -> tensor<{M}x{N}x!c_ty, #blocked> | ||
tt.return %D: tensor<{M}x{N}x!c_ty, #blocked> | ||
}} | ||
}} | ||
''' | ||
if cfg_id == len(configs) - 1: | ||
print(case_text, end="", file=output_file) | ||
else: | ||
print(case_text, end="// -----\n", file=output_file) | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("cdna_version", type=int) | ||
parser.add_argument("output_file", type=str) | ||
args = parser.parse_args() | ||
with open(args.output_file, "w") as f: | ||
generate(cdna_version=args.cdna_version, output_file=f) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
import argparse | ||
import sys | ||
|
||
# matrix code version, mfma instruction name | ||
configs = [(3, "mfma_f32_32x32x16_fp8_fp8"), | ||
(3, "mfma_f32_32x32x16_fp8_bf8"), | ||
(3, "mfma_f32_32x32x16_bf8_fp8"), | ||
(3, "mfma_f32_32x32x16_bf8_bf8"), | ||
(2, "mfma_f32_32x32x8f16"), | ||
(1, "mfma_f32_32x32x4bf16"), | ||
(2, "mfma_f32_32x32x8bf16_1k"), | ||
(2, "mfma_f32_32x32x2f32"), | ||
(2, "mfma_i32_32x32x8i8"), | ||
(3, "mfma_i32_32x32x16_i8"), | ||
(3, "mfma_f32_16x16x32_fp8_fp8"), | ||
(3, "mfma_f32_16x16x32_fp8_bf8"), | ||
(3, "mfma_f32_16x16x32_bf8_fp8"), | ||
(3, "mfma_f32_16x16x32_bf8_bf8"), | ||
(2, "mfma_f32_16x16x16f16"), | ||
(1, "mfma_f32_16x16x8bf16"), | ||
(2, "mfma_f32_16x16x16bf16_1k"), | ||
(2, "mfma_f32_16x16x4f32"), | ||
(2, "mfma_i32_16x16x16i8"), | ||
(3, "mfma_i32_16x16x32_i8"), | ||
(2, "mfma_f32_4x4x4f16"), | ||
(1, "mfma_f32_4x4x2bf16"), | ||
(2, "mfma_f32_4x4x4bf16_1k"), | ||
(2, "mfma_f32_4x4x1f32"), | ||
(2, "mfma_i32_4x4x4i8")] | ||
|
||
def generate(output_file): | ||
print(f'// This file is generated: $ python3 {" ".join(sys.argv)}', file=output_file) | ||
print('// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm="target=rocdl" 2>/dev/null | FileCheck --check-prefixes=CHECK,GCN %s', file=output_file) | ||
|
||
for cfg_id in range(len(configs)): | ||
matrix_core_version = configs[cfg_id][0] | ||
cfg = configs[cfg_id][1] | ||
parts = cfg.split("_") | ||
if parts[-1] == "1k": | ||
parts = parts[:-1] | ||
shape = parts[2].split("x") | ||
for ty in ["bf8", "f8", "bf16", "f16", "f32", "i8"]: | ||
if ty in shape[-1]: | ||
shape[-1] = shape[-1][:-len(ty)] | ||
parts += [ty] | ||
break | ||
for i in range(len(shape)): | ||
shape[i] = int(shape[i]) | ||
# shape | ||
non_k_dim = shape[0] | ||
k_width = shape[2] | ||
if non_k_dim == 32: | ||
k_width //= 2 | ||
if non_k_dim == 16: | ||
k_width //= 4 | ||
# types | ||
b_ty = parts[-1] | ||
if b_ty in ["fp8", "bf8"]: | ||
a_ty = parts[-2] | ||
else: | ||
a_ty = b_ty | ||
c_ty = parts[1] | ||
|
||
mlir_type_names = { | ||
"fp8": "f8E4M3FNUZ", | ||
"bf8": "f8E5M2FNUZ", | ||
"f16": "f16", | ||
"bf16": "bf16", | ||
"f32": "f32", | ||
"i8": "i8", | ||
"i32": "i32"} | ||
a_ty = mlir_type_names[a_ty] | ||
b_ty = mlir_type_names[b_ty] | ||
c_ty = mlir_type_names[c_ty] | ||
|
||
# misc | ||
if "i" in c_ty: | ||
cst_val = "0" | ||
else: | ||
cst_val = "0.000000e+00" | ||
|
||
# repeats | ||
if non_k_dim == 32: | ||
M = 128 | ||
N = 32 | ||
K = 256 | ||
if non_k_dim == 16: | ||
M = 128 | ||
N = 32 | ||
K = 256 | ||
if non_k_dim == 4: | ||
M = 128 | ||
N = 32 | ||
K = 256 | ||
|
||
num_subgroups = 1 | ||
if non_k_dim == 4: | ||
num_subgroups = 16 | ||
num_reps = (M // non_k_dim) * (N // non_k_dim) * (K // (shape[2] * num_subgroups)) | ||
|
||
# mlir operation name | ||
cfg.split | ||
mlir_op_name = "rocdl." + cfg.replace("_", ".") | ||
case_text = f''' | ||
!a_ty = {a_ty} | ||
!b_ty = {b_ty} | ||
!c_ty = {c_ty} | ||
#k_width = {k_width} | ||
#non_k_dim = {non_k_dim} | ||
#mfmaVersion = {matrix_core_version} | ||
#mfma = #triton_gpu.mfma<{{versionMajor = #mfmaVersion, warpsPerCTA=[1,1], instrShape = [#non_k_dim, #non_k_dim], isTranspose=false}}> | ||
#dot_operand_a = #triton_gpu.dot_op<{{opIdx=0, parent=#mfma, kWidth = #k_width}}> | ||
#dot_operand_b = #triton_gpu.dot_op<{{opIdx=1, parent=#mfma, kWidth = #k_width}}> | ||
module attributes {{"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32}} {{ | ||
// CHECK-LABEL: convert_dot_{cfg} | ||
tt.func @convert_dot_{cfg}(%a: tensor<{M}x{K}x!a_ty, #dot_operand_a>, %b: tensor<{K}x{N}x!b_ty, #dot_operand_b>) {{ | ||
%cst_c = arith.constant dense<{cst_val}> : tensor<{M}x{N}x!c_ty, #mfma> | ||
// GCN-COUNT-{num_reps}: {mlir_op_name} | ||
%D = tt.dot %a, %b, %cst_c {{allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false}} : tensor<{M}x{K}x!a_ty, #dot_operand_a> * tensor<{K}x{N}x!b_ty, #dot_operand_b> -> tensor<{M}x{N}x!c_ty, #mfma> | ||
tt.return | ||
}} | ||
}} | ||
''' | ||
if cfg_id == len(configs) - 1: | ||
print(case_text, end="", file=output_file) | ||
else: | ||
print(case_text, end="// -----\n", file=output_file) | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("output_file", type=str) | ||
args = parser.parse_args() | ||
with open(args.output_file, "w") as f: | ||
generate(output_file=f) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters