Skip to content

Commit

Permalink
Add elementwise and broadcast simulator tests (#56)
Browse files Browse the repository at this point in the history
As I'm using `torch.tensor` in simulator, all elementwise ops and
broadcasting should work out of the box.

Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
  • Loading branch information
Hardcode84 authored Jul 31, 2024
1 parent 32e62dd commit 5e28a8e
Showing 1 changed file with 146 additions and 0 deletions.
146 changes: 146 additions & 0 deletions tests/kernel/wave/wave_sim_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,152 @@
from numpy.testing import assert_allclose


def test_eltwise():
# Input sizes
M = tkl.sym.M
N = tkl.sym.N
# Workgroup tile sizes
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
# Address space (for GPU, shared(1) or global(0))
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
# Other hyperparameters
LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD
STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD

# Expose user-constraints
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]

constraints += [
tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(1, 1, 1))
]

@wave_sim(constraints)
def eltwise(
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32],
b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32],
c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32],
):
a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD)
b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD)
tkw.write(a_reg + b_reg, c, elements_per_thread=STORE_ELEMS_PER_THREAD)

a = torch.randn(128, 256, dtype=torch.float32)
b = torch.randn(128, 256, dtype=torch.float32)
c = torch.zeros(128, 256, dtype=torch.float32)
eltwise(a, b, c)
assert_allclose(c, a + b)


def test_broadcast_1():
# Input sizes
M = tkl.sym.M
N = tkl.sym.N
# Workgroup tile sizes
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
# Address space (for GPU, shared(1) or global(0))
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
# Other hyperparameters
LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD
STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD

# Expose user-constraints
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]

constraints += [
tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(1, 1, 1))
]

@wave_sim(constraints)
def eltwise(
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32],
b: tkl.Memory[N, ADDRESS_SPACE, tkl.f32],
c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32],
):
a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD)
b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD)
tkw.write(a_reg + b_reg, c, elements_per_thread=STORE_ELEMS_PER_THREAD)

a = torch.randn(128, 256, dtype=torch.float32)
b = torch.randn(256, dtype=torch.float32)
c = torch.zeros(128, 256, dtype=torch.float32)
eltwise(a, b, c)
assert_allclose(c, a + b)


def test_broadcast_2():
# Input sizes
M = tkl.sym.M
N = tkl.sym.N
# Workgroup tile sizes
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
# Address space (for GPU, shared(1) or global(0))
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
# Other hyperparameters
LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD
STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD

# Expose user-constraints
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]

constraints += [
tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(1, 1, 1))
]

@wave_sim(constraints)
def eltwise(
b: tkl.Memory[N, ADDRESS_SPACE, tkl.f32],
c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32],
):
b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD)
tkw.write(b_reg, c, elements_per_thread=STORE_ELEMS_PER_THREAD)

b = torch.randn(256, dtype=torch.float32)
c = torch.zeros(128, 256, dtype=torch.float32)
eltwise(b, c)
assert_allclose(c, b + torch.zeros(128, 256, dtype=torch.float32))


def test_broadcast_3():
# Input sizes
M = tkl.sym.M
N = tkl.sym.N
# Workgroup tile sizes
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
# Address space (for GPU, shared(1) or global(0))
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
# Other hyperparameters
LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD
STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD

# Expose user-constraints
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]

constraints += [
tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(1, 1, 1))
]

@wave_sim(constraints)
def eltwise(
b: tkl.Memory[N, ADDRESS_SPACE, tkl.f32],
c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32],
):
b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD)[0]
tkw.write(b_reg, c, elements_per_thread=STORE_ELEMS_PER_THREAD)

b = torch.randn(256, dtype=torch.float32)
c = torch.zeros(128, 256, dtype=torch.float32)
eltwise(b, c)
assert_allclose(c, b[0] + torch.zeros(128, 256, dtype=torch.float32))


def test_gemm():
# Input sizes
M = tkl.sym.M
Expand Down

0 comments on commit 5e28a8e

Please sign in to comment.