Skip to content

Commit

Permalink
Verifier and Type inference changes for reduction operations (#1869)
Browse files Browse the repository at this point in the history
Implements the specification changes at
#1796.

The PR adds/updates  the verifier and type inference routines for the
following ops: `reduce, reduce_window, select_and_scatter, all_reduce,
reduce_scatter, scatter`. Please refer to #1796 for the updated
constraints which the PR implements. Note the #1796 is going to be
merged soon.

Here are the changes for each operation:

 - reduce 
     - #1796 added a new constraint C8 
     - Updated labels
- Add positive tests and negative tests verifying reduce_c6  at
verify_reduce.mlir  and type inference tests at infrer_stablehlo.mlir,
for reduce_c8.

  - reduce_window: 
     - #1796 updated the C16
- Add positive tests ; negative tests verifying reduce_window_c13  at
verify_reduce_window.mlir  and type inference tests at
infrer_stablehlo.mlir, for reduce_window_c16.

  - select_and_scatter
     - #1796 added a new constraint C12
     - Updated labels
- Add positive tests ; negative tests verifying selelct_and_scatter_c10 
at verify_select_and_scatter.mlir  and type inference tests at
infrer_stablehlo.mlir, for select_and_scatter_c12.

  - scatter
   - #1796 added a new constraint C17
   -   Updated labels
   - Add positive tests ; negative tests verifying scatter_c15  at
verify_scatter.mlir  and type inference tests at infrer_stablehlo.mlir,
for scatter_C17.

  - reduce_scatter
     - #1796 added a new constraint C9 
- This op does not have the type inference supported. We had a trait
`SameOperandsAndResultElementType` implementing the outdated constraint.
For the new constraint C9, we added a check at `verifyReduceScatterOp`. 
- Add positive tests; negative tests verifying reduce_scatter_C7  at
ops_stablehlo.mlir  and type inference tests at ops_stablehlo.mlir,
for reduce_scatter_C9.

  - all_reduce
     - #1796 added a new constraint C7
- This op implemented an outdated constraint related to type inference
using `inferReturnTypeComponentsFromOperands`. For the new constraint
C9, we added a trait `InferTensorType` in the tablegen definition of the
op. 
     - Updated labels 
- Add positive tests; negative tests verifying all_reduce_C5  at
ops_stablehlo.mlir  and type inference tests at ops_stablehlo.mlir,
for all_reduce_C7.
  • Loading branch information
sdasgup3 authored Dec 20, 2023
1 parent cb42882 commit f8dcebf
Show file tree
Hide file tree
Showing 20 changed files with 3,699 additions and 75 deletions.
20 changes: 15 additions & 5 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ LogicalResult ReduceScatterOp::verify() {
}

INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AddOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AllReduceOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AndOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Atan2Op)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CbrtOp)
Expand Down Expand Up @@ -919,6 +918,15 @@ LogicalResult AllReduceOp::verify() {
getComputation());
}

LogicalResult AllReduceOp::inferReturnTypeComponents(
MLIRContext*, std::optional<Location> location, ValueShapeRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
AllReduceOp::Adaptor adaptor(operands, attributes, properties, regions);
return hlo::inferAllReduceOp(location, adaptor.getOperand(),
adaptor.getComputation(), inferredReturnShapes);
}

//===----------------------------------------------------------------------===//
// BatchNormGradOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1379,7 +1387,7 @@ LogicalResult ReduceWindowOp::inferReturnTypeComponents(
location, adaptor.getInputs(), adaptor.getInitValues(),
adaptor.getWindowDimensions(), adaptor.getWindowStrides(),
adaptor.getBaseDilations(), adaptor.getWindowDilations(),
adaptor.getPadding(), inferredReturnShapes);
adaptor.getPadding(), adaptor.getBody(), inferredReturnShapes);
}

LogicalResult ReduceWindowOp::verify() {
Expand Down Expand Up @@ -1782,7 +1790,8 @@ LogicalResult ReduceOp::inferReturnTypeComponents(
ReduceOp::Adaptor adaptor(operands, attributes, properties, regions);
return hlo::inferReduceOp(location, adaptor.getInputs().getTypes(),
adaptor.getInitValues().getTypes(),
adaptor.getDimensions(), inferredReturnShapes);
adaptor.getDimensions(), adaptor.getBody(),
inferredReturnShapes);
}

LogicalResult ReduceOp::verify() {
Expand Down Expand Up @@ -2295,8 +2304,8 @@ LogicalResult SelectAndScatterOp::inferReturnTypes(
SmallVectorImpl<Type>& inferredReturnTypes) {
SelectAndScatterOp::Adaptor adaptor(operands, attributes, properties,
regions);
return hlo::inferSelectAndScatterOp(adaptor.getOperand(),
inferredReturnTypes);
return hlo::inferSelectAndScatterOp(
adaptor.getOperand(), adaptor.getScatter(), inferredReturnTypes);
}

LogicalResult SelectAndScatterOp::verify() {
Expand All @@ -2316,6 +2325,7 @@ LogicalResult ScatterOp::inferReturnTypes(
SmallVectorImpl<Type>& inferredReturnTypes) {
ScatterOp::Adaptor adaptor(operands, attributes, properties, regions);
return hlo::inferScatterOp(location, adaptor.getInputs(),
adaptor.getUpdateComputation(),
inferredReturnTypes);
}

Expand Down
16 changes: 7 additions & 9 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1327,7 +1327,7 @@ def StableHLO_AllGatherOp : StableHLO_Op<"all_gather",
}

def StableHLO_AllReduceOp : StableHLO_Op<"all_reduce",
[HLO_CompatibleOperandsAndResultType /*all_reduce_c6*/]> {
[InferTensorType /*all_reduce_c6, all_reduce_c7*/]> {
let summary = "AllReduce operation";
let description = [{
Within each process group in the process grid, applies a reduction function
Expand Down Expand Up @@ -1362,8 +1362,7 @@ def StableHLO_AllReduceOp : StableHLO_Op<"all_reduce",
let hasVerifier = 1;
}

def StableHLO_ReduceScatterOp : StableHLO_Op<"reduce_scatter",
[SameOperandsAndResultElementType /*reduce_scatter_c8*/]> {
def StableHLO_ReduceScatterOp : StableHLO_Op<"reduce_scatter", []> {
let summary = "ReduceScatter operation";
let description = [{
Within each process group in the process grid, performs reduction, using
Expand Down Expand Up @@ -1448,7 +1447,7 @@ def StableHLO_AllToAllOp : StableHLO_Op<"all_to_all",
def StableHLO_ReduceOp: StableHLO_ShapedInterfaceOp<"reduce", [
RecursiveMemoryEffects,
SameVariadicOperandSize /*reduce_c3*/,
InferTensorTypeWithReify /*reduce_c7*/,
InferTensorTypeWithReify /*reduce_c7, reduce_c8*/,
SingleBlockImplicitTerminator<"ReturnOp">
]> { /*reduce_c7*/
let summary = "Reduce operation";
Expand All @@ -1475,8 +1474,6 @@ def StableHLO_ReduceOp: StableHLO_ShapedInterfaceOp<"reduce", [
Variadic<HLO_Tensor>:$init_values, /*reduce_i2*/
I64ElementsAttr:$dimensions /*reduce_i3*/
);
// TODO(hinsu): Verify that the attached body arguments and results are
// compatible with reduce op's operands.
let regions = (region SizedRegion<1>:$body /*reduce_i4*/);

let results = (outs Variadic<HLO_Tensor>);
Expand Down Expand Up @@ -2514,7 +2511,8 @@ def StableHLO_DynamicReshapeOp: StableHLO_ShapedInterfaceOp<"dynamic_reshape", [

def StableHLO_ScatterOp: StableHLO_Op<"scatter", [RecursiveMemoryEffects,
SameVariadicOperandSize /*scatter_c5*/,
DeclareOpInterfaceMethods<InferTypeOpInterface> /*scatter_c16*/]> {
DeclareOpInterfaceMethods<InferTypeOpInterface> /*scatter_c16,
scater_c17*/]> {
let summary = "Scatter operation";
let description = [{
Produces `results` tensors which are equal to `inputs` tensors except that
Expand Down Expand Up @@ -2587,8 +2585,8 @@ def StableHLO_SelectOp: StableHLO_Op<"select", [Pure, HLO_BroadcastingElementwis
}

def StableHLO_SelectAndScatterOp: StableHLO_Op<"select_and_scatter",
[DeclareOpInterfaceMethods<InferTypeOpInterface> /*select_and_scatter_c11*/,
RecursiveMemoryEffects]> {
[DeclareOpInterfaceMethods<InferTypeOpInterface> /*select_and_scatter_c11,
select_and_scatter_c12*/, RecursiveMemoryEffects]> {
let summary = "SelectAndScatter operation";
let description = [{
Scatters the values from the `source` tensor using `scatter` based on the
Expand Down
155 changes: 127 additions & 28 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,52 @@ bool tensorsHaveSameElType(Type type1, Type type2,
return tensorsHaveSameElType({type1, type2}, ignoreFpPrecision);
}

unsigned getBitWidth(Type type) {
if (auto complexTy = type.dyn_cast<ComplexType>())
return 2 * getBitWidth(complexTy.getElementType());
if (auto quantTy = type.dyn_cast<quant::QuantizedType>())
return getBitWidth(quantTy.getStorageType());
return type.getIntOrFloatBitWidth();
}

template <typename T>
bool matchesType(Type a, Type b) {
bool matches = a.isa<T>() && b.isa<T>();
// Check that expressed type matches for quantized types
if constexpr (std::is_same<T, quant::QuantizedType>::value) {
return matches && (a.cast<quant::QuantizedType>().getExpressedType() ==
b.cast<quant::QuantizedType>().getExpressedType());
}
return matches;
}

// Returns true if the element-type of type1 can be promoted to that of type2.
// An element-type 'x' is promotatble to element-type 'y' is they have the same
// base type and bitwidth(x) <= bitwidth(y). When 'x' and 'y' are quantized
// element-types, then promotion is applied only to the 'storage_type'
// component.
bool isPromotableElementType(Type type1, Type type2,
bool ignoreFpPrecision = false) {
auto tensorTy1 = type1.dyn_cast<ShapedType>();
auto tensorTy2 = type2.dyn_cast<ShapedType>();

if (!tensorTy1 || !tensorTy2) return false;

Type tensorEl1 = tensorTy1.getElementType();
Type tensorEl2 = tensorTy2.getElementType();

bool isSameType = matchesType<IntegerType>(tensorEl1, tensorEl2) ||
matchesType<FloatType>(tensorEl1, tensorEl2) ||
matchesType<ComplexType>(tensorEl1, tensorEl2) ||
matchesType<quant::QuantizedType>(tensorEl1, tensorEl2);

if (!isSameType) return false;

if (ignoreFpPrecision && tensorEl1.isa<FloatType>()) return true;

return getBitWidth(tensorEl1) <= getBitWidth(tensorEl2);
}

// Return true if type1 and type2 are shape-compatible and have same element
// type. If 'ignoreFpPrecision' is True, then allow floats with different
// precisions while checking element-types.
Expand Down Expand Up @@ -405,12 +451,6 @@ SmallVector<int64_t> inferWindowOutputShape(ArrayRef<int64_t> baseShape,
return outputDimensions;
}

unsigned potentiallyComplexBitWidth(Type type) {
auto complexTy = type.dyn_cast<ComplexType>();
return complexTy ? 2 * complexTy.getElementType().getIntOrFloatBitWidth()
: type.getIntOrFloatBitWidth();
}

LogicalResult verifyReplicaGroups(std::optional<Location> location,
DenseIntElementsAttr replicaGroups,
bool allGroupsMustHaveSameSize,
Expand Down Expand Up @@ -530,6 +570,17 @@ LogicalResult verifyReduceOpInputsAndInferShape(
return success();
}

// Returns the types of the terminator arguments of the input mlir::Block
// 'block'.
SmallVector<ShapedType> getAccumulatorTypes(Block& block) {
SmallVector<ShapedType> accumulatorSubShapes;
for (Value retOperand : block.getTerminator()->getOperands()) {
auto shapedTy = retOperand.getType().cast<ShapedType>();
accumulatorSubShapes.push_back(shapedTy);
}
return accumulatorSubShapes;
}

LogicalResult verifyReducerShape(std::optional<Location> loc, Block& block,
ArrayRef<ShapedType> inputTypes,
ArrayRef<ShapedType> initValueTypes,
Expand Down Expand Up @@ -598,24 +649,35 @@ LogicalResult verifyReducerShape(std::optional<Location> loc, Block& block,

// all_reduce_c5, reduce_c6, reduce_scatter_c7, reduce_window_c13,
// reduce_window_i2, scatter_c6, scatter_c15, select_and_scatter_c10
if (!compatibleShapeAndElementType(accumulatorSubShapes[inputIdx],
initValueTypes[inputIdx],
/*ignoreFpPrecision=*/true))
if (failed(verifyCompatibleShape(initValueTypes[inputIdx],
accumulatorSubShapes[inputIdx])))
return emitOptionalError(
loc, "The shape of reduction-region's result type at index ",
inputIdx, " differs from the op's corresponding init-value type: ",
accumulatorSubShapes[inputIdx], " vs ", initValueTypes[inputIdx]);

if (!isPromotableElementType(initValueTypes[inputIdx],
accumulatorSubShapes[inputIdx],
/*ignoreFpPrecision=*/true))
return emitOptionalError(
loc, "The type of reduction-region's result type at index ", inputIdx,
" differs from the op's corresponding init-value type: ",
loc, "The element-type of reduction-region's result type at index ",
inputIdx,
" is expected to be promotable from the op's corresponding "
"init-value element-type: ",
accumulatorSubShapes[inputIdx], " vs ", initValueTypes[inputIdx]);

// reduce_c6, reduce_window_c3, scatter_c6, scatter_c15,
// select_and_scatter_c10
if (!tensorsHaveSameElType(
if (!isPromotableElementType(
inputTypes[inputIdx],
block.getArgument(numInputs + inputIdx).getType(), true))
block.getArgument(numInputs + inputIdx).getType(),
/*ignoreFpPrecision=*/true))
return emitOptionalError(
loc, "The element-type of reduction-region's argument at index ",
numInputs + inputIdx, " is expected to be ",
numInputs + inputIdx, " is expected to be promotable from ",
inputTypes[inputIdx].getElementType(), ", but got ",
block.getArgument(numInputs + inputIdx).getType(), " as its type.");
getElementTypeOrSelf(
block.getArgument(numInputs + inputIdx).getType()));

Type blockArgType = block.getArgument(numInputs + inputIdx).getType();
auto blockArgTensorTy = blockArgType.cast<ShapedType>();
Expand Down Expand Up @@ -1453,6 +1515,18 @@ LogicalResult inferAllToAllOp(
return success();
}

LogicalResult inferAllReduceOp(
std::optional<Location> location, Value operand, Region& computation,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
// all_reduce_c6, all_reduce_c7
SmallVector<ShapedType> accumulatorTypes =
getAccumulatorTypes(computation.front());
auto operandShapedTy = operand.getType().cast<ShapedType>();
inferredReturnShapes.emplace_back(getSameShapeTensorType(
operandShapedTy, accumulatorTypes[0].getElementType()));
return success();
}

LogicalResult inferBatchNormGradOp(
std::optional<Location> location, Value operand, Value scale, Value mean,
Value variance, Value gradOutput, int64_t featureIndex,
Expand Down Expand Up @@ -2532,7 +2606,7 @@ LogicalResult inferRealOp(std::optional<Location>, Value operand,

LogicalResult inferReduceOp(
std::optional<Location> location, TypeRange inputTypes,
TypeRange initValueTypes, DenseIntElementsAttr dimensions,
TypeRange initValueTypes, DenseIntElementsAttr dimensions, Region& body,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
SmallVector<ShapedType> inputArgTensorTypes{
llvm::map_range(inputTypes, [](Type t) { return t.cast<ShapedType>(); })};
Expand All @@ -2546,10 +2620,11 @@ LogicalResult inferReduceOp(
initValueTensorTypes, dimensions,
newDimensions, encoding)))
return failure();
// reduce_c2, reduce_c3, reduce_c7
// reduce_c3, reduce_c7, reduce_c8
SmallVector<ShapedType> accumulatorTypes = getAccumulatorTypes(body.front());
for (uint64_t inputIdx = 0; inputIdx < inputTypes.size(); ++inputIdx) {
ShapedType inputType = inputArgTensorTypes[inputIdx];
Type elementType = inputType.getElementType();
Type elementType = accumulatorTypes[inputIdx].getElementType();
if (inputType.hasRank())
inferredReturnShapes.emplace_back(newDimensions, elementType, encoding);
else
Expand All @@ -2565,7 +2640,7 @@ LogicalResult inferReduceWindowOp(
std::optional<DenseIntElementsAttr> windowStrides,
std::optional<DenseIntElementsAttr> baseDilations,
std::optional<DenseIntElementsAttr> windowDilations,
std::optional<DenseIntElementsAttr> padding,
std::optional<DenseIntElementsAttr> padding, Region& body,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
SmallVector<ShapedType> inputTypes{llvm::map_range(
inputs.getTypes(), [](Type t) { return t.cast<ShapedType>(); })};
Expand All @@ -2582,21 +2657,22 @@ LogicalResult inferReduceWindowOp(
return failure();

// reduce_window_c1, reduce_window_c14...reduce_window_c16
SmallVector<ShapedType> accumulatorTypes = getAccumulatorTypes(body.front());
for (size_t i = 0; i < inputTypes.size(); ++i) {
auto inputRankedType = inputs[i].getType().dyn_cast<RankedTensorType>();
if (!inputRankedType) {
inferredReturnShapes.emplace_back(inputTypes[i].getElementType());
inferredReturnShapes.emplace_back(accumulatorTypes[i].getElementType());
} else {
auto resultShape =
inferWindowOutputShape(inputTypes[i].getShape(), inferredWindow);
auto inputBounds = encodingToBounds(inputRankedType.getEncoding());
if (inputBounds.empty()) {
inferredReturnShapes.emplace_back(resultShape,
inputTypes[i].getElementType());
accumulatorTypes[i].getElementType());
} else {
auto resultBounds = inferWindowOutputShape(inputBounds, inferredWindow);
inferredReturnShapes.emplace_back(
resultShape, inputTypes[i].getElementType(),
resultShape, accumulatorTypes[i].getElementType(),
boundsToEncoding(inputRankedType.getEncoding(), resultBounds));
}
}
Expand Down Expand Up @@ -2661,8 +2737,16 @@ LogicalResult inferRngOp(
}

LogicalResult inferScatterOp(std::optional<Location>, ValueRange inputs,
Region& updateComputation,
SmallVectorImpl<Type>& inferredReturnTypes) {
llvm::append_range(inferredReturnTypes, inputs.getTypes());
// scatter_c16, scatter_c17
SmallVector<ShapedType> accumulatorTypes =
getAccumulatorTypes(updateComputation.front());
for (uint64_t inputIdx = 0; inputIdx < inputs.size(); ++inputIdx) {
auto inputShapedTy = inputs[inputIdx].getType().cast<ShapedType>();
inferredReturnTypes.push_back(getSameShapeTensorType(
inputShapedTy, accumulatorTypes[inputIdx].getElementType()));
}
return success();
}

Expand Down Expand Up @@ -2692,9 +2776,14 @@ LogicalResult inferSelectOp(
}

LogicalResult inferSelectAndScatterOp(
Value operand, SmallVectorImpl<Type>& inferredReturnTypes) {
// select_and_scatter_c11
inferredReturnTypes.push_back(operand.getType());
Value operand, Region& scatter,
SmallVectorImpl<Type>& inferredReturnTypes) {
// select_and_scatter_c11, select_and_scatter_c12
SmallVector<ShapedType> accumulatorTypes =
getAccumulatorTypes(scatter.front());
auto operandShapedTy = operand.getType().cast<ShapedType>();
inferredReturnTypes.push_back(getSameShapeTensorType(
operandShapedTy, accumulatorTypes[0].getElementType()));
return success();
}

Expand Down Expand Up @@ -3139,8 +3228,8 @@ LogicalResult verifyBitcastConvertOp(std::optional<Location> location,
location, "cannot convert between real and complex types, but got: ",
operandShapedType, " and ", targetShapedType);

auto targetEltBitWidth = potentiallyComplexBitWidth(targetElt);
auto operandEltBitWidth = potentiallyComplexBitWidth(operandElt);
auto targetEltBitWidth = getBitWidth(targetElt);
auto operandEltBitWidth = getBitWidth(operandElt);

auto operandType = operandShapedType.dyn_cast<RankedTensorType>();
auto targetType = targetShapedType.dyn_cast<RankedTensorType>();
Expand Down Expand Up @@ -3821,6 +3910,16 @@ LogicalResult verifyReduceScatterOp(std::optional<Location> location,
operandType.getDimSize(index), ") and result (",
resultType.getDimSize(index), ")");
}

// reduce_scatter_c9
SmallVector<ShapedType> accumulatorTypes =
getAccumulatorTypes(computation.front());
if (resultType.getElementType() != accumulatorTypes[0].getElementType()) {
return emitOptionalError(location, "result element-type is expected to be ",
accumulatorTypes[0].getElementType(), ", but got ",
resultType.getElementType());
}

return success();
}

Expand Down
Loading

0 comments on commit f8dcebf

Please sign in to comment.