Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flags to filter gemm configs #3

Merged
merged 2 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions gemmbench/gemm_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import sys
from utils import *
from gemm_utils import *
from problems import get_gemm_configs, get_tk_gemm_configs
from problems import get_gemm_configs, get_tk_gemm_configs, get_matching_configs


def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args, tk):
Expand All @@ -40,33 +40,48 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args,
default=[],
help="Extra command line arguments passed to the IREE compiler. This can be specified multiple times to pass multiple arguments."
)
parser.add_argument(
"--dtypes", action='append', help="List of data types to benchmark. Defaults to all supported types."
)
parser.add_argument(
"--variants",
action='append',
help="List of matmul variants to benchmark. Default to all variants: NN, NT, TN, and TT."
)
parser.add_argument(
"--tag-regex",
help="Regular expression for allowed benchmark tags. Defaults to all tags allowed.",
default=".*"
)
parser.add_argument("--roofline", help="Comma separated csv file list to generate roofline plot with", default=None)
parser.add_argument("--plot", help="location to save plot", default=None)
parser.add_argument("--batch", help="roofline on certain batch", type=int, default=None)
parser.add_argument("--dtype", help="roofline on certain dtype", default=None)
parser.add_argument("--model", help="roofline on certain model", default=None)
parser.add_argument(
"--tk",
action="store_true",
default=False,
help="Option to run gemm kernels using Turbine Kernels",
help="Run gemm kernels using Turbine Kernels",
)

args = parser.parse_args()
# Handle default values here, since 'append' is not compatible with defaulted lists.
requested_dtypes = ["f16", "bf16"] if not args.dtypes else list(args.dtypes)
requested_variants = ["NN", "NT", "TN", "TT"] if not args.variants else list(args.variants)

logging.basicConfig(level=args.log_level)

if args.roofline:
roofline(args.roofline, args.plot, args.batch, args.dtype, args.model)
for dtype in requested_dtypes:
roofline(args.roofline, f"{args.plot}_{dtype}", args.batch, dtype, args.model)
sys.exit()

tk = args.tk
if tk:
configs = get_tk_gemm_configs()
else:
configs = get_gemm_configs()
configs = get_tk_gemm_configs() if tk else get_gemm_configs()
configs = get_matching_configs(configs, requested_dtypes, requested_variants, args.tag_regex)
print(f"Generated {len(configs)} gemm configs.")

num_cpus = max(1, cpu_count() - 20)
num_cpus = max(1, max(cpu_count() // 2, 1))
print(f"Using {num_cpus} CPUs for parallel processing.")

manager = Manager()
Expand Down Expand Up @@ -125,7 +140,7 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args,
]

if tk:
exec_args += ["--function=isolated_benchmark"]
exec_args += ["--function=isolated_benchmark"]
else:
exec_args += ["--function=main"]

Expand Down
67 changes: 48 additions & 19 deletions gemmbench/problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

from gemm_utils import GemmConfig

import re


def is_compute_bound(M, N, K, bpe):
"""Is this GEMM compute (or memory) bound?"""
magic_ratio = 64
Expand Down Expand Up @@ -860,34 +863,46 @@ def unet(dtype: str) -> list[GemmConfig]:
return configs

def get_gemm_configs() -> list[tuple[str, GemmConfig]]:
configs: list[tuple[str, GemmConfig]] = []
llama13bmatvec_configs = llama13bmatvec("f16")
llama13bmatvec_configs: list[GemmConfig] = []
llama13bmatvec_configs += llama13bmatvec("f16")
llama13bmatvec_configs += llama13bmatvecbf16("bf16")
llama70bmatvec_configs = llama70bmatvec("f16")

llama70bmatvec_configs: list[GemmConfig] = []
llama70bmatvec_configs += llama70bmatvec("f16")
llama70bmatvec_configs += llama70bmatvecbf16("bf16")
llama13bskinny_configs = llama13bskinny("f16")

llama13bskinny_configs: list[GemmConfig] = []
llama13bskinny_configs += llama13bskinny("f16")
llama13bskinny_configs += llama13bskinnybf16("bf16")
llama70bskinny_configs = llama70bskinny("f16")

llama70bskinny_configs: list[GemmConfig] = []
llama70bskinny_configs += llama70bskinny("f16")
llama70bskinny_configs += llama70bskinnybf16("bf16")

gpt4compute_configs = gpt4compute("f16")
llama70bmemory_configs = llama70bmemory("bf16")
tk_default_configs = tk_default("f16")
compute_configs = compute("f16")

compute_configs: list[GemmConfig] = []
compute_configs += compute("f16")
compute_configs += compute("bf16")
unet_configs = unet("f16")

unet_configs: list[GemmConfig] = []
unet_configs += unet("f16")
unet_configs += unet("bf16")

configs += [("llama13bmatvec", x) for x in llama13bmatvec_configs]
configs += [("llama70bmatvec", x) for x in llama70bmatvec_configs]
configs += [("llama13bskinny", x) for x in llama13bskinny_configs]
configs += [("llama70bskinny", x) for x in llama70bskinny_configs]
configs += [("gpt4compute", x) for x in gpt4compute_configs]
configs += [("llama70bmemory", x) for x in llama70bmemory_configs]
configs += [("compute", x) for x in compute_configs]
configs += [("unet", x) for x in unet_configs]
configs += [("tk", x) for x in tk_default_configs]

return configs
all_configs: list[tuple[str, GemmConfig]] = []
all_configs += [("llama13bmatvec", x) for x in llama13bmatvec_configs]
all_configs += [("llama70bmatvec", x) for x in llama70bmatvec_configs]
all_configs += [("llama13bskinny", x) for x in llama13bskinny_configs]
all_configs += [("llama70bskinny", x) for x in llama70bskinny_configs]
all_configs += [("gpt4compute", x) for x in gpt4compute_configs]
all_configs += [("llama70bmemory", x) for x in llama70bmemory_configs]
all_configs += [("compute", x) for x in compute_configs]
all_configs += [("unet", x) for x in unet_configs]
all_configs += [("tk", x) for x in tk_default_configs]

return all_configs

def get_tk_gemm_configs() -> list[tuple[str, GemmConfig]]:
configs: list[tuple[str, GemmConfig]] = []
Expand All @@ -896,5 +911,19 @@ def get_tk_gemm_configs() -> list[tuple[str, GemmConfig]]:

configs += [("tk", x) for x in tk_default_configs]
configs += [("unet", x) for x in tk_unet_configs]

return configs

def get_matching_configs(tagged_configs: list[tuple[str, GemmConfig]],
dtypes: list[str], variants: list[str], tag_regex: str) -> list[tuple[str, GemmConfig]]:
tag_re = re.compile(tag_regex)
matching_configs: list[tuple[str, GemmConfig]] = []
for tag, config in tagged_configs:
if config.dtype not in dtypes:
continue
if f"{config.tA}{config.tB}" not in variants:
continue
if not tag_re.match(tag):
continue
matching_configs.append((tag, config))

return matching_configs