diff --git a/scripts/amd/lit_tests/generate_accelerate_matmul_tests.py b/scripts/amd/lit_tests/generate_accelerate_matmul_tests.py new file mode 100755 index 000000000000..edc66770f4fd --- /dev/null +++ b/scripts/amd/lit_tests/generate_accelerate_matmul_tests.py @@ -0,0 +1,164 @@ +import argparse +import sys + +# M N K a_ty b_ty c_ty +configs = [[32, 32, 32, "f16", "f16", "f32"], + [32, 32, 32, "bf16", "bf16", "f32"], + [32, 32, 32, "f32", "f32", "f32"], + [32, 32, 32, "i8", "i8", "i32"], + [32, 32, 32, "f8E4M3FNUZ", "f8E4M3FNUZ", "f32"], + [32, 32, 32, "f8E4M3FNUZ", "f8E5M2FNUZ", "f32"], + [32, 32, 32, "f8E5M2FNUZ", "f8E4M3FNUZ", "f32"], + [32, 32, 32, "f8E5M2FNUZ", "f8E5M2FNUZ", "f32"], + + [16, 16, 32, "f16", "f16", "f32"], + [16, 16, 32, "bf16", "bf16", "f32"], + [16, 16, 32, "f32", "f32", "f32"], + [16, 16, 32, "i8", "i8", "i32"], + [16, 16, 32, "f8E4M3FNUZ", "f8E4M3FNUZ", "f32"], + [16, 16, 32, "f8E4M3FNUZ", "f8E5M2FNUZ", "f32"], + [16, 16, 32, "f8E5M2FNUZ", "f8E4M3FNUZ", "f32"], + [16, 16, 32, "f8E5M2FNUZ", "f8E5M2FNUZ", "f32"], + + [4, 4, 64, "f16", "f16", "f32"], + [4, 4, 64, "bf16", "bf16", "f32"], + [4, 4, 64, "f32", "f32", "f32"], + [4, 4, 64, "i8", "i8", "i32"], + [4, 4, 64, "f8E4M3FNUZ", "f8E4M3FNUZ", "f32"], + [4, 4, 64, "f8E4M3FNUZ", "f8E5M2FNUZ", "f32"], + [4, 4, 64, "f8E5M2FNUZ", "f8E4M3FNUZ", "f32"], + [4, 4, 64, "f8E5M2FNUZ", "f8E5M2FNUZ", "f32"], + + [64, 4, 4, "f16", "f16", "f32"], + [64, 4, 4, "bf16", "bf16", "f32"], + [64, 4, 4, "f32", "f32", "f32"], + [64, 4, 4, "i8", "i8", "i32"], + [64, 4, 4, "f8E4M3FNUZ", "f8E4M3FNUZ", "f32"], + [64, 4, 4, "f8E4M3FNUZ", "f8E5M2FNUZ", "f32"], + [64, 4, 4, "f8E5M2FNUZ", "f8E4M3FNUZ", "f32"], + [64, 4, 4, "f8E5M2FNUZ", "f8E5M2FNUZ", "f32"], + + [4, 64, 4, "f16", "f16", "f32"], + [4, 64, 4, "bf16", "bf16", "f32"], + [4, 64, 4, "f32", "f32", "f32"], + [4, 64, 4, "i8", "i8", "i32"], + [4, 64, 4, "f8E4M3FNUZ", "f8E4M3FNUZ", "f32"], + [4, 64, 4, "f8E4M3FNUZ", "f8E5M2FNUZ", "f32"], + [4, 64, 4, "f8E5M2FNUZ", "f8E4M3FNUZ", "f32"], + [4, 64, 4, "f8E5M2FNUZ", "f8E5M2FNUZ", "f32"] + ] + +def generate(cdna_version, output_file): + arch_names = {0:"", 1: "gfx908", 2: "gfx90a", 3: "gfx940"} + arch_name = arch_names[cdna_version] + print(f"// This file is generated: $ python3 {' '.join(sys.argv)}", file=output_file) + print(f"// RUN: (! triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul=arch-generation-name={arch_name} --mlir-pass-pipeline-crash-reproducer=%t 2>/dev/null) | FileCheck --check-prefixes=CHECK %s", file=output_file) + + for cfg_id in range(len(configs)): + cfg = configs[cfg_id] + + cfg_name = "_".join([str(item) for item in cfg]) + + M, N, K, a_ty, b_ty, c_ty = cfg + if "i" in c_ty: + cst_val = "0" + else: + cst_val = "0.000000e+00" + + supported = True + if cdna_version < 3 and ("f8" in a_ty or "f8" in b_ty): + supported = False + + if M >= 32 and N >= 32: + m_dim = 32 + n_dim = 32 + elif M >= 16 and N >= 16: + m_dim = 16 + n_dim = 16 + elif M >= 64 and N < 16: + m_dim = 64 + n_dim = 4 + elif M < 16 and N >= 64: + m_dim = 4 + n_dim = 64 + elif M < 16 and N < 16: + m_dim = 4 + n_dim = 4 + if ("f8" in a_ty or "f8" in b_ty) and min(m_dim, n_dim) == 4: + supported = False + + if cdna_version == 1: + if a_ty == "f16": + k_width = 4 + if a_ty == "bf16": + k_width = 2 + if a_ty == "i8": + k_width = 4 + if a_ty == "f32": + k_width = 1 + if cdna_version == 2: + if a_ty == "f16": + k_width = 4 + if a_ty == "bf16": + k_width = 4 + if a_ty == "i8": + k_width = 4 + if a_ty == "f32": + k_width = 1 + if cdna_version == 3: + if "f8" in a_ty: + k_width = 8 + if a_ty == "f16": + k_width = 4 + if a_ty == "bf16": + k_width = 4 + if a_ty == "i8": + if min(m_dim, n_dim) == 4: + k_width = 4 + else: + k_width = 8 + if a_ty == "f32": + k_width = 1 + + if supported: + mfma_check = f"// CHECK: #mfma = #triton_gpu.mfma<{{version = {cdna_version}.0, warpsPerCTA = [1, 1], instrShape = [{m_dim}, {n_dim}], isTransposed = false}}>" + label_check = f"// CHECK: convert_dot_{cfg_name}" + checks =f"""// CHECK: triton_gpu.convert_layout {{{{.*}}}} : (tensor<{{{{.*}}}}, #blocked>) -> tensor<{{{{.*}}}}, #mfma> +// CHECK: triton_gpu.convert_layout {{{{.*}}}} : (tensor<{{{{.*}}}}, #triton_gpu.dot_op<{{opIdx = 0, parent = #blocked}}>>) -> tensor<{{{{.*}}}}, #triton_gpu.dot_op<{{opIdx = 0, parent = #mfma, kWidth = {k_width}}}>> +// CHECK: triton_gpu.convert_layout {{{{.*}}}} : (tensor<{{{{.*}}}}, #triton_gpu.dot_op<{{opIdx = 1, parent = #blocked}}>>) -> tensor<{{{{.*}}}}, #triton_gpu.dot_op<{{opIdx = 1, parent = #mfma, kWidth = {k_width}}}>>""" + else: + mfma_check = "" + label_check = f"// CHECK-NOT: convert_dot_{cfg_name}" + checks = "" + + case_text = f''' +!a_ty = {a_ty} +!b_ty = {b_ty} +!c_ty = {c_ty} +#blocked = #triton_gpu.blocked<{{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}}> +#dot_operand_a = #triton_gpu.dot_op<{{opIdx=0, parent=#blocked}}> +#dot_operand_b = #triton_gpu.dot_op<{{opIdx=1, parent=#blocked}}> +module attributes {{"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32}} {{ +{mfma_check} +{label_check} + tt.func @convert_dot_{cfg_name}(%a: tensor<{M}x{K}x!a_ty, #dot_operand_a>, %b: tensor<{K}x{N}x!b_ty, #dot_operand_b>) -> tensor<{M}x{N}x!c_ty, #blocked> {{ + %cst_c = arith.constant dense<{cst_val}> : tensor<{M}x{N}x!c_ty, #blocked> +{checks} + %D = tt.dot %a, %b, %cst_c {{allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false}} : tensor<{M}x{K}x!a_ty, #dot_operand_a> * tensor<{K}x{N}x!b_ty, #dot_operand_b> -> tensor<{M}x{N}x!c_ty, #blocked> + tt.return %D: tensor<{M}x{N}x!c_ty, #blocked> + }} +}} + +''' + if cfg_id == len(configs) - 1: + print(case_text, end="", file=output_file) + else: + print(case_text, end="// -----\n", file=output_file) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("cdna_version", type=int) + parser.add_argument("output_file", type=str) + args = parser.parse_args() + with open(args.output_file, "w") as f: + generate(cdna_version=args.cdna_version, output_file=f) diff --git a/scripts/amd/lit_tests/generate_mfma_variants.py b/scripts/amd/lit_tests/generate_mfma_variants.py new file mode 100755 index 000000000000..00e09327b036 --- /dev/null +++ b/scripts/amd/lit_tests/generate_mfma_variants.py @@ -0,0 +1,135 @@ +import argparse +import sys + +# matrix code version, mfma instruction name +configs = [(3, "mfma_f32_32x32x16_fp8_fp8"), + (3, "mfma_f32_32x32x16_fp8_bf8"), + (3, "mfma_f32_32x32x16_bf8_fp8"), + (3, "mfma_f32_32x32x16_bf8_bf8"), + (2, "mfma_f32_32x32x8f16"), + (1, "mfma_f32_32x32x4bf16"), + (2, "mfma_f32_32x32x8bf16_1k"), + (2, "mfma_f32_32x32x2f32"), + (2, "mfma_i32_32x32x8i8"), + (3, "mfma_i32_32x32x16_i8"), + (3, "mfma_f32_16x16x32_fp8_fp8"), + (3, "mfma_f32_16x16x32_fp8_bf8"), + (3, "mfma_f32_16x16x32_bf8_fp8"), + (3, "mfma_f32_16x16x32_bf8_bf8"), + (2, "mfma_f32_16x16x16f16"), + (1, "mfma_f32_16x16x8bf16"), + (2, "mfma_f32_16x16x16bf16_1k"), + (2, "mfma_f32_16x16x4f32"), + (2, "mfma_i32_16x16x16i8"), + (3, "mfma_i32_16x16x32_i8"), + (2, "mfma_f32_4x4x4f16"), + (1, "mfma_f32_4x4x2bf16"), + (2, "mfma_f32_4x4x4bf16_1k"), + (2, "mfma_f32_4x4x1f32"), + (2, "mfma_i32_4x4x4i8")] + +def generate(output_file): + print(f'// This file is generated: $ python3 {" ".join(sys.argv)}', file=output_file) + print('// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm="target=rocdl" 2>/dev/null | FileCheck --check-prefixes=CHECK,GCN %s', file=output_file) + + for cfg_id in range(len(configs)): + matrix_core_version = configs[cfg_id][0] + cfg = configs[cfg_id][1] + parts = cfg.split("_") + if parts[-1] == "1k": + parts = parts[:-1] + shape = parts[2].split("x") + for ty in ["bf8", "f8", "bf16", "f16", "f32", "i8"]: + if ty in shape[-1]: + shape[-1] = shape[-1][:-len(ty)] + parts += [ty] + break + for i in range(len(shape)): + shape[i] = int(shape[i]) +# shape + non_k_dim = shape[0] + k_width = shape[2] + if non_k_dim == 32: + k_width //= 2 + if non_k_dim == 16: + k_width //= 4 +# types + b_ty = parts[-1] + if b_ty in ["fp8", "bf8"]: + a_ty = parts[-2] + else: + a_ty = b_ty + c_ty = parts[1] + + mlir_type_names = { + "fp8": "f8E4M3FNUZ", + "bf8": "f8E5M2FNUZ", + "f16": "f16", + "bf16": "bf16", + "f32": "f32", + "i8": "i8", + "i32": "i32"} + a_ty = mlir_type_names[a_ty] + b_ty = mlir_type_names[b_ty] + c_ty = mlir_type_names[c_ty] + +# misc + if "i" in c_ty: + cst_val = "0" + else: + cst_val = "0.000000e+00" + +# repeats + if non_k_dim == 32: + M = 128 + N = 32 + K = 256 + if non_k_dim == 16: + M = 128 + N = 32 + K = 256 + if non_k_dim == 4: + M = 128 + N = 32 + K = 256 + + num_subgroups = 1 + if non_k_dim == 4: + num_subgroups = 16 + num_reps = (M // non_k_dim) * (N // non_k_dim) * (K // (shape[2] * num_subgroups)) + +# mlir operation name + cfg.split + mlir_op_name = "rocdl." + cfg.replace("_", ".") + case_text = f''' +!a_ty = {a_ty} +!b_ty = {b_ty} +!c_ty = {c_ty} +#k_width = {k_width} +#non_k_dim = {non_k_dim} +#mfmaVersion = {matrix_core_version} +#mfma = #triton_gpu.mfma<{{versionMajor = #mfmaVersion, warpsPerCTA=[1,1], instrShape = [#non_k_dim, #non_k_dim], isTranspose=false}}> +#dot_operand_a = #triton_gpu.dot_op<{{opIdx=0, parent=#mfma, kWidth = #k_width}}> +#dot_operand_b = #triton_gpu.dot_op<{{opIdx=1, parent=#mfma, kWidth = #k_width}}> +module attributes {{"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32}} {{ + // CHECK-LABEL: convert_dot_{cfg} + tt.func @convert_dot_{cfg}(%a: tensor<{M}x{K}x!a_ty, #dot_operand_a>, %b: tensor<{K}x{N}x!b_ty, #dot_operand_b>) {{ + %cst_c = arith.constant dense<{cst_val}> : tensor<{M}x{N}x!c_ty, #mfma> + // GCN-COUNT-{num_reps}: {mlir_op_name} + %D = tt.dot %a, %b, %cst_c {{allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false}} : tensor<{M}x{K}x!a_ty, #dot_operand_a> * tensor<{K}x{N}x!b_ty, #dot_operand_b> -> tensor<{M}x{N}x!c_ty, #mfma> + tt.return + }} +}} + +''' + if cfg_id == len(configs) - 1: + print(case_text, end="", file=output_file) + else: + print(case_text, end="// -----\n", file=output_file) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("output_file", type=str) + args = parser.parse_args() + with open(args.output_file, "w") as f: + generate(output_file=f) diff --git a/test/Conversion/AMDGPU/mfma_variants.mlir b/test/Conversion/AMDGPU/mfma_variants.mlir index e40fcbfb9ccd..46e777a3d194 100644 --- a/test/Conversion/AMDGPU/mfma_variants.mlir +++ b/test/Conversion/AMDGPU/mfma_variants.mlir @@ -1,3 +1,4 @@ +// This file is generated: $ python3 generate_mfma_variants.py ../../../test/Conversion/AMDGPU/mfma_variants.mlir // RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm="target=rocdl" 2>/dev/null | FileCheck --check-prefixes=CHECK,GCN %s !a_ty = f8E4M3FNUZ @@ -90,7 +91,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #k_width = 4 #non_k_dim = 32 #mfmaVersion = 2 -#mfma = #triton_gpu.mfma<{versionMajor = #mfmaVersion , warpsPerCTA=[1,1], instrShape = [#non_k_dim, #non_k_dim], isTranspose=false}> +#mfma = #triton_gpu.mfma<{versionMajor = #mfmaVersion, warpsPerCTA=[1,1], instrShape = [#non_k_dim, #non_k_dim], isTranspose=false}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { @@ -111,7 +112,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #k_width = 2 #non_k_dim = 32 #mfmaVersion = 1 -#mfma = #triton_gpu.mfma<{versionMajor = 1, warpsPerCTA=[1,1], instrShape = [#non_k_dim, #non_k_dim], isTranspose=false}> +#mfma = #triton_gpu.mfma<{versionMajor = #mfmaVersion, warpsPerCTA=[1,1], instrShape = [#non_k_dim, #non_k_dim], isTranspose=false}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { @@ -132,7 +133,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #k_width = 4 #non_k_dim = 32 #mfmaVersion = 2 -#mfma = #triton_gpu.mfma<{versionMajor = #mfmaVersion , warpsPerCTA=[1,1], instrShape = [#non_k_dim, #non_k_dim], isTranspose=false}> +#mfma = #triton_gpu.mfma<{versionMajor = #mfmaVersion, warpsPerCTA=[1,1], instrShape = [#non_k_dim, #non_k_dim], isTranspose=false}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { @@ -153,7 +154,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #k_width = 1 #non_k_dim = 32 #mfmaVersion = 2 -#mfma = #triton_gpu.mfma<{versionMajor = #mfmaVersion , warpsPerCTA=[1,1], instrShape = [#non_k_dim, #non_k_dim], isTranspose=false}> +#mfma = #triton_gpu.mfma<{versionMajor = #mfmaVersion, warpsPerCTA=[1,1], instrShape = [#non_k_dim, #non_k_dim], isTranspose=false}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { @@ -174,7 +175,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #k_width = 4 #non_k_dim = 32 #mfmaVersion = 2 -#mfma = #triton_gpu.mfma<{versionMajor = #mfmaVersion , warpsPerCTA=[1,1], instrShape = [#non_k_dim, #non_k_dim], isTranspose=false}> +#mfma = #triton_gpu.mfma<{versionMajor = #mfmaVersion, warpsPerCTA=[1,1], instrShape = [#non_k_dim, #non_k_dim], isTranspose=false}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mfma, kWidth = #k_width}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mfma, kWidth = #k_width}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {