From b284312d0d121828e3c8e5956049ba3a2477d4a6 Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Wed, 20 Dec 2023 20:49:01 +0000 Subject: [PATCH] addressed feedback:II --- stablehlo/dialect/StablehloOps.td | 2 - stablehlo/dialect/TypeInference.cpp | 49 ++++++++-------- stablehlo/dialect/VhloOps.cpp | 58 +++++-------------- .../stablehlo_legalize_to_vhlo.0_17_0.mlir | 10 ++-- stablehlo/transforms/VhloToVersion.cpp | 5 +- 5 files changed, 48 insertions(+), 76 deletions(-) diff --git a/stablehlo/dialect/StablehloOps.td b/stablehlo/dialect/StablehloOps.td index 1d37c7e3628..e5dbfed5936 100644 --- a/stablehlo/dialect/StablehloOps.td +++ b/stablehlo/dialect/StablehloOps.td @@ -1474,8 +1474,6 @@ def StableHLO_ReduceOp: StableHLO_ShapedInterfaceOp<"reduce", [ Variadic:$init_values, /*reduce_i2*/ I64ElementsAttr:$dimensions /*reduce_i3*/ ); - // TODO(hinsu): Verify that the attached body arguments and results are - // compatible with reduce op's operands. let regions = (region SizedRegion<1>:$body /*reduce_i4*/); let results = (outs Variadic); diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index 82fba18aaa8..9018d69b913 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -97,15 +97,24 @@ bool tensorsHaveSameElType(Type type1, Type type2, return tensorsHaveSameElType({type1, type2}, ignoreFpPrecision); } -unsigned potentiallyComplexBitWidth(Type type) { - auto complexTy = type.dyn_cast(); - return complexTy ? 2 * complexTy.getElementType().getIntOrFloatBitWidth() - : type.getIntOrFloatBitWidth(); +unsigned getBitWidth(Type type) { + if (auto complexTy = type.dyn_cast()) + return 2 * getBitWidth(complexTy.getElementType()); + if (auto quantTy = type.dyn_cast()) + return getBitWidth(quantTy.getStorageType()); + return type.getIntOrFloatBitWidth(); } template bool matchesType(Type a, Type b) { - return a.isa() && b.isa(); + bool matches = a.isa() && b.isa(); + if constexpr (std::is_same:: + value) { // Check that expressed type matches for quantized + // types + return matches && (a.cast().getExpressedType() == + b.cast().getExpressedType()); + } + return matches; } // Returns true if the element-type of type1 can be promoted to that of type2. @@ -123,10 +132,6 @@ bool isPromotableElementType(Type type1, Type type2, Type tensorEl1 = tensorTy1.getElementType(); Type tensorEl2 = tensorTy2.getElementType(); - if (ignoreFpPrecision && tensorEl1.isa() && - tensorEl2.isa()) - return true; - bool isSameType = matchesType(tensorEl1, tensorEl2) || matchesType(tensorEl1, tensorEl2) || matchesType(tensorEl1, tensorEl2) || @@ -134,15 +139,9 @@ bool isPromotableElementType(Type type1, Type type2, if (!isSameType) return false; - if (!tensorEl1.isa()) - return potentiallyComplexBitWidth(tensorEl1) <= - potentiallyComplexBitWidth(tensorEl2); + if (ignoreFpPrecision && tensorEl1.isa()) return true; - auto quantType1 = tensorEl1.cast(); - auto quantType2 = tensorEl2.cast(); - return quantType1.getExpressedType() == quantType2.getExpressedType() && - potentiallyComplexBitWidth(quantType1.getStorageType()) <= - potentiallyComplexBitWidth(quantType2.getStorageType()); + return getBitWidth(tensorEl1) <= getBitWidth(tensorEl2); } // Return true if type1 and type2 are shape-compatible and have same element @@ -577,7 +576,7 @@ LogicalResult verifyReduceOpInputsAndInferShape( SmallVector getAccumulatorTypes(Block& block) { SmallVector accumulatorSubShapes; for (Value retOperand : block.getTerminator()->getOperands()) { - auto shapedTy = retOperand.getType().dyn_cast(); + auto shapedTy = retOperand.getType().cast(); accumulatorSubShapes.push_back(shapedTy); } return accumulatorSubShapes; @@ -678,10 +677,8 @@ LogicalResult verifyReducerShape(std::optional loc, Block& block, loc, "The element-type of reduction-region's argument at index ", numInputs + inputIdx, " is expected to be promotable from ", inputTypes[inputIdx].getElementType(), ", but got ", - block.getArgument(numInputs + inputIdx) - .getType() - .cast() - .getElementType()); + getElementTypeOrSelf( + block.getArgument(numInputs + inputIdx).getType())); Type blockArgType = block.getArgument(numInputs + inputIdx).getType(); auto blockArgTensorTy = blockArgType.cast(); @@ -2741,11 +2738,11 @@ LogicalResult inferRngOp( } LogicalResult inferScatterOp(std::optional, ValueRange inputs, - Region& update_computation, + Region& updateComputation, SmallVectorImpl& inferredReturnTypes) { // scatter_c16, scatter_c17 SmallVector accumulatorTypes = - getAccumulatorTypes(update_computation.front()); + getAccumulatorTypes(updateComputation.front()); for (uint64_t inputIdx = 0; inputIdx < inputs.size(); ++inputIdx) { auto inputShapedTy = inputs[inputIdx].getType().cast(); inferredReturnTypes.push_back(getSameShapeTensorType( @@ -3232,8 +3229,8 @@ LogicalResult verifyBitcastConvertOp(std::optional location, location, "cannot convert between real and complex types, but got: ", operandShapedType, " and ", targetShapedType); - auto targetEltBitWidth = potentiallyComplexBitWidth(targetElt); - auto operandEltBitWidth = potentiallyComplexBitWidth(operandElt); + auto targetEltBitWidth = getBitWidth(targetElt); + auto operandEltBitWidth = getBitWidth(operandElt); auto operandType = operandShapedType.dyn_cast(); auto targetType = targetShapedType.dyn_cast(); diff --git a/stablehlo/dialect/VhloOps.cpp b/stablehlo/dialect/VhloOps.cpp index 2c2ea9e0220..e8c651e58a6 100644 --- a/stablehlo/dialect/VhloOps.cpp +++ b/stablehlo/dialect/VhloOps.cpp @@ -326,72 +326,46 @@ bool checkIfOperandAndResultElementTypesMatch(TypeRange operandTypes, return false; } -} // namespace -LogicalResult AllReduceOpV1::validateConstraint(mlir::Operation* op, - Version targetVersion) { - // Allow mismatched operand and result types in v0.17.0 - if (checkIfOperandAndResultElementTypesMatch(getOperand().getType(), - getResult().getType()) && +// Allow mismatched operand and result types in reduce ops in v0.17.0 +LogicalResult verifyConstraint_0_17_0(mlir::Operation* op, + Version targetVersion) { + if (checkIfOperandAndResultElementTypesMatch(op->getOperandTypes(), + op->getResultTypes()) && targetVersion < Version(0, 17, 0)) return failure(); - return success(); } +} // namespace + +LogicalResult AllReduceOpV1::validateConstraint(mlir::Operation* op, + Version targetVersion) { + return verifyConstraint_0_17_0(op, targetVersion); +} LogicalResult ReduceOpV1::validateConstraint(mlir::Operation* op, Version targetVersion) { - // Allow mismatched operand and result types in v0.17.0 - if (checkIfOperandAndResultElementTypesMatch(getInputs().getTypes(), - getResultTypes()) && - targetVersion < Version(0, 17, 0)) - return failure(); - - return success(); + return verifyConstraint_0_17_0(op, targetVersion); } LogicalResult ReduceScatterOpV1::validateConstraint(mlir::Operation* op, Version targetVersion) { - // Allow mismatched operand and result types in v0.17.0 - if (checkIfOperandAndResultElementTypesMatch(getOperand().getType(), - getResult().getType()) && - targetVersion < Version(0, 17, 0)) - return failure(); - - return success(); + return verifyConstraint_0_17_0(op, targetVersion); } LogicalResult ReduceWindowOpV1::validateConstraint(mlir::Operation* op, Version targetVersion) { - // Allow mismatched operand and result types in v0.17.0 - if (checkIfOperandAndResultElementTypesMatch(getInputs().getTypes(), - getResultTypes()) && - targetVersion < Version(0, 17, 0)) - return failure(); - - return success(); + return verifyConstraint_0_17_0(op, targetVersion); } LogicalResult ScatterOpV1::validateConstraint(mlir::Operation* op, Version targetVersion) { - // Allow mismatched operand and result types in v0.17.0 - if (checkIfOperandAndResultElementTypesMatch(getInputs().getTypes(), - getResultTypes()) && - targetVersion < Version(0, 17, 0)) - return failure(); - - return success(); + return verifyConstraint_0_17_0(op, targetVersion); } LogicalResult SelectAndScatterOpV1::validateConstraint(mlir::Operation* op, Version targetVersion) { - // Allow mismatched operand and result types in v0.17.0 - if (checkIfOperandAndResultElementTypesMatch(getOperand().getType(), - getResult().getType()) && - targetVersion < Version(0, 17, 0)) - return failure(); - - return success(); + return verifyConstraint_0_17_0(op, targetVersion); } } // namespace vhlo diff --git a/stablehlo/tests/stablehlo_legalize_to_vhlo.0_17_0.mlir b/stablehlo/tests/stablehlo_legalize_to_vhlo.0_17_0.mlir index 70efa90c0b3..acd729402ff 100644 --- a/stablehlo/tests/stablehlo_legalize_to_vhlo.0_17_0.mlir +++ b/stablehlo/tests/stablehlo_legalize_to_vhlo.0_17_0.mlir @@ -509,7 +509,7 @@ func.func @default_dynamic_broadcast_in_dim(%arg0: tensor, %arg1: tenso // CHECK-SAME: known_nonexpanding_dimensions = #vhlo.tensor_v1 : tensor<0xi64>> // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 %0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { - broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> + broadcast_dimensions = array } : (tensor, tensor<2xindex>) -> tensor func.return %0 : tensor } @@ -933,7 +933,7 @@ func.func @op_broadcast_in_dim(%arg0: tensor<16xf32>) -> tensor<16x16xf32> { // CHECK-SAME: broadcast_dimensions = #vhlo.tensor_v1 : tensor<1xi64>> // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> %0 = "stablehlo.broadcast_in_dim"(%arg0) { - broadcast_dimensions = dense<1> : tensor<1xi64> + broadcast_dimensions = array } : (tensor<16xf32>) -> tensor<16x16xf32> func.return %0 : tensor<16x16xf32> } @@ -1214,9 +1214,9 @@ func.func @op_dynamic_broadcast_in_dim(%arg0: tensor, %arg1: tensor<2xi // CHECK-SAME: known_nonexpanding_dimensions = #vhlo.tensor_v1 : tensor<1xi64>> // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 %0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { - broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, - known_expanding_dimensions = dense<0> : tensor<1xi64>, - known_nonexpanding_dimensions = dense<1> : tensor<1xi64> + broadcast_dimensions = array, + known_expanding_dimensions = array, + known_nonexpanding_dimensions = array } : (tensor, tensor<2xindex>) -> tensor func.return %0 : tensor } diff --git a/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/transforms/VhloToVersion.cpp index 4e4b106651c..a8cf57d1aa6 100644 --- a/stablehlo/transforms/VhloToVersion.cpp +++ b/stablehlo/transforms/VhloToVersion.cpp @@ -186,8 +186,11 @@ bool isLegalOperation(Operation* op, const Version& targetVersion) { // Validate op constraints auto constraintInterface = dyn_cast(op); if (constraintInterface && - failed(constraintInterface.validateConstraint(op, targetVersion))) + failed(constraintInterface.validateConstraint(op, targetVersion))) { + LLVM_DEBUG(llvm::dbgs() + << "Op failed to satisfy versioned constraints. " << op << '\n'); return false; + } LLVM_DEBUG(llvm::dbgs() << "Legal constraints for target. " << op << '\n'); // Validate attributes