From a5a7625e0e007f52b3ec138534e0b61f49b95c03 Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Mon, 7 Oct 2024 15:47:47 -0500 Subject: [PATCH 1/4] Add SDXL conv shapes, add tool to plot roofline percentages Signed-off-by: Max Dawkins --- .github/workflows/run_bench.yml | 2 +- common_tools/kernel_stats.py | 2 +- common_tools/plot_roofline_percents.py | 80 ++++++++++++++++++++++ convbench/{shark_conv.py => conv_bench.py} | 49 +++++++++---- convbench/conv_utils.py | 10 +-- convbench/problems.py | 56 ++++++++++++++- 6 files changed, 178 insertions(+), 21 deletions(-) create mode 100644 common_tools/plot_roofline_percents.py rename convbench/{shark_conv.py => conv_bench.py} (71%) diff --git a/.github/workflows/run_bench.yml b/.github/workflows/run_bench.yml index dfda099..0f45bd2 100644 --- a/.github/workflows/run_bench.yml +++ b/.github/workflows/run_bench.yml @@ -35,7 +35,7 @@ jobs: - name: Convolutions run: | source bench_venv/bin/activate - python convbench/shark_conv.py + python convbench/conv_bench.py - name: Attention run: | diff --git a/common_tools/kernel_stats.py b/common_tools/kernel_stats.py index 98a1c5c..7e11c3d 100644 --- a/common_tools/kernel_stats.py +++ b/common_tools/kernel_stats.py @@ -98,7 +98,7 @@ class KernelStats: @staticmethod def get_csv_header() -> list[str]: return ( - ["Name"] + IsaStats.get_csv_header() + ConfiguredMlirStats.get_csv_header() + ["name"] + IsaStats.get_csv_header() + ConfiguredMlirStats.get_csv_header() ) def get_values(self): diff --git a/common_tools/plot_roofline_percents.py b/common_tools/plot_roofline_percents.py new file mode 100644 index 0000000..927a24f --- /dev/null +++ b/common_tools/plot_roofline_percents.py @@ -0,0 +1,80 @@ +import argparse +import pandas as pd +import matplotlib.pyplot as plt + +def plot_roofline_vs_column(kernel_stat_path, benchmark_stat_path, out_path, param_name, boxplot): + kernel_df = pd.read_csv(kernel_stat_path) + benchmark_df = pd.read_csv(benchmark_stat_path) + if param_name not in kernel_df.columns and param_name not in benchmark_df.columns: + print(f"`{param_name}` column not found in {kernel_stat_path} or {benchmark_stat_path}.\n") + return False + if "roofline_percent" not in benchmark_df.columns: + print(f"`roofline_percent` column not found in {benchmark_stat_path}.\n") + return False + if "name" not in benchmark_df.columns or "name" not in kernel_df.columns: + print(f"`name` column not found in {kernel_stat_path} and {benchmark_stat_path}.\n") + return False + df = kernel_df.merge(benchmark_df, on="name") + if boxplot: + axes = df[[param_name, "roofline_percent"]].boxplot( + by=param_name, + figsize=(12,12) + ) + else: + axes = df.plot( + param_name, + "roofline_percent", + kind="scatter", + figsize=(12,12) + ) + plt.xlabel(param_name) + plt.ylabel("roofline_percent") + plt.savefig(out_path, dpi=300, bbox_inches='tight') + plt.close() + return True + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Plotting tool to correlate kernel parameters with roofline percentages." + ) + parser.add_argument( + "--kernel_stats_csv", + help="The path to the input csv containing kernel metrics.", + type=str, + default=None + ) + parser.add_argument( + "--benchmark_csv", + help="The path to the input csv containing benchmarks.", + type=str, + default=None + ) + parser.add_argument( + "--out_path", + help="The path to save the resulting plot image.", + type=str, + default=None + ) + parser.add_argument( + "--parameter", + help="The name of the column with the parameter to use as the x-axis.", + type=str, + default=None + ) + parser.add_argument( + "--boxplot", + help="Use a boxplot graph, with one boxplot per parameter value.", + action=argparse.BooleanOptionalAction, + type=bool, + default=False + ) + args = parser.parse_args() + + succeeded = plot_roofline_vs_column( + args.kernel_stats_csv, args.benchmark_csv, args.out_path, args.parameter, args.boxplot + ) + if succeeded: + print(f"Plot saved to {args.out_path}\n") + else: + print(f"Failed to generate plot.\n") diff --git a/convbench/shark_conv.py b/convbench/conv_bench.py similarity index 71% rename from convbench/shark_conv.py rename to convbench/conv_bench.py index 8f797c1..3adac44 100644 --- a/convbench/shark_conv.py +++ b/convbench/conv_bench.py @@ -9,12 +9,12 @@ import sys from utils import * from conv_utils import * -from problems import get_conv_configs +from problems import get_conv_configs, get_conv_test_configs -def compile_conv(tag, config, kernel_dir, vmfb_dir): - mlir_file, vmfb_file = compile_conv_config(config, kernel_dir, vmfb_dir) - return (tag, config, mlir_file, vmfb_file) +def compile_conv(tag, config, kernel_dir, vmfb_dir, extra_compiler_args): + mlir_file, vmfb_file, dump_path = compile_conv_config(config, kernel_dir, vmfb_dir, extra_compiler_args) + return (tag, config, mlir_file, vmfb_file, dump_path) if __name__ == "__main__": @@ -27,6 +27,12 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir): help="Set the logging level", ) parser.add_argument("--device", help="The IREE device to execute benchmarks on", type=str, default="hip") + parser.add_argument( + "--Xiree_compile", + nargs='+', + default=[], + help="Extra command line arguments passed to the IREE compiler. The flags need to be specified without the `--` or `-`." + ) parser.add_argument( "--roofline", help="Comma seperated csv file list to generate roofline plot with", @@ -44,7 +50,8 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir): roofline(args.roofline, args.plot, args.batch, args.dtype, args.model) sys.exit() - configs = get_conv_configs() + configs = get_conv_test_configs() + # configs = get_conv_configs() print(f"Generated {len(configs)} conv configs.") num_cpus = max(1, cpu_count() - 20) @@ -60,16 +67,17 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir): vmfb_dir.mkdir(parents=True, exist_ok=True) device = args.device + extra_compiler_args = ['--' + x for x in list(args.Xiree_compile)] compile_args = itertools.starmap( - lambda tag, config: (tag, config, kernel_dir, vmfb_dir), configs + lambda tag, config: (tag, config, kernel_dir, vmfb_dir, extra_compiler_args), configs ) with Pool(num_cpus) as pool: compilation_results = list(tqdm(pool.starmap(compile_conv, list(compile_args)))) error_count = 0 - for tag, config, mlir_file, vmfb_file in compilation_results: + for tag, config, mlir_file, vmfb_file, dump_path in compilation_results: if vmfb_file: - vmfb_dict[vmfb_file] = (tag, config) + vmfb_dict[vmfb_file] = (tag, config, dump_path) else: error_count += 1 print( @@ -86,7 +94,7 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir): os.makedirs(csv_dir) for vmfb_filename, value in vmfb_dict.items(): - tag, config = value + tag, config, dump_path = value name = config.get_name() image_shape = config.get_img_shape() @@ -103,17 +111,28 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir): "--benchmark_repetitions=3", ] + print(f"Running {vmfb_filename}...") # iree benchmark kernels ret_value, cmd_out, cmd_stderr = run_iree_command(exec_args) ok = ret_value == 0 - benchmark_gemm_mean_time_ms = bench_summary_process(ret_value, cmd_out) - benchmark_gemm_mean_time_us = benchmark_gemm_mean_time_ms * 1000 + benchmark_conv_mean_time_ms = bench_summary_process(ret_value, cmd_out) + benchmark_conv_mean_time_us = benchmark_conv_mean_time_ms * 1000 flops = config.get_flops() byte_count = config.get_byte_count() arithmetic_intensity = flops / byte_count - tflops_per_second = (flops / 1e12) / (benchmark_gemm_mean_time_us / 1e6) + tflops_per_second = (flops / 1e12) / (benchmark_conv_mean_time_us / 1e6) + + # Compute percentage of the roofline. + tflops_map = { + "f32": 653.7, + "f16": 1307.4, + "bf16": 1307.4, + "f8E4M3FNUZ": 2614.9, + "i8": 2614.9, + } + roofline_tflops = tflops_map[config.input_dtype] results.append( ( @@ -130,9 +149,11 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir): config.S, config.input_dtype, config.output_dtype, - round(benchmark_gemm_mean_time_us, 4), + round(benchmark_conv_mean_time_us, 4), round(arithmetic_intensity, 4), round(tflops_per_second, 4), + roofline_tflops, + round(tflops_per_second / roofline_tflops, 4), ok, ) ) @@ -155,6 +176,8 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir): "mean_microseconds", "arithmetic_intensity", "tflops", + "roofline_tflops", + "roofline_percent", "ok", ] diff --git a/convbench/conv_utils.py b/convbench/conv_utils.py index c8c8ad4..4f48d89 100644 --- a/convbench/conv_utils.py +++ b/convbench/conv_utils.py @@ -162,11 +162,12 @@ def generate_mlir(config: ConvConfig): def compile_conv_config( - config: ConvConfig, kernel_dir: Path, vmfb_dir: Path + config: ConvConfig, kernel_dir: Path, vmfb_dir: Path, extra_compiler_args: list[str] ) -> tuple[Path, Optional[Path]]: mlir_file = kernel_dir / (config.get_name() + ".mlir") vmfb_file = vmfb_dir / (config.get_name() + ".vmfb") dump_file = kernel_dir / (config.get_name() + ".stderr.mlir") + files_path = vmfb_dir / config.get_name() # Generate mlir content mlir_content = generate_mlir(config) @@ -188,7 +189,8 @@ def compile_conv_config( "--iree-hal-target-device=hip", # Device: MI300x "--iree-hip-target=gfx942", - ] + f"--iree-hal-dump-executable-files-to={files_path}", + ] + extra_compiler_args print(" ".join(exec_args)) @@ -203,6 +205,6 @@ def compile_conv_config( print(f"Failed to compile {mlir_file}. Error dumped in {error_file}") with open(error_file, "w") as f: f.write(stderr.decode("utf-8")) - return mlir_file, None + return mlir_file, None, None - return mlir_file, vmfb_file + return mlir_file, vmfb_file, files_path diff --git a/convbench/problems.py b/convbench/problems.py index 6434c52..9408a0f 100644 --- a/convbench/problems.py +++ b/convbench/problems.py @@ -1,6 +1,38 @@ from conv_utils import ConvConfig +def unet_sweep(op: str, input_dtype: str, output_dtype: str) -> list[ConvConfig]: + configs = [] + for B in [1, 2, 4]: + configs.append(ConvConfig(B, 128, 128, 16, 3, 3, 320, 1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 128, 128, 320, 3, 3, 320, 1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 64, 64, 320, 3, 3, 320, 2, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 64, 64, 320, 3, 3, 640, 1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 64, 64, 640, 3, 3, 640, 1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 64, 64, 320, 1, 1, 640, 1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 32, 32, 640, 3, 3, 640, 2, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 32, 32, 640, 3, 3, 1280, 1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 32, 32, 1280, 3, 3, 1280, 1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 32, 32, 640, 1, 1, 1280, 1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 32, 32, 2560, 3, 3, 1280, 1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 32, 32, 2560, 1, 1, 1280, 1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 32, 32, 1920, 3, 3, 1280, 1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 32, 32, 1920, 1, 1, 1280, 1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 64, 64, 1280, 3, 3, 1280, 1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 64, 64, 1920, 3, 3, 640, 1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 64, 64, 1920, 1, 1, 640, 1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 64, 64, 1280, 3, 3, 640, 1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 64, 64, 1280, 1, 1, 640, 1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 64, 64, 960, 3, 3, 640, 1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 64, 64, 960, 1, 1, 640, 1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 128, 128, 640, 3, 3, 640, 1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 128, 128, 960, 3, 3, 320, 1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 128, 128, 960, 1, 1, 320, 1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 128, 128, 640, 3, 3, 320, 1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 128, 128, 640, 1, 1, 320, 1, op, input_dtype, output_dtype)) + configs.append(ConvConfig(B, 128, 128, 320, 3, 3, 16, 1, op, input_dtype, output_dtype)) + return configs + def resnet_sweep(op: str, input_dtype: str, output_dtype: str) -> list[ConvConfig]: configs = [] for B in [1, 2, 4, 8, 16, 32, 48]: @@ -19,9 +51,29 @@ def resnet_sweep(op: str, input_dtype: str, output_dtype: str) -> list[ConvConfi def get_conv_configs() -> list[tuple[str, ConvConfig]]: configs: list[tuple[str, ConvConfig]] = [] - resnet_configs = resnet_sweep("conv_2d_nchw_fchw", "f32", "f32") + + # Resnet + resnet_configs = [] + resnet_configs += resnet_sweep("conv_2d_nhwc_hwcf", "f16", "f32") resnet_configs += resnet_sweep("conv_2d_nhwc_hwcf_q", "i8", "i32") + configs += [("resnet", x) for x in resnet_configs] + + # Unet + unet_configs = [] + unet_configs += unet_sweep("conv_2d_nhwc_hwcf", "f16", "f32") + unet_configs += unet_sweep("conv_2d_nhwc_hwcf_q", "i8", "i32") + configs += [("unet", x) for x in unet_configs] + + return configs + +# Test function to run only a few chosen shapes +def get_conv_test_configs() -> list[tuple[str, ConvConfig]]: + configs: list[tuple[str, ConvConfig]] = [] + + unet_configs = [] + unet_configs.append(ConvConfig(1,128,128,16,3,3,320,1, "conv_2d_nhwc_hwcf_q", "i8", "i32")) + unet_configs.append(ConvConfig(1,32,32,640,1,1,1280,1, "conv_2d_nhwc_hwcf_q", "i8", "i32")) - configs += [("resnet_sweep", x) for x in resnet_configs] + configs += [("unet", x) for x in unet_configs] return configs From f1796ecd64d4f5e40df8f7d965c33da81253a28a Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Wed, 9 Oct 2024 09:16:15 -0500 Subject: [PATCH 2/4] add todo Signed-off-by: Max Dawkins --- convbench/conv_bench.py | 1 + 1 file changed, 1 insertion(+) diff --git a/convbench/conv_bench.py b/convbench/conv_bench.py index 3adac44..4ad0216 100644 --- a/convbench/conv_bench.py +++ b/convbench/conv_bench.py @@ -125,6 +125,7 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir, extra_compiler_args): tflops_per_second = (flops / 1e12) / (benchmark_conv_mean_time_us / 1e6) # Compute percentage of the roofline. + # TODO: Make this target specific and move to common utils. tflops_map = { "f32": 653.7, "f16": 1307.4, From d6a72a78d6cbe6509bc0f7c66203f0ab84133f5d Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Wed, 9 Oct 2024 09:18:47 -0500 Subject: [PATCH 3/4] update file name Signed-off-by: Max Dawkins --- .github/workflows/run_bench.yml | 14 +++++++------- README.md | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/run_bench.yml b/.github/workflows/run_bench.yml index 0f45bd2..46c576c 100644 --- a/.github/workflows/run_bench.yml +++ b/.github/workflows/run_bench.yml @@ -55,13 +55,13 @@ jobs: - name: Roofline Plots run: | source bench_venv/bin/activate - python convbench/shark_conv.py --roofline results/iree_conv.csv --plot results/iree_conv_i8.png --dtype i8 - python convbench/shark_conv.py --roofline results/iree_conv.csv --plot results/iree_conv_f32.png --dtype f32 - python convbench/shark_conv.py --roofline results/iree_attention.csv --plot results/iree_attention_fp16.png --dtype f16 - python convbench/shark_conv.py --roofline results/iree_attention.csv --plot results/iree_attention_fp8.png --dtype f8E4M3FNUZ - python convbench/shark_conv.py --roofline results/iree_gemm.csv --plot results/iree_gemm.png - python convbench/shark_conv.py --roofline results/iree_gemm_tk.csv --plot results/iree_gemm_tk.png - python convbench/shark_conv.py --roofline results/iree_gemm.csv,results/iree_gemm_tk.csv,results/iree_attention.csv,results/iree_conv.csv --plot results/combined.png + python convbench/conv_bench.py --roofline results/iree_conv.csv --plot results/iree_conv_i8.png --dtype i8 + python convbench/conv_bench.py --roofline results/iree_conv.csv --plot results/iree_conv_f32.png --dtype f32 + python convbench/conv_bench.py --roofline results/iree_attention.csv --plot results/iree_attention_fp16.png --dtype f16 + python convbench/conv_bench.py --roofline results/iree_attention.csv --plot results/iree_attention_fp8.png --dtype f8E4M3FNUZ + python convbench/conv_bench.py --roofline results/iree_gemm.csv --plot results/iree_gemm.png + python convbench/conv_bench.py --roofline results/iree_gemm_tk.csv --plot results/iree_gemm_tk.png + python convbench/conv_bench.py --roofline results/iree_gemm.csv,results/iree_gemm_tk.csv,results/iree_attention.csv,results/iree_conv.csv --plot results/combined.png - name: Upload benchmark results uses: actions/upload-artifact@v4 diff --git a/README.md b/README.md index 386b36e..a9e2f9d 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ Refer to the respective problems.py file in the folder to see which shapes are b ### Convolution Benchmarking ``` -python convbench/shark_conv.py +python convbench/conv_bench.py ``` ### GEMM Benchmarking @@ -50,7 +50,7 @@ python attentionbench/attention_bench.py If you want to generate a roofline plot, you can call any of the suites for now with the --roofline option (provide a commma seperated list if you want to generate for multiple benchmarks combined): ``` -python convbench/shark_conv.py --roofline results/iree_conv.csv,results/iree_attention.csv --plot results/attn_conv.png +python convbench/conv_bench.py --roofline results/iree_conv.csv,results/iree_attention.csv --plot results/attn_conv.png ``` If you want to generate a roofline plot for a certain data type, model, or batch size you can do: From c89b6fb7b2d28570d7770e73d336699f2ca3cb18 Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Fri, 18 Oct 2024 14:19:10 -0500 Subject: [PATCH 4/4] fix CI, update problem set Signed-off-by: Max Dawkins --- .github/workflows/run_bench.yml | 2 +- common_tools/utils/bench_utils.py | 2 +- convbench/conv_bench.py | 4 ++-- convbench/problems.py | 23 +++++++++++++++++------ 4 files changed, 21 insertions(+), 10 deletions(-) diff --git a/.github/workflows/run_bench.yml b/.github/workflows/run_bench.yml index 46c576c..ebe07f0 100644 --- a/.github/workflows/run_bench.yml +++ b/.github/workflows/run_bench.yml @@ -56,7 +56,7 @@ jobs: run: | source bench_venv/bin/activate python convbench/conv_bench.py --roofline results/iree_conv.csv --plot results/iree_conv_i8.png --dtype i8 - python convbench/conv_bench.py --roofline results/iree_conv.csv --plot results/iree_conv_f32.png --dtype f32 + python convbench/conv_bench.py --roofline results/iree_conv.csv --plot results/iree_conv_f16.png --dtype f16 python convbench/conv_bench.py --roofline results/iree_attention.csv --plot results/iree_attention_fp16.png --dtype f16 python convbench/conv_bench.py --roofline results/iree_attention.csv --plot results/iree_attention_fp8.png --dtype f8E4M3FNUZ python convbench/conv_bench.py --roofline results/iree_gemm.csv --plot results/iree_gemm.png diff --git a/common_tools/utils/bench_utils.py b/common_tools/utils/bench_utils.py index f23d13a..926f700 100644 --- a/common_tools/utils/bench_utils.py +++ b/common_tools/utils/bench_utils.py @@ -144,7 +144,7 @@ def roofline(results=None, out=None, batch=None, dtype=None, model=None, **kwarg with open(result_file.strip(), mode='r') as csvfile: reader = csv.DictReader(csvfile) for row in reader: - row = {k: float(v) if k in ['index', 'mean_microseconds', 'arithmetic_intensity', 'tflops'] else v for k, v in row.items()} + row = {k: float(v) if k in ['index', 'mean_microseconds', 'arithmetic_intensity', 'tflops', 'roofline_tflops', 'roofline_percent'] else v for k, v in row.items()} row['ok'] = True if 'ok' not in row else row['ok'] == 'True' data.append(row) if batch: diff --git a/convbench/conv_bench.py b/convbench/conv_bench.py index 4ad0216..91d7a37 100644 --- a/convbench/conv_bench.py +++ b/convbench/conv_bench.py @@ -50,8 +50,8 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir, extra_compiler_args): roofline(args.roofline, args.plot, args.batch, args.dtype, args.model) sys.exit() - configs = get_conv_test_configs() - # configs = get_conv_configs() + # configs = get_conv_test_configs() + configs = get_conv_configs() print(f"Generated {len(configs)} conv configs.") num_cpus = max(1, cpu_count() - 20) diff --git a/convbench/problems.py b/convbench/problems.py index 9408a0f..a61272a 100644 --- a/convbench/problems.py +++ b/convbench/problems.py @@ -3,7 +3,7 @@ def unet_sweep(op: str, input_dtype: str, output_dtype: str) -> list[ConvConfig]: configs = [] - for B in [1, 2, 4]: + for B in [1, 2, 4, 8]: configs.append(ConvConfig(B, 128, 128, 16, 3, 3, 320, 1, op, input_dtype, output_dtype)) configs.append(ConvConfig(B, 128, 128, 320, 3, 3, 320, 1, op, input_dtype, output_dtype)) configs.append(ConvConfig(B, 64, 64, 320, 3, 3, 320, 2, op, input_dtype, output_dtype)) @@ -35,7 +35,7 @@ def unet_sweep(op: str, input_dtype: str, output_dtype: str) -> list[ConvConfig] def resnet_sweep(op: str, input_dtype: str, output_dtype: str) -> list[ConvConfig]: configs = [] - for B in [1, 2, 4, 8, 16, 32, 48]: + for B in [1, 2, 4, 8]: configs.append(ConvConfig(B, 112, 112, 64, 7, 7, 3, 2, op, input_dtype, output_dtype)) configs.append(ConvConfig(B, 56, 56, 64, 3, 3, 64, 1, op, input_dtype, output_dtype)) configs.append(ConvConfig(B, 28, 28, 128, 3, 3, 128, 2, op, input_dtype, output_dtype)) @@ -55,13 +55,17 @@ def get_conv_configs() -> list[tuple[str, ConvConfig]]: # Resnet resnet_configs = [] resnet_configs += resnet_sweep("conv_2d_nhwc_hwcf", "f16", "f32") - resnet_configs += resnet_sweep("conv_2d_nhwc_hwcf_q", "i8", "i32") + resnet_configs += resnet_sweep("conv_2d_nhwc_hwcf", "i8", "i32") + resnet_configs += resnet_sweep("conv_2d_nchw_fchw", "f16", "f32") + resnet_configs += resnet_sweep("conv_2d_nchw_fchw", "i8", "i32") configs += [("resnet", x) for x in resnet_configs] # Unet unet_configs = [] unet_configs += unet_sweep("conv_2d_nhwc_hwcf", "f16", "f32") - unet_configs += unet_sweep("conv_2d_nhwc_hwcf_q", "i8", "i32") + unet_configs += unet_sweep("conv_2d_nhwc_hwcf", "i8", "i32") + unet_configs += unet_sweep("conv_2d_nchw_fchw", "f16", "f32") + unet_configs += unet_sweep("conv_2d_nchw_fchw", "i8", "i32") configs += [("unet", x) for x in unet_configs] return configs @@ -70,9 +74,16 @@ def get_conv_configs() -> list[tuple[str, ConvConfig]]: def get_conv_test_configs() -> list[tuple[str, ConvConfig]]: configs: list[tuple[str, ConvConfig]] = [] + resnet_configs = [] + # resnet_configs += resnet_sweep("conv_2d_nhwc_hwcf", "f16", "f32") + # resnet_configs += resnet_sweep("conv_2d_nhwc_hwcf", "i8", "i32") + # resnet_configs += resnet_sweep("conv_2d_nchw_fchw", "f16", "f32") + # resnet_configs += resnet_sweep("conv_2d_nchw_fchw", "i8", "i32") + configs += [("resnet", x) for x in resnet_configs] + unet_configs = [] - unet_configs.append(ConvConfig(1,128,128,16,3,3,320,1, "conv_2d_nhwc_hwcf_q", "i8", "i32")) - unet_configs.append(ConvConfig(1,32,32,640,1,1,1280,1, "conv_2d_nhwc_hwcf_q", "i8", "i32")) + # unet_configs.append(ConvConfig(1,128,128,16,3,3,320,1, "conv_2d_nhwc_hwcf_q", "i8", "i32")) + # unet_configs.append(ConvConfig(1,32,32,640,1,1,1280,1, "conv_2d_nhwc_hwcf_q", "i8", "i32")) configs += [("unet", x) for x in unet_configs]