Skip to content

Commit

Permalink
Add description of vector_shapes keyword in hardware constraint
Browse files Browse the repository at this point in the history
This PR describes the vector_shapes keyword and updates its
type to IndexSymbol -> int, as it is a map from tensor dimensions
to tile sizes.

Signed-off-by: Harsh Menon <harsh@nod-labs.com>
  • Loading branch information
harsh-nod committed Jul 29, 2024
1 parent baf58be commit 2e1589f
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion shark_turbine/kernel/wave/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,17 @@ class HardwareConstraint(Constraint):
we want all mma operations in the microkernel to be
mapped to a hardware mma instruction of shape (16x16x16).
This translates to a hardware specific index constraint.
Not all computation graphs have mma operators in them. In
these situations, the user can specify the vector shape they
want to tile to by specifying the vector shapes dictionary
which maps a tensor dimension to its corresponding tile size.
"""

threads_per_wave: int
waves_per_block: Optional[tuple[int, int, int]] = None
mma_type: Optional[MMAType] = MMAType.F32_16x16x16_F16
vector_shapes: Optional[dict[IndexExpr, int]] = None
vector_shapes: Optional[dict[IndexSymbol, int]] = None

@property
def mma_matrix_shapes(self):
Expand Down

0 comments on commit 2e1589f

Please sign in to comment.