Skip to content

Commit

Permalink
Add SDXL conv shapes, add tool to plot roofline percentages
Browse files Browse the repository at this point in the history
Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
  • Loading branch information
Max191 committed Oct 8, 2024
1 parent 1aa0004 commit 50d6245
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 20 deletions.
2 changes: 1 addition & 1 deletion common_tools/kernel_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class KernelStats:
@staticmethod
def get_csv_header() -> list[str]:
return (
["Name"] + IsaStats.get_csv_header() + ConfiguredMlirStats.get_csv_header()
["name"] + IsaStats.get_csv_header() + ConfiguredMlirStats.get_csv_header()
)

def get_values(self):
Expand Down
80 changes: 80 additions & 0 deletions common_tools/plot_roofline_percents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import argparse
import pandas as pd
import matplotlib.pyplot as plt

def plot_roofline_vs_column(kernel_stat_path, benchmark_stat_path, out_path, param_name, boxplot):
kernel_df = pd.read_csv(kernel_stat_path)
benchmark_df = pd.read_csv(benchmark_stat_path)
if param_name not in kernel_df.columns and param_name not in benchmark_df.columns:
print(f"`{param_name}` column not found in {kernel_stat_path} or {benchmark_stat_path}.\n")
return False
if "roofline_percent" not in benchmark_df.columns:
print(f"`roofline_percent` column not found in {benchmark_stat_path}.\n")
return False
if "name" not in benchmark_df.columns or "name" not in kernel_df.columns:
print(f"`name` column not found in {kernel_stat_path} and {benchmark_stat_path}.\n")
return False
df = kernel_df.merge(benchmark_df, on="name")
if boxplot:
axes = df[[param_name, "roofline_percent"]].boxplot(
by=param_name,
figsize=(12,12)
)
else:
axes = df.plot(
param_name,
"roofline_percent",
kind="scatter",
figsize=(12,12)
)
plt.xlabel(param_name)
plt.ylabel("roofline_percent")
plt.savefig(out_path, dpi=300, bbox_inches='tight')
plt.close()
return True


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Plotting tool to correlate kernel parameters with roofline percentages."
)
parser.add_argument(
"--kernel_stats_csv",
help="The path to the input csv containing kernel metrics.",
type=str,
default=None
)
parser.add_argument(
"--benchmark_csv",
help="The path to the input csv containing benchmarks.",
type=str,
default=None
)
parser.add_argument(
"--out_path",
help="The path to save the resulting plot image.",
type=str,
default=None
)
parser.add_argument(
"--parameter",
help="The name of the column with the parameter to use as the x-axis.",
type=str,
default=None
)
parser.add_argument(
"--boxplot",
help="Use a boxplot graph, with one boxplot per parameter value.",
action=argparse.BooleanOptionalAction,
type=bool,
default=False
)
args = parser.parse_args()

succeeded = plot_roofline_vs_column(
args.kernel_stats_csv, args.benchmark_csv, args.out_path, args.parameter, args.boxplot
)
if succeeded:
print(f"Plot saved to {args.out_path}\n")
else:
print(f"Failed to generate plot.\n")
49 changes: 36 additions & 13 deletions convbench/shark_conv.py → convbench/conv_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
import sys
from utils import *
from conv_utils import *
from problems import get_conv_configs
from problems import get_conv_configs, get_conv_test_configs


def compile_conv(tag, config, kernel_dir, vmfb_dir):
mlir_file, vmfb_file = compile_conv_config(config, kernel_dir, vmfb_dir)
return (tag, config, mlir_file, vmfb_file)
def compile_conv(tag, config, kernel_dir, vmfb_dir, extra_compiler_args):
mlir_file, vmfb_file, dump_path = compile_conv_config(config, kernel_dir, vmfb_dir, extra_compiler_args)
return (tag, config, mlir_file, vmfb_file, dump_path)


if __name__ == "__main__":
Expand All @@ -26,6 +26,12 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):
type=str.upper,
help="Set the logging level",
)
parser.add_argument(
"--Xiree_compile",
nargs='+',
default=[],
help="Extra command line arguments passed to the IREE compiler. The flags need to be specified without the `--` or `-`."
)
parser.add_argument(
"--roofline",
help="Comma seperated csv file list to generate roofline plot with",
Expand All @@ -43,7 +49,8 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):
roofline(args.roofline, args.plot, args.batch, args.dtype, args.model)
sys.exit()

configs = get_conv_configs()
configs = get_conv_test_configs()
# configs = get_conv_configs()
print(f"Generated {len(configs)} conv configs.")

num_cpus = max(1, cpu_count() - 20)
Expand All @@ -58,16 +65,17 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):
kernel_dir.mkdir(parents=True, exist_ok=True)
vmfb_dir.mkdir(parents=True, exist_ok=True)

extra_compiler_args = ['--' + x for x in list(args.Xiree_compile)]
args = itertools.starmap(
lambda tag, config: (tag, config, kernel_dir, vmfb_dir), configs
lambda tag, config: (tag, config, kernel_dir, vmfb_dir, extra_compiler_args), configs
)
with Pool(num_cpus) as pool:
compilation_results = list(tqdm(pool.starmap(compile_conv, list(args))))

error_count = 0
for tag, config, mlir_file, vmfb_file in compilation_results:
for tag, config, mlir_file, vmfb_file, dump_path in compilation_results:
if vmfb_file:
vmfb_dict[vmfb_file] = (tag, config)
vmfb_dict[vmfb_file] = (tag, config, dump_path)
else:
error_count += 1
print(
Expand All @@ -84,7 +92,7 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):
os.makedirs(csv_dir)

for vmfb_filename, value in vmfb_dict.items():
tag, config = value
tag, config, dump_path = value
name = config.get_name()

image_shape = config.get_img_shape()
Expand All @@ -101,17 +109,28 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):
"--benchmark_repetitions=3",
]

print(f"Running {vmfb_filename}...")
# iree benchmark kernels
ret_value, cmd_out, cmd_stderr = run_iree_command(exec_args)
ok = ret_value == 0
benchmark_gemm_mean_time_ms = bench_summary_process(ret_value, cmd_out)
benchmark_gemm_mean_time_us = benchmark_gemm_mean_time_ms * 1000
benchmark_conv_mean_time_ms = bench_summary_process(ret_value, cmd_out)
benchmark_conv_mean_time_us = benchmark_conv_mean_time_ms * 1000

flops = config.get_flops()
byte_count = config.get_byte_count()

arithmetic_intensity = flops / byte_count
tflops_per_second = (flops / 1e12) / (benchmark_gemm_mean_time_us / 1e6)
tflops_per_second = (flops / 1e12) / (benchmark_conv_mean_time_us / 1e6)

# Compute percentage of the roofline.
tflops_map = {
"f32": 653.7,
"f16": 1307.4,
"bf16": 1307.4,
"f8E4M3FNUZ": 2614.9,
"i8": 2614.9,
}
roofline_tflops = tflops_map[config.input_dtype]

results.append(
(
Expand All @@ -128,9 +147,11 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):
config.S,
config.input_dtype,
config.output_dtype,
round(benchmark_gemm_mean_time_us, 4),
round(benchmark_conv_mean_time_us, 4),
round(arithmetic_intensity, 4),
round(tflops_per_second, 4),
roofline_tflops,
round(tflops_per_second / roofline_tflops, 4),
ok,
)
)
Expand All @@ -153,6 +174,8 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):
"mean_microseconds",
"arithmetic_intensity",
"tflops",
"roofline_tflops",
"roofline_percent",
"ok",
]

Expand Down
10 changes: 6 additions & 4 deletions convbench/conv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,12 @@ def generate_mlir(config: ConvConfig):


def compile_conv_config(
config: ConvConfig, kernel_dir: Path, vmfb_dir: Path
config: ConvConfig, kernel_dir: Path, vmfb_dir: Path, extra_compiler_args: list[str]
) -> tuple[Path, Optional[Path]]:
mlir_file = kernel_dir / (config.get_name() + ".mlir")
vmfb_file = vmfb_dir / (config.get_name() + ".vmfb")
dump_file = kernel_dir / (config.get_name() + ".stderr.mlir")
files_path = vmfb_dir / config.get_name()

# Generate mlir content
mlir_content = generate_mlir(config)
Expand All @@ -188,7 +189,8 @@ def compile_conv_config(
"--iree-hal-target-device=hip",
# Device: MI300x
"--iree-hip-target=gfx942",
]
f"--iree-hal-dump-executable-files-to={files_path}",
] + extra_compiler_args

print(" ".join(exec_args))

Expand All @@ -203,6 +205,6 @@ def compile_conv_config(
print(f"Failed to compile {mlir_file}. Error dumped in {error_file}")
with open(error_file, "w") as f:
f.write(stderr.decode("utf-8"))
return mlir_file, None
return mlir_file, None, None

return mlir_file, vmfb_file
return mlir_file, vmfb_file, files_path
56 changes: 54 additions & 2 deletions convbench/problems.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,38 @@
from conv_utils import ConvConfig


def unet_sweep(op: str, input_dtype: str, output_dtype: str) -> list[ConvConfig]:
configs = []
for B in [1, 2, 4]:
configs.append(ConvConfig(B, 128, 128, 16, 3, 3, 320, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 320, 3, 3, 320, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 320, 3, 3, 320, 2, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 320, 3, 3, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 640, 3, 3, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 320, 1, 1, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 640, 3, 3, 640, 2, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 640, 3, 3, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 1280, 3, 3, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 640, 1, 1, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 2560, 3, 3, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 2560, 1, 1, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 1920, 3, 3, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 1920, 1, 1, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1280, 3, 3, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1920, 3, 3, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1920, 1, 1, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1280, 3, 3, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1280, 1, 1, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 960, 3, 3, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 960, 1, 1, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 640, 3, 3, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 960, 3, 3, 320, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 960, 1, 1, 320, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 640, 3, 3, 320, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 640, 1, 1, 320, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 320, 3, 3, 16, 1, op, input_dtype, output_dtype))
return configs

def resnet_sweep(op: str, input_dtype: str, output_dtype: str) -> list[ConvConfig]:
configs = []
for B in [1, 2, 4, 8, 16, 32, 48]:
Expand All @@ -19,9 +51,29 @@ def resnet_sweep(op: str, input_dtype: str, output_dtype: str) -> list[ConvConfi

def get_conv_configs() -> list[tuple[str, ConvConfig]]:
configs: list[tuple[str, ConvConfig]] = []
resnet_configs = resnet_sweep("conv_2d_nchw_fchw", "f32", "f32")

# Resnet
resnet_configs = []
resnet_configs += resnet_sweep("conv_2d_nhwc_hwcf", "f16", "f32")
resnet_configs += resnet_sweep("conv_2d_nhwc_hwcf_q", "i8", "i32")
configs += [("resnet", x) for x in resnet_configs]

# Unet
unet_configs = []
unet_configs += unet_sweep("conv_2d_nhwc_hwcf", "f16", "f32")
unet_configs += unet_sweep("conv_2d_nhwc_hwcf_q", "i8", "i32")
configs += [("unet", x) for x in unet_configs]

return configs

# Test function to run only a few chosen shapes
def get_conv_test_configs() -> list[tuple[str, ConvConfig]]:
configs: list[tuple[str, ConvConfig]] = []

unet_configs = []
unet_configs.append(ConvConfig(1,128,128,16,3,3,320,1, "conv_2d_nhwc_hwcf_q", "i8", "i32"))
unet_configs.append(ConvConfig(1,32,32,640,1,1,1280,1, "conv_2d_nhwc_hwcf_q", "i8", "i32"))

configs += [("resnet_sweep", x) for x in resnet_configs]
configs += [("unet", x) for x in unet_configs]

return configs

0 comments on commit 50d6245

Please sign in to comment.