Skip to content

Commit

Permalink
Updates to shape functions enabling reuse from MHLO (#1918)
Browse files Browse the repository at this point in the history
The upstream change #1869 in
StableHLO updates various API related to shape inference. MHLO shape
inference functions in
[hlo_ops.cc](https://github.com/openxla/xla/blob/main/xla/mlir_hlo/mhlo/IR/hlo_ops.cc)
uses those APIs. The PR updates the visibility and signature of those
API for a clearer integration.

Specifically, the PR does the followings:
1. **updates `getAccumulatorTypes` to return a error status when the
input regions is empty**: This function is used in type inference of
various reduction based operations
([eg](https://github.com/openxla/stablehlo/blob/d5b464925371092095ac934b46ba93ebd4284223/stablehlo/dialect/TypeInference.cpp#L2589)).
This functions enables infering type based on the reduction block of the
operation, which is something proposed in
[RFC](#1664). However, there
could be
[instances](https://github.com/openxla/xla/blob/a91877b9c9aa1edf307c5927782111b1a81cd81d/xla/mlir_hlo/mhlo/transforms/group_reduction_dimensions/group_reduction_dimensions.cc#L228)
when type inference can be called with empty region in which case we
would like to report a meaningful diagnostic message.

2. **Allow `hlo::inferAllReduceOp` to accept multiple operands
information**: In stableHLO, `all_reduce` op have a single operand
([e.g.](https://github.com/openxla/stablehlo/blob/d5b464925371092095ac934b46ba93ebd4284223/stablehlo/dialect/StablehloOps.td#L1355)),
whereas in MHLO the op can take multiple operand
([e.g.](https://github.com/openxla/xla/blob/79aba0801ef75c1c2dffbb4ecc506a0d8144c9ac/xla/mlir_hlo/mhlo/IR/hlo_ops.td#L1528).
The `hlo::inferAllReduceOp` signature is updated to accommodate both
cases.

3. Remove unused arguments to functions
`verifyReduceOpInputsAndInferShape` and `inferReduceOp`.
  • Loading branch information
sdasgup3 authored Jan 19, 2024
1 parent f4e084c commit 3bd2fad
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 51 deletions.
6 changes: 3 additions & 3 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1779,7 +1779,6 @@ LogicalResult ReduceOp::inferReturnTypeComponents(
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
ReduceOp::Adaptor adaptor(operands, attributes, properties, regions);
return hlo::inferReduceOp(location, adaptor.getInputs().getTypes(),
adaptor.getInitValues().getTypes(),
adaptor.getDimensions(), adaptor.getBody(),
inferredReturnShapes);
}
Expand Down Expand Up @@ -2293,8 +2292,9 @@ LogicalResult SelectAndScatterOp::inferReturnTypes(
SmallVectorImpl<Type>& inferredReturnTypes) {
SelectAndScatterOp::Adaptor adaptor(operands, attributes, properties,
regions);
return hlo::inferSelectAndScatterOp(
adaptor.getOperand(), adaptor.getScatter(), inferredReturnTypes);
return hlo::inferSelectAndScatterOp(location, adaptor.getOperand(),
adaptor.getScatter(),
inferredReturnTypes);
}

LogicalResult SelectAndScatterOp::verify() {
Expand Down
96 changes: 53 additions & 43 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -505,8 +505,8 @@ LogicalResult verifyReplicaGroups(std::optional<Location> location,

LogicalResult verifyReduceOpInputsAndInferShape(
std::optional<Location> location, SmallVector<ShapedType> inputTypes,
SmallVector<ShapedType> initValueTypes, ArrayRef<int64_t> dimensions,
SmallVector<int64_t>& newDimensions, Attribute& encoding) {
ArrayRef<int64_t> dimensions, SmallVector<int64_t>& newDimensions,
Attribute& encoding) {
// Check for unranked tensors in input operands.
uint64_t numInputs = inputTypes.size();
int64_t rankedInputIdx = -1;
Expand Down Expand Up @@ -568,13 +568,17 @@ LogicalResult verifyReduceOpInputsAndInferShape(

// Returns the types of the terminator arguments of the input mlir::Block
// 'block'.
SmallVector<ShapedType> getAccumulatorTypes(Block& block) {
SmallVector<ShapedType> accumulatorSubShapes;
for (Value retOperand : block.getTerminator()->getOperands()) {
auto shapedTy = retOperand.getType().cast<ShapedType>();
accumulatorSubShapes.push_back(shapedTy);
FailureOr<SmallVector<ShapedType>> getAccumulatorTypes(
std::optional<Location> loc, Region& region) {
if (region.empty()) {
return emitOptionalError(
loc, "Expects non-empty reduction block for type inference");
}
return accumulatorSubShapes;

Block& block = region.front();
return llvm::to_vector(
llvm::map_range(block.getTerminator()->getOperands(),
[&](Value v) { return v.getType().cast<ShapedType>(); }));
}

LogicalResult verifyReducerShape(std::optional<Location> loc, Block& block,
Expand Down Expand Up @@ -1497,14 +1501,20 @@ LogicalResult inferAllToAllOp(
}

LogicalResult inferAllReduceOp(
std::optional<Location> location, Value operand, Region& computation,
std::optional<Location> location, ValueRange operands, Region& computation,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
TypeRange inputTypes = operands.getTypes();
SmallVector<ShapedType> inputArgTensorTypes{
llvm::map_range(inputTypes, [](Type t) { return t.cast<ShapedType>(); })};
// all_reduce_c6, all_reduce_c7
SmallVector<ShapedType> accumulatorTypes =
getAccumulatorTypes(computation.front());
auto operandShapedTy = operand.getType().cast<ShapedType>();
inferredReturnShapes.emplace_back(getSameShapeTensorType(
operandShapedTy, accumulatorTypes[0].getElementType()));
auto accumulatorTypesOrErr = getAccumulatorTypes(location, computation);
if (failed(accumulatorTypesOrErr)) return failure();
for (size_t inputIdx = 0; inputIdx < inputTypes.size(); ++inputIdx) {
inferredReturnShapes.emplace_back(
getSameShapeTensorType(inputArgTensorTypes[inputIdx],
(*accumulatorTypesOrErr)[0].getElementType()));
}

return success();
}

Expand Down Expand Up @@ -2571,25 +2581,23 @@ LogicalResult inferRealOp(std::optional<Location>, Value operand,

LogicalResult inferReduceOp(
std::optional<Location> location, TypeRange inputTypes,
TypeRange initValueTypes, ArrayRef<int64_t> dimensions, Region& body,
ArrayRef<int64_t> dimensions, Region& body,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
SmallVector<ShapedType> inputArgTensorTypes{
llvm::map_range(inputTypes, [](Type t) { return t.cast<ShapedType>(); })};
SmallVector<ShapedType> initValueTensorTypes{llvm::map_range(
initValueTypes, [](Type t) { return t.cast<ShapedType>(); })};

SmallVector<int64_t> newDimensions;
Attribute encoding;
// reduce_c1, reduce_c4, reduce_c5, reduce_i3
if (failed(verifyReduceOpInputsAndInferShape(location, inputArgTensorTypes,
initValueTensorTypes, dimensions,
newDimensions, encoding)))
if (failed(verifyReduceOpInputsAndInferShape(
location, inputArgTensorTypes, dimensions, newDimensions, encoding)))
return failure();
// reduce_c3, reduce_c7, reduce_c8
SmallVector<ShapedType> accumulatorTypes = getAccumulatorTypes(body.front());
auto accumulatorTypesOrErr = getAccumulatorTypes(location, body);
if (failed(accumulatorTypesOrErr)) return failure();
for (uint64_t inputIdx = 0; inputIdx < inputTypes.size(); ++inputIdx) {
ShapedType inputType = inputArgTensorTypes[inputIdx];
Type elementType = accumulatorTypes[inputIdx].getElementType();
Type elementType = (*accumulatorTypesOrErr)[inputIdx].getElementType();
if (inputType.hasRank())
inferredReturnShapes.emplace_back(newDimensions, elementType, encoding);
else
Expand Down Expand Up @@ -2622,22 +2630,24 @@ LogicalResult inferReduceWindowOp(
return failure();

// reduce_window_c1, reduce_window_c14...reduce_window_c16
SmallVector<ShapedType> accumulatorTypes = getAccumulatorTypes(body.front());
auto accumulatorTypesOrErr = getAccumulatorTypes(location, body);
if (failed(accumulatorTypesOrErr)) return failure();
for (size_t i = 0; i < inputTypes.size(); ++i) {
auto inputRankedType = inputs[i].getType().dyn_cast<RankedTensorType>();
if (!inputRankedType) {
inferredReturnShapes.emplace_back(accumulatorTypes[i].getElementType());
inferredReturnShapes.emplace_back(
(*accumulatorTypesOrErr)[i].getElementType());
} else {
auto resultShape =
inferWindowOutputShape(inputTypes[i].getShape(), inferredWindow);
auto inputBounds = encodingToBounds(inputRankedType.getEncoding());
if (inputBounds.empty()) {
inferredReturnShapes.emplace_back(resultShape,
accumulatorTypes[i].getElementType());
inferredReturnShapes.emplace_back(
resultShape, (*accumulatorTypesOrErr)[i].getElementType());
} else {
auto resultBounds = inferWindowOutputShape(inputBounds, inferredWindow);
inferredReturnShapes.emplace_back(
resultShape, accumulatorTypes[i].getElementType(),
resultShape, (*accumulatorTypesOrErr)[i].getElementType(),
boundsToEncoding(inputRankedType.getEncoding(), resultBounds));
}
}
Expand Down Expand Up @@ -2701,16 +2711,16 @@ LogicalResult inferRngOp(
return success();
}

LogicalResult inferScatterOp(std::optional<Location>, ValueRange inputs,
Region& updateComputation,
LogicalResult inferScatterOp(std::optional<Location> location,
ValueRange inputs, Region& updateComputation,
SmallVectorImpl<Type>& inferredReturnTypes) {
// scatter_c16, scatter_c17
SmallVector<ShapedType> accumulatorTypes =
getAccumulatorTypes(updateComputation.front());
auto accumulatorTypesOrErr = getAccumulatorTypes(location, updateComputation);
if (failed(accumulatorTypesOrErr)) return failure();
for (uint64_t inputIdx = 0; inputIdx < inputs.size(); ++inputIdx) {
auto inputShapedTy = inputs[inputIdx].getType().cast<ShapedType>();
inferredReturnTypes.push_back(getSameShapeTensorType(
inputShapedTy, accumulatorTypes[inputIdx].getElementType()));
inputShapedTy, (*accumulatorTypesOrErr)[inputIdx].getElementType()));
}
return success();
}
Expand Down Expand Up @@ -2741,14 +2751,14 @@ LogicalResult inferSelectOp(
}

LogicalResult inferSelectAndScatterOp(
Value operand, Region& scatter,
std::optional<Location> location, Value operand, Region& scatter,
SmallVectorImpl<Type>& inferredReturnTypes) {
// select_and_scatter_c11, select_and_scatter_c12
SmallVector<ShapedType> accumulatorTypes =
getAccumulatorTypes(scatter.front());
auto accumulatorTypesOrErr = getAccumulatorTypes(location, scatter);
if (failed(accumulatorTypesOrErr)) return failure();
auto operandShapedTy = operand.getType().cast<ShapedType>();
inferredReturnTypes.push_back(getSameShapeTensorType(
operandShapedTy, accumulatorTypes[0].getElementType()));
operandShapedTy, (*accumulatorTypesOrErr)[0].getElementType()));
return success();
}

Expand Down Expand Up @@ -3774,8 +3784,7 @@ LogicalResult verifyReduceOp(std::optional<Location> location,
SmallVector<int64_t> newDimensions;
Attribute encoding;
// reduce_c1, reduce_c4, reduce_c5, reduce_i3
if (failed(verifyReduceOpInputsAndInferShape(location, inputTypes,
initValueTypes, dimensions,
if (failed(verifyReduceOpInputsAndInferShape(location, inputTypes, dimensions,
newDimensions, encoding)))
return failure();

Expand Down Expand Up @@ -3876,12 +3885,13 @@ LogicalResult verifyReduceScatterOp(std::optional<Location> location,
}

// reduce_scatter_c9
SmallVector<ShapedType> accumulatorTypes =
getAccumulatorTypes(computation.front());
if (resultType.getElementType() != accumulatorTypes[0].getElementType()) {
auto accumulatorTypesOrErr = getAccumulatorTypes(location, computation);
if (failed(accumulatorTypesOrErr)) return failure();
if (resultType.getElementType() !=
(*accumulatorTypesOrErr)[0].getElementType()) {
return emitOptionalError(location, "result element-type is expected to be ",
accumulatorTypes[0].getElementType(), ", but got ",
resultType.getElementType());
(*accumulatorTypesOrErr)[0].getElementType(),
", but got ", resultType.getElementType());
}

return success();
Expand Down
12 changes: 9 additions & 3 deletions stablehlo/dialect/TypeInference.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ LogicalResult inferAllToAllOp(
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes);

LogicalResult inferAllReduceOp(
std::optional<Location> location, Value operand, Region& computation,
std::optional<Location> location, ValueRange operands, Region& computation,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes);

LogicalResult inferBatchNormGradOp(
Expand Down Expand Up @@ -284,7 +284,7 @@ LogicalResult inferRealOp(std::optional<Location> location, Value operand,

LogicalResult inferReduceOp(
std::optional<Location> location, TypeRange inputTypes,
TypeRange initValueTypes, ArrayRef<int64_t> dimensions, Region& body,
ArrayRef<int64_t> dimensions, Region& body,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes);

LogicalResult inferReduceWindowOp(
Expand Down Expand Up @@ -317,7 +317,8 @@ LogicalResult inferSelectOp(
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes);

LogicalResult inferSelectAndScatterOp(
Value operand, Region& scatter, SmallVectorImpl<Type>& inferredReturnTypes);
std::optional<Location> location, Value operand, Region& scatter,
SmallVectorImpl<Type>& inferredReturnTypes);

LogicalResult inferSendOp(HloDialectInterface* dialect,
std::optional<Location> location,
Expand Down Expand Up @@ -463,6 +464,11 @@ LogicalResult verifyReduceOp(std::optional<Location> location,
ValueRange inputs, ValueRange initValues,
ArrayRef<int64_t> dimensions, Region& body);

LogicalResult verifyReduceOpInputsAndInferShape(
std::optional<Location> location, SmallVector<ShapedType> inputTypes,
ArrayRef<int64_t> dimensions, SmallVector<int64_t>& newDimensions,
Attribute& encoding);

LogicalResult verifyReducePrecisionOp(std::optional<Location> location,
int32_t exponentBits,
int32_t mantissaBits);
Expand Down
3 changes: 1 addition & 2 deletions stablehlo/reference/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ SmallVector<Tensor> evalReduceOp(ArrayRef<Tensor> inputs,
SmallVector<ShapedTypeComponents> inferredReduceTypes;
Builder builder(inputs[0].getType().getContext());
auto reduceStatus = hlo::inferReduceOp(
/*location=*/{}, inputTypes, initValueTypes, dimensions, body,
inferredReduceTypes);
/*location=*/{}, inputTypes, dimensions, body, inferredReduceTypes);
if (failed(reduceStatus))
report_fatal_error(
invalidArgument("Could not infer ReduceOp's return type"));
Expand Down

0 comments on commit 3bd2fad

Please sign in to comment.