Skip to content

Commit

Permalink
Add more shapes to tk gemmbench
Browse files Browse the repository at this point in the history
This PR adds more shapes to tk gemm testing.
It also modifies them to be f16 and matmul_transpose_b.
  • Loading branch information
harsh-amd committed Oct 2, 2024
1 parent 4b1345b commit dddad66
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions gemmbench/problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,12 +905,28 @@ def get_gemm_configs() -> list[tuple[str, GemmConfig]]:
return all_configs

def get_tk_gemm_configs() -> list[tuple[str, GemmConfig]]:
llama13bmatvec_configs = llama13bmatvec("f16")
llama70bmatvec_configs = llama70bmatvec("f16")
llama13bskinny_configs = llama13bskinny("f16")
llama70bskinny_configs = llama70bskinny("f16")
gpt4compute_configs = gpt4compute("f16")
compute_configs = compute("f16")
unet_configs = unet("f16")

configs: list[tuple[str, GemmConfig]] = []
tk_default_configs = tk_default("f16")
tk_unet_configs = tk_unet("f16")
configs += [("llama13bmatvec", x) for x in llama13bmatvec_configs]
configs += [("llama70bmatvec", x) for x in llama70bmatvec_configs]
configs += [("llama13bskinny", x) for x in llama13bskinny_configs]
configs += [("llama70bskinny", x) for x in llama70bskinny_configs]
configs += [("gpt4compute", x) for x in gpt4compute_configs]
configs += [("compute", x) for x in compute_configs]
configs += [("unet", x) for x in unet_configs]

# Convert all matmuls to transpose_b.
for _, config in configs:
config.tA = "N"
config.tB = "T"

configs += [("tk", x) for x in tk_default_configs]
configs += [("unet", x) for x in tk_unet_configs]
return configs

def get_matching_configs(tagged_configs: list[tuple[str, GemmConfig]],
Expand Down

0 comments on commit dddad66

Please sign in to comment.