diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index 1d225010..05963a4b 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -611,7 +611,7 @@ def type(self) -> Memory: raise ValueError( "BinaryPyOp requires lhs and rhs shape to be at least broadcastable." ) - broadcasted_type = lhs_type if lhs_dim_set > rhs_dim_set else rhstype + broadcasted_type = lhs_type if lhs_dim_set > rhs_dim_set else rhs_type return broadcasted_type @@ -1300,7 +1300,12 @@ def type(self) -> Memory: from ..wave.utils import all_equal src_types = [get_custom(arg).type for arg in self.arg] - if not all_equal(src_types): + ref_shape = src_types[0].symbolic_shape + ref_dtype = src_types[0].dtype + if not all( + src_type.symbolic_shape == ref_shape and src_type.dtype == ref_dtype + for src_type in src_types + ): raise NotImplementedError( "NYI: Only support case where all inputs to ReduceOp to have same type." ) @@ -1322,6 +1327,10 @@ def num_reduction_dims(self) -> int: else: return 1 + @property + def reduction_dim(self) -> IndexSymbol: + return self.dim + # TODO: Add support for more shuffle types. @define_op("shuffle") diff --git a/iree/turbine/kernel/wave/expansion.py b/iree/turbine/kernel/wave/expansion.py index b003016c..6ba9b52d 100644 --- a/iree/turbine/kernel/wave/expansion.py +++ b/iree/turbine/kernel/wave/expansion.py @@ -316,6 +316,7 @@ def _expand_reduction( dims = {dim: val for dim, val in zip(dim_scaling.keys(), dim_vals)} if not isinstance(return_vals, Sequence): return_vals = [return_vals] + # Proceed with expansion inside the reduction for arg_idx, arg in enumerate(return_vals): arg = get_custom(arg) # Add GetResult nodes for the corresponding dimensions @@ -327,33 +328,48 @@ def _expand_reduction( (reduction, get_indexed_dims(dims, expand_dims), arg_idx) ] = new_node - # Proceed with expansion inside the reduction - new_output_args.append( - _expand_node( - arg, - trace, - dims, - get_node_dim_scaling(arg), - context, - get_node_dim_scaling, - res_idx, - ) + expanded_output = _expand_node( + arg, + trace, + dims, + get_node_dim_scaling(arg), + context, + get_node_dim_scaling, + res_idx, ) + # If condition below is needed to skip over induction variable + # who doesn't have all dims of ReductionOp. For example, + # a reduction Op that has induction variables of types + # (max, mma) -> [M], [M, N] + # will have indexing dims of ([M, N]). + # However, the 1st induction variable won't expand in N-dim + # M:0, N:0 expand(max) -> max_0_0_0 + # M:0, N:1 expand(max) -> max_0_0_0 + # but will get added to the `new_output_args` without the if condition. + + # TODO: Handle expansion of induction variables with "non-complete" dims + # by checking on the indexing_dims on each induction variable. + if expanded_output in new_output_args: + continue + new_output_args.append(expanded_output) # Proceed with expansion outside the reduction for init_arg in reduction.init_args: custom_init_arg = get_custom(init_arg) - new_init_args.append( - _expand_node( - custom_init_arg, - trace, - dims, - get_node_dim_scaling(custom_init_arg), - context, - get_node_dim_scaling, - res_idx, - ) + expanded_init_arg = _expand_node( + custom_init_arg, + trace, + dims, + get_node_dim_scaling(custom_init_arg), + context, + get_node_dim_scaling, + res_idx, ) + # TODO: Handle expansion of induction variables with "non-complete" dims + # by checking on the indexing_dims on each induction variable. + if expanded_init_arg in new_init_args: + continue + new_init_args.append(expanded_init_arg) # Update init_args and return values reduction.update_arg( @@ -553,6 +569,54 @@ def get_dim_scaling( return dim_scaling +def _expand_mma_tiled_reduction( + mma: MMA, + trace: CapturedTrace, + dim_query: dict[IndexSymbol, int], + dim_scaling: dict[IndexSymbol, int], + context: ExpandedNodeMap, + get_node_dim_scaling: Callable[[fx.Node], dict[IndexSymbol, int]], + res_idx: int, +) -> CustomOp: + latest_reduced_op = mma + # The initial nodes are expanded in the first dimension, so we start from 1 + for scale_idx in range(1, dim_scaling[mma.reduction_dim]): + dim_query[mma.reduction_dim] = scale_idx + # Temporarily replace the loop carried arg here to avoid + # duplicated expansion. Otherwise we have the following situation: + # Suppose we have: + # mma_0_0_0(..., acc_0_0_0) + # mma_0_0_1(..., mma_0_0_0) + # Expanding mma_0_0_1 to mma_0_0_2 will trigger expansion of its arg + # mma_0_0_0 in dims 0_0_2 as well, effectively duplicating the new node. + # To avoid this we temporarily replace the use of it with a dummy + # placeholder which will not trigger further expansion. + dummy = Placeholder("dummy").add_to_graph(latest_reduced_op.graph) + dummy.type = None + + saved_acc = latest_reduced_op.acc + latest_reduced_op.update_arg("acc", dummy) + new_node = _expand_node( + latest_reduced_op, + trace, + dim_query, + dim_scaling, + context, + get_node_dim_scaling, + res_idx, + ) + + # Node is always cloned; Hence, will never be equal to latest reduced op + assert new_node != latest_reduced_op + # Update MMA_{t} to accumulate on MMA_{t-1}, and then save + # current MMA_{t} to outputs for use in next loop. + latest_reduced_op.update_arg("acc", saved_acc) + new_node.update_arg("acc", latest_reduced_op) + latest_reduced_op.graph.erase_node(dummy) + latest_reduced_op = new_node + return latest_reduced_op + + def _handle_reduction_dim( reduction: Reduction, output: Output, @@ -566,39 +630,53 @@ def _handle_reduction_dim( # TODO: Register iter args with the reduction initially so accessing them is easier iter_args: list[CustomOp] = [] reduction_subgraph = trace.get_subgraph(reduction.subgraph_name) + + # TODO: Handle case where MMAs/ReduceOps do not have Output as direct consumer. + def get_output_index(custom: CustomOp): + output_users = [ + get_custom(user) + for user in custom.fx_node.users + if isinstance(get_custom(user), Output) + ] + if len(output_users) != 1: + raise NotImplementedError( + "NYI: Currently only handle direct and 1:1 MMA -> Output case." + ) + return output_users[0].return_vals[0].index(custom.fx_node) + + # Collect MMA and ReduceOp who's reduction axis matches parent ReductionOp. + reduction_root_ops = [] for node in (get_custom(fx_node) for fx_node in reduction_subgraph.nodes): - if isinstance(node, IterArg): - iter_args.append(node) + if isinstance(node, (MMA, ReduceOp)) and reduction.axis == node.reduction_dim: + reduction_root_ops.append(node) new_outputs = list(reduction.outputs(trace.get_subgraph(reduction.subgraph_name))) # Users of the loop carried nodes will be duplicated - for idx, carried_node in enumerate(iter_args): - # The initial nodes are expanded in the first dimension, so we start from 1 - dim_scaling = get_node_dim_scaling(carried_node) - for scale_idx in range(1, dim_scaling[reduction.axis]): - for user in carried_node.users: - if isinstance(user, Output): - continue - - dims = dict(user.fx_node.expanded_dims) - dims[reduction.axis] = scale_idx - # Temporarily replace the loop carried arg here to avoid - # duplicated expansion. Otherwise we have the following situation: - # Suppose we have: - # mma_0_0_0(..., acc_0_0_0) - # mma_0_0_1(..., mma_0_0_0) - # Expanding mma_0_0_1 to mma_0_0_2 will trigger expansion of its arg - # mma_0_0_0 in dims 0_0_2 as well, effectively duplicating the new node. - # To avoid this we temporarily replace the use of it with a dummy - # placeholder which will not trigger further expansion. - index = user.get_node_arg_index(carried_node) - dummy = Placeholder("dummy").add_to_graph(user.graph) - dummy.type = None - - saved_arg = user.node_args[index] - user.update_arg(index, dummy) - new_node = _expand_node( - user, + for root_op in reduction_root_ops: + dim_scaling = get_node_dim_scaling(root_op) + dims = dict(root_op.fx_node.expanded_dims) + latest_reduced_op = root_op + op_output_index = get_output_index(root_op) + if isinstance(root_op, MMA): + latest_reduced_op = _expand_mma_tiled_reduction( + root_op, + trace, + dims, + dim_scaling, + context, + get_node_dim_scaling, + res_idx, + ) + elif isinstance(root_op, ReduceOp): + original_src = latest_reduced_op.arg + # The initial nodes are expanded in the first dimension, so we start from 1 + for scale_idx in range(1, dim_scaling[reduction.axis]): + dims[root_op.reduction_dim] = scale_idx + current_src = latest_reduced_op.arg + if not isinstance(current_src, Sequence): + current_src = [current_src] + expanded_src = _expand_node( + get_custom(original_src), trace, dims, dim_scaling, @@ -606,13 +684,11 @@ def _handle_reduction_dim( get_node_dim_scaling, res_idx, ) - - # This expansion always happens, user should never be reused - assert new_node != user - user.update_arg(index, saved_arg) - new_node.update_arg(index, user) - user.graph.erase_node(dummy) - carried_node = user - new_outputs[idx] = new_node.fx_node - + current_src.append(expanded_src.fx_node) + latest_reduced_op.update_arg("arg", current_src) + new_outputs[op_output_index] = latest_reduced_op.fx_node + init_dims = root_op.fx_node.expanded_dims + context[ + (root_op, get_indexed_dims(init_dims, root_op), res_idx) + ] = latest_reduced_op output.update_arg("return_vals", new_outputs) diff --git a/lit_tests/kernel/wave/expansion.py b/lit_tests/kernel/wave/expansion.py index 9f4fcf0a..e3b5f058 100644 --- a/lit_tests/kernel/wave/expansion.py +++ b/lit_tests/kernel/wave/expansion.py @@ -550,6 +550,120 @@ def test_batched_gemm(): # CHECK-NEXT: ----- +@tkw.wave_trace_only() +def gemm_non_direct_acc( + a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], + b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32], +): + c_reg = tkl.Register[M, N, tkl.f32](0.0) + + @tkw.reduction(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: + a_reg = tkw.read(a, elements_per_thread=4) + b_reg = tkw.read(b, elements_per_thread=4) + new_acc = tkw.exp2(a_reg) + acc + acc = tkw.mma(a_reg, b_reg, new_acc) + return acc + + tkw.write(repeat, c, elements_per_thread=4) + + +@run_test +def test_gemm_non_direct_acc(): + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.TilingConstraint(K, BLOCK_K, ARGK)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2, THREAD_0 / 64)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2, THREAD_1)] + constraints += [ + tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(2, 2, 1)) + ] + with tk.gen.TestLaunchContext( + { + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K: 32, + } + ): + graph = gemm_non_direct_acc() + IndexingContext.current().finalize() + set_node_indices(graph, constraints) + expand_graph(graph, constraints) + set_post_expansion_indices(graph, constraints) + print_trace(graph) + # CHECK: %add_0_0_0 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_0_0_0, %acc_0_0_0), kwargs = {}) + # CHECK: %add_1_1_0 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_1_0_0, %acc_1_1_0), kwargs = {}) + # CHECK: %add_1_0_0 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_1_0_0, %acc_1_0_0), kwargs = {}) + # CHECK: %add_0_1_0 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_0_0_0, %acc_0_1_0), kwargs = {}) + # CHECK: %mma_0_0_0 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_0_0_0, %read_0_0_0, %add_0_0_0), kwargs = {}) + # CHECK: %mma_0_0_1 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_0_0_1, %read_0_0_1, %mma_0_0_0), kwargs = {}) + # CHECK: %mma_1_1_0 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_1_0_0, %read_0_1_0, %add_1_1_0), kwargs = {}) + # CHECK: %mma_1_1_1 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_1_0_1, %read_0_1_1, %mma_1_1_0), kwargs = {}) + # CHECK: %mma_1_0_0 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_1_0_0, %read_0_0_0, %add_1_0_0), kwargs = {}) + # CHECK: %mma_1_0_1 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_1_0_1, %read_0_0_1, %mma_1_0_0), kwargs = {}) + # CHECK: %mma_0_1_0 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_0_0_0, %read_0_1_0, %add_0_1_0), kwargs = {}) + # CHECK: %mma_0_1_1 + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_0_0_1, %read_0_1_1, %mma_0_1_0), kwargs = {}) + + +@tkw.wave_trace_only() +def tiled_max( + a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, ADDRESS_SPACE, tkl.f16], +): + init_max = tkl.Register[M, tkl.f16](-1e6) + + @tkw.reduction(K, init_args=[init_max]) + def repeat(acc: tkl.Register[M, tkl.f16]) -> tkl.Register[M, tkl.f16]: + a_reg = tkw.read(a, elements_per_thread=4) + partial_max = tkw.max(a_reg, acc, dim=K) + return partial_max + + tkw.write(repeat, c, elements_per_thread=4) + + +@run_test +def test_tiled_max(): + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.TilingConstraint(K, BLOCK_K, ARGK)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2, THREAD_0 / 64)] + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(2, 1, 1), + vector_shapes={M: 16, K: 4}, + ) + ] + with tk.gen.TestLaunchContext( + { + BLOCK_M: 64, + BLOCK_K: 32, + } + ): + graph = tiled_max() + IndexingContext.current().finalize() + set_node_indices(graph, constraints) + expand_graph(graph, constraints) + set_post_expansion_indices(graph, constraints) + print_trace(graph) + # CHECK: max(arg=[read_0_0, read_0_1, read_0_2, read_0_3, read_0_4, read_0_5, read_0_6, read_0_7], init=acc_0_0 + # CHECK: max(arg=[read_1_0, read_1_1, read_1_2, read_1_3, read_1_4, read_1_5, read_1_6, read_1_7], init=acc_1_0 + # CHECK: output(return_vals=([max_0_0, max_1_0],)) + # CHECK-NEXT: ----- + + @run_test def test_gemm_reduction_expansion_only(): # Note: This does not implement an actual gemm computation but reuses the