Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Add fast path for lowering scatter #1214

Open
wants to merge 16 commits into
base: main
Choose a base branch
from

Conversation

erick-xanadu
Copy link
Contributor

@erick-xanadu erick-xanadu commented Oct 17, 2024

Context: The semantics of mhlo.scatter and stablehlo.scatter are the same. Currently we have a lowering from mhlo.scatter to upstream MLIR dialects (a mixture of func and scf dialects). The current implementation will lower the mhlo.scatter operation into a loop and move element by element values into the result tensor using tensor.insert.

There are some special cases where it is possible to use tensor.insert_slice instead of tensor.insert. The difference between tensor.insert and tensor.insert_slice is that tensor.insert inserts scalar elements into tensors while tensor.insert_slice may insert a tensor into a larger tensor.

This has implications on performance. While both lowerings of scatter should be equivalent, tensor.insert_slice will be lowered to memref.subview and memref.copy which lowers to a single memcpy. This is the root cause of the performance issues described in #1153.

Description of the Change: This PR adds an optimized lowering to mhlo.scatter to tensor.insert_slice in cases we can detect this lowering preserves the same semantics.

This detection can be generalized in the future. It currently makes the following checks:

  1. unique_indices and indices_are_sorted are both true: The semantics state the following:

If indices_are_sorted is true then the implementation can assume that scatter_indices are sorted with respect to scatter_dims_to_operand_dims, otherwise the behavior is undefined. More formally, for all i1 < i2 from indices(result), full_start_index(i1) <= full_start_index(i2).

If unique_indices is true then the implementation can assume that all result_index indices being scattered to are unique. If unique_indices is true but the indices being scattered to are not unique then the behavior is undefined.

I believe these are undefined to keep the possibility of implementing stablehlo.scatter in parallel.

  1. Only one %input , %result and %update tensor: The operands %input and %update are of variable length. To generalize the lowering to tensor.insert_slice one could run a pass earlier that will canonicalize the scatter operation into a series of scatter operations with a single %input , %result and %update tensor specified in the operation.
  2. **Restricting update_computation to assignment of %update_tensor. The stablehlo.scatter operation is more general than a simple assignment. It allows for the %input tensor to be updated with %update tensor values by running the function update_computation on the corresponding elements. We restrict update_computation to be the assignment to the %update values: I.e.,
({
      ^bb0(%input_element: tensor<T>, %update_element: tensor<T>):
        mhlo.return %update_element : tensor<T>
      })

This restriction may be relaxed by first computing a temporary tensor that will hold the result of the update_computation and replacing the uses of update with this new tensor and using the assignment function above.
4. No batching: Our current version of MLIR does not support the batching attributes in the operation. To generalize this more investigation is required.
5. Single full slice: This means that we are going to assign the whole update tensor to the input tensor and not just a subset. We could generalize this by using dynamic sizes when generating the tensor.insert_slice operation.
6. rank(%scatter_indices) == 1 and indexVectorDim == scatterIndicesTy.getRank() - 1: This implies that the scatter indices are valid coordinates and do not need to be treated as tensors. To generalize this would imply looping over the number of valid indices depending on the shape of the scatter indices and generating a single tensor.insert_slice operation for each iteration.

Benefits: Performance

Possible Drawbacks: I would feel more comfortable having more time and upstreaming this to stableHLO. That way, we can get a review from the StableHLO team to make sure the semantics are correct since this is a somewhat complex operation.

Related GitHub Issues: Fixes #1153

[sc-76025]

@erick-xanadu erick-xanadu changed the title [wip] Fixing scatter [MLIR] Add fast path for lowering scatter Oct 18, 2024
@erick-xanadu erick-xanadu marked this pull request as ready for review October 18, 2024 19:45
@erick-xanadu erick-xanadu marked this pull request as draft October 18, 2024 20:55
@erick-xanadu erick-xanadu marked this pull request as ready for review October 18, 2024 22:32
Copy link

codecov bot commented Oct 18, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 97.96%. Comparing base (7c5b828) to head (2cfa649).

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #1214   +/-   ##
=======================================
  Coverage   97.96%   97.96%           
=======================================
  Files          77       77           
  Lines       11244    11245    +1     
  Branches      967      967           
=======================================
+ Hits        11015    11016    +1     
  Misses        180      180           
  Partials       49       49           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Performance degradation with vmap and large data
1 participant