Skip to content

Commit

Permalink
[IR] Fix undiagnosed cases of structs containing scalable vectors (ll…
Browse files Browse the repository at this point in the history
…vm#113455)

Type::isScalableTy and StructType::containsScalableVectorType failed to
detect some cases of structs containing scalable vectors because
containsScalableVectorType did not call back into isScalableTy to check
the element types. Fix this, which requires sharing the same Visited set
in both functions. Also change the external API so that callers are
never required to pass in a Visited set, and normalize the naming to
isScalableTy.
  • Loading branch information
jayfoad authored Oct 25, 2024
1 parent 2c0b348 commit 90cdc03
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 23 deletions.
4 changes: 2 additions & 2 deletions llvm/include/llvm/IR/DerivedTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,8 @@ class StructType : public Type {
bool isSized(SmallPtrSetImpl<Type *> *Visited = nullptr) const;

/// Returns true if this struct contains a scalable vector.
bool
containsScalableVectorType(SmallPtrSetImpl<Type *> *Visited = nullptr) const;
bool isScalableTy(SmallPtrSetImpl<const Type *> &Visited) const;
using Type::isScalableTy;

/// Returns true if this struct contains homogeneous scalable vector types.
/// Note that the definition of homogeneous scalable vector type is not
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ class Type {
bool isScalableTargetExtTy() const;

/// Return true if this is a type whose size is a known multiple of vscale.
bool isScalableTy(SmallPtrSetImpl<const Type *> &Visited) const;
bool isScalableTy() const;

/// Return true if this is a FP type or a vector of FP.
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/AsmParser/LLParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8525,7 +8525,7 @@ int LLParser::parseGetElementPtr(Instruction *&Inst, PerFunctionState &PFS) {
return error(Loc, "base element of getelementptr must be sized");

auto *STy = dyn_cast<StructType>(Ty);
if (STy && STy->containsScalableVectorType())
if (STy && STy->isScalableTy())
return error(Loc, "getelementptr cannot target structure that contains "
"scalable vector type");

Expand Down
29 changes: 12 additions & 17 deletions llvm/lib/IR/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,19 @@ bool Type::isIntegerTy(unsigned Bitwidth) const {
return isIntegerTy() && cast<IntegerType>(this)->getBitWidth() == Bitwidth;
}

bool Type::isScalableTy() const {
bool Type::isScalableTy(SmallPtrSetImpl<const Type *> &Visited) const {
if (const auto *ATy = dyn_cast<ArrayType>(this))
return ATy->getElementType()->isScalableTy();
if (const auto *STy = dyn_cast<StructType>(this)) {
SmallPtrSet<Type *, 4> Visited;
return STy->containsScalableVectorType(&Visited);
}
return ATy->getElementType()->isScalableTy(Visited);
if (const auto *STy = dyn_cast<StructType>(this))
return STy->isScalableTy(Visited);
return getTypeID() == ScalableVectorTyID || isScalableTargetExtTy();
}

bool Type::isScalableTy() const {
SmallPtrSet<const Type *, 4> Visited;
return isScalableTy(Visited);
}

const fltSemantics &Type::getFltSemantics() const {
switch (getTypeID()) {
case HalfTyID: return APFloat::IEEEhalf();
Expand Down Expand Up @@ -394,30 +397,22 @@ StructType *StructType::get(LLVMContext &Context, ArrayRef<Type*> ETypes,
return ST;
}

bool StructType::containsScalableVectorType(
SmallPtrSetImpl<Type *> *Visited) const {
bool StructType::isScalableTy(SmallPtrSetImpl<const Type *> &Visited) const {
if ((getSubclassData() & SCDB_ContainsScalableVector) != 0)
return true;

if ((getSubclassData() & SCDB_NotContainsScalableVector) != 0)
return false;

if (Visited && !Visited->insert(const_cast<StructType *>(this)).second)
if (!Visited.insert(this).second)
return false;

for (Type *Ty : elements()) {
if (isa<ScalableVectorType>(Ty)) {
if (Ty->isScalableTy(Visited)) {
const_cast<StructType *>(this)->setSubclassData(
getSubclassData() | SCDB_ContainsScalableVector);
return true;
}
if (auto *STy = dyn_cast<StructType>(Ty)) {
if (STy->containsScalableVectorType(Visited)) {
const_cast<StructType *>(this)->setSubclassData(
getSubclassData() | SCDB_ContainsScalableVector);
return true;
}
}
}

// For structures that are opaque, return false but do not set the
Expand Down
3 changes: 1 addition & 2 deletions llvm/lib/IR/Verifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4107,8 +4107,7 @@ void Verifier::visitGetElementPtrInst(GetElementPtrInst &GEP) {
Check(GEP.getSourceElementType()->isSized(), "GEP into unsized type!", &GEP);

if (auto *STy = dyn_cast<StructType>(GEP.getSourceElementType())) {
SmallPtrSet<Type *, 4> Visited;
Check(!STy->containsScalableVectorType(&Visited),
Check(!STy->isScalableTy(),
"getelementptr cannot target structure that contains scalable vector"
"type",
&GEP);
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4087,7 +4087,7 @@ Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) {
if (LoadInst *L = dyn_cast<LoadInst>(Agg)) {
// Bail out if the aggregate contains scalable vector type
if (auto *STy = dyn_cast<StructType>(Agg->getType());
STy && STy->containsScalableVectorType())
STy && STy->isScalableTy())
return nullptr;

// If the (non-volatile) load only has one use, we can rewrite this to a
Expand Down
14 changes: 14 additions & 0 deletions llvm/test/Verifier/scalable-global-vars.ll
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,17 @@
; CHECK-NEXT: ptr @ScalableVecStructGlobal
@ScalableVecStructGlobal = global { i32, <vscale x 4 x i32> } zeroinitializer

; CHECK-NEXT: Globals cannot contain scalable types
; CHECK-NEXT: ptr @StructTestGlobal
%struct.test = type { <vscale x 1 x double>, <vscale x 1 x double> }
@StructTestGlobal = global %struct.test zeroinitializer

; CHECK-NEXT: Globals cannot contain scalable types
; CHECK-NEXT: ptr @StructArrayTestGlobal
%struct.array.test = type { [2 x <vscale x 1 x double>] }
@StructArrayTestGlobal = global %struct.array.test zeroinitializer

; CHECK-NEXT: Globals cannot contain scalable types
; CHECK-NEXT: ptr @StructTargetTestGlobal
%struct.target.test = type { target("aarch64.svcount"), target("aarch64.svcount") }
@StructTargetTestGlobal = global %struct.target.test zeroinitializer

0 comments on commit 90cdc03

Please sign in to comment.