diff --git a/shark_turbine/kernel/wave/constraints.py b/shark_turbine/kernel/wave/constraints.py index 4d98c9f1..ab8d913a 100644 --- a/shark_turbine/kernel/wave/constraints.py +++ b/shark_turbine/kernel/wave/constraints.py @@ -79,11 +79,11 @@ def threads_per_block(self) -> tuple[int]: @property def linearized_thread_id(self) -> IndexExpr: thread_ids = [THREAD_0, THREAD_1, THREAD_2] - threads_per_block = ( - [1] - + [self.threads_per_block[0]] - + [self.threads_per_block[0] * self.threads_per_block[1]] - ) + threads_per_block = [ + 1, + self.threads_per_block[0], + self.threads_per_block[0] * self.threads_per_block[1], + ] return sum([x * y for x, y in zip(thread_ids, threads_per_block)]) def apply(self, mma_index: int) -> IndexSequence: @@ -91,23 +91,23 @@ def apply(self, mma_index: int) -> IndexSequence: match self.mma_type: # (M x K, N x K) -> M x N case MMAType.F32_16x16x16_F16: - offset = { - 0: Piecewise( + offset = [ + Piecewise( (lane % 16, ~self.ACC), (4 * floor(lane / 16), self.ACC) ), # M - 1: lane % 16, # N - 2: 4 * floor(lane / 16), # K - } - size = { - 0: Piecewise((0, ~self.ACC), (4, self.ACC)), # M - 1: 0, # N - 2: 4, # K - } - stride = { - 0: Piecewise((1, ~self.ACC), (16, self.ACC)), # M - 1: 1, # N - 2: 1, # K - } + lane % 16, # N + 4 * floor(lane / 16), # K + ] + size = [ + Piecewise((1, ~self.ACC), (4, self.ACC)), # M + 1, # N + 4, # K + ] + stride = [ + Piecewise((1, ~self.ACC), (16, self.ACC)), # M + 1, # N + 1, # K + ] return IndexSequence( offset[mma_index], size[mma_index], stride[mma_index] ) diff --git a/shark_turbine/kernel/wave/expansion.py b/shark_turbine/kernel/wave/expansion.py index f744f5a4..42351efa 100644 --- a/shark_turbine/kernel/wave/expansion.py +++ b/shark_turbine/kernel/wave/expansion.py @@ -141,25 +141,36 @@ def set_node_indices( def compute_index(node: fx.Node) -> bool: custom = get_custom(node) - custom.index = {} + custom.index = {dim: None for dim in custom.indexing_dims} for dim in custom.indexing_dims: for constraint in constraints: - if ( + mma_check = ( + isinstance(constraint, HardwareConstraint) + and dim in mma_index + and isinstance(custom, MMA) + ) + + constraint_check = ( not isinstance(constraint, HardwareConstraint) - and dim != constraint.dim - ): + and dim == constraint.dim + ) + + if (not mma_check) and (not constraint_check): continue - if not custom.index: - custom.index = { - dim: IndexSequence(0, 1) for dim in custom.indexing_dims - } + + if custom.index[dim] is None: + custom.index[dim] = IndexSequence(0, 0) + if isinstance(constraint, HardwareConstraint): - if dim in mma_index and isinstance(custom, MMA): - custom.index[dim] += constraint.apply(mma_index[dim]) - elif dim == constraint.dim: - custom.index[dim] += constraint.apply() - if custom.index: - setattr(custom.fx_node, "index", custom.index) + # Thread-level constraint specifies size and stride. + index_seq: IndexSequence = constraint.apply(mma_index[dim]) + custom.index[dim].size = index_seq.size + custom.index[dim].stride = index_seq.stride + else: + index_seq: IndexSequence = constraint.apply() + custom.index[dim].start += index_seq.start + + setattr(custom.fx_node, "index", custom.index) return False trace.walk(compute_index) diff --git a/shark_turbine/kernel/wave/indexing.py b/shark_turbine/kernel/wave/indexing.py index a664a75a..037a3215 100644 --- a/shark_turbine/kernel/wave/indexing.py +++ b/shark_turbine/kernel/wave/indexing.py @@ -13,26 +13,18 @@ class IndexSequence: size: IndexExpr | int stride: Optional[IndexExpr | int] = 1 - def __add__(self, other: Any) -> Any: - if isinstance(other, IndexSequence): - return IndexSequence( - self.start + other.start, - self.size * other.size, - self.stride * other.stride, - ) - else: - raise NotImplementedError("IndexSequence addition not implemented!") + def _subs( + self, value: int | IndexExpr, map: dict[IndexSymbol, int] + ) -> int | IndexExpr: + new_value = value + if isinstance(value, IndexExpr): + new_value = value.subs(map) + return new_value def subs(self, map: dict[IndexSymbol, int]): - start = self.start - if isinstance(self.start, IndexExpr): - start = self.start.subs(map) - size = self.size - if isinstance(self.size, IndexExpr): - size = self.size.subs(map) - stride = self.stride - if isinstance(self.stride, IndexExpr): - stride = self.stride.subs(map) + start = self._subs(self.start, map) + size = self._subs(self.size, map) + stride = self._subs(self.stride, map) return IndexSequence(start, size, stride) def __repr__(self):