Skip to content

Commit

Permalink
Add support for batch matmuls
Browse files Browse the repository at this point in the history
This PR adds support for batch dimensions.
They can be specified in the vector_shapes dict
with a shape of 0 with a corresponding workgroup
constraint. This specifies that we want to distribute
the batch dim among workgroups, but not among threads.

Signed-off-by: Harsh Menon <harsh@nod-labs.com>
  • Loading branch information
harsh-nod committed Oct 14, 2024
1 parent bacfdcd commit cb92697
Show file tree
Hide file tree
Showing 8 changed files with 566 additions and 61 deletions.
10 changes: 6 additions & 4 deletions iree/turbine/kernel/wave/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ class HardwareConstraint(Constraint):
these situations, the user can specify the vector shape they
want to tile to by specifying the vector shapes dictionary
which maps a tensor dimension to its corresponding tile size.
Both mma constraints and vector shapes can be specified, but
the mapping from symbols to shapes should be injective.
"""

threads_per_wave: int
Expand Down Expand Up @@ -116,21 +119,20 @@ def compute_access_pattern_using_vector_shapes(
elements_per_thread: int | IndexSymbol,
stride: int,
) -> IndexSequence:
if dim not in self.vector_shapes:
raise ValueError(f"No vector shape specified for dimension {dim}")
thread_id = self.get_thread_id_from_workgroup_dim(workgroup_dim)
return IndexSequence(
thread_id * elements_per_thread, elements_per_thread, stride
)

def apply(
self,
constraint_index: int,
dim: IndexSymbol,
constraint_index: int,
elements_per_thread: int | IndexSymbol,
stride: int,
is_mma_dim: bool,
) -> IndexSequence:
if self.vector_shapes is not None:
if not is_mma_dim:
return self.compute_access_pattern_using_vector_shapes(
dim, constraint_index, elements_per_thread, stride
)
Expand Down
74 changes: 41 additions & 33 deletions iree/turbine/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
from .._support.indexing import IndexingContext, IndexSequence
from ...support.logging import get_logger
from .._support.tracing import CapturedTrace
from .utils import get_mma_dimensional_mapping, specialize_index_sequence
from .utils import (
get_mma_dimensional_mapping,
specialize_index_sequence,
get_hardware_constraint,
get_workgroup_constraints,
)
from ..lang.global_symbols import *

logger = get_logger("turbine.wave.expansion")
Expand Down Expand Up @@ -163,7 +168,7 @@ def set_node_index(
dimensions M, N and K, but allow each mapping to be piecewise conditioned
on the operand.
"""
hardware_constraint = [c for c in constraints if isinstance(c, HardwareConstraint)]
hardware_constraint = [get_hardware_constraint(constraints)]
workgroup_constraints = {
c.dim: c for c in constraints if isinstance(c, WorkgroupConstraint)
}
Expand All @@ -174,31 +179,28 @@ def set_node_index(
sorted_constraints = hardware_constraint + other_constraints

index = {}
# The semantics of elements_per_thread are that it represents the number of
# elements that are loaded contiguously from memory.
elements_per_thread = getattr(custom, "elements_per_thread", None)

for dim in custom.indexing_dims:
index_seq = None
for constraint in sorted_constraints:
mma_check = isinstance(constraint, HardwareConstraint) and dim in mma_index

vector_check = (
isinstance(constraint, HardwareConstraint)
and constraint.vector_shapes is not None
and hasattr(custom, "elements_per_thread")
)

constraint_check = (
not isinstance(constraint, HardwareConstraint) and dim == constraint.dim
)

if (not (mma_check or vector_check)) and (not constraint_check):
continue

if isinstance(constraint, HardwareConstraint):

# The semantics of elements_per_thread are that it represents the number of
# elements that are loaded contiguously from memory.
elements_per_thread = getattr(custom, "elements_per_thread", None)
constraint_index, elements_per_thread, stride = (
(
inputs = None
if dim in mma_index:
inputs = (mma_index[dim], elements_per_thread, None)
else:
# Assumes vector shapes are associated with workgroup dims.
assert (
dim in workgroup_constraints
), f"Dimension {dim} not found in workgroup constraints"
assert (
dim in constraint.vector_shapes
), f"Dimension {dim} not found in vector shapes"
if constraint.vector_shapes[dim] == 0:
continue
inputs = (
workgroup_constraints[dim].workgroup_dim,
(
1
Expand All @@ -213,16 +215,16 @@ def set_node_index(
custom.indexing_dims, constraint.vector_shapes, dim
),
)
if constraint.vector_shapes is not None
else (mma_index[dim], elements_per_thread, None)
)
index_seq = constraint.apply(
constraint_index, dim, elements_per_thread, stride
)
if mma_index:
if elements_per_thread is None:
# Here we end up with a situation where there will be no thread level
# dependence in the dimensional index.
# TODO: Evaluate if this is a valid case.
continue
index_seq = constraint.apply(dim, *inputs, dim in mma_index)
if dim in mma_index:
index_seq = specialize_index_sequence(index_seq, mma_slices, custom)

else:
elif constraint.dim == dim:
if index_seq is None:
index_seq = constraint.apply()
else:
Expand Down Expand Up @@ -250,7 +252,9 @@ def expand_graph(
dim_scaling = constraints_or_scaling
node_index_setter = lambda *args: None
else:
mma_index, mma_slices = get_mma_dimensional_mapping(trace)
mma_index, mma_slices = get_mma_dimensional_mapping(
trace, get_hardware_constraint(constraints_or_scaling)
)
dim_scaling, dim_tile_size = get_dim_scaling(constraints_or_scaling, mma_index)
node_index_setter = partial(
set_node_index, constraints_or_scaling, mma_index, mma_slices, dim_tile_size
Expand Down Expand Up @@ -527,13 +531,17 @@ def get_dim_scaling(
f"Attempting to determine vector shape for unmapped dimension {constraint.dim}"
)

if mma_indices:
if mma_indices and constraint.dim in mma_indices:
vector_size = hardware_constraints[0].mma_matrix_shapes[
mma_indices[constraint.dim]
]
else:
vector_size = hardware_constraints[0].vector_shapes[constraint.dim]

# No dim scaling for dims with 0 vector size.
if vector_size == 0:
continue

wave_count = 1
if isinstance(constraint, WorkgroupConstraint):
wave_count = hardware_constraints[0].waves_per_block[
Expand Down
20 changes: 10 additions & 10 deletions iree/turbine/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@


def get_vector_shape(
trace: CapturedTrace,
hardware_constraint: HardwareConstraint,
symbolic_shape: list[IndexSymbol],
) -> list[int]:
mma_indices, _ = get_mma_dimensional_mapping(trace)
return [
get_hardware_vector_size(dim, hardware_constraint, mma_indices)
for dim in symbolic_shape
assert all(
dim in hardware_constraint.vector_shapes for dim in symbolic_shape
), "Missing vector shape in hardware constraint"
vector_shapes = [
max(hardware_constraint.vector_shapes[dim], 1) for dim in symbolic_shape
]
return vector_shapes


def partition_strided_operators(trace: CapturedTrace, constraints: list[Constraint]):
Expand All @@ -46,7 +47,7 @@ def has_strided_access(node: fx.Node) -> bool:
read more than a single element.
"""
custom = get_custom(node)
if isinstance(custom, Write) and len(custom.register_type.symbolic_shape) == 2:
if isinstance(custom, Write):
strides = [
simplify_index(custom.register_index[dim]).stride
for dim in custom.register_index
Expand All @@ -69,13 +70,12 @@ def has_strided_access(node: fx.Node) -> bool:
for operator in strided_operators:
custom = get_custom(operator)
simplified_index = {
dim: simplify_index(custom.register_index[dim]) for dim in custom.index
dim: simplify_index(custom.register_index.get(dim, custom.index[dim]))
for dim in custom.index
}

max_stride = int(max(simplified_index[dim].stride for dim in simplified_index))
shape = get_vector_shape(
trace, hw_constraint, custom.register_type.symbolic_shape
)
shape = get_vector_shape(hw_constraint, custom.register_type.symbolic_shape)
elements_per_thread = subs_idxc(custom.elements_per_thread)
with custom.graph.inserting_before(operator):
for i in range(elements_per_thread):
Expand Down
15 changes: 12 additions & 3 deletions iree/turbine/kernel/wave/iree_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,18 @@
from ...support.conversions import TORCH_DTYPE_TO_MLIR_TYPE_ASM


def get_mmt_asm(lhs_type: str, rhs_type: str, acc_type: str) -> str:
def get_mmt_asm(
lhs_type: str, rhs_type: str, acc_type: str, batch: bool = False
) -> str:
acc_dtype = acc_type.split("x")[-1]
operator = "batch_matmul_transpose_b" if batch else "matmul_transpose_b"
func_name = "bmmt" if batch else "mmt"
matmul_function = f"""
func.func @mmt(%lhs: tensor<{lhs_type}>, %rhs: tensor<{rhs_type}>) -> tensor<{acc_type}> {{
func.func @{func_name}(%lhs: tensor<{lhs_type}>, %rhs: tensor<{rhs_type}>) -> tensor<{acc_type}> {{
%c0 = arith.constant 0.0 : {acc_dtype}
%init = tensor.empty() : tensor<{acc_type}>
%inital_result = linalg.fill ins(%c0 : {acc_dtype}) outs(%init : tensor<{acc_type}>) -> tensor<{acc_type}>
%result = linalg.matmul_transpose_b ins(%lhs, %rhs: tensor<{lhs_type}>, tensor<{rhs_type}>)
%result = linalg.{operator} ins(%lhs, %rhs: tensor<{lhs_type}>, tensor<{rhs_type}>)
outs(%inital_result: tensor<{acc_type}>) -> tensor<{acc_type}>
return %result : tensor<{acc_type}>
}}"""
Expand Down Expand Up @@ -70,6 +74,11 @@ def generate_iree_ref(
rhs_type = get_type_str(kernel_inputs[1].shape, kernel_inputs[1].dtype)
acc_type = get_type_str(kernel_outputs[0].shape, kernel_outputs[0].dtype)
asm = get_mmt_asm(lhs_type, rhs_type, acc_type)
elif kernel_type == "bmmt":
lhs_type = get_type_str(kernel_inputs[0].shape, kernel_inputs[0].dtype)
rhs_type = get_type_str(kernel_inputs[1].shape, kernel_inputs[1].dtype)
acc_type = get_type_str(kernel_outputs[0].shape, kernel_outputs[0].dtype)
asm = get_mmt_asm(lhs_type, rhs_type, acc_type, batch=True)
elif kernel_type.startswith(conv_str):
lhs_type = get_type_str(kernel_inputs[0].shape, kernel_inputs[0].dtype)
rhs_type = get_type_str(kernel_inputs[1].shape, kernel_inputs[1].dtype)
Expand Down
29 changes: 26 additions & 3 deletions iree/turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def simplify_index(index: IndexExpr) -> IndexExpr:

def get_mma_dimensional_mapping(
trace: CapturedTrace,
hardware_constraint: HardwareConstraint,
) -> tuple[dict[IndexSymbol, int], dict[IndexSymbol, list[fx.Node]]]:
"""
Given a trace, determine the MMA dimensional mapping for all the
Expand All @@ -200,6 +201,9 @@ def get_mma_dimensional_mapping(
where a_reg has shape UxV, b has shape SxV and acc has shape UxS,
we map U to the MMA M dimension (0), S to the MMA N dimension (1) and
V to the MMA K dimension (2).
Also update the vector shapes in the hardware constraint based on the
discovered MMA dimensions.
"""

def is_mma(node):
Expand All @@ -217,6 +221,13 @@ def is_mma(node):
mapping[m] = 0
mapping[n] = 1
mapping[k] = 2
# Update vector shapes in hardware constraint.
M, N, K = hardware_constraint.mma_matrix_shapes
if not hardware_constraint.vector_shapes:
hardware_constraint.vector_shapes = {}
hardware_constraint.vector_shapes[m] = M
hardware_constraint.vector_shapes[n] = N
hardware_constraint.vector_shapes[k] = K

return mapping, capture_mma_slices([get_custom(x) for x in mma_nodes])

Expand Down Expand Up @@ -509,9 +520,7 @@ def get_inputs(
inputs.append(local_reduction.init_args[iter_arg_idx])
elif isinstance(custom, GetResult):
reduction = get_custom(custom.value)
assert isinstance(
get_custom(reduction), Reduction
), "GetResult must be used by a Reduction"
assert isinstance(reduction, Reduction), "GetResult must be used by a Reduction"
# Map get result to output
reduction_subgraph = reduction.graph.subgraphs[reduction.subgraph_name]
inputs.append(reduction.outputs(reduction_subgraph)[custom.res_idx])
Expand Down Expand Up @@ -651,6 +660,20 @@ def get_tiling_constraint(
raise ValueError(f"Could not find tiling constraint for reduction {reduction}")


def get_hardware_constraint(constraints: list[Constraint]) -> HardwareConstraint:
for constraint in constraints:
if isinstance(constraint, HardwareConstraint):
return constraint
else:
raise ValueError(f"Could not find hardware constraint in {constraints}")


def get_workgroup_constraints(
constraints: list[Constraint],
) -> list[WorkgroupConstraint]:
return [x for x in constraints if isinstance(x, WorkgroupConstraint)]


def replace_uses_in(users: dict[fx.Node, list[CustomOp]], old: CustomOp, new: fx.Node):
"""
Replace all uses of `old` with `new` in the list of users.
Expand Down
Loading

0 comments on commit cb92697

Please sign in to comment.