diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index b4cce680e2876f..42dea13e16246e 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -10225,9 +10225,11 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry( for (const TreeEntry *TE : ForRemoval) Set.erase(TE); } + bool NeedToRemapValues = false; for (auto *It = UsedTEs.begin(); It != UsedTEs.end();) { if (It->empty()) { UsedTEs.erase(It); + NeedToRemapValues = true; continue; } std::advance(It, 1); @@ -10236,6 +10238,19 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry( Entries.clear(); return std::nullopt; } + // Recalculate the mapping between the values and entries sets. + if (NeedToRemapValues) { + DenseMap PrevUsedValuesEntry; + PrevUsedValuesEntry.swap(UsedValuesEntry); + for (auto [Idx, Set] : enumerate(UsedTEs)) { + DenseSet Values; + for (const TreeEntry *E : Set) + Values.insert(E->Scalars.begin(), E->Scalars.end()); + for (const auto &P : PrevUsedValuesEntry) + if (Values.contains(P.first)) + UsedValuesEntry.try_emplace(P.first, Idx); + } + } } unsigned VF = 0; @@ -14001,6 +14016,33 @@ bool BoUpSLP::collectValuesToDemote( }; unsigned Start = 0; unsigned End = I->getNumOperands(); + + auto FinalAnalysis = [&](const TreeEntry *ITE = nullptr) { + if (!IsProfitableToDemote) + return false; + return (ITE && ITE->UserTreeIndices.size() > 1) || + IsPotentiallyTruncated(I, BitWidth); + }; + auto ProcessOperands = [&](ArrayRef Operands, bool &NeedToExit) { + NeedToExit = false; + unsigned InitLevel = MaxDepthLevel; + for (Value *IncValue : Operands) { + unsigned Level = InitLevel; + if (!collectValuesToDemote(IncValue, IsProfitableToDemoteRoot, BitWidth, + ToDemote, DemotedConsts, Visited, Level, + IsProfitableToDemote, IsTruncRoot)) { + if (!IsProfitableToDemote) + return false; + NeedToExit = true; + if (!FinalAnalysis(ITE)) + return false; + continue; + } + MaxDepthLevel = std::max(MaxDepthLevel, Level); + } + return true; + }; + bool NeedToExit = false; switch (I->getOpcode()) { // We can always demote truncations and extensions. Since truncations can @@ -14026,35 +14068,21 @@ bool BoUpSLP::collectValuesToDemote( case Instruction::And: case Instruction::Or: case Instruction::Xor: { - unsigned Level1, Level2; - if ((ITE->UserTreeIndices.size() > 1 && - !IsPotentiallyTruncated(I, BitWidth)) || - !collectValuesToDemote(I->getOperand(0), IsProfitableToDemoteRoot, - BitWidth, ToDemote, DemotedConsts, Visited, - Level1, IsProfitableToDemote, IsTruncRoot) || - !collectValuesToDemote(I->getOperand(1), IsProfitableToDemoteRoot, - BitWidth, ToDemote, DemotedConsts, Visited, - Level2, IsProfitableToDemote, IsTruncRoot)) + if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth)) + return false; + if (!ProcessOperands({I->getOperand(0), I->getOperand(1)}, NeedToExit)) return false; - MaxDepthLevel = std::max(Level1, Level2); break; } // We can demote selects if we can demote their true and false values. case Instruction::Select: { + if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth)) + return false; Start = 1; - unsigned Level1, Level2; - SelectInst *SI = cast(I); - if ((ITE->UserTreeIndices.size() > 1 && - !IsPotentiallyTruncated(I, BitWidth)) || - !collectValuesToDemote(SI->getTrueValue(), IsProfitableToDemoteRoot, - BitWidth, ToDemote, DemotedConsts, Visited, - Level1, IsProfitableToDemote, IsTruncRoot) || - !collectValuesToDemote(SI->getFalseValue(), IsProfitableToDemoteRoot, - BitWidth, ToDemote, DemotedConsts, Visited, - Level2, IsProfitableToDemote, IsTruncRoot)) + auto *SI = cast(I); + if (!ProcessOperands({SI->getTrueValue(), SI->getFalseValue()}, NeedToExit)) return false; - MaxDepthLevel = std::max(Level1, Level2); break; } @@ -14065,22 +14093,20 @@ bool BoUpSLP::collectValuesToDemote( MaxDepthLevel = 0; if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth)) return false; - for (Value *IncValue : PN->incoming_values()) { - unsigned Level; - if (!collectValuesToDemote(IncValue, IsProfitableToDemoteRoot, BitWidth, - ToDemote, DemotedConsts, Visited, Level, - IsProfitableToDemote, IsTruncRoot)) - return false; - MaxDepthLevel = std::max(MaxDepthLevel, Level); - } + SmallVector Ops(PN->incoming_values().begin(), + PN->incoming_values().end()); + if (!ProcessOperands(Ops, NeedToExit)) + return false; break; } // Otherwise, conservatively give up. default: MaxDepthLevel = 1; - return IsProfitableToDemote && IsPotentiallyTruncated(I, BitWidth); + return FinalAnalysis(); } + if (NeedToExit) + return true; ++MaxDepthLevel; // Gather demoted constant operands. @@ -14119,6 +14145,7 @@ void BoUpSLP::computeMinimumValueSizes() { // The first value node for store/insertelement is sext/zext/trunc? Skip it, // resize to the final type. + bool IsTruncRoot = false; bool IsProfitableToDemoteRoot = !IsStoreOrInsertElt; if (NodeIdx != 0 && VectorizableTree[NodeIdx]->State == TreeEntry::Vectorize && @@ -14126,8 +14153,9 @@ void BoUpSLP::computeMinimumValueSizes() { VectorizableTree[NodeIdx]->getOpcode() == Instruction::SExt || VectorizableTree[NodeIdx]->getOpcode() == Instruction::Trunc)) { assert(IsStoreOrInsertElt && "Expected store/insertelement seeded graph."); - ++NodeIdx; + IsTruncRoot = VectorizableTree[NodeIdx]->getOpcode() == Instruction::Trunc; IsProfitableToDemoteRoot = true; + ++NodeIdx; } // Analyzed in reduction already and not profitable - exit. @@ -14259,7 +14287,6 @@ void BoUpSLP::computeMinimumValueSizes() { ReductionBitWidth = bit_ceil(ReductionBitWidth); } bool IsTopRoot = NodeIdx == 0; - bool IsTruncRoot = false; while (NodeIdx < VectorizableTree.size() && VectorizableTree[NodeIdx]->State == TreeEntry::Vectorize && VectorizableTree[NodeIdx]->getOpcode() == Instruction::Trunc) { diff --git a/llvm/test/Transforms/SLPVectorizer/AArch64/horizontal.ll b/llvm/test/Transforms/SLPVectorizer/AArch64/horizontal.ll index 1986b51ec94828..02d1f9f60d0ca1 100644 --- a/llvm/test/Transforms/SLPVectorizer/AArch64/horizontal.ll +++ b/llvm/test/Transforms/SLPVectorizer/AArch64/horizontal.ll @@ -228,7 +228,7 @@ for.end: ; preds = %for.end.loopexit, % ; YAML-NEXT: Function: test_unrolled_select ; YAML-NEXT: Args: ; YAML-NEXT: - String: 'Vectorized horizontal reduction with cost ' -; YAML-NEXT: - Cost: '-36' +; YAML-NEXT: - Cost: '-40' ; YAML-NEXT: - String: ' and with tree size ' ; YAML-NEXT: - TreeSize: '10' @@ -246,15 +246,17 @@ define i32 @test_unrolled_select(ptr noalias nocapture readonly %blk1, ptr noali ; CHECK-NEXT: [[P2_045:%.*]] = phi ptr [ [[BLK2:%.*]], [[FOR_BODY_LR_PH]] ], [ [[ADD_PTR88:%.*]], [[IF_END_86]] ] ; CHECK-NEXT: [[P1_044:%.*]] = phi ptr [ [[BLK1:%.*]], [[FOR_BODY_LR_PH]] ], [ [[ADD_PTR:%.*]], [[IF_END_86]] ] ; CHECK-NEXT: [[TMP0:%.*]] = load <8 x i8>, ptr [[P1_044]], align 1 -; CHECK-NEXT: [[TMP1:%.*]] = zext <8 x i8> [[TMP0]] to <8 x i32> +; CHECK-NEXT: [[TMP1:%.*]] = zext <8 x i8> [[TMP0]] to <8 x i16> ; CHECK-NEXT: [[TMP2:%.*]] = load <8 x i8>, ptr [[P2_045]], align 1 -; CHECK-NEXT: [[TMP3:%.*]] = zext <8 x i8> [[TMP2]] to <8 x i32> -; CHECK-NEXT: [[TMP4:%.*]] = sub nsw <8 x i32> [[TMP1]], [[TMP3]] -; CHECK-NEXT: [[TMP5:%.*]] = icmp slt <8 x i32> [[TMP4]], zeroinitializer -; CHECK-NEXT: [[TMP6:%.*]] = sub nsw <8 x i32> zeroinitializer, [[TMP4]] -; CHECK-NEXT: [[TMP7:%.*]] = select <8 x i1> [[TMP5]], <8 x i32> [[TMP6]], <8 x i32> [[TMP4]] -; CHECK-NEXT: [[TMP8:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP7]]) -; CHECK-NEXT: [[OP_RDX]] = add i32 [[TMP8]], [[S_047]] +; CHECK-NEXT: [[TMP3:%.*]] = zext <8 x i8> [[TMP2]] to <8 x i16> +; CHECK-NEXT: [[TMP4:%.*]] = sub <8 x i16> [[TMP1]], [[TMP3]] +; CHECK-NEXT: [[TMP5:%.*]] = trunc <8 x i16> [[TMP4]] to <8 x i1> +; CHECK-NEXT: [[TMP6:%.*]] = icmp slt <8 x i1> [[TMP5]], zeroinitializer +; CHECK-NEXT: [[TMP7:%.*]] = sub <8 x i16> zeroinitializer, [[TMP4]] +; CHECK-NEXT: [[TMP8:%.*]] = select <8 x i1> [[TMP6]], <8 x i16> [[TMP7]], <8 x i16> [[TMP4]] +; CHECK-NEXT: [[TMP9:%.*]] = zext <8 x i16> [[TMP8]] to <8 x i32> +; CHECK-NEXT: [[TMP10:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP9]]) +; CHECK-NEXT: [[OP_RDX]] = add i32 [[TMP10]], [[S_047]] ; CHECK-NEXT: [[CMP83:%.*]] = icmp slt i32 [[OP_RDX]], [[LIM:%.*]] ; CHECK-NEXT: br i1 [[CMP83]], label [[IF_END_86]], label [[FOR_END_LOOPEXIT:%.*]] ; CHECK: if.end.86: