Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates to shape functions enabling reuse from MHLO #1918

Merged
merged 2 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1779,7 +1779,6 @@ LogicalResult ReduceOp::inferReturnTypeComponents(
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
ReduceOp::Adaptor adaptor(operands, attributes, properties, regions);
return hlo::inferReduceOp(location, adaptor.getInputs().getTypes(),
adaptor.getInitValues().getTypes(),
adaptor.getDimensions(), adaptor.getBody(),
inferredReturnShapes);
}
Expand Down Expand Up @@ -2293,8 +2292,9 @@ LogicalResult SelectAndScatterOp::inferReturnTypes(
SmallVectorImpl<Type>& inferredReturnTypes) {
SelectAndScatterOp::Adaptor adaptor(operands, attributes, properties,
regions);
return hlo::inferSelectAndScatterOp(
adaptor.getOperand(), adaptor.getScatter(), inferredReturnTypes);
return hlo::inferSelectAndScatterOp(location, adaptor.getOperand(),
adaptor.getScatter(),
inferredReturnTypes);
}

LogicalResult SelectAndScatterOp::verify() {
Expand Down
96 changes: 53 additions & 43 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -505,8 +505,8 @@ LogicalResult verifyReplicaGroups(std::optional<Location> location,

LogicalResult verifyReduceOpInputsAndInferShape(
std::optional<Location> location, SmallVector<ShapedType> inputTypes,
SmallVector<ShapedType> initValueTypes, ArrayRef<int64_t> dimensions,
SmallVector<int64_t>& newDimensions, Attribute& encoding) {
ArrayRef<int64_t> dimensions, SmallVector<int64_t>& newDimensions,
Attribute& encoding) {
// Check for unranked tensors in input operands.
uint64_t numInputs = inputTypes.size();
int64_t rankedInputIdx = -1;
Expand Down Expand Up @@ -568,13 +568,17 @@ LogicalResult verifyReduceOpInputsAndInferShape(

// Returns the types of the terminator arguments of the input mlir::Block
// 'block'.
SmallVector<ShapedType> getAccumulatorTypes(Block& block) {
SmallVector<ShapedType> accumulatorSubShapes;
for (Value retOperand : block.getTerminator()->getOperands()) {
auto shapedTy = retOperand.getType().cast<ShapedType>();
accumulatorSubShapes.push_back(shapedTy);
FailureOr<SmallVector<ShapedType>> getAccumulatorTypes(
std::optional<Location> loc, Region& region) {
if (region.empty()) {
return emitOptionalError(
loc, "Expects non-empty reduction block for type inference");
}
return accumulatorSubShapes;

Block& block = region.front();
return llvm::to_vector(
llvm::map_range(block.getTerminator()->getOperands(),
[&](Value v) { return v.getType().cast<ShapedType>(); }));
}

LogicalResult verifyReducerShape(std::optional<Location> loc, Block& block,
Expand Down Expand Up @@ -1497,14 +1501,20 @@ LogicalResult inferAllToAllOp(
}

LogicalResult inferAllReduceOp(
std::optional<Location> location, Value operand, Region& computation,
std::optional<Location> location, ValueRange operands, Region& computation,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
TypeRange inputTypes = operands.getTypes();
SmallVector<ShapedType> inputArgTensorTypes{
llvm::map_range(inputTypes, [](Type t) { return t.cast<ShapedType>(); })};
// all_reduce_c6, all_reduce_c7
SmallVector<ShapedType> accumulatorTypes =
getAccumulatorTypes(computation.front());
auto operandShapedTy = operand.getType().cast<ShapedType>();
inferredReturnShapes.emplace_back(getSameShapeTensorType(
operandShapedTy, accumulatorTypes[0].getElementType()));
auto accumulatorTypesOrErr = getAccumulatorTypes(location, computation);
if (failed(accumulatorTypesOrErr)) return failure();
for (size_t inputIdx = 0; inputIdx < inputTypes.size(); ++inputIdx) {
inferredReturnShapes.emplace_back(
getSameShapeTensorType(inputArgTensorTypes[inputIdx],
(*accumulatorTypesOrErr)[0].getElementType()));
}

return success();
}

Expand Down Expand Up @@ -2571,25 +2581,23 @@ LogicalResult inferRealOp(std::optional<Location>, Value operand,

LogicalResult inferReduceOp(
std::optional<Location> location, TypeRange inputTypes,
TypeRange initValueTypes, ArrayRef<int64_t> dimensions, Region& body,
ArrayRef<int64_t> dimensions, Region& body,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
SmallVector<ShapedType> inputArgTensorTypes{
llvm::map_range(inputTypes, [](Type t) { return t.cast<ShapedType>(); })};
SmallVector<ShapedType> initValueTensorTypes{llvm::map_range(
initValueTypes, [](Type t) { return t.cast<ShapedType>(); })};

SmallVector<int64_t> newDimensions;
Attribute encoding;
// reduce_c1, reduce_c4, reduce_c5, reduce_i3
if (failed(verifyReduceOpInputsAndInferShape(location, inputArgTensorTypes,
initValueTensorTypes, dimensions,
newDimensions, encoding)))
if (failed(verifyReduceOpInputsAndInferShape(
location, inputArgTensorTypes, dimensions, newDimensions, encoding)))
return failure();
// reduce_c3, reduce_c7, reduce_c8
SmallVector<ShapedType> accumulatorTypes = getAccumulatorTypes(body.front());
auto accumulatorTypesOrErr = getAccumulatorTypes(location, body);
if (failed(accumulatorTypesOrErr)) return failure();
for (uint64_t inputIdx = 0; inputIdx < inputTypes.size(); ++inputIdx) {
ShapedType inputType = inputArgTensorTypes[inputIdx];
Type elementType = accumulatorTypes[inputIdx].getElementType();
Type elementType = (*accumulatorTypesOrErr)[inputIdx].getElementType();
if (inputType.hasRank())
inferredReturnShapes.emplace_back(newDimensions, elementType, encoding);
else
Expand Down Expand Up @@ -2622,22 +2630,24 @@ LogicalResult inferReduceWindowOp(
return failure();

// reduce_window_c1, reduce_window_c14...reduce_window_c16
SmallVector<ShapedType> accumulatorTypes = getAccumulatorTypes(body.front());
auto accumulatorTypesOrErr = getAccumulatorTypes(location, body);
if (failed(accumulatorTypesOrErr)) return failure();
for (size_t i = 0; i < inputTypes.size(); ++i) {
auto inputRankedType = inputs[i].getType().dyn_cast<RankedTensorType>();
if (!inputRankedType) {
inferredReturnShapes.emplace_back(accumulatorTypes[i].getElementType());
inferredReturnShapes.emplace_back(
(*accumulatorTypesOrErr)[i].getElementType());
} else {
auto resultShape =
inferWindowOutputShape(inputTypes[i].getShape(), inferredWindow);
auto inputBounds = encodingToBounds(inputRankedType.getEncoding());
if (inputBounds.empty()) {
inferredReturnShapes.emplace_back(resultShape,
accumulatorTypes[i].getElementType());
inferredReturnShapes.emplace_back(
resultShape, (*accumulatorTypesOrErr)[i].getElementType());
} else {
auto resultBounds = inferWindowOutputShape(inputBounds, inferredWindow);
inferredReturnShapes.emplace_back(
resultShape, accumulatorTypes[i].getElementType(),
resultShape, (*accumulatorTypesOrErr)[i].getElementType(),
boundsToEncoding(inputRankedType.getEncoding(), resultBounds));
}
}
Expand Down Expand Up @@ -2701,16 +2711,16 @@ LogicalResult inferRngOp(
return success();
}

LogicalResult inferScatterOp(std::optional<Location>, ValueRange inputs,
Region& updateComputation,
LogicalResult inferScatterOp(std::optional<Location> location,
ValueRange inputs, Region& updateComputation,
SmallVectorImpl<Type>& inferredReturnTypes) {
// scatter_c16, scatter_c17
SmallVector<ShapedType> accumulatorTypes =
getAccumulatorTypes(updateComputation.front());
auto accumulatorTypesOrErr = getAccumulatorTypes(location, updateComputation);
if (failed(accumulatorTypesOrErr)) return failure();
for (uint64_t inputIdx = 0; inputIdx < inputs.size(); ++inputIdx) {
auto inputShapedTy = inputs[inputIdx].getType().cast<ShapedType>();
inferredReturnTypes.push_back(getSameShapeTensorType(
inputShapedTy, accumulatorTypes[inputIdx].getElementType()));
inputShapedTy, (*accumulatorTypesOrErr)[inputIdx].getElementType()));
}
return success();
}
Expand Down Expand Up @@ -2741,14 +2751,14 @@ LogicalResult inferSelectOp(
}

LogicalResult inferSelectAndScatterOp(
Value operand, Region& scatter,
std::optional<Location> location, Value operand, Region& scatter,
SmallVectorImpl<Type>& inferredReturnTypes) {
// select_and_scatter_c11, select_and_scatter_c12
SmallVector<ShapedType> accumulatorTypes =
getAccumulatorTypes(scatter.front());
auto accumulatorTypesOrErr = getAccumulatorTypes(location, scatter);
if (failed(accumulatorTypesOrErr)) return failure();
auto operandShapedTy = operand.getType().cast<ShapedType>();
inferredReturnTypes.push_back(getSameShapeTensorType(
operandShapedTy, accumulatorTypes[0].getElementType()));
operandShapedTy, (*accumulatorTypesOrErr)[0].getElementType()));
return success();
}

Expand Down Expand Up @@ -3774,8 +3784,7 @@ LogicalResult verifyReduceOp(std::optional<Location> location,
SmallVector<int64_t> newDimensions;
Attribute encoding;
// reduce_c1, reduce_c4, reduce_c5, reduce_i3
if (failed(verifyReduceOpInputsAndInferShape(location, inputTypes,
initValueTypes, dimensions,
if (failed(verifyReduceOpInputsAndInferShape(location, inputTypes, dimensions,
newDimensions, encoding)))
return failure();

Expand Down Expand Up @@ -3876,12 +3885,13 @@ LogicalResult verifyReduceScatterOp(std::optional<Location> location,
}

// reduce_scatter_c9
SmallVector<ShapedType> accumulatorTypes =
getAccumulatorTypes(computation.front());
if (resultType.getElementType() != accumulatorTypes[0].getElementType()) {
auto accumulatorTypesOrErr = getAccumulatorTypes(location, computation);
if (failed(accumulatorTypesOrErr)) return failure();
if (resultType.getElementType() !=
(*accumulatorTypesOrErr)[0].getElementType()) {
return emitOptionalError(location, "result element-type is expected to be ",
accumulatorTypes[0].getElementType(), ", but got ",
resultType.getElementType());
(*accumulatorTypesOrErr)[0].getElementType(),
", but got ", resultType.getElementType());
}

return success();
Expand Down
12 changes: 9 additions & 3 deletions stablehlo/dialect/TypeInference.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ LogicalResult inferAllToAllOp(
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes);

LogicalResult inferAllReduceOp(
std::optional<Location> location, Value operand, Region& computation,
std::optional<Location> location, ValueRange operands, Region& computation,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes);

LogicalResult inferBatchNormGradOp(
Expand Down Expand Up @@ -284,7 +284,7 @@ LogicalResult inferRealOp(std::optional<Location> location, Value operand,

LogicalResult inferReduceOp(
std::optional<Location> location, TypeRange inputTypes,
TypeRange initValueTypes, ArrayRef<int64_t> dimensions, Region& body,
ArrayRef<int64_t> dimensions, Region& body,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes);

LogicalResult inferReduceWindowOp(
Expand Down Expand Up @@ -317,7 +317,8 @@ LogicalResult inferSelectOp(
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes);

LogicalResult inferSelectAndScatterOp(
Value operand, Region& scatter, SmallVectorImpl<Type>& inferredReturnTypes);
std::optional<Location> location, Value operand, Region& scatter,
SmallVectorImpl<Type>& inferredReturnTypes);

LogicalResult inferSendOp(HloDialectInterface* dialect,
std::optional<Location> location,
Expand Down Expand Up @@ -463,6 +464,11 @@ LogicalResult verifyReduceOp(std::optional<Location> location,
ValueRange inputs, ValueRange initValues,
ArrayRef<int64_t> dimensions, Region& body);

LogicalResult verifyReduceOpInputsAndInferShape(
std::optional<Location> location, SmallVector<ShapedType> inputTypes,
ArrayRef<int64_t> dimensions, SmallVector<int64_t>& newDimensions,
Attribute& encoding);

LogicalResult verifyReducePrecisionOp(std::optional<Location> location,
int32_t exponentBits,
int32_t mantissaBits);
Expand Down
3 changes: 1 addition & 2 deletions stablehlo/reference/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ SmallVector<Tensor> evalReduceOp(ArrayRef<Tensor> inputs,
SmallVector<ShapedTypeComponents> inferredReduceTypes;
Builder builder(inputs[0].getType().getContext());
auto reduceStatus = hlo::inferReduceOp(
/*location=*/{}, inputTypes, initValueTypes, dimensions, body,
inferredReduceTypes);
/*location=*/{}, inputTypes, dimensions, body, inferredReduceTypes);
if (failed(reduceStatus))
report_fatal_error(
invalidArgument("Could not infer ReduceOp's return type"));
Expand Down