-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TKW] Fix types, shapes and propagate resolved indexing #177
Conversation
25c6d23
to
e03e3ce
Compare
Currently broken on igemm because ExtractSlice used in iGEMM semantic needs to follow upstream. We need to implement an extract/indexing op for semantic required in local thread reduction |
d3f6edb
to
ab70a52
Compare
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
ab70a52
to
08d43e0
Compare
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
iree/turbine/kernel/ops/wave_ops.py
Outdated
|
||
# Typically only fastest dim has non-unit dim, | ||
# but if all unit-dim get fastest/last one. | ||
all_unit_dims = lambda index: all(x.size == 1 for x in index.values()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this check? Can we merge this with the non_unit_dim check? if len(non_unit_dim) == 0, use fastest dim.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done! I think :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks pretty good. Just a few comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm! just some minor nits
This PR add supports for doing accumulate on non-induction variables. For most part our stack already supports this, but we'd need to fix reduction symbolic shape which used to be just the input shape to input shape - reduction dim. Without this our thread shape analysis wouldn't be able to handle broadcasting of reduced op properly.
To support the above, we also introduce
Extract
op in addition to the existingExtractSlice
to fix the shape types. Specifically, ExtractSlice follow upstream's semantic where we just slice but do not reduce any dimensions. For the ReduceOp case, specifically the local reduction, we'd want it to have a reducing semantic on the fastest dimension.Additionally, we also add a propagation of our resolution for thread shape. This is helpful in the test we added which is a broadcast-sub followed by an exp2.