From ff5a81d9cec42659ff6ef4cd3be1df070fabe1c2 Mon Sep 17 00:00:00 2001 From: Harsh Menon Date: Tue, 30 Jul 2024 17:32:21 -0700 Subject: [PATCH] Add indexing to nodes This primary purpose of this PR is to annotate nodes with their access patterns based on the workgroup, tiling and MMA constraints. This is accomplished prior to expansion and propagates through expansion to the expanded nodes. Signed-off-by: Harsh Menon --- lit_tests/kernel/wave/expansion.py | 199 +++++++++++++++++- shark_turbine/kernel/ops/wave_ops.py | 50 ++++- shark_turbine/kernel/wave/constraints.py | 72 +++++-- .../kernel/wave/distribution_symbols.py | 9 + shark_turbine/kernel/wave/expansion.py | 53 ++++- shark_turbine/kernel/wave/indexing.py | 43 ++++ shark_turbine/kernel/wave/wave.py | 46 +++- tests/kernel/wave/constraints_test.py | 2 +- tests/kernel/wave/wave_gemm_test.py | 2 +- 9 files changed, 448 insertions(+), 28 deletions(-) create mode 100644 shark_turbine/kernel/wave/distribution_symbols.py create mode 100644 shark_turbine/kernel/wave/indexing.py diff --git a/lit_tests/kernel/wave/expansion.py b/lit_tests/kernel/wave/expansion.py index ee56c30a..c8cc8111 100644 --- a/lit_tests/kernel/wave/expansion.py +++ b/lit_tests/kernel/wave/expansion.py @@ -9,6 +9,7 @@ from shark_turbine.kernel.wave.expansion import expand_graph from shark_turbine.kernel._support.tracing import CapturedTrace from shark_turbine.kernel._support.indexing import IndexingContext +from shark_turbine.kernel.ops.wave_ops import get_custom def run(func: Callable[[], None]) -> Callable[[], None]: @@ -23,12 +24,14 @@ def run(func: Callable[[], None]) -> Callable[[], None]: def print_trace(trace: CapturedTrace): """ Prints all subgraphs of a trace starting with the root graph. - The graphs are printed first in the torch printing format and then using - our custom node format. + The graphs are printed first in the torch printing format and + then using our custom node format. """ # The root graph is at the back so we print the subgraphs in reverse order for subgraph in reversed(list(trace.region_graph.subgraphs.values())): print(subgraph) + for node in subgraph.nodes: + print(get_custom(node)) # Input sizes @@ -44,6 +47,9 @@ def print_trace(trace: CapturedTrace): # Address space (for GPU, shared(1) or global(0)) ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE +# Induction variable for dimension K +ARGK = tkl.sym.ARGK + @tkw.wave_trace_only() def read_write_same_size( @@ -94,6 +100,28 @@ def test_read_write_equal_sizes(): # CHECK-SAME: (%read_1_0, %c, 4) # CHECK-NEXT: %write_0_1 # CHECK-SAME: (%read_0_1, %c, 4) + # CHECK-NEXT: return + + # Custom format: + # CHECK-NEXT: placeholder(_name=a + # CHECK-NEXT: placeholder(_name=c + # CHECK-NEXT: read(memory=a + # CHECK-SAME: index={M: BLOCK_M*WG0, N: BLOCK_N*WG1} + # CHECK-NEXT: read(memory=a + # CHECK-SAME: index={M: BLOCK_M*WG0, N: BLOCK_N*WG1} + # CHECK-NEXT: read(memory=a + # CHECK-SAME: index={M: BLOCK_M*WG0, N: BLOCK_N*WG1} + # CHECK-NEXT: read(memory=a + # CHECK-SAME: index={M: BLOCK_M*WG0, N: BLOCK_N*WG1} + # CHECK-NEXT: write(register_=read_0_0 + # CHECK-SAME: index={M: BLOCK_M*WG0, N: BLOCK_N*WG1} + # CHECK-NEXT: write(register_=read_1_1 + # CHECK-SAME: index={M: BLOCK_M*WG0, N: BLOCK_N*WG1} + # CHECK-NEXT: write(register_=read_1_0 + # CHECK-SAME: index={M: BLOCK_M*WG0, N: BLOCK_N*WG1} + # CHECK-NEXT: write(register_=read_0_1 + # CHECK-SAME: index={M: BLOCK_M*WG0, N: BLOCK_N*WG1} + # CHECK-NEXT: output # CHECK: ----- @@ -111,7 +139,7 @@ def read_write_different_dims( def test_read_write(): constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] - constraints += [tkw.TilingConstraint(K, BLOCK_K)] + constraints += [tkw.TilingConstraint(K, BLOCK_K, ARGK)] constraints += [ tkw.HardwareConstraint( threads_per_wave=64, @@ -144,6 +172,24 @@ def test_read_write(): # CHECK-SAME: (%read_1_0_0, %c, 4) # CHECK-NEXT: %write_0_0_1 # CHECK-SAME: (%read_0_0_0, %c, 4) + # CHECK-NEXT: return None + + # Custom format: + # CHECK-NEXT: placeholder(_name=a + # CHECK-NEXT: placeholder(_name=c + # CHECK-NEXT: read(memory=a + # CHECK-SAME: index={M: BLOCK_M*WG0, N: BLOCK_N*WG1} + # CHECK-NEXT: read(memory=a + # CHECK-SAME: index={M: BLOCK_M*WG0, N: BLOCK_N*WG1} + # CHECK-NEXT: write(register_=read_0_0_0 + # CHECK-SAME: index={M: BLOCK_M*WG0, K: ARGK*BLOCK_K} + # CHECK-NEXT: write(register_=read_1_0_0 + # CHECK-SAME: index={M: BLOCK_M*WG0, K: ARGK*BLOCK_K} + # CHECK-NEXT: write(register_=read_1_0_0 + # CHECK-SAME: index={M: BLOCK_M*WG0, K: ARGK*BLOCK_K} + # CHECK-NEXT: write(register_=read_0_0_0 + # CHECK-SAME: index={M: BLOCK_M*WG0, K: ARGK*BLOCK_K} + # CHECK-NEXT: output # CHECK: ----- @@ -170,7 +216,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: def test_gemm(): constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] - constraints += [tkw.TilingConstraint(K, BLOCK_K)] + constraints += [tkw.TilingConstraint(K, BLOCK_K, ARGK)] constraints += [ tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(1, 1, 1)) ] @@ -207,6 +253,30 @@ def test_gemm(): # CHECK-SAME: (%getresult_1_0_0, %c, 4) # CHECK-NEXT: %write_0_1_0 # CHECK-SAME: (%getresult_0_1_0, %c, 4) + # CHECK-NEXT: return None + + # Custom format: + # CHECK-NEXT: placeholder(_name=a + # CHECK-NEXT: placeholder(_name=b + # CHECK-NEXT: placeholder(_name=c + # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: BLOCK_M*WG0, N: BLOCK_N*WG1}) + # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: BLOCK_M*WG0, N: BLOCK_N*WG1}) + # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: BLOCK_M*WG0, N: BLOCK_N*WG1}) + # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: BLOCK_M*WG0, N: BLOCK_N*WG1}) + # CHECK-NEXT: reduction(axis=K, init_args=[register_0_0_0, register_0_1_0, register_1_0_0, register_1_1_0], subgraph_name=region_0, implicit_captures=[a, b], index={M: BLOCK_M*WG0, N: BLOCK_N*WG1}) + # CHECK-NEXT: get_result(value=reduction, res_idx=3) + # CHECK-NEXT: get_result(value=reduction, res_idx=2) + # CHECK-NEXT: get_result(value=reduction, res_idx=1) + # CHECK-NEXT: get_result(value=reduction, res_idx=0) + # CHECK-NEXT: write(register_=getresult_0_0_0 + # CHECK-SAME: index={M: BLOCK_M*WG0, N: BLOCK_N*WG1} + # CHECK-NEXT: write(register_=getresult_1_1_0 + # CHECK-SAME: index={M: BLOCK_M*WG0, N: BLOCK_N*WG1} + # CHECK-NEXT: write(register_=getresult_1_0_0 + # CHECK-SAME: index={M: BLOCK_M*WG0, N: BLOCK_N*WG1} + # CHECK-NEXT: write(register_=getresult_0_1_0 + # CHECK-SAME: index={M: BLOCK_M*WG0, N: BLOCK_N*WG1} + # CHECK-NEXT: output # Reduction subgraph: @@ -253,6 +323,47 @@ def test_gemm(): # CHECK-SAME: (%read_0_0_1, %read_0_1_1, %mma_0_1_0) # CHECK-NEXT: return [mma_0_0_1, mma_1_1_1, mma_1_0_1, mma_0_1_1] + # Custom format: + # CHECK-NEXT: placeholder(_name=acc_0_0_0 + # CHECK-NEXT: placeholder(_name=acc_1_1_0 + # CHECK-NEXT: placeholder(_name=acc_1_0_0 + # CHECK-NEXT: placeholder(_name=acc_0_1_0 + # CHECK-NEXT: placeholder(_name=a + # CHECK-NEXT: read(memory=a, elements_per_thread=4, index={M: BLOCK_M*WG0, K: ARGK*BLOCK_K}) + # CHECK-NEXT: read(memory=a, elements_per_thread=4, index={M: BLOCK_M*WG0, K: ARGK*BLOCK_K}) + # CHECK-NEXT: read(memory=a, elements_per_thread=4, index={M: BLOCK_M*WG0, K: ARGK*BLOCK_K}) + # CHECK-NEXT: read(memory=a, elements_per_thread=4, index={M: BLOCK_M*WG0, K: ARGK*BLOCK_K}) + # CHECK-NEXT: placeholder(_name=b + # CHECK-NEXT: read(memory=b, elements_per_thread=4, index={N: BLOCK_N*WG1, K: ARGK*BLOCK_K}) + # CHECK-NEXT: read(memory=b, elements_per_thread=4, index={N: BLOCK_N*WG1, K: ARGK*BLOCK_K}) + # CHECK-NEXT: read(memory=b, elements_per_thread=4, index={N: BLOCK_N*WG1, K: ARGK*BLOCK_K}) + # CHECK-NEXT: read(memory=b, elements_per_thread=4, index={N: BLOCK_N*WG1, K: ARGK*BLOCK_K}) + # CHECK-NEXT: mma(lhs=read_0_0_0 (index = [BLOCK_M*WG0 + Mod(T0, 16), ARGK*BLOCK_K + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 1]) + # CHECK-SAME: rhs=read_0_0_0 (index = [BLOCK_N*WG1 + Mod(T0, 16), ARGK*BLOCK_K + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 1]) + # CHECK-SAME: acc=acc_0_0_0 (index = [BLOCK_M*WG0 + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 16, BLOCK_N*WG1 + Mod(T0, 16)])) + # CHECK-NEXT: mma(lhs=read_0_0_1 (index = [BLOCK_M*WG0 + Mod(T0, 16), ARGK*BLOCK_K + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 1]) + # CHECK-SAME: rhs=read_0_0_1 (index = [BLOCK_N*WG1 + Mod(T0, 16), ARGK*BLOCK_K + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 1]) + # CHECK-SAME: acc=mma_0_0_0 (index = None)) + # CHECK-NEXT: mma(lhs=read_1_0_0 (index = [BLOCK_M*WG0 + Mod(T0, 16), ARGK*BLOCK_K + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 1]) + # CHECK-SAME: rhs=read_0_1_0 (index = [BLOCK_N*WG1 + Mod(T0, 16), ARGK*BLOCK_K + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 1]) + # CHECK-SAME: acc=acc_1_1_0 (index = [BLOCK_M*WG0 + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 16, BLOCK_N*WG1 + Mod(T0, 16)])) + # CHECK-NEXT: mma(lhs=read_1_0_1 (index = [BLOCK_M*WG0 + Mod(T0, 16), ARGK*BLOCK_K + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 1]) + # CHECK-SAME: rhs=read_0_1_1 (index = [BLOCK_N*WG1 + Mod(T0, 16), ARGK*BLOCK_K + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 1]) + # CHECK-SAME: acc=mma_1_1_0 (index = None)) + # CHECK-NEXT: mma(lhs=read_1_0_0 (index = [BLOCK_M*WG0 + Mod(T0, 16), ARGK*BLOCK_K + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 1]) + # CHECK-SAME: rhs=read_0_0_0 (index = [BLOCK_N*WG1 + Mod(T0, 16), ARGK*BLOCK_K + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 1]) + # CHECK-SAME: acc=acc_1_0_0 (index = [BLOCK_M*WG0 + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 16, BLOCK_N*WG1 + Mod(T0, 16)])) + # CHECK-NEXT: mma(lhs=read_1_0_1 (index = [BLOCK_M*WG0 + Mod(T0, 16), ARGK*BLOCK_K + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 1]) + # CHECK-SAME: rhs=read_0_0_1 (index = [BLOCK_N*WG1 + Mod(T0, 16), ARGK*BLOCK_K + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 1]) + # CHECK-SAME: acc=mma_1_0_0 (index = None)) + # CHECK-NEXT: mma(lhs=read_0_0_0 (index = [BLOCK_M*WG0 + Mod(T0, 16), ARGK*BLOCK_K + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 1]) + # CHECK-SAME: rhs=read_0_1_0 (index = [BLOCK_N*WG1 + Mod(T0, 16), ARGK*BLOCK_K + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 1]) + # CHECK-SAME: acc=acc_0_1_0 (index = [BLOCK_M*WG0 + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 16, BLOCK_N*WG1 + Mod(T0, 16)])) + # CHECK-NEXT: mma(lhs=read_0_0_1 (index = [BLOCK_M*WG0 + Mod(T0, 16), ARGK*BLOCK_K + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 1]) + # CHECK-SAME: rhs=read_0_1_1 (index = [BLOCK_N*WG1 + Mod(T0, 16), ARGK*BLOCK_K + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 1]) + # CHECK-SAME: acc=mma_0_1_0 (index = None)) + # CHECK-NEXT: output(return_vals=([mma_0_0_1, mma_1_1_1, mma_1_0_1, mma_0_1_1],)) + # CHECK-NEXT: ----- @@ -262,7 +373,7 @@ def test_gemm_reduction_expansion_only(): # gemm kernel to test the expansion of the reduction subgraph. constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] - constraints += [tkw.TilingConstraint(K, BLOCK_K)] + constraints += [tkw.TilingConstraint(K, BLOCK_K, ARGK)] constraints += [ tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(1, 1, 1)) ] @@ -290,6 +401,23 @@ def test_gemm_reduction_expansion_only(): # CHECK-SAME: (%getresult_0_0_0, %c, 4) # CHECK-NEXT: %write_0_1_0 # CHECK-SAME: (%getresult_0_1_0, %c, 4) + # CHECK-NEXT: return None + + # Custom format: + # CHECK-NEXT: placeholder(_name=a + # CHECK-NEXT: placeholder(_name=b + # CHECK-NEXT: placeholder(_name=c + # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: BLOCK_M*WG0, N: BLOCK_N*WG1}) + # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: BLOCK_M*WG0, N: BLOCK_N*WG1}) + # CHECK-NEXT: reduction(axis=K, init_args=[register_0_0_0, register_0_1_0] + # CHECK-SAME: index={M: BLOCK_M*WG0, N: BLOCK_N*WG1} + # CHECK-NEXT: get_result(value=reduction, res_idx=1) + # CHECK-NEXT: get_result(value=reduction, res_idx=0) + # CHECK-NEXT: write(register_=getresult_0_0_0 + # CHECK-SAME: index={M: BLOCK_M*WG0, N: BLOCK_N*WG1}) + # CHECK-NEXT: write(register_=getresult_0_1_0 + # CHECK-SAME: index={M: BLOCK_M*WG0, N: BLOCK_N*WG1}) + # CHECK-NEXT: output(return_vals=(None,)) # Reduction subgraph: @@ -343,6 +471,50 @@ def test_gemm_reduction_expansion_only(): # CHECK-NEXT: return [mma_0_0_3, mma_0_1_3] + # Custom format: + + # CHECK-NEXT: placeholder(_name=acc_0_0_0 + # CHECK-NEXT: placeholder(_name=acc_0_1_0 + # CHECK-NEXT: placeholder(_name=a + # CHECK-NEXT: read(memory=a, elements_per_thread=4, index={M: BLOCK_M*WG0, K: ARGK*BLOCK_K}) + # CHECK-NEXT: read(memory=a, elements_per_thread=4, index={M: BLOCK_M*WG0, K: ARGK*BLOCK_K}) + # CHECK-NEXT: read(memory=a, elements_per_thread=4, index={M: BLOCK_M*WG0, K: ARGK*BLOCK_K}) + # CHECK-NEXT: read(memory=a, elements_per_thread=4, index={M: BLOCK_M*WG0, K: ARGK*BLOCK_K}) + # CHECK-NEXT: placeholder(_name=b + # CHECK-NEXT: read(memory=b, elements_per_thread=4, index={N: BLOCK_N*WG1, K: ARGK*BLOCK_K}) + # CHECK-NEXT: read(memory=b, elements_per_thread=4, index={N: BLOCK_N*WG1, K: ARGK*BLOCK_K}) + # CHECK-NEXT: read(memory=b, elements_per_thread=4, index={N: BLOCK_N*WG1, K: ARGK*BLOCK_K}) + # CHECK-NEXT: read(memory=b, elements_per_thread=4, index={N: BLOCK_N*WG1, K: ARGK*BLOCK_K}) + # CHECK-NEXT: read(memory=b, elements_per_thread=4, index={N: BLOCK_N*WG1, K: ARGK*BLOCK_K}) + # CHECK-NEXT: read(memory=b, elements_per_thread=4, index={N: BLOCK_N*WG1, K: ARGK*BLOCK_K}) + # CHECK-NEXT: read(memory=b, elements_per_thread=4, index={N: BLOCK_N*WG1, K: ARGK*BLOCK_K}) + # CHECK-NEXT: read(memory=b, elements_per_thread=4, index={N: BLOCK_N*WG1, K: ARGK*BLOCK_K}) + # CHECK-NEXT: mma(lhs=read_0_0_0 (index = [BLOCK_M*WG0 + Mod(T0, 16), ARGK*BLOCK_K + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 1]) + # CHECK-SAME: rhs=read_0_0_0 (index = [BLOCK_N*WG1 + Mod(T0, 16), ARGK*BLOCK_K + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 1]) + # CHECK-SAME: acc=acc_0_0_0 (index = [BLOCK_M*WG0 + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 16, BLOCK_N*WG1 + Mod(T0, 16)])) + # CHECK-NEXT: mma(lhs=read_0_0_1 (index = [BLOCK_M*WG0 + Mod(T0, 16), ARGK*BLOCK_K + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 1]) + # CHECK-SAME: rhs=read_0_0_1 (index = [BLOCK_N*WG1 + Mod(T0, 16), ARGK*BLOCK_K + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 1]) + # CHECK-SAME: acc=mma_0_0_0 (index = None)) + # CHECK-NEXT: mma(lhs=read_0_0_2 (index = [BLOCK_M*WG0 + Mod(T0, 16), ARGK*BLOCK_K + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 1]) + # CHECK-SAME: rhs=read_0_0_2 (index = [BLOCK_N*WG1 + Mod(T0, 16), ARGK*BLOCK_K + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 1]) + # CHECK-SAME: acc=mma_0_0_1 (index = None)) + # CHECK-NEXT: mma(lhs=read_0_0_3 (index = [BLOCK_M*WG0 + Mod(T0, 16), ARGK*BLOCK_K + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 1]) + # CHECK-SAME: rhs=read_0_0_3 (index = [BLOCK_N*WG1 + Mod(T0, 16), ARGK*BLOCK_K + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 1]) + # CHECK-SAME: acc=mma_0_0_2 (index = None)) + # CHECK-NEXT: mma(lhs=read_0_0_0 (index = [BLOCK_M*WG0 + Mod(T0, 16), ARGK*BLOCK_K + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 1]) + # CHECK-SAME: rhs=read_0_1_0 (index = [BLOCK_N*WG1 + Mod(T0, 16), ARGK*BLOCK_K + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 1]) + # CHECK-SAME: acc=acc_0_1_0 (index = [BLOCK_M*WG0 + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 16, BLOCK_N*WG1 + Mod(T0, 16)])) + # CHECK-NEXT: mma(lhs=read_0_0_1 (index = [BLOCK_M*WG0 + Mod(T0, 16), ARGK*BLOCK_K + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 1]) + # CHECK-SAME: rhs=read_0_1_1 (index = [BLOCK_N*WG1 + Mod(T0, 16), ARGK*BLOCK_K + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 1]) + # CHECK-SAME: acc=mma_0_1_0 (index = None)) + # CHECK-NEXT: mma(lhs=read_0_0_2 (index = [BLOCK_M*WG0 + Mod(T0, 16), ARGK*BLOCK_K + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 1]) + # CHECK-SAME: rhs=read_0_1_2 (index = [BLOCK_N*WG1 + Mod(T0, 16), ARGK*BLOCK_K + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 1]) + # CHECK-SAME: acc=mma_0_1_1 (index = None)) + # CHECK-NEXT: mma(lhs=read_0_0_3 (index = [BLOCK_M*WG0 + Mod(T0, 16), ARGK*BLOCK_K + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 1]) + # CHECK-SAME: rhs=read_0_1_3 (index = [BLOCK_N*WG1 + Mod(T0, 16), ARGK*BLOCK_K + 16*T1 + 16*T2 + 4*floor(T0/16) : 4 : 1]) + # CHECK-SAME: acc=mma_0_1_2 (index = None)) + # CHECK-NEXT: output(return_vals=([mma_0_0_3, mma_0_1_3],)) + # CHECK-NEXT: ----- @@ -406,6 +578,23 @@ def py_arithmetic_different_dims(): # CHECK-SAME: (%neg_1_0_0, %c, 4) # CHECK-NEXT: %write_0_0_1 # CHECK-SAME: (%neg_0_0_0, %c, 4) + # CHECK-NEXT: return None + + # Custom format: + # CHECK-NEXT: placeholder(_name=a + # CHECK-NEXT: placeholder(_name=c + # CHECK-NEXT: read(memory=a, elements_per_thread=4, index={M: BLOCK_M*WG0, N: BLOCK_N*WG1}) + # CHECK-NEXT: read(memory=a, elements_per_thread=4, index={M: BLOCK_M*WG0, N: BLOCK_N*WG1}) + # CHECK-NEXT: add(lhs=read_0_0_0, rhs=read_0_0_0, index={M: BLOCK_M*WG0, N: BLOCK_N*WG1}) + # CHECK-NEXT: add(lhs=read_1_0_0, rhs=read_1_0_0, index={M: BLOCK_M*WG0, N: BLOCK_N*WG1}) + # CHECK-NEXT: sub(lhs=add_0_0_0, rhs=read_0_0_0, index={M: BLOCK_M*WG0, N: BLOCK_N*WG1}) + # CHECK-NEXT: sub(lhs=add_1_0_0, rhs=read_1_0_0, index={M: BLOCK_M*WG0, N: BLOCK_N*WG1}) + # CHECK-NEXT: neg(arg=sub_0_0_0, index={M: BLOCK_M*WG0, N: BLOCK_N*WG1}) + # CHECK-NEXT: neg(arg=sub_1_0_0, index={M: BLOCK_M*WG0, N: BLOCK_N*WG1}) + # CHECK-NEXT: write(register_=neg_0_0_0, memory=c, elements_per_thread=4, index={M: BLOCK_M*WG0, K: BLOCK_K*WG2}) + # CHECK-NEXT: write(register_=neg_1_0_0, memory=c, elements_per_thread=4, index={M: BLOCK_M*WG0, K: BLOCK_K*WG2}) + # CHECK-NEXT: write(register_=neg_1_0_0, memory=c, elements_per_thread=4, index={M: BLOCK_M*WG0, K: BLOCK_K*WG2}) + # CHECK-NEXT: write(register_=neg_0_0_0, memory=c, elements_per_thread=4, index={M: BLOCK_M*WG0, K: BLOCK_K*WG2}) # CHECK: ----- diff --git a/shark_turbine/kernel/ops/wave_ops.py b/shark_turbine/kernel/ops/wave_ops.py index 5bd8418d..4c9f3555 100644 --- a/shark_turbine/kernel/ops/wave_ops.py +++ b/shark_turbine/kernel/ops/wave_ops.py @@ -22,6 +22,7 @@ from .._support.dtype import DataType from .._support.regions import RegionGraph from .base import OpDispatcher +import shark_turbine.kernel.lang as tkl T = TypeVar("T", bound=Type[Any]) AccT = TypeVar("AccT") @@ -171,13 +172,15 @@ class CustomOp(ABC): fx_node: Optional[fx.Node] = field(default=None, init=False) tkw_op_name: str = field(default="unknown", init=False) _tracing_function: Optional[Callable[..., Any]] = field(default=None, init=False) - index: Optional[IndexExpr] = field(default=None, init=False) + index: Optional[dict[IndexSymbol, IndexSequence]] = field(default=None, init=False) @classmethod def from_fx_node(cls: Type[CustomOpT], node: fx.Node) -> CustomOpT: instance = cls(*node.args) instance.fx_node = node instance.graph = node.graph + if hasattr(node, "index"): + instance.index = node.index return instance def __post_init__(self): @@ -198,7 +201,14 @@ def __str__(self) -> str: def custom_string(self, value_map: dict[str, str]) -> str: # print all variables of the node apart from graph and fx_node - vars_list = [f"{key}={value}" for key, value in vars(self).items()][:-2] + ignore_list = ["fx_node", "graph"] + if self.index is None: + ignore_list += ["index"] + vars_list = [ + f"{key}={value}" + for key, value in vars(self).items() + if key not in ignore_list + ] vars_str = ", ".join(vars_list) return f"{self.tkw_op_name}({vars_str})" @@ -213,6 +223,7 @@ def add_to_graph(self, region_graph: RegionGraph) -> fx.Node: ) self.fx_node.tkw_op = self.__class__ self.fx_node.tkw_op_name = self.tkw_op_name + self.fx_node.index = None return self.fx_node def _add_proxy_to_graph(self, region_graph: RegionGraph): @@ -259,6 +270,7 @@ def copy( graph.inserting_after(self.fx_node) new_node = graph.node_copy(self.fx_node) new_node.tkw_op = self + new_node.index = self.fx_node.index if new_name: new_node.name = new_name return get_custom(new_node) @@ -512,6 +524,40 @@ def rhs_type(self) -> Memory: def acc_type(self) -> Memory: return get_custom(self.acc).type + def operand_index( + self, operand_map: dict[IndexSymbol, int], shape: list[IndexExpr] + ) -> list[IndexSequence]: + indices: list[IndexSequence] = [] + for dim in shape: + indices.append(self.index[dim].subs(operand_map)) + return indices + + @property + def lhs_index(self) -> list[IndexSequence]: + operand_map = {tkl.sym.MMA_LHS: 1, tkl.sym.MMA_RHS: 0, tkl.sym.MMA_ACC: 0} + return self.operand_index(operand_map, self.lhs_type.symbolic_shape) + + @property + def rhs_index(self) -> list[IndexSequence]: + operand_map = {tkl.sym.MMA_LHS: 0, tkl.sym.MMA_RHS: 1, tkl.sym.MMA_ACC: 0} + return self.operand_index(operand_map, self.rhs_type.symbolic_shape) + + @property + def acc_index(self) -> list[IndexSequence]: + operand_map = {tkl.sym.MMA_LHS: 0, tkl.sym.MMA_RHS: 0, tkl.sym.MMA_ACC: 1} + if self.acc.type is None: + return None + return self.operand_index(operand_map, self.acc_type.symbolic_shape) + + def custom_string(self, value_map: dict[str, str]) -> str: + if self.index is None: + return super().custom_string(value_map) + custom_str = f"{self.tkw_op_name}(" + custom_str += f"lhs={self.lhs} (index = {self.lhs_index}), " + custom_str += f"rhs={self.rhs} (index = {self.rhs_index}), " + custom_str += f"acc={self.acc} (index = {self.acc_index}))" + return custom_str + @define_op("read") @dataclass diff --git a/shark_turbine/kernel/wave/constraints.py b/shark_turbine/kernel/wave/constraints.py index 82794de8..4d98c9f1 100644 --- a/shark_turbine/kernel/wave/constraints.py +++ b/shark_turbine/kernel/wave/constraints.py @@ -3,9 +3,11 @@ from enum import Enum from typing import Optional import shark_turbine.kernel.lang as tkl -from sympy import ceiling +from sympy import ceiling, Piecewise, floor from .._support.indexing import IndexExpr, IndexSymbol +from .indexing import IndexSequence +from .distribution_symbols import * class MMAType(Enum): @@ -25,8 +27,8 @@ class Constraint(ABC): """ @abstractmethod - def apply(self) -> IndexExpr: - """Apply the constraint and get the resulting index expression.""" + def apply(self) -> IndexSequence: + """Apply the constraint and get the resulting index sequence.""" ... @@ -52,8 +54,13 @@ class HardwareConstraint(Constraint): mma_type: Optional[MMAType] = MMAType.F32_16x16x16_F16 vector_shapes: Optional[dict[IndexSymbol, int]] = None + def __post_init__(self): + self.LHS = tkl.sym.MMA_LHS + self.RHS = tkl.sym.MMA_RHS + self.ACC = tkl.sym.MMA_ACC + @property - def mma_matrix_shapes(self): + def mma_matrix_shapes(self) -> tuple[int]: # TODO: Eventually the shapes and indices should be provided by a tool match self.mma_type: case MMAType.F32_16x16x16_F16: @@ -63,8 +70,47 @@ def mma_matrix_shapes(self): case _: return () - def apply(self) -> IndexExpr: - raise NotImplementedError("Not yet implemented") + @property + def threads_per_block(self) -> tuple[int]: + return ( + self.waves_per_block[0] * self.threads_per_wave, + ) + self.waves_per_block[1:] + + @property + def linearized_thread_id(self) -> IndexExpr: + thread_ids = [THREAD_0, THREAD_1, THREAD_2] + threads_per_block = ( + [1] + + [self.threads_per_block[0]] + + [self.threads_per_block[0] * self.threads_per_block[1]] + ) + return sum([x * y for x, y in zip(thread_ids, threads_per_block)]) + + def apply(self, mma_index: int) -> IndexSequence: + lane = self.linearized_thread_id + match self.mma_type: + # (M x K, N x K) -> M x N + case MMAType.F32_16x16x16_F16: + offset = { + 0: Piecewise( + (lane % 16, ~self.ACC), (4 * floor(lane / 16), self.ACC) + ), # M + 1: lane % 16, # N + 2: 4 * floor(lane / 16), # K + } + size = { + 0: Piecewise((0, ~self.ACC), (4, self.ACC)), # M + 1: 0, # N + 2: 4, # K + } + stride = { + 0: Piecewise((1, ~self.ACC), (16, self.ACC)), # M + 1: 1, # N + 2: 1, # K + } + return IndexSequence( + offset[mma_index], size[mma_index], stride[mma_index] + ) @dataclass @@ -81,17 +127,17 @@ class WorkgroupConstraint(Constraint): tile_size: IndexExpr workgroup_dim: int - def apply(self) -> IndexExpr: + def apply(self) -> IndexSequence: match self.workgroup_dim: case 0: - wg_dim = tkl.sym.WG0 + wg_dim = WORKGROUP_0 case 1: - wg_dim = tkl.sym.WG1 + wg_dim = WORKGROUP_1 case 2: - wg_dim = tkl.sym.WG2 + wg_dim = WORKGROUP_2 case _: raise ValueError("Invalid workgroup dimension. Expected 0, 1 or 2.") - return wg_dim * self.tile_size + return IndexSequence(wg_dim * self.tile_size, 1) def get_grid_shape(wg_constraints: list[WorkgroupConstraint]) -> list[IndexExpr]: @@ -130,9 +176,9 @@ def iterations(self) -> IndexExpr: """ return ceiling(self.dim / self.tile_size) - def apply(self) -> IndexExpr: + def apply(self) -> IndexSequence: if self.induction_var is None: raise ValueError( "Index is being computed without setting induction variable" ) - return self.induction_var * self.tile_size + return IndexSequence(self.induction_var * self.tile_size, 1) diff --git a/shark_turbine/kernel/wave/distribution_symbols.py b/shark_turbine/kernel/wave/distribution_symbols.py new file mode 100644 index 00000000..a508373a --- /dev/null +++ b/shark_turbine/kernel/wave/distribution_symbols.py @@ -0,0 +1,9 @@ +import shark_turbine.kernel.lang as tkl + +WORKGROUP_0 = tkl.sym.WG0 +WORKGROUP_1 = tkl.sym.WG1 +WORKGROUP_2 = tkl.sym.WG2 + +THREAD_0 = tkl.sym.T0 +THREAD_1 = tkl.sym.T1 +THREAD_2 = tkl.sym.T2 diff --git a/shark_turbine/kernel/wave/expansion.py b/shark_turbine/kernel/wave/expansion.py index 42438b49..f744f5a4 100644 --- a/shark_turbine/kernel/wave/expansion.py +++ b/shark_turbine/kernel/wave/expansion.py @@ -12,6 +12,7 @@ from .._support.indexing import IndexingContext from ...support.logging import get_logger from .._support.tracing import CapturedTrace +from .indexing import IndexSequence logger = get_logger("turbine.wave.expansion") # This represents a mapping of a node + indexing into the dimensions to the @@ -87,7 +88,7 @@ def is_expandable(arg: Any) -> bool: return isinstance(arg, CustomOp) -def get_mma_dimensional_mapping(trace: CapturedTrace) -> dict[fx.Node, int]: +def get_mma_dimensional_mapping(trace: CapturedTrace) -> dict[IndexSymbol, int]: """ Given a trace, determine the MMA dimensional mapping for all the MMA operations in the graph. For example, if we have @@ -103,7 +104,7 @@ def is_mma(node: fx.Node) -> bool: return True mma_nodes = trace.walk(is_mma) - mapping: dict[fx.Node, int] = {} + mapping: dict[IndexSymbol, int] = {} for node in mma_nodes: custom: MMA = get_custom(node) m, n = custom.acc_type.symbolic_shape[-2:] @@ -120,6 +121,50 @@ def is_mma(node: fx.Node) -> bool: return mapping +def set_node_indices( + trace: CapturedTrace, + constraints: Sequence[Constraint], + mma_index: dict[IndexSymbol, int], +): + """ + Set the indices of the nodes based on the user constraints. In certain + operators (like read, write), there is only a single index associated + with the node (the index to read from, the index to write to). But for + other operators like mma, each operand reads from a different index. + + Rather than maintain operand specific indices for operators, we maintain + dimension specific indices for each operator. So for an mma operator that + has a signature of (MxK, NxK) -> MxN, we maintain only 3 mappings for + dimensions M, N and K, but allow each mapping to be piecewise conditioned + on the operand. + """ + + def compute_index(node: fx.Node) -> bool: + custom = get_custom(node) + custom.index = {} + for dim in custom.indexing_dims: + for constraint in constraints: + if ( + not isinstance(constraint, HardwareConstraint) + and dim != constraint.dim + ): + continue + if not custom.index: + custom.index = { + dim: IndexSequence(0, 1) for dim in custom.indexing_dims + } + if isinstance(constraint, HardwareConstraint): + if dim in mma_index and isinstance(custom, MMA): + custom.index[dim] += constraint.apply(mma_index[dim]) + elif dim == constraint.dim: + custom.index[dim] += constraint.apply() + if custom.index: + setattr(custom.fx_node, "index", custom.index) + return False + + trace.walk(compute_index) + + def expand_graph( trace: CapturedTrace, constraints_or_scaling: Sequence[Constraint] | dict[IndexSymbol, int], @@ -132,6 +177,7 @@ def expand_graph( dim_scaling = constraints_or_scaling else: mma_index = get_mma_dimensional_mapping(trace) + set_node_indices(trace, constraints_or_scaling, mma_index) dim_scaling = get_dim_scaling(constraints_or_scaling, mma_index) # Start from the back and expand in the corresponding indexing dimensions of a node @@ -142,6 +188,7 @@ def expand_graph( for node in ( get_custom(fx_node) for fx_node in reversed(list(trace.get_root_graph().nodes)) ): + # Expansion begins at the leaf nodes if node.__class__ not in leaf_nodes: continue @@ -276,7 +323,7 @@ def get_expanded_name(node: CustomOp, dims: dict[IndexSymbol, int]) -> str: def get_dim_scaling( - constraints: Sequence[Constraint], mma_indices: dict[IndexExpr, int] + constraints: Sequence[Constraint], mma_indices: dict[IndexSymbol, int] ) -> dict[IndexSymbol, int]: """Get the number of expansions for the dimensions based on the constraints.""" dim_scaling: dict[IndexSymbol, int] = {} diff --git a/shark_turbine/kernel/wave/indexing.py b/shark_turbine/kernel/wave/indexing.py new file mode 100644 index 00000000..a664a75a --- /dev/null +++ b/shark_turbine/kernel/wave/indexing.py @@ -0,0 +1,43 @@ +from dataclasses import dataclass +from typing import Optional, Any +from ...support.logging import get_logger +from .._support.indexing import IndexExpr, IndexSymbol +import sympy + +logger = get_logger("turbine.wave.indexing") + + +@dataclass +class IndexSequence: + start: IndexExpr | int + size: IndexExpr | int + stride: Optional[IndexExpr | int] = 1 + + def __add__(self, other: Any) -> Any: + if isinstance(other, IndexSequence): + return IndexSequence( + self.start + other.start, + self.size * other.size, + self.stride * other.stride, + ) + else: + raise NotImplementedError("IndexSequence addition not implemented!") + + def subs(self, map: dict[IndexSymbol, int]): + start = self.start + if isinstance(self.start, IndexExpr): + start = self.start.subs(map) + size = self.size + if isinstance(self.size, IndexExpr): + size = self.size.subs(map) + stride = self.stride + if isinstance(self.stride, IndexExpr): + stride = self.stride.subs(map) + return IndexSequence(start, size, stride) + + def __repr__(self): + if isinstance(self.size, sympy.Integer): + self.size = int(self.size) + if isinstance(self.size, int) and self.size <= 1: + return f"{self.start}" + return f"{self.start} : {self.size} : {self.stride}" diff --git a/shark_turbine/kernel/wave/wave.py b/shark_turbine/kernel/wave/wave.py index cf1470f3..daff4827 100644 --- a/shark_turbine/kernel/wave/wave.py +++ b/shark_turbine/kernel/wave/wave.py @@ -1,4 +1,5 @@ from typing import Any, Callable, Optional +import torch.fx as fx import inspect from ..compiler import builder, dispatch_codegen, kernel_codegen @@ -6,6 +7,7 @@ from .codegen import WaveEmitter from .constraints import ( Constraint, + TilingConstraint, WorkgroupConstraint, get_grid_shape, ) @@ -13,7 +15,9 @@ from .expansion import expand_graph from ..lang import Grid from ..ops import wave_ops -from .._support.indexing import IndexingContext +from ..ops.wave_ops import Reduction, CustomOp, get_custom +from .._support.indexing import IndexingContext, IndexExpr +import shark_turbine.kernel.lang as tkl from .._support.tracing import ( CapturedTrace, CompiledContext, @@ -49,6 +53,7 @@ def __init__( super().__init__(eager_function) self.constraints = constraints if constraints else [] + self.induction_vars: dict[CustomOp, IndexExpr] = {} self._name = name self._f = eager_function self._sig = inspect.signature(eager_function) @@ -63,6 +68,14 @@ def workgroup_constraints(self) -> list[WorkgroupConstraint]: if isinstance(constraint, WorkgroupConstraint) ] + @property + def tiling_constraints(self) -> list[TilingConstraint]: + return [ + constraint + for constraint in self.constraints + if isinstance(constraint, TilingConstraint) + ] + def _trace(self) -> CapturedTrace: region_graph = KernelRegionGraph() with CompiledContext(region_graph, grid_type=self.grid_type) as context: @@ -83,6 +96,28 @@ def _trace(self) -> CapturedTrace: return trace + def create_induction_vars(self, trace: CapturedTrace) -> None: + """ + Creates induction variables for all the reductions in the graph + and associates tiling constraints all the reduction dimensions + with the appropriate induction variables. + + """ + + def is_reduction(node: fx.Node): + custom = get_custom(node) + if isinstance(custom, Reduction): + return True + return False + + reduction_nodes = trace.walk(is_reduction) + for node in reduction_nodes: + custom = get_custom(node) + self.induction_vars[custom] = tkl.IndexSymbol("ARG" + custom.axis.name) + for tiling_constraint in self.tiling_constraints: + if tiling_constraint.dim == custom.axis: + tiling_constraint.induction_var = self.induction_vars[custom] + def _trace_and_get_kernel_signature( self, args, @@ -93,6 +128,8 @@ def _trace_and_get_kernel_signature( # Trace the function. graph = self._trace() + self.create_induction_vars(graph) + idxc = IndexingContext.current() idxc.finalize() @@ -100,8 +137,11 @@ def _trace_and_get_kernel_signature( expand_graph(graph, self.constraints) kernel_sig = kernel_codegen.KernelSignature() - # Fixed values for now, will be determined through constraints - self.grid_type.dims = [32, 32] # Will be determined by constraints + self.grid_type.dims = [1, 1, 1] + for constraint in self.workgroup_constraints: + self.grid_type.dims[constraint.workgroup_dim] = ( + constraint.dim // constraint.tile_size + ).subs(idxc.subs) grid = self.grid_type mb = builder.ModuleBuilder(context=context, module_op=module_op) diff --git a/tests/kernel/wave/constraints_test.py b/tests/kernel/wave/constraints_test.py index 8c2ea8d0..76dc6028 100644 --- a/tests/kernel/wave/constraints_test.py +++ b/tests/kernel/wave/constraints_test.py @@ -44,7 +44,7 @@ def testTilingConstraint(self): assert constraints[0].iterations() == ceiling(M / BLOCK_M) assert constraints[1].iterations() == ceiling(N / BLOCK_N) - assert constraints[1].apply() == I * BLOCK_N + assert constraints[1].apply().start == I * BLOCK_N with pytest.raises( ValueError, diff --git a/tests/kernel/wave/wave_gemm_test.py b/tests/kernel/wave/wave_gemm_test.py index 69290917..80f45e7e 100644 --- a/tests/kernel/wave/wave_gemm_test.py +++ b/tests/kernel/wave/wave_gemm_test.py @@ -27,7 +27,7 @@ def testGemm(self): # Expose user-constraints constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] - constraints += [tkw.WorkgroupConstraint(K, BLOCK_K, 2)] + constraints += [tkw.TilingConstraint(K, BLOCK_K)] constraints += [ tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(1, 1, 1))