Skip to content

Commit

Permalink
Add support for varying vector shapes (#247)
Browse files Browse the repository at this point in the history
  • Loading branch information
harsh-nod authored Oct 31, 2024
1 parent 2b45c0f commit 8febe6a
Show file tree
Hide file tree
Showing 9 changed files with 819 additions and 7 deletions.
29 changes: 29 additions & 0 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ def permute(src: "Register", target_shape: Sequence[IndexExpr]) -> "Register":
...


def reshape(inputs: Sequence["Register"]) -> "Register":
...


def define_op(op_name: str) -> Callable[[T], T]:
def decorator(cls: T) -> T:
cls.tkw_op_name = op_name
Expand Down Expand Up @@ -1400,3 +1404,28 @@ def type(self) -> Register:
self.target_shape
), f"Target shape {self.target_shape} must be a permutation of source shape {src_type.symbolic_shape}"
return Register[*self.target_shape, src_type.dtype]


def _to_sequence(input: Any | Sequence[Any]) -> Sequence[Any]:
return input if isinstance(input, Sequence) else (input,)


@define_op("reshape")
@dataclass
class Reshape(CustomOp, ABC):
"""
Represents a reshape operation that reshapes
vectors along the same dimension.
"""

args: fx.Node | Sequence[fx.Node]
target_vector_shape: dict[IndexSymbol, int]

@property
def indexing_dims(self) -> list[IndexExpr]:
return get_custom(_to_sequence(self.args)[0]).indexing_dims

@property
def type(self) -> Register:
return get_custom(_to_sequence(self.args)[0]).type
53 changes: 53 additions & 0 deletions iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
scheduling_group_barrier,
cast,
permute,
reshape,
)
from ..lang.wave_types import IndexMapping, IndexSymbol
from ..compiler.base import CodegenError, ValidationError, NDEBUG
Expand Down Expand Up @@ -1310,3 +1311,55 @@ def handle_permute(emitter: WaveEmitter, node: fx.Node):
raise ValidationError("Malformed arguments") from e
vector_src = cast_py_value(emitter, register)
emitter.bind_node_proxy(node, vector_src)


@handle_op(reshape)
def handle_reshape(emitter: WaveEmitter, node: fx.Node):
try:
args, target_vector_shapes = node.args
except ValueError as e:
raise ValidationError("Malformed arguments") from e
custom = get_custom(node)
innermost_dim = custom.type.symbolic_shape[-1]
offset = custom.expanded_dims[innermost_dim]

# Determine whether to extract or combine.
if len(args) > 1:
concatenated = None
for i, sub_arg in enumerate(args):
vector = cast_vector(emitter, sub_arg)
shape = vector.type.shape[0]
if concatenated is None:
element_type = vector.type.element_type
vector_type = VectorType.get([shape * len(args)], element_type)
concatenated = arith_d.ConstantOp(
vector_type,
DenseElementsAttr.get_splat(
vector_type, get_constant_attr(0, element_type)
),
).result
concatenated = vector_d.insert_strided_slice(
vector, concatenated, [i * shape], [1]
)
emitter.bind_node_proxy(node, IRProxyValue(concatenated))
return

# Extract the appropriate slice. The offset is obtained from the expanded_dim
# and so corresponds to the dim_query during expansion. To obtain the
# actual offset, we need to multiply by the size. The size is obtained by
# computing the number of partitions using the source and target vector shapes
# and dividing the incoming vector shape by the number of partitions.
num_partitions = (
target_vector_shapes[innermost_dim] // custom.vector_shapes[innermost_dim]
)
vector = cast_vector(emitter, args[0])
size = vector.type.shape[0] // num_partitions
result_type = VectorType.get([size], vector.type.element_type)
slice = vector_d.extract_strided_slice(
result_type,
vector,
[offset * size],
[size],
[1],
)
emitter.bind_node_proxy(node, IRProxyValue(slice))
92 changes: 92 additions & 0 deletions iree/turbine/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,20 @@ def _expand_node(
new_node.expanded_dims = restricted_dims
new_node.fx_node.name = get_expanded_name(node, restricted_dims)

# For reshapes, we need more explicit control over how the arguments are expanded.
if isinstance(new_node, Reshape):
_expand_reshape(
new_node,
trace,
dim_query,
dim_scaling,
context,
get_node_dim_scaling,
res_idx,
)
context[(node, get_indexed_dims(restricted_dims, node), res_idx)] = new_node
return new_node

# Proceed with expansion of the arguments
for i, arg in node.node_args.items():
arg_list = arg
Expand Down Expand Up @@ -496,6 +510,84 @@ def _expand_mma_reduction(
return new_node


def _expand_reshape(
reshape: Reshape,
trace: CapturedTrace,
dim_query: dict[IndexSymbol, int],
dim_scaling: dict[IndexSymbol, int],
context: ExpandedNodeMap,
get_node_dim_scaling: Callable[[fx.Node], dict[IndexSymbol, int]],
res_idx: int,
) -> CustomOp:
"""
When expanding a reshape, we have to expand the arguments of the reshape and then concatenate them together
for the expanded node. Say we have a node with indexing dims = [M, N] with vector shapes m=8, n=2 and
the reshape wants to map it to m=4, n=4. So we start by expanding the node
node: {m = 0, n = 0}
arg: {m = 0, n = 0}
arg: {m = 0, n = 1}
node: {m = 1, n = 0}
arg: {m = 0, n = 0}
arg: {m = 0, n = 1}
node: {m = 2, n = 0}
arg: {m = 1, n = 0}
arg: {m = 1, n = 1}
node: {m = 3, n = 0}
arg: {m = 1, n = 0}
arg: {m = 1, n = 1}
...
In general,
For the (m = i, n = j) expansion of the reshape node, we expand the arguments of the reshape node
using the following recipe:
- if m_src < m_dst, => we have a one to many mapping from source to destination
so we expand the arguments along m = i // (m_dst / m_src) and we expand the argument only once.
- if m_src > m_dst, => we have a many to one mapping from source to destination
so we expand the arguments along m = i * (m_src / m_dst), ... and we expand the argument m_dst / m_src times.
In situations where the argument has been expanded along the same dimension, we reuse the expanded node
by making use of the context.
"""

dim_combinations = {}
for dim, value in dim_query.items():
if dim not in reshape.target_vector_shape:
continue
if reshape.vector_shapes[dim] < reshape.target_vector_shape[dim]:
scale_factor = (
reshape.target_vector_shape[dim] // reshape.vector_shapes[dim]
)
dim_combinations[dim] = [value // scale_factor]
else:
scale_factor = (
reshape.vector_shapes[dim] // reshape.target_vector_shape[dim]
)
begin = value * scale_factor
dim_combinations[dim] = list(range(begin, begin + scale_factor))
reshape_dim_combinations = list(itertools.product(*dim_combinations.values()))

new_args = []
for i, arg_dim_query in enumerate(reshape_dim_combinations):
arg_dim_query = {
dim: val for dim, val in zip(dim_combinations.keys(), arg_dim_query)
}
if isinstance(reshape.args, Sequence):
custom_arg = get_custom(reshape.args[i])
else:
custom_arg = get_custom(reshape.args)
new_node = _expand_node(
custom_arg,
trace,
arg_dim_query,
get_node_dim_scaling(custom_arg.fx_node),
context,
get_node_dim_scaling,
res_idx,
)
new_args.append(new_node.fx_node)

reshape.update_arg("args", new_args)


def get_expanded_name(node: CustomOp, dims: dict[IndexSymbol, int]) -> str:
"""Returns the name of a node with the dimensions appended."""

Expand Down
7 changes: 5 additions & 2 deletions iree/turbine/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ..ops.wave_ops import (
Allocate,
Write,
ExtractSlice,
get_custom,
Reduction,
MMA,
Placeholder,
IterArg,
CustomOp,
Reshape,
)
from .constraints import Constraint, HardwareConstraint, WorkgroupConstraint
from .._support.tracing import CapturedTrace, IndexingContext
Expand Down Expand Up @@ -193,8 +196,8 @@ def set_vector_shapes(
an MMA slice as well as the anchor node.
"""
custom = get_custom(node)
# MMA & Reduction nodes already have their vector shapes set.
if isinstance(custom, (MMA, Reduction)):
# MMA, Reduction & Reshape nodes already have their vector shapes set.
if isinstance(custom, (MMA, Reduction, Reshape)):
return
# Add vector shapes from constraints to all ops. These are global constraints.
custom.vector_shapes = {}
Expand Down
44 changes: 44 additions & 0 deletions iree/turbine/kernel/wave/iree_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,44 @@ def get_chain_mmt_asm(
}}"""


def get_chain_mmt_f8_asm(
query_type: str, key_type: str, value_type: str, output_type: str
) -> str:
B, M, K1, input_dtype = query_type.split("x")
B, K2, K1, input_dtype = key_type.split("x")
B, N, K2, input_dtype = value_type.split("x")
B, N, M, output_dtype = output_type.split("x")
f8_dtype = "f8E4M3FNUZ"
intermediate_output_type = f"{B}x{K2}x{M}x{output_dtype}"
intermediate_cast_type = f"{B}x{K2}x{M}x{f8_dtype}"
transposed_cast_type = f"{B}x{M}x{K2}x{f8_dtype}"
transposed_output_type = f"{B}x{M}x{N}x{output_dtype}"
query_f8_type = "x".join([B, M, K1, f8_dtype])
key_f8_type = "x".join([B, K2, K1, f8_dtype])
value_f8_type = "x".join([B, N, K2, f8_dtype])
return f"""
func.func @chain_mmt_f8(%query: tensor<{query_type}>, %key: tensor<{key_type}>, %value: tensor<{value_type}>) -> tensor<{output_type}> {{
%c0 = arith.constant 0.0 : f32
%init = tensor.empty() : tensor<{intermediate_output_type}>
%query_f8 = arith.truncf %query : tensor<{query_type}> to tensor<{query_f8_type}>
%key_f8 = arith.truncf %key : tensor<{key_type}> to tensor<{key_f8_type}>
%inital_result = linalg.fill ins(%c0 : f32) outs(%init : tensor<{intermediate_output_type}>) -> tensor<{intermediate_output_type}>
%result = linalg.batch_matmul_transpose_b ins(%key_f8, %query_f8 : tensor<{key_f8_type}>, tensor<{query_f8_type}>)
outs(%inital_result : tensor<{intermediate_output_type}>) -> tensor<{intermediate_output_type}>
%trunc = arith.truncf %result : tensor<{intermediate_output_type}> to tensor<{intermediate_cast_type}>
%init2 = tensor.empty() : tensor<{transposed_cast_type}>
%transpose = linalg.transpose ins(%trunc: tensor<{intermediate_cast_type}>) outs(%init2: tensor<{transposed_cast_type}>) permutation=[0, 2, 1]
%init3 = tensor.empty() : tensor<{transposed_output_type}>
%inital_result3 = linalg.fill ins(%c0 : f32) outs(%init3 : tensor<{transposed_output_type}>) -> tensor<{transposed_output_type}>
%value_f8 = arith.truncf %value : tensor<{value_type}> to tensor<{value_f8_type}>
%result2 = linalg.batch_matmul_transpose_b ins(%transpose, %value_f8: tensor<{transposed_cast_type}>, tensor<{value_f8_type}>)
outs(%inital_result3 : tensor<{transposed_output_type}>) -> tensor<{transposed_output_type}>
%init4 = tensor.empty() : tensor<{output_type}>
%transpose2 = linalg.transpose ins(%result2: tensor<{transposed_output_type}>) outs(%init4: tensor<{output_type}>) permutation=[0, 2, 1]
return %transpose2 : tensor<{output_type}>
}}"""


def get_mmt_asm(
lhs_type: str,
rhs_type: str,
Expand Down Expand Up @@ -141,6 +179,12 @@ def generate_iree_ref(
value_type = get_type_str(kernel_inputs[2].shape, kernel_inputs[2].dtype)
output_type = get_type_str(kernel_outputs[0].shape, kernel_outputs[0].dtype)
asm = get_chain_mmt_asm(query_type, key_type, value_type, output_type)
elif kernel_type == "chain_mmt_f8":
query_type = get_type_str(kernel_inputs[0].shape, kernel_inputs[0].dtype)
key_type = get_type_str(kernel_inputs[1].shape, kernel_inputs[1].dtype)
value_type = get_type_str(kernel_inputs[2].shape, kernel_inputs[2].dtype)
output_type = get_type_str(kernel_outputs[0].shape, kernel_outputs[0].dtype)
asm = get_chain_mmt_f8_asm(query_type, key_type, value_type, output_type)
elif kernel_type.startswith(conv_str):
lhs_type = get_type_str(kernel_inputs[0].shape, kernel_inputs[0].dtype)
rhs_type = get_type_str(kernel_inputs[1].shape, kernel_inputs[1].dtype)
Expand Down
59 changes: 58 additions & 1 deletion iree/turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
UnitAttr,
Value,
)
from typing import Optional, Callable, Any, List, Tuple
from typing import Optional, Callable, Any, List, Tuple, Sequence
from .._support.tracing import CapturedTrace
from .._support.indexing import IndexExpr, IndexingContext, IndexSymbol, IndexSequence
from ..lang.global_symbols import *
Expand All @@ -25,6 +25,7 @@
Reduction,
GetResult,
IterArg,
Reshape,
)
from .constraints import (
Constraint,
Expand Down Expand Up @@ -192,6 +193,20 @@ def simplify_index(index: IndexExpr) -> IndexExpr:
return subs_idxc(index.subs(mapping))


def is_reshape_needed(
node: CustomOp,
node_vector_shapes: dict[IndexSymbol, int],
vector_shapes: dict[IndexSymbol, int],
) -> bool:
for dim in node.type.symbolic_shape:
if dim not in vector_shapes:
# Ignore nodes that are not used in both mmas.
return False
if node_vector_shapes[dim] != vector_shapes[dim]:
return True
return False


def get_mma_dimensional_mapping(
trace: CapturedTrace,
hardware_constraint: HardwareConstraint,
Expand Down Expand Up @@ -243,6 +258,48 @@ def is_mma(node):
reduction.anchor = custom

mma_slices = {get_custom(x): capture_mma_slices(get_custom(x)) for x in mma_nodes}

# Determine if any reshapes are required. Reshapes are added for
# chained matmuls when the vector shapes of the operands in one matmul
# differ from those in another matmul. The mma_slices contain all the ops
# in the backward slice of the lhs and rhs upto a previous mma (if one exists).
# So we check for the previous node of the first operator in the slice to see
# if it is an MMA and if so check if a reshape is required.
def add_reshape_if_needed(mma: MMA, prev_mma: MMA):
with mma.graph.inserting_before(mma.fx_node):
for i, arg in mma.node_args.items():
if is_reshape_needed(arg, mma.vector_shapes, prev_mma.vector_shapes):
reshape = Reshape(arg.fx_node, prev_mma.vector_shapes).add_to_graph(
custom.graph
)
custom_reshape = get_custom(reshape)
custom_reshape.vector_shapes = custom.vector_shapes
custom_reshape.anchor = custom
custom.update_arg(i, reshape)

def find_mma_in_slice(node: CustomOp) -> Optional[MMA]:
"""
Find the closest mma by iterating through the backward slice of a node
in reverse.
"""
slice = list(capture_backward_slice(node))
for arg in reversed(slice):
prev_mma = get_custom(arg)
if isinstance(prev_mma, MMA):
return prev_mma
return None

# Look in the backward slices of both the LHS and RHS to find
# mmas. If found, add reshapes if necessary.
for mma in mma_nodes:
custom_mma = get_custom(mma)
prev_mma = find_mma_in_slice(custom_mma.lhs)
if prev_mma:
add_reshape_if_needed(custom_mma, prev_mma)
prev_mma = find_mma_in_slice(custom_mma.rhs)
if prev_mma:
add_reshape_if_needed(custom_mma, prev_mma)

return mapping, mma_slices


Expand Down
Loading

0 comments on commit 8febe6a

Please sign in to comment.