diff --git a/iree/turbine/kernel/wave/decompose_reduce_ops.py b/iree/turbine/kernel/wave/decompose_reduce_ops.py index bf972b75..dc3504e0 100644 --- a/iree/turbine/kernel/wave/decompose_reduce_ops.py +++ b/iree/turbine/kernel/wave/decompose_reduce_ops.py @@ -11,7 +11,7 @@ TilingConstraint, ) from .._support.tracing import CapturedTrace -from .._support.indexing import IndexingContext, IndexSequence, IndexSymbol, IndexExpr +from .._support.indexing import IndexSequence, IndexSymbol, IndexExpr from ..ops.wave_ops import ( get_custom, Add, @@ -22,6 +22,7 @@ Extract, Reduction, ) +from ..lang.global_symbols import * from .utils import DCE, subs_idxc, all_equal import torch.fx as fx @@ -31,6 +32,51 @@ TKW_COMBINER = {"sum": Add, "max": Maximum} +def determine_shuffle_config( + index: dict[IndexSymbol, IndexSequence], + reduction_dim: IndexSymbol, + vector_shapes: dict[IndexSymbol, int], + subgroup_size: int, + induction_vars: list[IndexSymbol], +): + """ + This function determines the cluster size and stride for a given index. + The cluster size specifies the number of threads that participate in a shuffle. + The cluster stride specifies the stride between the threads. In order to + determine the cluster stride, we do a binary search on the start value of the + index sequence. + + """ + access_pattern = index[reduction_dim] + elements_per_thread = access_pattern.size + cluster_size = vector_shapes[reduction_dim] // elements_per_thread + + # Since we are only concerned with what happens within a subgroup, + # we can ignore the TID_1 and TID_2 components of the index. We can + # also ignore the GPR_NUM since we can assume we are only dealing with the + # same GPR_NUM. We ignore the workgroup indices and induction variables as well. + # Finally, we substitute in all variables that are known in the indexing context. + ignore = [ + THREAD_1, + THREAD_2, + GPR_NUM, + WORKGROUP_0, + WORKGROUP_1, + WORKGROUP_2, + ] + induction_vars + offset = access_pattern.start.subs({k: 0 for k in ignore}) + offset = subs_idxc(offset) + offset_table = [offset.subs({THREAD_0: i}) for i in range(subgroup_size)] + # Determine the thread ids participating in the shuffle. + thread_ids = [] + for i in range(cluster_size): + thread_ids.append(offset_table.index(i * elements_per_thread)) + + cluster_stride = [x - y for x, y in zip(thread_ids[1:], thread_ids[:-1])] + assert all_equal(cluster_stride), f"Cluster stride must be equal across threads." + return cluster_size, cluster_stride[0] + + def get_graph_node(custom: CustomOp, graph: fx.Graph): custom.add_to_graph(graph) custom = custom.fx_node @@ -58,15 +104,20 @@ def emit_local_reduction( def emit_global_reduction( - binary_fn: Callable, src: fx.Node, graph: fx.Graph, subgroup_size: int + binary_fn: Callable, + src: fx.Node, + graph: fx.Graph, + subgroup_size: int, + cluster_size: int, + cluster_stride: int, ) -> fx.Node: init = src - num_steps = int(math.log2(float(subgroup_size))) - for i in range(num_steps): - shuffle_offset = 2**i - shuffle_val = ShuffleOp(init, shuffle_offset, subgroup_size) + num_steps = int(math.log2(float(cluster_size))) + for _ in range(num_steps): + shuffle_val = ShuffleOp(init, cluster_stride, subgroup_size) shuffle_node = get_graph_node(shuffle_val, graph) init = get_graph_node(binary_fn(init, shuffle_node), graph) + cluster_stride <<= 1 return init @@ -99,6 +150,9 @@ def decompose_reduce_ops( for c in constraints if isinstance(c, TilingConstraint) or isinstance(c, WorkgroupConstraint) } + induction_vars = [ + c.induction_var for c in constraints if isinstance(c, TilingConstraint) + ] subgroup_size = hardware_constraint.threads_per_wave for node in reduce_nodes: custom = get_custom(node) @@ -143,8 +197,20 @@ def decompose_reduce_ops( ) # Global Reduce + cluster_size, cluster_stride = determine_shuffle_config( + reduction_src[0].index, + reduction_dim, + node.vector_shapes, + subgroup_size, + induction_vars, + ) global_reduction = emit_global_reduction( - binary_fn, local_reduction, custom.graph, subgroup_size + binary_fn, + local_reduction, + custom.graph, + subgroup_size, + cluster_size, + cluster_stride, ) # Local Accumulator Reduce diff --git a/iree/turbine/kernel/wave/index_sequence_analysis.py b/iree/turbine/kernel/wave/index_sequence_analysis.py index ea1b32ae..140c0ac6 100644 --- a/iree/turbine/kernel/wave/index_sequence_analysis.py +++ b/iree/turbine/kernel/wave/index_sequence_analysis.py @@ -23,6 +23,7 @@ get_hardware_constraint, subs_idxc, specialize_index_sequence, + capture_backward_slice, ) import torch.fx as fx import numpy as np @@ -238,6 +239,11 @@ def set_node_index( dimensions M, N and K, but allow each mapping to be piecewise conditioned on the operand. """ + custom = get_custom(node) + anchor = custom.anchor + if isinstance(custom, (Reduction, Placeholder)) and not isinstance(custom, IterArg): + return + hardware_constraint = [get_hardware_constraint(constraints)] workgroup_constraints = { c.dim: c for c in constraints if isinstance(c, WorkgroupConstraint) @@ -251,13 +257,17 @@ def set_node_index( index = {} # The semantics of elements_per_thread are that it represents the number of # elements that are loaded contiguously from memory. - custom = get_custom(node) - anchor = custom.anchor - elements_per_thread = getattr(custom, "elements_per_thread", None) - - if isinstance(custom, (Reduction, Placeholder)) and not isinstance(custom, IterArg): - return + # For elementwise operations that do not have an elements per thread attribute, + # look back to the backward slice to see if they can find an appropriate value. + # TODO: Remove this once set_node_index is integrated with thread_shape_analysis. + if elements_per_thread is None: + backward_slice = capture_backward_slice(node) + for bwd_node in backward_slice: + custom_node = get_custom(bwd_node) + elements_per_thread = getattr(custom_node, "elements_per_thread", None) + if elements_per_thread: + break for dim in custom.indexing_dims: index_seq = None diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 14e7c713..ceff8128 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -990,7 +990,7 @@ def repeat( # CHECK: %[[LOOP:.+]]:2 = scf.for %[[ITER:.+]] = %[[C0_IDX]] to %[[C4_IDX]] step %[[C1_IDX]] # CHECK-SAME: iter_args(%[[ACC0:.+]] = %{{.*}}, %[[ACC1:.+]] = {{.*}}) # CHECK-COUNT-2: vector.load{{.*}} memref<32x20xf16, #gpu.address_space>, vector<4xf16> - # CHECK-COUNT-6: gpu.shuffle xor + # CHECK-COUNT-2: gpu.shuffle xor # CHECK: %[[MAX:.+]] = arith.maximumf %[[ACC0]], %{{.*}} # CHECK: %[[MMA:.+]] = amdgpu.mfma %{{.*}} * %{{.*}} + %[[ACC1]] # CHECK: scf.yield %[[MAX]], %[[MMA]] : vector<1xf16>, vector<4xf32> diff --git a/lit_tests/kernel/wave/expansion.py b/lit_tests/kernel/wave/expansion.py index de8a3b40..9f4fcf0a 100644 --- a/lit_tests/kernel/wave/expansion.py +++ b/lit_tests/kernel/wave/expansion.py @@ -711,12 +711,12 @@ def py_arithmetic_different_dims(): # CHECK-NEXT: placeholder(_name=c # CHECK-NEXT: read(memory=a, elements_per_thread=4, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1} # CHECK-NEXT: read(memory=a, elements_per_thread=4, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1} - # CHECK-NEXT: add(lhs=read_0_0_0, rhs=read_0_0_0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N} - # CHECK-NEXT: add(lhs=read_1_0_0, rhs=read_1_0_0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N} - # CHECK-NEXT: sub(lhs=add_0_0_0, rhs=read_0_0_0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N} - # CHECK-NEXT: sub(lhs=add_1_0_0, rhs=read_1_0_0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N - # CHECK-NEXT: neg(arg=sub_0_0_0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N} - # CHECK-NEXT: neg(arg=sub_1_0_0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N} + # CHECK-NEXT: add(lhs=read_0_0_0, rhs=read_0_0_0, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1} + # CHECK-NEXT: add(lhs=read_1_0_0, rhs=read_1_0_0, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1} + # CHECK-NEXT: sub(lhs=add_0_0_0, rhs=read_0_0_0, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1} + # CHECK-NEXT: sub(lhs=add_1_0_0, rhs=read_1_0_0, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1} + # CHECK-NEXT: neg(arg=sub_0_0_0, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1} + # CHECK-NEXT: neg(arg=sub_1_0_0, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1} # CHECK-NEXT: write(register_=neg_0_0_0, memory=c, elements_per_thread=4, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M, K: 4*$T2 + $WG2*BLOCK_K : 4 : 1} # CHECK-NEXT: write(register_=neg_1_0_0, memory=c, elements_per_thread=4, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M + 16, K: 4*$T2 + $WG2*BLOCK_K + 16 : 4 : 1} # CHECK-NEXT: write(register_=neg_1_0_0, memory=c, elements_per_thread=4, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M + 16, K: 4*$T2 + $WG2*BLOCK_K : 4 : 1}