diff --git a/gemmbench/problems.py b/gemmbench/problems.py index ccdc8b2..53c6e41 100644 --- a/gemmbench/problems.py +++ b/gemmbench/problems.py @@ -905,12 +905,28 @@ def get_gemm_configs() -> list[tuple[str, GemmConfig]]: return all_configs def get_tk_gemm_configs() -> list[tuple[str, GemmConfig]]: + llama13bmatvec_configs = llama13bmatvec("f16") + llama70bmatvec_configs = llama70bmatvec("f16") + llama13bskinny_configs = llama13bskinny("f16") + llama70bskinny_configs = llama70bskinny("f16") + gpt4compute_configs = gpt4compute("f16") + compute_configs = compute("f16") + unet_configs = unet("f16") + configs: list[tuple[str, GemmConfig]] = [] - tk_default_configs = tk_default("f16") - tk_unet_configs = tk_unet("f16") + configs += [("llama13bmatvec", x) for x in llama13bmatvec_configs] + configs += [("llama70bmatvec", x) for x in llama70bmatvec_configs] + configs += [("llama13bskinny", x) for x in llama13bskinny_configs] + configs += [("llama70bskinny", x) for x in llama70bskinny_configs] + configs += [("gpt4compute", x) for x in gpt4compute_configs] + configs += [("compute", x) for x in compute_configs] + configs += [("unet", x) for x in unet_configs] + + # Convert all matmuls to transpose_b. + for _, config in configs: + config.tA = "N" + config.tB = "T" - configs += [("tk", x) for x in tk_default_configs] - configs += [("unet", x) for x in tk_unet_configs] return configs def get_matching_configs(tagged_configs: list[tuple[str, GemmConfig]],