diff --git a/gemmbench/problems.py b/gemmbench/problems.py index 43d233c..26ee76e 100644 --- a/gemmbench/problems.py +++ b/gemmbench/problems.py @@ -811,17 +811,6 @@ def test(dtype: str) -> list[GemmConfig]: configs.append(GemmConfig(M, N, K, "N", "N", dtype)) return configs -def tk_default(dtype: str) -> list[GemmConfig]: - """TK Shapes.""" - configs = [] - M, N, K = 1024, 5120, 640 - configs.append(GemmConfig(M, N, K, "N", "T", dtype)) - M, N, K = 2048, 10240, 1280 - configs.append(GemmConfig(M, N, K, "N", "T", dtype)) - M, N, K = 4096, 20480, 2560 - configs.append(GemmConfig(M, N, K, "N", "T", dtype)) - return configs - def tk_unet(dtype: str) -> list[GemmConfig]: """UNET Shapes for TK.""" configs = [] @@ -871,7 +860,6 @@ def get_gemm_configs() -> list[tuple[str, GemmConfig]]: llama70bskinny_configs += llama70bskinnybf16("bf16") gpt4compute_configs = gpt4compute("f16") llama70bmemory_configs = llama70bmemory("bf16") - tk_default_configs = tk_default("f16") compute_configs = compute("f16") compute_configs += compute("bf16") unet_configs = unet("f16") @@ -885,16 +873,27 @@ def get_gemm_configs() -> list[tuple[str, GemmConfig]]: configs += [("llama70bmemory", x) for x in llama70bmemory_configs] configs += [("compute", x) for x in compute_configs] configs += [("unet", x) for x in unet_configs] - configs += [("tk", x) for x in tk_default_configs] return configs def get_tk_gemm_configs() -> list[tuple[str, GemmConfig]]: configs: list[tuple[str, GemmConfig]] = [] - tk_default_configs = tk_default("f16") - tk_unet_configs = tk_unet("f16") + llama13bmatvec_configs = llama13bmatvec("f16") + llama70bmatvec_configs = llama70bmatvec("f16") + llama13bskinny_configs = llama13bskinny("f16") + llama70bskinny_configs = llama70bskinny("f16") + gpt4compute_configs = gpt4compute("f16") + llama70bmemory_configs = llama70bmemory("bf16") + compute_configs = compute("f16") + unet_configs = unet("f16") - configs += [("tk", x) for x in tk_default_configs] - configs += [("unet", x) for x in tk_unet_configs] + configs += [("llama13bmatvec", x) for x in llama13bmatvec_configs] + configs += [("llama70bmatvec", x) for x in llama70bmatvec_configs] + configs += [("llama13bskinny", x) for x in llama13bskinny_configs] + configs += [("llama70bskinny", x) for x in llama70bskinny_configs] + configs += [("gpt4compute", x) for x in gpt4compute_configs] + configs += [("llama70bmemory", x) for x in llama70bmemory_configs] + configs += [("compute", x) for x in compute_configs] + configs += [("unet", x) for x in unet_configs] return configs