Commit 3fed4306 by Olli Etuaho

Unfold short-circuiting operators in loop conditions correctly

Sometimes short-circuiting operators need to be unfolded to if statements. If the unfolded operator is inside a loop condition or expression, it needs to be evaluated repeatedly inside the loop. Add logic to UnfoldShortCircuitToIf that can move or copy the unfolded part of loop conditions or expressions to inside the loop. The exact changes that need to be done depend on the type of the loop. For loops may require also moving the initializer to outside the loop. The unfolded expression inside a loop condition or expression is moved or copied to inside the loop on the first traversal of the loop node, and unfolding to if is deferred until a second traversal. This keeps the code relatively simple. BUG=angleproject:1167 TEST=WebGL 2 conformance tests, dEQP-GLES2.functional.shaders.*select_iteration_count* Change-Id: Ieffc0ea858186054378d387dca9aa64a5fa95137 Reviewed-on: https://chromium-review.googlesource.com/310230Reviewed-by: 's avatarCorentin Wallez <cwallez@chromium.org> Reviewed-by: 's avatarZhenyao Mo <zmo@chromium.org> Tested-by: 's avatarOlli Etuaho <oetuaho@nvidia.com>
parent 15c2ac30
...@@ -261,7 +261,7 @@ bool TIntermLoop::replaceChildNode( ...@@ -261,7 +261,7 @@ bool TIntermLoop::replaceChildNode(
REPLACE_IF_IS(mInit, TIntermNode, original, replacement); REPLACE_IF_IS(mInit, TIntermNode, original, replacement);
REPLACE_IF_IS(mCond, TIntermTyped, original, replacement); REPLACE_IF_IS(mCond, TIntermTyped, original, replacement);
REPLACE_IF_IS(mExpr, TIntermTyped, original, replacement); REPLACE_IF_IS(mExpr, TIntermTyped, original, replacement);
REPLACE_IF_IS(mBody, TIntermNode, original, replacement); REPLACE_IF_IS(mBody, TIntermAggregate, original, replacement);
return false; return false;
} }
......
...@@ -173,14 +173,13 @@ class TIntermLoop : public TIntermNode ...@@ -173,14 +173,13 @@ class TIntermLoop : public TIntermNode
{ {
public: public:
TIntermLoop(TLoopType type, TIntermLoop(TLoopType type,
TIntermNode *init, TIntermTyped *cond, TIntermTyped *expr, TIntermNode *init,
TIntermNode *body) TIntermTyped *cond,
: mType(type), TIntermTyped *expr,
mInit(init), TIntermAggregate *body)
mCond(cond), : mType(type), mInit(init), mCond(cond), mExpr(expr), mBody(body), mUnrollFlag(false)
mExpr(expr), {
mBody(body), }
mUnrollFlag(false) { }
TIntermLoop *getAsLoopNode() override { return this; } TIntermLoop *getAsLoopNode() override { return this; }
void traverse(TIntermTraverser *it) override; void traverse(TIntermTraverser *it) override;
...@@ -190,7 +189,7 @@ class TIntermLoop : public TIntermNode ...@@ -190,7 +189,7 @@ class TIntermLoop : public TIntermNode
TIntermNode *getInit() { return mInit; } TIntermNode *getInit() { return mInit; }
TIntermTyped *getCondition() { return mCond; } TIntermTyped *getCondition() { return mCond; }
TIntermTyped *getExpression() { return mExpr; } TIntermTyped *getExpression() { return mExpr; }
TIntermNode *getBody() { return mBody; } TIntermAggregate *getBody() { return mBody; }
void setUnrollFlag(bool flag) { mUnrollFlag = flag; } void setUnrollFlag(bool flag) { mUnrollFlag = flag; }
bool getUnrollFlag() const { return mUnrollFlag; } bool getUnrollFlag() const { return mUnrollFlag; }
...@@ -200,7 +199,7 @@ class TIntermLoop : public TIntermNode ...@@ -200,7 +199,7 @@ class TIntermLoop : public TIntermNode
TIntermNode *mInit; // for-loop initialization TIntermNode *mInit; // for-loop initialization
TIntermTyped *mCond; // loop exit condition TIntermTyped *mCond; // loop exit condition
TIntermTyped *mExpr; // for-loop expression TIntermTyped *mExpr; // for-loop expression
TIntermNode *mBody; // loop body TIntermAggregate *mBody; // loop body
bool mUnrollFlag; // Whether the loop should be unrolled or not. bool mUnrollFlag; // Whether the loop should be unrolled or not.
}; };
......
...@@ -24,19 +24,39 @@ class UnfoldShortCircuitTraverser : public TIntermTraverser ...@@ -24,19 +24,39 @@ class UnfoldShortCircuitTraverser : public TIntermTraverser
bool visitBinary(Visit visit, TIntermBinary *node) override; bool visitBinary(Visit visit, TIntermBinary *node) override;
bool visitAggregate(Visit visit, TIntermAggregate *node) override; bool visitAggregate(Visit visit, TIntermAggregate *node) override;
bool visitSelection(Visit visit, TIntermSelection *node) override; bool visitSelection(Visit visit, TIntermSelection *node) override;
bool visitLoop(Visit visit, TIntermLoop *node) override;
void nextIteration(); void nextIteration();
bool foundShortCircuit() const { return mFoundShortCircuit; } bool foundShortCircuit() const { return mFoundShortCircuit; }
protected: protected:
// Check if the traversal is inside a loop condition or expression, in which case the unfolded
// expression needs to be copied inside the loop. Returns true if the copying is done, in which
// case no further unfolding should be done on the same traversal.
// The parameters are the node that will be unfolded to multiple statements and so can't remain
// inside a loop condition, and its parent.
bool copyLoopConditionOrExpression(TIntermNode *parent, TIntermTyped *node);
// Marked to true once an operation that needs to be unfolded has been found. // Marked to true once an operation that needs to be unfolded has been found.
// After that, no more unfolding is performed on that traversal. // After that, no more unfolding is performed on that traversal.
bool mFoundShortCircuit; bool mFoundShortCircuit;
// Set to the loop node while a loop condition or expression is being traversed.
TIntermLoop *mParentLoop;
// Parent of the loop node while a loop condition or expression is being traversed.
TIntermNode *mLoopParent;
bool mInLoopCondition;
bool mInLoopExpression;
}; };
UnfoldShortCircuitTraverser::UnfoldShortCircuitTraverser() UnfoldShortCircuitTraverser::UnfoldShortCircuitTraverser()
: TIntermTraverser(true, false, true), : TIntermTraverser(true, false, true),
mFoundShortCircuit(false) mFoundShortCircuit(false),
mParentLoop(nullptr),
mLoopParent(nullptr),
mInLoopCondition(false),
mInLoopExpression(false)
{ {
} }
...@@ -56,10 +76,12 @@ bool UnfoldShortCircuitTraverser::visitBinary(Visit visit, TIntermBinary *node) ...@@ -56,10 +76,12 @@ bool UnfoldShortCircuitTraverser::visitBinary(Visit visit, TIntermBinary *node)
{ {
case EOpLogicalOr: case EOpLogicalOr:
mFoundShortCircuit = true; mFoundShortCircuit = true;
if (!copyLoopConditionOrExpression(getParentNode(), node))
// "x || y" is equivalent to "x ? true : y", which unfolds to "bool s; if(x) s = true; else s = y;",
// and then further simplifies down to "bool s = x; if(!s) s = y;".
{ {
// "x || y" is equivalent to "x ? true : y", which unfolds to "bool s; if(x) s = true;
// else s = y;",
// and then further simplifies down to "bool s = x; if(!s) s = y;".
TIntermSequence insertions; TIntermSequence insertions;
TType boolType(EbtBool, EbpUndefined, EvqTemporary); TType boolType(EbtBool, EbpUndefined, EvqTemporary);
...@@ -83,10 +105,11 @@ bool UnfoldShortCircuitTraverser::visitBinary(Visit visit, TIntermBinary *node) ...@@ -83,10 +105,11 @@ bool UnfoldShortCircuitTraverser::visitBinary(Visit visit, TIntermBinary *node)
return false; return false;
case EOpLogicalAnd: case EOpLogicalAnd:
mFoundShortCircuit = true; mFoundShortCircuit = true;
if (!copyLoopConditionOrExpression(getParentNode(), node))
// "x && y" is equivalent to "x ? y : false", which unfolds to "bool s; if(x) s = y; else s = false;",
// and then further simplifies down to "bool s = x; if(s) s = y;".
{ {
// "x && y" is equivalent to "x ? y : false", which unfolds to "bool s; if(x) s = y;
// else s = false;",
// and then further simplifies down to "bool s = x; if(s) s = y;".
TIntermSequence insertions; TIntermSequence insertions;
TType boolType(EbtBool, EbpUndefined, EvqTemporary); TType boolType(EbtBool, EbpUndefined, EvqTemporary);
...@@ -120,29 +143,35 @@ bool UnfoldShortCircuitTraverser::visitSelection(Visit visit, TIntermSelection * ...@@ -120,29 +143,35 @@ bool UnfoldShortCircuitTraverser::visitSelection(Visit visit, TIntermSelection *
if (visit == PreVisit && node->usesTernaryOperator()) if (visit == PreVisit && node->usesTernaryOperator())
{ {
mFoundShortCircuit = true; mFoundShortCircuit = true;
TIntermSequence insertions; if (!copyLoopConditionOrExpression(getParentNode(), node))
{
TIntermSequence insertions;
TIntermSymbol *tempSymbol = createTempSymbol(node->getType()); TIntermSymbol *tempSymbol = createTempSymbol(node->getType());
TIntermAggregate *tempDeclaration = new TIntermAggregate(EOpDeclaration); TIntermAggregate *tempDeclaration = new TIntermAggregate(EOpDeclaration);
tempDeclaration->getSequence()->push_back(tempSymbol); tempDeclaration->getSequence()->push_back(tempSymbol);
insertions.push_back(tempDeclaration); insertions.push_back(tempDeclaration);
TIntermAggregate *trueBlock = new TIntermAggregate(EOpSequence); TIntermAggregate *trueBlock = new TIntermAggregate(EOpSequence);
TIntermBinary *trueAssignment = createTempAssignment(node->getTrueBlock()->getAsTyped()); TIntermBinary *trueAssignment =
trueBlock->getSequence()->push_back(trueAssignment); createTempAssignment(node->getTrueBlock()->getAsTyped());
trueBlock->getSequence()->push_back(trueAssignment);
TIntermAggregate *falseBlock = new TIntermAggregate(EOpSequence); TIntermAggregate *falseBlock = new TIntermAggregate(EOpSequence);
TIntermBinary *falseAssignment = createTempAssignment(node->getFalseBlock()->getAsTyped()); TIntermBinary *falseAssignment =
falseBlock->getSequence()->push_back(falseAssignment); createTempAssignment(node->getFalseBlock()->getAsTyped());
falseBlock->getSequence()->push_back(falseAssignment);
TIntermSelection *ifNode = new TIntermSelection(node->getCondition()->getAsTyped(), trueBlock, falseBlock); TIntermSelection *ifNode =
insertions.push_back(ifNode); new TIntermSelection(node->getCondition()->getAsTyped(), trueBlock, falseBlock);
insertions.push_back(ifNode);
insertStatementsInParentBlock(insertions); insertStatementsInParentBlock(insertions);
TIntermSymbol *ternaryResult = createTempSymbol(node->getType()); TIntermSymbol *ternaryResult = createTempSymbol(node->getType());
NodeUpdateEntry replaceVariable(getParentNode(), node, ternaryResult, false); NodeUpdateEntry replaceVariable(getParentNode(), node, ternaryResult, false);
mReplacements.push_back(replaceVariable); mReplacements.push_back(replaceVariable);
}
return false; return false;
} }
...@@ -170,25 +199,148 @@ bool UnfoldShortCircuitTraverser::visitAggregate(Visit visit, TIntermAggregate * ...@@ -170,25 +199,148 @@ bool UnfoldShortCircuitTraverser::visitAggregate(Visit visit, TIntermAggregate *
mMultiReplacements.clear(); mMultiReplacements.clear();
mInsertions.clear(); mInsertions.clear();
TIntermSequence insertions; if (!copyLoopConditionOrExpression(getParentNode(), node))
TIntermSequence *seq = node->getSequence(); {
TIntermSequence insertions;
TIntermSequence *seq = node->getSequence();
TIntermSequence::size_type i = 0;
ASSERT(!seq->empty());
while (i < seq->size() - 1)
{
TIntermTyped *child = (*seq)[i]->getAsTyped();
insertions.push_back(child);
++i;
}
insertStatementsInParentBlock(insertions);
NodeUpdateEntry replaceVariable(getParentNode(), node, (*seq)[i], false);
mReplacements.push_back(replaceVariable);
}
}
}
return true;
}
bool UnfoldShortCircuitTraverser::visitLoop(Visit visit, TIntermLoop *node)
{
if (visit == PreVisit)
{
if (mFoundShortCircuit)
return false; // No need to traverse further
TIntermSequence::size_type i = 0; mLoopParent = getParentNode();
ASSERT(!seq->empty()); mParentLoop = node;
while (i < seq->size() - 1) incrementDepth(node);
if (node->getInit())
{
node->getInit()->traverse(this);
if (mFoundShortCircuit)
{ {
TIntermTyped *child = (*seq)[i]->getAsTyped(); decrementDepth();
insertions.push_back(child); return false;
++i;
} }
}
insertStatementsInParentBlock(insertions); if (node->getCondition())
{
mInLoopCondition = true;
node->getCondition()->traverse(this);
mInLoopCondition = false;
NodeUpdateEntry replaceVariable(getParentNode(), node, (*seq)[i], false); if (mFoundShortCircuit)
mReplacements.push_back(replaceVariable); {
decrementDepth();
return false;
}
}
if (node->getExpression())
{
mInLoopExpression = true;
node->getExpression()->traverse(this);
mInLoopExpression = false;
if (mFoundShortCircuit)
{
decrementDepth();
return false;
}
} }
if (node->getBody())
node->getBody()->traverse(this);
decrementDepth();
} }
return true; return false;
}
bool UnfoldShortCircuitTraverser::copyLoopConditionOrExpression(TIntermNode *parent,
TIntermTyped *node)
{
if (mInLoopCondition)
{
mReplacements.push_back(
NodeUpdateEntry(parent, node, createTempSymbol(node->getType()), false));
TIntermAggregate *body = mParentLoop->getBody();
TIntermSequence empty;
if (mParentLoop->getType() == ELoopDoWhile)
{
// Declare the temporary variable before the loop.
TIntermSequence insertionsBeforeLoop;
insertionsBeforeLoop.push_back(createTempDeclaration(node->getType()));
insertStatementsInParentBlock(insertionsBeforeLoop);
// Move a part of do-while loop condition to inside the loop.
TIntermSequence insertionsInLoop;
insertionsInLoop.push_back(createTempAssignment(node));
mInsertions.push_back(NodeInsertMultipleEntry(body, body->getSequence()->size() - 1,
empty, insertionsInLoop));
}
else
{
// The loop initializer expression and one copy of the part of the loop condition are
// executed before the loop. They need to be in a new scope.
TIntermAggregate *loopScope = new TIntermAggregate(EOpSequence);
TIntermNode *initializer = mParentLoop->getInit();
if (initializer != nullptr)
{
// Move the initializer to the newly created outer scope, so that condition can
// depend on it.
mReplacements.push_back(NodeUpdateEntry(mParentLoop, initializer, nullptr, false));
loopScope->getSequence()->push_back(initializer);
}
loopScope->getSequence()->push_back(createTempInitDeclaration(node));
loopScope->getSequence()->push_back(mParentLoop);
mReplacements.push_back(NodeUpdateEntry(mLoopParent, mParentLoop, loopScope, true));
// The second copy of the part of the loop condition is executed inside the loop.
TIntermSequence insertionsInLoop;
insertionsInLoop.push_back(createTempAssignment(node->deepCopy()));
mInsertions.push_back(NodeInsertMultipleEntry(body, body->getSequence()->size() - 1,
empty, insertionsInLoop));
}
return true;
}
if (mInLoopExpression)
{
TIntermTyped *movedExpression = mParentLoop->getExpression();
mReplacements.push_back(NodeUpdateEntry(mParentLoop, movedExpression, nullptr, false));
TIntermAggregate *body = mParentLoop->getBody();
TIntermSequence empty;
TIntermSequence insertions;
insertions.push_back(movedExpression);
mInsertions.push_back(
NodeInsertMultipleEntry(body, body->getSequence()->size() - 1, empty, insertions));
return true;
}
return false;
} }
void UnfoldShortCircuitTraverser::nextIteration() void UnfoldShortCircuitTraverser::nextIteration()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment