diff --git a/gemmbench/problems.py b/gemmbench/problems.py index 385a343..ab2dc7d 100644 --- a/gemmbench/problems.py +++ b/gemmbench/problems.py @@ -670,6 +670,15 @@ def is_compute_bound(M: int, N: int, K: int, dtype: str) -> bool: (4096, 5120, 640), ] +SQUARE = [ + (128, 128, 128), + (256, 256, 256), + (512, 512, 512), + (1024, 1024, 1024), + (2048, 2048, 2048), + (4096, 4096, 4096), + (8192, 8192, 8192), +] def llama13bmatvec(dtype: str) -> list[GemmConfig]: configs = [] @@ -953,6 +962,24 @@ def unet(dtype: str) -> list[GemmConfig]: return configs +def square(dtype: str) -> list[GemmConfig]: + configs = [] + for m, n, k in SQUARE: + configs.append( + GemmConfig( + m, + n, + k, + "N", + "T", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype), + ) + ) + return configs + + def get_gemm_configs() -> list[tuple[str, GemmConfig]]: llama13bmatvec_configs: list[GemmConfig] = [] llama13bmatvec_configs += llama13bmatvec("f16") @@ -982,6 +1009,8 @@ def get_gemm_configs() -> list[tuple[str, GemmConfig]]: unet_configs += unet("f16") unet_configs += unet("bf16") + square_configs: list[GemmConfig] = square("f16") + all_configs: list[tuple[str, GemmConfig]] = [] all_configs += [("llama13bmatvec", x) for x in llama13bmatvec_configs] all_configs += [("llama70bmatvec", x) for x in llama70bmatvec_configs] @@ -991,6 +1020,7 @@ def get_gemm_configs() -> list[tuple[str, GemmConfig]]: all_configs += [("llama70bmemory", x) for x in llama70bmemory_configs] all_configs += [("compute", x) for x in compute_configs] all_configs += [("unet", x) for x in unet_configs] + all_configs += [("square", x) for x in square_configs] all_configs += [("tk", x) for x in tk_default_configs] return all_configs