From 5e28a8eb6c362d8e81c0aa7ba97d17878e845b0a Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 31 Jul 2024 20:54:45 +0300 Subject: [PATCH] Add elementwise and broadcast simulator tests (#56) As I'm using `torch.tensor` in simulator, all elementwise ops and broadcasting should work out of the box. Signed-off-by: Ivan Butygin --- tests/kernel/wave/wave_sim_test.py | 146 +++++++++++++++++++++++++++++ 1 file changed, 146 insertions(+) diff --git a/tests/kernel/wave/wave_sim_test.py b/tests/kernel/wave/wave_sim_test.py index 5fcb826a..9c920fe2 100644 --- a/tests/kernel/wave/wave_sim_test.py +++ b/tests/kernel/wave/wave_sim_test.py @@ -6,6 +6,152 @@ from numpy.testing import assert_allclose +def test_eltwise(): + # Input sizes + M = tkl.sym.M + N = tkl.sym.N + # Workgroup tile sizes + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD + STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD + + # Expose user-constraints + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + + constraints += [ + tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(1, 1, 1)) + ] + + @wave_sim(constraints) + def eltwise( + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32], + b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32], + c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32], + ): + a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD) + b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD) + tkw.write(a_reg + b_reg, c, elements_per_thread=STORE_ELEMS_PER_THREAD) + + a = torch.randn(128, 256, dtype=torch.float32) + b = torch.randn(128, 256, dtype=torch.float32) + c = torch.zeros(128, 256, dtype=torch.float32) + eltwise(a, b, c) + assert_allclose(c, a + b) + + +def test_broadcast_1(): + # Input sizes + M = tkl.sym.M + N = tkl.sym.N + # Workgroup tile sizes + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD + STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD + + # Expose user-constraints + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + + constraints += [ + tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(1, 1, 1)) + ] + + @wave_sim(constraints) + def eltwise( + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32], + b: tkl.Memory[N, ADDRESS_SPACE, tkl.f32], + c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32], + ): + a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD) + b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD) + tkw.write(a_reg + b_reg, c, elements_per_thread=STORE_ELEMS_PER_THREAD) + + a = torch.randn(128, 256, dtype=torch.float32) + b = torch.randn(256, dtype=torch.float32) + c = torch.zeros(128, 256, dtype=torch.float32) + eltwise(a, b, c) + assert_allclose(c, a + b) + + +def test_broadcast_2(): + # Input sizes + M = tkl.sym.M + N = tkl.sym.N + # Workgroup tile sizes + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD + STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD + + # Expose user-constraints + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + + constraints += [ + tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(1, 1, 1)) + ] + + @wave_sim(constraints) + def eltwise( + b: tkl.Memory[N, ADDRESS_SPACE, tkl.f32], + c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32], + ): + b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD) + tkw.write(b_reg, c, elements_per_thread=STORE_ELEMS_PER_THREAD) + + b = torch.randn(256, dtype=torch.float32) + c = torch.zeros(128, 256, dtype=torch.float32) + eltwise(b, c) + assert_allclose(c, b + torch.zeros(128, 256, dtype=torch.float32)) + + +def test_broadcast_3(): + # Input sizes + M = tkl.sym.M + N = tkl.sym.N + # Workgroup tile sizes + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD + STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD + + # Expose user-constraints + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + + constraints += [ + tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(1, 1, 1)) + ] + + @wave_sim(constraints) + def eltwise( + b: tkl.Memory[N, ADDRESS_SPACE, tkl.f32], + c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32], + ): + b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD)[0] + tkw.write(b_reg, c, elements_per_thread=STORE_ELEMS_PER_THREAD) + + b = torch.randn(256, dtype=torch.float32) + c = torch.zeros(128, 256, dtype=torch.float32) + eltwise(b, c) + assert_allclose(c, b[0] + torch.zeros(128, 256, dtype=torch.float32)) + + def test_gemm(): # Input sizes M = tkl.sym.M