Skip to content

Commit

Permalink
Correctly pass values from the conditional block to the loop during i…
Browse files Browse the repository at this point in the history
…nversion (#3311)

Co-authored-by: Yong He <yonghe@outlook.com>
  • Loading branch information
expipiplus1 and csyonghe authored Nov 6, 2023
1 parent 79677b8 commit da9e0ad
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 25 deletions.
99 changes: 74 additions & 25 deletions source/slang/slang-ir-loop-inversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ static IRParam* duplicateToParamWithDecorations(IRBuilder& builder, IRCloneEnv&
static void invertLoop(IRBuilder& builder, IRLoop* loop)
{
IRBuilderInsertLocScope builderScope(&builder);
builder.setInsertInto(loop->getParent());

const auto s = as<IRBlock>(loop->getParent());
auto domTree = computeDominatorTree(s->getParent());
SLANG_ASSERT(s);
Expand All @@ -152,34 +154,58 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop)
IRCloneEnv cloneEnv;
cloneEnv.squashChildrenMapping = true;

// We don't expect 'd' to have any parameters, because it used to be the
// target of a conditional branch
SLANG_ASSERT(d->getFirstParam() == nullptr);

// Since we are duplicating the loop break condition block (c1) we must
// introduce phi values for anything in it upon which the rest of the
// program (b onwards) uses. Lift the values fron c1 used in b (and
// onwards) to parameters. To avoid a critical edge, pass these via a new
// block, e1.
builder.setInsertInto(b);
List<IRInst*> c1Params;
// program (inside the loop, and b onwards) uses. Lift the values from c1
// used in b (and onwards) to parameters, the same for those used before b.
// To avoid a critical edge, pass these via a new block, e1.
// For any such values used within the loop we can pass directly to d.
//
// c1PostLoopParams are values form c1 used after the loop
List<IRInst*> c1PostLoopParams;
// c1LoopParams are values from c1 used within the loop itself
List<IRInst*> c1LoopParams;
for(auto i : IRInstList<IRInst>(c1->getFirstInst(), c1->getLastInst()))
{
IRParam* p = nullptr;
IRParam* postLoopParam = nullptr;
IRParam* loopParam = nullptr;
traverseUses(i, [&](IRUse* u){
auto userBlock = u->getUser()->getParent();
if(domTree->dominates(b, userBlock))
{
// A new parameter to replace this 'i'
if(!p)
p = duplicateToParamWithDecorations(builder, cloneEnv, i);
u->set(p);
if(!postLoopParam)
{
postLoopParam = duplicateToParamWithDecorations(builder, cloneEnv, i);
b->addParam(postLoopParam);
}
u->set(postLoopParam);
}
else if(userBlock != c1)
{
// A new parameter to replace this 'i'
if(!loopParam)
{
loopParam = duplicateToParamWithDecorations(builder, cloneEnv, i);
d->addParam(loopParam);
}
u->set(loopParam);
}
});
if(p)
c1Params.add(i);
if(postLoopParam)
c1PostLoopParams.add(i);
if(loopParam)
c1LoopParams.add(i);
}

// Create another break block b2 that will act as the new break block for the
// loop. The original break block b will become the merge point for the outer condition.
//
auto b2 = builder.emitBlock();
const auto b2 = builder.emitBlock();
b2->insertBefore(b);

// Create a copy of the parameters in b. b2 will simply pass these to b.
Expand All @@ -194,7 +220,7 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop)

auto e1 = builder.emitBlock();
e1->insertAfter(c1);
builder.emitBranch(b2, c1Params.getCount(), c1Params.getBuffer());
builder.emitBranch(b2, c1PostLoopParams.getCount(), c1PostLoopParams.getBuffer());
c1bUse.set(e1);
c1Terminator->afterBlock.set(d);

Expand All @@ -204,11 +230,15 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop)
traverseUses(b, [&](IRUse* u){
auto userBlock = u->getUser()->getParent();
// Restrict this to just those blocks within this loop
if(userBlock != e1 && userBlock != b2 && domTree->dominates(s, userBlock) && !domTree->dominates(b, userBlock))
if(userBlock != e1
&& userBlock != b2
&& userBlock != s
&& domTree->dominates(s, userBlock)
&& !domTree->dominates(b, userBlock))
{
auto jumpToB2Block = builder.emitBlock();
const auto jumpToB2Block = builder.emitBlock();
jumpToB2Block->insertAfter(userBlock);
builder.emitBranch(b2, c1Params.getCount(), c1Params.getBuffer());
builder.emitBranch(b2, c1PostLoopParams.getCount(), c1PostLoopParams.getBuffer());
u->set(jumpToB2Block);
}
});
Expand All @@ -231,7 +261,12 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop)
auto& c2eUse = c2Terminator->getTrueBlock() == e1 ? c2Terminator->trueBlock : c2Terminator->falseBlock;
c2eUse.set(e2);
builder.setInsertAfter(c2Terminator);
const auto newC2Terminator = builder.emitIfElse(c2Terminator->getCondition(), c2Terminator->getTrueBlock(), c2Terminator->getFalseBlock(), b);
const auto newC2Terminator = builder.emitIfElse(
c2Terminator->getCondition(),
c2Terminator->getTrueBlock(),
c2Terminator->getFalseBlock(),
b
);
c2Terminator->removeAndDeallocate();
// The cloned e2 will branch into b2 by default, rewrite it to branch to b, the correct merge point.
SLANG_ASSERT(cast<IRUnconditionalBranch>(e2->getTerminator())->getTargetBlock() == b2);
Expand All @@ -251,6 +286,7 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop)
const auto l = builder.emitBlock();
l->insertAfter(e2);
loop->insertAtEnd(l);

// We now have
// s: ...1 no-termiator
// c2: if x then goto e2 else goto d (merge at b)
Expand Down Expand Up @@ -306,22 +342,35 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop)
// Beyond just retargeting the loop instruction, we need to make sure any
// parameters the loop instruction is passing to c1 are instead passed to
// 'd', and because we've added parameters to 'd' we need to forward them
// from c1 also which we will accomplish using a new block, e3,
// from c1 also during the back-edge, which we will accomplish using a new
// block, e3,
//
loop->block.set(d);
loop->breakBlock.set(b2);
// Utilize the cloneenv to make sure that when entering the loop we use
// c1's instructions as cloned into c2
builder.setInsertAfter(loop);
List<IRInst*> loopEntryArgs;
for(const auto p : c1LoopParams)
loopEntryArgs.add(cloneInst(&cloneEnv, &builder, p));
for(UInt i = 0; i < loop->getArgCount(); ++i)
loopEntryArgs.add(loop->getArg(i));
const auto newLoop = builder.emitLoop(d, b2, loop->getContinueBlock(), loopEntryArgs.getCount(), loopEntryArgs.getBuffer());
newLoop->sourceLoc = loop->sourceLoc;
loop->transferDecorationsTo(newLoop);
loop->removeAndDeallocate();

// TODO: This really upsets a few later passes, why isn't it ok to do given
// our "irrelevant continue" condition?
// loop->continueBlock.set(loop->getTargetBlock());
SLANG_ASSERT(d->getFirstParam() == nullptr);
c1->insertBefore(b2);
e1->insertAfter(c1);
List<IRInst*> ps;
for(const auto p : c1->getParams())
ps.add(p);
builder.setInsertInto(d);
for(const auto p : ps)
// Add the necessary parameters for the loop state to d, first the
// paramters for instructions in the duplicated conditional block, then the
// ones from the loop body.
List<IRInst*> ps = c1LoopParams;
for(const auto p : c1->getParams())
{
ps.add(p);
const auto q = duplicateToParamWithDecorations(builder, cloneEnv, p);
// Replace all uses, except for those in c1 and e1
List<IRUse*> uses;
Expand Down
32 changes: 32 additions & 0 deletions tests/bugs/inversion-tricky-phi.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
//TEST:SIMPLE(filecheck=CHECK): -stage compute -entry computeMain -target hlsl

// Annoyingly, the reproducer for this doesn't terminate, so just check that we
// succeeded compilation.
// Previously, this would fail in SCCP after loop inversion failed to account
// for condition variables being used in the loop.
// CHECK: computeMain

//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;

[numthreads(1, 1, 1)]
void computeMain(uint i : SV_GroupIndex)
{
float x = 0;
f(x);
outputBuffer[i] = x;
}

void f(inout float r)
{
r = 0;
float a = 0;
do {
do {
a = a + 1;
if (a > 0)
break;
} while (true);
r = a;
} while (true);
}

0 comments on commit da9e0ad

Please sign in to comment.