Skip to content

Commit

Permalink
address feedback:1
Browse files Browse the repository at this point in the history
  • Loading branch information
sdasgup3 committed Dec 20, 2023
1 parent 10d2d4a commit 6fbb8ef
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 21 deletions.
15 changes: 9 additions & 6 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,11 @@ unsigned potentiallyComplexBitWidth(Type type) {
: type.getIntOrFloatBitWidth();
}

template <typename T>
bool matchesType(Type a, Type b) {
return a.isa<T>() && b.isa<T>();
}

// Returns true if the element-type of type1 can be promoted to that of type2.
// An element-type 'x' is promotatble to element-type 'y' is they have the same
// base type and bitwidth(x) <= bitwidth(y). When 'x' and 'y' are quantized
Expand All @@ -122,12 +127,10 @@ bool isPromotableElementType(Type type1, Type type2,
tensorEl2.isa<FloatType>())
return true;

bool isSameType =
(tensorEl1.isa<IntegerType>() and tensorEl2.isa<IntegerType>()) ||
(tensorEl1.isa<FloatType>() and tensorEl2.isa<FloatType>()) ||
(tensorEl1.isa<ComplexType>() and tensorEl2.isa<ComplexType>()) ||
(tensorEl1.isa<quant::QuantizedType>() and
tensorEl2.isa<quant::QuantizedType>());
bool isSameType = matchesType<IntegerType>(tensorEl1, tensorEl2) ||
matchesType<FloatType>(tensorEl1, tensorEl2) ||
matchesType<ComplexType>(tensorEl1, tensorEl2) ||
matchesType<quant::QuantizedType>(tensorEl1, tensorEl2);

if (!isSameType) return false;

Expand Down
32 changes: 17 additions & 15 deletions stablehlo/dialect/VhloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,23 +305,25 @@ void VhloDialect::printAttribute(Attribute attr, DialectAsmPrinter& os) const {
// to represent this sort of constraint in tablegen.

namespace {
Type getVhloElementType(Type tensorType) {
if (auto ranked = tensorType.dyn_cast<RankedTensorV1Type>()) {
return ranked.getElementType();
}
return tensorType.cast<UnrankedTensorV1Type>().getElementType();
}

bool checkIfOperandAndResultElementTypesMatch(TypeRange operandTypes,
TypeRange resultTypes) {
SmallVector<ShapedType> inputShapedTypes{
llvm::map_range(operandTypes, [](Type t) {
return convertTypeToBuiltinForPrint(t).cast<ShapedType>();
})};
SmallVector<ShapedType> resultShapedTypes{
llvm::map_range(resultTypes, [](Type t) {
return convertTypeToBuiltinForPrint(t).cast<ShapedType>();
})};

int64_t numInputs = inputShapedTypes.size();
for (int64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) {
if (inputShapedTypes[inputIdx].getElementType() !=
resultShapedTypes[inputIdx].getElementType())
return true;
}
SmallVector<Type> inputElementTypes{llvm::map_range(
operandTypes, [](Type t) { return getVhloElementType(t); })};
SmallVector<Type> resultElementTypes{llvm::map_range(
resultTypes, [](Type t) { return getVhloElementType(t); })};

if (llvm::any_of(
llvm::zip(inputElementTypes, resultElementTypes),
[&](auto pair) { return std::get<0>(pair) != std::get<1>(pair); }))
return true;

return false;
}
} // namespace
Expand Down

0 comments on commit 6fbb8ef

Please sign in to comment.