diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index de953373..e68ad86c 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -498,6 +498,18 @@ def scheduling_parameters(self, value: Any): raise ValueError("Scheduling parameters must be a dict") self.fx_node.scheduling_parameters = value + @property + def expanded_dims(self) -> dict[IndexSymbol, int]: + if hasattr(self.fx_node, "expanded_dims"): + return self.fx_node.expanded_dims + return None + + @expanded_dims.setter + def expanded_dims(self, value: dict[IndexSymbol, int]): + if not isinstance(value, dict): + raise ValueError("Expanded dims must be a dict") + self.fx_node.expanded_dims = value + def post_expansion(self, constraints: list["Constraint"]) -> None: """ Hook for post-expansion operations. This is called after the arguments @@ -1060,7 +1072,11 @@ def index(self) -> dict[IndexSymbol, IndexSequence]: custom = get_custom(self.value) if custom.index is None: return None - assert isinstance(custom.index, list) and self.res_idx < len(custom.index) + if not isinstance(custom, Reduction): + return custom.index + assert isinstance(custom.index, list) and self.res_idx < len( + custom.indexing_dims + ) return custom.index[self.res_idx] @index.setter diff --git a/iree/turbine/kernel/wave/expansion.py b/iree/turbine/kernel/wave/expansion.py index e124ac50..f83b8fad 100644 --- a/iree/turbine/kernel/wave/expansion.py +++ b/iree/turbine/kernel/wave/expansion.py @@ -109,137 +109,6 @@ def is_expandable(arg: Any) -> bool: return isinstance(arg, CustomOp) -def is_contiguous_dim( - dim: IndexSymbol, symbolic_shape: list[IndexSymbol], vector_shapes: list[int] -) -> bool: - """ - Checks if the given dimension is stored contiguously in memory. This happens if - the dimension is the last one in the symbolic shape or all dimensions after it - are unit dimensions. - """ - is_innermost_dim = dim == symbolic_shape[-1] - dim_index = symbolic_shape.index(dim) - static_shape = [vector_shapes[dim] for dim in symbolic_shape] - all_unit_dims = all(dim == 1 for dim in static_shape[dim_index + 1 :]) - return is_innermost_dim or all_unit_dims - - -def compute_stride( - symbolic_shape: tuple[IndexSymbol, ...], - vector_shapes: dict[IndexSymbol, int], - target_dim: IndexSymbol, -) -> int: - """ - Compute the stride for a given dimension based on the vector shapes. - The stride is the product of the vector shapes of all dimensions that are - not the given dimension. - """ - stride = 1 - for dim in reversed(symbolic_shape): - if dim == target_dim: - break - assert dim in vector_shapes, f"Dimension {dim} not found in vector shapes" - stride *= vector_shapes[dim] - - try: - stride = int(stride) - except Exception as e: - logger.error(e) - return stride - - -def set_node_index( - constraints: Sequence[Constraint], - mma_index: dict[IndexSymbol, int], - mma_slices: dict[IndexSymbol, list[fx.Node]], - dim_tile_size: dict[IndexSymbol, int], - custom: CustomOp, - dim_scaling: dict[IndexSymbol, int], -): - """ - Set the index of the node based on the user constraints. In certain - operators (like read, write), there is only a single index associated - with the node (the index to read from, the index to write to). But for - other operators like mma, each operand reads from a different index. - - Rather than maintain operand specific indices for operators, we maintain - dimension specific indices for each operator. So for an mma operator that - has a signature of (MxK, NxK) -> MxN, we maintain only 3 mappings for - dimensions M, N and K, but allow each mapping to be piecewise conditioned - on the operand. - """ - hardware_constraint = [get_hardware_constraint(constraints)] - workgroup_constraints = { - c.dim: c for c in constraints if isinstance(c, WorkgroupConstraint) - } - other_constraints = [ - c for c in constraints if not isinstance(c, HardwareConstraint) - ] - # Apply hardware constraint first since it dictates the stride and size. - sorted_constraints = hardware_constraint + other_constraints - - index = {} - # The semantics of elements_per_thread are that it represents the number of - # elements that are loaded contiguously from memory. - elements_per_thread = getattr(custom, "elements_per_thread", None) - - for dim in custom.indexing_dims: - index_seq = None - for constraint in sorted_constraints: - if isinstance(constraint, HardwareConstraint): - inputs = None - if dim in mma_index: - inputs = (mma_index[dim], elements_per_thread, None) - else: - # Assumes vector shapes are associated with workgroup dims. - assert ( - dim in workgroup_constraints - ), f"Dimension {dim} not found in workgroup constraints" - assert ( - dim in constraint.vector_shapes - ), f"Dimension {dim} not found in vector shapes" - if constraint.vector_shapes[dim] == 0: - continue - inputs = ( - workgroup_constraints[dim].workgroup_dim, - ( - 1 - if not is_contiguous_dim( - dim, - custom.indexing_dims, - constraint.vector_shapes, - ) - else elements_per_thread - ), - compute_stride( - custom.indexing_dims, constraint.vector_shapes, dim - ), - ) - if elements_per_thread is None: - # Here we end up with a situation where there will be no thread level - # dependence in the dimensional index. - # TODO: Evaluate if this is a valid case. - continue - index_seq = constraint.apply(dim, *inputs, dim in mma_index) - if dim in mma_index: - index_seq = specialize_index_sequence(index_seq, mma_slices, custom) - - elif constraint.dim == dim: - if index_seq is None: - index_seq = constraint.apply() - else: - index_seq.start += constraint.apply().start - - if index_seq is not None: - if dim in dim_scaling and dim in dim_tile_size: - index_seq.start += dim_scaling[dim] * dim_tile_size[dim] - index.update({dim: index_seq}) - else: - index.update({dim: IndexSequence(0, 1, 1)}) - - setattr(custom.fx_node, "index", index) - - def expand_graph( trace: CapturedTrace, constraints_or_scaling: Sequence[Constraint] | dict[IndexSymbol, int], @@ -250,15 +119,8 @@ def expand_graph( """ if isinstance(constraints_or_scaling, dict): dim_scaling = constraints_or_scaling - node_index_setter = lambda *args: None else: - mma_index, mma_slices = get_mma_dimensional_mapping( - trace, get_hardware_constraint(constraints_or_scaling) - ) - dim_scaling, dim_tile_size = get_dim_scaling(constraints_or_scaling, mma_index) - node_index_setter = partial( - set_node_index, constraints_or_scaling, mma_index, mma_slices, dim_tile_size - ) + dim_scaling = get_dim_scaling(constraints_or_scaling) # Start from the back and expand in the corresponding indexing dimensions of a node # Then proceed to the operands @@ -294,7 +156,6 @@ def expand_graph( trace, expand_dims, dim_scaling, - node_index_setter, expansion_context, ) @@ -304,7 +165,6 @@ def _expand_node( trace: CapturedTrace, dim_query: dict[IndexSymbol, int], dim_scaling: dict[IndexSymbol, int], - node_index_setter: Callable[[CustomOp, dict[IndexSymbol, int]], None], context: ExpandedNodeMap, res_idx: int = 0, ) -> CustomOp: @@ -318,7 +178,6 @@ def _expand_node( trace, dim_query, dim_scaling, - node_index_setter, context, res_idx, ).fx_node @@ -329,9 +188,7 @@ def _expand_node( logger.debug(f"Already expanded node: {node} in {dim_query}") return context[(node, get_indexed_dims(dim_query, node), res_idx)] elif isinstance(node, Reduction): - return _expand_reduction( - node, trace, dim_query, dim_scaling, node_index_setter, context, res_idx - ) + return _expand_reduction(node, trace, dim_query, dim_scaling, context, res_idx) elif isinstance(node, Getitem): res_idx = node.res_idx elif isinstance(node, GetResult) and not isinstance(node, Getitem): @@ -368,11 +225,8 @@ def _expand_node( if isinstance(node, IterArg): _expand_node.last_expanded_iter_arg = new_node.fx_node - new_node.fx_node.expanded_dims = restricted_dims + new_node.expanded_dims = restricted_dims new_node.fx_node.name = get_expanded_name(node, restricted_dims) - node_index_setter(new_node, restricted_dims) - - constraints = node_index_setter.args[0] # Proceed with expansion of the arguments for i, arg in node.node_args.items(): @@ -382,14 +236,11 @@ def _expand_node( trace, restricted_dims, dim_scaling, - node_index_setter, context, res_idx, ) new_node.update_arg(i, new_arg) - new_node.post_expansion(constraints) - context[(node, get_indexed_dims(restricted_dims, node), res_idx)] = new_node return new_node @@ -399,7 +250,6 @@ def _expand_reduction( trace: CapturedTrace, dim_query: dict[IndexSymbol, int], dim_scaling: dict[IndexSymbol, int], - node_index_setter: Callable[[CustomOp, dict[IndexSymbol, int]], None], context: ExpandedNodeMap, res_idx: int = 0, ) -> CustomOp: @@ -442,7 +292,12 @@ def _expand_reduction( # Proceed with expansion inside the reduction new_output_args.append( _expand_node( - arg, trace, dims, dim_scaling, node_index_setter, context, res_idx + arg, + trace, + dims, + dim_scaling, + context, + res_idx, ) ) @@ -454,7 +309,6 @@ def _expand_reduction( trace, dims, dim_scaling, - node_index_setter, context, res_idx, ) @@ -470,7 +324,6 @@ def _expand_reduction( output, trace, dim_scaling, - node_index_setter, context, res_idx, ) @@ -502,12 +355,9 @@ def _contains(elem, container): return elem in container -def get_dim_scaling( - constraints: Sequence[Constraint], mma_indices: dict[IndexSymbol, int] -) -> tuple[dict[IndexSymbol, int]]: +def get_dim_scaling(constraints: Sequence[Constraint]) -> dict[IndexSymbol, int]: """Get the number of expansions for the dimensions based on the constraints.""" dim_scaling: dict[IndexSymbol, int] = {} - dim_tile_size: dict[IndexSymbol, int] = {} hardware_constraints: list[HardwareConstraint] = [ constraint for constraint in constraints @@ -523,20 +373,12 @@ def get_dim_scaling( ): hw_cons = hardware_constraints[0] tile_size = idxc.get_static_value(constraint.tile_size) - if not ( - _contains(constraint.dim, mma_indices) - or _contains(constraint.dim, hw_cons.vector_shapes) - ): + if not _contains(constraint.dim, hw_cons.vector_shapes): raise ValueError( f"Attempting to determine vector shape for unmapped dimension {constraint.dim}" ) - if mma_indices and constraint.dim in mma_indices: - vector_size = hardware_constraints[0].mma_matrix_shapes[ - mma_indices[constraint.dim] - ] - else: - vector_size = hardware_constraints[0].vector_shapes[constraint.dim] + vector_size = hw_cons.vector_shapes[constraint.dim] # No dim scaling for dims with 0 vector size. if vector_size == 0: @@ -544,9 +386,7 @@ def get_dim_scaling( wave_count = 1 if isinstance(constraint, WorkgroupConstraint): - wave_count = hardware_constraints[0].waves_per_block[ - constraint.workgroup_dim - ] + wave_count = hw_cons.waves_per_block[constraint.workgroup_dim] if tile_size is None or wave_count is None or vector_size is None: raise ValueError( "Tile size, wave count and vector size must be statically known" @@ -559,8 +399,7 @@ def get_dim_scaling( "Tile size must be divisible by wave count and vector size" ) dim_scaling[constraint.dim] = tile_size // wave_count // vector_size - dim_tile_size[constraint.dim] = vector_size - return (dim_scaling, dim_tile_size) + return dim_scaling def _handle_reduction_dim( @@ -568,7 +407,6 @@ def _handle_reduction_dim( output: Output, trace: CapturedTrace, dim_scaling: dict[IndexSymbol, int], - node_index_setter: Callable[[CustomOp, dict[IndexSymbol, int]], None], context: ExpandedNodeMap, res_idx: int, ): @@ -589,7 +427,7 @@ def _handle_reduction_dim( if isinstance(user, Output): continue - dims = user.fx_node.expanded_dims + dims = dict(user.fx_node.expanded_dims) dims[reduction.axis] = scale_idx # Temporarily replace the loop carried arg here to avoid # duplicated expansion. Otherwise we have the following situation: @@ -611,7 +449,6 @@ def _handle_reduction_dim( trace, dims, dim_scaling, - node_index_setter, context, res_idx, ) diff --git a/iree/turbine/kernel/wave/index_sequence_analysis.py b/iree/turbine/kernel/wave/index_sequence_analysis.py index 305ea5a4..91fdf530 100644 --- a/iree/turbine/kernel/wave/index_sequence_analysis.py +++ b/iree/turbine/kernel/wave/index_sequence_analysis.py @@ -4,19 +4,25 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from ..ops.wave_ops import Write, ExtractSlice, get_custom -from .constraints import Constraint, HardwareConstraint +from ..ops.wave_ops import Write, ExtractSlice, get_custom, Reduction +from .constraints import Constraint, HardwareConstraint, WorkgroupConstraint from .._support.tracing import CapturedTrace, IndexingContext from .._support.indexing import IndexSymbol, IndexSequence from ..lang.global_symbols import * from .utils import ( simplify_index, get_mma_dimensional_mapping, - get_hardware_vector_size, + get_hardware_constraint, subs_idxc, + specialize_index_sequence, ) import torch.fx as fx import numpy as np +from functools import partial +from typing import Sequence +from ...support.logging import get_logger + +logger = get_logger("turbine.wave.index_sequence_analysis") def get_vector_shape( @@ -94,3 +100,159 @@ def has_strided_access(node: fx.Node) -> bool: for j, dim in enumerate(custom.register_type.symbolic_shape) } custom.graph.erase_node(operator) + + +def set_node_indices(trace: CapturedTrace, constraints: list[Constraint]): + mma_index, mma_slices = get_mma_dimensional_mapping( + trace, get_hardware_constraint(constraints) + ) + trace.walk(partial(set_node_index, constraints, mma_index, mma_slices)) + + +def compute_stride( + symbolic_shape: tuple[IndexSymbol, ...], + vector_shapes: dict[IndexSymbol, int], + target_dim: IndexSymbol, +) -> int: + """ + Compute the stride for a given dimension based on the vector shapes. + The stride is the product of the vector shapes of all dimensions that are + not the given dimension. + """ + stride = 1 + for dim in reversed(symbolic_shape): + if dim == target_dim: + break + assert dim in vector_shapes, f"Dimension {dim} not found in vector shapes" + stride *= vector_shapes[dim] + + try: + stride = int(stride) + except Exception as e: + logger.error(e) + return stride + + +def is_contiguous_dim( + dim: IndexSymbol, symbolic_shape: list[IndexSymbol], vector_shapes: list[int] +) -> bool: + """ + Checks if the given dimension is stored contiguously in memory. This happens if + the dimension is the last one in the symbolic shape or all dimensions after it + are unit dimensions. + """ + is_innermost_dim = dim == symbolic_shape[-1] + dim_index = symbolic_shape.index(dim) + static_shape = [vector_shapes[dim] for dim in symbolic_shape] + all_unit_dims = all(dim == 1 for dim in static_shape[dim_index + 1 :]) + return is_innermost_dim or all_unit_dims + + +def set_node_index( + constraints: Sequence[Constraint], + mma_index: dict[IndexSymbol, int], + mma_slices: dict[IndexSymbol, list[fx.Node]], + node: fx.Node, +): + """ + Set the index of the node based on the user constraints. In certain + operators (like read, write), there is only a single index associated + with the node (the index to read from, the index to write to). But for + other operators like mma, each operand reads from a different index. + + Rather than maintain operand specific indices for operators, we maintain + dimension specific indices for each operator. So for an mma operator that + has a signature of (MxK, NxK) -> MxN, we maintain only 3 mappings for + dimensions M, N and K, but allow each mapping to be piecewise conditioned + on the operand. + """ + hardware_constraint = [get_hardware_constraint(constraints)] + workgroup_constraints = { + c.dim: c for c in constraints if isinstance(c, WorkgroupConstraint) + } + other_constraints = [ + c for c in constraints if not isinstance(c, HardwareConstraint) + ] + # Apply hardware constraint first since it dictates the stride and size. + sorted_constraints = hardware_constraint + other_constraints + + index = {} + # The semantics of elements_per_thread are that it represents the number of + # elements that are loaded contiguously from memory. + custom = get_custom(node) + + elements_per_thread = getattr(custom, "elements_per_thread", None) + + if isinstance(custom, Reduction): + return + + for dim in custom.indexing_dims: + index_seq = None + for constraint in sorted_constraints: + if isinstance(constraint, HardwareConstraint): + inputs = None + if dim in mma_index: + inputs = (mma_index[dim], elements_per_thread, None) + else: + # Assumes vector shapes are associated with workgroup dims. + if dim not in workgroup_constraints: + continue + assert ( + dim in constraint.vector_shapes + ), f"Dimension {dim} not found in vector shapes" + if constraint.vector_shapes[dim] == 0: + continue + inputs = ( + workgroup_constraints[dim].workgroup_dim, + ( + 1 + if not is_contiguous_dim( + dim, + custom.indexing_dims, + constraint.vector_shapes, + ) + else elements_per_thread + ), + compute_stride( + custom.indexing_dims, constraint.vector_shapes, dim + ), + ) + if elements_per_thread is None: + # Here we end up with a situation where there will be no thread level + # dependence in the dimensional index. + # TODO: Evaluate if this is a valid case. + continue + index_seq = constraint.apply(dim, *inputs, dim in mma_index) + if dim in mma_index: + index_seq = specialize_index_sequence(index_seq, mma_slices, custom) + + elif constraint.dim == dim: + if index_seq is None: + index_seq = constraint.apply() + else: + index_seq.start += constraint.apply().start + + if index_seq is not None: + index.update({dim: index_seq}) + else: + index.update({dim: IndexSequence(0, 1, 1)}) + + custom.index = index + + +def set_post_expansion_indices(trace: CapturedTrace, constraints: list[Constraint]): + """ + Add offsets to the indices based on the expanded dims. + """ + hw_cons = get_hardware_constraint(constraints) + + def apply_offset(node: fx.Node): + custom = get_custom(node) + if custom.expanded_dims is None: + return False + for dim, scale in custom.expanded_dims.items(): + if dim in custom.index: + custom.index[dim].start += scale * hw_cons.vector_shapes[dim] + return False + + trace.walk(apply_offset) diff --git a/iree/turbine/kernel/wave/wave.py b/iree/turbine/kernel/wave/wave.py index 177d9867..0492a185 100644 --- a/iree/turbine/kernel/wave/wave.py +++ b/iree/turbine/kernel/wave/wave.py @@ -37,7 +37,11 @@ from ..lang.global_symbols import * from ..ops import wave_ops from ..ops.wave_ops import Reduction, CustomOp, get_custom -from .index_sequence_analysis import partition_strided_operators +from .index_sequence_analysis import ( + partition_strided_operators, + set_node_indices, + set_post_expansion_indices, +) from .shared_memory_indexing import apply_shared_memory_indexing_corrections from .thread_shape_analysis import determine_thread_shapes from .scheduling.schedule import schedule_graph @@ -221,9 +225,15 @@ def _trace_and_get_kernel_signature( promote_placeholders(graph, self.constraints) hoist_allocs(graph) + # Set indices. + set_node_indices(graph, self.constraints) + # Expansion expand_graph(graph, self.constraints) + # Set post expansion indices. + set_post_expansion_indices(graph, self.constraints) + # Clean up chains of GetResults remove_chained_getresult(graph) diff --git a/lit_tests/kernel/wave/barriers.py b/lit_tests/kernel/wave/barriers.py index 14eb2e60..4b302e87 100644 --- a/lit_tests/kernel/wave/barriers.py +++ b/lit_tests/kernel/wave/barriers.py @@ -6,6 +6,10 @@ import iree.turbine.kernel as tk import iree.turbine.kernel.lang as tkl import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.wave.index_sequence_analysis import ( + set_node_indices, + set_post_expansion_indices, +) from iree.turbine.kernel.wave.promotion import promote_node, promote_placeholders from iree.turbine.kernel.wave.barriers import add_shared_memory_barriers from iree.turbine.kernel.wave.hoisting import hoist_allocs @@ -83,7 +87,9 @@ def test_read_write_equal_sizes(): read_node = get_read_nodes(graph)[0] IndexingContext.current().finalize() promote_node(read_node, SHARED_ADDRESS_SPACE, constraints) + set_node_indices(trace, constraints) expand_graph(trace, constraints) + set_post_expansion_indices(trace, constraints) tweak_index(graph) add_shared_memory_barriers(trace) print_trace(trace, False) @@ -169,7 +175,9 @@ def test_gemm(): for read_node in read_nodes: promote_node(read_node, SHARED_ADDRESS_SPACE, constraints) hoist_allocs(trace) + set_node_indices(trace, constraints) expand_graph(trace, constraints) + set_post_expansion_indices(trace, constraints) tweak_index(graph) add_shared_memory_barriers(trace) print_trace(trace, False) diff --git a/lit_tests/kernel/wave/expansion.py b/lit_tests/kernel/wave/expansion.py index 369cf24e..10a11931 100644 --- a/lit_tests/kernel/wave/expansion.py +++ b/lit_tests/kernel/wave/expansion.py @@ -6,6 +6,10 @@ import iree.turbine.kernel.lang as tkl import iree.turbine.kernel.wave as tkw from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.wave.index_sequence_analysis import ( + set_node_indices, + set_post_expansion_indices, +) from iree.turbine.kernel._support.indexing import IndexingContext from iree.turbine.kernel.lang.global_symbols import * from iree.turbine.kernel.wave.utils import run_test, print_trace @@ -62,7 +66,9 @@ def test_read_write_equal_sizes(): ): graph = read_write_same_size() IndexingContext.current().finalize() + set_node_indices(graph, constraints) expand_graph(graph, constraints) + set_post_expansion_indices(graph, constraints) print_trace(graph) # CHECK: %a # CHECK-NEXT: %c @@ -141,7 +147,9 @@ def test_read_write(): ): graph = read_write_different_dims() IndexingContext.current().finalize() + set_node_indices(graph, constraints) expand_graph(graph, constraints) + set_post_expansion_indices(graph, constraints) print_trace(graph) # CHECK: %a # CHECK-NEXT: %c @@ -216,7 +224,9 @@ def test_gemm(): ): graph = gemm() IndexingContext.current().finalize() + set_node_indices(graph, constraints) expand_graph(graph, constraints) + set_post_expansion_indices(graph, constraints) print_trace(graph) # Root graph: # CHECK: %a @@ -250,7 +260,7 @@ def test_gemm(): # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}) # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}) # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}) - # CHECK-NEXT: reduction(axis=K, init_args=[register_0_0_0, register_0_1_0, register_1_0_0, register_1_1_0], subgraph_name=region_0, implicit_captures=[a, b]) + # CHECK-NEXT: reduction(axis=K, init_args=[register_0_0_0, register_0_1_0, register_1_0_0, register_1_1_0], subgraph_name=region_0, implicit_captures=[a, b] # CHECK-NEXT: get_result(value=reduction, res_idx=3) # CHECK-NEXT: get_result(value=reduction, res_idx=2) # CHECK-NEXT: get_result(value=reduction, res_idx=1) @@ -400,7 +410,9 @@ def test_batched_gemm(): ): graph = batched_gemm() IndexingContext.current().finalize() + set_node_indices(graph, constraints) expand_graph(graph, constraints) + set_post_expansion_indices(graph, constraints) print_trace(graph) # Root graph: # CHECK: %a @@ -434,7 +446,7 @@ def test_batched_gemm(): # CHECK-NEXT: register(shape=(B, M, N), dtype=f32, value=0.0, index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}) # CHECK-NEXT: register(shape=(B, M, N), dtype=f32, value=0.0, index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}) # CHECK-NEXT: register(shape=(B, M, N), dtype=f32, value=0.0, index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}) - # CHECK-NEXT: reduction(axis=K, init_args=[register_0_0_0, register_0_1_0, register_1_0_0, register_1_1_0], subgraph_name=region_0, implicit_captures=[a, b]) + # CHECK-NEXT: reduction(axis=K, init_args=[register_0_0_0, register_0_1_0, register_1_0_0, register_1_1_0], subgraph_name=region_0, implicit_captures=[a, b] # CHECK-NEXT: get_result(value=reduction, res_idx=3) # CHECK-NEXT: get_result(value=reduction, res_idx=2) # CHECK-NEXT: get_result(value=reduction, res_idx=1) @@ -559,7 +571,9 @@ def test_gemm_reduction_expansion_only(): ): graph = gemm() IndexingContext.current().finalize() + set_node_indices(graph, constraints) expand_graph(graph, constraints) + set_post_expansion_indices(graph, constraints) print_trace(graph) # Root graph: # CHECK: %a @@ -660,7 +674,9 @@ def py_arithmetic_different_dims(): ): graph = py_arithmetic_different_dims() IndexingContext.current().finalize() + set_node_indices(graph, constraints) expand_graph(graph, constraints) + set_post_expansion_indices(graph, constraints) print_trace(graph) # CHECK: %a # CHECK-NEXT: %c diff --git a/lit_tests/kernel/wave/index_sequence_analysis.py b/lit_tests/kernel/wave/index_sequence_analysis.py index f0149b70..16dc03cb 100644 --- a/lit_tests/kernel/wave/index_sequence_analysis.py +++ b/lit_tests/kernel/wave/index_sequence_analysis.py @@ -19,6 +19,8 @@ ) from iree.turbine.kernel.wave.index_sequence_analysis import ( partition_strided_operators, + set_node_indices, + set_post_expansion_indices, ) @@ -84,7 +86,9 @@ def test_gemm(): IndexingContext.current().finalize() promote_placeholders(trace, constraints) hoist_allocs(trace) + set_node_indices(trace, constraints) expand_graph(trace, constraints) + set_post_expansion_indices(trace, constraints) minimize_global_loads(trace, constraints) apply_shared_memory_indexing_corrections(trace, constraints) partition_strided_operators(trace, constraints) diff --git a/lit_tests/kernel/wave/minimize_global_loads.py b/lit_tests/kernel/wave/minimize_global_loads.py index 329a9ccf..21d15ce8 100644 --- a/lit_tests/kernel/wave/minimize_global_loads.py +++ b/lit_tests/kernel/wave/minimize_global_loads.py @@ -20,6 +20,10 @@ from iree.turbine.kernel.wave.shared_memory_indexing import ( apply_shared_memory_indexing_corrections, ) +from iree.turbine.kernel.wave.index_sequence_analysis import ( + set_node_indices, + set_post_expansion_indices, +) # Input sizes @@ -85,7 +89,9 @@ def test_gemm(): IndexingContext.current().finalize() promote_placeholders(trace, constraints) hoist_allocs(trace) + set_node_indices(trace, constraints) expand_graph(trace, constraints) + set_post_expansion_indices(trace, constraints) if visualize: visualize_graph(trace.get_subgraph("region_0"), "before.png") minimize_global_loads(trace, constraints) diff --git a/lit_tests/kernel/wave/scheduling.py b/lit_tests/kernel/wave/scheduling.py index 55de465e..513437fd 100644 --- a/lit_tests/kernel/wave/scheduling.py +++ b/lit_tests/kernel/wave/scheduling.py @@ -18,6 +18,10 @@ apply_shared_memory_indexing_corrections, ) from iree.turbine.kernel.wave.scheduling.schedule import schedule_graph +from iree.turbine.kernel.wave.index_sequence_analysis import ( + set_node_indices, + set_post_expansion_indices, +) # Input sizes @@ -90,7 +94,9 @@ def test_gemm_pipelined(): IndexingContext.current().finalize() promote_placeholders(trace, constraints) hoist_allocs(trace) + set_node_indices(trace, constraints) expand_graph(trace, constraints) + set_post_expansion_indices(trace, constraints) minimize_global_loads(trace, constraints) apply_shared_memory_indexing_corrections(trace, constraints) schedule_graph(trace, constraints, True)