Commit 1d9dcc24 by Olli Etuaho Committed by Commit Bot

Make AST path always include the current node being traversed

AST traversers tend to sometimes call traverse() functions manually during PreVisit. Change TIntermTraverser so that even if this happens, all the nodes are automatically added to the traversal path, instead of having to add them manually in each individual AST traverser. This also makes calling getParentNode() return the correct node during InVisit. This does cause the same node being added to the traversal path twice in some cases, where nodes are repeatedly traversed, like in OutputHLSL, but this should not have adverse side effects. The more common case is that the traverse() function is called on the children of the node being currently traversed. This fixes a bug in OVR_multiview validation, which did not previously call incrementDepth and decrementDepth when it should have. BUG=angleproject:1725 TEST=angle_unittests, angle_end2end_tests Change-Id: I6ae762eef760509ebe853eefa37dac28c16e7a9b Reviewed-on: https://chromium-review.googlesource.com/430732 Commit-Queue: Olli Etuaho <oetuaho@nvidia.com> Reviewed-by: 's avatarJamie Madill <jmadill@chromium.org>
parent 0894288d
...@@ -962,6 +962,7 @@ class TIntermTraverser : angle::NonCopyable ...@@ -962,6 +962,7 @@ class TIntermTraverser : angle::NonCopyable
void useTemporaryIndex(unsigned int *temporaryIndex); void useTemporaryIndex(unsigned int *temporaryIndex);
protected: protected:
// Should only be called from traverse*() functions
void incrementDepth(TIntermNode *current) void incrementDepth(TIntermNode *current)
{ {
mDepth++; mDepth++;
...@@ -969,20 +970,35 @@ class TIntermTraverser : angle::NonCopyable ...@@ -969,20 +970,35 @@ class TIntermTraverser : angle::NonCopyable
mPath.push_back(current); mPath.push_back(current);
} }
// Should only be called from traverse*() functions
void decrementDepth() void decrementDepth()
{ {
mDepth--; mDepth--;
mPath.pop_back(); mPath.pop_back();
} }
TIntermNode *getParentNode() { return mPath.size() == 0 ? NULL : mPath.back(); } // RAII helper for incrementDepth/decrementDepth
class ScopedNodeInTraversalPath
{
public:
ScopedNodeInTraversalPath(TIntermTraverser *traverser, TIntermNode *current)
: mTraverser(traverser)
{
mTraverser->incrementDepth(current);
}
~ScopedNodeInTraversalPath() { mTraverser->decrementDepth(); }
private:
TIntermTraverser *mTraverser;
};
TIntermNode *getParentNode() { return mPath.size() <= 1 ? nullptr : mPath[mPath.size() - 2u]; }
// Return the nth ancestor of the node being traversed. getAncestorNode(0) == getParentNode() // Return the nth ancestor of the node being traversed. getAncestorNode(0) == getParentNode()
TIntermNode *getAncestorNode(unsigned int n) TIntermNode *getAncestorNode(unsigned int n)
{ {
if (mPath.size() > n) if (mPath.size() > n + 1u)
{ {
return mPath[mPath.size() - n - 1u]; return mPath[mPath.size() - n - 2u];
} }
return nullptr; return nullptr;
} }
...@@ -991,11 +1007,6 @@ class TIntermTraverser : angle::NonCopyable ...@@ -991,11 +1007,6 @@ class TIntermTraverser : angle::NonCopyable
void incrementParentBlockPos(); void incrementParentBlockPos();
void popParentBlock(); void popParentBlock();
bool parentNodeIsBlock()
{
return !mParentBlockStack.empty() && getParentNode() == mParentBlockStack.back().node;
}
// To replace a single node with multiple nodes on the parent aggregate node // To replace a single node with multiple nodes on the parent aggregate node
struct NodeReplaceWithMultipleEntry struct NodeReplaceWithMultipleEntry
{ {
...@@ -1086,9 +1097,6 @@ class TIntermTraverser : angle::NonCopyable ...@@ -1086,9 +1097,6 @@ class TIntermTraverser : angle::NonCopyable
int mDepth; int mDepth;
int mMaxDepth; int mMaxDepth;
// All the nodes from root to the current node's parent during traversing.
TVector<TIntermNode *> mPath;
bool mInGlobalScope; bool mInGlobalScope;
// During traversing, save all the changes that need to happen into // During traversing, save all the changes that need to happen into
...@@ -1131,6 +1139,9 @@ class TIntermTraverser : angle::NonCopyable ...@@ -1131,6 +1139,9 @@ class TIntermTraverser : angle::NonCopyable
std::vector<NodeUpdateEntry> mReplacements; std::vector<NodeUpdateEntry> mReplacements;
// All the nodes from root to the current node during traversing.
TVector<TIntermNode *> mPath;
// All the code blocks from the root to the current node's parent during traversal. // All the code blocks from the root to the current node's parent during traversal.
std::vector<ParentBlock> mParentBlockStack; std::vector<ParentBlock> mParentBlockStack;
......
...@@ -761,7 +761,6 @@ bool TOutputGLSLBase::visitIfElse(Visit visit, TIntermIfElse *node) ...@@ -761,7 +761,6 @@ bool TOutputGLSLBase::visitIfElse(Visit visit, TIntermIfElse *node)
node->getCondition()->traverse(this); node->getCondition()->traverse(this);
out << ")\n"; out << ")\n";
incrementDepth(node);
visitCodeBlock(node->getTrueBlock()); visitCodeBlock(node->getTrueBlock());
if (node->getFalseBlock()) if (node->getFalseBlock())
...@@ -769,7 +768,6 @@ bool TOutputGLSLBase::visitIfElse(Visit visit, TIntermIfElse *node) ...@@ -769,7 +768,6 @@ bool TOutputGLSLBase::visitIfElse(Visit visit, TIntermIfElse *node)
out << "else\n"; out << "else\n";
visitCodeBlock(node->getFalseBlock()); visitCodeBlock(node->getFalseBlock());
} }
decrementDepth();
return false; return false;
} }
...@@ -812,7 +810,6 @@ bool TOutputGLSLBase::visitBlock(Visit visit, TIntermBlock *node) ...@@ -812,7 +810,6 @@ bool TOutputGLSLBase::visitBlock(Visit visit, TIntermBlock *node)
out << "{\n"; out << "{\n";
} }
incrementDepth(node);
for (TIntermSequence::const_iterator iter = node->getSequence()->begin(); for (TIntermSequence::const_iterator iter = node->getSequence()->begin();
iter != node->getSequence()->end(); ++iter) iter != node->getSequence()->end(); ++iter)
{ {
...@@ -823,7 +820,6 @@ bool TOutputGLSLBase::visitBlock(Visit visit, TIntermBlock *node) ...@@ -823,7 +820,6 @@ bool TOutputGLSLBase::visitBlock(Visit visit, TIntermBlock *node)
if (isSingleStatement(curNode)) if (isSingleStatement(curNode))
out << ";\n"; out << ";\n";
} }
decrementDepth();
// Scope the blocks except when at the global scope. // Scope the blocks except when at the global scope.
if (mDepth > 0) if (mDepth > 0)
...@@ -835,11 +831,9 @@ bool TOutputGLSLBase::visitBlock(Visit visit, TIntermBlock *node) ...@@ -835,11 +831,9 @@ bool TOutputGLSLBase::visitBlock(Visit visit, TIntermBlock *node)
bool TOutputGLSLBase::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) bool TOutputGLSLBase::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
{ {
incrementDepth(node);
TIntermFunctionPrototype *prototype = node->getFunctionPrototype(); TIntermFunctionPrototype *prototype = node->getFunctionPrototype();
prototype->traverse(this); prototype->traverse(this);
visitCodeBlock(node->getBody()); visitCodeBlock(node->getBody());
decrementDepth();
// Fully processed; no need to visit children. // Fully processed; no need to visit children.
return false; return false;
...@@ -986,8 +980,6 @@ bool TOutputGLSLBase::visitLoop(Visit visit, TIntermLoop *node) ...@@ -986,8 +980,6 @@ bool TOutputGLSLBase::visitLoop(Visit visit, TIntermLoop *node)
{ {
TInfoSinkBase &out = objSink(); TInfoSinkBase &out = objSink();
incrementDepth(node);
TLoopType loopType = node->getType(); TLoopType loopType = node->getType();
if (loopType == ELoopFor) // for loop if (loopType == ELoopFor) // for loop
...@@ -1029,8 +1021,6 @@ bool TOutputGLSLBase::visitLoop(Visit visit, TIntermLoop *node) ...@@ -1029,8 +1021,6 @@ bool TOutputGLSLBase::visitLoop(Visit visit, TIntermLoop *node)
out << ");\n"; out << ");\n";
} }
decrementDepth();
// No need to visit children. They have been already processed in // No need to visit children. They have been already processed in
// this function. // this function.
return false; return false;
......
...@@ -919,11 +919,9 @@ void OutputHLSL::outputEqual(Visit visit, const TType &type, TOperator op, TInfo ...@@ -919,11 +919,9 @@ void OutputHLSL::outputEqual(Visit visit, const TType &type, TOperator op, TInfo
} }
} }
bool OutputHLSL::ancestorEvaluatesToSamplerInStruct(Visit visit) bool OutputHLSL::ancestorEvaluatesToSamplerInStruct()
{ {
// Inside InVisit the current node is already in the path. for (unsigned int n = 0u; getAncestorNode(n) != nullptr; ++n)
const unsigned int initialN = visit == InVisit ? 1u : 0u;
for (unsigned int n = initialN; getAncestorNode(n) != nullptr; ++n)
{ {
TIntermNode *ancestor = getAncestorNode(n); TIntermNode *ancestor = getAncestorNode(n);
const TIntermBinary *ancestorBinary = ancestor->getAsBinaryNode(); const TIntermBinary *ancestorBinary = ancestor->getAsBinaryNode();
...@@ -1126,7 +1124,7 @@ bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node) ...@@ -1126,7 +1124,7 @@ bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node)
return false; return false;
} }
} }
else if (ancestorEvaluatesToSamplerInStruct(visit)) else if (ancestorEvaluatesToSamplerInStruct())
{ {
// All parts of an expression that access a sampler in a struct need to use _ as // All parts of an expression that access a sampler in a struct need to use _ as
// separator to access the sampler variable that has been moved out of the struct. // separator to access the sampler variable that has been moved out of the struct.
...@@ -1163,7 +1161,7 @@ bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node) ...@@ -1163,7 +1161,7 @@ bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node)
{ {
// All parts of an expression that access a sampler in a struct need to use _ as // All parts of an expression that access a sampler in a struct need to use _ as
// separator to access the sampler variable that has been moved out of the struct. // separator to access the sampler variable that has been moved out of the struct.
indexingReturnsSampler = ancestorEvaluatesToSamplerInStruct(visit); indexingReturnsSampler = ancestorEvaluatesToSamplerInStruct();
} }
if (visit == InVisit) if (visit == InVisit)
{ {
......
...@@ -239,7 +239,7 @@ class OutputHLSL : public TIntermTraverser ...@@ -239,7 +239,7 @@ class OutputHLSL : public TIntermTraverser
private: private:
TString samplerNamePrefixFromStruct(TIntermTyped *node); TString samplerNamePrefixFromStruct(TIntermTyped *node);
bool ancestorEvaluatesToSamplerInStruct(Visit visit); bool ancestorEvaluatesToSamplerInStruct();
}; };
} }
......
...@@ -135,9 +135,7 @@ void SimplifyLoopConditionsTraverser::traverseLoop(TIntermLoop *node) ...@@ -135,9 +135,7 @@ void SimplifyLoopConditionsTraverser::traverseLoop(TIntermLoop *node)
// Mark that we're inside a loop condition or expression, and transform the loop if needed. // Mark that we're inside a loop condition or expression, and transform the loop if needed.
incrementDepth(node); ScopedNodeInTraversalPath addToPath(this, node);
// Note: No need to traverse the loop init node.
mInsideLoopInitConditionOrExpression = true; mInsideLoopInitConditionOrExpression = true;
TLoopType loopType = node->getType(); TLoopType loopType = node->getType();
...@@ -274,8 +272,7 @@ void SimplifyLoopConditionsTraverser::traverseLoop(TIntermLoop *node) ...@@ -274,8 +272,7 @@ void SimplifyLoopConditionsTraverser::traverseLoop(TIntermLoop *node)
ELoopWhile, nullptr, createTempSymbol(conditionInitializer->getType()), nullptr, ELoopWhile, nullptr, createTempSymbol(conditionInitializer->getType()), nullptr,
whileLoopBody); whileLoopBody);
loopScope->getSequence()->push_back(whileLoop); loopScope->getSequence()->push_back(whileLoop);
queueReplacementWithParent(getAncestorNode(1), node, loopScope, queueReplacement(node, loopScope, OriginalNode::IS_DROPPED);
OriginalNode::IS_DROPPED);
} }
} }
...@@ -283,8 +280,6 @@ void SimplifyLoopConditionsTraverser::traverseLoop(TIntermLoop *node) ...@@ -283,8 +280,6 @@ void SimplifyLoopConditionsTraverser::traverseLoop(TIntermLoop *node)
if (!mFoundLoopToChange && node->getBody()) if (!mFoundLoopToChange && node->getBody())
node->getBody()->traverse(this); node->getBody()->traverse(this);
decrementDepth();
} }
} // namespace } // namespace
......
...@@ -381,3 +381,19 @@ TEST_F(WEBGLMultiviewVertexShaderTest, ValidUseOfExtensionMacros) ...@@ -381,3 +381,19 @@ TEST_F(WEBGLMultiviewVertexShaderTest, ValidUseOfExtensionMacros)
FAIL() << "Shader compilation failed, expecting success:\n" << mInfoLog; FAIL() << "Shader compilation failed, expecting success:\n" << mInfoLog;
} }
} }
// Test that the parent node is tracked correctly when validating assignment to gl_Position.
TEST_F(WEBGLMultiviewVertexShaderTest, AssignmentWithViewIDInsideAssignment)
{
const std::string &shaderString =
"#version 300 es\n"
"#extension GL_OVR_multiview : require\n"
"void main()\n"
"{\n"
" gl_Position.y = (gl_Position.x = (gl_ViewID_OVR == 0u) ? 1.0 : 0.0);\n"
"}\n";
if (compile(shaderString))
{
FAIL() << "Shader compilation succeeded, expecting failure:\n" << mInfoLog;
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment