Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TKW] Fix types, shapes and propagate resolved indexing #177

Merged
merged 3 commits into from
Oct 17, 2024

Conversation

raikonenfnu
Copy link
Contributor

@raikonenfnu raikonenfnu commented Sep 27, 2024

This PR add supports for doing accumulate on non-induction variables. For most part our stack already supports this, but we'd need to fix reduction symbolic shape which used to be just the input shape to input shape - reduction dim. Without this our thread shape analysis wouldn't be able to handle broadcasting of reduced op properly.

To support the above, we also introduce Extract op in addition to the existing ExtractSlice to fix the shape types. Specifically, ExtractSlice follow upstream's semantic where we just slice but do not reduce any dimensions. For the ReduceOp case, specifically the local reduction, we'd want it to have a reducing semantic on the fastest dimension.

Additionally, we also add a propagation of our resolution for thread shape. This is helpful in the test we added which is a broadcast-sub followed by an exp2.

@raikonenfnu raikonenfnu force-pushed the fixTypesUpstream branch 3 times, most recently from 25c6d23 to e03e3ce Compare September 27, 2024 21:34
@raikonenfnu
Copy link
Contributor Author

Currently broken on igemm because ExtractSlice used in iGEMM semantic needs to follow upstream. We need to implement an extract/indexing op for semantic required in local thread reduction

@raikonenfnu raikonenfnu force-pushed the fixTypesUpstream branch 3 times, most recently from d3f6edb to ab70a52 Compare October 6, 2024 22:00
@raikonenfnu raikonenfnu changed the title [TKW] Fix types and shapes for accumulate on non IV [TKW] Fix types, shapes and propagate resolved indexing Oct 6, 2024
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>

# Typically only fastest dim has non-unit dim,
# but if all unit-dim get fastest/last one.
all_unit_dims = lambda index: all(x.size == 1 for x in index.values())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this check? Can we merge this with the non_unit_dim check? if len(non_unit_dim) == 0, use fastest dim.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done! I think :)

Copy link
Contributor

@harsh-nod harsh-nod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks pretty good. Just a few comments

Copy link
Contributor

@harsh-nod harsh-nod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! just some minor nits

Signed-off-by: Stanley Winata <stanley.winata@amd.com>
@raikonenfnu raikonenfnu merged commit 594d580 into iree-org:main Oct 17, 2024
6 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants