Skip to content

Commit

Permalink
addressed feedback:II
Browse files Browse the repository at this point in the history
  • Loading branch information
sdasgup3 committed Dec 20, 2023
1 parent 6fbb8ef commit b284312
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 76 deletions.
2 changes: 0 additions & 2 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1474,8 +1474,6 @@ def StableHLO_ReduceOp: StableHLO_ShapedInterfaceOp<"reduce", [
Variadic<HLO_Tensor>:$init_values, /*reduce_i2*/
I64ElementsAttr:$dimensions /*reduce_i3*/
);
// TODO(hinsu): Verify that the attached body arguments and results are
// compatible with reduce op's operands.
let regions = (region SizedRegion<1>:$body /*reduce_i4*/);

let results = (outs Variadic<HLO_Tensor>);
Expand Down
49 changes: 23 additions & 26 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,24 @@ bool tensorsHaveSameElType(Type type1, Type type2,
return tensorsHaveSameElType({type1, type2}, ignoreFpPrecision);
}

unsigned potentiallyComplexBitWidth(Type type) {
auto complexTy = type.dyn_cast<ComplexType>();
return complexTy ? 2 * complexTy.getElementType().getIntOrFloatBitWidth()
: type.getIntOrFloatBitWidth();
unsigned getBitWidth(Type type) {
if (auto complexTy = type.dyn_cast<ComplexType>())
return 2 * getBitWidth(complexTy.getElementType());
if (auto quantTy = type.dyn_cast<quant::QuantizedType>())
return getBitWidth(quantTy.getStorageType());
return type.getIntOrFloatBitWidth();
}

template <typename T>
bool matchesType(Type a, Type b) {
return a.isa<T>() && b.isa<T>();
bool matches = a.isa<T>() && b.isa<T>();
if constexpr (std::is_same<T, quant::QuantizedType>::
value) { // Check that expressed type matches for quantized
// types
return matches && (a.cast<quant::QuantizedType>().getExpressedType() ==
b.cast<quant::QuantizedType>().getExpressedType());
}
return matches;
}

// Returns true if the element-type of type1 can be promoted to that of type2.
Expand All @@ -123,26 +132,16 @@ bool isPromotableElementType(Type type1, Type type2,
Type tensorEl1 = tensorTy1.getElementType();
Type tensorEl2 = tensorTy2.getElementType();

if (ignoreFpPrecision && tensorEl1.isa<FloatType>() &&
tensorEl2.isa<FloatType>())
return true;

bool isSameType = matchesType<IntegerType>(tensorEl1, tensorEl2) ||
matchesType<FloatType>(tensorEl1, tensorEl2) ||
matchesType<ComplexType>(tensorEl1, tensorEl2) ||
matchesType<quant::QuantizedType>(tensorEl1, tensorEl2);

if (!isSameType) return false;

if (!tensorEl1.isa<quant::QuantizedType>())
return potentiallyComplexBitWidth(tensorEl1) <=
potentiallyComplexBitWidth(tensorEl2);
if (ignoreFpPrecision && tensorEl1.isa<FloatType>()) return true;

auto quantType1 = tensorEl1.cast<quant::QuantizedType>();
auto quantType2 = tensorEl2.cast<quant::QuantizedType>();
return quantType1.getExpressedType() == quantType2.getExpressedType() &&
potentiallyComplexBitWidth(quantType1.getStorageType()) <=
potentiallyComplexBitWidth(quantType2.getStorageType());
return getBitWidth(tensorEl1) <= getBitWidth(tensorEl2);
}

// Return true if type1 and type2 are shape-compatible and have same element
Expand Down Expand Up @@ -577,7 +576,7 @@ LogicalResult verifyReduceOpInputsAndInferShape(
SmallVector<ShapedType> getAccumulatorTypes(Block& block) {
SmallVector<ShapedType> accumulatorSubShapes;
for (Value retOperand : block.getTerminator()->getOperands()) {
auto shapedTy = retOperand.getType().dyn_cast<ShapedType>();
auto shapedTy = retOperand.getType().cast<ShapedType>();
accumulatorSubShapes.push_back(shapedTy);
}
return accumulatorSubShapes;
Expand Down Expand Up @@ -678,10 +677,8 @@ LogicalResult verifyReducerShape(std::optional<Location> loc, Block& block,
loc, "The element-type of reduction-region's argument at index ",
numInputs + inputIdx, " is expected to be promotable from ",
inputTypes[inputIdx].getElementType(), ", but got ",
block.getArgument(numInputs + inputIdx)
.getType()
.cast<ShapedType>()
.getElementType());
getElementTypeOrSelf(
block.getArgument(numInputs + inputIdx).getType()));

Type blockArgType = block.getArgument(numInputs + inputIdx).getType();
auto blockArgTensorTy = blockArgType.cast<ShapedType>();
Expand Down Expand Up @@ -2741,11 +2738,11 @@ LogicalResult inferRngOp(
}

LogicalResult inferScatterOp(std::optional<Location>, ValueRange inputs,
Region& update_computation,
Region& updateComputation,
SmallVectorImpl<Type>& inferredReturnTypes) {
// scatter_c16, scatter_c17
SmallVector<ShapedType> accumulatorTypes =
getAccumulatorTypes(update_computation.front());
getAccumulatorTypes(updateComputation.front());
for (uint64_t inputIdx = 0; inputIdx < inputs.size(); ++inputIdx) {
auto inputShapedTy = inputs[inputIdx].getType().cast<ShapedType>();
inferredReturnTypes.push_back(getSameShapeTensorType(
Expand Down Expand Up @@ -3232,8 +3229,8 @@ LogicalResult verifyBitcastConvertOp(std::optional<Location> location,
location, "cannot convert between real and complex types, but got: ",
operandShapedType, " and ", targetShapedType);

auto targetEltBitWidth = potentiallyComplexBitWidth(targetElt);
auto operandEltBitWidth = potentiallyComplexBitWidth(operandElt);
auto targetEltBitWidth = getBitWidth(targetElt);
auto operandEltBitWidth = getBitWidth(operandElt);

auto operandType = operandShapedType.dyn_cast<RankedTensorType>();
auto targetType = targetShapedType.dyn_cast<RankedTensorType>();
Expand Down
58 changes: 16 additions & 42 deletions stablehlo/dialect/VhloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,72 +326,46 @@ bool checkIfOperandAndResultElementTypesMatch(TypeRange operandTypes,

return false;
}
} // namespace

LogicalResult AllReduceOpV1::validateConstraint(mlir::Operation* op,
Version targetVersion) {
// Allow mismatched operand and result types in v0.17.0
if (checkIfOperandAndResultElementTypesMatch(getOperand().getType(),
getResult().getType()) &&
// Allow mismatched operand and result types in reduce ops in v0.17.0
LogicalResult verifyConstraint_0_17_0(mlir::Operation* op,
Version targetVersion) {
if (checkIfOperandAndResultElementTypesMatch(op->getOperandTypes(),
op->getResultTypes()) &&
targetVersion < Version(0, 17, 0))
return failure();

return success();
}
} // namespace

LogicalResult AllReduceOpV1::validateConstraint(mlir::Operation* op,
Version targetVersion) {
return verifyConstraint_0_17_0(op, targetVersion);
}

LogicalResult ReduceOpV1::validateConstraint(mlir::Operation* op,
Version targetVersion) {
// Allow mismatched operand and result types in v0.17.0
if (checkIfOperandAndResultElementTypesMatch(getInputs().getTypes(),
getResultTypes()) &&
targetVersion < Version(0, 17, 0))
return failure();

return success();
return verifyConstraint_0_17_0(op, targetVersion);
}

LogicalResult ReduceScatterOpV1::validateConstraint(mlir::Operation* op,
Version targetVersion) {
// Allow mismatched operand and result types in v0.17.0
if (checkIfOperandAndResultElementTypesMatch(getOperand().getType(),
getResult().getType()) &&
targetVersion < Version(0, 17, 0))
return failure();

return success();
return verifyConstraint_0_17_0(op, targetVersion);
}

LogicalResult ReduceWindowOpV1::validateConstraint(mlir::Operation* op,
Version targetVersion) {
// Allow mismatched operand and result types in v0.17.0
if (checkIfOperandAndResultElementTypesMatch(getInputs().getTypes(),
getResultTypes()) &&
targetVersion < Version(0, 17, 0))
return failure();

return success();
return verifyConstraint_0_17_0(op, targetVersion);
}

LogicalResult ScatterOpV1::validateConstraint(mlir::Operation* op,
Version targetVersion) {
// Allow mismatched operand and result types in v0.17.0
if (checkIfOperandAndResultElementTypesMatch(getInputs().getTypes(),
getResultTypes()) &&
targetVersion < Version(0, 17, 0))
return failure();

return success();
return verifyConstraint_0_17_0(op, targetVersion);
}

LogicalResult SelectAndScatterOpV1::validateConstraint(mlir::Operation* op,
Version targetVersion) {
// Allow mismatched operand and result types in v0.17.0
if (checkIfOperandAndResultElementTypesMatch(getOperand().getType(),
getResult().getType()) &&
targetVersion < Version(0, 17, 0))
return failure();

return success();
return verifyConstraint_0_17_0(op, targetVersion);
}

} // namespace vhlo
Expand Down
10 changes: 5 additions & 5 deletions stablehlo/tests/stablehlo_legalize_to_vhlo.0_17_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ func.func @default_dynamic_broadcast_in_dim(%arg0: tensor<?x?xf32>, %arg1: tenso
// CHECK-SAME: known_nonexpanding_dimensions = #vhlo.tensor_v1<dense<> : tensor<0xi64>>
// CHECK-SAME: }> : (!vhlo.tensor_v1<?x?x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1<?x?x!vhlo.f32_v1>
%0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %arg1) {
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
broadcast_dimensions = array<i64: 0, 1>
} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
func.return %0 : tensor<?x?xf32>
}
Expand Down Expand Up @@ -933,7 +933,7 @@ func.func @op_broadcast_in_dim(%arg0: tensor<16xf32>) -> tensor<16x16xf32> {
// CHECK-SAME: broadcast_dimensions = #vhlo.tensor_v1<dense<1> : tensor<1xi64>>
// CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1>
%0 = "stablehlo.broadcast_in_dim"(%arg0) {
broadcast_dimensions = dense<1> : tensor<1xi64>
broadcast_dimensions = array<i64: 1>
} : (tensor<16xf32>) -> tensor<16x16xf32>
func.return %0 : tensor<16x16xf32>
}
Expand Down Expand Up @@ -1214,9 +1214,9 @@ func.func @op_dynamic_broadcast_in_dim(%arg0: tensor<?x?xf32>, %arg1: tensor<2xi
// CHECK-SAME: known_nonexpanding_dimensions = #vhlo.tensor_v1<dense<1> : tensor<1xi64>>
// CHECK-SAME: }> : (!vhlo.tensor_v1<?x?x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1<?x?x!vhlo.f32_v1>
%0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %arg1) {
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>,
known_expanding_dimensions = dense<0> : tensor<1xi64>,
known_nonexpanding_dimensions = dense<1> : tensor<1xi64>
broadcast_dimensions = array<i64: 0, 1>,
known_expanding_dimensions = array<i64: 0>,
known_nonexpanding_dimensions = array<i64: 1>
} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
func.return %0 : tensor<?x?xf32>
}
Expand Down
5 changes: 4 additions & 1 deletion stablehlo/transforms/VhloToVersion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,11 @@ bool isLegalOperation(Operation* op, const Version& targetVersion) {
// Validate op constraints
auto constraintInterface = dyn_cast<VersionedOpConstraintInterface>(op);
if (constraintInterface &&
failed(constraintInterface.validateConstraint(op, targetVersion)))
failed(constraintInterface.validateConstraint(op, targetVersion))) {
LLVM_DEBUG(llvm::dbgs()
<< "Op failed to satisfy versioned constraints. " << op << '\n');
return false;
}
LLVM_DEBUG(llvm::dbgs() << "Legal constraints for target. " << op << '\n');

// Validate attributes
Expand Down

0 comments on commit b284312

Please sign in to comment.