From 0f00c6d77e9a0ef19069fe7c021b89d36dc100df Mon Sep 17 00:00:00 2001 From: erman-gurses <99776114+erman-gurses@users.noreply.github.com> Date: Thu, 3 Oct 2024 11:08:06 -0700 Subject: [PATCH] Add benchmark support for e2e tests (#183) Signed-off-by: erman-gurses --- .github/workflows/ci.yaml | 2 +- shark_turbine/kernel/wave/utils.py | 27 ++++++++++++++++++++++++++- tests/kernel/wave/wave_e2e_test.py | 8 ++++++++ tests/kernel/wave/wave_gemm_test.py | 1 + 4 files changed, 36 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 5146f144..0796d30a 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -60,7 +60,7 @@ jobs: if: "contains(matrix.os, 'mi300') && !cancelled()" run: | export WAVE_RUN_E2E_TESTS=1 - pytest -n 4 ./tests/kernel/wave/ + pytest -n 4 --capture=tee-sys ./tests/kernel/wave/ - name: Run LIT tests if: ${{ !cancelled() }} diff --git a/shark_turbine/kernel/wave/utils.py b/shark_turbine/kernel/wave/utils.py index 278666b0..9ea9adad 100644 --- a/shark_turbine/kernel/wave/utils.py +++ b/shark_turbine/kernel/wave/utils.py @@ -35,6 +35,9 @@ import torch.fx as fx import shark_turbine.kernel.lang as tkl + +import tempfile +from ...support.conversions import TORCH_DTYPE_TO_SIGNED_MLIR_TYPE_ASM from iree.compiler.dialects.transform import ( interpreter as transform_interpreter, any_op_t, @@ -372,7 +375,29 @@ def compile_and_invoke( _invoke(ctx.vm_context, device, func, kernel_inputs, kernel_outputs) if run_bench: - inputs = [inp.numpy() for inp in kernel_inputs] + bench_with_constant_weights = config.get("bench_with_constant_weights", False) + tempfiles = [] + inputs = [] + if bench_with_constant_weights: + for inp in kernel_inputs: + inputs.append( + "x".join( + [str(x) for x in inp.shape] + + [TORCH_DTYPE_TO_SIGNED_MLIR_TYPE_ASM[inp.dtype]] + ) + ) + else: + for inp in kernel_inputs: + tf = tempfile.NamedTemporaryFile() + torch.save(inp, tf) + tempfiles.append(tf) + inputs.append("@" + tf.name) + + benchmark_results = bench.benchmark_module( + mod, + entry_function=func_name, + ) + benchmark_results = bench.benchmark_module( mod, entry_function=func_name, diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index 69ba718a..633c73c5 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -102,6 +102,7 @@ def test( }, canonicalize=True, run=True, + run_bench=True, run_config=config, ): test(a, b) @@ -214,6 +215,7 @@ def test( }, canonicalize=True, run=True, + run_bench=True, run_config=config, ): test(a, b) @@ -270,6 +272,7 @@ def test( }, canonicalize=True, run=True, + run_bench=True, run_config=config, ): test(a, b) @@ -326,6 +329,7 @@ def test( }, canonicalize=True, run=True, + run_bench=True, run_config=config, ): test(a, b, c) @@ -401,6 +405,7 @@ def repeat( }, canonicalize=True, run=True, + run_bench=True, run_config=config, ): test(a, b, c) @@ -505,6 +510,7 @@ def test( }, canonicalize=True, run=True, + run_bench=True, run_config=config, ): test(a, b) @@ -635,6 +641,7 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]: }, canonicalize=True, run=True, + run_bench=True, run_config=config, ): gpu_func(x, we, out) @@ -949,6 +956,7 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]: }, canonicalize=True, run=True, + run_bench=True, run_config=config, ): conv(x, we, out) diff --git a/tests/kernel/wave/wave_gemm_test.py b/tests/kernel/wave/wave_gemm_test.py index 2386ebd9..63e51909 100644 --- a/tests/kernel/wave/wave_gemm_test.py +++ b/tests/kernel/wave/wave_gemm_test.py @@ -122,6 +122,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: hyperparams, canonicalize=True, run=True, + run_bench=True, run_config=config, schedule=enable_scheduling, ):