From d9e0b2448dcc0c777066d1159d0603fe438d4378 Mon Sep 17 00:00:00 2001 From: Stanley Winata Date: Fri, 25 Oct 2024 16:34:51 -0700 Subject: [PATCH] add expansion of ReduceOp Signed-off-by: Stanley Winata --- iree/turbine/kernel/ops/wave_ops.py | 2 +- iree/turbine/kernel/wave/expansion.py | 39 ++++++--- lit_tests/kernel/wave/expansion.py | 114 ++++++++++++++++++++++++++ 3 files changed, 144 insertions(+), 11 deletions(-) diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index b2224c3c..de9c3784 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -615,7 +615,7 @@ def type(self) -> Memory: raise ValueError( "BinaryPyOp requires lhs and rhs shape to be at least broadcastable." ) - broadcasted_type = lhs_type if lhs_dim_set > rhs_dim_set else rhstype + broadcasted_type = lhs_type if lhs_dim_set > rhs_dim_set else rhs_type return broadcasted_type diff --git a/iree/turbine/kernel/wave/expansion.py b/iree/turbine/kernel/wave/expansion.py index 517ea733..ee0a8b80 100644 --- a/iree/turbine/kernel/wave/expansion.py +++ b/iree/turbine/kernel/wave/expansion.py @@ -630,7 +630,7 @@ def get_output_index(custom: CustomOp): # Collect MMA and ReduceOp who's reduction axis matches parent ReductionOp. reduction_root_ops = [] for node in (get_custom(fx_node) for fx_node in reduction_subgraph.nodes): - if isinstance(node, MMA) and reduction.axis == node.reduction_dim: + if isinstance(node, (MMA, ReduceOp)) and reduction.axis == node.reduction_dim: reduction_root_ops.append(node) new_outputs = list(reduction.outputs(trace.get_subgraph(reduction.subgraph_name))) @@ -643,14 +643,33 @@ def get_output_index(custom: CustomOp): op_output_index = get_output_index(root_op) if dim_scaling[reduction.axis] <= 1: continue - latest_reduced_op = _expand_mma_tiled_reduction( - root_op, - trace, - dims, - dim_scaling, - context, - get_node_dim_scaling, - res_idx, - ) + if isinstance(root_op, MMA): + latest_reduced_op = _expand_mma_tiled_reduction( + root_op, + trace, + dims, + dim_scaling, + context, + get_node_dim_scaling, + res_idx, + ) + elif isinstance(root_op, ReduceOp): + original_src = latest_reduced_op.arg + for scale_idx in range(1, dim_scaling[reduction.axis]): + dims[root_op.reduction_dim] = scale_idx + current_src = latest_reduced_op.arg + if not isinstance(current_src, Sequence): + current_src = [current_src] + expanded_src = _expand_node( + get_custom(original_src), + trace, + dims, + dim_scaling, + context, + get_node_dim_scaling, + res_idx, + ) + current_src.append(expanded_src.fx_node) + latest_reduced_op.update_arg("arg", current_src) new_outputs[op_output_index] = latest_reduced_op.fx_node output.update_arg("return_vals", new_outputs) diff --git a/lit_tests/kernel/wave/expansion.py b/lit_tests/kernel/wave/expansion.py index de8a3b40..442cb9c3 100644 --- a/lit_tests/kernel/wave/expansion.py +++ b/lit_tests/kernel/wave/expansion.py @@ -550,6 +550,120 @@ def test_batched_gemm(): # CHECK-NEXT: ----- +@tkw.wave_trace_only() +def gemm_non_direct_acc( + a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], + b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32], +): + c_reg = tkl.Register[M, N, tkl.f32](0.0) + + @tkw.reduction(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: + a_reg = tkw.read(a, elements_per_thread=4) + b_reg = tkw.read(b, elements_per_thread=4) + new_acc = tkw.exp2(a_reg) + acc + acc = tkw.mma(a_reg, b_reg, new_acc) + return acc + + tkw.write(repeat, c, elements_per_thread=4) + + +@run_test +def test_gemm_non_direct_acc(): + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.TilingConstraint(K, BLOCK_K, ARGK)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2, THREAD_0 / 64)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2, THREAD_1)] + constraints += [ + tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(2, 2, 1)) + ] + with tk.gen.TestLaunchContext( + { + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K: 32, + } + ): + graph = gemm_non_direct_acc() + IndexingContext.current().finalize() + set_node_indices(graph, constraints) + expand_graph(graph, constraints) + set_post_expansion_indices(graph, constraints) + print_trace(graph) + # CHECK: %add_0_0_0 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_0_0_0, %acc_0_0_0), kwargs = {}) + # CHECK: %add_1_1_0 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_1_0_0, %acc_1_1_0), kwargs = {}) + # CHECK: %add_1_0_0 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_1_0_0, %acc_1_0_0), kwargs = {}) + # CHECK: %add_0_1_0 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_0_0_0, %acc_0_1_0), kwargs = {}) + # CHECK: %mma_0_0_0 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_0_0_0, %read_0_0_0, %add_0_0_0), kwargs = {}) + # CHECK: %mma_0_0_1 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_0_0_1, %read_0_0_1, %mma_0_0_0), kwargs = {}) + # CHECK: %mma_1_1_0 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_1_0_0, %read_0_1_0, %add_1_1_0), kwargs = {}) + # CHECK: %mma_1_1_1 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_1_0_1, %read_0_1_1, %mma_1_1_0), kwargs = {}) + # CHECK: %mma_1_0_0 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_1_0_0, %read_0_0_0, %add_1_0_0), kwargs = {}) + # CHECK: %mma_1_0_1 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_1_0_1, %read_0_0_1, %mma_1_0_0), kwargs = {}) + # CHECK: %mma_0_1_0 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_0_0_0, %read_0_1_0, %add_0_1_0), kwargs = {}) + # CHECK: %mma_0_1_1 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_0_0_1, %read_0_1_1, %mma_0_1_0), kwargs = {}) + + +@tkw.wave_trace_only() +def tiled_max( + a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, ADDRESS_SPACE, tkl.f16], +): + init_max = tkl.Register[M, tkl.f16](-1e6) + + @tkw.reduction(K, init_args=[init_max]) + def repeat(acc: tkl.Register[M, tkl.f16]) -> tkl.Register[M, tkl.f16]: + a_reg = tkw.read(a, elements_per_thread=4) + partial_max = tkw.max(a_reg, acc, dim=K) + return partial_max + + tkw.write(repeat, c, elements_per_thread=4) + + +@run_test +def test_tiled_max(): + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.TilingConstraint(K, BLOCK_K, ARGK)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2, THREAD_0 / 64)] + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(2, 1, 1), + vector_shapes={M: 16, K: 4}, + ) + ] + with tk.gen.TestLaunchContext( + { + BLOCK_M: 64, + BLOCK_K: 32, + } + ): + graph = tiled_max() + IndexingContext.current().finalize() + set_node_indices(graph, constraints) + expand_graph(graph, constraints) + set_post_expansion_indices(graph, constraints) + print_trace(graph) + # CHECK: max(arg=[read_0_0, read_0_1, read_0_2, read_0_3, read_0_4, read_0_5, read_0_6, read_0_7], init=acc_0_0 + # CHECK: max(arg=[read_1_0, read_1_1, read_1_2, read_1_3, read_1_4, read_1_5, read_1_6, read_1_7], init=acc_1_0 + # CHECK: output(return_vals=([max_0_0, max_1_0],)) + # CHECK-NEXT: ----- + + @run_test def test_gemm_reduction_expansion_only(): # Note: This does not implement an actual gemm computation but reuses the