Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for varying vector shapes #247

Merged
merged 5 commits into from
Oct 31, 2024
Merged

Conversation

harsh-nod
Copy link
Contributor

This PR adds support for expanding operators
with varying vector shapes, specifically for the
MMA case where either the same dimension has different vector shapes in different mmas or if different instructions are being used.

The idea is to insert a reshape operator whenever such a shape mismatch is discovered. The reshape operator lowers to an extract or concatenate operation, depending on the context.

This PR adds support for expanding operators
with varying vector shapes, specifically for the
MMA case where either the same dimension has different
vector shapes in different mmas or if different instructions
are being used.

The idea is to insert a reshape operator whenever such a shape
mismatch is discovered.  The reshape operator lowers to an extract or
concatenate operation, depending on the context.

Signed-off-by: Harsh Menon <harsh@nod-labs.com>
Copy link
Contributor

@raikonenfnu raikonenfnu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super awesome stuff Harsh! this is exciting! I have a few Qs tho :)

Signed-off-by: Harsh Menon <harsh@nod-labs.com>
Signed-off-by: Harsh Menon <harsh@nod-labs.com>
Comment on lines 267 to 272
for dst in mma_nodes:
if src == dst:
continue
custom_dst = get_custom(dst)
lhs_slice = capture_backward_slice(custom_dst.lhs)
rhs_slice = capture_backward_slice(custom_dst.rhs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we remove the double nested loop, and instead find the mma op inside lhs_slice/rhs_slice?
I think we can add a stop right after is_mma, since we'd only want to find the closest MMA in the use def chain.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes we can, but then we would have to iterate through all the operators in the slice to see if any of them has an mma as an argument. That would end up with a triple-nested-loop. Unless you are thinking of something else?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay I think I have a cleaner version now. let me know what you think.

Comment on lines +586 to +589
new_args.append(new_node.fx_node)

reshape.update_arg("args", new_args)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to do anything special to ensure this ordering of expansion is correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, but do you have an example in mind?

Signed-off-by: Harsh Menon <harsh@nod-labs.com>
Signed-off-by: Harsh Menon <harsh@nod-labs.com>
Copy link
Contributor

@raikonenfnu raikonenfnu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks awesome, merge please!

@harsh-nod harsh-nod merged commit 8febe6a into iree-org:main Oct 31, 2024
7 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants