Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
kuhar committed Oct 7, 2024
1 parent 48a258b commit 9f34a73
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 24 deletions.
2 changes: 1 addition & 1 deletion gemmbench/gemm_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args,
tflops_per_second = (flops / 1e12) / (benchmark_gemm_mean_time_us / 1e6)

results.append((
index, tag, name, vmfb_hash, config.M, config.N, config.K, config.dtype, config.tA, config.tB,
index, tag, name, vmfb_hash, config.M, config.N, config.K, config.operand_element_type, config.tA, config.tB,
round(benchmark_gemm_mean_time_us, 4),
round(arithmetic_intensity, 4),
round(tflops_per_second, 4),
Expand Down
42 changes: 20 additions & 22 deletions gemmbench/gemm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,30 +29,26 @@ def get_name(self) -> str:

def get_inp1(self) -> str:
if self.tA == "T":
inp1 = f"{self.K}x{self.M}x{self.operand_element_type}"
else:
inp1 = f"{self.M}x{self.K}x{self.operand_element_type}"
return inp1
return f"{self.K}x{self.M}x{self.operand_element_type}"
return f"{self.M}x{self.K}x{self.operand_element_type}"

def get_inp2(self) -> str:
if self.tB == "T":
inp2 = f"{self.N}x{self.K}x{self.operand_element_type}"
else:
inp2 = f"{self.K}x{self.N}x{self.operand_element_type}"
return inp2
return f"{self.N}x{self.K}x{self.operand_element_type}"
return f"{self.K}x{self.N}x{self.operand_element_type}"

def get_byte_count(self) -> int:
operand_element_type_bits_map = {
"f32": 32,
"f16": 16,
"bf16": 16,
"f8E4M3FNUZ": 8,
"i8": 8,
"i32": 32,
dtype_to_bytes = {
"f32": 4,
"f16": 2,
"bf16": 2,
"f8E4M3FNUZ": 1,
"i8": 1,
"i32": 4,
}
bytes_per_element = operand_element_type_bits_map[self.operand_element_type] // 8
element_count = self.M * self.K + self.N * self.K + self.M * self.N
byte_count = element_count * bytes_per_element
operand_bytes_per_element = dtype_to_bytes[self.operand_element_type]
result_bytes_per_element = dtype_to_bytes[self.result_element_type]
byte_count = (self.M * self.K + self.N * self.K) * operand_bytes_per_element + (self.M * self.N) * result_bytes_per_element
return byte_count

def get_flops(self) -> int:
Expand All @@ -66,6 +62,8 @@ def generate_mlir(config: GemmConfig):
operand_element_type = config.operand_element_type
acc_element_type = config.accumulator_element_type
result_element_type = config.result_element_type
assert not operand_element_type.startswith('i'), "Integer types not supported yet"

tA = config.tA
tB = config.tB
mlir_template_A = f"""
Expand All @@ -77,7 +75,7 @@ def generate_mlir(config: GemmConfig):
%2 = linalg.matmul_transpose_a ins(%arg0, %arg1 : tensor<{K}x{M}x{operand_element_type}>, tensor<{K}x{N}x{operand_element_type}>)
outs(%1 : tensor<{M}x{N}x{acc_element_type}>)
-> tensor<{M}x{N}x{acc_element_type}>
%3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> -> tensor<{M}x{N}x{result_element_type}>
%3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}>
return %3 : tensor<{M}x{N}x{result_element_type}>
}}
}}
Expand All @@ -92,7 +90,7 @@ def generate_mlir(config: GemmConfig):
%2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<{M}x{K}x{operand_element_type}>, tensor<{N}x{K}x{operand_element_type}>)
outs(%1 : tensor<{M}x{N}x{acc_element_type}>)
-> tensor<{M}x{N}x{acc_element_type}>
%3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> -> tensor<{M}x{N}x{result_element_type}>
%3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}>
return %3 : tensor<{M}x{N}x{result_element_type}>
}}
}}
Expand All @@ -105,9 +103,9 @@ def generate_mlir(config: GemmConfig):
%0 = tensor.empty() : tensor<{M}x{N}x{acc_element_type}>
%1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}>
%2 = linalg.matmul ins(%arg0, %arg1 : tensor<{M}x{K}x{operand_element_type}>, tensor<{K}x{N}x{operand_element_type}>)
outs(%1 : tensor<{M}x{N}x{operand_element_type}>)
outs(%1 : tensor<{M}x{N}x{acc_element_type}>)
-> tensor<{M}x{N}x{acc_element_type}>
%3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> -> tensor<{M}x{N}x{result_element_type}>
%3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}>
return %3 : tensor<{M}x{N}x{result_element_type}>
}}
}}
Expand Down
2 changes: 1 addition & 1 deletion gemmbench/problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,7 +1011,7 @@ def get_matching_configs(
tag_re = re.compile(tag_regex)
matching_configs: list[tuple[str, GemmConfig]] = []
for tag, config in tagged_configs:
if config.dtype not in dtypes:
if config.operand_element_type not in dtypes:
continue
if f"{config.tA}{config.tB}" not in variants:
continue
Expand Down

0 comments on commit 9f34a73

Please sign in to comment.