Skip to content

Commit

Permalink
Update based on Stan's comments
Browse files Browse the repository at this point in the history
- Removed add operation on index sequences
- General cleanup / refactor

Signed-off-by: Harsh Menon <harsh@nod-labs.com>
  • Loading branch information
harsh-nod committed Jul 31, 2024
1 parent ff5a81d commit 9982ca0
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 52 deletions.
40 changes: 20 additions & 20 deletions shark_turbine/kernel/wave/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,35 +79,35 @@ def threads_per_block(self) -> tuple[int]:
@property
def linearized_thread_id(self) -> IndexExpr:
thread_ids = [THREAD_0, THREAD_1, THREAD_2]
threads_per_block = (
[1]
+ [self.threads_per_block[0]]
+ [self.threads_per_block[0] * self.threads_per_block[1]]
)
threads_per_block = [
1,
self.threads_per_block[0],
self.threads_per_block[0] * self.threads_per_block[1],
]
return sum([x * y for x, y in zip(thread_ids, threads_per_block)])

def apply(self, mma_index: int) -> IndexSequence:
lane = self.linearized_thread_id
match self.mma_type:
# (M x K, N x K) -> M x N
case MMAType.F32_16x16x16_F16:
offset = {
0: Piecewise(
offset = [
Piecewise(
(lane % 16, ~self.ACC), (4 * floor(lane / 16), self.ACC)
), # M
1: lane % 16, # N
2: 4 * floor(lane / 16), # K
}
size = {
0: Piecewise((0, ~self.ACC), (4, self.ACC)), # M
1: 0, # N
2: 4, # K
}
stride = {
0: Piecewise((1, ~self.ACC), (16, self.ACC)), # M
1: 1, # N
2: 1, # K
}
lane % 16, # N
4 * floor(lane / 16), # K
]
size = [
Piecewise((1, ~self.ACC), (4, self.ACC)), # M
1, # N
4, # K
]
stride = [
Piecewise((1, ~self.ACC), (16, self.ACC)), # M
1, # N
1, # K
]
return IndexSequence(
offset[mma_index], size[mma_index], stride[mma_index]
)
Expand Down
39 changes: 25 additions & 14 deletions shark_turbine/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,25 +141,36 @@ def set_node_indices(

def compute_index(node: fx.Node) -> bool:
custom = get_custom(node)
custom.index = {}
custom.index = {dim: None for dim in custom.indexing_dims}
for dim in custom.indexing_dims:
for constraint in constraints:
if (
mma_check = (
isinstance(constraint, HardwareConstraint)
and dim in mma_index
and isinstance(custom, MMA)
)

constraint_check = (
not isinstance(constraint, HardwareConstraint)
and dim != constraint.dim
):
and dim == constraint.dim
)

if (not mma_check) and (not constraint_check):
continue
if not custom.index:
custom.index = {
dim: IndexSequence(0, 1) for dim in custom.indexing_dims
}

if custom.index[dim] is None:
custom.index[dim] = IndexSequence(0, 0)

if isinstance(constraint, HardwareConstraint):
if dim in mma_index and isinstance(custom, MMA):
custom.index[dim] += constraint.apply(mma_index[dim])
elif dim == constraint.dim:
custom.index[dim] += constraint.apply()
if custom.index:
setattr(custom.fx_node, "index", custom.index)
# Thread-level constraint specifies size and stride.
index_seq: IndexSequence = constraint.apply(mma_index[dim])
custom.index[dim].size = index_seq.size
custom.index[dim].stride = index_seq.stride
else:
index_seq: IndexSequence = constraint.apply()
custom.index[dim].start += index_seq.start

setattr(custom.fx_node, "index", custom.index)
return False

trace.walk(compute_index)
Expand Down
28 changes: 10 additions & 18 deletions shark_turbine/kernel/wave/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,18 @@ class IndexSequence:
size: IndexExpr | int
stride: Optional[IndexExpr | int] = 1

def __add__(self, other: Any) -> Any:
if isinstance(other, IndexSequence):
return IndexSequence(
self.start + other.start,
self.size * other.size,
self.stride * other.stride,
)
else:
raise NotImplementedError("IndexSequence addition not implemented!")
def _subs(
self, value: int | IndexExpr, map: dict[IndexSymbol, int]
) -> int | IndexExpr:
new_value = value
if isinstance(value, IndexExpr):
new_value = value.subs(map)
return new_value

def subs(self, map: dict[IndexSymbol, int]):
start = self.start
if isinstance(self.start, IndexExpr):
start = self.start.subs(map)
size = self.size
if isinstance(self.size, IndexExpr):
size = self.size.subs(map)
stride = self.stride
if isinstance(self.stride, IndexExpr):
stride = self.stride.subs(map)
start = self._subs(self.start, map)
size = self._subs(self.size, map)
stride = self._subs(self.stride, map)
return IndexSequence(start, size, stride)

def __repr__(self):
Expand Down

0 comments on commit 9982ca0

Please sign in to comment.