diff --git a/iree/turbine/kernel/wave/constraints.py b/iree/turbine/kernel/wave/constraints.py index a6eb008f..1ad28ba8 100644 --- a/iree/turbine/kernel/wave/constraints.py +++ b/iree/turbine/kernel/wave/constraints.py @@ -52,6 +52,9 @@ class HardwareConstraint(Constraint): these situations, the user can specify the vector shape they want to tile to by specifying the vector shapes dictionary which maps a tensor dimension to its corresponding tile size. + + Both mma constraints and vector shapes can be specified, but + the mapping from symbols to shapes should be injective. """ threads_per_wave: int @@ -116,8 +119,6 @@ def compute_access_pattern_using_vector_shapes( elements_per_thread: int | IndexSymbol, stride: int, ) -> IndexSequence: - if dim not in self.vector_shapes: - raise ValueError(f"No vector shape specified for dimension {dim}") thread_id = self.get_thread_id_from_workgroup_dim(workgroup_dim) return IndexSequence( thread_id * elements_per_thread, elements_per_thread, stride @@ -125,12 +126,13 @@ def compute_access_pattern_using_vector_shapes( def apply( self, - constraint_index: int, dim: IndexSymbol, + constraint_index: int, elements_per_thread: int | IndexSymbol, stride: int, + is_mma_dim: bool, ) -> IndexSequence: - if self.vector_shapes is not None: + if not is_mma_dim: return self.compute_access_pattern_using_vector_shapes( dim, constraint_index, elements_per_thread, stride ) diff --git a/iree/turbine/kernel/wave/expansion.py b/iree/turbine/kernel/wave/expansion.py index 75cc5551..e124ac50 100644 --- a/iree/turbine/kernel/wave/expansion.py +++ b/iree/turbine/kernel/wave/expansion.py @@ -19,7 +19,12 @@ from .._support.indexing import IndexingContext, IndexSequence from ...support.logging import get_logger from .._support.tracing import CapturedTrace -from .utils import get_mma_dimensional_mapping, specialize_index_sequence +from .utils import ( + get_mma_dimensional_mapping, + specialize_index_sequence, + get_hardware_constraint, + get_workgroup_constraints, +) from ..lang.global_symbols import * logger = get_logger("turbine.wave.expansion") @@ -163,7 +168,7 @@ def set_node_index( dimensions M, N and K, but allow each mapping to be piecewise conditioned on the operand. """ - hardware_constraint = [c for c in constraints if isinstance(c, HardwareConstraint)] + hardware_constraint = [get_hardware_constraint(constraints)] workgroup_constraints = { c.dim: c for c in constraints if isinstance(c, WorkgroupConstraint) } @@ -174,31 +179,28 @@ def set_node_index( sorted_constraints = hardware_constraint + other_constraints index = {} + # The semantics of elements_per_thread are that it represents the number of + # elements that are loaded contiguously from memory. + elements_per_thread = getattr(custom, "elements_per_thread", None) + for dim in custom.indexing_dims: index_seq = None for constraint in sorted_constraints: - mma_check = isinstance(constraint, HardwareConstraint) and dim in mma_index - - vector_check = ( - isinstance(constraint, HardwareConstraint) - and constraint.vector_shapes is not None - and hasattr(custom, "elements_per_thread") - ) - - constraint_check = ( - not isinstance(constraint, HardwareConstraint) and dim == constraint.dim - ) - - if (not (mma_check or vector_check)) and (not constraint_check): - continue - if isinstance(constraint, HardwareConstraint): - - # The semantics of elements_per_thread are that it represents the number of - # elements that are loaded contiguously from memory. - elements_per_thread = getattr(custom, "elements_per_thread", None) - constraint_index, elements_per_thread, stride = ( - ( + inputs = None + if dim in mma_index: + inputs = (mma_index[dim], elements_per_thread, None) + else: + # Assumes vector shapes are associated with workgroup dims. + assert ( + dim in workgroup_constraints + ), f"Dimension {dim} not found in workgroup constraints" + assert ( + dim in constraint.vector_shapes + ), f"Dimension {dim} not found in vector shapes" + if constraint.vector_shapes[dim] == 0: + continue + inputs = ( workgroup_constraints[dim].workgroup_dim, ( 1 @@ -213,16 +215,16 @@ def set_node_index( custom.indexing_dims, constraint.vector_shapes, dim ), ) - if constraint.vector_shapes is not None - else (mma_index[dim], elements_per_thread, None) - ) - index_seq = constraint.apply( - constraint_index, dim, elements_per_thread, stride - ) - if mma_index: + if elements_per_thread is None: + # Here we end up with a situation where there will be no thread level + # dependence in the dimensional index. + # TODO: Evaluate if this is a valid case. + continue + index_seq = constraint.apply(dim, *inputs, dim in mma_index) + if dim in mma_index: index_seq = specialize_index_sequence(index_seq, mma_slices, custom) - else: + elif constraint.dim == dim: if index_seq is None: index_seq = constraint.apply() else: @@ -250,7 +252,9 @@ def expand_graph( dim_scaling = constraints_or_scaling node_index_setter = lambda *args: None else: - mma_index, mma_slices = get_mma_dimensional_mapping(trace) + mma_index, mma_slices = get_mma_dimensional_mapping( + trace, get_hardware_constraint(constraints_or_scaling) + ) dim_scaling, dim_tile_size = get_dim_scaling(constraints_or_scaling, mma_index) node_index_setter = partial( set_node_index, constraints_or_scaling, mma_index, mma_slices, dim_tile_size @@ -527,13 +531,17 @@ def get_dim_scaling( f"Attempting to determine vector shape for unmapped dimension {constraint.dim}" ) - if mma_indices: + if mma_indices and constraint.dim in mma_indices: vector_size = hardware_constraints[0].mma_matrix_shapes[ mma_indices[constraint.dim] ] else: vector_size = hardware_constraints[0].vector_shapes[constraint.dim] + # No dim scaling for dims with 0 vector size. + if vector_size == 0: + continue + wave_count = 1 if isinstance(constraint, WorkgroupConstraint): wave_count = hardware_constraints[0].waves_per_block[ diff --git a/iree/turbine/kernel/wave/index_sequence_analysis.py b/iree/turbine/kernel/wave/index_sequence_analysis.py index b9212f01..305ea5a4 100644 --- a/iree/turbine/kernel/wave/index_sequence_analysis.py +++ b/iree/turbine/kernel/wave/index_sequence_analysis.py @@ -20,15 +20,16 @@ def get_vector_shape( - trace: CapturedTrace, hardware_constraint: HardwareConstraint, symbolic_shape: list[IndexSymbol], ) -> list[int]: - mma_indices, _ = get_mma_dimensional_mapping(trace) - return [ - get_hardware_vector_size(dim, hardware_constraint, mma_indices) - for dim in symbolic_shape + assert all( + dim in hardware_constraint.vector_shapes for dim in symbolic_shape + ), "Missing vector shape in hardware constraint" + vector_shapes = [ + max(hardware_constraint.vector_shapes[dim], 1) for dim in symbolic_shape ] + return vector_shapes def partition_strided_operators(trace: CapturedTrace, constraints: list[Constraint]): @@ -46,7 +47,7 @@ def has_strided_access(node: fx.Node) -> bool: read more than a single element. """ custom = get_custom(node) - if isinstance(custom, Write) and len(custom.register_type.symbolic_shape) == 2: + if isinstance(custom, Write): strides = [ simplify_index(custom.register_index[dim]).stride for dim in custom.register_index @@ -69,13 +70,12 @@ def has_strided_access(node: fx.Node) -> bool: for operator in strided_operators: custom = get_custom(operator) simplified_index = { - dim: simplify_index(custom.register_index[dim]) for dim in custom.index + dim: simplify_index(custom.register_index.get(dim, custom.index[dim])) + for dim in custom.index } max_stride = int(max(simplified_index[dim].stride for dim in simplified_index)) - shape = get_vector_shape( - trace, hw_constraint, custom.register_type.symbolic_shape - ) + shape = get_vector_shape(hw_constraint, custom.register_type.symbolic_shape) elements_per_thread = subs_idxc(custom.elements_per_thread) with custom.graph.inserting_before(operator): for i in range(elements_per_thread): diff --git a/iree/turbine/kernel/wave/iree_utils.py b/iree/turbine/kernel/wave/iree_utils.py index 39f67404..3a7f5c69 100644 --- a/iree/turbine/kernel/wave/iree_utils.py +++ b/iree/turbine/kernel/wave/iree_utils.py @@ -10,14 +10,18 @@ from ...support.conversions import TORCH_DTYPE_TO_MLIR_TYPE_ASM -def get_mmt_asm(lhs_type: str, rhs_type: str, acc_type: str) -> str: +def get_mmt_asm( + lhs_type: str, rhs_type: str, acc_type: str, batch: bool = False +) -> str: acc_dtype = acc_type.split("x")[-1] + operator = "batch_matmul_transpose_b" if batch else "matmul_transpose_b" + func_name = "bmmt" if batch else "mmt" matmul_function = f""" - func.func @mmt(%lhs: tensor<{lhs_type}>, %rhs: tensor<{rhs_type}>) -> tensor<{acc_type}> {{ + func.func @{func_name}(%lhs: tensor<{lhs_type}>, %rhs: tensor<{rhs_type}>) -> tensor<{acc_type}> {{ %c0 = arith.constant 0.0 : {acc_dtype} %init = tensor.empty() : tensor<{acc_type}> %inital_result = linalg.fill ins(%c0 : {acc_dtype}) outs(%init : tensor<{acc_type}>) -> tensor<{acc_type}> - %result = linalg.matmul_transpose_b ins(%lhs, %rhs: tensor<{lhs_type}>, tensor<{rhs_type}>) + %result = linalg.{operator} ins(%lhs, %rhs: tensor<{lhs_type}>, tensor<{rhs_type}>) outs(%inital_result: tensor<{acc_type}>) -> tensor<{acc_type}> return %result : tensor<{acc_type}> }}""" @@ -70,6 +74,11 @@ def generate_iree_ref( rhs_type = get_type_str(kernel_inputs[1].shape, kernel_inputs[1].dtype) acc_type = get_type_str(kernel_outputs[0].shape, kernel_outputs[0].dtype) asm = get_mmt_asm(lhs_type, rhs_type, acc_type) + elif kernel_type == "bmmt": + lhs_type = get_type_str(kernel_inputs[0].shape, kernel_inputs[0].dtype) + rhs_type = get_type_str(kernel_inputs[1].shape, kernel_inputs[1].dtype) + acc_type = get_type_str(kernel_outputs[0].shape, kernel_outputs[0].dtype) + asm = get_mmt_asm(lhs_type, rhs_type, acc_type, batch=True) elif kernel_type.startswith(conv_str): lhs_type = get_type_str(kernel_inputs[0].shape, kernel_inputs[0].dtype) rhs_type = get_type_str(kernel_inputs[1].shape, kernel_inputs[1].dtype) diff --git a/iree/turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py index 11b3cbe2..da11e66b 100644 --- a/iree/turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -192,6 +192,7 @@ def simplify_index(index: IndexExpr) -> IndexExpr: def get_mma_dimensional_mapping( trace: CapturedTrace, + hardware_constraint: HardwareConstraint, ) -> tuple[dict[IndexSymbol, int], dict[IndexSymbol, list[fx.Node]]]: """ Given a trace, determine the MMA dimensional mapping for all the @@ -200,6 +201,9 @@ def get_mma_dimensional_mapping( where a_reg has shape UxV, b has shape SxV and acc has shape UxS, we map U to the MMA M dimension (0), S to the MMA N dimension (1) and V to the MMA K dimension (2). + + Also update the vector shapes in the hardware constraint based on the + discovered MMA dimensions. """ def is_mma(node): @@ -217,6 +221,13 @@ def is_mma(node): mapping[m] = 0 mapping[n] = 1 mapping[k] = 2 + # Update vector shapes in hardware constraint. + M, N, K = hardware_constraint.mma_matrix_shapes + if not hardware_constraint.vector_shapes: + hardware_constraint.vector_shapes = {} + hardware_constraint.vector_shapes[m] = M + hardware_constraint.vector_shapes[n] = N + hardware_constraint.vector_shapes[k] = K return mapping, capture_mma_slices([get_custom(x) for x in mma_nodes]) @@ -509,9 +520,7 @@ def get_inputs( inputs.append(local_reduction.init_args[iter_arg_idx]) elif isinstance(custom, GetResult): reduction = get_custom(custom.value) - assert isinstance( - get_custom(reduction), Reduction - ), "GetResult must be used by a Reduction" + assert isinstance(reduction, Reduction), "GetResult must be used by a Reduction" # Map get result to output reduction_subgraph = reduction.graph.subgraphs[reduction.subgraph_name] inputs.append(reduction.outputs(reduction_subgraph)[custom.res_idx]) @@ -651,6 +660,20 @@ def get_tiling_constraint( raise ValueError(f"Could not find tiling constraint for reduction {reduction}") +def get_hardware_constraint(constraints: list[Constraint]) -> HardwareConstraint: + for constraint in constraints: + if isinstance(constraint, HardwareConstraint): + return constraint + else: + raise ValueError(f"Could not find hardware constraint in {constraints}") + + +def get_workgroup_constraints( + constraints: list[Constraint], +) -> list[WorkgroupConstraint]: + return [x for x in constraints if isinstance(x, WorkgroupConstraint)] + + def replace_uses_in(users: dict[fx.Node, list[CustomOp]], old: CustomOp, new: fx.Node): """ Replace all uses of `old` with `new` in the list of users. diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index b9fc81ed..4378044a 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -12,9 +12,11 @@ M = tkl.sym.M N = tkl.sym.N K = tkl.sym.K +B = tkl.sym.B BLOCK_M = tkl.sym.BLOCK_M BLOCK_N = tkl.sym.BLOCK_N BLOCK_K = tkl.sym.BLOCK_K +BLOCK_B = tkl.sym.BLOCK_B LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEM_PER_THREAD STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEM_PER_THREAD ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE @@ -674,6 +676,163 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK: return +@run_test +def test_batched_gemm(): + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] + constraints += [tkw.TilingConstraint(K, BLOCK_K)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(2, 2, 1), + mma_type=tkw.MMAType.F32_16x16x16_F16, + vector_shapes={B: 0}, + ) + ] + + @tkw.wave(constraints) + def batched_gemm( + a: tkl.Memory[B, M, K, ADDRESS_SPACE, tkl.f16], + b: tkl.Memory[B, N, K, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[B, M, N, ADDRESS_SPACE_0, tkl.f32], + ): + c_reg = tkl.Register[B, M, N, tkl.f32](0.0) + + @tkw.reduction(K, init_args=[c_reg]) + def repeat( + acc: tkl.Register[B, M, N, tkl.f32] + ) -> tkl.Register[B, M, N, tkl.f32]: + a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD) + b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD) + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD) + + with tk.gen.TestLaunchContext( + { + B: 12, + M: 64, + N: 128, + K: 64, + BLOCK_M: 32, + BLOCK_N: 32, + BLOCK_K: 16, + BLOCK_B: 1, + LOAD_ELEMS_PER_THREAD: 4, + STORE_ELEMS_PER_THREAD: 4, + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, + }, + canonicalize=True, + ): + a = torch.randn(12, 64, 32, dtype=torch.float16) + b = torch.randn(12, 128, 32, dtype=torch.float16) + c = torch.zeros(12, 64, 128, dtype=torch.float32) + print(batched_gemm(a, b, c).module_op) + + # CHECK: func.func @batched_gemm(%[[ARG0:[a-zA-Z0-9_]+]]: !stream.binding, %[[ARG1:[a-zA-Z0-9_]+]]: + # CHECK-SAME: !stream.binding, %[[ARG2:[a-zA-Z0-9_]+]]: !stream.binding) attributes {translation_info = + # CHECK-SAME: #[[TRANSLATION:.+]]} { + # CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index + # CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index + # CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index + # CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index + # CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index + # CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index + # CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index + # CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + # CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<4xf32> + # CHECK: %[[WORKGROUP_ID_0:.+]] = stream.dispatch.workgroup.id[0] : index + # CHECK: %[[WORKGROUP_ID_1:.+]] = stream.dispatch.workgroup.id[1] : index + # CHECK: %[[WORKGROUP_ID_2:.+]] = stream.dispatch.workgroup.id[2] : index + # CHECK-DAG: %[[THREAD_ID_X:.+]] = gpu.thread_id x + # CHECK-DAG: %[[THREAD_ID_Y:.+]] = gpu.thread_id y + # CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1x32x20xf16, #[[GPU:.+]].address_space> + # CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<1x32x20xf16, #[[GPU]].address_space> + # CHECK: %[[D0:.+]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<12x64x64xf16, + # CHECK-SAME: strided<[4096, 64, 1], offset: ?>> + # CHECK: %[[D1:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<12x128x64xf16, + # CHECK-SAME: strided<[8192, 64, 1], offset: ?>> + # CHECK: %[[D2:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C64]] : index + # CHECK: %[[D3:.+]] = arith.muli %[[D2]], %[[C16]] : index + # CHECK: %[[D4:.+]] = arith.muli %[[WORKGROUP_ID_0]], %[[C32]] : index + # CHECK: %[[D5:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C16]] : index + # CHECK: %[[D6:.+]] = arith.addi %[[D5]], %[[D4]] : index + # CHECK: %[[D7:.+]] = arith.addi %[[D6]], %[[D3]] : index + # CHECK: %[[D8:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C64]] : index + # CHECK: %[[D9:.+]] = arith.divsi %[[D8]], %[[C16]] : index + # CHECK: %[[D10:.+]] = arith.muli %[[D9]], %[[C4]] : index + # CHECK: %[[D11:.+]] = arith.addi %[[D5]], %[[D3]] : index + # CHECK: %[[D12:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C16]] : index + # CHECK: %[[D13:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C32]] : index + # CHECK: %[[D14:.+]] = arith.addi %[[D5]], %[[D13]] : index + # CHECK: %[[D15:.+]] = arith.addi %[[D14]], %[[D12]] : index + # CHECK: %[[D16:.+]] = arith.addi %[[D5]], %[[D12]] : index + # CHECK: %[[D17:.+]] = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C4]] step %[[C1]] + # CHECK-SAME: iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[CST]]) -> (vector<4xf32>) { + # CHECK: %[[D39:.+]] = arith.muli %[[ARG3]], %[[C16]] : index + # CHECK: %[[D40:.+]] = arith.addi %[[D39]], %[[D10]] : index + # CHECK: %[[D41:.+]] = vector.load %[[D0]][%[[WORKGROUP_ID_2]], %[[D7]], %[[D40]]] : memref<12x64x64xf16, + # CHECK-SAME: strided<[4096, 64, 1], offset: ?>>, vector<4xf16> + # CHECK: vector.store %[[D41]], %[[ALLOC]][%[[C0]], %[[D11]], %[[D10]]] : memref<1x32x20xf16, + # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> + # CHECK: amdgpu.lds_barrier + # CHECK: %[[D42:.+]] = vector.load %[[ALLOC]][%[[C0]], %[[D11]], %[[D10]]] : memref<1x32x20xf16, + # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> + # CHECK: %[[D43:.+]] = vector.load %[[D1]][%[[WORKGROUP_ID_2]], %[[D15]], %[[D40]]] : memref<12x128x64xf16, + # CHECK-SAME: strided<[8192, 64, 1], offset: ?>>, vector<4xf16> + # CHECK: amdgpu.lds_barrier + # CHECK: vector.store %[[D43]], %[[ALLOC_0]][%[[C0]], %[[D16]], %[[D10]]] : memref<1x32x20xf16, + # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> + # CHECK: amdgpu.lds_barrier + # CHECK: %[[D44:.+]] = vector.load %[[ALLOC_0]][%[[C0]], %[[D16]], %[[D10]]] : memref<1x32x20xf16, + # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> + # CHECK: %[[D45:.+]] = amdgpu.mfma %[[D42]] * %[[D44]] + %[[ARG4]] {blocks = 1 : i32, k = 16 : i32, m = 16 + # CHECK-SAME: : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> + # CHECK: scf.yield %[[D45]] : vector<4xf32> + # CHECK: } + # CHECK: %[[D18:.+]] = vector.extract_strided_slice %[[D17]] {offsets = [0], sizes = [1], strides = [1]} : + # CHECK-SAME: vector<4xf32> to vector<1xf32> + # CHECK: %[[D19:.+]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<12x64x128xf32, + # CHECK-SAME: strided<[8192, 128, 1], offset: ?>> + # CHECK: %[[D20:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C64]] : index + # CHECK: %[[D21:.+]] = arith.divsi %[[D20]], %[[C16]] : index + # CHECK: %[[D22:.+]] = arith.muli %[[D21]], %[[C4]] : index + # CHECK: %[[D23:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C64]] : index + # CHECK: %[[D24:.+]] = arith.muli %[[D23]], %[[C16]] : index + # CHECK: %[[D25:.+]] = arith.muli %[[WORKGROUP_ID_0]], %[[C32]] : index + # CHECK: %[[D26:.+]] = arith.addi %[[D25]], %[[D24]] : index + # CHECK: %[[D27:.+]] = arith.addi %[[D26]], %[[D22]] : index + # CHECK: %[[D28:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C16]] : index + # CHECK: %[[D29:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C32]] : index + # CHECK: %[[D30:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C16]] : index + # CHECK: %[[D31:.+]] = arith.addi %[[D30]], %[[D29]] : index + # CHECK: %[[D32:.+]] = arith.addi %[[D31]], %[[D28]] : index + # CHECK: vector.store %[[D18]], %[[D19]][%[[WORKGROUP_ID_2]], %[[D27]], %[[D32]]] : memref<12x64x128xf32, + # CHECK-SAME: strided<[8192, 128, 1], offset: ?>>, vector<1xf32> + # CHECK: %[[D33:.+]] = vector.extract_strided_slice %[[D17]] {offsets = [1], sizes = [1], strides = [1]} : + # CHECK-SAME: vector<4xf32> to vector<1xf32> + # CHECK: %[[D34:.+]] = arith.addi %[[D27]], %[[C1]] : index + # CHECK: vector.store %[[D33]], %[[D19]][%[[WORKGROUP_ID_2]], %[[D34]], %[[D32]]] : memref<12x64x128xf32, + # CHECK-SAME: strided<[8192, 128, 1], offset: ?>>, vector<1xf32> + # CHECK: %[[D35:.+]] = vector.extract_strided_slice %[[D17]] {offsets = [2], sizes = [1], strides = [1]} : + # CHECK-SAME: vector<4xf32> to vector<1xf32> + # CHECK: %[[D36:.+]] = arith.addi %[[D27]], %[[C2]] : index + # CHECK: vector.store %[[D35]], %[[D19]][%[[WORKGROUP_ID_2]], %[[D36]], %[[D32]]] : memref<12x64x128xf32, + # CHECK-SAME: strided<[8192, 128, 1], offset: ?>>, vector<1xf32> + # CHECK: %[[D37:.+]] = vector.extract_strided_slice %[[D17]] {offsets = [3], sizes = [1], strides = [1]} : + # CHECK-SAME: vector<4xf32> to vector<1xf32> + # CHECK: %[[D38:.+]] = arith.addi %[[D27]], %[[C3]] : index + # CHECK: vector.store %[[D37]], %[[D19]][%[[WORKGROUP_ID_2]], %[[D38]], %[[D32]]] : memref<12x64x128xf32, + # CHECK-SAME: strided<[8192, 128, 1], offset: ?>>, vector<1xf32> + # CHECK: return + + @run_test def test_gemm_pipelined(): constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] diff --git a/lit_tests/kernel/wave/expansion.py b/lit_tests/kernel/wave/expansion.py index efcdd582..369cf24e 100644 --- a/lit_tests/kernel/wave/expansion.py +++ b/lit_tests/kernel/wave/expansion.py @@ -9,17 +9,20 @@ from iree.turbine.kernel._support.indexing import IndexingContext from iree.turbine.kernel.lang.global_symbols import * from iree.turbine.kernel.wave.utils import run_test, print_trace +from iree.turbine.kernel.wave.constraints import MMAType import sympy # Input sizes M = tkl.sym.M N = tkl.sym.N K = tkl.sym.K +B = tkl.sym.B # Workgroup tile sizes BLOCK_M = tkl.sym.BLOCK_M BLOCK_N = tkl.sym.BLOCK_N BLOCK_K = tkl.sym.BLOCK_K +BLOCK_B = tkl.sym.BLOCK_B # Address space (for GPU, shared(1) or global(0)) ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE @@ -351,6 +354,190 @@ def test_gemm(): # CHECK-NEXT: ----- +@tkw.wave_trace_only() +def batched_gemm( + a: tkl.Memory[B, M, K, ADDRESS_SPACE, tkl.f16], + b: tkl.Memory[B, N, K, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[B, M, N, ADDRESS_SPACE, tkl.f32], +): + c_reg = tkl.Register[B, M, N, tkl.f32](0.0) + + @tkw.reduction(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[B, M, N, tkl.f32]) -> tkl.Register[B, M, N, tkl.f32]: + a_reg = tkw.read(a, elements_per_thread=4) + b_reg = tkw.read(b, elements_per_thread=4) + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + tkw.write(repeat, c, elements_per_thread=4) + + +@run_test +def test_batched_gemm(): + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] + constraints += [tkw.TilingConstraint(K, BLOCK_K, ARGK)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2, THREAD_0 / 64)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2, THREAD_1)] + # Since the MMA shapes only cover M, N and K, we specify the canonical shape for + # the batch dimension in the vector_shapes. + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(2, 2, 1), + vector_shapes={B: 0}, + mma_type=MMAType.F32_16x16x16_F16, + ) + ] + with tk.gen.TestLaunchContext( + { + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K: 32, + BLOCK_B: 1, + } + ): + graph = batched_gemm() + IndexingContext.current().finalize() + expand_graph(graph, constraints) + print_trace(graph) + # Root graph: + # CHECK: %a + # CHECK-NEXT: %b + # CHECK-NEXT: %c + # CHECK-NEXT: %register_0_0_0 + # CHECK-NEXT: %register_1_1_0 + # CHECK-NEXT: %register_1_0_0 + # CHECK-NEXT: %register_0_1_0 + # CHECK-NEXT: %reduction + # CHECK-SAME: %register_0_0_0, %register_0_1_0, %register_1_0_0, %register_1_1_0 + # CHECK-NEXT: %getresult_1_1_0 + # CHECK-NEXT: %getresult_1_0_0 + # CHECK-NEXT: %getresult_0_1_0 + # CHECK-NEXT: %getresult_0_0_0 + # CHECK-NEXT: %write_0_0_0 + # CHECK-SAME: (%getresult_0_0_0, %c, 4, None) + # CHECK-NEXT: %write_1_1_0 + # CHECK-SAME: (%getresult_1_1_0, %c, 4, None) + # CHECK-NEXT: %write_1_0_0 + # CHECK-SAME: (%getresult_1_0_0, %c, 4, None) + # CHECK-NEXT: %write_0_1_0 + # CHECK-SAME: (%getresult_0_1_0, %c, 4, None) + # CHECK-NEXT: return None + + # Custom format: + # CHECK-NEXT: placeholder(_name=a + # CHECK-NEXT: placeholder(_name=b + # CHECK-NEXT: placeholder(_name=c + # CHECK-NEXT: register(shape=(B, M, N), dtype=f32, value=0.0, index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}) + # CHECK-NEXT: register(shape=(B, M, N), dtype=f32, value=0.0, index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}) + # CHECK-NEXT: register(shape=(B, M, N), dtype=f32, value=0.0, index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}) + # CHECK-NEXT: register(shape=(B, M, N), dtype=f32, value=0.0, index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}) + # CHECK-NEXT: reduction(axis=K, init_args=[register_0_0_0, register_0_1_0, register_1_0_0, register_1_1_0], subgraph_name=region_0, implicit_captures=[a, b]) + # CHECK-NEXT: get_result(value=reduction, res_idx=3) + # CHECK-NEXT: get_result(value=reduction, res_idx=2) + # CHECK-NEXT: get_result(value=reduction, res_idx=1) + # CHECK-NEXT: get_result(value=reduction, res_idx=0) + # CHECK-NEXT: write(register_=getresult_0_0_0 + # CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)} + # CHECK-NEXT: write(register_=getresult_1_1_0 + # CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16} + # CHECK-NEXT: write(register_=getresult_1_0_0 + # CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)} + # CHECK-NEXT: write(register_=getresult_0_1_0 + # CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16} + # CHECK-NEXT: output + + # Reduction subgraph: + + # CHECK: %acc_0_0_0 + # CHECK-NEXT: %acc_0_1_0 + # CHECK-NEXT: %acc_1_0_0 + # CHECK-NEXT: %acc_1_1_0 + + # CHECK-NEXT: %a + # CHECK-NEXT: %read_0_0_0 + # CHECK-SAME: (%a, 4, None, None) + # CHECK-NEXT: %read_0_0_1 + # CHECK-SAME: (%a, 4, None, None) + # CHECK-NEXT: %read_1_0_0 + # CHECK-SAME: (%a, 4, None, None) + # CHECK-NEXT: %read_1_0_1 + # CHECK-SAME: (%a, 4, None, None) + + # CHECK-NEXT: %b + # CHECK-NEXT: %read_0_0_0 + # CHECK-SAME: (%b, 4, None, None) + # CHECK-NEXT: %read_0_0_1 + # CHECK-SAME: (%b, 4, None, None) + # CHECK-NEXT: %read_0_1_0 + # CHECK-SAME: (%b, 4, None, None) + # CHECK-NEXT: %read_0_1_1 + # CHECK-SAME: (%b, 4, None, None) + + # CHECK-NEXT: %mma_0_0_0 + # CHECK-SAME: (%read_0_0_0, %read_0_0_0, %acc_0_0_0) + # CHECK-NEXT: %mma_0_0_1 + # CHECK-SAME: (%read_0_0_1, %read_0_0_1, %mma_0_0_0) + # CHECK-NEXT: %mma_1_1_0 + # CHECK-SAME: (%read_1_0_0, %read_0_1_0, %acc_1_1_0) + # CHECK-NEXT: %mma_1_1_1 + # CHECK-SAME: (%read_1_0_1, %read_0_1_1, %mma_1_1_0) + # CHECK-NEXT: %mma_1_0_0 + # CHECK-SAME: (%read_1_0_0, %read_0_0_0, %acc_1_0_0) + # CHECK-NEXT: %mma_1_0_1 + # CHECK-SAME: (%read_1_0_1, %read_0_0_1, %mma_1_0_0) + # CHECK-NEXT: %mma_0_1_0 + # CHECK-SAME: (%read_0_0_0, %read_0_1_0, %acc_0_1_0) + # CHECK-NEXT: %mma_0_1_1 + # CHECK-SAME: (%read_0_0_1, %read_0_1_1, %mma_0_1_0) + # CHECK-NEXT: return [mma_0_0_1, mma_0_1_1, mma_1_0_1, mma_1_1_1] + + # Custom format: + # CHECK-NEXT: placeholder(_name=acc_0_0_0 + # CHECK-NEXT: placeholder(_name=acc_0_1_0 + # CHECK-NEXT: placeholder(_name=acc_1_0_0 + # CHECK-NEXT: placeholder(_name=acc_1_1_0 + # CHECK-NEXT: placeholder(_name=a + # CHECK-NEXT: read(memory=a, elements_per_thread=4, index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-NEXT: read(memory=a, elements_per_thread=4, index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-NEXT: read(memory=a, elements_per_thread=4, index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-NEXT: read(memory=a, elements_per_thread=4, index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-NEXT: placeholder(_name=b + # CHECK-NEXT: read(memory=b, elements_per_thread=4, index={B: $WG2*BLOCK_B, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16), K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-NEXT: read(memory=b, elements_per_thread=4, index={B: $WG2*BLOCK_B, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16), K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-NEXT: read(memory=b, elements_per_thread=4, index={B: $WG2*BLOCK_B, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-NEXT: read(memory=b, elements_per_thread=4, index={B: $WG2*BLOCK_B, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-NEXT: mma(lhs=read_0_0_0 (index = {B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-SAME: rhs=read_0_0_0 (index = {B: $WG2*BLOCK_B, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16), K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-SAME: acc=acc_0_0_0 (index = {B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)})) + # CHECK-NEXT: mma(lhs=read_0_0_1 (index = {B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-SAME: rhs=read_0_0_1 (index = {B: $WG2*BLOCK_B, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16), K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-SAME: acc=mma_0_0_0 (index = {B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)})) + # CHECK-NEXT: mma(lhs=read_1_0_0 (index = {B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-SAME: rhs=read_0_1_0 (index = {B: $WG2*BLOCK_B, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-SAME: acc=acc_1_1_0 (index = {B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16})) + # CHECK-NEXT: mma(lhs=read_1_0_1 (index = {B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-SAME: rhs=read_0_1_1 (index = {B: $WG2*BLOCK_B, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-SAME: acc=mma_1_1_0 (index = {B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16})) + # CHECK-NEXT: mma(lhs=read_1_0_0 (index = {B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-SAME: rhs=read_0_0_0 (index = {B: $WG2*BLOCK_B, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16), K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-SAME: acc=acc_1_0_0 (index = {B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)})) + # CHECK-NEXT: mma(lhs=read_1_0_1 (index = {B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-SAME: rhs=read_0_0_1 (index = {B: $WG2*BLOCK_B, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16), K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-SAME: acc=mma_1_0_0 (index = {B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)})) + # CHECK-NEXT: mma(lhs=read_0_0_0 (index = {B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-SAME: rhs=read_0_1_0 (index = {B: $WG2*BLOCK_B, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK-SAME: acc=acc_0_1_0 (index = {B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16})) + # CHECK-NEXT: mma(lhs=read_0_0_1 (index = {B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-SAME: rhs=read_0_1_1 (index = {B: $WG2*BLOCK_B, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK-SAME: acc=mma_0_1_0 (index = {B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16})) + # CHECK-NEXT: output(return_vals=([mma_0_0_1, mma_0_1_1, mma_1_0_1, mma_1_1_1],)) + + # CHECK-NEXT: ----- + + @run_test def test_gemm_reduction_expansion_only(): # Note: This does not implement an actual gemm computation but reuses the diff --git a/tests/kernel/wave/wave_gemm_test.py b/tests/kernel/wave/wave_gemm_test.py index 78d26e3b..0f1a356b 100644 --- a/tests/kernel/wave/wave_gemm_test.py +++ b/tests/kernel/wave/wave_gemm_test.py @@ -24,15 +24,17 @@ # Whether to use scheduling group barriers (needs LLVM fix). enable_scheduling_barriers = int(os.environ.get("WAVE_USE_SCHED_BARRIERS", 0)) -default_test_shapes = [(1024, 5120, 640), (2048, 10240, 1280), (4096, 20480, 2560)] - +# Add test shapes for validation and performance testing. perf_test = lambda *a: pytest.param(*a, marks=pytest.mark.perf_only) - -default_test_shapes += [ - perf_test((1024, 5120, 640)), - perf_test((2048, 10240, 1280)), - perf_test((4096, 20480, 2560)), +default_test_shapes = {} +default_test_shapes["test_gemm"] = [ + (1024, 5120, 640), + (2048, 10240, 1280), + (4096, 20480, 2560), ] +default_test_shapes["test_gemm"] += [perf_test(x) for x in default_test_shapes] +default_test_shapes["test_batched_gemm"] = [(8, 256, 128, 192), (32, 1024, 512, 768)] + user_specified_test_shapes = "" @@ -46,7 +48,7 @@ def get_test_shapes(test_name: str) -> list[tuple[int]]: if test_name in user_specified_test_shapes: return user_specified_test_shapes[test_name] - return default_test_shapes + return default_test_shapes[test_name] @require_e2e @@ -165,3 +167,118 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: iree_ref = torch.zeros(shape[0], shape[1], dtype=torch.float32) generate_iree_ref("mmt", [a, b], [iree_ref], config, run_bench=run_bench) assert_close(c, iree_ref) + + +@require_e2e +@pytest.mark.parametrize("shape", get_test_shapes("test_batched_gemm")) +@pytest.mark.parametrize("enable_scheduling", [False, True]) +def testBatchedGemm(shape: tuple[int], enable_scheduling: bool, request): + run_bench = request.config.getoption("--runperf") + dump_perf = request.config.getoption("--dump-perf-files-path") + # Input sizes + B = tkl.sym.B + M = tkl.sym.M + N = tkl.sym.N + K = tkl.sym.K + # Workgroup tile sizes + BLOCK_B = tkl.sym.BLOCK_B + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K = tkl.sym.BLOCK_K + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD + STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD + + # Expose user-constraints + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] + constraints += [tkw.TilingConstraint(K, BLOCK_K)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, waves_per_block=(2, 2, 1), vector_shapes={B: 0} + ) + ] + + @tkw.wave(constraints) + def batched_gemm( + a: tkl.Memory[B, M, K, ADDRESS_SPACE, tkl.f16], + b: tkl.Memory[B, N, K, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[B, M, N, tkl.f32](0.0) + + @tkw.reduction(K, init_args=[c_reg]) + def repeat( + acc: tkl.Register[B, M, N, tkl.f32] + ) -> tkl.Register[B, M, N, tkl.f32]: + a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD) + b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD) + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD) + + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + LOAD_ELEMS_PER_THREAD: 4, + STORE_ELEMS_PER_THREAD: 4, + BLOCK_B: 1, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K: 32, + B: shape[0], + M: shape[1], + N: shape[2], + K: shape[3], + READ_SHARED_DELAY: 1, + WRITE_SHARED_DELAY: 1, + READ_GLOBAL_DELAY: 2, + WRITE_GLOBAL_DELAY: 2, + MMA_DELAY: 1, + SHARED_MEMORY_UNITS: 4, + GLOBAL_MEMORY_UNITS: 4, + MMA_UNITS: 4, + } + config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + if run_bench: + config["benchmark_batch_size"] = 10 + config["benchmark_repetitions"] = 3 + if dump_perf is not None: + perf_filename = request.node.name + ".json" + config["benchmark_results_file"] = os.path.join( + dump_perf, "tk_" + perf_filename + ) + + with tk.gen.TestLaunchContext( + hyperparams, + canonicalize=True, + run=True, + run_bench=run_bench, + run_config=config, + schedule=enable_scheduling, + use_scheduling_barriers=enable_scheduling_barriers, + ): + a = torch.randn(shape[0], shape[1], shape[3], dtype=torch.float16) + b = torch.randn(shape[0], shape[2], shape[3], dtype=torch.float16) + c = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32) + mb = batched_gemm(a, b, c) + + if test_dump_generated_mlir: + filename = f"wave_batched_gemm_{'x'.join(map(str, shape))}.mlir" + with open(filename, "w") as f: + f.write(mb.module_op.get_asm()) + + if run_bench: + if dump_perf is not None: + config["benchmark_results_file"] = os.path.join( + dump_perf, "iree_" + perf_filename + ) + iree_ref = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32) + generate_iree_ref("bmmt", [a, b], [iree_ref], config, run_bench=run_bench) + assert_close(c, iree_ref)