Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Add fast path for lowering scatter #1214

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,10 @@
* Registers the func dialect as a requirement for running the scatter lowering pass.
[(#1216)](https://github.com/PennyLaneAI/catalyst/pull/1216)

* Fixes #1153 : a performance issue with vmap with its root cause in the
lowering of the scatter operation.
[(#1214)](https://github.com/PennyLaneAI/catalyst/pull/1214)

<h3>Internal changes</h3>

* Remove deprecated pennylane code across the frontend.
Expand Down
16 changes: 6 additions & 10 deletions frontend/catalyst/api_extensions/function_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,10 @@ def __call__(self, *args, **kwargs):
fn_args = tree_unflatten(args_tree, fn_args_flat)

# Run 'fn' one time to get output-shape
init_result = self.fn(*fn_args, **kwargs)
_, shape = jax.make_jaxpr(self.fn, return_shape=True)(*fn_args, **kwargs)
shapes, init_result_tree = tree_flatten(shape)
init_result_flat = [jnp.zeros(shape=shape.shape, dtype=shape.dtype) for shape in shapes]
init_result = tree_unflatten(init_result_tree, init_result_flat)

# Check the validity of the output w.r.t. out_axes
out_axes_deep_struct = tree_structure(self.out_axes, is_leaf=lambda x: x is None)
Expand All @@ -238,8 +241,6 @@ def __call__(self, *args, **kwargs):
f"{out_axes_deep_struct} axis specifiers and {init_result_deep_struct} results."
)

init_result_flat, init_result_tree = tree_flatten(init_result)

num_axes_out = len(init_result_flat)

if isinstance(self.out_axes, int):
Expand All @@ -255,16 +256,11 @@ def __call__(self, *args, **kwargs):
# in the flatten format with respect to the 'init_result' shape
batched_result_list = []
for j in range(num_axes_out):
out_shape = (
(batch_size,)
if not init_result_flat[j].shape
else (batch_size, *init_result_flat[j].shape)
)
out_shape = (batch_size, *init_result_flat[j].shape)
batched_result_list.append(jnp.zeros(shape=out_shape, dtype=init_result_flat[j].dtype))
batched_result_list[j] = batched_result_list[j].at[0].set(init_result_flat[j])

# Apply mapping batched_args[1:] ---> fn(args)
@for_loop(1, batch_size, 1)
@for_loop(0, batch_size, 1)
def loop_fn(i, batched_result_list):
fn_args_flat = args_flat
for loc in batch_loc:
Expand Down
1 change: 0 additions & 1 deletion frontend/test/lit/test_mitigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def circuit():
# CHECK: mitigation.zne @one_shot_wrapper(%c_0) folding( global) numFolds(%6 : tensor<2xi64>) : (tensor<5xi1>) -> tensor<2xf64>

# CHECK: func.func private @one_shot_wrapper(%arg0: tensor<5xi1>) -> tensor<f64>
# CHECK: catalyst.launch_kernel @module_circuit::@circuit() : () -> tensor<f64>
# CHECK: scf.for
# CHECK: catalyst.launch_kernel @module_circuit::@circuit() : () -> tensor<f64>
print(mcm_method_with_zne.mlir)
1 change: 1 addition & 0 deletions mlir/include/Catalyst/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def ScatterLoweringPass : Pass<"scatter-lowering"> {
"mlir::func::FuncDialect",
"index::IndexDialect",
"mhlo::MhloDialect",
"tensor::TensorDialect",
"scf::SCFDialect"
];

Expand Down
295 changes: 295 additions & 0 deletions mlir/lib/Catalyst/Transforms/ScatterPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,304 @@ namespace catalyst {
struct ScatterOpRewritePattern : public mlir::OpRewritePattern<mhlo::ScatterOp> {
using mlir::OpRewritePattern<mhlo::ScatterOp>::OpRewritePattern;

mlir::LogicalResult onlyOneInputUpdateAndResult(mhlo::ScatterOp op) const
{
// Assumption 1: only one input, one update, and one result
// * size(inputs) == 1
// * size(updates) == 1
// * size(results) == 1
// All ScatterOp ops with N inputs, N updates, and N results can be split
// into N ScatterOp ops with 1 input, 1 update and 1 result.
// This simplifies the analysis of the update_computation.

// From:
// C5: 0 < size(inputs) = size(updates) = N
// C24: element_type(results[i]) = Ei for all i in [0,N).
// It follows that:
// 0 < size(inputs) = size(updates) = size(results) = N
return op.getResults().size() == 1 ? success() : failure();
}

mlir::LogicalResult isAssignment(mhlo::ScatterOp op) const
{
// From:
// C23: update_computation has type
// (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ...,
// tensor<EN-1>) , where is_promotable(element_type(inputs[i]), Ei)
//
// On the description of the schedule:
// updated_values = update_computation(results...[result_index], updates_converted)
//
// It follows that:
// We are guaranteed that the update_computation
// function only has two parameters and one result.
// One parameter that corresponds to the
// result at the result_index
// and the single updates_converted_values
// This means that if the only operation inside the update_computation
// function is returning the second argument, then we are just assigning the update
// value to the result.
Region &region = op.getUpdateComputation();
Block &block = region.front();
bool oneOperation = block.begin() == --block.end();
if (!oneOperation) {
return failure();
}

mhlo::ReturnOp returnOp = dyn_cast<mhlo::ReturnOp>(block.getTerminator());
if (!returnOp) {
return failure();
}

return returnOp.getResults().front() == block.getArgument(1) ? success() : failure();
}

mlir::LogicalResult noBatching(mhlo::ScatterOp op) const
{
// Ok, now that we know it is an assignment, we need to worry about
// where exactly are we assigning and what are we assigning.
// First let's worry about the what we are assigning.
// It needs to be a proper slice. No preprocessing of anyway.
// What kind of preprocessing exists?
// * Batching for input
// * Batching for indices
//
// From:
// (C13) 0 <= input_batching_dims < rank(inputs[0])).
// (C17) size(input_batching_dims) == size(scatter_indices_batching_dims)
// Implies:
// If there is no input_batching_dims and no scatter_indices_batching
// TODO: This will always be success until we update our version of mlir-hlo.
// It looks we are using an old version where getInputBatchingDims was not yet available.
// See here:
// https://github.com/tensorflow/mlir-hlo/commit/5ac7c579c52ef02b13c29886a98672c2ade7c9b0
return success();
// Until then, keep this code commented:
// auto scatterDimNumbers = op.getScatterDimensionNumbers();
// return scatterDimNumbers.getInputBatchingDims().empty() ? success() : failure();
}

mlir::LogicalResult singleFullSlices(mhlo::ScatterOp op) const
{
// From:
// More formally, for all update_index in index_space(updates[0]):
// * update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims]
// * update_scatter_index = update_index[update_scatter_dims...]
// we want update_scatter_index to be empty. This would mean that:
// scatter_indices points to a location in the input tensor and the corresponding
// update value is a full window that is inserted at that location.
// So we have a single update
auto update = op.getUpdates().front();
// And we need to make sure that all of its axes are in the update_window_dims.
// From:
// (C7) is_unique(update_window_dims) and is_sorted(update_window_dims)
// Implies
auto updateTy = cast<RankedTensorType>(update.getType());
auto scatterDimNumbers = op.getScatterDimensionNumbers();
return updateTy.getRank() == scatterDimNumbers.getUpdateWindowDims().size() ? success()
: failure();
}

mlir::LogicalResult canBeDoneWithSingleTensorInsertSlice(mhlo::ScatterOp op) const
{
return cast<RankedTensorType>(op.getScatterIndices().getType()).getRank() == 1 ? success()
: failure();
}

mlir::LogicalResult lowerToTensorInsertSlice(mhlo::ScatterOp op,
mlir::PatternRewriter &rewriter) const
{
// mhlo::ScatterOp is exactly the same as stablehlo::ScatterOp
// See https://www.tensorflow.org/mlir/hlo_ops#mhloscatter_mhloscatterop
// and https://github.com/openxla/stablehlo/blob/main/docs/spec.md#scatter
//
// From https://github.com/openxla/stablehlo/blob/main/docs/spec.md#scatter:
//
// Semantics
//
// Produces results tensors which are equal to inputs tensors
// except that several slices specified by scatter_indices
// are updated with the values updates using update_computation.
//
// These simple semantics are obscured a bit by too many other details.
//
// Let's make some simplifying assumptions

// Add checks for supported cases (assumptions: no update windows dim, unique indices and
// sorted indices)
if (!op.getUniqueIndices() || !op.getIndicesAreSorted()) {
op.emitError() << "Indices are not unique and/or not sorted, unique boolean: "
<< op.getUniqueIndices()
<< ", sorted boolean :" << op.getIndicesAreSorted();
return failure();
}

// size(%result) == size(%update) == size(%input) == 1
if (failed(this->onlyOneInputUpdateAndResult(op))) {
return failure();
}
auto result = op.getResults().front();
auto input = op.getInputs().front();
auto update = op.getUpdates().front();
auto scatterIndices = op.getScatterIndices();

// update_function =
// ^bb0(%arg0: T, %arg1: T):
// stablehlo.return %arg1 : T
// })
if (failed(this->isAssignment(op))) {
return failure();
}

// input_batching_dims = []
// scatter_indices_batching_dims = []
if (failed(this->noBatching(op))) {
return failure();
}

// rank(%update) == size(update_window_dims)
// => we are inserting the whole %update into a dimension of %input
if (failed(this->singleFullSlices(op))) {
return failure();
}

// Now, where are we going to insert this full slice?
// scatter_indices is typed as tensor of integer type
// So, normally I would need a loop around the scatter_indices.
// But let's assume that scatter_indices is a tensor of rank 1
// If this is not true, we would need to create a loop?
// rank(%scatter_indices) == 1
if (failed(this->canBeDoneWithSingleTensorInsertSlice(op))) {
return failure();
}

auto resultTy = cast<RankedTensorType>(result.getType());
auto inputTy = cast<RankedTensorType>(input.getType());
auto updateTy = cast<RankedTensorType>(update.getType());
auto resultShape = resultTy.getShape();
auto inputShape = inputTy.getShape();
auto updateShape = updateTy.getShape();
auto scatterIndicesTy = cast<RankedTensorType>(scatterIndices.getType());
// (C24) shape(%result) == shape(%input)

auto scatterDimNumbers = op.getScatterDimensionNumbers();
auto updateWindowDims = scatterDimNumbers.getUpdateWindowDims();
auto insertedWindowDims = scatterDimNumbers.getInsertedWindowDims();
auto scatterDimsToOperandDims = scatterDimNumbers.getScatterDimsToOperandDims();
auto indexVectorDim = scatterDimNumbers.getIndexVectorDim();

if (indexVectorDim != scatterIndicesTy.getRank() - 1) {
// TODO: I think if indexVectorDim > 0
// implies a loop of insert_slices.
return failure();
}
// Because we said before
// rank(%scatter_indices) == 1
// => indexVectorDim = 0

// But we still have a couple of more attributes that we need to understand.
// Somehow I need to use update_window_dims to correctly set this line:
// %input[%scatter_indices]
// * scatter_dims_to_operand_dims
// * index_vector_dim
// * inserted_window_dim
// Out of these three
// %result = tensor.insert_slice %update into
// %input[%scatter_indices],[size(%update)...],[1*rank(%update)] Annotated description of
// scatter semantics: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#scatter
//
// More formally, for all update_index in index_space(updates[0]):
// // In our case len(updates) = 1
// // so change this to index_space(update)
// // index_space is defined here:
// https://github.com/openxla/stablehlo/blob/main/docs/spec.md#shape-computations
//
// update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims]
// // rank(%update) == size(update_window_dims)
// // => update_scatter_dims = []
//
// update_scatter_index = update_index[update_scatter_dims...]
// // update_scatter_index = update_index
//
// update_window_index = update_index[update_window_dims...]
// // rank(%update) == size(update_window_dims)
// // => update_window_index = update_index
//
// full_window_index = [wi0, ..., 0, ..., wiN] where wi are individual elements in
// update_window_index, and 0 is inserted at indices from inserted_window_dims and
// input_batching_dims
// // rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims) +
// size(input_batching_dims)

// the offset relates to inserted_window_dims
//
// rewriter.create<tensor::InsertSliceOp>(loc, update, input, dyn_off, dyn_siz, dyn_str,
// static_off, static_siz, static_str);
SmallVector<Value> dynOffsets, dynSizes, dynStrides;
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;

// TODO: Close, but not 100% sure. Verify for correctness...
for (int i = 0, inputDim = 0, updateDim = 0; i < inputShape.size(); i++) {
if (llvm::is_contained(insertedWindowDims, i)) {
int scatterDimIndex = scatterDimsToOperandDims[inputDim];
Value scatterDimVal =
rewriter.create<index::ConstantOp>(op.getLoc(), scatterDimIndex);
auto extractOp =
rewriter.create<tensor::ExtractOp>(op.getLoc(), scatterIndices, scatterDimVal)
.getResult();
auto indexCastOp =
rewriter
.create<arith::IndexCastOp>(op.getLoc(), rewriter.getIndexType(), extractOp)
.getResult();
dynOffsets.push_back(indexCastOp);
staticOffsets.push_back(ShapedType::kDynamic);
staticSizes.push_back(1);
}
else if (updateDim == inputDim) {
int scatterDimIndex = scatterDimsToOperandDims[inputDim];
Value scatterDimVal =
rewriter.create<index::ConstantOp>(op.getLoc(), scatterDimIndex);
auto extractOp =
rewriter.create<tensor::ExtractOp>(op.getLoc(), scatterIndices, scatterDimVal)
.getResult();
auto indexCastOp =
rewriter
.create<arith::IndexCastOp>(op.getLoc(), rewriter.getIndexType(), extractOp)
.getResult();
dynOffsets.push_back(indexCastOp);
staticOffsets.push_back(ShapedType::kDynamic);
staticSizes.push_back(updateShape[updateDim]);
updateDim++;
} else {
staticOffsets.push_back(0);
staticSizes.push_back(updateShape[updateDim]);
updateDim++;
}
inputDim++;
staticStrides.push_back(1);
}

rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(op, update, input, dynOffsets, dynSizes,
dynStrides, staticOffsets, staticSizes,
staticStrides);

return success();
}

mlir::LogicalResult matchAndRewrite(mhlo::ScatterOp op,
mlir::PatternRewriter &rewriter) const override
{
// FastPath
if (!failed(this->lowerToTensorInsertSlice(op, rewriter))) {
return success();
}

if (failed(onlyOneInputUpdateAndResult(op))) {
// Otherwise it will segfault.
op.emitError() << "Only one input, update, and result";
return failure();
}

// Compute operation hash in case they are more than one scatter and they have different
// update function
auto opHash = OperationEquivalence::computeHash(op);
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Catalyst/Transforms/scatter_lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

Expand Down
Loading
Loading