Skip to content

Commit

Permalink
Add support for scheduling barriers (#185)
Browse files Browse the repository at this point in the history
This PR adds op for scheduling barriers and
scheduling group barriers. These are placed
after every cycle in the kernel.

Signed-off-by: Harsh Menon <harsh@nod-labs.com>
  • Loading branch information
harsh-nod authored Oct 8, 2024
1 parent 21801d2 commit 2fd510f
Show file tree
Hide file tree
Showing 10 changed files with 181 additions and 20 deletions.
27 changes: 27 additions & 0 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

if TYPE_CHECKING:
from ..wave.constraints import Constraint
from ..wave.scheduling.resources import Operation

T = TypeVar("T", bound=Type[Any])
AccT = TypeVar("AccT")
Expand Down Expand Up @@ -723,6 +724,32 @@ def is_barrier_between(self, src: fx.Node, dst: fx.Node) -> bool:
return found_dst


@define_op("scheduling_barrier")
@dataclass
class SchedulingBarrier(CustomOp):
"""
Represents a scheduling barrier in the graph.
Takes in a list of operations that are allowed to cross
the barrier.
"""

operations: list[Operation]


@define_op("scheduling_group_barrier")
@dataclass
class SchedulingGroupBarrier(CustomOp):
"""
Represents a scheduling group barrier in the graph.
The scheduling group barrier defines scheduling groups.
Each scheduling group contains different instructions in a specific order.
The sync_id identifies scheduling groups that need to be aware of each other.
"""

instructions: dict[Operation, int]
sync_id: int


@define_op("register")
@dataclass
class NewRegister(CustomOp):
Expand Down
36 changes: 36 additions & 0 deletions iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
stream_d,
scf_d,
vector_d,
llvm_d,
)
from iree.turbine.aot.support.ir_utils import _is_float_type, _is_integer_like_type

Expand All @@ -60,6 +61,8 @@
shared_memory_barrier,
extract_slice,
CustomOp,
scheduling_barrier,
scheduling_group_barrier,
)
from ..lang.wave_types import IndexMapping, IndexSymbol
from ..compiler.base import CodegenError, ValidationError, NDEBUG
Expand All @@ -84,6 +87,7 @@

# Indexing imports.
from .._support.indexing import IndexingContext, IndexExpr, IndexSequence
from .scheduling.resources import get_scheduling_mask


@dataclass
Expand Down Expand Up @@ -1071,6 +1075,38 @@ def handle_shared_memory_barrier(emitter: WaveEmitter, node: fx.Node):
amdgpu_d.lds_barrier()


@handle_op(scheduling_barrier)
def handle_scheduling_barrier(emitter: WaveEmitter, node: fx.Node):
try:
operations = node.args[0]
except ValueError as e:
raise ValidationError("Malformed arguments") from e
mask = 0
for operation in operations:
mask |= get_scheduling_mask(operation)

mask = arith_d.constant(IntegerType.get_signless(32), mask)
llvm_d.call_intrinsic(None, "llvm.amdgcn.sched.barrier", [mask])


@handle_op(scheduling_group_barrier)
def handle_scheduling_group_barrier(emitter: WaveEmitter, node: fx.Node):
try:
instructions, sync_id = node.args
except ValueError as e:
raise ValidationError("Malformed arguments") from e
sync_id = arith_d.constant(IntegerType.get_signless(32), sync_id)
for instruction, counts in instructions.items():
mask = get_scheduling_mask(instruction)
if mask is None:
continue
mask = arith_d.constant(IntegerType.get_signless(32), mask)
counts = arith_d.constant(IntegerType.get_signless(32), counts)
llvm_d.call_intrinsic(
None, "llvm.amdgcn.sched.group.barrier", [mask, counts, sync_id]
)


###############################################################################
# Slicing ops
###############################################################################
Expand Down
1 change: 1 addition & 0 deletions iree/turbine/kernel/wave/scheduling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .schedule import *
from .resources import *
27 changes: 14 additions & 13 deletions iree/turbine/kernel/wave/scheduling/loop_reconstruction.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,24 @@
from ..constraints import Constraint, TilingConstraint
from ..constraints import Constraint
from ..._support.indexing import IndexSymbol
from ..._support.tracing import CapturedTrace
from ...ops.wave_ops import (
Reduction,
IterArg,
Placeholder,
Allocate,
Output,
Write,
GetResult,
get_custom,
SchedulingGroupBarrier,
)
from .modulo_scheduling import ModuloScheduler
from ..utils import (
graph_copy,
erase_graph,
get_induction_variable,
replace_uses_in,
)
from ..utils import subs_idxc
import torch.fx as fx
import math
from collections import deque
from collections import deque, defaultdict
from ..visualization import visualize_mapped_graphs, visualize_graph
from ....support.logging import get_logger
from ...lang.global_symbols import SHARED_ADDRESS_SPACE
import random
from typing import Optional
from .loop_reconstruction_utils import (
ArgumentContext,
create_fill_stage_schedule,
Expand All @@ -35,6 +27,7 @@
partition_graph_by_stage,
interleave_instructions,
)
from .resources import get_custom_operation_type
from enum import Enum

logger = get_logger("turbine.wave.scheduling.loop_reconstruction")
Expand All @@ -56,15 +49,14 @@ def add_nodes_by_schedule(
current_induction_variables: list[int],
rotating_registers: dict[fx.Node, list[fx.Node]],
pipelining_stage: PipelineStage = PipelineStage.KERNEL,
use_scheduling_barriers: bool = False,
):
"""
Interleave the instructions in the partitioned graph by stage
for a single initiation interval, updating the argument maps
per stage starting at the provided start times and indices.
"""
fill_or_drain = pipelining_stage in [PipelineStage.PROLOGUE, PipelineStage.EPILOGUE]
fill = pipelining_stage == PipelineStage.PROLOGUE
drain = pipelining_stage == PipelineStage.EPILOGUE

for cycle in range(initiation_interval):
logger.debug(f"Cycle: {cycle}")
Expand All @@ -79,6 +71,7 @@ def add_nodes_by_schedule(
interleaved_instructions.append((iteration, stage, node))
interleave_instructions(interleaved_instructions)

instructions = defaultdict(int)
for iteration, stage, node in interleaved_instructions:
logger.debug(f"Node: {node}, Stage: {stage}, Iteration: {iteration}")
custom_node = get_custom(node)
Expand All @@ -97,6 +90,7 @@ def add_nodes_by_schedule(
else x
),
)
instructions[get_custom_operation_type(new_node)] += 1
# Update the argument context.
arg_context[(iteration, stage, node)] = new_node.fx_node
logger.debug(
Expand Down Expand Up @@ -140,6 +134,9 @@ def add_nodes_by_schedule(
arg_context.result_to_init_arg[node], new_node.fx_node
)

if pipelining_stage == PipelineStage.KERNEL and use_scheduling_barriers:
SchedulingGroupBarrier(instructions, 0).add_to_graph(reduction_graph)


def push_placeholders(
implicit_captures: list[fx.Node],
Expand Down Expand Up @@ -306,6 +303,7 @@ def construct_kernel(
new_induction_variables: list[int],
node_map: dict[fx.Node, fx.Node],
visualize: bool = False,
use_scheduling_barriers: bool = False,
) -> tuple[Reduction, fx.Graph]:
"""
Construct the kernel of the pipelined loop.
Expand Down Expand Up @@ -367,6 +365,7 @@ def construct_kernel(
new_induction_variables,
new_rotating_registers,
PipelineStage.KERNEL,
use_scheduling_barriers,
)

# Create output node (last node in the graph).
Expand Down Expand Up @@ -491,6 +490,7 @@ def construct_pipelined_loop(
node_map: dict[fx.Node, fx.Node],
max_induction_variable: int,
visualize: bool = False,
use_scheduling_barriers: bool = False,
) -> fx.Node:
"""
Given a graph annotated with scheduling parameters, construct a pipelined loop
Expand Down Expand Up @@ -524,6 +524,7 @@ def construct_pipelined_loop(
[induction_variable + i for i in range(scheduler.num_stages)],
node_map,
visualize,
use_scheduling_barriers,
)
trace.add_subgraph(
get_custom(pipelined_reduction).subgraph_name, pipelined_reduction_graph
Expand Down
60 changes: 59 additions & 1 deletion iree/turbine/kernel/wave/scheduling/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,15 @@

from ...lang.global_symbols import *
from ..utils import subs_idxc
from ...ops.wave_ops import Read, Write, MMA, IterArg, Output, get_custom
from ...ops.wave_ops import (
Read,
Write,
MMA,
IterArg,
Output,
get_custom,
CustomOp,
)
import torch.fx as fx
from enum import Enum
import numpy as np
Expand All @@ -24,6 +32,9 @@ class Operation(Enum):
READ_GLOBAL = "read_global"
WRITE_GLOBAL = "write_global"
MMA = "mma"
ALU = "alu"
VALU = "valu"
SALU = "salu"
NOOP = "noop"


Expand All @@ -49,6 +60,29 @@ class Operation(Enum):
}


def get_custom_operation_type(custom: CustomOp) -> Operation:
if isinstance(custom, Read):
return (
Operation.READ_GLOBAL
if custom.memory_type.address_space == GLOBAL_ADDRESS_SPACE
else Operation.READ_SHARED
)
elif isinstance(custom, Write):
return (
Operation.WRITE_GLOBAL
if custom.memory_type.address_space == GLOBAL_ADDRESS_SPACE
else Operation.WRITE_SHARED
)
elif isinstance(custom, MMA):
return Operation.MMA
elif isinstance(custom, IterArg):
return Operation.NOOP
elif isinstance(custom, Output):
return Operation.NOOP
else:
return None


def annotate_resource_usage(
graph: fx.Graph,
) -> tuple[set[fx.Node], list[fx.Node], fx.Node]:
Expand Down Expand Up @@ -79,3 +113,27 @@ def annotate_resource_usage(
else:
ignore_nodes.add(node)
return ignore_nodes, iter_args, output


def get_scheduling_mask(operation: Operation) -> int:
"""
Returns the scheduling mask for the given operation.
"""
match operation:
case Operation.READ_GLOBAL:
return int("0x20", 0)
case Operation.WRITE_GLOBAL:
return int("0x40", 0)
case Operation.READ_SHARED:
return int("0x100", 0)
case Operation.WRITE_SHARED:
return int("0x200", 0)
case Operation.MMA:
return int("0x8", 0)
case Operation.ALU:
return int("0x1", 0)
case Operation.VALU:
return int("0x2", 0)
case Operation.SALU:
return int("0x4", 0)
return None
16 changes: 13 additions & 3 deletions iree/turbine/kernel/wave/scheduling/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ def visualize_scheduling_graph(edges: list[Edge]):


def schedule_reduction(
reduction: Reduction, trace: CapturedTrace, constraints: list[Constraint]
reduction: Reduction,
trace: CapturedTrace,
constraints: list[Constraint],
use_scheduling_barriers: bool = False,
):
"""
Clones the reduction graph and does the following:
Expand Down Expand Up @@ -93,13 +96,18 @@ def schedule_reduction(
node_map,
max_induction_variable,
visualize,
use_scheduling_barriers,
)

# Update new reduction count.
new_reduction.count = max_induction_variable - (scheduler.num_stages - 1)


def schedule_graph(trace: CapturedTrace, constraints: list[Constraint]):
def schedule_graph(
trace: CapturedTrace,
constraints: list[Constraint],
use_scheduling_barriers: bool = False,
):
"""
Given a graph, pipelines the reductions in the graph.
"""
Expand All @@ -112,4 +120,6 @@ def is_reduction(node: fx.Node) -> bool:
return

for reduction_node in reduction_nodes:
schedule_reduction(get_custom(reduction_node), trace, constraints)
schedule_reduction(
get_custom(reduction_node), trace, constraints, use_scheduling_barriers
)
12 changes: 11 additions & 1 deletion iree/turbine/kernel/wave/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,18 @@ def _trace_and_get_kernel_signature(
decompose_reduce_ops(graph, self.constraints, idxc.subs)

# Schedule the reduction ops.
# Scheduling should always be used with use_scheduling_barriers=True,
# as this is the only way we can ensure that LLVM enforces our desired schedule.
# However, due a bug in LLVM, you will need to patch your local LLVM repo
# with the following PR: https://github.com/kerbowa/llvm-project/commit/ee52732cddae42deed2e3387a83b20ec05860b4e
# Specifically:
# git remote add sched_fixes https://github.com/kerbowa/llvm-project.git
# git fetch sched_fixes
# git cherry-pick ee52732cddae42deed2e3387a83b20ec05860b4e
# [Manually resolve conflicts consistent with the PR]
if kwargs.get("schedule", False):
schedule_graph(graph, self.constraints)
use_scheduling_barriers = kwargs.get("use_scheduling_barriers", False)
schedule_graph(graph, self.constraints, use_scheduling_barriers)

# Add shared memory barriers.
add_shared_memory_barriers(graph)
Expand Down
7 changes: 6 additions & 1 deletion lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
},
canonicalize=True,
schedule=True,
use_scheduling_barriers=True,
):
a = torch.randn(64, 32, dtype=torch.float16)
b = torch.randn(128, 32, dtype=torch.float16)
Expand All @@ -747,10 +748,14 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
# CHECK-COUNT-1: scf.for
# CHECK-COUNT-4: amdgpu.mfma
# CHECK-COUNT-1: amdgpu.lds_barrier
# CHECK-COUNT-10: vector.load
# CHECK-COUNT-6: vector.load
# CHECK-COUNT-3: llvm.call_intrinsic "llvm.amdgcn.sched.group.barrier"
# CHECK-COUNT-4: vector.load
# CHECK-COUNT-1: llvm.call_intrinsic "llvm.amdgcn.sched.group.barrier"
# CHECK-COUNT-4: amdgpu.mfma
# CHECK-COUNT-1: amdgpu.lds_barrier
# CHECK-COUNT-2: vector.store
# CHECK-COUNT-2: llvm.call_intrinsic "llvm.amdgcn.sched.group.barrier"
# CHECK-COUNT-1: scf.yield
# CHECK-COUNT-4: amdgpu.mfma
# CHECK-COUNT-1: amdgpu.lds_barrier
Expand Down
Loading

0 comments on commit 2fd510f

Please sign in to comment.