-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for varying vector shapes #247
Conversation
181b93e
to
3a69637
Compare
This PR adds support for expanding operators with varying vector shapes, specifically for the MMA case where either the same dimension has different vector shapes in different mmas or if different instructions are being used. The idea is to insert a reshape operator whenever such a shape mismatch is discovered. The reshape operator lowers to an extract or concatenate operation, depending on the context. Signed-off-by: Harsh Menon <harsh@nod-labs.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super awesome stuff Harsh! this is exciting! I have a few Qs tho :)
Signed-off-by: Harsh Menon <harsh@nod-labs.com>
Signed-off-by: Harsh Menon <harsh@nod-labs.com>
iree/turbine/kernel/wave/utils.py
Outdated
for dst in mma_nodes: | ||
if src == dst: | ||
continue | ||
custom_dst = get_custom(dst) | ||
lhs_slice = capture_backward_slice(custom_dst.lhs) | ||
rhs_slice = capture_backward_slice(custom_dst.rhs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we remove the double nested loop, and instead find the mma op inside lhs_slice/rhs_slice?
I think we can add a stop right after is_mma, since we'd only want to find the closest MMA in the use def chain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes we can, but then we would have to iterate through all the operators in the slice to see if any of them has an mma as an argument. That would end up with a triple-nested-loop. Unless you are thinking of something else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay I think I have a cleaner version now. let me know what you think.
new_args.append(new_node.fx_node) | ||
|
||
reshape.update_arg("args", new_args) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to do anything special to ensure this ordering of expansion is correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so, but do you have an example in mind?
Signed-off-by: Harsh Menon <harsh@nod-labs.com>
Signed-off-by: Harsh Menon <harsh@nod-labs.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks awesome, merge please!
This PR adds support for expanding operators
with varying vector shapes, specifically for the
MMA case where either the same dimension has different vector shapes in different mmas or if different instructions are being used.
The idea is to insert a reshape operator whenever such a shape mismatch is discovered. The reshape operator lowers to an extract or concatenate operation, depending on the context.