Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix accumulator and result types #14

Merged
merged 5 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions gemmbench/gemm_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args,
target = args.target
extra_compiler_args = list(args.Xiree_compile)
dump_dir = args.dump_dir

args = itertools.starmap(
lambda tag, config: (tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args, tk, dump_dir), configs
)
Expand Down Expand Up @@ -171,7 +171,7 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args,
tflops_per_second = (flops / 1e12) / (benchmark_gemm_mean_time_us / 1e6)

results.append((
index, tag, name, vmfb_hash, config.M, config.N, config.K, config.dtype, config.tA, config.tB,
index, tag, name, vmfb_hash, config.M, config.N, config.K, config.operand_element_type, config.tA, config.tB,
round(benchmark_gemm_mean_time_us, 4),
round(arithmetic_intensity, 4),
round(tflops_per_second, 4),
Expand Down
104 changes: 58 additions & 46 deletions gemmbench/gemm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ class GemmConfig:
K: int
tA: str
tB: str
dtype: str
operand_element_type: str
accumulator_element_type: str
result_element_type: str

def get_name(self) -> str:
name = f"gemm_{self.M}_{self.N}_{self.K}_{self.dtype}"
name = f"gemm_{self.M}_{self.N}_{self.K}_{self.operand_element_type}_{self.accumulator_element_type}"
if self.tA == "T":
name += "_tA"
elif self.tB == "T":
Expand All @@ -27,30 +29,26 @@ def get_name(self) -> str:

def get_inp1(self) -> str:
if self.tA == "T":
inp1 = f"{self.K}x{self.M}x{self.dtype}"
else:
inp1 = f"{self.M}x{self.K}x{self.dtype}"
return inp1
return f"{self.K}x{self.M}x{self.operand_element_type}"
return f"{self.M}x{self.K}x{self.operand_element_type}"

def get_inp2(self) -> str:
if self.tB == "T":
inp2 = f"{self.N}x{self.K}x{self.dtype}"
else:
inp2 = f"{self.K}x{self.N}x{self.dtype}"
return inp2
return f"{self.N}x{self.K}x{self.operand_element_type}"
return f"{self.K}x{self.N}x{self.operand_element_type}"

def get_byte_count(self) -> int:
dtype_bits_map = {
"f32": 32,
"f16": 16,
"bf16": 16,
"f8E4M3FNUZ": 8,
"i8": 8,
"i32": 32,
dtype_to_bytes = {
"f32": 4,
"f16": 2,
"bf16": 2,
"f8E4M3FNUZ": 1,
"i8": 1,
"i32": 4,
}
bytes_per_element = dtype_bits_map[self.dtype] // 8
element_count = self.M * self.K + self.N * self.K + self.M * self.N
byte_count = element_count * bytes_per_element
operand_bytes_per_element = dtype_to_bytes[self.operand_element_type]
result_bytes_per_element = dtype_to_bytes[self.result_element_type]
byte_count = (self.M * self.K + self.N * self.K) * operand_bytes_per_element + (self.M * self.N) * result_bytes_per_element
return byte_count

def get_flops(self) -> int:
Expand All @@ -61,40 +59,54 @@ def generate_mlir(config: GemmConfig):
K = config.K
M = config.M
N = config.N
dtype = config.dtype
operand_element_type = config.operand_element_type
acc_element_type = config.accumulator_element_type
result_element_type = config.result_element_type
assert not operand_element_type.startswith('i'), "Integer types not supported yet"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason integer types not supported? (would it not just work with this i8 operand, i32 accumulator, i8 result type?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well they are not supported elsewhere either -- I didn't want to plumb through int8 support in the same PR


tA = config.tA
tB = config.tB
mlir_template_A = f"""
module {{
func.func @main(%arg0: tensor<{K}x{M}x{dtype}>, %arg1: tensor<{K}x{N}x{dtype}>) -> tensor<{M}x{N}x{dtype}> {{
%cst = arith.constant 0.000000e+00 : {dtype}
%0 = tensor.empty() : tensor<{M}x{N}x{dtype}>
%1 = linalg.fill ins(%cst : {dtype}) outs(%0 : tensor<{M}x{N}x{dtype}>) -> tensor<{M}x{N}x{dtype}>
%2 = linalg.matmul_transpose_a ins(%arg0, %arg1 : tensor<{K}x{M}x{dtype}>, tensor<{K}x{N}x{dtype}>) outs(%1 : tensor<{M}x{N}x{dtype}>) -> tensor<{M}x{N}x{dtype}>
return %2 : tensor<{M}x{N}x{dtype}>
func.func @main(%arg0: tensor<{K}x{M}x{operand_element_type}>, %arg1: tensor<{K}x{N}x{operand_element_type}>) -> tensor<{M}x{N}x{result_element_type}> {{
%cst = arith.constant 0.000000e+00 : {acc_element_type}
%0 = tensor.empty() : tensor<{M}x{N}x{acc_element_type}>
%1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}>
%2 = linalg.matmul_transpose_a ins(%arg0, %arg1 : tensor<{K}x{M}x{operand_element_type}>, tensor<{K}x{N}x{operand_element_type}>)
outs(%1 : tensor<{M}x{N}x{acc_element_type}>)
-> tensor<{M}x{N}x{acc_element_type}>
%3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}>
return %3 : tensor<{M}x{N}x{result_element_type}>
}}
}}
"""

mlir_template_B = f"""
module {{
func.func @main(%arg0: tensor<{M}x{K}x{dtype}>, %arg1: tensor<{N}x{K}x{dtype}>) -> tensor<{M}x{N}x{dtype}> {{
%cst = arith.constant 0.000000e+00 : {dtype}
%0 = tensor.empty() : tensor<{M}x{N}x{dtype}>
%1 = linalg.fill ins(%cst : {dtype}) outs(%0 : tensor<{M}x{N}x{dtype}>) -> tensor<{M}x{N}x{dtype}>
%2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<{M}x{K}x{dtype}>, tensor<{N}x{K}x{dtype}>) outs(%1 : tensor<{M}x{N}x{dtype}>) -> tensor<{M}x{N}x{dtype}>
return %2 : tensor<{M}x{N}x{dtype}>
func.func @main(%arg0: tensor<{M}x{K}x{operand_element_type}>, %arg1: tensor<{N}x{K}x{operand_element_type}>) -> tensor<{M}x{N}x{result_element_type}> {{
%cst = arith.constant 0.000000e+00 : {acc_element_type}
%0 = tensor.empty() : tensor<{M}x{N}x{acc_element_type}>
%1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}>
%2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<{M}x{K}x{operand_element_type}>, tensor<{N}x{K}x{operand_element_type}>)
outs(%1 : tensor<{M}x{N}x{acc_element_type}>)
-> tensor<{M}x{N}x{acc_element_type}>
%3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}>
return %3 : tensor<{M}x{N}x{result_element_type}>
}}
}}
"""

mlir_template = f"""module {{
func.func @main(%arg0: tensor<{M}x{K}x{dtype}>, %arg1: tensor<{K}x{N}x{dtype}>) -> tensor<{M}x{N}x{dtype}> {{
%cst = arith.constant 0.000000e+00 : {dtype}
%0 = tensor.empty() : tensor<{M}x{N}x{dtype}>
%1 = linalg.fill ins(%cst : {dtype}) outs(%0 : tensor<{M}x{N}x{dtype}>) -> tensor<{M}x{N}x{dtype}>
%2 = linalg.matmul ins(%arg0, %arg1 : tensor<{M}x{K}x{dtype}>, tensor<{K}x{N}x{dtype}>) outs(%1 : tensor<{M}x{N}x{dtype}>) -> tensor<{M}x{N}x{dtype}>
return %2 : tensor<{M}x{N}x{dtype}>
mlir_template = f"""
module {{
func.func @main(%arg0: tensor<{M}x{K}x{operand_element_type}>, %arg1: tensor<{K}x{N}x{operand_element_type}>) -> tensor<{M}x{N}x{result_element_type}> {{
%cst = arith.constant 0.000000e+00 : {acc_element_type}
%0 = tensor.empty() : tensor<{M}x{N}x{acc_element_type}>
%1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}>
%2 = linalg.matmul ins(%arg0, %arg1 : tensor<{M}x{K}x{operand_element_type}>, tensor<{K}x{N}x{operand_element_type}>)
outs(%1 : tensor<{M}x{N}x{acc_element_type}>)
-> tensor<{M}x{N}x{acc_element_type}>
%3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}>
return %3 : tensor<{M}x{N}x{result_element_type}>
}}
}}
"""
Expand Down Expand Up @@ -158,12 +170,12 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:

# repeat represents the results of the loop
tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD)

shape = [config.M, config.N, config.K]
dtype_map = {
operand_element_type_map = {
"f16": torch.float16,
}
dtype = dtype_map[config.dtype]
operand_element_type = operand_element_type_map[config.operand_element_type]

hyperparams = {
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
Expand All @@ -180,9 +192,9 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
with tk.gen.TestLaunchContext(
hyperparams, canonicalize=True, run=True, run_config=config
):
a = torch.randn(shape[0], shape[2], dtype=dtype)
b = torch.randn(shape[1], shape[2], dtype=dtype)
c = torch.zeros(shape[0], shape[1], dtype=torch.float32)
a = torch.randn(shape[0], shape[2], operand_element_type=operand_element_type)
kuhar marked this conversation as resolved.
Show resolved Hide resolved
b = torch.randn(shape[1], shape[2], operand_element_type=operand_element_type)
c = torch.zeros(shape[0], shape[1], operand_element_type=torch.float32)
kuhar marked this conversation as resolved.
Show resolved Hide resolved
mb = gemm(a, b, c)

return mb.module_op.get_asm()
Expand Down
Loading
Loading