From 9f34a7377a27683983bb74f626a99a6dad563f13 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Mon, 7 Oct 2024 18:54:08 -0500 Subject: [PATCH] fixup --- gemmbench/gemm_bench.py | 2 +- gemmbench/gemm_utils.py | 42 ++++++++++++++++++++--------------------- gemmbench/problems.py | 2 +- 3 files changed, 22 insertions(+), 24 deletions(-) diff --git a/gemmbench/gemm_bench.py b/gemmbench/gemm_bench.py index d41d326..c3afe14 100644 --- a/gemmbench/gemm_bench.py +++ b/gemmbench/gemm_bench.py @@ -171,7 +171,7 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args, tflops_per_second = (flops / 1e12) / (benchmark_gemm_mean_time_us / 1e6) results.append(( - index, tag, name, vmfb_hash, config.M, config.N, config.K, config.dtype, config.tA, config.tB, + index, tag, name, vmfb_hash, config.M, config.N, config.K, config.operand_element_type, config.tA, config.tB, round(benchmark_gemm_mean_time_us, 4), round(arithmetic_intensity, 4), round(tflops_per_second, 4), diff --git a/gemmbench/gemm_utils.py b/gemmbench/gemm_utils.py index 30aff7f..348a5ca 100644 --- a/gemmbench/gemm_utils.py +++ b/gemmbench/gemm_utils.py @@ -29,30 +29,26 @@ def get_name(self) -> str: def get_inp1(self) -> str: if self.tA == "T": - inp1 = f"{self.K}x{self.M}x{self.operand_element_type}" - else: - inp1 = f"{self.M}x{self.K}x{self.operand_element_type}" - return inp1 + return f"{self.K}x{self.M}x{self.operand_element_type}" + return f"{self.M}x{self.K}x{self.operand_element_type}" def get_inp2(self) -> str: if self.tB == "T": - inp2 = f"{self.N}x{self.K}x{self.operand_element_type}" - else: - inp2 = f"{self.K}x{self.N}x{self.operand_element_type}" - return inp2 + return f"{self.N}x{self.K}x{self.operand_element_type}" + return f"{self.K}x{self.N}x{self.operand_element_type}" def get_byte_count(self) -> int: - operand_element_type_bits_map = { - "f32": 32, - "f16": 16, - "bf16": 16, - "f8E4M3FNUZ": 8, - "i8": 8, - "i32": 32, + dtype_to_bytes = { + "f32": 4, + "f16": 2, + "bf16": 2, + "f8E4M3FNUZ": 1, + "i8": 1, + "i32": 4, } - bytes_per_element = operand_element_type_bits_map[self.operand_element_type] // 8 - element_count = self.M * self.K + self.N * self.K + self.M * self.N - byte_count = element_count * bytes_per_element + operand_bytes_per_element = dtype_to_bytes[self.operand_element_type] + result_bytes_per_element = dtype_to_bytes[self.result_element_type] + byte_count = (self.M * self.K + self.N * self.K) * operand_bytes_per_element + (self.M * self.N) * result_bytes_per_element return byte_count def get_flops(self) -> int: @@ -66,6 +62,8 @@ def generate_mlir(config: GemmConfig): operand_element_type = config.operand_element_type acc_element_type = config.accumulator_element_type result_element_type = config.result_element_type + assert not operand_element_type.startswith('i'), "Integer types not supported yet" + tA = config.tA tB = config.tB mlir_template_A = f""" @@ -77,7 +75,7 @@ def generate_mlir(config: GemmConfig): %2 = linalg.matmul_transpose_a ins(%arg0, %arg1 : tensor<{K}x{M}x{operand_element_type}>, tensor<{K}x{N}x{operand_element_type}>) outs(%1 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}> - %3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> -> tensor<{M}x{N}x{result_element_type}> + %3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}> return %3 : tensor<{M}x{N}x{result_element_type}> }} }} @@ -92,7 +90,7 @@ def generate_mlir(config: GemmConfig): %2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<{M}x{K}x{operand_element_type}>, tensor<{N}x{K}x{operand_element_type}>) outs(%1 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}> - %3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> -> tensor<{M}x{N}x{result_element_type}> + %3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}> return %3 : tensor<{M}x{N}x{result_element_type}> }} }} @@ -105,9 +103,9 @@ def generate_mlir(config: GemmConfig): %0 = tensor.empty() : tensor<{M}x{N}x{acc_element_type}> %1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}> %2 = linalg.matmul ins(%arg0, %arg1 : tensor<{M}x{K}x{operand_element_type}>, tensor<{K}x{N}x{operand_element_type}>) - outs(%1 : tensor<{M}x{N}x{operand_element_type}>) + outs(%1 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}> - %3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> -> tensor<{M}x{N}x{result_element_type}> + %3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}> return %3 : tensor<{M}x{N}x{result_element_type}> }} }} diff --git a/gemmbench/problems.py b/gemmbench/problems.py index e7fd018..69dbe47 100644 --- a/gemmbench/problems.py +++ b/gemmbench/problems.py @@ -1011,7 +1011,7 @@ def get_matching_configs( tag_re = re.compile(tag_regex) matching_configs: list[tuple[str, GemmConfig]] = [] for tag, config in tagged_configs: - if config.dtype not in dtypes: + if config.operand_element_type not in dtypes: continue if f"{config.tA}{config.tB}" not in variants: continue