-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add indexing to nodes #59
Conversation
|
||
@property | ||
def acc_index(self) -> list[IndexSequence]: | ||
operand_map = {tkl.sym.MMA_LHS: 0, tkl.sym.MMA_RHS: 0, tkl.sym.MMA_ACC: 1} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we are subbing IndexExpr.sub by an operand_map, would the symbolic expression already have the tkl.sym.MMA_LHS
, tkl.sym.MMA_RHS
, and/or tkl.sym.MMA_ACC
as part of the expression?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, so this what I was talking about in the meeting. For each operator, we specify a dimension-specific index. So for MMA, a separate index of M, N and K. Rather than partition these dimensional indices further into operand specific indices, I have them represented as a piecewise function where the conditions depend on MMA_{LHS/RHS/ACC} (In the current PR, just ACC). So in order to extract the operand specific parts, we just substitute the appropriate values as above. The advantage of this piecewise function approach is it allows you to see where the dimensional mapping bifurcates and for which operands and allows you to reason about "layout changes". (For example, you could ask questions like - what setting of LHS, RHS and ACC would make the indices of MMA_0 be the same as that of the LHS of MMA1?)
elif dim == constraint.dim: | ||
custom.index[dim] += constraint.apply() | ||
if custom.index: | ||
setattr(custom.fx_node, "index", custom.index) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this mean during handle_op, we cshould be able to just access this "index" attribute? do we still need to take the size from elem_per_thread, or would it be handled somewhere here as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you should be able to just access the index attribute, but you will have to check if it is None. You will still need to get the size from elem_per_thread. The index sequence tells you how many you can load but that could be different from how many the user has requested to load.
continue | ||
if not custom.index: | ||
custom.index = { | ||
dim: IndexSequence(0, 1) for dim in custom.indexing_dims |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we expect the handling of the not custom index and loop over custom.indexing_dims
in this for loop? would it not make sense to handle this outside of for dim in custom.indexing_dims:
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My intention here was to only set the index attribute on operators that were affected by the constraints. The reason for this is to distinguish between the following 3 scenarios: an operator that has no index, an operator that has an index but the index is None and an operator that has an index that is not None. I dont think we need to distinguish between the first 2 options - so we can always set the index, but the index could be None. So with that I could move the initialization of custom_index outside the loop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
This primary purpose of this PR is to annotate nodes with their access patterns based on the workgroup, tiling and MMA constraints. This is accomplished prior to expansion and propagates through expansion to the expanded nodes. Signed-off-by: Harsh Menon <harsh@nod-labs.com>
- Removed add operation on index sequences - General cleanup / refactor Signed-off-by: Harsh Menon <harsh@nod-labs.com>
This primary purpose of this PR is to annotate nodes with their access patterns based on the workgroup, tiling and MMA constraints. This is accomplished prior to expansion and propagates through expansion to the expanded nodes.