diff --git a/llvm/include/llvm/IR/DerivedTypes.h b/llvm/include/llvm/IR/DerivedTypes.h index a24801d8bdf834..820b5c0707df6c 100644 --- a/llvm/include/llvm/IR/DerivedTypes.h +++ b/llvm/include/llvm/IR/DerivedTypes.h @@ -290,8 +290,8 @@ class StructType : public Type { bool isSized(SmallPtrSetImpl *Visited = nullptr) const; /// Returns true if this struct contains a scalable vector. - bool - containsScalableVectorType(SmallPtrSetImpl *Visited = nullptr) const; + bool isScalableTy(SmallPtrSetImpl &Visited) const; + using Type::isScalableTy; /// Returns true if this struct contains homogeneous scalable vector types. /// Note that the definition of homogeneous scalable vector type is not diff --git a/llvm/include/llvm/IR/Type.h b/llvm/include/llvm/IR/Type.h index 2f53197df19998..d563b25d600a0c 100644 --- a/llvm/include/llvm/IR/Type.h +++ b/llvm/include/llvm/IR/Type.h @@ -206,6 +206,7 @@ class Type { bool isScalableTargetExtTy() const; /// Return true if this is a type whose size is a known multiple of vscale. + bool isScalableTy(SmallPtrSetImpl &Visited) const; bool isScalableTy() const; /// Return true if this is a FP type or a vector of FP. diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp index 6a2372c9751408..8ddb2efb0e26c2 100644 --- a/llvm/lib/AsmParser/LLParser.cpp +++ b/llvm/lib/AsmParser/LLParser.cpp @@ -8525,7 +8525,7 @@ int LLParser::parseGetElementPtr(Instruction *&Inst, PerFunctionState &PFS) { return error(Loc, "base element of getelementptr must be sized"); auto *STy = dyn_cast(Ty); - if (STy && STy->containsScalableVectorType()) + if (STy && STy->isScalableTy()) return error(Loc, "getelementptr cannot target structure that contains " "scalable vector type"); diff --git a/llvm/lib/IR/Type.cpp b/llvm/lib/IR/Type.cpp index f618263f79c313..912b1a3960ef19 100644 --- a/llvm/lib/IR/Type.cpp +++ b/llvm/lib/IR/Type.cpp @@ -58,16 +58,19 @@ bool Type::isIntegerTy(unsigned Bitwidth) const { return isIntegerTy() && cast(this)->getBitWidth() == Bitwidth; } -bool Type::isScalableTy() const { +bool Type::isScalableTy(SmallPtrSetImpl &Visited) const { if (const auto *ATy = dyn_cast(this)) - return ATy->getElementType()->isScalableTy(); - if (const auto *STy = dyn_cast(this)) { - SmallPtrSet Visited; - return STy->containsScalableVectorType(&Visited); - } + return ATy->getElementType()->isScalableTy(Visited); + if (const auto *STy = dyn_cast(this)) + return STy->isScalableTy(Visited); return getTypeID() == ScalableVectorTyID || isScalableTargetExtTy(); } +bool Type::isScalableTy() const { + SmallPtrSet Visited; + return isScalableTy(Visited); +} + const fltSemantics &Type::getFltSemantics() const { switch (getTypeID()) { case HalfTyID: return APFloat::IEEEhalf(); @@ -394,30 +397,22 @@ StructType *StructType::get(LLVMContext &Context, ArrayRef ETypes, return ST; } -bool StructType::containsScalableVectorType( - SmallPtrSetImpl *Visited) const { +bool StructType::isScalableTy(SmallPtrSetImpl &Visited) const { if ((getSubclassData() & SCDB_ContainsScalableVector) != 0) return true; if ((getSubclassData() & SCDB_NotContainsScalableVector) != 0) return false; - if (Visited && !Visited->insert(const_cast(this)).second) + if (!Visited.insert(this).second) return false; for (Type *Ty : elements()) { - if (isa(Ty)) { + if (Ty->isScalableTy(Visited)) { const_cast(this)->setSubclassData( getSubclassData() | SCDB_ContainsScalableVector); return true; } - if (auto *STy = dyn_cast(Ty)) { - if (STy->containsScalableVectorType(Visited)) { - const_cast(this)->setSubclassData( - getSubclassData() | SCDB_ContainsScalableVector); - return true; - } - } } // For structures that are opaque, return false but do not set the diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp index f34fe7594c8602..60e65392218dad 100644 --- a/llvm/lib/IR/Verifier.cpp +++ b/llvm/lib/IR/Verifier.cpp @@ -4107,8 +4107,7 @@ void Verifier::visitGetElementPtrInst(GetElementPtrInst &GEP) { Check(GEP.getSourceElementType()->isSized(), "GEP into unsized type!", &GEP); if (auto *STy = dyn_cast(GEP.getSourceElementType())) { - SmallPtrSet Visited; - Check(!STy->containsScalableVectorType(&Visited), + Check(!STy->isScalableTy(), "getelementptr cannot target structure that contains scalable vector" "type", &GEP); diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index c8b9f166b16020..971ace2a4f4716 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -4087,7 +4087,7 @@ Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) { if (LoadInst *L = dyn_cast(Agg)) { // Bail out if the aggregate contains scalable vector type if (auto *STy = dyn_cast(Agg->getType()); - STy && STy->containsScalableVectorType()) + STy && STy->isScalableTy()) return nullptr; // If the (non-volatile) load only has one use, we can rewrite this to a diff --git a/llvm/test/Verifier/scalable-global-vars.ll b/llvm/test/Verifier/scalable-global-vars.ll index 81882261e664ef..fb9a3067acba98 100644 --- a/llvm/test/Verifier/scalable-global-vars.ll +++ b/llvm/test/Verifier/scalable-global-vars.ll @@ -15,3 +15,17 @@ ; CHECK-NEXT: ptr @ScalableVecStructGlobal @ScalableVecStructGlobal = global { i32, } zeroinitializer +; CHECK-NEXT: Globals cannot contain scalable types +; CHECK-NEXT: ptr @StructTestGlobal +%struct.test = type { , } +@StructTestGlobal = global %struct.test zeroinitializer + +; CHECK-NEXT: Globals cannot contain scalable types +; CHECK-NEXT: ptr @StructArrayTestGlobal +%struct.array.test = type { [2 x ] } +@StructArrayTestGlobal = global %struct.array.test zeroinitializer + +; CHECK-NEXT: Globals cannot contain scalable types +; CHECK-NEXT: ptr @StructTargetTestGlobal +%struct.target.test = type { target("aarch64.svcount"), target("aarch64.svcount") } +@StructTargetTestGlobal = global %struct.target.test zeroinitializer