-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR] Add fast path for lowering scatter #1214
Open
erick-xanadu
wants to merge
16
commits into
main
Choose a base branch
from
eochoa/2024-10-16/vmap
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
+503
−11
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
erick-xanadu
changed the title
[wip] Fixing scatter
[MLIR] Add fast path for lowering scatter
Oct 18, 2024
3 tasks
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1214 +/- ##
=======================================
Coverage 97.96% 97.96%
=======================================
Files 77 77
Lines 11244 11245 +1
Branches 967 967
=======================================
+ Hits 11015 11016 +1
Misses 180 180
Partials 49 49 ☔ View full report in Codecov by Sentry. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Context: The semantics of
mhlo.scatter
andstablehlo.scatter
are the same. Currently we have a lowering frommhlo.scatter
to upstream MLIR dialects (a mixture of func and scf dialects). The current implementation will lower themhlo.scatter
operation into a loop and move element by element values into the result tensor usingtensor.insert
.There are some special cases where it is possible to use
tensor.insert_slice
instead oftensor.insert
. The difference betweentensor.insert
andtensor.insert_slice
is thattensor.insert
inserts scalar elements into tensors whiletensor.insert_slice
may insert a tensor into a larger tensor.This has implications on performance. While both lowerings of scatter should be equivalent,
tensor.insert_slice
will be lowered tomemref.subview
andmemref.copy
which lowers to a singlememcpy
. This is the root cause of the performance issues described in #1153.Description of the Change: This PR adds an optimized lowering to
mhlo.scatter
totensor.insert_slice
in cases we can detect this lowering preserves the same semantics.This detection can be generalized in the future. It currently makes the following checks:
unique_indices
andindices_are_sorted
are both true: The semantics state the following:I believe these are undefined to keep the possibility of implementing
stablehlo.scatter
in parallel.%input
,%result
and%update
tensor: The operands%input
and%update
are of variable length. To generalize the lowering totensor.insert_slice
one could run a pass earlier that will canonicalize thescatter
operation into a series ofscatter
operations with a single%input
,%result
and%update
tensor specified in the operation.update_computation
to assignment of%update_tensor
. Thestablehlo.scatter
operation is more general than a simple assignment. It allows for the%input
tensor to be updated with%update
tensor values by running the functionupdate_computation
on the corresponding elements. We restrictupdate_computation
to be the assignment to the%update
values: I.e.,This restriction may be relaxed by first computing a temporary tensor that will hold the result of the
update_computation
and replacing the uses ofupdate
with this new tensor and using the assignment function above.4. No batching: Our current version of MLIR does not support the batching attributes in the operation. To generalize this more investigation is required.
5. Single full slice: This means that we are going to assign the whole
update
tensor to theinput
tensor and not just a subset. We could generalize this by using dynamic sizes when generating thetensor.insert_slice
operation.6. rank(%scatter_indices) == 1 and indexVectorDim == scatterIndicesTy.getRank() - 1: This implies that the scatter indices are valid coordinates and do not need to be treated as tensors. To generalize this would imply looping over the number of valid indices depending on the shape of the scatter indices and generating a single
tensor.insert_slice
operation for each iteration.Benefits: Performance
Possible Drawbacks: I would feel more comfortable having more time and upstreaming this to stableHLO. That way, we can get a review from the StableHLO team to make sure the semantics are correct since this is a somewhat complex operation.
Related GitHub Issues: Fixes #1153
[sc-76025]