Skip to content

Commit

Permalink
Add SDXL conv shapes, extra iree flags option, tool to plot roofline …
Browse files Browse the repository at this point in the history
…percentages (#19)

- Adds the SDXL convolution shapes to convbench
 - Adds the option to pass Xiree_compile flags in convbench
 - Adds percentage of roofline to the collected conv benchmark metrics
- Adds a tool to plot roofline percents against kernel parameters given
the benchmarks and kernel stats
- Renames `shark_conv.py` to `conv_bench.py` to match gemm and attention
formats

---------

Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
  • Loading branch information
Max191 authored Oct 28, 2024
1 parent 982eb72 commit 4621947
Show file tree
Hide file tree
Showing 8 changed files with 201 additions and 32 deletions.
16 changes: 8 additions & 8 deletions .github/workflows/run_bench.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
- name: Convolutions
run: |
source bench_venv/bin/activate
python convbench/shark_conv.py
python convbench/conv_bench.py
- name: Attention
run: |
Expand All @@ -55,13 +55,13 @@ jobs:
- name: Roofline Plots
run: |
source bench_venv/bin/activate
python convbench/shark_conv.py --roofline results/iree_conv.csv --plot results/iree_conv_i8.png --dtype i8
python convbench/shark_conv.py --roofline results/iree_conv.csv --plot results/iree_conv_f32.png --dtype f32
python convbench/shark_conv.py --roofline results/iree_attention.csv --plot results/iree_attention_fp16.png --dtype f16
python convbench/shark_conv.py --roofline results/iree_attention.csv --plot results/iree_attention_fp8.png --dtype f8E4M3FNUZ
python convbench/shark_conv.py --roofline results/iree_gemm.csv --plot results/iree_gemm.png
python convbench/shark_conv.py --roofline results/iree_gemm_tk.csv --plot results/iree_gemm_tk.png
python convbench/shark_conv.py --roofline results/iree_gemm.csv,results/iree_gemm_tk.csv,results/iree_attention.csv,results/iree_conv.csv --plot results/combined.png
python convbench/conv_bench.py --roofline results/iree_conv.csv --plot results/iree_conv_i8.png --dtype i8
python convbench/conv_bench.py --roofline results/iree_conv.csv --plot results/iree_conv_f16.png --dtype f16
python convbench/conv_bench.py --roofline results/iree_attention.csv --plot results/iree_attention_fp16.png --dtype f16
python convbench/conv_bench.py --roofline results/iree_attention.csv --plot results/iree_attention_fp8.png --dtype f8E4M3FNUZ
python convbench/conv_bench.py --roofline results/iree_gemm.csv --plot results/iree_gemm.png
python convbench/conv_bench.py --roofline results/iree_gemm_tk.csv --plot results/iree_gemm_tk.png
python convbench/conv_bench.py --roofline results/iree_gemm.csv,results/iree_gemm_tk.csv,results/iree_attention.csv,results/iree_conv.csv --plot results/combined.png
- name: Upload benchmark results
uses: actions/upload-artifact@v4
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Refer to the respective problems.py file in the folder to see which shapes are b
### Convolution Benchmarking

```
python convbench/shark_conv.py
python convbench/conv_bench.py
```

### GEMM Benchmarking
Expand All @@ -50,7 +50,7 @@ python attentionbench/attention_bench.py
If you want to generate a roofline plot, you can call any of the suites for now with the --roofline option (provide a commma seperated list if you want to generate for multiple benchmarks combined):

```
python convbench/shark_conv.py --roofline results/iree_conv.csv,results/iree_attention.csv --plot results/attn_conv.png
python convbench/conv_bench.py --roofline results/iree_conv.csv,results/iree_attention.csv --plot results/attn_conv.png
```

If you want to generate a roofline plot for a certain data type, model, or batch size you can do:
Expand Down
2 changes: 1 addition & 1 deletion common_tools/kernel_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class KernelStats:
@staticmethod
def get_csv_header() -> list[str]:
return (
["Name"] + IsaStats.get_csv_header() + ConfiguredMlirStats.get_csv_header()
["name"] + IsaStats.get_csv_header() + ConfiguredMlirStats.get_csv_header()
)

def get_values(self):
Expand Down
80 changes: 80 additions & 0 deletions common_tools/plot_roofline_percents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import argparse
import pandas as pd
import matplotlib.pyplot as plt

def plot_roofline_vs_column(kernel_stat_path, benchmark_stat_path, out_path, param_name, boxplot):
kernel_df = pd.read_csv(kernel_stat_path)
benchmark_df = pd.read_csv(benchmark_stat_path)
if param_name not in kernel_df.columns and param_name not in benchmark_df.columns:
print(f"`{param_name}` column not found in {kernel_stat_path} or {benchmark_stat_path}.\n")
return False
if "roofline_percent" not in benchmark_df.columns:
print(f"`roofline_percent` column not found in {benchmark_stat_path}.\n")
return False
if "name" not in benchmark_df.columns or "name" not in kernel_df.columns:
print(f"`name` column not found in {kernel_stat_path} and {benchmark_stat_path}.\n")
return False
df = kernel_df.merge(benchmark_df, on="name")
if boxplot:
axes = df[[param_name, "roofline_percent"]].boxplot(
by=param_name,
figsize=(12,12)
)
else:
axes = df.plot(
param_name,
"roofline_percent",
kind="scatter",
figsize=(12,12)
)
plt.xlabel(param_name)
plt.ylabel("roofline_percent")
plt.savefig(out_path, dpi=300, bbox_inches='tight')
plt.close()
return True


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Plotting tool to correlate kernel parameters with roofline percentages."
)
parser.add_argument(
"--kernel_stats_csv",
help="The path to the input csv containing kernel metrics.",
type=str,
default=None
)
parser.add_argument(
"--benchmark_csv",
help="The path to the input csv containing benchmarks.",
type=str,
default=None
)
parser.add_argument(
"--out_path",
help="The path to save the resulting plot image.",
type=str,
default=None
)
parser.add_argument(
"--parameter",
help="The name of the column with the parameter to use as the x-axis.",
type=str,
default=None
)
parser.add_argument(
"--boxplot",
help="Use a boxplot graph, with one boxplot per parameter value.",
action=argparse.BooleanOptionalAction,
type=bool,
default=False
)
args = parser.parse_args()

succeeded = plot_roofline_vs_column(
args.kernel_stats_csv, args.benchmark_csv, args.out_path, args.parameter, args.boxplot
)
if succeeded:
print(f"Plot saved to {args.out_path}\n")
else:
print(f"Failed to generate plot.\n")
2 changes: 1 addition & 1 deletion common_tools/utils/bench_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def roofline(results=None, out=None, batch=None, dtype=None, model=None, **kwarg
with open(result_file.strip(), mode='r') as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
row = {k: float(v) if k in ['index', 'mean_microseconds', 'arithmetic_intensity', 'tflops'] else v for k, v in row.items()}
row = {k: float(v) if k in ['index', 'mean_microseconds', 'arithmetic_intensity', 'tflops', 'roofline_tflops', 'roofline_percent'] else v for k, v in row.items()}
row['ok'] = True if 'ok' not in row else row['ok'] == 'True'
data.append(row)
if batch:
Expand Down
48 changes: 36 additions & 12 deletions convbench/shark_conv.py → convbench/conv_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
import sys
from utils import *
from conv_utils import *
from problems import get_conv_configs
from problems import get_conv_configs, get_conv_test_configs


def compile_conv(tag, config, kernel_dir, vmfb_dir):
mlir_file, vmfb_file = compile_conv_config(config, kernel_dir, vmfb_dir)
return (tag, config, mlir_file, vmfb_file)
def compile_conv(tag, config, kernel_dir, vmfb_dir, extra_compiler_args):
mlir_file, vmfb_file, dump_path = compile_conv_config(config, kernel_dir, vmfb_dir, extra_compiler_args)
return (tag, config, mlir_file, vmfb_file, dump_path)


if __name__ == "__main__":
Expand All @@ -27,6 +27,12 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):
help="Set the logging level",
)
parser.add_argument("--device", help="The IREE device to execute benchmarks on", type=str, default="hip")
parser.add_argument(
"--Xiree_compile",
nargs='+',
default=[],
help="Extra command line arguments passed to the IREE compiler. The flags need to be specified without the `--` or `-`."
)
parser.add_argument(
"--roofline",
help="Comma seperated csv file list to generate roofline plot with",
Expand All @@ -44,6 +50,7 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):
roofline(args.roofline, args.plot, args.batch, args.dtype, args.model)
sys.exit()

# configs = get_conv_test_configs()
configs = get_conv_configs()
print(f"Generated {len(configs)} conv configs.")

Expand All @@ -60,16 +67,17 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):
vmfb_dir.mkdir(parents=True, exist_ok=True)
device = args.device

extra_compiler_args = ['--' + x for x in list(args.Xiree_compile)]
compile_args = itertools.starmap(
lambda tag, config: (tag, config, kernel_dir, vmfb_dir), configs
lambda tag, config: (tag, config, kernel_dir, vmfb_dir, extra_compiler_args), configs
)
with Pool(num_cpus) as pool:
compilation_results = list(tqdm(pool.starmap(compile_conv, list(compile_args))))

error_count = 0
for tag, config, mlir_file, vmfb_file in compilation_results:
for tag, config, mlir_file, vmfb_file, dump_path in compilation_results:
if vmfb_file:
vmfb_dict[vmfb_file] = (tag, config)
vmfb_dict[vmfb_file] = (tag, config, dump_path)
else:
error_count += 1
print(
Expand All @@ -86,7 +94,7 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):
os.makedirs(csv_dir)

for vmfb_filename, value in vmfb_dict.items():
tag, config = value
tag, config, dump_path = value
name = config.get_name()

image_shape = config.get_img_shape()
Expand All @@ -103,17 +111,29 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):
"--benchmark_repetitions=3",
]

print(f"Running {vmfb_filename}...")
# iree benchmark kernels
ret_value, cmd_out, cmd_stderr = run_iree_command(exec_args)
ok = ret_value == 0
benchmark_gemm_mean_time_ms = bench_summary_process(ret_value, cmd_out)
benchmark_gemm_mean_time_us = benchmark_gemm_mean_time_ms * 1000
benchmark_conv_mean_time_ms = bench_summary_process(ret_value, cmd_out)
benchmark_conv_mean_time_us = benchmark_conv_mean_time_ms * 1000

flops = config.get_flops()
byte_count = config.get_byte_count()

arithmetic_intensity = flops / byte_count
tflops_per_second = (flops / 1e12) / (benchmark_gemm_mean_time_us / 1e6)
tflops_per_second = (flops / 1e12) / (benchmark_conv_mean_time_us / 1e6)

# Compute percentage of the roofline.
# TODO: Make this target specific and move to common utils.
tflops_map = {
"f32": 653.7,
"f16": 1307.4,
"bf16": 1307.4,
"f8E4M3FNUZ": 2614.9,
"i8": 2614.9,
}
roofline_tflops = tflops_map[config.input_dtype]

results.append(
(
Expand All @@ -130,9 +150,11 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):
config.S,
config.input_dtype,
config.output_dtype,
round(benchmark_gemm_mean_time_us, 4),
round(benchmark_conv_mean_time_us, 4),
round(arithmetic_intensity, 4),
round(tflops_per_second, 4),
roofline_tflops,
round(tflops_per_second / roofline_tflops, 4),
ok,
)
)
Expand All @@ -155,6 +177,8 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):
"mean_microseconds",
"arithmetic_intensity",
"tflops",
"roofline_tflops",
"roofline_percent",
"ok",
]

Expand Down
10 changes: 6 additions & 4 deletions convbench/conv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,12 @@ def generate_mlir(config: ConvConfig):


def compile_conv_config(
config: ConvConfig, kernel_dir: Path, vmfb_dir: Path
config: ConvConfig, kernel_dir: Path, vmfb_dir: Path, extra_compiler_args: list[str]
) -> tuple[Path, Optional[Path]]:
mlir_file = kernel_dir / (config.get_name() + ".mlir")
vmfb_file = vmfb_dir / (config.get_name() + ".vmfb")
dump_file = kernel_dir / (config.get_name() + ".stderr.mlir")
files_path = vmfb_dir / config.get_name()

# Generate mlir content
mlir_content = generate_mlir(config)
Expand All @@ -188,7 +189,8 @@ def compile_conv_config(
"--iree-hal-target-device=hip",
# Device: MI300x
"--iree-hip-target=gfx942",
]
f"--iree-hal-dump-executable-files-to={files_path}",
] + extra_compiler_args

print(" ".join(exec_args))

Expand All @@ -203,6 +205,6 @@ def compile_conv_config(
print(f"Failed to compile {mlir_file}. Error dumped in {error_file}")
with open(error_file, "w") as f:
f.write(stderr.decode("utf-8"))
return mlir_file, None
return mlir_file, None, None

return mlir_file, vmfb_file
return mlir_file, vmfb_file, files_path
71 changes: 67 additions & 4 deletions convbench/problems.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,41 @@
from conv_utils import ConvConfig


def unet_sweep(op: str, input_dtype: str, output_dtype: str) -> list[ConvConfig]:
configs = []
for B in [1, 2, 4, 8]:
configs.append(ConvConfig(B, 128, 128, 16, 3, 3, 320, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 320, 3, 3, 320, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 320, 3, 3, 320, 2, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 320, 3, 3, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 640, 3, 3, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 320, 1, 1, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 640, 3, 3, 640, 2, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 640, 3, 3, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 1280, 3, 3, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 640, 1, 1, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 2560, 3, 3, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 2560, 1, 1, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 1920, 3, 3, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 32, 32, 1920, 1, 1, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1280, 3, 3, 1280, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1920, 3, 3, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1920, 1, 1, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1280, 3, 3, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 1280, 1, 1, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 960, 3, 3, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 64, 64, 960, 1, 1, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 640, 3, 3, 640, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 960, 3, 3, 320, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 960, 1, 1, 320, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 640, 3, 3, 320, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 640, 1, 1, 320, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 128, 128, 320, 3, 3, 16, 1, op, input_dtype, output_dtype))
return configs

def resnet_sweep(op: str, input_dtype: str, output_dtype: str) -> list[ConvConfig]:
configs = []
for B in [1, 2, 4, 8, 16, 32, 48]:
for B in [1, 2, 4, 8]:
configs.append(ConvConfig(B, 112, 112, 64, 7, 7, 3, 2, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 56, 56, 64, 3, 3, 64, 1, op, input_dtype, output_dtype))
configs.append(ConvConfig(B, 28, 28, 128, 3, 3, 128, 2, op, input_dtype, output_dtype))
Expand All @@ -19,9 +51,40 @@ def resnet_sweep(op: str, input_dtype: str, output_dtype: str) -> list[ConvConfi

def get_conv_configs() -> list[tuple[str, ConvConfig]]:
configs: list[tuple[str, ConvConfig]] = []
resnet_configs = resnet_sweep("conv_2d_nchw_fchw", "f32", "f32")
resnet_configs += resnet_sweep("conv_2d_nhwc_hwcf_q", "i8", "i32")

configs += [("resnet_sweep", x) for x in resnet_configs]
# Resnet
resnet_configs = []
resnet_configs += resnet_sweep("conv_2d_nhwc_hwcf", "f16", "f32")
resnet_configs += resnet_sweep("conv_2d_nhwc_hwcf", "i8", "i32")
resnet_configs += resnet_sweep("conv_2d_nchw_fchw", "f16", "f32")
resnet_configs += resnet_sweep("conv_2d_nchw_fchw", "i8", "i32")
configs += [("resnet", x) for x in resnet_configs]

# Unet
unet_configs = []
unet_configs += unet_sweep("conv_2d_nhwc_hwcf", "f16", "f32")
unet_configs += unet_sweep("conv_2d_nhwc_hwcf", "i8", "i32")
unet_configs += unet_sweep("conv_2d_nchw_fchw", "f16", "f32")
unet_configs += unet_sweep("conv_2d_nchw_fchw", "i8", "i32")
configs += [("unet", x) for x in unet_configs]

return configs

# Test function to run only a few chosen shapes
def get_conv_test_configs() -> list[tuple[str, ConvConfig]]:
configs: list[tuple[str, ConvConfig]] = []

resnet_configs = []
# resnet_configs += resnet_sweep("conv_2d_nhwc_hwcf", "f16", "f32")
# resnet_configs += resnet_sweep("conv_2d_nhwc_hwcf", "i8", "i32")
# resnet_configs += resnet_sweep("conv_2d_nchw_fchw", "f16", "f32")
# resnet_configs += resnet_sweep("conv_2d_nchw_fchw", "i8", "i32")
configs += [("resnet", x) for x in resnet_configs]

unet_configs = []
# unet_configs.append(ConvConfig(1,128,128,16,3,3,320,1, "conv_2d_nhwc_hwcf_q", "i8", "i32"))
# unet_configs.append(ConvConfig(1,32,32,640,1,1,1280,1, "conv_2d_nhwc_hwcf_q", "i8", "i32"))

configs += [("unet", x) for x in unet_configs]

return configs

0 comments on commit 4621947

Please sign in to comment.