Skip to content

Commit

Permalink
[MFMA][Test] Add scripts generating mfma related lit tests
Browse files Browse the repository at this point in the history
This PR adds scripts for lit test generation.
  • Loading branch information
binarman committed Jan 18, 2024
1 parent e703321 commit e6b907a
Show file tree
Hide file tree
Showing 3 changed files with 305 additions and 5 deletions.
164 changes: 164 additions & 0 deletions scripts/amd/lit_tests/generate_accelerate_matmul_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import argparse
import sys

# M N K a_ty b_ty c_ty
configs = [[32, 32, 32, "f16", "f16", "f32"],
[32, 32, 32, "bf16", "bf16", "f32"],
[32, 32, 32, "f32", "f32", "f32"],
[32, 32, 32, "i8", "i8", "i32"],
[32, 32, 32, "f8E4M3FNUZ", "f8E4M3FNUZ", "f32"],
[32, 32, 32, "f8E4M3FNUZ", "f8E5M2FNUZ", "f32"],
[32, 32, 32, "f8E5M2FNUZ", "f8E4M3FNUZ", "f32"],
[32, 32, 32, "f8E5M2FNUZ", "f8E5M2FNUZ", "f32"],

[16, 16, 32, "f16", "f16", "f32"],
[16, 16, 32, "bf16", "bf16", "f32"],
[16, 16, 32, "f32", "f32", "f32"],
[16, 16, 32, "i8", "i8", "i32"],
[16, 16, 32, "f8E4M3FNUZ", "f8E4M3FNUZ", "f32"],
[16, 16, 32, "f8E4M3FNUZ", "f8E5M2FNUZ", "f32"],
[16, 16, 32, "f8E5M2FNUZ", "f8E4M3FNUZ", "f32"],
[16, 16, 32, "f8E5M2FNUZ", "f8E5M2FNUZ", "f32"],

[4, 4, 64, "f16", "f16", "f32"],
[4, 4, 64, "bf16", "bf16", "f32"],
[4, 4, 64, "f32", "f32", "f32"],
[4, 4, 64, "i8", "i8", "i32"],
[4, 4, 64, "f8E4M3FNUZ", "f8E4M3FNUZ", "f32"],
[4, 4, 64, "f8E4M3FNUZ", "f8E5M2FNUZ", "f32"],
[4, 4, 64, "f8E5M2FNUZ", "f8E4M3FNUZ", "f32"],
[4, 4, 64, "f8E5M2FNUZ", "f8E5M2FNUZ", "f32"],

[64, 4, 4, "f16", "f16", "f32"],
[64, 4, 4, "bf16", "bf16", "f32"],
[64, 4, 4, "f32", "f32", "f32"],
[64, 4, 4, "i8", "i8", "i32"],
[64, 4, 4, "f8E4M3FNUZ", "f8E4M3FNUZ", "f32"],
[64, 4, 4, "f8E4M3FNUZ", "f8E5M2FNUZ", "f32"],
[64, 4, 4, "f8E5M2FNUZ", "f8E4M3FNUZ", "f32"],
[64, 4, 4, "f8E5M2FNUZ", "f8E5M2FNUZ", "f32"],

[4, 64, 4, "f16", "f16", "f32"],
[4, 64, 4, "bf16", "bf16", "f32"],
[4, 64, 4, "f32", "f32", "f32"],
[4, 64, 4, "i8", "i8", "i32"],
[4, 64, 4, "f8E4M3FNUZ", "f8E4M3FNUZ", "f32"],
[4, 64, 4, "f8E4M3FNUZ", "f8E5M2FNUZ", "f32"],
[4, 64, 4, "f8E5M2FNUZ", "f8E4M3FNUZ", "f32"],
[4, 64, 4, "f8E5M2FNUZ", "f8E5M2FNUZ", "f32"]
]

def generate(cdna_version, output_file):
arch_names = {0:"", 1: "gfx908", 2: "gfx90a", 3: "gfx940"}
arch_name = arch_names[cdna_version]
print(f"// This file is generated: $ python3 {' '.join(sys.argv)}", file=output_file)
print(f"// RUN: (! triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul=arch-generation-name={arch_name} --mlir-pass-pipeline-crash-reproducer=%t 2>/dev/null) | FileCheck --check-prefixes=CHECK %s", file=output_file)

for cfg_id in range(len(configs)):
cfg = configs[cfg_id]

cfg_name = "_".join([str(item) for item in cfg])

M, N, K, a_ty, b_ty, c_ty = cfg
if "i" in c_ty:
cst_val = "0"
else:
cst_val = "0.000000e+00"

supported = True
if cdna_version < 3 and ("f8" in a_ty or "f8" in b_ty):
supported = False

if M >= 32 and N >= 32:
m_dim = 32
n_dim = 32
elif M >= 16 and N >= 16:
m_dim = 16
n_dim = 16
elif M >= 64 and N < 16:
m_dim = 64
n_dim = 4
elif M < 16 and N >= 64:
m_dim = 4
n_dim = 64
elif M < 16 and N < 16:
m_dim = 4
n_dim = 4
if ("f8" in a_ty or "f8" in b_ty) and min(m_dim, n_dim) == 4:
supported = False

if cdna_version == 1:
if a_ty == "f16":
k_width = 4
if a_ty == "bf16":
k_width = 2
if a_ty == "i8":
k_width = 4
if a_ty == "f32":
k_width = 1
if cdna_version == 2:
if a_ty == "f16":
k_width = 4
if a_ty == "bf16":
k_width = 4
if a_ty == "i8":
k_width = 4
if a_ty == "f32":
k_width = 1
if cdna_version == 3:
if "f8" in a_ty:
k_width = 8
if a_ty == "f16":
k_width = 4
if a_ty == "bf16":
k_width = 4
if a_ty == "i8":
if min(m_dim, n_dim) == 4:
k_width = 4
else:
k_width = 8
if a_ty == "f32":
k_width = 1

if supported:
mfma_check = f"// CHECK: #mfma = #triton_gpu.mfma<{{version = {cdna_version}.0, warpsPerCTA = [1, 1], instrShape = [{m_dim}, {n_dim}], isTransposed = false}}>"
label_check = f"// CHECK: convert_dot_{cfg_name}"
checks =f"""// CHECK: triton_gpu.convert_layout {{{{.*}}}} : (tensor<{{{{.*}}}}, #blocked>) -> tensor<{{{{.*}}}}, #mfma>
// CHECK: triton_gpu.convert_layout {{{{.*}}}} : (tensor<{{{{.*}}}}, #triton_gpu.dot_op<{{opIdx = 0, parent = #blocked}}>>) -> tensor<{{{{.*}}}}, #triton_gpu.dot_op<{{opIdx = 0, parent = #mfma, kWidth = {k_width}}}>>
// CHECK: triton_gpu.convert_layout {{{{.*}}}} : (tensor<{{{{.*}}}}, #triton_gpu.dot_op<{{opIdx = 1, parent = #blocked}}>>) -> tensor<{{{{.*}}}}, #triton_gpu.dot_op<{{opIdx = 1, parent = #mfma, kWidth = {k_width}}}>>"""
else:
mfma_check = ""
label_check = f"// CHECK-NOT: convert_dot_{cfg_name}"
checks = ""

case_text = f'''
!a_ty = {a_ty}
!b_ty = {b_ty}
!c_ty = {c_ty}
#blocked = #triton_gpu.blocked<{{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}}>
#dot_operand_a = #triton_gpu.dot_op<{{opIdx=0, parent=#blocked}}>
#dot_operand_b = #triton_gpu.dot_op<{{opIdx=1, parent=#blocked}}>
module attributes {{"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32}} {{
{mfma_check}
{label_check}
tt.func @convert_dot_{cfg_name}(%a: tensor<{M}x{K}x!a_ty, #dot_operand_a>, %b: tensor<{K}x{N}x!b_ty, #dot_operand_b>) -> tensor<{M}x{N}x!c_ty, #blocked> {{
%cst_c = arith.constant dense<{cst_val}> : tensor<{M}x{N}x!c_ty, #blocked>
{checks}
%D = tt.dot %a, %b, %cst_c {{allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false}} : tensor<{M}x{K}x!a_ty, #dot_operand_a> * tensor<{K}x{N}x!b_ty, #dot_operand_b> -> tensor<{M}x{N}x!c_ty, #blocked>
tt.return %D: tensor<{M}x{N}x!c_ty, #blocked>
}}
}}
'''
if cfg_id == len(configs) - 1:
print(case_text, end="", file=output_file)
else:
print(case_text, end="// -----\n", file=output_file)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("cdna_version", type=int)
parser.add_argument("output_file", type=str)
args = parser.parse_args()
with open(args.output_file, "w") as f:
generate(cdna_version=args.cdna_version, output_file=f)
135 changes: 135 additions & 0 deletions scripts/amd/lit_tests/generate_mfma_variants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import argparse
import sys

# matrix code version, mfma instruction name
configs = [(3, "mfma_f32_32x32x16_fp8_fp8"),
(3, "mfma_f32_32x32x16_fp8_bf8"),
(3, "mfma_f32_32x32x16_bf8_fp8"),
(3, "mfma_f32_32x32x16_bf8_bf8"),
(2, "mfma_f32_32x32x8f16"),
(1, "mfma_f32_32x32x4bf16"),
(2, "mfma_f32_32x32x8bf16_1k"),
(2, "mfma_f32_32x32x2f32"),
(2, "mfma_i32_32x32x8i8"),
(3, "mfma_i32_32x32x16_i8"),
(3, "mfma_f32_16x16x32_fp8_fp8"),
(3, "mfma_f32_16x16x32_fp8_bf8"),
(3, "mfma_f32_16x16x32_bf8_fp8"),
(3, "mfma_f32_16x16x32_bf8_bf8"),
(2, "mfma_f32_16x16x16f16"),
(1, "mfma_f32_16x16x8bf16"),
(2, "mfma_f32_16x16x16bf16_1k"),
(2, "mfma_f32_16x16x4f32"),
(2, "mfma_i32_16x16x16i8"),
(3, "mfma_i32_16x16x32_i8"),
(2, "mfma_f32_4x4x4f16"),
(1, "mfma_f32_4x4x2bf16"),
(2, "mfma_f32_4x4x4bf16_1k"),
(2, "mfma_f32_4x4x1f32"),
(2, "mfma_i32_4x4x4i8")]

def generate(output_file):
print(f'// This file is generated: $ python3 {" ".join(sys.argv)}', file=output_file)
print('// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm="target=rocdl" 2>/dev/null | FileCheck --check-prefixes=CHECK,GCN %s', file=output_file)

for cfg_id in range(len(configs)):
matrix_core_version = configs[cfg_id][0]
cfg = configs[cfg_id][1]
parts = cfg.split("_")
if parts[-1] == "1k":
parts = parts[:-1]
shape = parts[2].split("x")
for ty in ["bf8", "f8", "bf16", "f16", "f32", "i8"]:
if ty in shape[-1]:
shape[-1] = shape[-1][:-len(ty)]
parts += [ty]
break
for i in range(len(shape)):
shape[i] = int(shape[i])
# shape
non_k_dim = shape[0]
k_width = shape[2]
if non_k_dim == 32:
k_width //= 2
if non_k_dim == 16:
k_width //= 4
# types
b_ty = parts[-1]
if b_ty in ["fp8", "bf8"]:
a_ty = parts[-2]
else:
a_ty = b_ty
c_ty = parts[1]

mlir_type_names = {
"fp8": "f8E4M3FNUZ",
"bf8": "f8E5M2FNUZ",
"f16": "f16",
"bf16": "bf16",
"f32": "f32",
"i8": "i8",
"i32": "i32"}
a_ty = mlir_type_names[a_ty]
b_ty = mlir_type_names[b_ty]
c_ty = mlir_type_names[c_ty]

# misc
if "i" in c_ty:
cst_val = "0"
else:
cst_val = "0.000000e+00"

# repeats
if non_k_dim == 32:
M = 128
N = 32
K = 256
if non_k_dim == 16:
M = 128
N = 32
K = 256
if non_k_dim == 4:
M = 128
N = 32
K = 256

num_subgroups = 1
if non_k_dim == 4:
num_subgroups = 16
num_reps = (M // non_k_dim) * (N // non_k_dim) * (K // (shape[2] * num_subgroups))

# mlir operation name
cfg.split
mlir_op_name = "rocdl." + cfg.replace("_", ".")
case_text = f'''
!a_ty = {a_ty}
!b_ty = {b_ty}
!c_ty = {c_ty}
#k_width = {k_width}
#non_k_dim = {non_k_dim}
#mfmaVersion = {matrix_core_version}
#mfma = #triton_gpu.mfma<{{versionMajor = #mfmaVersion, warpsPerCTA=[1,1], instrShape = [#non_k_dim, #non_k_dim], isTranspose=false}}>
#dot_operand_a = #triton_gpu.dot_op<{{opIdx=0, parent=#mfma, kWidth = #k_width}}>
#dot_operand_b = #triton_gpu.dot_op<{{opIdx=1, parent=#mfma, kWidth = #k_width}}>
module attributes {{"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32}} {{
// CHECK-LABEL: convert_dot_{cfg}
tt.func @convert_dot_{cfg}(%a: tensor<{M}x{K}x!a_ty, #dot_operand_a>, %b: tensor<{K}x{N}x!b_ty, #dot_operand_b>) {{
%cst_c = arith.constant dense<{cst_val}> : tensor<{M}x{N}x!c_ty, #mfma>
// GCN-COUNT-{num_reps}: {mlir_op_name}
%D = tt.dot %a, %b, %cst_c {{allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false}} : tensor<{M}x{K}x!a_ty, #dot_operand_a> * tensor<{K}x{N}x!b_ty, #dot_operand_b> -> tensor<{M}x{N}x!c_ty, #mfma>
tt.return
}}
}}
'''
if cfg_id == len(configs) - 1:
print(case_text, end="", file=output_file)
else:
print(case_text, end="// -----\n", file=output_file)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("output_file", type=str)
args = parser.parse_args()
with open(args.output_file, "w") as f:
generate(output_file=f)
11 changes: 6 additions & 5 deletions test/Conversion/AMDGPU/mfma_variants.mlir
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
// This file is generated: $ python3 generate_mfma_variants.py ../../../test/Conversion/AMDGPU/mfma_variants.mlir
// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm="target=rocdl" 2>/dev/null | FileCheck --check-prefixes=CHECK,GCN %s

!a_ty = f8E4M3FNUZ
Expand Down Expand Up @@ -90,7 +91,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
#k_width = 4
#non_k_dim = 32
#mfmaVersion = 2
#mfma = #triton_gpu.mfma<{versionMajor = #mfmaVersion , warpsPerCTA=[1,1], instrShape = [#non_k_dim, #non_k_dim], isTranspose=false}>
#mfma = #triton_gpu.mfma<{versionMajor = #mfmaVersion, warpsPerCTA=[1,1], instrShape = [#non_k_dim, #non_k_dim], isTranspose=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
Expand All @@ -111,7 +112,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
#k_width = 2
#non_k_dim = 32
#mfmaVersion = 1
#mfma = #triton_gpu.mfma<{versionMajor = 1, warpsPerCTA=[1,1], instrShape = [#non_k_dim, #non_k_dim], isTranspose=false}>
#mfma = #triton_gpu.mfma<{versionMajor = #mfmaVersion, warpsPerCTA=[1,1], instrShape = [#non_k_dim, #non_k_dim], isTranspose=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
Expand All @@ -132,7 +133,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
#k_width = 4
#non_k_dim = 32
#mfmaVersion = 2
#mfma = #triton_gpu.mfma<{versionMajor = #mfmaVersion , warpsPerCTA=[1,1], instrShape = [#non_k_dim, #non_k_dim], isTranspose=false}>
#mfma = #triton_gpu.mfma<{versionMajor = #mfmaVersion, warpsPerCTA=[1,1], instrShape = [#non_k_dim, #non_k_dim], isTranspose=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
Expand All @@ -153,7 +154,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
#k_width = 1
#non_k_dim = 32
#mfmaVersion = 2
#mfma = #triton_gpu.mfma<{versionMajor = #mfmaVersion , warpsPerCTA=[1,1], instrShape = [#non_k_dim, #non_k_dim], isTranspose=false}>
#mfma = #triton_gpu.mfma<{versionMajor = #mfmaVersion, warpsPerCTA=[1,1], instrShape = [#non_k_dim, #non_k_dim], isTranspose=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
Expand All @@ -174,7 +175,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
#k_width = 4
#non_k_dim = 32
#mfmaVersion = 2
#mfma = #triton_gpu.mfma<{versionMajor = #mfmaVersion , warpsPerCTA=[1,1], instrShape = [#non_k_dim, #non_k_dim], isTranspose=false}>
#mfma = #triton_gpu.mfma<{versionMajor = #mfmaVersion, warpsPerCTA=[1,1], instrShape = [#non_k_dim, #non_k_dim], isTranspose=false}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
Expand Down

0 comments on commit e6b907a

Please sign in to comment.