Skip to content

Commit

Permalink
Modify the shuffle op to use information from the index sequence (#241)
Browse files Browse the repository at this point in the history
This PR modifies the shuffle op global reduction so that it determines
the number of threads and the stride between them from the reduction
source's index.

Signed-off-by: Harsh Menon <harsh@nod-labs.com>
  • Loading branch information
harsh-nod authored Oct 28, 2024
1 parent bf7b686 commit 32a47b2
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 20 deletions.
80 changes: 73 additions & 7 deletions iree/turbine/kernel/wave/decompose_reduce_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
TilingConstraint,
)
from .._support.tracing import CapturedTrace
from .._support.indexing import IndexingContext, IndexSequence, IndexSymbol, IndexExpr
from .._support.indexing import IndexSequence, IndexSymbol, IndexExpr
from ..ops.wave_ops import (
get_custom,
Add,
Expand All @@ -22,6 +22,7 @@
Extract,
Reduction,
)
from ..lang.global_symbols import *

from .utils import DCE, subs_idxc, all_equal
import torch.fx as fx
Expand All @@ -31,6 +32,51 @@
TKW_COMBINER = {"sum": Add, "max": Maximum}


def determine_shuffle_config(
index: dict[IndexSymbol, IndexSequence],
reduction_dim: IndexSymbol,
vector_shapes: dict[IndexSymbol, int],
subgroup_size: int,
induction_vars: list[IndexSymbol],
):
"""
This function determines the cluster size and stride for a given index.
The cluster size specifies the number of threads that participate in a shuffle.
The cluster stride specifies the stride between the threads. In order to
determine the cluster stride, we do a binary search on the start value of the
index sequence.
"""
access_pattern = index[reduction_dim]
elements_per_thread = access_pattern.size
cluster_size = vector_shapes[reduction_dim] // elements_per_thread

# Since we are only concerned with what happens within a subgroup,
# we can ignore the TID_1 and TID_2 components of the index. We can
# also ignore the GPR_NUM since we can assume we are only dealing with the
# same GPR_NUM. We ignore the workgroup indices and induction variables as well.
# Finally, we substitute in all variables that are known in the indexing context.
ignore = [
THREAD_1,
THREAD_2,
GPR_NUM,
WORKGROUP_0,
WORKGROUP_1,
WORKGROUP_2,
] + induction_vars
offset = access_pattern.start.subs({k: 0 for k in ignore})
offset = subs_idxc(offset)
offset_table = [offset.subs({THREAD_0: i}) for i in range(subgroup_size)]
# Determine the thread ids participating in the shuffle.
thread_ids = []
for i in range(cluster_size):
thread_ids.append(offset_table.index(i * elements_per_thread))

cluster_stride = [x - y for x, y in zip(thread_ids[1:], thread_ids[:-1])]
assert all_equal(cluster_stride), f"Cluster stride must be equal across threads."
return cluster_size, cluster_stride[0]


def get_graph_node(custom: CustomOp, graph: fx.Graph):
custom.add_to_graph(graph)
custom = custom.fx_node
Expand Down Expand Up @@ -58,15 +104,20 @@ def emit_local_reduction(


def emit_global_reduction(
binary_fn: Callable, src: fx.Node, graph: fx.Graph, subgroup_size: int
binary_fn: Callable,
src: fx.Node,
graph: fx.Graph,
subgroup_size: int,
cluster_size: int,
cluster_stride: int,
) -> fx.Node:
init = src
num_steps = int(math.log2(float(subgroup_size)))
for i in range(num_steps):
shuffle_offset = 2**i
shuffle_val = ShuffleOp(init, shuffle_offset, subgroup_size)
num_steps = int(math.log2(float(cluster_size)))
for _ in range(num_steps):
shuffle_val = ShuffleOp(init, cluster_stride, subgroup_size)
shuffle_node = get_graph_node(shuffle_val, graph)
init = get_graph_node(binary_fn(init, shuffle_node), graph)
cluster_stride <<= 1
return init


Expand Down Expand Up @@ -99,6 +150,9 @@ def decompose_reduce_ops(
for c in constraints
if isinstance(c, TilingConstraint) or isinstance(c, WorkgroupConstraint)
}
induction_vars = [
c.induction_var for c in constraints if isinstance(c, TilingConstraint)
]
subgroup_size = hardware_constraint.threads_per_wave
for node in reduce_nodes:
custom = get_custom(node)
Expand Down Expand Up @@ -143,8 +197,20 @@ def decompose_reduce_ops(
)

# Global Reduce
cluster_size, cluster_stride = determine_shuffle_config(
reduction_src[0].index,
reduction_dim,
node.vector_shapes,
subgroup_size,
induction_vars,
)
global_reduction = emit_global_reduction(
binary_fn, local_reduction, custom.graph, subgroup_size
binary_fn,
local_reduction,
custom.graph,
subgroup_size,
cluster_size,
cluster_stride,
)

# Local Accumulator Reduce
Expand Down
22 changes: 16 additions & 6 deletions iree/turbine/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
get_hardware_constraint,
subs_idxc,
specialize_index_sequence,
capture_backward_slice,
)
import torch.fx as fx
import numpy as np
Expand Down Expand Up @@ -238,6 +239,11 @@ def set_node_index(
dimensions M, N and K, but allow each mapping to be piecewise conditioned
on the operand.
"""
custom = get_custom(node)
anchor = custom.anchor
if isinstance(custom, (Reduction, Placeholder)) and not isinstance(custom, IterArg):
return

hardware_constraint = [get_hardware_constraint(constraints)]
workgroup_constraints = {
c.dim: c for c in constraints if isinstance(c, WorkgroupConstraint)
Expand All @@ -251,13 +257,17 @@ def set_node_index(
index = {}
# The semantics of elements_per_thread are that it represents the number of
# elements that are loaded contiguously from memory.
custom = get_custom(node)
anchor = custom.anchor

elements_per_thread = getattr(custom, "elements_per_thread", None)

if isinstance(custom, (Reduction, Placeholder)) and not isinstance(custom, IterArg):
return
# For elementwise operations that do not have an elements per thread attribute,
# look back to the backward slice to see if they can find an appropriate value.
# TODO: Remove this once set_node_index is integrated with thread_shape_analysis.
if elements_per_thread is None:
backward_slice = capture_backward_slice(node)
for bwd_node in backward_slice:
custom_node = get_custom(bwd_node)
elements_per_thread = getattr(custom_node, "elements_per_thread", None)
if elements_per_thread:
break

for dim in custom.indexing_dims:
index_seq = None
Expand Down
2 changes: 1 addition & 1 deletion lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,7 +990,7 @@ def repeat(
# CHECK: %[[LOOP:.+]]:2 = scf.for %[[ITER:.+]] = %[[C0_IDX]] to %[[C4_IDX]] step %[[C1_IDX]]
# CHECK-SAME: iter_args(%[[ACC0:.+]] = %{{.*}}, %[[ACC1:.+]] = {{.*}})
# CHECK-COUNT-2: vector.load{{.*}} memref<32x20xf16, #gpu.address_space<workgroup>>, vector<4xf16>
# CHECK-COUNT-6: gpu.shuffle xor
# CHECK-COUNT-2: gpu.shuffle xor
# CHECK: %[[MAX:.+]] = arith.maximumf %[[ACC0]], %{{.*}}
# CHECK: %[[MMA:.+]] = amdgpu.mfma %{{.*}} * %{{.*}} + %[[ACC1]]
# CHECK: scf.yield %[[MAX]], %[[MMA]] : vector<1xf16>, vector<4xf32>
Expand Down
12 changes: 6 additions & 6 deletions lit_tests/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,12 +711,12 @@ def py_arithmetic_different_dims():
# CHECK-NEXT: placeholder(_name=c
# CHECK-NEXT: read(memory=a, elements_per_thread=4, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1}
# CHECK-NEXT: read(memory=a, elements_per_thread=4, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1}
# CHECK-NEXT: add(lhs=read_0_0_0, rhs=read_0_0_0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N}
# CHECK-NEXT: add(lhs=read_1_0_0, rhs=read_1_0_0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N}
# CHECK-NEXT: sub(lhs=add_0_0_0, rhs=read_0_0_0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N}
# CHECK-NEXT: sub(lhs=add_1_0_0, rhs=read_1_0_0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N
# CHECK-NEXT: neg(arg=sub_0_0_0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N}
# CHECK-NEXT: neg(arg=sub_1_0_0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/4 + $WG1*BLOCK_N}
# CHECK-NEXT: add(lhs=read_0_0_0, rhs=read_0_0_0, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1}
# CHECK-NEXT: add(lhs=read_1_0_0, rhs=read_1_0_0, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1}
# CHECK-NEXT: sub(lhs=add_0_0_0, rhs=read_0_0_0, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1}
# CHECK-NEXT: sub(lhs=add_1_0_0, rhs=read_1_0_0, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1}
# CHECK-NEXT: neg(arg=sub_0_0_0, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1}
# CHECK-NEXT: neg(arg=sub_1_0_0, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/4 + 4*$T1 + $WG1*BLOCK_N : 4 : 1}
# CHECK-NEXT: write(register_=neg_0_0_0, memory=c, elements_per_thread=4, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M, K: 4*$T2 + $WG2*BLOCK_K : 4 : 1}
# CHECK-NEXT: write(register_=neg_1_0_0, memory=c, elements_per_thread=4, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M + 16, K: 4*$T2 + $WG2*BLOCK_K + 16 : 4 : 1}
# CHECK-NEXT: write(register_=neg_1_0_0, memory=c, elements_per_thread=4, index={M: $T0*BLOCK_M/128 + $T0 + $WG0*BLOCK_M + 16, K: 4*$T2 + $WG2*BLOCK_K : 4 : 1}
Expand Down

0 comments on commit 32a47b2

Please sign in to comment.