Skip to content

Commit

Permalink
Move globals to separate file and prefix with $ (#63)
Browse files Browse the repository at this point in the history
This PR adds a dollar prefix to all private globals to avoid namespace
collisions with user defined
variables.

Signed-off-by: Harsh Menon <harsh@nod-labs.com>
  • Loading branch information
harsh-nod authored Aug 1, 2024
1 parent 4876123 commit cf66a4c
Show file tree
Hide file tree
Showing 9 changed files with 133 additions and 135 deletions.
196 changes: 98 additions & 98 deletions lit_tests/kernel/wave/expansion.py

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions lit_tests/kernel/wave/promotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import shark_turbine.kernel.wave as tkw
from shark_turbine.kernel.wave.promotion import promote_node
from shark_turbine.kernel.wave.hoisting import hoist_allocs
from shark_turbine.kernel.wave.address_spaces import *
from shark_turbine.kernel.lang.global_symbols import *
from shark_turbine.kernel._support.tracing import CapturedTrace
from shark_turbine.kernel._support.indexing import IndexingContext
from shark_turbine.kernel.ops.wave_ops import *
Expand Down Expand Up @@ -92,7 +92,7 @@ def test_read_write_equal_sizes():
# CHECK-NEXT: %read
# CHECK-SAME: (%a, 4)
# CHECK-NEXT: %allocate
# CHECK-SAME: ((M, N), f16, SHARED_ADDRESS_SPACE)
# CHECK-SAME: ((M, N), f16, $SHARED_ADDRESS_SPACE)
# CHECK-NEXT: %write_1
# CHECK-SAME: (%read, %allocate, 4)
# CHECK-NEXT: %read_1
Expand Down Expand Up @@ -150,9 +150,9 @@ def test_gemm():
# CHECK-NEXT: %c
# CHECK-NEXT: %register
# CHECK-NEXT: %allocate
# CHECK-SAME: ((M, K), f16, SHARED_ADDRESS_SPACE)
# CHECK-SAME: ((M, K), f16, $SHARED_ADDRESS_SPACE)
# CHECK-NEXT: %allocate_1
# CHECK-SAME: ((N, K), f16, SHARED_ADDRESS_SPACE)
# CHECK-SAME: ((N, K), f16, $SHARED_ADDRESS_SPACE)
# CHECK-NEXT: reduction
# CHECK-NEXT: %write
# CHECK-SAME: (%reduction, %c, 4)
Expand Down
21 changes: 21 additions & 0 deletions shark_turbine/kernel/lang/global_symbols.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from .._support.indexing import index_symbol

# Global symbols used throughout the code.

# Address spaces.
GLOBAL_ADDRESS_SPACE = index_symbol("$GLOBAL_ADDRESS_SPACE")
SHARED_ADDRESS_SPACE = index_symbol("$SHARED_ADDRESS_SPACE")

# Distribution symbols.
WORKGROUP_0 = index_symbol("$WG0")
WORKGROUP_1 = index_symbol("$WG1")
WORKGROUP_2 = index_symbol("$WG2")

THREAD_0 = index_symbol("$T0")
THREAD_1 = index_symbol("$T1")
THREAD_2 = index_symbol("$T2")

# MMA symbols
MMA_LHS = index_symbol("$MMA_LHS")
MMA_RHS = index_symbol("$MMA_RHS")
MMA_ACC = index_symbol("$MMA_ACC")
8 changes: 4 additions & 4 deletions shark_turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .._support.dtype import DataType
from .._support.regions import RegionGraph
from .base import OpDispatcher
import shark_turbine.kernel.lang as tkl
from ..lang.global_symbols import MMA_ACC, MMA_LHS, MMA_RHS

T = TypeVar("T", bound=Type[Any])
AccT = TypeVar("AccT")
Expand Down Expand Up @@ -534,17 +534,17 @@ def operand_index(

@property
def lhs_index(self) -> list[IndexSequence]:
operand_map = {tkl.sym.MMA_LHS: 1, tkl.sym.MMA_RHS: 0, tkl.sym.MMA_ACC: 0}
operand_map = {MMA_LHS: 1, MMA_RHS: 0, MMA_ACC: 0}
return self.operand_index(operand_map, self.lhs_type.symbolic_shape)

@property
def rhs_index(self) -> list[IndexSequence]:
operand_map = {tkl.sym.MMA_LHS: 0, tkl.sym.MMA_RHS: 1, tkl.sym.MMA_ACC: 0}
operand_map = {MMA_LHS: 0, MMA_RHS: 1, MMA_ACC: 0}
return self.operand_index(operand_map, self.rhs_type.symbolic_shape)

@property
def acc_index(self) -> list[IndexSequence]:
operand_map = {tkl.sym.MMA_LHS: 0, tkl.sym.MMA_RHS: 0, tkl.sym.MMA_ACC: 1}
operand_map = {MMA_LHS: 0, MMA_RHS: 0, MMA_ACC: 1}
if self.acc.type is None:
return None
return self.operand_index(operand_map, self.acc_type.symbolic_shape)
Expand Down
5 changes: 0 additions & 5 deletions shark_turbine/kernel/wave/address_spaces.py

This file was deleted.

14 changes: 4 additions & 10 deletions shark_turbine/kernel/wave/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
from dataclasses import dataclass
from enum import Enum
from typing import Optional
import shark_turbine.kernel.lang as tkl
from sympy import ceiling, Piecewise, floor

from .._support.indexing import IndexExpr, IndexSymbol
from .indexing import IndexSequence
from .distribution_symbols import *
from ..lang.global_symbols import *


class MMAType(Enum):
Expand Down Expand Up @@ -54,11 +53,6 @@ class HardwareConstraint(Constraint):
mma_type: Optional[MMAType] = MMAType.F32_16x16x16_F16
vector_shapes: Optional[dict[IndexSymbol, int]] = None

def __post_init__(self):
self.LHS = tkl.sym.MMA_LHS
self.RHS = tkl.sym.MMA_RHS
self.ACC = tkl.sym.MMA_ACC

@property
def mma_matrix_shapes(self) -> tuple[int]:
# TODO: Eventually the shapes and indices should be provided by a tool
Expand Down Expand Up @@ -93,18 +87,18 @@ def apply(self, mma_index: int) -> IndexSequence:
case MMAType.F32_16x16x16_F16:
offset = [
Piecewise(
(lane % 16, ~self.ACC), (4 * floor(lane / 16), self.ACC)
(lane % 16, ~MMA_ACC), (4 * floor(lane / 16), MMA_ACC)
), # M
lane % 16, # N
4 * floor(lane / 16), # K
]
size = [
Piecewise((1, ~self.ACC), (4, self.ACC)), # M
Piecewise((1, ~MMA_ACC), (4, MMA_ACC)), # M
1, # N
4, # K
]
stride = [
Piecewise((1, ~self.ACC), (16, self.ACC)), # M
Piecewise((1, ~MMA_ACC), (16, MMA_ACC)), # M
1, # N
1, # K
]
Expand Down
9 changes: 0 additions & 9 deletions shark_turbine/kernel/wave/distribution_symbols.py

This file was deleted.

2 changes: 1 addition & 1 deletion shark_turbine/kernel/wave/hoisting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from shark_turbine.kernel._support.tracing import CapturedTrace
import torch.fx as fx
from ..ops.wave_ops import *
from .address_spaces import *
from ..lang.global_symbols import *

logger = get_logger("turbine.wave.hoisting")

Expand Down
5 changes: 1 addition & 4 deletions shark_turbine/kernel/wave/promotion.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from ...support.logging import get_logger
from shark_turbine.kernel._support.tracing import CapturedTrace
import torch.fx as fx
from ..ops.wave_ops import *
from .address_spaces import *
import shark_turbine.kernel.lang as tkl
from ..lang.global_symbols import *

logger = get_logger("turbine.wave.promotion")

Expand Down

0 comments on commit cf66a4c

Please sign in to comment.