Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TKW] Fix types, shapes and propagate resolved indexing #177

Merged
merged 3 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 53 additions & 2 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ def allocate(
...


def extract(
register: "Register",
offsets: tuple[IndexExpr],
) -> "Register":
...


def extract_slice(
register: "Register",
offsets: tuple[IndexExpr],
Expand Down Expand Up @@ -1057,6 +1064,46 @@ def index(self, value: dict[IndexSymbol, IndexSequence]):
CustomOp.index.fset(self, value)


@define_op("extract")
@dataclass
class Extract(CustomOp):
"""
Op Rationale:

Extract is an op used to represent extracting of
a scalar from TKW's 1-D vector on the specified index.

This can also be viewed as indexing/slicing on the fastest
dimension. Hence, the semantic of this op is designed to
see itself as a reduction on the indexed/fastest dimension.
harsh-nod marked this conversation as resolved.
Show resolved Hide resolved
"""
raikonenfnu marked this conversation as resolved.
Show resolved Hide resolved

register_: fx.Proxy
offset: IndexExpr | int
harsh-nod marked this conversation as resolved.
Show resolved Hide resolved

@property
def type(self) -> "Register":
# Intuition here is we are trying to extract an element
# from fastest dim => we reduce the fastest dim.
src_type = get_custom(self.register_).type
# Return itself if just 0-D/1-D symbolic.
if len(src_type.symbolic_shape) <= 1:
return src_type

# Typically fastest dim is the last dimension,
# If non-unit dim exists => non-unit dim is fastest dim.
harsh-nod marked this conversation as resolved.
Show resolved Hide resolved
non_unit_dim = [k for k, v in self.register_.index.items() if v.size != 1]
if len(non_unit_dim) > 1:
raise NotImplementedError(
f"NYI: Extract only support 1 non-unit dim, but found: {len(non_unit_dim)}"
)
dst_shape = list(src_type.symbolic_shape)
dim_to_remove = dst_shape[-1] if not non_unit_dim else non_unit_dim[0]
dst_shape.remove(dim_to_remove)
dst_type = Register[*dst_shape, src_type.dtype]
return dst_type


@define_op("extract_slice")
@dataclass
class ExtractSlice(CustomOp):
Expand Down Expand Up @@ -1116,12 +1163,16 @@ class ReduceOp(CustomOp, ABC):

@property
def indexing_dims(self) -> list[IndexSymbol]:
return get_custom(self.arg).indexing_dims
src_indexing = get_custom(self.arg).indexing_dims
dst_indexing = [dim for dim in src_indexing if dim != self.dim]
return dst_indexing

@property
def type(self) -> Memory:
src_type = get_custom(self.arg).type
return src_type
reduced_dims = [dims for dims in src_type.symbolic_shape if dims != self.dim]
dst_type = Register[*reduced_dims, src_type.dtype]
return dst_type

@property
def num_reduction_dims(self) -> int:
Expand Down
21 changes: 21 additions & 0 deletions iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
get_result,
allocate,
shared_memory_barrier,
extract,
extract_slice,
CustomOp,
scheduling_barrier,
Expand Down Expand Up @@ -1123,6 +1124,26 @@ def handle_scheduling_group_barrier(emitter: WaveEmitter, node: fx.Node):
###############################################################################


@handle_op(extract)
def handle_extract(emitter: WaveEmitter, node: fx.Node):
try:
register, offset = node.args
except ValueError as e:
raise ValidationError("Malformed arguments") from e
assert isinstance(offset, list) and len(offset) == 1
extract_vector = cast_vector(emitter, register)
result_type = VectorType.get([1], extract_vector.type.element_type)
element = vector_d.extract_strided_slice(
harsh-nod marked this conversation as resolved.
Show resolved Hide resolved
result_type,
extract_vector,
offset,
[1],
[1],
)

emitter.bind_node_proxy(node, IRProxyValue(element))


@handle_op(extract_slice)
def handle_extract_slice(emitter: WaveEmitter, node: fx.Node):
try:
Expand Down
8 changes: 4 additions & 4 deletions iree/turbine/kernel/wave/decompose_reduce_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
ReduceOp,
ShuffleOp,
CustomOp,
ExtractSlice,
Extract,
Reduction,
)

Expand All @@ -40,9 +40,9 @@ def get_graph_node(custom: CustomOp, graph: fx.Graph):
def emit_local_reduction(
binary_fn: Callable, src: fx.Node, graph: fx.Graph, local_reduction_size: int
) -> fx.Node:
init = get_graph_node(ExtractSlice(src, [0], [1], [1]), graph)
init = get_graph_node(Extract(src, [0]), graph)
for i in range(1, local_reduction_size):
cur_slice = get_graph_node(ExtractSlice(src, [i], [1], [1]), graph)
cur_slice = get_graph_node(Extract(src, [i]), graph)
init = get_graph_node(binary_fn(init, cur_slice), graph)
return init

Expand Down Expand Up @@ -100,7 +100,7 @@ def decompose_reduce_ops(
)

# Local Reduce
if reduction_dim is not custom.type.symbolic_shape[-1]:
if reduction_dim is not get_custom(custom.arg).type.symbolic_shape[-1]:
raise NotImplementedError(
"Only implemented reduction on fastest dimension."
)
Expand Down
56 changes: 38 additions & 18 deletions iree/turbine/kernel/wave/thread_shape_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@

logger = get_logger("turbine.wave.thread_shape_analysis")

################################################################
# Index/Symbol and Thread size helper fn and data structure
#################################################################


@dataclass(order=True)
class DimSize:
Expand Down Expand Up @@ -43,6 +47,23 @@ def set_index_size(custom: CustomOp, target_dim_sizes: list[DimSize]):
custom.index[target.dim].size = target.size


#################################################################
# Anchor Indicies and Conflict resolution helpers
#################################################################

anchorOpTypes = (Read, Write, MMA, ReduceOp)
noHandleTypes = (Placeholder, Output, ExtractSlice, Allocate)
nonPropagatableTypes = anchorOpTypes + noHandleTypes


def is_anchor_op(node: fx.Node):
return isinstance(get_custom(node), anchorOpTypes)


def propagatable_op(node: fx.Node):
return not isinstance(get_custom(node), nonPropagatableTypes)


def handle_binaryop_conflict(custom_node: CustomOp):
# Analyze if we can resolve conflict with broadcast.
lhs = get_custom(custom_node.lhs)
Expand All @@ -59,21 +80,32 @@ def handle_binaryop_conflict(custom_node: CustomOp):
broadcast = Broadcast(broadcast_src.fx_node, dst_op.type)
with custom_node.graph.inserting_before(custom_node.fx_node):
broadcast.add_to_graph(custom_node.graph)
setattr(broadcast.fx_node, "index", dst_op.index)
custom_node.index = dst_op.index
custom_node.update_arg(broadcast_idx, broadcast.fx_node)
return True
propagated_resolutions = capture_forward_slice(broadcast.fx_node, propagatable_op)
harsh-nod marked this conversation as resolved.
Show resolved Hide resolved
for node in propagated_resolutions:
get_custom(node).index = dst_op.index
return propagated_resolutions


# Returns True iff all conflicts are handled succesfully.
def handle_conflicts(conflicted_ops: set[CustomOp]):
cummulative_resolved = set()
for conflict in conflicted_ops:
custom = get_custom(conflict)
if isinstance(custom, BinaryPyOp):
handle_binaryop_conflict(custom)
resolved_ops = handle_binaryop_conflict(custom)
cummulative_resolved = cummulative_resolved.union(resolved_ops)
else:
return False
return True
continue
# Superset because path/cumulative resolved includes resolution helper ops
# such as broadcast.
all_conflicts_resolved = cummulative_resolved.issuperset(conflicted_ops)
harsh-nod marked this conversation as resolved.
Show resolved Hide resolved
return all_conflicts_resolved


###############################################################################
# Main pass
#####################################################################


def determine_thread_shapes(trace: CapturedTrace):
Expand Down Expand Up @@ -118,18 +150,6 @@ def determine_thread_shapes(trace: CapturedTrace):
thread_sizes_to_ops[frozenset({IndexSize(index=K, size=4), IndexSize(index=N, size=1)}] = set(rhs, ...)

"""

# Anchor ops are ops who's thread shape are predetermined.
anchorOpTypes = (Read, Write, MMA, ReduceOp)
noHandleTypes = (Placeholder, Output, ExtractSlice, Allocate)
nonPropagatableTypes = anchorOpTypes + noHandleTypes

def is_anchor_op(node: fx.Node):
return isinstance(get_custom(node), anchorOpTypes)

def propagatable_op(node: fx.Node):
return not isinstance(get_custom(node), nonPropagatableTypes)

anchor_ops = trace.walk(is_anchor_op)
thread_size_to_ops: dict[frozenset[DimSize], set[CustomOp]] = {}
for anchor_op in anchor_ops:
Expand Down
76 changes: 76 additions & 0 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1438,6 +1438,82 @@ def repeat(
# CHECK: scf.yield %[[ACC_MAX_0]], %[[ACC_SUM_0]], %[[ACC_MAX_1]], %[[ACC_SUM_1]]


# This test is used to ensure:
# 1. ReduceOp has correct symbolic shape for thread shape analysis.
# 2. We can propagate the resolved indexing from broadcast.(in this case from sub to exp2.)
@run_test
def test_reduce_propagate_broadcast():
constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(1, 1, 1),
vector_shapes={M: 1, N: BLOCK_N},
)
]
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)]
constraints += [tkw.WorkgroupConstraint(N, N, 0)]
constraints += [tkw.TilingConstraint(N, BLOCK_N)]
constraints += [tkw.WaveConstraint(M, BLOCK_M)]
constraints += [tkw.WaveConstraint(N, BLOCK_N)]

@tkw.wave(constraints)
def test(
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32],
c: tkl.Memory[M, ADDRESS_SPACE, tkl.f32],
):
init_max = tkl.Register[M, tkl.f32](-1e6)
init_sum = tkl.Register[M, tkl.f32](0)

@tkw.reduction(N, init_args=[init_max, init_sum])
def repeat(
partial_max: tkl.Register[M, tkl.f32],
partial_sum: tkl.Register[M, tkl.f32],
) -> tkl.Register[M, tkl.f32]:
src = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD)
m_src = tkw.max(src, partial_max, dim=N)
exp_d = tkw.exp2(src - m_src)
sum_d = tkw.sum(exp_d, partial_sum, dim=N)
return m_src, sum_d

res_max, res_sum = repeat
tkw.write(res_sum, c, elements_per_thread=1)

config = {"backend": "rocm", "device": "hip", "target": "gfx942"}

shape = (256, 1024)
a = torch.randn(shape, dtype=torch.float32)
c = torch.zeros((shape[0],), dtype=torch.float32)
with tk.gen.TestLaunchContext(
{
M: shape[0],
N: shape[1],
BLOCK_M: 1,
BLOCK_N: 128,
LOAD_ELEMS_PER_THREAD: 2,
ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value,
},
canonicalize=True,
run=False,
run_config=config,
):
print(test(a, c).module_op)
# CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
# CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
# CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
# CHECK: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<1xf32>
# CHECK: scf.for %{{.*}} = %[[C0]] to %[[C8]] step %[[C1]]
# CHECK-COUNT-7: arith.maximumf
# CHECK: %[[ACC_MAX:.+]] = arith.maximumf
# CHECK: %[[EXTRACT:.+]] = vector.extract %[[ACC_MAX]][0] : f32 from vector<1xf32>
# CHECK: %[[BROADCAST:.+]] = vector.splat %[[EXTRACT]] : vector<2xf32>
# CHECK: %[[SUBF:.+]] = arith.subf %{{.+}}, %[[BROADCAST]] : vector<2xf32>
# CHECK: %[[EXP2:.+]] = math.exp2 %[[SUBF]] : vector<2xf32>
# CHECK: %[[EXP2_SLICE_0:.+]] = vector.extract_strided_slice %[[EXP2]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32>
# CHECK: %[[EXP2_SLICE_1:.+]] = vector.extract_strided_slice %[[EXP2]] {offsets = [1], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32>
# CHECK: arith.addf %[[EXP2_SLICE_0]], %[[EXP2_SLICE_1]]
# CHECK-COUNT-6: gpu.shuffle xor


@run_test
def test_broadcast_add():
constraints: list[tkw.Constraint] = [
Expand Down
Loading