diff --git a/mlir/lib/Catalyst/Transforms/ScatterPatterns.cpp b/mlir/lib/Catalyst/Transforms/ScatterPatterns.cpp index 013cdec2a3..d92fc10788 100644 --- a/mlir/lib/Catalyst/Transforms/ScatterPatterns.cpp +++ b/mlir/lib/Catalyst/Transforms/ScatterPatterns.cpp @@ -32,9 +32,33 @@ namespace catalyst { struct ScatterOpRewritePattern : public mlir::OpRewritePattern { using mlir::OpRewritePattern::OpRewritePattern; + mlir::LogicalResult onlyOneInputUpdateAndResult(mhlo::ScatterOp op) const + { + // Assumption 1: only one input, one update, and one result + // * size(inputs) == 1 + // * size(updates) == 1 + // * size(results) == 1 + // All ScatterOp ops with N inputs, N updates, and N results can be split + // into N ScatterOp ops with 1 input, 1 update and 1 result. + // This simplifies the analysis of the update_computation. + + // From: + // C5: 0 < size(inputs) = size(updates) = N + // C24: element_type(results[i]) = Ei for all i in [0,N). + // It follows that: + // 0 < size(inputs) = size(updates) = size(results) = N + return op.getResults().size() == 1 ? success() : failure(); + } + mlir::LogicalResult matchAndRewrite(mhlo::ScatterOp op, mlir::PatternRewriter &rewriter) const override { + + if (failed(onlyOneInputUpdateAndResult(op))) { + // Otherwise it will segfault. + op.emitError() << "Only one input, update, and result"; + return failure(); + } // Compute operation hash in case they are more than one scatter and they have different // update function auto opHash = OperationEquivalence::computeHash(op); diff --git a/mlir/test/Catalyst/ScatterTest.mlir b/mlir/test/Catalyst/ScatterTest.mlir index b6e953ad05..217db6fcc1 100644 --- a/mlir/test/Catalyst/ScatterTest.mlir +++ b/mlir/test/Catalyst/ScatterTest.mlir @@ -247,3 +247,25 @@ func.func public @example_no_update_dim(%arg0: tensor<4xf64>) -> tensor<4xf64> { // CHECK: scf.yield [[INSERTED]] : tensor<4xf64> // CHECK: } // CHECK: return [[FORRES]] : tensor<4xf64> + +// ----- + +module @test_multiple_inputs { + %inputs = "test.op"() : () -> (tensor<7x131072xf64>) + %scatter_indices = "test.op"() : () -> (tensor<1xi32>) + %updates = "test.op"() : () -> (tensor<131072xf64>) + // expected-error@+1 {{Only one input, update, and result}} + %results:2 = "mhlo.scatter"(%inputs, %inputs, %scatter_indices, %updates, %updates) <{ + indices_are_sorted = true, + unique_indices = true, + scatter_dimension_numbers = #mhlo.scatter< + update_window_dims = [0], + inserted_window_dims = [0], + scatter_dims_to_operand_dims = [0] + > + }> ({ + ^bb0(%arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor): + mhlo.return %arg4, %arg6 : tensor, tensor + }) : (tensor<7x131072xf64>, tensor<7x131072xf64>, tensor<1xi32>, tensor<131072xf64>, tensor<131072xf64>) -> (tensor<7x131072xf64>, tensor<7x131072xf64>) + "test.op"(%results#0, %results#1) : (tensor<7x131072xf64>, tensor<7x131072xf64>) -> () +}