diff --git a/stablehlo/dialect/StablehloOps.cpp b/stablehlo/dialect/StablehloOps.cpp index cb9bde80f17..4a35df1b27f 100644 --- a/stablehlo/dialect/StablehloOps.cpp +++ b/stablehlo/dialect/StablehloOps.cpp @@ -154,7 +154,6 @@ LogicalResult ReduceScatterOp::verify() { } INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AddOp) -INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AllReduceOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AndOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Atan2Op) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CbrtOp) @@ -917,6 +916,15 @@ LogicalResult AllReduceOp::verify() { getComputation()); } +LogicalResult AllReduceOp::inferReturnTypeComponents( + MLIRContext*, std::optional location, ValueShapeRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl& inferredReturnShapes) { + AllReduceOp::Adaptor adaptor(operands, attributes, properties, regions); + return hlo::inferAllReduceOp(location, adaptor.getOperand(), + adaptor.getComputation(), inferredReturnShapes); +} + //===----------------------------------------------------------------------===// // BatchNormGradOp //===----------------------------------------------------------------------===// @@ -1378,7 +1386,7 @@ LogicalResult ReduceWindowOp::inferReturnTypeComponents( location, adaptor.getInputs(), adaptor.getInitValues(), adaptor.getWindowDimensions(), adaptor.getWindowStrides(), adaptor.getBaseDilations(), adaptor.getWindowDilations(), - adaptor.getPadding(), inferredReturnShapes); + adaptor.getPadding(), adaptor.getBody(), inferredReturnShapes); } LogicalResult ReduceWindowOp::verify() { @@ -1781,7 +1789,8 @@ LogicalResult ReduceOp::inferReturnTypeComponents( ReduceOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferReduceOp(location, adaptor.getInputs().getTypes(), adaptor.getInitValues().getTypes(), - adaptor.getDimensions(), inferredReturnShapes); + adaptor.getDimensions(), adaptor.getBody(), + inferredReturnShapes); } LogicalResult ReduceOp::verify() { @@ -2312,8 +2321,8 @@ LogicalResult SelectAndScatterOp::inferReturnTypes( SmallVectorImpl& inferredReturnTypes) { SelectAndScatterOp::Adaptor adaptor(operands, attributes, properties, regions); - return hlo::inferSelectAndScatterOp(adaptor.getOperand(), - inferredReturnTypes); + return hlo::inferSelectAndScatterOp( + adaptor.getOperand(), adaptor.getScatter(), inferredReturnTypes); } LogicalResult SelectAndScatterOp::verify() { @@ -2333,6 +2342,7 @@ LogicalResult ScatterOp::inferReturnTypes( SmallVectorImpl& inferredReturnTypes) { ScatterOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferScatterOp(location, adaptor.getInputs(), + adaptor.getUpdateComputation(), inferredReturnTypes); } diff --git a/stablehlo/dialect/StablehloOps.td b/stablehlo/dialect/StablehloOps.td index 399ecc635d3..377afb15fcb 100644 --- a/stablehlo/dialect/StablehloOps.td +++ b/stablehlo/dialect/StablehloOps.td @@ -1326,7 +1326,7 @@ def StableHLO_AllGatherOp : StableHLO_Op<"all_gather", } def StableHLO_AllReduceOp : StableHLO_Op<"all_reduce", - [HLO_CompatibleOperandsAndResultType /*all_reduce_c6*/]> { + [InferTensorType /*all_reduce_c6, all_reduce_c7*/]> { let summary = "AllReduce operation"; let description = [{ Within each process group in the process grid, applies a reduction function @@ -1361,8 +1361,7 @@ def StableHLO_AllReduceOp : StableHLO_Op<"all_reduce", let hasVerifier = 1; } -def StableHLO_ReduceScatterOp : StableHLO_Op<"reduce_scatter", - [SameOperandsAndResultElementType /*reduce_scatter_c8*/]> { +def StableHLO_ReduceScatterOp : StableHLO_Op<"reduce_scatter", []> { let summary = "ReduceScatter operation"; let description = [{ Within each process group in the process grid, performs reduction, using @@ -1447,7 +1446,7 @@ def StableHLO_AllToAllOp : StableHLO_Op<"all_to_all", def StableHLO_ReduceOp: StableHLO_ShapedInterfaceOp<"reduce", [ RecursiveMemoryEffects, SameVariadicOperandSize /*reduce_c3*/, - InferTensorTypeWithReify /*reduce_c7*/, + InferTensorTypeWithReify /*reduce_c7, reduce_c8*/, SingleBlockImplicitTerminator<"ReturnOp"> ]> { /*reduce_c7*/ let summary = "Reduce operation"; @@ -2512,7 +2511,8 @@ def StableHLO_DynamicReshapeOp: StableHLO_ShapedInterfaceOp<"dynamic_reshape", [ def StableHLO_ScatterOp: StableHLO_Op<"scatter", [RecursiveMemoryEffects, SameVariadicOperandSize /*scatter_c5*/, - DeclareOpInterfaceMethods /*scatter_c16*/]> { + DeclareOpInterfaceMethods /*scatter_c16, + scater_c17*/]> { let summary = "Scatter operation"; let description = [{ Produces `results` tensors which are equal to `inputs` tensors except that @@ -2585,8 +2585,8 @@ def StableHLO_SelectOp: StableHLO_Op<"select", [Pure, HLO_BroadcastingElementwis } def StableHLO_SelectAndScatterOp: StableHLO_Op<"select_and_scatter", - [DeclareOpInterfaceMethods /*select_and_scatter_c11*/, - RecursiveMemoryEffects]> { + [DeclareOpInterfaceMethods /*select_and_scatter_c11, + select_and_scatter_c12*/, RecursiveMemoryEffects]> { let summary = "SelectAndScatter operation"; let description = [{ Scatters the values from the `source` tensor using `scatter` based on the diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index 7dbfbd260da..3d8fe00eecb 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -97,6 +97,51 @@ bool tensorsHaveSameElType(Type type1, Type type2, return tensorsHaveSameElType({type1, type2}, ignoreFpPrecision); } +unsigned potentiallyComplexBitWidth(Type type) { + auto complexTy = type.dyn_cast(); + return complexTy ? 2 * complexTy.getElementType().getIntOrFloatBitWidth() + : type.getIntOrFloatBitWidth(); +} + +// Returns true if the element-type of type1 can be promoted to that of type2. +// An element-type 'x' is promotatble to element-type 'y' is they have the same +// base type and bitwidth(x) <= bitwidth(y). When 'x' and 'y' are quantized +// element-types, then promotion is applied only to the 'storage_type' +// component. +bool isPromotableElementType(Type type1, Type type2, + bool ignoreFpPrecision = false) { + auto tensorTy1 = type1.dyn_cast(); + auto tensorTy2 = type2.dyn_cast(); + + if (!tensorTy1 || !tensorTy2) return false; + + Type tensorEl1 = tensorTy1.getElementType(); + Type tensorEl2 = tensorTy2.getElementType(); + + if (ignoreFpPrecision && tensorEl1.isa() && + tensorTy2.getElementType().isa()) + return true; + + bool isSameType = + (tensorEl1.isa() and tensorEl2.isa()) || + (tensorEl1.isa() and tensorEl2.isa()) || + (tensorEl1.isa() and tensorEl2.isa()) || + (tensorEl1.isa() and + tensorEl2.isa()); + + if (!isSameType) return false; + + if (!tensorEl1.isa()) + return potentiallyComplexBitWidth(tensorEl1) <= + potentiallyComplexBitWidth(tensorEl2); + + auto quantType1 = tensorEl1.cast(); + auto quantType2 = tensorEl2.cast(); + return quantType1.getExpressedType() == quantType2.getExpressedType() && + potentiallyComplexBitWidth(quantType1.getStorageType()) <= + potentiallyComplexBitWidth(quantType2.getStorageType()); +} + // Return true if type1 and type2 are shape-compatible and have same element // type. If 'ignoreFpPrecision' is True, then allow floats with different // precisions while checking element-types. @@ -405,12 +450,6 @@ SmallVector inferWindowOutputShape(ArrayRef baseShape, return outputDimensions; } -unsigned potentiallyComplexBitWidth(Type type) { - auto complexTy = type.dyn_cast(); - return complexTy ? 2 * complexTy.getElementType().getIntOrFloatBitWidth() - : type.getIntOrFloatBitWidth(); -} - LogicalResult verifyReplicaGroups(std::optional location, DenseIntElementsAttr replicaGroups, bool allGroupsMustHaveSameSize, @@ -530,6 +569,17 @@ LogicalResult verifyReduceOpInputsAndInferShape( return success(); } +// Returns the types of the terminator arguments of the input mlir::Block +// 'block'. +SmallVector getAccumulatorTypes(Block& block) { + SmallVector accumulatorSubShapes; + for (Value retOperand : block.getTerminator()->getOperands()) { + auto shapedTy = retOperand.getType().dyn_cast(); + accumulatorSubShapes.push_back(shapedTy); + } + return accumulatorSubShapes; +} + LogicalResult verifyReducerShape(std::optional loc, Block& block, ArrayRef inputTypes, ArrayRef initValueTypes, @@ -598,24 +648,37 @@ LogicalResult verifyReducerShape(std::optional loc, Block& block, // all_reduce_c5, reduce_c6, reduce_scatter_c7, reduce_window_c13, // reduce_window_i2, scatter_c6, scatter_c15, select_and_scatter_c10 - if (!compatibleShapeAndElementType(accumulatorSubShapes[inputIdx], - initValueTypes[inputIdx], - /*ignoreFpPrecision=*/true)) + if (failed(verifyCompatibleShape(initValueTypes[inputIdx], + accumulatorSubShapes[inputIdx]))) return emitOptionalError( - loc, "The type of reduction-region's result type at index ", inputIdx, - " differs from the op's corresponding init-value type: ", + loc, "The shape of reduction-region's result type at index ", + inputIdx, " differs from the op's corresponding init-value type: ", + accumulatorSubShapes[inputIdx], " vs ", initValueTypes[inputIdx]); + + if (!isPromotableElementType(initValueTypes[inputIdx], + accumulatorSubShapes[inputIdx], + /*ignoreFpPrecision=*/true)) + return emitOptionalError( + loc, "The element-type of reduction-region's result type at index ", + inputIdx, + " is expected to be promotable from the op's corresponding " + "init-value element-type: ", accumulatorSubShapes[inputIdx], " vs ", initValueTypes[inputIdx]); // reduce_c6, reduce_window_c3, scatter_c6, scatter_c15, // select_and_scatter_c10 - if (!tensorsHaveSameElType( + if (!isPromotableElementType( inputTypes[inputIdx], - block.getArgument(numInputs + inputIdx).getType(), true)) + block.getArgument(numInputs + inputIdx).getType(), + /*ignoreFpPrecision=*/true)) return emitOptionalError( loc, "The element-type of reduction-region's argument at index ", - numInputs + inputIdx, " is expected to be ", + numInputs + inputIdx, " is expected to be promotable from ", inputTypes[inputIdx].getElementType(), ", but got ", - block.getArgument(numInputs + inputIdx).getType(), " as its type."); + block.getArgument(numInputs + inputIdx) + .getType() + .cast() + .getElementType()); Type blockArgType = block.getArgument(numInputs + inputIdx).getType(); auto blockArgTensorTy = blockArgType.cast(); @@ -1453,6 +1516,17 @@ LogicalResult inferAllToAllOp( return success(); } +LogicalResult inferAllReduceOp( + std::optional location, Value operand, Region& body, + SmallVectorImpl& inferredReturnShapes) { + // all_reduce_c6, all_reduce_c7 + SmallVector accumulatorTypes = getAccumulatorTypes(body.front()); + auto operandShapedTy = operand.getType().cast(); + inferredReturnShapes.emplace_back(getSameShapeTensorType( + operandShapedTy, accumulatorTypes[0].getElementType())); + return success(); +} + LogicalResult inferBatchNormGradOp( std::optional location, Value operand, Value scale, Value mean, Value variance, Value gradOutput, int64_t featureIndex, @@ -2554,7 +2628,7 @@ LogicalResult inferRealOp(std::optional, Value operand, LogicalResult inferReduceOp( std::optional location, TypeRange inputTypes, - TypeRange initValueTypes, DenseIntElementsAttr dimensions, + TypeRange initValueTypes, DenseIntElementsAttr dimensions, Region& body, SmallVectorImpl& inferredReturnShapes) { SmallVector inputArgTensorTypes{ llvm::map_range(inputTypes, [](Type t) { return t.cast(); })}; @@ -2568,10 +2642,11 @@ LogicalResult inferReduceOp( initValueTensorTypes, dimensions, newDimensions, encoding))) return failure(); - // reduce_c2, reduce_c3, reduce_c7 + // reduce_c3, reduce_c7, reduce_c8 + SmallVector accumulatorTypes = getAccumulatorTypes(body.front()); for (uint64_t inputIdx = 0; inputIdx < inputTypes.size(); ++inputIdx) { ShapedType inputType = inputArgTensorTypes[inputIdx]; - Type elementType = inputType.getElementType(); + Type elementType = accumulatorTypes[inputIdx].getElementType(); if (inputType.hasRank()) inferredReturnShapes.emplace_back(newDimensions, elementType, encoding); else @@ -2587,7 +2662,7 @@ LogicalResult inferReduceWindowOp( std::optional windowStrides, std::optional baseDilations, std::optional windowDilations, - std::optional padding, + std::optional padding, Region& body, SmallVectorImpl& inferredReturnShapes) { SmallVector inputTypes{llvm::map_range( inputs.getTypes(), [](Type t) { return t.cast(); })}; @@ -2604,21 +2679,22 @@ LogicalResult inferReduceWindowOp( return failure(); // reduce_window_c1, reduce_window_c14...reduce_window_c16 + SmallVector accumulatorTypes = getAccumulatorTypes(body.front()); for (size_t i = 0; i < inputTypes.size(); ++i) { auto inputRankedType = inputs[i].getType().dyn_cast(); if (!inputRankedType) { - inferredReturnShapes.emplace_back(inputTypes[i].getElementType()); + inferredReturnShapes.emplace_back(accumulatorTypes[i].getElementType()); } else { auto resultShape = inferWindowOutputShape(inputTypes[i].getShape(), inferredWindow); auto inputBounds = encodingToBounds(inputRankedType.getEncoding()); if (inputBounds.empty()) { inferredReturnShapes.emplace_back(resultShape, - inputTypes[i].getElementType()); + accumulatorTypes[i].getElementType()); } else { auto resultBounds = inferWindowOutputShape(inputBounds, inferredWindow); inferredReturnShapes.emplace_back( - resultShape, inputTypes[i].getElementType(), + resultShape, accumulatorTypes[i].getElementType(), boundsToEncoding(inputRankedType.getEncoding(), resultBounds)); } } @@ -2683,8 +2759,15 @@ LogicalResult inferRngOp( } LogicalResult inferScatterOp(std::optional, ValueRange inputs, + Region& body, SmallVectorImpl& inferredReturnTypes) { - llvm::append_range(inferredReturnTypes, inputs.getTypes()); + // scatter_c16, scatter_c17 + SmallVector accumulatorTypes = getAccumulatorTypes(body.front()); + for (uint64_t inputIdx = 0; inputIdx < inputs.size(); ++inputIdx) { + auto inputShapedTy = inputs[inputIdx].getType().cast(); + inferredReturnTypes.push_back(getSameShapeTensorType( + inputShapedTy, accumulatorTypes[inputIdx].getElementType())); + } return success(); } @@ -2714,9 +2797,12 @@ LogicalResult inferSelectOp( } LogicalResult inferSelectAndScatterOp( - Value operand, SmallVectorImpl& inferredReturnTypes) { - // select_and_scatter_c11 - inferredReturnTypes.push_back(operand.getType()); + Value operand, Region& body, SmallVectorImpl& inferredReturnTypes) { + // select_and_scatter_c11, select_and_scatter_c12 + SmallVector accumulatorTypes = getAccumulatorTypes(body.front()); + auto operandShapedTy = operand.getType().cast(); + inferredReturnTypes.push_back(getSameShapeTensorType( + operandShapedTy, accumulatorTypes[0].getElementType())); return success(); } @@ -3871,6 +3957,16 @@ LogicalResult verifyReduceScatterOp(std::optional location, operandType.getDimSize(index), ") and result (", resultType.getDimSize(index), ")"); } + + // reduce_scatter_c9 + SmallVector accumulatorTypes = + getAccumulatorTypes(computation.front()); + if (resultType.getElementType() != accumulatorTypes[0].getElementType()) { + return emitOptionalError(location, "result element-type is expected to be ", + accumulatorTypes[0].getElementType(), ", but got ", + resultType.getElementType()); + } + return success(); } diff --git a/stablehlo/dialect/TypeInference.h b/stablehlo/dialect/TypeInference.h index 3c9dd83e326..5969e5146d6 100644 --- a/stablehlo/dialect/TypeInference.h +++ b/stablehlo/dialect/TypeInference.h @@ -120,6 +120,10 @@ LogicalResult inferAllToAllOp( DenseIntElementsAttr replicaGroups, SmallVectorImpl& inferredReturnShapes); +LogicalResult inferAllReduceOp( + std::optional location, Value operand, Region& body, + SmallVectorImpl& inferredReturnShapes); + LogicalResult inferBatchNormGradOp( std::optional location, Value operand, Value scale, Value mean, Value variance, Value gradOutput, int64_t featureIndex, @@ -281,7 +285,7 @@ LogicalResult inferRealOp(std::optional location, Value operand, LogicalResult inferReduceOp( std::optional location, TypeRange inputTypes, - TypeRange initValueTypes, DenseIntElementsAttr dimensions, + TypeRange initValueTypes, DenseIntElementsAttr dimensions, Region& body, SmallVectorImpl& inferredReturnShapes); LogicalResult inferReduceWindowOp( @@ -290,7 +294,7 @@ LogicalResult inferReduceWindowOp( std::optional windowStrides, std::optional baseDilations, std::optional windowDilations, - std::optional padding, + std::optional padding, Region& body, SmallVectorImpl& inferredReturnShapes); LogicalResult inferReplicaIdOp(MLIRContext* context, std::optional, @@ -306,7 +310,7 @@ LogicalResult inferRngOp( SmallVectorImpl& inferredReturnShapes); LogicalResult inferScatterOp(std::optional location, - ValueRange inputs, + ValueRange inputs, Region& body, SmallVectorImpl& inferredReturnTypes); LogicalResult inferSelectOp( @@ -314,7 +318,7 @@ LogicalResult inferSelectOp( SmallVectorImpl& inferredReturnShapes); LogicalResult inferSelectAndScatterOp( - Value operand, SmallVectorImpl& inferredReturnTypes); + Value operand, Region& body, SmallVectorImpl& inferredReturnTypes); LogicalResult inferSendOp(HloDialectInterface* dialect, std::optional location, diff --git a/stablehlo/reference/Ops.cpp b/stablehlo/reference/Ops.cpp index 749ac016321..268f87345e0 100644 --- a/stablehlo/reference/Ops.cpp +++ b/stablehlo/reference/Ops.cpp @@ -84,7 +84,7 @@ SmallVector evalReduceOp(ArrayRef inputs, Builder builder(inputs[0].getType().getContext()); auto reduceStatus = hlo::inferReduceOp( /*location=*/{}, inputTypes, initValueTypes, - builder.getI64TensorAttr(dimensions), inferredReduceTypes); + builder.getI64TensorAttr(dimensions), body, inferredReduceTypes); if (failed(reduceStatus)) report_fatal_error( invalidArgument("Could not infer ReduceOp's return type")); diff --git a/stablehlo/tests/infer_stablehlo.mlir b/stablehlo/tests/infer_stablehlo.mlir index 8d8a746266b..e5ae6ee7935 100644 --- a/stablehlo/tests/infer_stablehlo.mlir +++ b/stablehlo/tests/infer_stablehlo.mlir @@ -102,6 +102,24 @@ func.func @cholesky(%arg0: tensor<1x2x2xf32>) -> tensor<1x2x2xindex> { // ----- +// CHECK-LABEL: func @all_reduce_c6_c7 +func.func @all_reduce_c6_c7(%operand: tensor<10xf32>) -> tensor<10xindex> { + + %0 = "stablehlo.all_reduce"(%operand) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<10xf32>) -> tensor<10xf64> + // CHECK: types0 = tensor<10xf64> + %1 = "hlo_test_infer.get_return_types"(%0) : (tensor<10xf64>) -> tensor<10xindex> + func.return %1 : tensor<10xindex> +} + +// ----- + // CHECK-LABEL: func @all_to_all_c9 func.func @all_to_all_c9(%data: tensor<4x16xf32>) -> tensor<16x4xindex> { %0 = "stablehlo.all_to_all"(%data) { @@ -584,8 +602,8 @@ func.func @after_all_empty_arg() -> !stablehlo.token { // ----- -// CHECK: func @select_and_scatter_c11 -func.func @select_and_scatter_c11( +// CHECK: func @select_and_scatter_c11_c12 +func.func @select_and_scatter_c11_c12( %arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xindex> { %0 = stablehlo.constant dense<0.000000e+00> : tensor @@ -612,8 +630,35 @@ func.func @select_and_scatter_c11( // ----- -// CHECK-LABEL: func @scatter_c16 -func.func @scatter_c16(%input_tensor: tensor<200x100x300xf32>, +// CHECK: func @select_and_scatter_c11_c12 +func.func @select_and_scatter_c11_c12( + %arg0: tensor<10x24x24x64xf32>, + %arg1: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xindex> { + %0 = stablehlo.constant dense<0.000000e+00> : tensor + %1 = "stablehlo.select_and_scatter"(%arg0, %arg1, %0) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %2 = "stablehlo.compare"(%arg3, %arg4) { + compare_type = #stablehlo, + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + "stablehlo.return"(%2) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %2 = stablehlo.add %arg3, %arg4 : tensor + "stablehlo.return"(%2) : (tensor) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64> + } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> + tensor<10x24x24x64xf64> + %2 = "hlo_test_infer.get_return_types"(%1) : (tensor<10x24x24x64xf64>) -> tensor<10x24x24x64xindex> + func.return %2 : tensor<10x24x24x64xindex> +} + +// ----- + +// CHECK-LABEL: func @scatter_c16_c17 +func.func @scatter_c16_c17(%input_tensor: tensor<200x100x300xf32>, %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> tensor<200x100x300xindex> { %0 = "stablehlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ @@ -638,6 +683,32 @@ func.func @scatter_c16(%input_tensor: tensor<200x100x300xf32>, // ----- +// CHECK-LABEL: func @scatter_c16_c17 +func.func @scatter_c16_c17(%input_tensor: tensor<200x100x300xf32>, + %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> + tensor<200x100x300xindex> { + %0 = "stablehlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ + ^bb0(%lhs: tensor, %rhs: tensor): + %add = stablehlo.add %lhs, %rhs : tensor + "stablehlo.return"(%add) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> + tensor<200x100x300xf64> + // CHECK: types0 = tensor<200x100x300xf64> + %1 = "hlo_test_infer.get_return_types"(%0) : (tensor<200x100x300xf64>) -> tensor<200x100x300xindex> + func.return %1 : tensor<200x100x300xindex> +} + +// ----- + // CHECK-LABEL: func @scatter_bounds func.func @scatter_bounds(%input_tensor: tensor<200x?x?xf32, #stablehlo.bounds>, %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> @@ -864,7 +935,24 @@ func.func @reduce_c7(%arg0: tensor<7x5xf32>, %arg1 : tensor<5xf32>) -> tensor<6x // ----- -func.func @reduce_c7(%arg0: tensor, %arg1: tensor, +func.func @reduce_c8(%arg0: tensor<4x4xf32>, %arg1 : tensor) + -> (tensor<4xf32>) { + // expected-error@+2 {{failed to infer returned types}} + // expected-error@+1{{'stablehlo.reduce' op inferred type(s) 'tensor<4xf64>' are incompatible with return type(s) of operation 'tensor<4xf32>'}} + %0 = "stablehlo.reduce"(%arg0, %arg1) ({ + + ^bb0(%arg2: tensor, %arg3: tensor ): + %1 = "stablehlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + + }) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor<4x4xf32>, tensor) -> tensor<4xf32> + + func.return %0: tensor<4xf32> +} + +// ----- + +func.func @reduce_c3_c7(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor) { // expected-error@+2 {{failed to infer returned types}} // expected-error@+1 {{inferred type(s) 'tensor', 'tensor' are incompatible with return type(s) of operation 'tensor', 'tensor', 'tensor'}} @@ -882,7 +970,7 @@ func.func @reduce_c7(%arg0: tensor, %arg1: tensor, // ----- -func.func @reduce_c7(%arg0: tensor, %arg1: tensor, +func.func @reduce_c7_c8(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor) { // expected-error@+2 {{failed to infer returned types}} // expected-error@+1 {{'stablehlo.reduce' op inferred type(s) 'tensor', 'tensor' are incompatible with return type(s) of operation 'tensor', 'tensor'}} @@ -900,7 +988,7 @@ func.func @reduce_c7(%arg0: tensor, %arg1: tensor, // ----- -func.func @reduce_c7(%arg0: tensor, %arg1 : tensor) +func.func @reduce_c8(%arg0: tensor, %arg1 : tensor) -> (tensor) { // expected-error@+2 {{failed to infer returned types}} // expected-error@+1 {{'stablehlo.reduce' op inferred type(s) 'tensor' are incompatible with return type(s) of operation 'tensor'}} @@ -1023,6 +1111,25 @@ func.func @reduce_window_c16(%arg0: tensor<4x2xf32>, // ----- +func.func @reduce_window_c16(%arg0: tensor<4x2xf32>, %init0: tensor) -> + (tensor<2x2xf32>) { + // expected-error@+2 {{failed to infer returned types}} + // expected-error@+1 {{inferred type(s) 'tensor<2x2xf64>' are incompatible with return type(s) of operation 'tensor<2x2xf32>'}} + %0 = "stablehlo.reduce_window"(%arg0, %init0) ({ + ^bb0(%a0: tensor, %b0: tensor): + %1 = stablehlo.add %a0, %b0 : tensor + "stablehlo.return"(%1) : (tensor) -> () + }) + { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, + window_dimensions = dense<[5, 1]> : tensor<2xi64>, + window_strides = dense<[3, 1]> : tensor<2xi64> + } + : (tensor<4x2xf32>, tensor) -> (tensor<2x2xf32>) + func.return %0 : tensor<2x2xf32> +} + +// ----- + //===----------------------------------------------------------------------===// // Bounded Dynamism //===----------------------------------------------------------------------===// diff --git a/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/tests/ops_stablehlo.mlir index 2d4875409c6..1b19c400366 100644 --- a/stablehlo/tests/ops_stablehlo.mlir +++ b/stablehlo/tests/ops_stablehlo.mlir @@ -27,6 +27,40 @@ func.func @all_reduce(%operand: tensor<10xf32>) -> tensor<10xf32> { func.return %0 : tensor<10xf32> } +// ----- + +// CHECK-LABEL: func @all_reduce_with_promotable_types +func.func @all_reduce_with_promotable_types(%operand: tensor) -> tensor { + + %result = "stablehlo.all_reduce"(%operand) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor) -> tensor + + func.return %result : tensor +} + +// ----- + +// CHECK-LABEL: func @all_reduce_with_promotable_quantized_types +func.func @all_reduce_with_promotable_quantized_types(%operand: tensor>) + -> tensor> { + + %result = "stablehlo.all_reduce"(%operand) ({ + ^bb0(%arg0: tensor>, %arg1: tensor>): + %0 = stablehlo.add %arg0, %arg1 : tensor> + "stablehlo.return"(%0) : (tensor>) -> () + }) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor>) -> tensor> + + func.return %result : tensor> +} // ----- @@ -186,7 +220,7 @@ func.func @all_reduce_c5(%operand: tensor<10xf32>) -> tensor<10xf32> { // ----- func.func @all_reduce_c5(%operand: tensor<10xf32>) -> tensor<10xf32> { - // expected-error@+1 {{The type of reduction-region's result type at index 0 differs from the op's corresponding init-value type: 'tensor' vs 'tensor'}} + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor' vs 'tensor'}} %0 = "stablehlo.all_reduce"(%operand) ({ ^bb0(%arg0: tensor, %arg1: tensor): %max = stablehlo.maximum %arg0, %arg1 : tensor @@ -201,7 +235,7 @@ func.func @all_reduce_c5(%operand: tensor<10xf32>) -> tensor<10xf32> { // ----- func.func @all_reduce_c5(%operand: tensor<10xf32>) -> tensor<10xf32> { - // expected-error@+1 {{The type of reduction-region's result type at index 0 differs from the op's corresponding init-value type: 'tensor<4xf32>' vs 'tensor'}} + // expected-error@+1 {{The shape of reduction-region's result type at index 0 differs from the op's corresponding init-value type: 'tensor<4xf32>' vs 'tensor'}} %0 = "stablehlo.all_reduce"(%operand) ({ ^bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>): %max = stablehlo.maximum %arg0, %arg1 : tensor<4xf32> @@ -215,6 +249,57 @@ func.func @all_reduce_c5(%operand: tensor<10xf32>) -> tensor<10xf32> { // ----- +func.func @all_reduce_c5(%operand: tensor) -> tensor { + + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor' vs 'tensor'}} + %result = "stablehlo.all_reduce"(%operand) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor) -> tensor + + func.return %result : tensor +} + +// ----- + +func.func @all_reduce_c5(%operand: tensor) -> tensor { + + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor' vs 'tensor'}} + %result = "stablehlo.all_reduce"(%operand) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor) -> tensor + + func.return %result : tensor +} +// ----- + +func.func @all_reduce_c5(%operand: tensor>) + -> tensor> { + + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor>' vs 'tensor>'}} + %result = "stablehlo.all_reduce"(%operand) ({ + ^bb0(%arg0: tensor>, %arg1: tensor>): + %0 = stablehlo.add %arg0, %arg1 : tensor> + "stablehlo.return"(%0) : (tensor>) -> () + }) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor>) -> tensor> + + func.return %result : tensor> +} + +// ----- + // CHECK-LABEL: func @reduce_scatter func.func @reduce_scatter(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { %0 = "stablehlo.reduce_scatter"(%data) ({ @@ -245,6 +330,38 @@ func.func @reduce_scatter_dynamic(%data: tensor) -> tensor { // ----- +// CHECK-LABEL: func @reduce_scatter_with_promotable_types +func.func @reduce_scatter_with_promotable_types(%data: tensor<4x16xf32>) -> tensor<4x4xf64> { + %0 = "stablehlo.reduce_scatter"(%data) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = stablehlo.add %arg2, %arg3 : tensor + "stablehlo.return"(%1) : (tensor) -> () + }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids} : (tensor<4x16xf32>) -> tensor<4x4xf64> + func.return %0 : tensor<4x4xf64> +} + +// ----- + +// CHECK-LABEL: func @reduce_scatter_with_promotable_quantized_types +func.func @reduce_scatter_with_promotable_quantized_types( + %data: tensor<4x16x!quant.uniform>) -> + tensor<4x4x!quant.uniform> { + %0 = "stablehlo.reduce_scatter"(%data) ({ + ^bb0(%arg2: tensor>, %arg3: tensor>): + %1 = stablehlo.add %arg2, %arg3 : tensor> + "stablehlo.return"(%1) : (tensor>) -> () + }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids} : (tensor<4x16x!quant.uniform>) -> tensor<4x4x!quant.uniform> + func.return %0 : tensor<4x4x!quant.uniform> +} + +// ----- + func.func @reduce_scatter_c2(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { // expected-error@+1 {{expects scatter_dimension >= 0}} %0 = "stablehlo.reduce_scatter"(%data) ({ @@ -404,7 +521,7 @@ func.func @reduce_scatter_c7(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { // ----- func.func @reduce_scatter_c7(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { - // expected-error@+1 {{The type of reduction-region's result type at index 0 differs from the op's corresponding init-value type: 'tensor' vs 'tensor'}} + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor' vs 'tensor'}} %0 = "stablehlo.reduce_scatter"(%data) ({ ^bb0(%arg2: tensor, %arg3: tensor): %1 = stablehlo.add %arg2, %arg3 : tensor @@ -416,6 +533,39 @@ func.func @reduce_scatter_c7(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { // ----- +func.func @reduce_scatter_c7(%data: tensor<4x16xi32>) -> tensor<4x4xi8> { + + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor' vs 'tensor'}} + %0 = "stablehlo.reduce_scatter"(%data) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = stablehlo.add %arg2, %arg3 : tensor + "stablehlo.return"(%1) : (tensor) -> () + }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids} : (tensor<4x16xi32>) -> tensor<4x4xi8> + func.return %0 : tensor<4x4xi8> +} + +// ----- + +func.func @reduce_scatter_c7(%data: tensor<4x16x!quant.uniform>) + -> tensor<4x4x!quant.uniform> { + + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor>' vs 'tensor>'}} + %0 = "stablehlo.reduce_scatter"(%data) ({ + ^bb0(%arg2: tensor>, %arg3: tensor>): + %1 = stablehlo.add %arg2, %arg3 : tensor> + "stablehlo.return"(%1) : (tensor>) -> () + }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids} : (tensor<4x16x!quant.uniform>) -> tensor<4x4x!quant.uniform> + func.return %0 : tensor<4x4x!quant.uniform> +} + +// ----- + func.func @reduce_scatter_c8(%data: tensor<4x16xf32>) -> tensor<4xf32> { // expected-error@+1 {{operand and result should have same rank}} %0 = "stablehlo.reduce_scatter"(%data) ({ @@ -455,6 +605,21 @@ func.func @reduce_scatter_c8(%data: tensor<4x16xf32>) -> tensor<3x4xf32> { // ----- +func.func @reduce_scatter_c9(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { + // expected-error@+1 {{result element-type is expected to be 'f64', but got 'f32'}} + %0 = "stablehlo.reduce_scatter"(%data) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = stablehlo.add %arg2, %arg3 : tensor + "stablehlo.return"(%1) : (tensor) -> () + }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids} : (tensor<4x16xf32>) -> tensor<4x4xf32> + func.return %0 : tensor<4x4xf32> +} + +// ----- + func.func @reduce_scatter_i3(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { // expected-error@+1 {{replica groups should be a rank 2 tensor}} %0 = "stablehlo.reduce_scatter"(%data) ({ diff --git a/stablehlo/tests/verify_reduce.mlir b/stablehlo/tests/verify_reduce.mlir index 8915f4941aa..f48a2db01f7 100644 --- a/stablehlo/tests/verify_reduce.mlir +++ b/stablehlo/tests/verify_reduce.mlir @@ -96,6 +96,35 @@ func.func @reduce_mix_rank_and_unranked(%arg0: tensor<4x4xf32>, %arg1: tensor<*x // ----- +// CHECK-LABEL: func @reduce_with_promotable_types +func.func @reduce_with_promotable_types(%arg0: tensor<4x4xf32>, %arg1 : tensor) + -> (tensor<4xf64>) { + %0 = "stablehlo.reduce"(%arg0, %arg1) ({ + + ^bb0(%arg2: tensor, %arg3: tensor ): + %1 = "stablehlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + + }) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor<4x4xf32>, tensor) -> tensor<4xf64> + + func.return %0: tensor<4xf64> +} + +// ----- + +// CHECK-LABEL: func @reduce_with_promotable_quantized_types +func.func @reduce_with_promotable_quantized_types(%arg0: tensor<4x4x!quant.uniform>, + %arg1: tensor>) -> tensor<4x!quant.uniform> { + %0 = stablehlo.reduce(%arg0 init: %arg1) across dimensions = [0] : (tensor<4x4x!quant.uniform>, tensor>) -> tensor<4x!quant.uniform> + reducer(%arg2: tensor>, %arg3: tensor>) { + %1 = stablehlo.add %arg2, %arg3 : tensor> + stablehlo.return %1 : tensor> + } + return %0 : tensor<4x!quant.uniform> +} + +// ----- + func.func @reduce_c1(%arg0: tensor<2x3xf32>, %arg1: tensor<3x2xf32>, %arg2: tensor, %arg3: tensor) -> (tensor<2xf32>, tensor<2xf32>) { @@ -307,7 +336,7 @@ func.func @reduce_c6(%arg0: tensor, %arg1: tensor, func.func @reduce_c6(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor, tensor) { - // expected-error@+1 {{The type of reduction-region's result type at index 0 differs from the op's corresponding init-value type: 'tensor' vs 'tensor'}} + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor' vs 'tensor'}} %0:2 = "stablehlo.reduce"(%arg0, %arg1, %arg2, %arg3) ({ ^bb0(%arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor): @@ -325,7 +354,7 @@ func.func @reduce_c6(%arg0: tensor, %arg1: tensor, func.func @reduce_c6(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor, tensor) { - // expected-error@+1 {{The type of reduction-region's result type at index 1 differs from the op's corresponding init-value type: 'tensor' vs 'tensor'}} + // expected-error@+1 {{The element-type of reduction-region's result type at index 1 is expected to be promotable from the op's corresponding init-value element-type: 'tensor' vs 'tensor'}} %0:2 = "stablehlo.reduce"(%arg0, %arg1, %arg2, %arg3) ({ ^bb0(%arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor): @@ -343,7 +372,7 @@ func.func @reduce_c6(%arg0: tensor, %arg1: tensor, func.func @reduce_c6(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor, tensor) { - // expected-error@+1 {{The element-type of reduction-region's argument at index 3 is expected to be 'i32', but got 'tensor'}} + // expected-error@+1 {{The element-type of reduction-region's argument at index 3 is expected to be promotable from 'i32', but got 'f32'}} %0:2 = "stablehlo.reduce"(%arg0, %arg1, %arg2, %arg3) ({ ^bb0(%arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor): @@ -392,6 +421,73 @@ func.func @reduce_c6(%arg0: tensor<8x5xf32>, %arg1 : tensor<4xf32>) // ----- + +func.func @reduce_c6(%arg0: tensor<4x4xi32>, %arg1 : tensor) + -> (tensor<4xi8>) { + + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor' vs 'tensor'}} + %0 = "stablehlo.reduce"(%arg0, %arg1) ({ + + ^bb0(%arg2: tensor, %arg3: tensor ): + %1 = "stablehlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + + }) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor<4x4xi32>, tensor) -> tensor<4xi8> + + func.return %0: tensor<4xi8> +} + +// ----- + +func.func @reduce_c6(%arg0: tensor<4x4xi32>, %arg1 : tensor) + -> (tensor<4xi8>) { + + // expected-error@+1 {{The element-type of reduction-region's argument at index 1 is expected to be promotable from 'i32', but got 'i8'}} + %0 = "stablehlo.reduce"(%arg0, %arg1) ({ + + ^bb0(%arg2: tensor, %arg3: tensor ): + %1 = "stablehlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + + }) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor<4x4xi32>, tensor) -> tensor<4xi8> + + func.return %0: tensor<4xi8> +} + +// ----- + +func.func @reduce_c6(%arg0: tensor<4x4x!quant.uniform>, + %arg1: tensor>) -> tensor<4x!quant.uniform> { + + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor>' vs 'tensor>'}} + %0 = stablehlo.reduce(%arg0 init: %arg1) across dimensions = [0] : (tensor<4x4x!quant.uniform>, + tensor>) -> tensor<4x!quant.uniform> + + reducer(%arg2: tensor>, %arg3: tensor>) { + %1 = stablehlo.add %arg2, %arg3 : tensor> + stablehlo.return %1 : tensor> + } + return %0 : tensor<4x!quant.uniform> +} + +// ----- + +func.func @reduce_c6(%arg0: tensor<4x4x!quant.uniform>, + %arg1: tensor>) -> tensor<4x!quant.uniform> { + + // expected-error@+1 {{The element-type of reduction-region's argument at index 1 is expected to be promotable from '!quant.uniform', but got '!quant.uniform'}} + %0 = stablehlo.reduce(%arg0 init: %arg1) across dimensions = [0] : (tensor<4x4x!quant.uniform>, + tensor>) -> tensor<4x!quant.uniform> + + reducer(%arg2: tensor>, %arg3: tensor>) { + %1 = stablehlo.add %arg2, %arg3 : tensor> + stablehlo.return %1 : tensor> + } + return %0 : tensor<4x!quant.uniform> +} + +// ----- + func.func @reduce_i3(%input: tensor<1x6xi64>, %init_value: tensor) -> tensor<1xi64> { // expected-error@+1 {{dimensions must be rank 1}} %0 = "stablehlo.reduce"(%input, %init_value) ({ diff --git a/stablehlo/tests/verify_reduce_window.mlir b/stablehlo/tests/verify_reduce_window.mlir index f410243b1b9..62d9b4d80e0 100644 --- a/stablehlo/tests/verify_reduce_window.mlir +++ b/stablehlo/tests/verify_reduce_window.mlir @@ -82,6 +82,46 @@ func.func @reduce_window_with_unranked_dynamic_dims(%arg0: tensor<*xf32>, // ----- +// CHECK-LABEL: func @reduce_window_with_promotable_types +func.func @reduce_window_with_promotable_types(%arg0: tensor<4x2xf32>, + %arg1: tensor<4x2xf32>, %init0: tensor, %init1: tensor) -> + (tensor<2x2xf64>, tensor<2x2xf32>) { + %0:2 = "stablehlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({ + ^bb0(%a0: tensor, %a1: tensor, %b0: tensor, + %b1: tensor): + %2 = stablehlo.add %a0, %b0 : tensor + %3 = stablehlo.add %a1, %b1 : tensor + "stablehlo.return"(%2,%3) : (tensor, tensor) -> () + }) + { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, + window_dimensions = dense<[5, 1]> : tensor<2xi64>, + window_strides = dense<[3, 1]> : tensor<2xi64> } + : (tensor<4x2xf32>, tensor<4x2xf32>, tensor, tensor) -> + (tensor<2x2xf64>, tensor<2x2xf32>) + func.return %0#0, %0#1 : tensor<2x2xf64>, tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @reduce_window_with_promotable_quantized_types +func.func @reduce_window_with_promotable_quantized_types(%arg0: tensor<4x2x!quant.uniform>, + %init0: tensor>) -> (tensor<2x2x!quant.uniform>) { + + %0 = "stablehlo.reduce_window"(%arg0, %init0) ({ + ^bb0(%a0: tensor>, %b0: tensor>): + %1 = stablehlo.add %a0, %b0 : tensor> + "stablehlo.return"(%1) : (tensor>) -> () + }) + { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, + window_dimensions = dense<[5, 1]> : tensor<2xi64>, + window_strides = dense<[3, 1]> : tensor<2xi64> + } + : (tensor<4x2x!quant.uniform>, tensor>) -> (tensor<2x2x!quant.uniform>) + func.return %0 : tensor<2x2x!quant.uniform> +} + +// ----- + func.func @reduce_window_c1(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, %init0: tensor, %init1: tensor) -> (tensor<2x2xf32>, tensor<2x2xi32>) { @@ -123,7 +163,7 @@ func.func @reduce_window_c2(%arg0: tensor<4x2xf32>, func.func @reduce_window_c3(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, %init0: tensor, %init1: tensor) -> (tensor<2x2xf32>, tensor<2x2xf32>) { - // expected-error@+1 {{The element-type of reduction-region's argument at index 3 is expected to be 'i32', but got 'tensor' as its type.}} + // expected-error@+1 {{The element-type of reduction-region's argument at index 3 is expected to be promotable from 'i32', but got 'f32'}} %0:2 = "stablehlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({ ^bb0(%a0: tensor, %a1: tensor, %b0: tensor, %b1: tensor): @@ -490,7 +530,7 @@ func.func @reduce_window_c13(%arg0: tensor<4x2xf32>, func.func @reduce_window_c13(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, %init0: tensor, %init1: tensor) -> (tensor<2x2xf32>, tensor<2x2xi32>) { - // expected-error@+1 {{The type of reduction-region's result type at index 1 differs from the op's corresponding init-value type: 'tensor' vs 'tensor'}} + // expected-error@+1 {{The element-type of reduction-region's result type at index 1 is expected to be promotable from the op's corresponding init-value element-type: 'tensor' vs 'tensor'}} %0:2 = "stablehlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({ ^bb0(%a0: tensor, %a1: tensor, %b0: tensor, %b1: tensor): @@ -544,10 +584,86 @@ func.func @reduce_window_c13(%arg0: tensor<4x2xf32>, %init0: tensor<4x2xf32>) // ----- +func.func @reduce_window_c13(%arg0: tensor<4x2xi32>, %init0: tensor) -> + (tensor<2x2xi8>) { + + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor' vs 'tensor'}} + %0 = "stablehlo.reduce_window"(%arg0, %init0) ({ + ^bb0(%a0: tensor, %b0: tensor): + %1 = stablehlo.add %a0, %b0 : tensor + "stablehlo.return"(%1) : (tensor) -> () + }) + { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, + window_dimensions = dense<[5, 1]> : tensor<2xi64>, + window_strides = dense<[3, 1]> : tensor<2xi64> + } + : (tensor<4x2xi32>, tensor) -> (tensor<2x2xi8>) + func.return %0 : tensor<2x2xi8> +} + +// ----- + +func.func @reduce_window_c13(%arg0: tensor<4x2xi32>, %init0: tensor) -> + (tensor<2x2xi8>) { + + // expected-error@+1 {{The element-type of reduction-region's argument at index 1 is expected to be promotable from 'i32', but got 'i8'}} + %0 = "stablehlo.reduce_window"(%arg0, %init0) ({ + ^bb0(%a0: tensor, %b0: tensor): + %1 = stablehlo.add %a0, %b0 : tensor + "stablehlo.return"(%1) : (tensor) -> () + }) + { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, + window_dimensions = dense<[5, 1]> : tensor<2xi64>, + window_strides = dense<[3, 1]> : tensor<2xi64> + } + : (tensor<4x2xi32>, tensor) -> (tensor<2x2xi8>) + func.return %0 : tensor<2x2xi8> +} + +// ----- + +func.func @reduce_window_c13(%arg0: tensor<4x2x!quant.uniform>, + %init0: tensor>) -> (tensor<2x2x!quant.uniform>) { + + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor>' vs 'tensor>'}} + %0 = "stablehlo.reduce_window"(%arg0, %init0) ({ + ^bb0(%a0: tensor>, %b0: tensor>): + %1 = stablehlo.add %a0, %b0 : tensor> + "stablehlo.return"(%1) : (tensor>) -> () + }) + { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, + window_dimensions = dense<[5, 1]> : tensor<2xi64>, + window_strides = dense<[3, 1]> : tensor<2xi64> + } + : (tensor<4x2x!quant.uniform>, tensor>) -> (tensor<2x2x!quant.uniform>) + func.return %0 : tensor<2x2x!quant.uniform> +} + +// ----- + +func.func @reduce_window_c13(%arg0: tensor<4x2x!quant.uniform>, + %init0: tensor>) -> (tensor<2x2x!quant.uniform>) { + + // expected-error@+1 {{The element-type of reduction-region's argument at index 1 is expected to be promotable from '!quant.uniform', but got '!quant.uniform'}} + %0 = "stablehlo.reduce_window"(%arg0, %init0) ({ + ^bb0(%a0: tensor>, %b0: tensor>): + %1 = stablehlo.add %a0, %b0 : tensor> + "stablehlo.return"(%1) : (tensor>) -> () + }) + { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, + window_dimensions = dense<[5, 1]> : tensor<2xi64>, + window_strides = dense<[3, 1]> : tensor<2xi64> + } + : (tensor<4x2x!quant.uniform>, tensor>) -> (tensor<2x2x!quant.uniform>) + func.return %0 : tensor<2x2x!quant.uniform> +} + +// ----- + func.func @reduce_window_i2(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, %init0: tensor<1xf32>, %init1: tensor<1xi32>) -> (tensor<2x2xf32>, tensor<2x2xi32>) { - // expected-error@+1 {{The type of reduction-region's result type at index 0 differs from the op's corresponding init-value type: 'tensor' vs 'tensor<1xf32>'}} + // expected-error@+1 {{The shape of reduction-region's result type at index 0 differs from the op's corresponding init-value type: 'tensor' vs 'tensor<1xf32>'}} %0:2 = "stablehlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({ ^bb0(%a0: tensor, %a1: tensor, %b0: tensor, %b1: tensor): diff --git a/stablehlo/tests/verify_scatter.mlir b/stablehlo/tests/verify_scatter.mlir index 1b7ec22fdcd..cc7284a6892 100644 --- a/stablehlo/tests/verify_scatter.mlir +++ b/stablehlo/tests/verify_scatter.mlir @@ -71,6 +71,55 @@ func.func @valid_scatter_dimensions_with_dynamic_index_vector_dim( // ----- +// CHECK: func @scatter_with_promotable_types +func.func @scatter_with_promotable_types(%input_tensor: tensor<200x100x300xf32>, + %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> + tensor<200x100x300xf64> { + %0 = "stablehlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ + ^bb0(%lhs: tensor, %rhs: tensor): + %add = stablehlo.add %lhs, %rhs : tensor + "stablehlo.return"(%add) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> + tensor<200x100x300xf64> + func.return %0 : tensor<200x100x300xf64> +} + +// ----- + +// CHECK: func @scatter_with_promotable_quantized_types +func.func @scatter_with_promotable_quantized_types(%input_tensor: tensor<200x100x300x!quant.uniform>, + %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300x!quant.uniform>) -> + tensor<200x100x300x!quant.uniform> { + %0 = "stablehlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ + ^bb0(%lhs: tensor>, %rhs: tensor>): + %add = stablehlo.add %lhs, %rhs : tensor> + "stablehlo.return"(%add) : (tensor>) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<200x100x300x!quant.uniform>, tensor<10x2xi32>, + tensor<10x300x!quant.uniform>) -> + tensor<200x100x300x!quant.uniform> + func.return %0 : tensor<200x100x300x!quant.uniform> +} + +// ----- + func.func @scatter_c1(%arg0: tensor<3xi32>, %arg1: tensor<1x1xi32>, %arg2: tensor<1xi32>) -> tensor<3xi32> { // expected-error @+1 {{Not all inputs have compatible shapes.}} @@ -233,7 +282,7 @@ func.func @scatter_c4(%input_tensor: tensor<200x100x300xf32>, func.func @scatter_c6_c15(%input_tensor: tensor<200x100x300xf32>, %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xi32>) -> tensor<200x100x300xf32> { - // expected-error@+1 {{The type of reduction-region's result type at index 0 differs from the op's corresponding init-value type: 'tensor' vs 'tensor'}} + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor' vs 'tensor'}} %0 = "stablehlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ ^bb0(%lhs: tensor, %rhs: tensor): %add = stablehlo.add %lhs, %rhs : tensor @@ -257,7 +306,7 @@ func.func @scatter_c6_c15(%input_tensor: tensor<200x100x300xf32>, func.func @scatter_c6_c15_c16(%input_tensor: tensor<200x100x300xi32>, %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> tensor<200x100x300xf32> { - // expected-error@+1 {{The element-type of reduction-region's argument at index 1 is expected to be 'i32', but got 'tensor' as its type.}} + // expected-error@+1 {{The element-type of reduction-region's argument at index 1 is expected to be promotable from 'i32', but got 'f32'}} %0 = "stablehlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ ^bb0(%lhs: tensor, %rhs: tensor): %add = stablehlo.add %lhs, %rhs : tensor @@ -794,3 +843,103 @@ func.func @scatter_c15(%input_tensor: tensor<200x100x300xf32>, tensor<200x100x300xf32> func.return %0 : tensor<200x100x300xf32> } + +// ----- + +func.func @scatter_c15(%input_tensor: tensor<200x100x300xi32>, + %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xi32>) -> + tensor<200x100x300xi8> { + + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor' vs 'tensor'}} + %0 = "stablehlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ + ^bb0(%lhs: tensor, %rhs: tensor): + %add = stablehlo.add %lhs, %rhs : tensor + "stablehlo.return"(%add) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<200x100x300xi32>, tensor<10x2xi32>, tensor<10x300xi32>) -> + tensor<200x100x300xi8> + func.return %0 : tensor<200x100x300xi8> +} + +// ----- + +func.func @scatter_c15(%input_tensor: tensor<200x100x300xi32>, + %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xi8>) -> + tensor<200x100x300xi8> { + + // expected-error@+1 {{The element-type of reduction-region's argument at index 1 is expected to be promotable from 'i32', but got 'i8'}} + %0 = "stablehlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ + ^bb0(%lhs: tensor, %rhs: tensor): + %add = stablehlo.add %lhs, %rhs : tensor + "stablehlo.return"(%add) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<200x100x300xi32>, tensor<10x2xi32>, tensor<10x300xi8>) -> + tensor<200x100x300xi8> + func.return %0 : tensor<200x100x300xi8> +} + +// ----- + +func.func @scatter_c15(%input_tensor: tensor<200x100x300x!quant.uniform>, + %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300x!quant.uniform>) -> + tensor<200x100x300x!quant.uniform> { + + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor>' vs 'tensor>'}} + %0 = "stablehlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ + ^bb0(%lhs: tensor>, %rhs: tensor>): + %add = stablehlo.add %lhs, %rhs : tensor> + "stablehlo.return"(%add) : (tensor>) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<200x100x300x!quant.uniform>, tensor<10x2xi32>, tensor<10x300x!quant.uniform>) -> + tensor<200x100x300x!quant.uniform> + func.return %0 : tensor<200x100x300x!quant.uniform> +} + +// ----- + +func.func @scatter_c15(%input_tensor: tensor<200x100x300x!quant.uniform>, + %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300x!quant.uniform>) -> + tensor<200x100x300x!quant.uniform> { + + // expected-error@+1 {{The element-type of reduction-region's argument at index 1 is expected to be promotable from '!quant.uniform', but got '!quant.uniform'}} + %0 = "stablehlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ + ^bb0(%lhs: tensor>, %rhs: tensor>): + %add = stablehlo.add %lhs, %rhs : tensor> + "stablehlo.return"(%add) : (tensor>) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<200x100x300x!quant.uniform>, tensor<10x2xi32>, tensor<10x300x!quant.uniform>) -> + tensor<200x100x300x!quant.uniform> + func.return %0 : tensor<200x100x300x!quant.uniform> +} diff --git a/stablehlo/tests/verify_select_and_scatter.mlir b/stablehlo/tests/verify_select_and_scatter.mlir index bede69572aa..7857ac85ca1 100644 --- a/stablehlo/tests/verify_select_and_scatter.mlir +++ b/stablehlo/tests/verify_select_and_scatter.mlir @@ -26,6 +26,62 @@ func.func @select_and_scatter( // ----- +// CHECK: func @select_and_scatter_with_promotable_types +func.func @select_and_scatter_with_promotable_types( + %arg0: tensor<10x24x24x64xf32>, + %arg1: tensor<10x12x12x64xf32>) -> () { + %0 = stablehlo.constant dense<0.000000e+00> : tensor + %1 = "stablehlo.select_and_scatter"(%arg0, %arg1, %0) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %2 = "stablehlo.compare"(%arg3, %arg4) { + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + "stablehlo.return"(%2) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %2 = stablehlo.add %arg3, %arg4 : tensor + "stablehlo.return"(%2) : (tensor) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>, + padding = dense<0> : tensor<4x2xi64> + } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> + tensor<10x24x24x64xf64> + func.return +} + +// ----- + +// CHECK-LABEL: func @select_and_scatter_with_promotable_quantized_types +func.func @select_and_scatter_with_promotable_quantized_types( + %arg0: tensor<10x24x24x64x!quant.uniform>, + %arg1: tensor<10x12x12x64x!quant.uniform>, + %arg2 : tensor>) -> + tensor<10x24x24x64x!quant.uniform> { + + %1 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor>, %arg4: tensor>): + %2 = "stablehlo.compare"(%arg3, %arg4) { + compare_type = #stablehlo, + comparison_direction = #stablehlo + } : (tensor>, tensor>) -> tensor + "stablehlo.return"(%2) : (tensor) -> () + }, { + ^bb0(%arg3: tensor>, %arg4: tensor>): + %2 = stablehlo.add %arg3, %arg4 : tensor> + "stablehlo.return"(%2) : (tensor>) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64> + } : (tensor<10x24x24x64x!quant.uniform>, + tensor<10x12x12x64x!quant.uniform>, + tensor>) -> + tensor<10x24x24x64x!quant.uniform> + func.return %1 : tensor<10x24x24x64x!quant.uniform> +} + +// ----- + // CHECK: func @select_and_scatter_with_unranked_dims func.func @select_and_scatter_with_unranked_dims( %arg0: tensor<4x5x1x1xbf16>, @@ -607,7 +663,7 @@ func.func @select_and_scatter_c10( %arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>) -> () { %0 = stablehlo.constant dense<0.000000e+00> : tensor - // expected-error @+1 {{The type of reduction-region's result type at index 0 differs from the op's corresponding init-value type: 'tensor' vs 'tensor'}} + // expected-error @+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor' vs 'tensor'}} %1 = "stablehlo.select_and_scatter"(%arg0, %arg1, %0) ({ ^bb0(%arg3: tensor, %arg4: tensor): %2 = "stablehlo.compare"(%arg3, %arg4) { @@ -633,7 +689,7 @@ func.func @select_and_scatter_c10( %arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>) -> () { %0 = stablehlo.constant dense<0> : tensor - // expected-error @+1 {{The element-type of reduction-region's argument at index 1 is expected to be 'f32', but got 'tensor' as its type.}} + // expected-error @+1 {{The element-type of reduction-region's argument at index 1 is expected to be promotable from 'f32', but got 'i32'}} %1 = "stablehlo.select_and_scatter"(%arg0, %arg1, %0) ({ ^bb0(%arg3: tensor, %arg4: tensor): %2 = "stablehlo.compare"(%arg3, %arg4) { @@ -678,3 +734,86 @@ func.func @select_and_scatter_c10( tensor<10x24x24x64xf32> func.return } + +// ----- + +func.func @select_and_scatter_c10( + %arg0: tensor<10x24x24x64xi32>, + %arg1: tensor<10x12x12x64xi32>) -> tensor<10x24x24x64xi8> { + %0 = stablehlo.constant dense<0> : tensor + + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor' vs 'tensor'}} + %1 = "stablehlo.select_and_scatter"(%arg0, %arg1, %0) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %2 = "stablehlo.compare"(%arg3, %arg4) { + compare_type = #stablehlo, + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + "stablehlo.return"(%2) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %2 = stablehlo.add %arg3, %arg4 : tensor + "stablehlo.return"(%2) : (tensor) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64> + } : (tensor<10x24x24x64xi32>, tensor<10x12x12x64xi32>, tensor) -> + tensor<10x24x24x64xi8> + func.return %1 : tensor<10x24x24x64xi8> +} + +// ----- + +func.func @select_and_scatter_c10( + %arg0: tensor<10x24x24x64xf32>, + %arg1: tensor<10x12x12x64xf32>) -> () { + %0 = stablehlo.constant dense<0> : tensor + // expected-error @+1 {{The element-type of reduction-region's argument at index 1 is expected to be promotable from 'f32', but got 'i8'}} + %1 = "stablehlo.select_and_scatter"(%arg0, %arg1, %0) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %2 = "stablehlo.compare"(%arg3, %arg4) { + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + "stablehlo.return"(%2) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %2 = stablehlo.add %arg3, %arg4 : tensor + "stablehlo.return"(%2) : (tensor) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>, + padding = dense<0> : tensor<4x2xi64> + } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> + tensor<10x24x24x64xf32> + func.return +} + +// ----- + +func.func @select_and_scatter_c10( + %arg0: tensor<10x24x24x64x!quant.uniform>, + %arg1: tensor<10x12x12x64x!quant.uniform>, + %arg2 : tensor>) -> + tensor<10x24x24x64x!quant.uniform> { + + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor>' vs 'tensor>'}} + %1 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor>, %arg4: tensor>): + %2 = "stablehlo.compare"(%arg3, %arg4) { + compare_type = #stablehlo, + comparison_direction = #stablehlo + } : (tensor>, tensor>) -> tensor + "stablehlo.return"(%2) : (tensor) -> () + }, { + ^bb0(%arg3: tensor>, %arg4: tensor>): + %2 = stablehlo.add %arg3, %arg4 : tensor> + "stablehlo.return"(%2) : (tensor>) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64> + } : (tensor<10x24x24x64x!quant.uniform>, + tensor<10x12x12x64x!quant.uniform>, + tensor>) -> + tensor<10x24x24x64x!quant.uniform> + func.return %1 : tensor<10x24x24x64x!quant.uniform> +}