From 90cdc03e7f5bda2e31573d48450a8ac8fa856efa Mon Sep 17 00:00:00 2001 From: Jay Foad Date: Fri, 25 Oct 2024 12:56:10 +0100 Subject: [PATCH] [IR] Fix undiagnosed cases of structs containing scalable vectors (#113455) Type::isScalableTy and StructType::containsScalableVectorType failed to detect some cases of structs containing scalable vectors because containsScalableVectorType did not call back into isScalableTy to check the element types. Fix this, which requires sharing the same Visited set in both functions. Also change the external API so that callers are never required to pass in a Visited set, and normalize the naming to isScalableTy. --- llvm/include/llvm/IR/DerivedTypes.h | 4 +-- llvm/include/llvm/IR/Type.h | 1 + llvm/lib/AsmParser/LLParser.cpp | 2 +- llvm/lib/IR/Type.cpp | 29 ++++++++----------- llvm/lib/IR/Verifier.cpp | 3 +- .../InstCombine/InstructionCombining.cpp | 2 +- llvm/test/Verifier/scalable-global-vars.ll | 14 +++++++++ 7 files changed, 32 insertions(+), 23 deletions(-) diff --git a/llvm/include/llvm/IR/DerivedTypes.h b/llvm/include/llvm/IR/DerivedTypes.h index a24801d8bdf834..820b5c0707df6c 100644 --- a/llvm/include/llvm/IR/DerivedTypes.h +++ b/llvm/include/llvm/IR/DerivedTypes.h @@ -290,8 +290,8 @@ class StructType : public Type { bool isSized(SmallPtrSetImpl *Visited = nullptr) const; /// Returns true if this struct contains a scalable vector. - bool - containsScalableVectorType(SmallPtrSetImpl *Visited = nullptr) const; + bool isScalableTy(SmallPtrSetImpl &Visited) const; + using Type::isScalableTy; /// Returns true if this struct contains homogeneous scalable vector types. /// Note that the definition of homogeneous scalable vector type is not diff --git a/llvm/include/llvm/IR/Type.h b/llvm/include/llvm/IR/Type.h index 2f53197df19998..d563b25d600a0c 100644 --- a/llvm/include/llvm/IR/Type.h +++ b/llvm/include/llvm/IR/Type.h @@ -206,6 +206,7 @@ class Type { bool isScalableTargetExtTy() const; /// Return true if this is a type whose size is a known multiple of vscale. + bool isScalableTy(SmallPtrSetImpl &Visited) const; bool isScalableTy() const; /// Return true if this is a FP type or a vector of FP. diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp index 6a2372c9751408..8ddb2efb0e26c2 100644 --- a/llvm/lib/AsmParser/LLParser.cpp +++ b/llvm/lib/AsmParser/LLParser.cpp @@ -8525,7 +8525,7 @@ int LLParser::parseGetElementPtr(Instruction *&Inst, PerFunctionState &PFS) { return error(Loc, "base element of getelementptr must be sized"); auto *STy = dyn_cast(Ty); - if (STy && STy->containsScalableVectorType()) + if (STy && STy->isScalableTy()) return error(Loc, "getelementptr cannot target structure that contains " "scalable vector type"); diff --git a/llvm/lib/IR/Type.cpp b/llvm/lib/IR/Type.cpp index f618263f79c313..912b1a3960ef19 100644 --- a/llvm/lib/IR/Type.cpp +++ b/llvm/lib/IR/Type.cpp @@ -58,16 +58,19 @@ bool Type::isIntegerTy(unsigned Bitwidth) const { return isIntegerTy() && cast(this)->getBitWidth() == Bitwidth; } -bool Type::isScalableTy() const { +bool Type::isScalableTy(SmallPtrSetImpl &Visited) const { if (const auto *ATy = dyn_cast(this)) - return ATy->getElementType()->isScalableTy(); - if (const auto *STy = dyn_cast(this)) { - SmallPtrSet Visited; - return STy->containsScalableVectorType(&Visited); - } + return ATy->getElementType()->isScalableTy(Visited); + if (const auto *STy = dyn_cast(this)) + return STy->isScalableTy(Visited); return getTypeID() == ScalableVectorTyID || isScalableTargetExtTy(); } +bool Type::isScalableTy() const { + SmallPtrSet Visited; + return isScalableTy(Visited); +} + const fltSemantics &Type::getFltSemantics() const { switch (getTypeID()) { case HalfTyID: return APFloat::IEEEhalf(); @@ -394,30 +397,22 @@ StructType *StructType::get(LLVMContext &Context, ArrayRef ETypes, return ST; } -bool StructType::containsScalableVectorType( - SmallPtrSetImpl *Visited) const { +bool StructType::isScalableTy(SmallPtrSetImpl &Visited) const { if ((getSubclassData() & SCDB_ContainsScalableVector) != 0) return true; if ((getSubclassData() & SCDB_NotContainsScalableVector) != 0) return false; - if (Visited && !Visited->insert(const_cast(this)).second) + if (!Visited.insert(this).second) return false; for (Type *Ty : elements()) { - if (isa(Ty)) { + if (Ty->isScalableTy(Visited)) { const_cast(this)->setSubclassData( getSubclassData() | SCDB_ContainsScalableVector); return true; } - if (auto *STy = dyn_cast(Ty)) { - if (STy->containsScalableVectorType(Visited)) { - const_cast(this)->setSubclassData( - getSubclassData() | SCDB_ContainsScalableVector); - return true; - } - } } // For structures that are opaque, return false but do not set the diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp index f34fe7594c8602..60e65392218dad 100644 --- a/llvm/lib/IR/Verifier.cpp +++ b/llvm/lib/IR/Verifier.cpp @@ -4107,8 +4107,7 @@ void Verifier::visitGetElementPtrInst(GetElementPtrInst &GEP) { Check(GEP.getSourceElementType()->isSized(), "GEP into unsized type!", &GEP); if (auto *STy = dyn_cast(GEP.getSourceElementType())) { - SmallPtrSet Visited; - Check(!STy->containsScalableVectorType(&Visited), + Check(!STy->isScalableTy(), "getelementptr cannot target structure that contains scalable vector" "type", &GEP); diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index c8b9f166b16020..971ace2a4f4716 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -4087,7 +4087,7 @@ Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) { if (LoadInst *L = dyn_cast(Agg)) { // Bail out if the aggregate contains scalable vector type if (auto *STy = dyn_cast(Agg->getType()); - STy && STy->containsScalableVectorType()) + STy && STy->isScalableTy()) return nullptr; // If the (non-volatile) load only has one use, we can rewrite this to a diff --git a/llvm/test/Verifier/scalable-global-vars.ll b/llvm/test/Verifier/scalable-global-vars.ll index 81882261e664ef..fb9a3067acba98 100644 --- a/llvm/test/Verifier/scalable-global-vars.ll +++ b/llvm/test/Verifier/scalable-global-vars.ll @@ -15,3 +15,17 @@ ; CHECK-NEXT: ptr @ScalableVecStructGlobal @ScalableVecStructGlobal = global { i32, } zeroinitializer +; CHECK-NEXT: Globals cannot contain scalable types +; CHECK-NEXT: ptr @StructTestGlobal +%struct.test = type { , } +@StructTestGlobal = global %struct.test zeroinitializer + +; CHECK-NEXT: Globals cannot contain scalable types +; CHECK-NEXT: ptr @StructArrayTestGlobal +%struct.array.test = type { [2 x ] } +@StructArrayTestGlobal = global %struct.array.test zeroinitializer + +; CHECK-NEXT: Globals cannot contain scalable types +; CHECK-NEXT: ptr @StructTargetTestGlobal +%struct.target.test = type { target("aarch64.svcount"), target("aarch64.svcount") } +@StructTargetTestGlobal = global %struct.target.test zeroinitializer