From d8394caa5bafb41de33a6ded98ddcc61a34b7c5d Mon Sep 17 00:00:00 2001 From: Harsh Menon Date: Fri, 25 Oct 2024 11:31:32 -0700 Subject: [PATCH] Modify the shuffle op to use information from the index sequence This PR modifies the shuffle op global reduction so that it determines the number of threads and the stride between them from the reduction source's index. Signed-off-by: Harsh Menon --- .../kernel/wave/decompose_reduce_ops.py | 81 +++++++++++++++++-- .../kernel/wave/index_sequence_analysis.py | 21 +++-- lit_tests/kernel/wave/codegen.py | 2 +- lit_tests/kernel/wave/expansion.py | 12 +-- 4 files changed, 96 insertions(+), 20 deletions(-) diff --git a/iree/turbine/kernel/wave/decompose_reduce_ops.py b/iree/turbine/kernel/wave/decompose_reduce_ops.py index bf972b75..e1dc9bc7 100644 --- a/iree/turbine/kernel/wave/decompose_reduce_ops.py +++ b/iree/turbine/kernel/wave/decompose_reduce_ops.py @@ -22,15 +22,62 @@ Extract, Reduction, ) +from ..lang.global_symbols import * -from .utils import DCE, subs_idxc, all_equal +from .utils import DCE, subs_idxc, all_equal, delinearize_index import torch.fx as fx import math from typing import Callable +from functools import lru_cache TKW_COMBINER = {"sum": Add, "max": Maximum} +def determine_shuffle_config( + index: dict[IndexSymbol, IndexSequence], + reduction_dim: IndexSymbol, + vector_shapes: dict[IndexSymbol, int], + subgroup_size: int, + induction_vars: list[IndexSymbol], +): + """ + This function determines the cluster size and stride for a given index. + The cluster size specifies the number of threads that participate in a shuffle. + The cluster stride specifies the stride between the threads. In order to + determine the cluster stride, we do a binary search on the start value of the + index sequence. + + """ + access_pattern = index[reduction_dim] + elements_per_thread = access_pattern.size + cluster_size = vector_shapes[reduction_dim] // elements_per_thread + + # Since we are only concerned with what happens within a subgroup, + # we can ignore the TID_1 and TID_2 components of the index. We can + # also ignore the GPR_NUM since we can assume we are only dealing with the + # same GPR_NUM. We ignore the workgroup indices and induction variables as well. + # Finally, we substitute in all variables that are known in the indexing context. + ignore = [ + THREAD_1, + THREAD_2, + GPR_NUM, + WORKGROUP_0, + WORKGROUP_1, + WORKGROUP_2, + ] + induction_vars + offset = access_pattern.start.subs({k: 0 for k in ignore}) + offset = subs_idxc(offset) + offset_table = [offset.subs({THREAD_0: i}) for i in range(subgroup_size)] + # Determine the thread ids participating in the shuffle. + thread_ids = [] + for i in range(cluster_size): + thread_ids.append(offset_table.index(i * elements_per_thread)) + + cluster_stride = [x - y for x, y in zip(thread_ids[1:], thread_ids[:-1])] + assert all_equal(cluster_stride), f"Cluster stride must be equal across threads." + return cluster_size, cluster_stride[0] + + def get_graph_node(custom: CustomOp, graph: fx.Graph): custom.add_to_graph(graph) custom = custom.fx_node @@ -58,15 +105,20 @@ def emit_local_reduction( def emit_global_reduction( - binary_fn: Callable, src: fx.Node, graph: fx.Graph, subgroup_size: int + binary_fn: Callable, + src: fx.Node, + graph: fx.Graph, + subgroup_size: int, + cluster_size: int, + cluster_stride: int, ) -> fx.Node: init = src - num_steps = int(math.log2(float(subgroup_size))) - for i in range(num_steps): - shuffle_offset = 2**i - shuffle_val = ShuffleOp(init, shuffle_offset, subgroup_size) + num_steps = int(math.log2(float(cluster_size))) + for _ in range(num_steps): + shuffle_val = ShuffleOp(init, cluster_stride, subgroup_size) shuffle_node = get_graph_node(shuffle_val, graph) init = get_graph_node(binary_fn(init, shuffle_node), graph) + cluster_stride <<= 1 return init @@ -99,6 +151,9 @@ def decompose_reduce_ops( for c in constraints if isinstance(c, TilingConstraint) or isinstance(c, WorkgroupConstraint) } + induction_vars = [ + c.induction_var for c in constraints if isinstance(c, TilingConstraint) + ] subgroup_size = hardware_constraint.threads_per_wave for node in reduce_nodes: custom = get_custom(node) @@ -143,8 +198,20 @@ def decompose_reduce_ops( ) # Global Reduce + cluster_size, cluster_stride = determine_shuffle_config( + reduction_src[0].index, + reduction_dim, + node.vector_shapes, + subgroup_size, + induction_vars, + ) global_reduction = emit_global_reduction( - binary_fn, local_reduction, custom.graph, subgroup_size + binary_fn, + local_reduction, + custom.graph, + subgroup_size, + cluster_size, + cluster_stride, ) # Local Accumulator Reduce diff --git a/iree/turbine/kernel/wave/index_sequence_analysis.py b/iree/turbine/kernel/wave/index_sequence_analysis.py index 7f417665..8449b5bd 100644 --- a/iree/turbine/kernel/wave/index_sequence_analysis.py +++ b/iree/turbine/kernel/wave/index_sequence_analysis.py @@ -23,6 +23,7 @@ get_hardware_constraint, subs_idxc, specialize_index_sequence, + capture_backward_slice, ) import torch.fx as fx import numpy as np @@ -229,6 +230,11 @@ def set_node_index( dimensions M, N and K, but allow each mapping to be piecewise conditioned on the operand. """ + custom = get_custom(node) + anchor = custom.anchor + if isinstance(custom, (Reduction, Placeholder)) and not isinstance(custom, IterArg): + return + hardware_constraint = [get_hardware_constraint(constraints)] workgroup_constraints = { c.dim: c for c in constraints if isinstance(c, WorkgroupConstraint) @@ -242,13 +248,16 @@ def set_node_index( index = {} # The semantics of elements_per_thread are that it represents the number of # elements that are loaded contiguously from memory. - custom = get_custom(node) - anchor = custom.anchor - elements_per_thread = getattr(custom, "elements_per_thread", None) - - if isinstance(custom, (Reduction, Placeholder)) and not isinstance(custom, IterArg): - return + # For elementwise operations that do not have an elements per thread attribute, + # look back to the backward slice to see if they can find an appropriate value. + if elements_per_thread is None: + backward_slice = capture_backward_slice(node) + for bwd_node in backward_slice: + custom_node = get_custom(bwd_node) + elements_per_thread = getattr(custom_node, "elements_per_thread", None) + if elements_per_thread: + break for dim in custom.indexing_dims: index_seq = None diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 14e7c713..ceff8128 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -990,7 +990,7 @@ def repeat( # CHECK: %[[LOOP:.+]]:2 = scf.for %[[ITER:.+]] = %[[C0_IDX]] to %[[C4_IDX]] step %[[C1_IDX]] # CHECK-SAME: iter_args(%[[ACC0:.+]] = %{{.*}}, %[[ACC1:.+]] = {{.*}}) # CHECK-COUNT-2: vector.load{{.*}} memref<32x20xf16, #gpu.address_space>, vector<4xf16> - # CHECK-COUNT-6: gpu.shuffle xor + # CHECK-COUNT-2: gpu.shuffle xor # CHECK: %[[MAX:.+]] = arith.maximumf %[[ACC0]], %{{.*}} # CHECK: %[[MMA:.+]] = amdgpu.mfma %{{.*}} * %{{.*}} + %[[ACC1]] # CHECK: scf.yield %[[MAX]], %[[MMA]] : vector<1xf16>, vector<4xf32> diff --git a/lit_tests/kernel/wave/expansion.py b/lit_tests/kernel/wave/expansion.py index de8a3b40..9f4fcf0a 100644 --- a/lit_tests/kernel/wave/expansion.py +++ b/lit_tests/kernel/wave/expansion.py @@ -711,12 +711,12 @@ def py_arithmetic_different_dims(): # CHECK-NEXT: placeholder(_name=c # CHECK-NEXT: read(memory=a, elements_per_thread=4, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1} # CHECK-NEXT: read(memory=a, elements_per_thread=4, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1} - # CHECK-NEXT: add(lhs=read_0_0_0, rhs=read_0_0_0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N} - # CHECK-NEXT: add(lhs=read_1_0_0, rhs=read_1_0_0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N} - # CHECK-NEXT: sub(lhs=add_0_0_0, rhs=read_0_0_0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N} - # CHECK-NEXT: sub(lhs=add_1_0_0, rhs=read_1_0_0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N - # CHECK-NEXT: neg(arg=sub_0_0_0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N} - # CHECK-NEXT: neg(arg=sub_1_0_0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N} + # CHECK-NEXT: add(lhs=read_0_0_0, rhs=read_0_0_0, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1} + # CHECK-NEXT: add(lhs=read_1_0_0, rhs=read_1_0_0, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1} + # CHECK-NEXT: sub(lhs=add_0_0_0, rhs=read_0_0_0, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1} + # CHECK-NEXT: sub(lhs=add_1_0_0, rhs=read_1_0_0, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1} + # CHECK-NEXT: neg(arg=sub_0_0_0, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1} + # CHECK-NEXT: neg(arg=sub_1_0_0, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1} # CHECK-NEXT: write(register_=neg_0_0_0, memory=c, elements_per_thread=4, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M, K: 4*$T2 + $WG2*BLOCK_K : 4 : 1} # CHECK-NEXT: write(register_=neg_1_0_0, memory=c, elements_per_thread=4, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M + 16, K: 4*$T2 + $WG2*BLOCK_K + 16 : 4 : 1} # CHECK-NEXT: write(register_=neg_1_0_0, memory=c, elements_per_thread=4, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M + 16, K: 4*$T2 + $WG2*BLOCK_K : 4 : 1}