Skip to content

Commit

Permalink
emit error instead of segfault
Browse files Browse the repository at this point in the history
  • Loading branch information
erick-xanadu committed Oct 18, 2024
1 parent 9714f6e commit dd276c9
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
24 changes: 24 additions & 0 deletions mlir/lib/Catalyst/Transforms/ScatterPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,33 @@ namespace catalyst {
struct ScatterOpRewritePattern : public mlir::OpRewritePattern<mhlo::ScatterOp> {
using mlir::OpRewritePattern<mhlo::ScatterOp>::OpRewritePattern;

mlir::LogicalResult onlyOneInputUpdateAndResult(mhlo::ScatterOp op) const
{
// Assumption 1: only one input, one update, and one result
// * size(inputs) == 1
// * size(updates) == 1
// * size(results) == 1
// All ScatterOp ops with N inputs, N updates, and N results can be split
// into N ScatterOp ops with 1 input, 1 update and 1 result.
// This simplifies the analysis of the update_computation.

// From:
// C5: 0 < size(inputs) = size(updates) = N
// C24: element_type(results[i]) = Ei for all i in [0,N).
// It follows that:
// 0 < size(inputs) = size(updates) = size(results) = N
return op.getResults().size() == 1 ? success() : failure();
}

mlir::LogicalResult matchAndRewrite(mhlo::ScatterOp op,
mlir::PatternRewriter &rewriter) const override
{

if (failed(onlyOneInputUpdateAndResult(op))) {
// Otherwise it will segfault.
op.emitError() << "Only one input, update, and result";
return failure();
}
// Compute operation hash in case they are more than one scatter and they have different
// update function
auto opHash = OperationEquivalence::computeHash(op);
Expand Down
22 changes: 22 additions & 0 deletions mlir/test/Catalyst/ScatterTest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,25 @@ func.func public @example_no_update_dim(%arg0: tensor<4xf64>) -> tensor<4xf64> {
// CHECK: scf.yield [[INSERTED]] : tensor<4xf64>
// CHECK: }
// CHECK: return [[FORRES]] : tensor<4xf64>

// -----

module @test_multiple_inputs {
%inputs = "test.op"() : () -> (tensor<7x131072xf64>)
%scatter_indices = "test.op"() : () -> (tensor<1xi32>)
%updates = "test.op"() : () -> (tensor<131072xf64>)
// expected-error@+1 {{Only one input, update, and result}}
%results:2 = "mhlo.scatter"(%inputs, %inputs, %scatter_indices, %updates, %updates) <{
indices_are_sorted = true,
unique_indices = true,
scatter_dimension_numbers = #mhlo.scatter<
update_window_dims = [0],
inserted_window_dims = [0],
scatter_dims_to_operand_dims = [0]
>
}> ({
^bb0(%arg3: tensor<f64>, %arg4: tensor<f64>, %arg5: tensor<f64>, %arg6: tensor<f64>):
mhlo.return %arg4, %arg6 : tensor<f64>, tensor<f64>
}) : (tensor<7x131072xf64>, tensor<7x131072xf64>, tensor<1xi32>, tensor<131072xf64>, tensor<131072xf64>) -> (tensor<7x131072xf64>, tensor<7x131072xf64>)
"test.op"(%results#0, %results#1) : (tensor<7x131072xf64>, tensor<7x131072xf64>) -> ()
}

0 comments on commit dd276c9

Please sign in to comment.