From 7a6f821cfb76be1041677fdd14bae073cab69ac5 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Mon, 7 Oct 2024 19:01:50 -0400 Subject: [PATCH 1/5] Fix accumulator and result types --- gemmbench/gemm_utils.py | 82 +++++++++++-------- gemmbench/problems.py | 169 ++++++++++++++++++++++++++++++++-------- 2 files changed, 186 insertions(+), 65 deletions(-) diff --git a/gemmbench/gemm_utils.py b/gemmbench/gemm_utils.py index 03c4228..30aff7f 100644 --- a/gemmbench/gemm_utils.py +++ b/gemmbench/gemm_utils.py @@ -15,10 +15,12 @@ class GemmConfig: K: int tA: str tB: str - dtype: str + operand_element_type: str + accumulator_element_type: str + result_element_type: str def get_name(self) -> str: - name = f"gemm_{self.M}_{self.N}_{self.K}_{self.dtype}" + name = f"gemm_{self.M}_{self.N}_{self.K}_{self.operand_element_type}_{self.accumulator_element_type}" if self.tA == "T": name += "_tA" elif self.tB == "T": @@ -27,20 +29,20 @@ def get_name(self) -> str: def get_inp1(self) -> str: if self.tA == "T": - inp1 = f"{self.K}x{self.M}x{self.dtype}" + inp1 = f"{self.K}x{self.M}x{self.operand_element_type}" else: - inp1 = f"{self.M}x{self.K}x{self.dtype}" + inp1 = f"{self.M}x{self.K}x{self.operand_element_type}" return inp1 def get_inp2(self) -> str: if self.tB == "T": - inp2 = f"{self.N}x{self.K}x{self.dtype}" + inp2 = f"{self.N}x{self.K}x{self.operand_element_type}" else: - inp2 = f"{self.K}x{self.N}x{self.dtype}" + inp2 = f"{self.K}x{self.N}x{self.operand_element_type}" return inp2 def get_byte_count(self) -> int: - dtype_bits_map = { + operand_element_type_bits_map = { "f32": 32, "f16": 16, "bf16": 16, @@ -48,7 +50,7 @@ def get_byte_count(self) -> int: "i8": 8, "i32": 32, } - bytes_per_element = dtype_bits_map[self.dtype] // 8 + bytes_per_element = operand_element_type_bits_map[self.operand_element_type] // 8 element_count = self.M * self.K + self.N * self.K + self.M * self.N byte_count = element_count * bytes_per_element return byte_count @@ -61,40 +63,52 @@ def generate_mlir(config: GemmConfig): K = config.K M = config.M N = config.N - dtype = config.dtype + operand_element_type = config.operand_element_type + acc_element_type = config.accumulator_element_type + result_element_type = config.result_element_type tA = config.tA tB = config.tB mlir_template_A = f""" module {{ - func.func @main(%arg0: tensor<{K}x{M}x{dtype}>, %arg1: tensor<{K}x{N}x{dtype}>) -> tensor<{M}x{N}x{dtype}> {{ - %cst = arith.constant 0.000000e+00 : {dtype} - %0 = tensor.empty() : tensor<{M}x{N}x{dtype}> - %1 = linalg.fill ins(%cst : {dtype}) outs(%0 : tensor<{M}x{N}x{dtype}>) -> tensor<{M}x{N}x{dtype}> - %2 = linalg.matmul_transpose_a ins(%arg0, %arg1 : tensor<{K}x{M}x{dtype}>, tensor<{K}x{N}x{dtype}>) outs(%1 : tensor<{M}x{N}x{dtype}>) -> tensor<{M}x{N}x{dtype}> - return %2 : tensor<{M}x{N}x{dtype}> + func.func @main(%arg0: tensor<{K}x{M}x{operand_element_type}>, %arg1: tensor<{K}x{N}x{operand_element_type}>) -> tensor<{M}x{N}x{result_element_type}> {{ + %cst = arith.constant 0.000000e+00 : {acc_element_type} + %0 = tensor.empty() : tensor<{M}x{N}x{acc_element_type}> + %1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}> + %2 = linalg.matmul_transpose_a ins(%arg0, %arg1 : tensor<{K}x{M}x{operand_element_type}>, tensor<{K}x{N}x{operand_element_type}>) + outs(%1 : tensor<{M}x{N}x{acc_element_type}>) + -> tensor<{M}x{N}x{acc_element_type}> + %3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> -> tensor<{M}x{N}x{result_element_type}> + return %3 : tensor<{M}x{N}x{result_element_type}> }} }} """ mlir_template_B = f""" module {{ - func.func @main(%arg0: tensor<{M}x{K}x{dtype}>, %arg1: tensor<{N}x{K}x{dtype}>) -> tensor<{M}x{N}x{dtype}> {{ - %cst = arith.constant 0.000000e+00 : {dtype} - %0 = tensor.empty() : tensor<{M}x{N}x{dtype}> - %1 = linalg.fill ins(%cst : {dtype}) outs(%0 : tensor<{M}x{N}x{dtype}>) -> tensor<{M}x{N}x{dtype}> - %2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<{M}x{K}x{dtype}>, tensor<{N}x{K}x{dtype}>) outs(%1 : tensor<{M}x{N}x{dtype}>) -> tensor<{M}x{N}x{dtype}> - return %2 : tensor<{M}x{N}x{dtype}> + func.func @main(%arg0: tensor<{M}x{K}x{operand_element_type}>, %arg1: tensor<{N}x{K}x{operand_element_type}>) -> tensor<{M}x{N}x{result_element_type}> {{ + %cst = arith.constant 0.000000e+00 : {acc_element_type} + %0 = tensor.empty() : tensor<{M}x{N}x{acc_element_type}> + %1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}> + %2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<{M}x{K}x{operand_element_type}>, tensor<{N}x{K}x{operand_element_type}>) + outs(%1 : tensor<{M}x{N}x{acc_element_type}>) + -> tensor<{M}x{N}x{acc_element_type}> + %3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> -> tensor<{M}x{N}x{result_element_type}> + return %3 : tensor<{M}x{N}x{result_element_type}> }} }} """ - mlir_template = f"""module {{ - func.func @main(%arg0: tensor<{M}x{K}x{dtype}>, %arg1: tensor<{K}x{N}x{dtype}>) -> tensor<{M}x{N}x{dtype}> {{ - %cst = arith.constant 0.000000e+00 : {dtype} - %0 = tensor.empty() : tensor<{M}x{N}x{dtype}> - %1 = linalg.fill ins(%cst : {dtype}) outs(%0 : tensor<{M}x{N}x{dtype}>) -> tensor<{M}x{N}x{dtype}> - %2 = linalg.matmul ins(%arg0, %arg1 : tensor<{M}x{K}x{dtype}>, tensor<{K}x{N}x{dtype}>) outs(%1 : tensor<{M}x{N}x{dtype}>) -> tensor<{M}x{N}x{dtype}> - return %2 : tensor<{M}x{N}x{dtype}> + mlir_template = f""" +module {{ + func.func @main(%arg0: tensor<{M}x{K}x{operand_element_type}>, %arg1: tensor<{K}x{N}x{operand_element_type}>) -> tensor<{M}x{N}x{result_element_type}> {{ + %cst = arith.constant 0.000000e+00 : {acc_element_type} + %0 = tensor.empty() : tensor<{M}x{N}x{acc_element_type}> + %1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}> + %2 = linalg.matmul ins(%arg0, %arg1 : tensor<{M}x{K}x{operand_element_type}>, tensor<{K}x{N}x{operand_element_type}>) + outs(%1 : tensor<{M}x{N}x{operand_element_type}>) + -> tensor<{M}x{N}x{acc_element_type}> + %3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> -> tensor<{M}x{N}x{result_element_type}> + return %3 : tensor<{M}x{N}x{result_element_type}> }} }} """ @@ -158,12 +172,12 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # repeat represents the results of the loop tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD) - + shape = [config.M, config.N, config.K] - dtype_map = { + operand_element_type_map = { "f16": torch.float16, } - dtype = dtype_map[config.dtype] + operand_element_type = operand_element_type_map[config.operand_element_type] hyperparams = { ADDRESS_SPACE: SHARED_ADDRESS_SPACE, @@ -180,9 +194,9 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: with tk.gen.TestLaunchContext( hyperparams, canonicalize=True, run=True, run_config=config ): - a = torch.randn(shape[0], shape[2], dtype=dtype) - b = torch.randn(shape[1], shape[2], dtype=dtype) - c = torch.zeros(shape[0], shape[1], dtype=torch.float32) + a = torch.randn(shape[0], shape[2], operand_element_type=operand_element_type) + b = torch.randn(shape[1], shape[2], operand_element_type=operand_element_type) + c = torch.zeros(shape[0], shape[1], operand_element_type=torch.float32) mb = gemm(a, b, c) return mb.module_op.get_asm() diff --git a/gemmbench/problems.py b/gemmbench/problems.py index c0cc8c5..e7fd018 100644 --- a/gemmbench/problems.py +++ b/gemmbench/problems.py @@ -9,11 +9,27 @@ import re -def is_compute_bound(M, N, K, bpe): +def num_bytes(dtype: str) -> int: + return {"f16": 2, "bf16": 2, "f32": 4, "i8": 1, "i32": 4}[dtype] + + +def get_default_accumulator_element_type(operand_element_type: str) -> str: + return {"f16": "f32", "bf16": "f32", "f32": "f32", "i8": "i32", "i32": "i32"}[ + operand_element_type + ] + + +def get_default_result_element_type(operand_element_type: str) -> str: + return operand_element_type + + +def is_compute_bound(M: int, N: int, K: int, dtype: str) -> bool: """Is this GEMM compute (or memory) bound?""" magic_ratio = 64 flops = 2 * M * N * K - bytes = bpe * (M * K + K * N + M * N) + elem_type_bytes = num_bytes(dtype) + result_bytes = num_bytes(get_default_result_element_type(dtype)) + bytes = elem_type_bytes * (M * K + K * N) + result_bytes * (M * N) return flops > magic_ratio * bytes @@ -654,19 +670,24 @@ def is_compute_bound(M, N, K, bpe): (4096, 5120, 640), ] + def llama13bmatvec(dtype: str) -> list[GemmConfig]: configs = [] """LLAMA 13b, single batch, FP16.""" for m, n, k, model, gcount in LLAMA: if n == 1 and model == "13b": - configs.append(GemmConfig( - m, - n, - k, - "T", - "N", - dtype - )) + configs.append( + GemmConfig( + m, + n, + k, + "T", + "N", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype), + ) + ) return configs @@ -681,7 +702,9 @@ def llama13bmatvecbf16(dtype: str) -> list[GemmConfig]: k, "T", "N", - dtype + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype), )) return configs @@ -697,7 +720,9 @@ def llama70bmatvec(dtype: str) -> list[GemmConfig]: k, "T", "N", - dtype + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype), )) return configs @@ -713,7 +738,9 @@ def llama70bmatvecbf16(dtype: str) -> list[GemmConfig]: k, "T", "N", - dtype + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype), )) return configs @@ -730,7 +757,9 @@ def llama13bskinny(dtype: str) -> list[GemmConfig]: k, "T", "N", - dtype + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype), )) return configs @@ -747,7 +776,9 @@ def llama13bskinnybf16(dtype: str) -> list[GemmConfig]: k, "T", "N", - dtype + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype), )) return configs @@ -764,7 +795,9 @@ def llama70bskinny(dtype: str) -> list[GemmConfig]: k, "T", "N", - dtype + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype), )) return configs @@ -781,7 +814,9 @@ def llama70bskinnybf16(dtype: str) -> list[GemmConfig]: k, "T", "N", - dtype + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype), )) return configs @@ -790,8 +825,17 @@ def gpt4memory(dtype: str) -> list[GemmConfig]: """GPT4 memory bound GEMMs; FP16.""" configs = [] for m, n, k in GPT4: - hgemm = GemmConfig(m, n, k, "N", "N", dtype) - if not is_compute_bound(m, n, k, 2): + hgemm = GemmConfig( + m, + n, + k, + "N", + "N", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype), + ) + if not is_compute_bound(m, n, k, dtype): configs.append(hgemm) return configs @@ -800,28 +844,51 @@ def gpt4compute(dtype: str) -> list[GemmConfig]: """GPT4 compute bound GEMMs; FP16.""" configs = [] for m, n, k in GPT4: - hgemm = GemmConfig(m, n, k, "N", "N", dtype) - if is_compute_bound(m, n, k, 2): + hgemm = GemmConfig( + m, + n, + k, + "N", + "N", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype), + ) + if is_compute_bound(m, n, k, dtype): configs.append(hgemm) return configs def tk_default(dtype: str) -> list[GemmConfig]: """TK Shapes.""" + acc_type = get_default_accumulator_element_type(dtype) + res_type = get_default_result_element_type(dtype) configs = [] M, N, K = 1024, 5120, 640 - configs.append(GemmConfig(M, N, K, "N", "T", dtype)) + configs.append(GemmConfig(M, N, K, "N", "T", dtype, acc_type, res_type)) M, N, K = 2048, 10240, 1280 - configs.append(GemmConfig(M, N, K, "N", "T", dtype)) + configs.append(GemmConfig(M, N, K, "N", "T", dtype, acc_type, res_type)) M, N, K = 4096, 20480, 2560 - configs.append(GemmConfig(M, N, K, "N", "T", dtype)) + configs.append(GemmConfig(M, N, K, "N", "T", dtype, acc_type, res_type)) return configs + def tk_unet(dtype: str) -> list[GemmConfig]: """UNET Shapes for TK.""" configs = [] for m, n, k in UNET: - configs.append(GemmConfig(m, n, k, "N", "T", dtype)) + configs.append( + GemmConfig( + m, + n, + k, + "N", + "T", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype), + ) + ) return configs @@ -829,25 +896,59 @@ def llama70bmemory(dtype: str) -> list[GemmConfig]: """LLAMA 70b memory bound GEMMs; NT; BF16.""" configs = [] for n in [1280, 3584, 7168]: - configs.append(GemmConfig(2, n, 8192, "N", "T", dtype)) + configs.append( + GemmConfig( + 2, + n, + 8192, + "N", + "T", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype), + ) + ) return configs def compute(dtype: str) -> list[GemmConfig]: """Compute bound GEMMs.""" - #for dtype in ["fp16", "bf16", "fp8"]: configs = [] for tA, tB in [("N", "N"), ("N", "T"), ("T", "N")]: - configs.append(GemmConfig(4096, 4096, 8192, tA, tB, dtype)) + configs.append( + GemmConfig( + 4096, + 4096, + 8192, + tA, + tB, + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype), + ) + ) return configs + def unet(dtype: str) -> list[GemmConfig]: configs = [] for tA, tB in [("N", "N"), ("N", "T")]: for m, n, k in UNET: - configs.append(GemmConfig(m, n, k, tA, tB, dtype)) + configs.append( + GemmConfig( + m, + n, + k, + tA, + tB, + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype), + ) + ) return configs + def get_gemm_configs() -> list[tuple[str, GemmConfig]]: llama13bmatvec_configs: list[GemmConfig] = [] llama13bmatvec_configs += llama13bmatvec("f16") @@ -890,6 +991,7 @@ def get_gemm_configs() -> list[tuple[str, GemmConfig]]: return all_configs + def get_tk_gemm_configs() -> list[tuple[str, GemmConfig]]: configs: list[tuple[str, GemmConfig]] = [] tk_default_configs = tk_default("f16") @@ -899,8 +1001,13 @@ def get_tk_gemm_configs() -> list[tuple[str, GemmConfig]]: configs += [("unet", x) for x in tk_unet_configs] return configs -def get_matching_configs(tagged_configs: list[tuple[str, GemmConfig]], - dtypes: list[str], variants: list[str], tag_regex: str) -> list[tuple[str, GemmConfig]]: + +def get_matching_configs( + tagged_configs: list[tuple[str, GemmConfig]], + dtypes: list[str], + variants: list[str], + tag_regex: str, +) -> list[tuple[str, GemmConfig]]: tag_re = re.compile(tag_regex) matching_configs: list[tuple[str, GemmConfig]] = [] for tag, config in tagged_configs: From 48a258b5a8317a9db7bacbe58ac5d1fd5fef1d4c Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Mon, 7 Oct 2024 19:36:48 -0400 Subject: [PATCH 2/5] wip --- gemmbench/gemm_bench.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gemmbench/gemm_bench.py b/gemmbench/gemm_bench.py index 56c6dbf..d41d326 100644 --- a/gemmbench/gemm_bench.py +++ b/gemmbench/gemm_bench.py @@ -107,7 +107,7 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args, target = args.target extra_compiler_args = list(args.Xiree_compile) dump_dir = args.dump_dir - + args = itertools.starmap( lambda tag, config: (tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args, tk, dump_dir), configs ) From 9f34a7377a27683983bb74f626a99a6dad563f13 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Mon, 7 Oct 2024 18:54:08 -0500 Subject: [PATCH 3/5] fixup --- gemmbench/gemm_bench.py | 2 +- gemmbench/gemm_utils.py | 42 ++++++++++++++++++++--------------------- gemmbench/problems.py | 2 +- 3 files changed, 22 insertions(+), 24 deletions(-) diff --git a/gemmbench/gemm_bench.py b/gemmbench/gemm_bench.py index d41d326..c3afe14 100644 --- a/gemmbench/gemm_bench.py +++ b/gemmbench/gemm_bench.py @@ -171,7 +171,7 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args, tflops_per_second = (flops / 1e12) / (benchmark_gemm_mean_time_us / 1e6) results.append(( - index, tag, name, vmfb_hash, config.M, config.N, config.K, config.dtype, config.tA, config.tB, + index, tag, name, vmfb_hash, config.M, config.N, config.K, config.operand_element_type, config.tA, config.tB, round(benchmark_gemm_mean_time_us, 4), round(arithmetic_intensity, 4), round(tflops_per_second, 4), diff --git a/gemmbench/gemm_utils.py b/gemmbench/gemm_utils.py index 30aff7f..348a5ca 100644 --- a/gemmbench/gemm_utils.py +++ b/gemmbench/gemm_utils.py @@ -29,30 +29,26 @@ def get_name(self) -> str: def get_inp1(self) -> str: if self.tA == "T": - inp1 = f"{self.K}x{self.M}x{self.operand_element_type}" - else: - inp1 = f"{self.M}x{self.K}x{self.operand_element_type}" - return inp1 + return f"{self.K}x{self.M}x{self.operand_element_type}" + return f"{self.M}x{self.K}x{self.operand_element_type}" def get_inp2(self) -> str: if self.tB == "T": - inp2 = f"{self.N}x{self.K}x{self.operand_element_type}" - else: - inp2 = f"{self.K}x{self.N}x{self.operand_element_type}" - return inp2 + return f"{self.N}x{self.K}x{self.operand_element_type}" + return f"{self.K}x{self.N}x{self.operand_element_type}" def get_byte_count(self) -> int: - operand_element_type_bits_map = { - "f32": 32, - "f16": 16, - "bf16": 16, - "f8E4M3FNUZ": 8, - "i8": 8, - "i32": 32, + dtype_to_bytes = { + "f32": 4, + "f16": 2, + "bf16": 2, + "f8E4M3FNUZ": 1, + "i8": 1, + "i32": 4, } - bytes_per_element = operand_element_type_bits_map[self.operand_element_type] // 8 - element_count = self.M * self.K + self.N * self.K + self.M * self.N - byte_count = element_count * bytes_per_element + operand_bytes_per_element = dtype_to_bytes[self.operand_element_type] + result_bytes_per_element = dtype_to_bytes[self.result_element_type] + byte_count = (self.M * self.K + self.N * self.K) * operand_bytes_per_element + (self.M * self.N) * result_bytes_per_element return byte_count def get_flops(self) -> int: @@ -66,6 +62,8 @@ def generate_mlir(config: GemmConfig): operand_element_type = config.operand_element_type acc_element_type = config.accumulator_element_type result_element_type = config.result_element_type + assert not operand_element_type.startswith('i'), "Integer types not supported yet" + tA = config.tA tB = config.tB mlir_template_A = f""" @@ -77,7 +75,7 @@ def generate_mlir(config: GemmConfig): %2 = linalg.matmul_transpose_a ins(%arg0, %arg1 : tensor<{K}x{M}x{operand_element_type}>, tensor<{K}x{N}x{operand_element_type}>) outs(%1 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}> - %3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> -> tensor<{M}x{N}x{result_element_type}> + %3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}> return %3 : tensor<{M}x{N}x{result_element_type}> }} }} @@ -92,7 +90,7 @@ def generate_mlir(config: GemmConfig): %2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<{M}x{K}x{operand_element_type}>, tensor<{N}x{K}x{operand_element_type}>) outs(%1 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}> - %3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> -> tensor<{M}x{N}x{result_element_type}> + %3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}> return %3 : tensor<{M}x{N}x{result_element_type}> }} }} @@ -105,9 +103,9 @@ def generate_mlir(config: GemmConfig): %0 = tensor.empty() : tensor<{M}x{N}x{acc_element_type}> %1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}> %2 = linalg.matmul ins(%arg0, %arg1 : tensor<{M}x{K}x{operand_element_type}>, tensor<{K}x{N}x{operand_element_type}>) - outs(%1 : tensor<{M}x{N}x{operand_element_type}>) + outs(%1 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}> - %3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> -> tensor<{M}x{N}x{result_element_type}> + %3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}> return %3 : tensor<{M}x{N}x{result_element_type}> }} }} diff --git a/gemmbench/problems.py b/gemmbench/problems.py index e7fd018..69dbe47 100644 --- a/gemmbench/problems.py +++ b/gemmbench/problems.py @@ -1011,7 +1011,7 @@ def get_matching_configs( tag_re = re.compile(tag_regex) matching_configs: list[tuple[str, GemmConfig]] = [] for tag, config in tagged_configs: - if config.dtype not in dtypes: + if config.operand_element_type not in dtypes: continue if f"{config.tA}{config.tB}" not in variants: continue From fda11c9e767e1274e4f317061c23d6a10fd59c4b Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Mon, 7 Oct 2024 20:08:48 -0500 Subject: [PATCH 4/5] Add tk assertion --- gemmbench/gemm_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/gemmbench/gemm_utils.py b/gemmbench/gemm_utils.py index 348a5ca..7f5a865 100644 --- a/gemmbench/gemm_utils.py +++ b/gemmbench/gemm_utils.py @@ -116,7 +116,10 @@ def generate_mlir(config: GemmConfig): return mlir_template_B return mlir_template + def generate_tk_mlir(config: GemmConfig): + assert config.operand_element_type == 'f16', "Unsupported problem" + assert config.accumulator_element_type == 'f32', "Unsupported problem" # Input sizes M = tkl.sym.M N = tkl.sym.N From d1e8c7aa17ca11144a0a46fc552b47b5b09b1aed Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Mon, 7 Oct 2024 22:56:37 -0500 Subject: [PATCH 5/5] Fix --- gemmbench/gemm_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gemmbench/gemm_utils.py b/gemmbench/gemm_utils.py index 7f5a865..7dbbdcd 100644 --- a/gemmbench/gemm_utils.py +++ b/gemmbench/gemm_utils.py @@ -195,9 +195,9 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: with tk.gen.TestLaunchContext( hyperparams, canonicalize=True, run=True, run_config=config ): - a = torch.randn(shape[0], shape[2], operand_element_type=operand_element_type) - b = torch.randn(shape[1], shape[2], operand_element_type=operand_element_type) - c = torch.zeros(shape[0], shape[1], operand_element_type=torch.float32) + a = torch.randn(shape[0], shape[2], dtype=operand_element_type) + b = torch.randn(shape[1], shape[2], dtype=operand_element_type) + c = torch.zeros(shape[0], shape[1], dtype=torch.float32) mb = gemm(a, b, c) return mb.module_op.get_asm()