Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for varying vector shapes #247

Merged
merged 5 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ def permute(src: "Register", target_shape: Sequence[IndexExpr]) -> "Register":
...


def reshape(inputs: Sequence["Register"]) -> "Register":
...


def define_op(op_name: str) -> Callable[[T], T]:
def decorator(cls: T) -> T:
cls.tkw_op_name = op_name
Expand Down Expand Up @@ -1400,3 +1404,28 @@ def type(self) -> Register:
self.target_shape
), f"Target shape {self.target_shape} must be a permutation of source shape {src_type.symbolic_shape}"
return Register[*self.target_shape, src_type.dtype]


def _to_sequence(input: Any | Sequence[Any]) -> Sequence[Any]:
return input if isinstance(input, Sequence) else (input,)


@define_op("reshape")
@dataclass
class Reshape(CustomOp, ABC):
"""
Represents a reshape operation that reshapes
vectors along the same dimension.

"""

args: fx.Node | Sequence[fx.Node]
target_vector_shape: dict[IndexSymbol, int]

@property
def indexing_dims(self) -> list[IndexExpr]:
return get_custom(_to_sequence(self.args)[0]).indexing_dims

@property
def type(self) -> Register:
return get_custom(_to_sequence(self.args)[0]).type
53 changes: 53 additions & 0 deletions iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
scheduling_group_barrier,
cast,
permute,
reshape,
)
from ..lang.wave_types import IndexMapping, IndexSymbol
from ..compiler.base import CodegenError, ValidationError, NDEBUG
Expand Down Expand Up @@ -1310,3 +1311,55 @@ def handle_permute(emitter: WaveEmitter, node: fx.Node):
raise ValidationError("Malformed arguments") from e
vector_src = cast_py_value(emitter, register)
emitter.bind_node_proxy(node, vector_src)


@handle_op(reshape)
def handle_reshape(emitter: WaveEmitter, node: fx.Node):
try:
args, target_vector_shapes = node.args
except ValueError as e:
raise ValidationError("Malformed arguments") from e
custom = get_custom(node)
innermost_dim = custom.type.symbolic_shape[-1]
offset = custom.expanded_dims[innermost_dim]

# Determine whether to extract or combine.
if len(args) > 1:
concatenated = None
for i, sub_arg in enumerate(args):
vector = cast_vector(emitter, sub_arg)
shape = vector.type.shape[0]
if concatenated is None:
element_type = vector.type.element_type
vector_type = VectorType.get([shape * len(args)], element_type)
concatenated = arith_d.ConstantOp(
vector_type,
DenseElementsAttr.get_splat(
vector_type, get_constant_attr(0, element_type)
),
).result
concatenated = vector_d.insert_strided_slice(
vector, concatenated, [i * shape], [1]
)
emitter.bind_node_proxy(node, IRProxyValue(concatenated))
return

# Extract the appropriate slice. The offset is obtained from the expanded_dim
# and so corresponds to the dim_query during expansion. To obtain the
# actual offset, we need to multiply by the size. The size is obtained by
# computing the number of partitions using the source and target vector shapes
# and dividing the incoming vector shape by the number of partitions.
num_partitions = (
target_vector_shapes[innermost_dim] // custom.vector_shapes[innermost_dim]
)
vector = cast_vector(emitter, args[0])
size = vector.type.shape[0] // num_partitions
result_type = VectorType.get([size], vector.type.element_type)
slice = vector_d.extract_strided_slice(
result_type,
vector,
[offset * size],
[size],
[1],
)
emitter.bind_node_proxy(node, IRProxyValue(slice))
92 changes: 92 additions & 0 deletions iree/turbine/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,20 @@ def _expand_node(
new_node.expanded_dims = restricted_dims
new_node.fx_node.name = get_expanded_name(node, restricted_dims)

# For reshapes, we need more explicit control over how the arguments are expanded.
if isinstance(new_node, Reshape):
_expand_reshape(
new_node,
trace,
dim_query,
dim_scaling,
context,
get_node_dim_scaling,
res_idx,
)
context[(node, get_indexed_dims(restricted_dims, node), res_idx)] = new_node
return new_node

# Proceed with expansion of the arguments
for i, arg in node.node_args.items():
arg_list = arg
Expand Down Expand Up @@ -496,6 +510,84 @@ def _expand_mma_reduction(
return new_node


def _expand_reshape(
reshape: Reshape,
trace: CapturedTrace,
dim_query: dict[IndexSymbol, int],
dim_scaling: dict[IndexSymbol, int],
context: ExpandedNodeMap,
get_node_dim_scaling: Callable[[fx.Node], dict[IndexSymbol, int]],
res_idx: int,
) -> CustomOp:
"""
When expanding a reshape, we have to expand the arguments of the reshape and then concatenate them together
for the expanded node. Say we have a node with indexing dims = [M, N] with vector shapes m=8, n=2 and
the reshape wants to map it to m=4, n=4. So we start by expanding the node
node: {m = 0, n = 0}
arg: {m = 0, n = 0}
arg: {m = 0, n = 1}
node: {m = 1, n = 0}
arg: {m = 0, n = 0}
arg: {m = 0, n = 1}
node: {m = 2, n = 0}
arg: {m = 1, n = 0}
arg: {m = 1, n = 1}
node: {m = 3, n = 0}
arg: {m = 1, n = 0}
arg: {m = 1, n = 1}
...
In general,
For the (m = i, n = j) expansion of the reshape node, we expand the arguments of the reshape node
using the following recipe:
- if m_src < m_dst, => we have a one to many mapping from source to destination
so we expand the arguments along m = i // (m_dst / m_src) and we expand the argument only once.
- if m_src > m_dst, => we have a many to one mapping from source to destination
so we expand the arguments along m = i * (m_src / m_dst), ... and we expand the argument m_dst / m_src times.

In situations where the argument has been expanded along the same dimension, we reuse the expanded node
by making use of the context.
"""

dim_combinations = {}
for dim, value in dim_query.items():
if dim not in reshape.target_vector_shape:
continue
if reshape.vector_shapes[dim] < reshape.target_vector_shape[dim]:
scale_factor = (
reshape.target_vector_shape[dim] // reshape.vector_shapes[dim]
)
dim_combinations[dim] = [value // scale_factor]
else:
scale_factor = (
reshape.vector_shapes[dim] // reshape.target_vector_shape[dim]
)
begin = value * scale_factor
dim_combinations[dim] = list(range(begin, begin + scale_factor))
reshape_dim_combinations = list(itertools.product(*dim_combinations.values()))

new_args = []
for i, arg_dim_query in enumerate(reshape_dim_combinations):
arg_dim_query = {
dim: val for dim, val in zip(dim_combinations.keys(), arg_dim_query)
}
if isinstance(reshape.args, Sequence):
custom_arg = get_custom(reshape.args[i])
else:
custom_arg = get_custom(reshape.args)
new_node = _expand_node(
custom_arg,
trace,
arg_dim_query,
get_node_dim_scaling(custom_arg.fx_node),
context,
get_node_dim_scaling,
res_idx,
)
new_args.append(new_node.fx_node)

reshape.update_arg("args", new_args)

Comment on lines +586 to +589
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to do anything special to ensure this ordering of expansion is correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, but do you have an example in mind?


def get_expanded_name(node: CustomOp, dims: dict[IndexSymbol, int]) -> str:
"""Returns the name of a node with the dimensions appended."""

Expand Down
7 changes: 5 additions & 2 deletions iree/turbine/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ..ops.wave_ops import (
Allocate,
Write,
ExtractSlice,
get_custom,
Reduction,
MMA,
Placeholder,
IterArg,
CustomOp,
Reshape,
)
from .constraints import Constraint, HardwareConstraint, WorkgroupConstraint
from .._support.tracing import CapturedTrace, IndexingContext
Expand Down Expand Up @@ -193,8 +196,8 @@ def set_vector_shapes(
an MMA slice as well as the anchor node.
"""
custom = get_custom(node)
# MMA & Reduction nodes already have their vector shapes set.
if isinstance(custom, (MMA, Reduction)):
# MMA, Reduction & Reshape nodes already have their vector shapes set.
if isinstance(custom, (MMA, Reduction, Reshape)):
return
# Add vector shapes from constraints to all ops. These are global constraints.
custom.vector_shapes = {}
Expand Down
44 changes: 44 additions & 0 deletions iree/turbine/kernel/wave/iree_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,44 @@ def get_chain_mmt_asm(
}}"""


def get_chain_mmt_f8_asm(
query_type: str, key_type: str, value_type: str, output_type: str
) -> str:
B, M, K1, input_dtype = query_type.split("x")
B, K2, K1, input_dtype = key_type.split("x")
B, N, K2, input_dtype = value_type.split("x")
B, N, M, output_dtype = output_type.split("x")
f8_dtype = "f8E4M3FNUZ"
intermediate_output_type = f"{B}x{K2}x{M}x{output_dtype}"
intermediate_cast_type = f"{B}x{K2}x{M}x{f8_dtype}"
transposed_cast_type = f"{B}x{M}x{K2}x{f8_dtype}"
transposed_output_type = f"{B}x{M}x{N}x{output_dtype}"
query_f8_type = "x".join([B, M, K1, f8_dtype])
key_f8_type = "x".join([B, K2, K1, f8_dtype])
value_f8_type = "x".join([B, N, K2, f8_dtype])
return f"""
func.func @chain_mmt_f8(%query: tensor<{query_type}>, %key: tensor<{key_type}>, %value: tensor<{value_type}>) -> tensor<{output_type}> {{
%c0 = arith.constant 0.0 : f32
%init = tensor.empty() : tensor<{intermediate_output_type}>
%query_f8 = arith.truncf %query : tensor<{query_type}> to tensor<{query_f8_type}>
%key_f8 = arith.truncf %key : tensor<{key_type}> to tensor<{key_f8_type}>
%inital_result = linalg.fill ins(%c0 : f32) outs(%init : tensor<{intermediate_output_type}>) -> tensor<{intermediate_output_type}>
%result = linalg.batch_matmul_transpose_b ins(%key_f8, %query_f8 : tensor<{key_f8_type}>, tensor<{query_f8_type}>)
outs(%inital_result : tensor<{intermediate_output_type}>) -> tensor<{intermediate_output_type}>
%trunc = arith.truncf %result : tensor<{intermediate_output_type}> to tensor<{intermediate_cast_type}>
%init2 = tensor.empty() : tensor<{transposed_cast_type}>
%transpose = linalg.transpose ins(%trunc: tensor<{intermediate_cast_type}>) outs(%init2: tensor<{transposed_cast_type}>) permutation=[0, 2, 1]
%init3 = tensor.empty() : tensor<{transposed_output_type}>
%inital_result3 = linalg.fill ins(%c0 : f32) outs(%init3 : tensor<{transposed_output_type}>) -> tensor<{transposed_output_type}>
%value_f8 = arith.truncf %value : tensor<{value_type}> to tensor<{value_f8_type}>
%result2 = linalg.batch_matmul_transpose_b ins(%transpose, %value_f8: tensor<{transposed_cast_type}>, tensor<{value_f8_type}>)
outs(%inital_result3 : tensor<{transposed_output_type}>) -> tensor<{transposed_output_type}>
%init4 = tensor.empty() : tensor<{output_type}>
%transpose2 = linalg.transpose ins(%result2: tensor<{transposed_output_type}>) outs(%init4: tensor<{output_type}>) permutation=[0, 2, 1]
return %transpose2 : tensor<{output_type}>
}}"""


def get_mmt_asm(
lhs_type: str,
rhs_type: str,
Expand Down Expand Up @@ -141,6 +179,12 @@ def generate_iree_ref(
value_type = get_type_str(kernel_inputs[2].shape, kernel_inputs[2].dtype)
output_type = get_type_str(kernel_outputs[0].shape, kernel_outputs[0].dtype)
asm = get_chain_mmt_asm(query_type, key_type, value_type, output_type)
elif kernel_type == "chain_mmt_f8":
query_type = get_type_str(kernel_inputs[0].shape, kernel_inputs[0].dtype)
key_type = get_type_str(kernel_inputs[1].shape, kernel_inputs[1].dtype)
value_type = get_type_str(kernel_inputs[2].shape, kernel_inputs[2].dtype)
output_type = get_type_str(kernel_outputs[0].shape, kernel_outputs[0].dtype)
asm = get_chain_mmt_f8_asm(query_type, key_type, value_type, output_type)
elif kernel_type.startswith(conv_str):
lhs_type = get_type_str(kernel_inputs[0].shape, kernel_inputs[0].dtype)
rhs_type = get_type_str(kernel_inputs[1].shape, kernel_inputs[1].dtype)
Expand Down
43 changes: 42 additions & 1 deletion iree/turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
UnitAttr,
Value,
)
from typing import Optional, Callable, Any, List, Tuple
from typing import Optional, Callable, Any, List, Tuple, Sequence
from .._support.tracing import CapturedTrace
from .._support.indexing import IndexExpr, IndexingContext, IndexSymbol, IndexSequence
from ..lang.global_symbols import *
Expand All @@ -25,6 +25,7 @@
Reduction,
GetResult,
IterArg,
Reshape,
)
from .constraints import (
Constraint,
Expand Down Expand Up @@ -192,6 +193,20 @@ def simplify_index(index: IndexExpr) -> IndexExpr:
return subs_idxc(index.subs(mapping))


def is_reshape_needed(
node: CustomOp,
node_vector_shapes: dict[IndexSymbol, int],
vector_shapes: dict[IndexSymbol, int],
) -> bool:
for dim in node.type.symbolic_shape:
if dim not in vector_shapes:
# Ignore nodes that are not used in both mmas.
return False
if node_vector_shapes[dim] != vector_shapes[dim]:
return True
return False


def get_mma_dimensional_mapping(
trace: CapturedTrace,
hardware_constraint: HardwareConstraint,
Expand Down Expand Up @@ -243,6 +258,32 @@ def is_mma(node):
reduction.anchor = custom

mma_slices = {get_custom(x): capture_mma_slices(get_custom(x)) for x in mma_nodes}

# Determine if any reshapes are required. Reshapes are added for
# chained matmuls when the vector shapes of the operands in one matmul
# differ from those in another matmul.
for src in mma_nodes:
custom_src = get_custom(src)
for dst in mma_nodes:
if src == dst:
continue
custom_dst = get_custom(dst)
lhs_slice = capture_backward_slice(custom_dst.lhs)
rhs_slice = capture_backward_slice(custom_dst.rhs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we remove the double nested loop, and instead find the mma op inside lhs_slice/rhs_slice?
I think we can add a stop right after is_mma, since we'd only want to find the closest MMA in the use def chain.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes we can, but then we would have to iterate through all the operators in the slice to see if any of them has an mma as an argument. That would end up with a triple-nested-loop. Unless you are thinking of something else?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay I think I have a cleaner version now. let me know what you think.

if src in lhs_slice or src in rhs_slice:
with custom_dst.graph.inserting_before(dst):
for i, arg in custom_dst.node_args.items():
if is_reshape_needed(
arg, custom_dst.vector_shapes, custom_src.vector_shapes
):
reshape = Reshape(
arg.fx_node, custom_src.vector_shapes
).add_to_graph(custom.graph)
custom_reshape = get_custom(reshape)
custom_reshape.vector_shapes = custom.vector_shapes
custom_reshape.anchor = custom
custom.update_arg(i, reshape)

return mapping, mma_slices


Expand Down
Loading
Loading