From 2fd510f45598a50adc558891f7e806e7db59d60c Mon Sep 17 00:00:00 2001 From: harsh-nod Date: Mon, 7 Oct 2024 18:16:23 -0700 Subject: [PATCH] Add support for scheduling barriers (#185) This PR adds op for scheduling barriers and scheduling group barriers. These are placed after every cycle in the kernel. Signed-off-by: Harsh Menon --- iree/turbine/kernel/ops/wave_ops.py | 27 +++++++++ iree/turbine/kernel/wave/codegen.py | 36 +++++++++++ .../kernel/wave/scheduling/__init__.py | 1 + .../wave/scheduling/loop_reconstruction.py | 27 +++++---- .../kernel/wave/scheduling/resources.py | 60 ++++++++++++++++++- .../kernel/wave/scheduling/schedule.py | 16 ++++- iree/turbine/kernel/wave/wave.py | 12 +++- lit_tests/kernel/wave/codegen.py | 7 ++- lit_tests/kernel/wave/scheduling.py | 12 +++- tests/kernel/wave/wave_gemm_test.py | 3 + 10 files changed, 181 insertions(+), 20 deletions(-) diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index 47c6a5ec..731c8678 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -27,6 +27,7 @@ if TYPE_CHECKING: from ..wave.constraints import Constraint + from ..wave.scheduling.resources import Operation T = TypeVar("T", bound=Type[Any]) AccT = TypeVar("AccT") @@ -723,6 +724,32 @@ def is_barrier_between(self, src: fx.Node, dst: fx.Node) -> bool: return found_dst +@define_op("scheduling_barrier") +@dataclass +class SchedulingBarrier(CustomOp): + """ + Represents a scheduling barrier in the graph. + Takes in a list of operations that are allowed to cross + the barrier. + """ + + operations: list[Operation] + + +@define_op("scheduling_group_barrier") +@dataclass +class SchedulingGroupBarrier(CustomOp): + """ + Represents a scheduling group barrier in the graph. + The scheduling group barrier defines scheduling groups. + Each scheduling group contains different instructions in a specific order. + The sync_id identifies scheduling groups that need to be aware of each other. + """ + + instructions: dict[Operation, int] + sync_id: int + + @define_op("register") @dataclass class NewRegister(CustomOp): diff --git a/iree/turbine/kernel/wave/codegen.py b/iree/turbine/kernel/wave/codegen.py index 233a571d..bc6e54ed 100644 --- a/iree/turbine/kernel/wave/codegen.py +++ b/iree/turbine/kernel/wave/codegen.py @@ -39,6 +39,7 @@ stream_d, scf_d, vector_d, + llvm_d, ) from iree.turbine.aot.support.ir_utils import _is_float_type, _is_integer_like_type @@ -60,6 +61,8 @@ shared_memory_barrier, extract_slice, CustomOp, + scheduling_barrier, + scheduling_group_barrier, ) from ..lang.wave_types import IndexMapping, IndexSymbol from ..compiler.base import CodegenError, ValidationError, NDEBUG @@ -84,6 +87,7 @@ # Indexing imports. from .._support.indexing import IndexingContext, IndexExpr, IndexSequence +from .scheduling.resources import get_scheduling_mask @dataclass @@ -1071,6 +1075,38 @@ def handle_shared_memory_barrier(emitter: WaveEmitter, node: fx.Node): amdgpu_d.lds_barrier() +@handle_op(scheduling_barrier) +def handle_scheduling_barrier(emitter: WaveEmitter, node: fx.Node): + try: + operations = node.args[0] + except ValueError as e: + raise ValidationError("Malformed arguments") from e + mask = 0 + for operation in operations: + mask |= get_scheduling_mask(operation) + + mask = arith_d.constant(IntegerType.get_signless(32), mask) + llvm_d.call_intrinsic(None, "llvm.amdgcn.sched.barrier", [mask]) + + +@handle_op(scheduling_group_barrier) +def handle_scheduling_group_barrier(emitter: WaveEmitter, node: fx.Node): + try: + instructions, sync_id = node.args + except ValueError as e: + raise ValidationError("Malformed arguments") from e + sync_id = arith_d.constant(IntegerType.get_signless(32), sync_id) + for instruction, counts in instructions.items(): + mask = get_scheduling_mask(instruction) + if mask is None: + continue + mask = arith_d.constant(IntegerType.get_signless(32), mask) + counts = arith_d.constant(IntegerType.get_signless(32), counts) + llvm_d.call_intrinsic( + None, "llvm.amdgcn.sched.group.barrier", [mask, counts, sync_id] + ) + + ############################################################################### # Slicing ops ############################################################################### diff --git a/iree/turbine/kernel/wave/scheduling/__init__.py b/iree/turbine/kernel/wave/scheduling/__init__.py index 19879f4b..65b7ec28 100644 --- a/iree/turbine/kernel/wave/scheduling/__init__.py +++ b/iree/turbine/kernel/wave/scheduling/__init__.py @@ -5,3 +5,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .schedule import * +from .resources import * diff --git a/iree/turbine/kernel/wave/scheduling/loop_reconstruction.py b/iree/turbine/kernel/wave/scheduling/loop_reconstruction.py index 52f205b1..db456ec9 100644 --- a/iree/turbine/kernel/wave/scheduling/loop_reconstruction.py +++ b/iree/turbine/kernel/wave/scheduling/loop_reconstruction.py @@ -1,32 +1,24 @@ -from ..constraints import Constraint, TilingConstraint +from ..constraints import Constraint from ..._support.indexing import IndexSymbol from ..._support.tracing import CapturedTrace from ...ops.wave_ops import ( Reduction, IterArg, Placeholder, - Allocate, Output, - Write, GetResult, get_custom, + SchedulingGroupBarrier, ) from .modulo_scheduling import ModuloScheduler from ..utils import ( - graph_copy, - erase_graph, get_induction_variable, replace_uses_in, ) -from ..utils import subs_idxc import torch.fx as fx -import math -from collections import deque +from collections import deque, defaultdict from ..visualization import visualize_mapped_graphs, visualize_graph from ....support.logging import get_logger -from ...lang.global_symbols import SHARED_ADDRESS_SPACE -import random -from typing import Optional from .loop_reconstruction_utils import ( ArgumentContext, create_fill_stage_schedule, @@ -35,6 +27,7 @@ partition_graph_by_stage, interleave_instructions, ) +from .resources import get_custom_operation_type from enum import Enum logger = get_logger("turbine.wave.scheduling.loop_reconstruction") @@ -56,6 +49,7 @@ def add_nodes_by_schedule( current_induction_variables: list[int], rotating_registers: dict[fx.Node, list[fx.Node]], pipelining_stage: PipelineStage = PipelineStage.KERNEL, + use_scheduling_barriers: bool = False, ): """ Interleave the instructions in the partitioned graph by stage @@ -63,8 +57,6 @@ def add_nodes_by_schedule( per stage starting at the provided start times and indices. """ fill_or_drain = pipelining_stage in [PipelineStage.PROLOGUE, PipelineStage.EPILOGUE] - fill = pipelining_stage == PipelineStage.PROLOGUE - drain = pipelining_stage == PipelineStage.EPILOGUE for cycle in range(initiation_interval): logger.debug(f"Cycle: {cycle}") @@ -79,6 +71,7 @@ def add_nodes_by_schedule( interleaved_instructions.append((iteration, stage, node)) interleave_instructions(interleaved_instructions) + instructions = defaultdict(int) for iteration, stage, node in interleaved_instructions: logger.debug(f"Node: {node}, Stage: {stage}, Iteration: {iteration}") custom_node = get_custom(node) @@ -97,6 +90,7 @@ def add_nodes_by_schedule( else x ), ) + instructions[get_custom_operation_type(new_node)] += 1 # Update the argument context. arg_context[(iteration, stage, node)] = new_node.fx_node logger.debug( @@ -140,6 +134,9 @@ def add_nodes_by_schedule( arg_context.result_to_init_arg[node], new_node.fx_node ) + if pipelining_stage == PipelineStage.KERNEL and use_scheduling_barriers: + SchedulingGroupBarrier(instructions, 0).add_to_graph(reduction_graph) + def push_placeholders( implicit_captures: list[fx.Node], @@ -306,6 +303,7 @@ def construct_kernel( new_induction_variables: list[int], node_map: dict[fx.Node, fx.Node], visualize: bool = False, + use_scheduling_barriers: bool = False, ) -> tuple[Reduction, fx.Graph]: """ Construct the kernel of the pipelined loop. @@ -367,6 +365,7 @@ def construct_kernel( new_induction_variables, new_rotating_registers, PipelineStage.KERNEL, + use_scheduling_barriers, ) # Create output node (last node in the graph). @@ -491,6 +490,7 @@ def construct_pipelined_loop( node_map: dict[fx.Node, fx.Node], max_induction_variable: int, visualize: bool = False, + use_scheduling_barriers: bool = False, ) -> fx.Node: """ Given a graph annotated with scheduling parameters, construct a pipelined loop @@ -524,6 +524,7 @@ def construct_pipelined_loop( [induction_variable + i for i in range(scheduler.num_stages)], node_map, visualize, + use_scheduling_barriers, ) trace.add_subgraph( get_custom(pipelined_reduction).subgraph_name, pipelined_reduction_graph diff --git a/iree/turbine/kernel/wave/scheduling/resources.py b/iree/turbine/kernel/wave/scheduling/resources.py index 13e80687..346833f2 100644 --- a/iree/turbine/kernel/wave/scheduling/resources.py +++ b/iree/turbine/kernel/wave/scheduling/resources.py @@ -6,7 +6,15 @@ from ...lang.global_symbols import * from ..utils import subs_idxc -from ...ops.wave_ops import Read, Write, MMA, IterArg, Output, get_custom +from ...ops.wave_ops import ( + Read, + Write, + MMA, + IterArg, + Output, + get_custom, + CustomOp, +) import torch.fx as fx from enum import Enum import numpy as np @@ -24,6 +32,9 @@ class Operation(Enum): READ_GLOBAL = "read_global" WRITE_GLOBAL = "write_global" MMA = "mma" + ALU = "alu" + VALU = "valu" + SALU = "salu" NOOP = "noop" @@ -49,6 +60,29 @@ class Operation(Enum): } +def get_custom_operation_type(custom: CustomOp) -> Operation: + if isinstance(custom, Read): + return ( + Operation.READ_GLOBAL + if custom.memory_type.address_space == GLOBAL_ADDRESS_SPACE + else Operation.READ_SHARED + ) + elif isinstance(custom, Write): + return ( + Operation.WRITE_GLOBAL + if custom.memory_type.address_space == GLOBAL_ADDRESS_SPACE + else Operation.WRITE_SHARED + ) + elif isinstance(custom, MMA): + return Operation.MMA + elif isinstance(custom, IterArg): + return Operation.NOOP + elif isinstance(custom, Output): + return Operation.NOOP + else: + return None + + def annotate_resource_usage( graph: fx.Graph, ) -> tuple[set[fx.Node], list[fx.Node], fx.Node]: @@ -79,3 +113,27 @@ def annotate_resource_usage( else: ignore_nodes.add(node) return ignore_nodes, iter_args, output + + +def get_scheduling_mask(operation: Operation) -> int: + """ + Returns the scheduling mask for the given operation. + """ + match operation: + case Operation.READ_GLOBAL: + return int("0x20", 0) + case Operation.WRITE_GLOBAL: + return int("0x40", 0) + case Operation.READ_SHARED: + return int("0x100", 0) + case Operation.WRITE_SHARED: + return int("0x200", 0) + case Operation.MMA: + return int("0x8", 0) + case Operation.ALU: + return int("0x1", 0) + case Operation.VALU: + return int("0x2", 0) + case Operation.SALU: + return int("0x4", 0) + return None diff --git a/iree/turbine/kernel/wave/scheduling/schedule.py b/iree/turbine/kernel/wave/scheduling/schedule.py index e2b3a88e..9cf6eb19 100644 --- a/iree/turbine/kernel/wave/scheduling/schedule.py +++ b/iree/turbine/kernel/wave/scheduling/schedule.py @@ -24,7 +24,10 @@ def visualize_scheduling_graph(edges: list[Edge]): def schedule_reduction( - reduction: Reduction, trace: CapturedTrace, constraints: list[Constraint] + reduction: Reduction, + trace: CapturedTrace, + constraints: list[Constraint], + use_scheduling_barriers: bool = False, ): """ Clones the reduction graph and does the following: @@ -93,13 +96,18 @@ def schedule_reduction( node_map, max_induction_variable, visualize, + use_scheduling_barriers, ) # Update new reduction count. new_reduction.count = max_induction_variable - (scheduler.num_stages - 1) -def schedule_graph(trace: CapturedTrace, constraints: list[Constraint]): +def schedule_graph( + trace: CapturedTrace, + constraints: list[Constraint], + use_scheduling_barriers: bool = False, +): """ Given a graph, pipelines the reductions in the graph. """ @@ -112,4 +120,6 @@ def is_reduction(node: fx.Node) -> bool: return for reduction_node in reduction_nodes: - schedule_reduction(get_custom(reduction_node), trace, constraints) + schedule_reduction( + get_custom(reduction_node), trace, constraints, use_scheduling_barriers + ) diff --git a/iree/turbine/kernel/wave/wave.py b/iree/turbine/kernel/wave/wave.py index 21485ed1..177d9867 100644 --- a/iree/turbine/kernel/wave/wave.py +++ b/iree/turbine/kernel/wave/wave.py @@ -243,8 +243,18 @@ def _trace_and_get_kernel_signature( decompose_reduce_ops(graph, self.constraints, idxc.subs) # Schedule the reduction ops. + # Scheduling should always be used with use_scheduling_barriers=True, + # as this is the only way we can ensure that LLVM enforces our desired schedule. + # However, due a bug in LLVM, you will need to patch your local LLVM repo + # with the following PR: https://github.com/kerbowa/llvm-project/commit/ee52732cddae42deed2e3387a83b20ec05860b4e + # Specifically: + # git remote add sched_fixes https://github.com/kerbowa/llvm-project.git + # git fetch sched_fixes + # git cherry-pick ee52732cddae42deed2e3387a83b20ec05860b4e + # [Manually resolve conflicts consistent with the PR] if kwargs.get("schedule", False): - schedule_graph(graph, self.constraints) + use_scheduling_barriers = kwargs.get("use_scheduling_barriers", False) + schedule_graph(graph, self.constraints, use_scheduling_barriers) # Add shared memory barriers. add_shared_memory_barriers(graph) diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 65db9a78..b9fc81ed 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -730,6 +730,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: }, canonicalize=True, schedule=True, + use_scheduling_barriers=True, ): a = torch.randn(64, 32, dtype=torch.float16) b = torch.randn(128, 32, dtype=torch.float16) @@ -747,10 +748,14 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK-COUNT-1: scf.for # CHECK-COUNT-4: amdgpu.mfma # CHECK-COUNT-1: amdgpu.lds_barrier - # CHECK-COUNT-10: vector.load + # CHECK-COUNT-6: vector.load + # CHECK-COUNT-3: llvm.call_intrinsic "llvm.amdgcn.sched.group.barrier" + # CHECK-COUNT-4: vector.load + # CHECK-COUNT-1: llvm.call_intrinsic "llvm.amdgcn.sched.group.barrier" # CHECK-COUNT-4: amdgpu.mfma # CHECK-COUNT-1: amdgpu.lds_barrier # CHECK-COUNT-2: vector.store + # CHECK-COUNT-2: llvm.call_intrinsic "llvm.amdgcn.sched.group.barrier" # CHECK-COUNT-1: scf.yield # CHECK-COUNT-4: amdgpu.mfma # CHECK-COUNT-1: amdgpu.lds_barrier diff --git a/lit_tests/kernel/wave/scheduling.py b/lit_tests/kernel/wave/scheduling.py index aefad516..55de465e 100644 --- a/lit_tests/kernel/wave/scheduling.py +++ b/lit_tests/kernel/wave/scheduling.py @@ -93,7 +93,7 @@ def test_gemm_pipelined(): expand_graph(trace, constraints) minimize_global_loads(trace, constraints) apply_shared_memory_indexing_corrections(trace, constraints) - schedule_graph(trace, constraints) + schedule_graph(trace, constraints, True) print_subgraph(trace, "pipelined_reduction", False) # CHECK: %acc_0_0_0 @@ -113,28 +113,38 @@ def test_gemm_pipelined(): # CHECK-NEXT: %read_shared_0_0_1 # CHECK-NEXT: %read_4 # CHECK-NEXT: %read_5 + # CHECK-NEXT: %scheduling_group_barrier + # CHECK-SAME: ({Operation.MMA: 1, Operation.READ_SHARED: 2, Operation.READ_GLOBAL: 2}, 0) # CHECK-NEXT: %read_shared_1_0_0 # CHECK-NEXT: %read_shared_1_0_1 # CHECK-NEXT: %mma_0_0_0 # CHECK-SAME: (%read_shared_0_0_0, %read_shared_0_0_1, %acc_0_0_0) # CHECK-NEXT: %mma_0_1_0 # CHECK-SAME: (%read_shared_0_0_0, %rotating_reg_3, %acc_0_1_0) + # CHECK-NEXT: %scheduling_group_barrier + # CHECK-SAME: ({Operation.READ_SHARED: 2, Operation.MMA: 2}, 0) # CHECK-NEXT: %mma_0_0_1 # CHECK-SAME: (%rotating_reg_0, %rotating_reg_2, %mma_0_0_0) # CHECK-NEXT: %mma_1_0_0 # CHECK-SAME: (%read_shared_1_0_0, %read_shared_0_0_1, %acc_1_0_0) # CHECK-NEXT: %write_2 # CHECK-NEXT: %write_3 + # CHECK-NEXT: %scheduling_group_barrier + # CHECK-SAME: ({Operation.MMA: 2, Operation.WRITE_SHARED: 2}, 0) # CHECK-NEXT: %mma_1_0_1 # CHECK-SAME: (%read_shared_1_0_1, %rotating_reg_2, %mma_1_0_0) # CHECK-NEXT: %mma_0_1_1 # CHECK-SAME: (%rotating_reg_0, %rotating_reg_5, %mma_0_1_0) # CHECK-NEXT: %read_shared_0_1_0 # CHECK-NEXT: %read_shared_0_1_1 + # CHECK-NEXT: %scheduling_group_barrier + # CHECK-SAME: ({Operation.MMA: 2, Operation.READ_SHARED: 2}, 0) # CHECK-NEXT: %mma_1_1_0 # CHECK-SAME: (%read_shared_1_0_0, %rotating_reg_3, %mma_1_1_1) # CHECK-NEXT: %read_shared_0_0_2 # CHECK-NEXT: %read_shared_0_0_3 + # CHECK-NEXT: %scheduling_group_barrier + # CHECK-SAME: ({Operation.MMA: 1, Operation.READ_SHARED: 2}, 0) # CHECK-NEXT: [mma_0_0_1, mma_0_1_1, mma_1_0_1, mma_1_1_1, read_shared_0_0_2, read_shared_1_0_1, read_shared_0_0_3, read_shared_0_1_0, rotating_reg_5, read_shared_0_1_1, mma_1_1_0] print_subgraph(trace, "region_1", False) diff --git a/tests/kernel/wave/wave_gemm_test.py b/tests/kernel/wave/wave_gemm_test.py index e9487a64..78d26e3b 100644 --- a/tests/kernel/wave/wave_gemm_test.py +++ b/tests/kernel/wave/wave_gemm_test.py @@ -21,6 +21,8 @@ require_e2e = pytest.mark.skipif(not _run_e2e, reason="e2e tests are disabled") # Whether to dump the generated MLIR module. test_dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0)) +# Whether to use scheduling group barriers (needs LLVM fix). +enable_scheduling_barriers = int(os.environ.get("WAVE_USE_SCHED_BARRIERS", 0)) default_test_shapes = [(1024, 5120, 640), (2048, 10240, 1280), (4096, 20480, 2560)] @@ -143,6 +145,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: run_bench=run_bench, run_config=config, schedule=enable_scheduling, + use_scheduling_barriers=enable_scheduling_barriers, ): a = torch.randn(shape[0], shape[2], dtype=torch.float16) b = torch.randn(shape[1], shape[2], dtype=torch.float16)